!===============================================================================
! Copyright 2020-2022 Intel Corporation.
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
!       Example of using sfftw_plan_dft_1d function on a (GPU) device
!       using the OpenMP target (asynchronous offload) interface.
!
!*****************************************************************************

include "fftw/offload/fftw3_omp_offload.f90"

program sp_plan_dft_1d_async

  use FFTW3_OMP_OFFLOAD
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING

  include 'fftw/fftw3.f'

  ! Size of 1D transform
  integer, parameter :: N = 7

  ! Arbitrary harmonic used to verify FFT
  integer :: H = 1

  ! Working precision is single precision (using sfftw_* functions)
  integer, parameter :: WP = selected_real_kind(6,37)

  ! Execution status
  integer :: statusf = 0, statusb = 0, status = 0

  ! FFTW plan variables
  integer*8 :: fwd = 0, bwd = 0

  ! The data array
  complex(WP), allocatable :: x(:)

  print *,"Example sp_plan_dft_1d_async"
  print *,"Forward and backward complex 1D in-place asynchronous  FFT"
  print *,"Configuration parameters:"
  print '("  N = "I0)', N
  print '("  H = "I0)', H

  print *,"Allocate data array"
  allocate ( x(N), STAT = status)
  if (0 /= status) goto 999

  print *,"Initialize data for forward FFT"
  call init(x, N, H)

  print *,"Create FFTW plan for forward in-place transform"
  !$omp target data map(tofrom:x)
  !$omp target variant dispatch use_device_ptr(x)
  call sfftw_plan_dft_1d(fwd, N, x, x, FFTW_FORWARD, FFTW_ESTIMATE)
  !$omp end target variant dispatch
  if (0 == fwd) print *, "Call to sfftw_plan_dft_1d for forward transform has &
                          &failed"

  print *,"Create FFTW plan for backward in-place transform"
  !$omp target variant dispatch use_device_ptr(x)
  call sfftw_plan_dft_1d(bwd, N, x, x, FFTW_BACKWARD, FFTW_ESTIMATE)
  !$omp end target variant dispatch
  if (0 == bwd) print *, "Call to sfftw_plan_dft_1d for backward transform has &
                          &failed"

  print *,"Compute forward transform"
  !$omp target variant dispatch use_device_ptr(x) nowait
  call sfftw_execute_dft(fwd, x, x)
  !$omp end target variant dispatch
  !$omp taskwait

  ! Update the host with the results from forward FFT
  !$omp target update from(x)

  print *,"Verify the result of the forward transform"
  statusf = verify(x, N, H)

  print *,"Initialize data for backward FFT"
  call init(x, N, -H)

  ! Update the device with input for backward FFT
  !$omp target update to(x)

  print *,"Compute backward transform"
  !$omp target variant dispatch use_device_ptr(x) nowait
  call sfftw_execute_dft(bwd, x, x)
  !$omp end target variant dispatch
  !$omp taskwait
  !$omp end target data

  print *,"Verify the result of the backward transform"
  statusb = verify(x, N, H)
  if ((0 /= statusf) .or. (0 /= statusb)) goto 999

100 continue

  print *,"Destroy FFTW plans"
  call sfftw_destroy_plan(fwd)
  call sfftw_destroy_plan(bwd)

  print *,"Deallocate data array"
  deallocate(x)

  if (status == 0) then
    print *,"TEST PASSED"
    call exit(0)
  else
    print *,"TEST FAILED"
    call exit(1)
  endif

999 print '("  Error, status forward = ",I0)', statusf
  print '(" Error, status backward = ",I0)', statusb
  status = 1
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure integer*8 function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = mod(k8*l,m)
  end function moda

  ! Initialize array x(N) with harmonic H
  subroutine init(x, N, H)
    integer N, H
    complex(WP) :: x(:)

    integer k
    complex(WP), parameter :: I_TWOPI = (0.0_WP,6.2831853071795864769_WP)

    do k = 1, N
      x(k) = exp( I_TWOPI * real(moda(k-1, H, N), WP)/N ) / N
    end do
  end subroutine init

  ! Verify that x(N) is unit peak at x(H)
  integer function verify(x, N, H)
    integer N, H
    complex(WP) :: x(:)

    integer k
    real(WP) err, errthr, maxerr
    complex(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 5.0 * log(real(N, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    maxerr = 0.0_WP
    do k = 1, N
      if (mod(k-1-H,N)==0) then
        res_exp = 1.0_WP
      else
        res_exp = 0.0_WP
      end if
      res_got = x(k)
      err = abs(res_got - res_exp)
      maxerr = max(err,maxerr)
      if (.not.(err < errthr)) then
        print '("  x("I0"): "$)', k
        print '(" expected "G14.7","$)', res_exp
        print '(" got "G14.7","$)', res_got
        print '(" err "G10.3)', err
        print *," Verification FAILED"
        verify = 1
        return
      end if
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify = 0
  end function verify

end program sp_plan_dft_1d_async
