/*******************************************************************************
* Copyright 2019-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       API oneapi::mkl::dft to perform 2-D Double Precision Real to Complex
*       Fast-Fourier Transform on a SYCL device (Host, CPU, GPU).
*
*       The supported floating point data types for data are:
*           double
*           std::complex<double>
*
*******************************************************************************/

#define _USE_MATH_DEFINES
#include <cmath>
#include <vector>
#include <iostream>
#include <CL/sycl.hpp>
#include "oneapi/mkl/dfti.hpp"

#include <stdexcept>
#include <cfloat>
#include <cstddef>
#include <limits>
#include <type_traits>
#include "mkl.h"

// local includes
#define NO_MATRIX_HELPERS
#include "common_for_examples.hpp"

typedef oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL> descriptor_t;

constexpr int SUCCESS = 0;
constexpr int FAILURE = 1;
constexpr double TWOPI = 6.2831853071795864769;

// Compute (K*L)%M accurately
static double moda(int K, int L, int M)
{
    return (double)(((long long)K * L) % M);
}

// Initialize array data(N) to produce unit peaks at data(H) and data(N-H)
static void init_r(double *data, int N1, int N2, int H1, int H2)
{
    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1;

    double factor =
        ((2 * (N1 - H1) % N1 == 0) &&
         (2 * (N2 - H2) % N2 == 0)) ? 1.0 : 2.0;

    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            double phase = moda(n1, H1, N1) / N1 +
                                 moda(n2, H2, N2) / N2;
            int index = n2*S2 + n1*S1;
            data[index] = factor * cos(TWOPI * phase) / (N2*N1);
        }
    }
}

// Verify that data has unit peak at H
static int verify_c(double* data,
                    int N1, int N2, int H1, int H2)
{
    // Note: this simple error bound doesn't take into account error of
    //       input data
    double errthr = 2.5 * log((double) N2*N1) / log(2.0) * DBL_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1/2+1;

    double maxerr = 0.0;
    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1/2+1; n1++) {
            double re_exp = (
                    ((n1 - H1) % N1 == 0) &&
                    ((n2 - H2) % N2 == 0)
                ) || (
                    ((-n1 - H1) % N1 == 0) &&
                    ((-n2 - H2) % N2 == 0)
                ) ? 1.0 : 0.0;
            double im_exp = 0.0;

            int index = n2*S2 + n1*S1;
            double re_got = data[index*2+0];
            double im_got = data[index*2+1];
            double err  = fabs(re_got - re_exp) + fabs(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tdata"
                          << "[" << n2 << "][" << n1 << "]: "
                          << " expected (" << re_exp << "," << im_exp << ")"
                          << " got (" << re_got << "," << im_got << ")"
                          << " err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                return FAILURE;
            }
        }
    }
    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}

// Initialize array data(N) to produce unit peak at data(H)
static void init_c(double *data, int N1, int N2, int H1, int H2)
{
    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1/2+1;

    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1/2+1; n1++) {
            double phase = moda(n1, H1, N1) / N1 +
                                 moda(n2, H2, N2) / N2;
            int index = n2*S2 + n1*S1;
            data[index*2+0] =  cos(TWOPI * phase) / (N2*N1);
            data[index*2+1] = -sin(TWOPI * phase) / (N2*N1);
        }
    }
}

/* Verify that data has unit peak at H */
static int verify_r(double *data,
                    int N1, int N2, int H1, int H2)
{
    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1;

    // Note: this simple error bound doesn't take into account error of
    //       input data
    double errthr = 2.5 * log((double) N2*N1) / log(2.0) * DBL_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    double maxerr = 0.0;
    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            double re_exp = (
                ((n1 - H1) % N1 == 0) &&
                ((n2 - H2) % N2 == 0)) ? 1.0 : 0.0;

            int index = n2*S2 + n1*S1;
            double re_got = data[index];
            double err  = fabs(re_got - re_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tdata"
                          << "[" << n2 << "][" << n1 << "]: "
                          << " expected (" << re_exp << ")"
                          << " got (" << re_got << ")"
                          << " err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                return FAILURE;
            }
        }
    }
    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}

static int verify_round(double *data, int N1, int N2, int H1, int H2)
{
    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1;
    
    double errthr = 2.5 * log((double) N2*N1) / log(2.0) * DBL_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    const double normalizationConstant = 1.0 / (N1 * N2);
    double maxerr = 0.0;
    double factor =
        ((2 * (N1 - H1) % N1 == 0) &&
         (2 * (N2 - H2) % N2 == 0)) ? 1.0 : 2.0;

    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            double phase = moda(n1, H1, N1) / N1 +
                                moda(n2, H2, N2) / N2;
            int index = n2*S2 + n1*S1;
            double re_exp = factor * cos(TWOPI * phase) / (N2*N1);
            double re_got = normalizationConstant * data[index];
            double err  = fabs(re_got - re_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tdata"
                            << "[" << n2 << "][" << n1 << "]: "
                            << " expected (" << re_exp << ")"
                            << " got (" << re_got << ")"
                            << " err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                return FAILURE;
            }
        }
    }

    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}

int run_dft_forward_example(cl::sycl::device &dev) {
    //
    // Initialize data for DFT
    //
    int N1 = 4, N2 = 5;
    // Arbitrary harmonic used to verify FFT
    int H1 = -1, H2 = 2;
    // Strides describe data layout in real and conjugate-even domain
    std::int64_t rs[3] = {0, N1, 1};
    std::int64_t cs[3] = {0, (N1/2+1), 1};
    int buffer_result = FAILURE;
    int usm_result = FAILURE;
    int result = FAILURE;

    double* x = (double*) mkl_malloc(N2*N1*sizeof(double), 64);
    init_r(x, N1, N2, H1, H2);

    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (cl::sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(cl::sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };

        // create execution queue with asynchronous error handling
        cl::sycl::queue queue(dev, exception_handler);

        // Setting up SYCL buffer and initialization
        cl::sycl::buffer<double, 1> xBuffer(x, cl::sycl::range<1>(N2*N1));
        cl::sycl::buffer<double, 1> yBuffer(cl::sycl::range<1>(N2*(N1/2 + 1)*2));
        xBuffer.set_write_back(false);
        yBuffer.set_write_back(false);

        // Setting up USM and initialization
        double *x_usm = (double*) malloc_shared(N2*N1*sizeof(double), queue.get_device(), queue.get_context());
        double *y_usm = (double*) malloc_shared((N2*(N1/2 + 1)*2)*sizeof(double), queue.get_device(), queue.get_context());
        init_r(x_usm, N1, N2, H1, H2);

        descriptor_t desc({N2, N1});
        desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
        desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs);
        desc.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, cs);
        desc.commit(queue);

        // Using SYCL buffers
        std::cout<<"\tUsing SYCL buffers"<<std::endl;
        oneapi::mkl::dft::compute_forward(desc, xBuffer, yBuffer);

        auto yAcc = yBuffer.get_access<cl::sycl::access::mode::read>();
        buffer_result = verify_c(yAcc.get_pointer(), N1, N2, H1, H2);

        // Using USM
        std::cout<<"\tUsing USM"<<std::endl;
        auto fwd = oneapi::mkl::dft::compute_forward(desc, x_usm, y_usm);
        fwd.wait();
        usm_result = verify_c(y_usm, N1, N2, H1, H2);

        if ((buffer_result == SUCCESS) && (usm_result == SUCCESS))
            result = SUCCESS;

        free(x_usm, queue.get_context());
        free(y_usm, queue.get_context());
    }
    catch(cl::sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << get_error_code(e) << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(x);

    return result;
}

int run_dft_backward_example(cl::sycl::device &dev) {
    //
    // Initialize data for DFT
    //
    int N1 = 4, N2 = 5;
    // Arbitrary harmonic used to verify FFT
    int H1 = -1, H2 = 2;
    // Strides describe data layout in real and conjugate-even domain
    std::int64_t rs[4] = {0, N1, 1};
    std::int64_t cs[4] = {0, (N1/2+1), 1};
    int buffer_result = FAILURE;
    int usm_result = FAILURE;
    int result = FAILURE;

    double* x = (double*) mkl_malloc((N2*(N1/2 + 1)*2)*sizeof(double), 64);
    init_c(x, N1, N2, H1, H2);
    
    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (cl::sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(cl::sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };

        // create execution queue with asynchronous error handling
        cl::sycl::queue queue(dev, exception_handler);

        // Setting up SYCL buffer and initialization
        cl::sycl::buffer<double, 1> xBuffer(x, cl::sycl::range<1>(N2*(N1/2 + 1)*2));
        cl::sycl::buffer<double, 1> yBuffer(cl::sycl::range<1>(N2*N1));

        // Setting up USM and initialization
        double *x_usm = (double*) malloc_shared((N2*(N1/2+1)*2)*sizeof(double), queue.get_device(), queue.get_context());
        double *y_usm = (double*) malloc_shared(N2*N1*sizeof(double), queue.get_device(), queue.get_context());
        init_c(x_usm, N1, N2, H1, H2);

        descriptor_t desc({N2, N1});
        desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
        desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, cs);
        desc.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, rs);
        desc.commit(queue);

        // Using SYCL buffers
        std::cout<<"\tusing SYCL buffers"<<std::endl;
        oneapi::mkl::dft::compute_backward(desc, xBuffer, yBuffer);
        auto yAcc = yBuffer.get_access<cl::sycl::access::mode::read>();
        buffer_result = verify_r(yAcc.get_pointer(), N1, N2, H1, H2);

        // Using USM
        std::cout<<"\tUsing USM"<<std::endl;
        auto bwd = oneapi::mkl::dft::compute_backward(desc, x_usm, y_usm);
        bwd.wait();
        usm_result = verify_r(y_usm, N1, N2, H1, H2);

        if ((buffer_result == SUCCESS) && (usm_result == SUCCESS))
            result = SUCCESS;

        free(x_usm, queue.get_context());
        free(y_usm, queue.get_context());
    }
    catch(cl::sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << get_error_code(e) << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(x);

    return result;
}

int run_dft_roundtrip_example(cl::sycl::device &dev) {
    //
    // Initialize data for DFT
    //
    int N1 = 4, N2 = 5;
    // Arbitrary harmonic used to verify FFT
    int H1 = -1, H2 = 2;
    // Strides describe data layout in real and conjugate-even domain

    std::int64_t cs[3]  {0, (N1/2+1), 1};
    std::int64_t rs[3] {0, N1, 1};
    int buffer_result = FAILURE;
    int usm_result = FAILURE;
    int result = FAILURE;

    double* x = (double *) mkl_malloc(N2*N1*sizeof(double), 64);
    init_r(x, N1, N2, -1, 2);
    
    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (cl::sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(cl::sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };
        cl::sycl::queue queue(dev, exception_handler);
        // Setting up buffer and initialization

        cl::sycl::buffer<double, 1> realInBuffer(x, sycl::range<1>(N2*N1));
        cl::sycl::buffer<double, 1> complexOutBuffer(sycl::range<1>(N2*(N1/2 + 1)*2));
        cl::sycl::buffer<double, 1> realOutBuffer(sycl::range<1>(N2*N1));
        realInBuffer.set_write_back(false);
        complexOutBuffer.set_write_back(false);
        realOutBuffer.set_write_back(false);

        // setting up USM and initialization
        double *realInUSM = (double*) malloc_shared(N2*N1*sizeof(double), queue.get_device(), queue.get_context());
        double *complexOutUSM = (double*) malloc_shared((N2*(N1/2+1)*2)*sizeof(double), queue.get_device(), queue.get_context());
        double *realOutUSM = (double*) malloc_shared(N2*N1*sizeof(double), queue.get_device(), queue.get_context());
        init_r(realInUSM, N1, N2, H1, H2);

        descriptor_t desc({N2, N1});

        //-----------------------------------------------------------------
        // Forward

        desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
        desc.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
        desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs);
        desc.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, cs);
        desc.commit(queue);

        // Using SYCL buffers
        oneapi::mkl::dft::compute_forward(desc, realInBuffer, complexOutBuffer);

        // Using SYCL buffers
        auto bwd = oneapi::mkl::dft::compute_forward(desc, realInUSM, complexOutUSM);
        bwd.wait();

        //--------------------------------------------------------------------
        // Backward

        desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, cs);
        desc.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, rs);
        desc.commit(queue);

        // Using SYCL buffers
        oneapi::mkl::dft::compute_backward(desc, complexOutBuffer, realOutBuffer);
        
        std::cout<<"\tusing SYCL buffers"<<std::endl;
        buffer_result = verify_round((realOutBuffer.get_access<cl::sycl::access::mode::read>()).
                                                get_pointer(), N1, N2, H1, H2);

        // Using USM
        auto fwd = oneapi::mkl::dft::compute_backward(desc, complexOutUSM, realOutUSM);
        fwd.wait();

        std::cout<<"\tUsing USM"<<std::endl;
        usm_result = verify_round(realOutUSM, N1, N2, H1, H2);

        if((buffer_result == SUCCESS) && (usm_result == SUCCESS))
            result = SUCCESS;

        free(realInUSM, queue.get_context());
        free(complexOutUSM, queue.get_context());
        free(realOutUSM, queue.get_context());
    }
    catch(cl::sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << get_error_code(e) << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(x);

    return result;
}

//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# 2D FFT Real-Complex Double-Precision Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   dft" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_host -- only runs host implementation
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: host, cpu and gpu devices
//
//  For each device selected and each supported data type, Basic_Dp_R2C_2D_FFTExample
//  runs is with all supported data types
//
int main() {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int returnCode = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        cl::sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            if (!isDoubleSupported(my_dev)) {
                std::cout << "Double precision not supported on this device " << std::endl;
                std::cout << std::endl;
                continue;
            }
            int status;
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            std::cout << "\tRunning with double precision real-to-complex 2-D FFT:" << std::endl;
            status = run_dft_forward_example(my_dev);
            if (status != SUCCESS) {
                std::cout << "\tTest Forward Failed" << std::endl << std::endl;
                returnCode = status;
            } else {
                std::cout << "\tTest Forward Passed" << std::endl << std::endl;
            }
            status = run_dft_backward_example(my_dev);
            if (status != SUCCESS) {
                std::cout << "\tTest Backward Failed" << std::endl << std::endl;
                returnCode = status;
            } else {
                std::cout << "\tTest Backward Passed" << std::endl << std::endl;
            }
            status = run_dft_roundtrip_example(my_dev);
            if (status != SUCCESS) {
                std::cout << "\tTest Roundtrip Failed" << std::endl << std::endl;
                returnCode = status;
            } else {
                std::cout << "\tTest Roundtrip Passed" << std::endl << std::endl;
            }
        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it] << " devices found; Fail on missing devices is enabled." << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping " << sycl_device_names[*it] << " tests." << std::endl << std::endl;
#endif
        }
    }

    return returnCode;
}
