/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP API oneapi::mkl::sparse::matmat to perform general
*       sparse matrix-sparse matrix multiplication on a SYCL device (Host, CPU, GPU).
*
*           C = op(A) * op(B)
*
*       where op() is defined by one of
*
*           oneapi::mkl::transpose::{nontrans,trans}
*
*       It uses the simplified API usage model where the user only handles memory
*       allocation for the final C Matrix, and any temporary workspaces are
*       handled by the library.
*
*       The supported floating point data types for matmat matrix data are:
*           float
*           double
*
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <CL/sycl.hpp>

// local includes
#include "../common/common_for_examples.hpp"
#include "common_for_sparse_examples.hpp"

//
// Main example for Sparse Matrix-Sparse Matrix Multiply consisting of
// initialization of A and B matrices through process of creating C matrix as
// the product
//
// C = op(A) * op(B)
//
// In this case, we multiply a square symmetric A by itself to have C be the
// square of A.
//
template <typename fpType, typename intType>
int run_sparse_matrix_sparse_matrix_multiply_example(const cl::sycl::device &dev)
{
    // Initialize data for Sparse Matrix - Sparse Matrix Multiply
    oneapi::mkl::transpose opA = oneapi::mkl::transpose::nontrans;
    oneapi::mkl::transpose opB = oneapi::mkl::transpose::trans;

    oneapi::mkl::sparse::matrix_view_descr viewA = oneapi::mkl::sparse::matrix_view_descr::general;
    oneapi::mkl::sparse::matrix_view_descr viewB = oneapi::mkl::sparse::matrix_view_descr::general;
    oneapi::mkl::sparse::matrix_view_descr viewC = oneapi::mkl::sparse::matrix_view_descr::general;

    oneapi::mkl::index_base a_index = oneapi::mkl::index_base::one;
    oneapi::mkl::index_base b_index = oneapi::mkl::index_base::one;
    oneapi::mkl::index_base c_index = oneapi::mkl::index_base::one;

    //
    // set up dimensions of matrix products
    //
    intType size = 4;

    intType a_nrows = size * size * size;
    intType a_ncols = a_nrows;
    intType a_nnz   = 27 * a_nrows;
    intType b_nrows = size * size * size;
    intType b_ncols = b_nrows;
    intType b_nnz   = 27 * b_nrows;
    intType c_nrows = size * size * size;
    intType c_ncols = c_nrows;
    // c_nnz is unknown at this point

    //
    // setup A data locally in CSR format
    //
    std::vector<intType, mkl_allocator<intType, 64>> ia;
    std::vector<intType, mkl_allocator<intType, 64>> ja;
    std::vector<fpType, mkl_allocator<fpType, 64>> a;

    ia.resize(a_nrows + 1);
    ja.resize(27 * a_nrows);
    a.resize(27 * a_nrows);

    intType a_ind = a_index == oneapi::mkl::index_base::zero ? 0 : 1;
    generate_sparse_matrix<fpType, intType>(size, ia, ja, a, a_ind);
    a_nnz = ia[a_nrows] - a_ind;

    //
    // setup B data locally in CSR format
    //
    std::vector<intType, mkl_allocator<intType, 64>> ib;
    std::vector<intType, mkl_allocator<intType, 64>> jb;
    std::vector<fpType, mkl_allocator<fpType, 64>> b;

    ib.resize(b_nrows + 1);
    jb.resize(27 * b_nrows);
    b.resize(27 * b_nrows);

    intType b_ind = b_index == oneapi::mkl::index_base::zero ? 0 : 1;
    generate_sparse_matrix<fpType, intType>(size, ib, jb, b, b_ind);
    b_nnz = ib[b_nrows] - b_ind;

    // Catch asynchronous exceptions
    auto exception_handler = [](cl::sycl::exception_list exceptions) {
        for (std::exception_ptr const &e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (cl::sycl::exception const &e) {
                std::cout << "Caught asynchronous SYCL "
                             "exception during sparse::gemv:\n"
                          << e.what() << std::endl;
            }
        }
    };

    std::cout << "\n\t\tsparse::matmat parameters:\n";
    std::cout << "\t\t\topA = " << opA << std::endl;
    std::cout << "\t\t\topB = " << opB << std::endl;

    std::cout << "\t\t\tviewA = " << viewA << std::endl;
    std::cout << "\t\t\tviewB = " << viewB << std::endl;
    std::cout << "\t\t\tviewC = " << viewC << std::endl;

    std::cout << "\t\t\tA_nrows = A_ncols = " << a_nrows << std::endl;
    std::cout << "\t\t\tB_nrows = B_ncols = " << b_nrows << std::endl;
    std::cout << "\t\t\tC_nrows = C_ncols = " << c_nrows << std::endl;

    std::cout << "\t\t\tA_index = " << a_index << std::endl;
    std::cout << "\t\t\tB_index = " << b_index << std::endl;
    std::cout << "\t\t\tC_index = " << c_index << std::endl;


    //
    // Execute Matrix Multiply
    //

    oneapi::mkl::sparse::matrix_handle_t A;
    oneapi::mkl::sparse::matrix_handle_t B;
    oneapi::mkl::sparse::matrix_handle_t C;

    oneapi::mkl::sparse::matmat_descr_t descr = NULL;
    oneapi::mkl::sparse::matmat_request req;

    // create execution main_queue and buffers of matrix data
    cl::sycl::queue main_queue(dev, exception_handler);

    sycl::buffer<intType, 1> a_rowptr(ia.data(), a_nrows + 1);
    sycl::buffer<intType, 1> a_colind(ja.data(), a_nnz);
    sycl::buffer<fpType, 1> a_val(a.data(), a_nnz);
    sycl::buffer<intType, 1> b_rowptr(ib.data(), b_nrows + 1);
    sycl::buffer<intType, 1> b_colind(jb.data(), b_nnz);
    sycl::buffer<fpType, 1> b_val(b.data(), b_nnz);
    sycl::buffer<intType, 1> c_rowptr(c_nrows + 1);

    try {
        oneapi::mkl::sparse::init_matrix_handle(&A);
        oneapi::mkl::sparse::init_matrix_handle(&B);
        oneapi::mkl::sparse::init_matrix_handle(&C);

        oneapi::mkl::sparse::set_csr_data(A, a_nrows, a_ncols, a_index, a_rowptr, a_colind, a_val);

        oneapi::mkl::sparse::set_csr_data(B, b_nrows, b_ncols, b_index, b_rowptr, b_colind, b_val);

        //
        // only c_rowptr exists at this point in process so create some
        // dummy args for now
        //
        sycl::buffer<intType, 1> c_colind_dummy(0);
        sycl::buffer<fpType, 1> c_vals_dummy(0);
        oneapi::mkl::sparse::set_csr_data(C, c_nrows, c_ncols, c_index, c_rowptr, c_colind_dummy,
                                          c_vals_dummy);

        //
        // initialize the matmat descriptor
        //
        oneapi::mkl::sparse::init_matmat_descr(&descr);
        oneapi::mkl::sparse::set_matmat_data(descr, viewA, opA, viewB, opB, viewC);

        //
        // Stage 1:  work estimation
        //
        // Note we are using the simplified form (passing nullptr's) here and in compute
        // where we let the library handle the allocation and lifetime of temporary
        // workspaces internally.
        //
        req = oneapi::mkl::sparse::matmat_request::work_estimation;
        oneapi::mkl::sparse::matmat(main_queue, A, B, C, req, descr, nullptr, nullptr);

        //
        // Stage 2:  compute
        //
        req = oneapi::mkl::sparse::matmat_request::compute;
        oneapi::mkl::sparse::matmat(main_queue, A, B, C, req, descr, nullptr, nullptr);

        //
        // Stage 3:  finalize
        //

        // Step 3.1  get nnz
        req = oneapi::mkl::sparse::matmat_request::get_nnz;
        sycl::buffer<std::int64_t, 1> c_nnzBuf(1);
        oneapi::mkl::sparse::matmat(main_queue, A, B, C, req, descr, &c_nnzBuf, nullptr);

        // Step 3.2  allocate final c matrix arrays
        std::int64_t c_nnz = 0;
        {
            auto c_nnz_acc = c_nnzBuf.template get_access<sycl::access::mode::read>();
            c_nnz          = c_nnz_acc[0];
        }
        sycl::buffer<intType, 1> c_colind(c_nnz);
        sycl::buffer<fpType, 1> c_vals(c_nnz);
        oneapi::mkl::sparse::set_csr_data(C, c_nrows, c_ncols, c_index, c_rowptr, c_colind, c_vals);

        // Step 3.3  finalize into C matrix
        req = oneapi::mkl::sparse::matmat_request::finalize;
        oneapi::mkl::sparse::matmat(main_queue, A, B, C, req, descr, nullptr, nullptr);

        // Print portion of C solution
        {
            auto ic       = c_rowptr.template get_access<sycl::access::mode::read>();
            auto jc       = c_colind.template get_access<sycl::access::mode::read>();
            auto c        = c_vals.template get_access<sycl::access::mode::read>();
            intType c_ind = c_index == oneapi::mkl::index_base::zero ? 0 : 1;
            std::cout << "C matrix [first two rows]:" << std::endl;
            for (intType row = 0; row < std::min(static_cast<intType>(2), c_nrows); ++row) {
                for (intType j = ic[row] - c_ind; j < ic[row + 1] - c_ind; ++j) {
                    intType col = jc[j];
                    fpType val  = c[j];
                    std::cout << "C(" << row + c_ind << ", " << col << ") = " << val << std::endl;
                }
            }
        }

        // clean up
        oneapi::mkl::sparse::release_matmat_descr(&descr);
        oneapi::mkl::sparse::release_matrix_handle(&A);
        oneapi::mkl::sparse::release_matrix_handle(&B);
        oneapi::mkl::sparse::release_matrix_handle(&C);
    }
    catch (cl::sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        oneapi::mkl::sparse::release_matmat_descr(&descr);
        oneapi::mkl::sparse::release_matrix_handle(&A);
        oneapi::mkl::sparse::release_matrix_handle(&B);
        oneapi::mkl::sparse::release_matrix_handle(&C);

        return 1;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        oneapi::mkl::sparse::release_matmat_descr(&descr);
        oneapi::mkl::sparse::release_matrix_handle(&A);
        oneapi::mkl::sparse::release_matrix_handle(&B);
        oneapi::mkl::sparse::release_matrix_handle(&C);

        return 1;
    }

    return 0;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Matrix-Sparse Matrix Multiply Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#    C = op(A) * op(B)" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A and B are sparse matrices in CSR format, and C is the\n "
                 "# sparse matrix product in CSR format"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::matmat" << std::endl;
    std::cout << "#   sparse::init_matmat_descr" << std::endl;
    std::cout << "#   sparse::set_matmat_data" << std::endl;
    std::cout << "#   sparse::release_matmat_descr" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#   sparse::init_matrix_handle" << std::endl;
    std::cout << "#   sparse::set_csr_data" << std::endl;
    std::cout << "#   sparse::release_matrix_handle" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_host -- only runs host implementation
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: host, cpu and gpu devices
//
//  For each device selected and each supported data type, MatrixMultiplyExample
//  runs is with all supported data types
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {

        if (*it == my_sycl_device_types::host_device) {
            std:: cout << "Host device is not currently supported by sparse::matmat" << std::endl;
            break;
        }

        cl::sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            status = run_sparse_matrix_sparse_matrix_multiply_example<float, std::int32_t>(my_dev);
            if (status != 0)
                return status;

            if (my_dev.get_info<cl::sycl::info::device::double_fp_config>().size() != 0) {
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                status = run_sparse_matrix_sparse_matrix_multiply_example<double, std::int32_t>(
                        my_dev);
                if (status != 0)
                    return status;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices "
                         "is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }

    return status;
}
