!===============================================================================
! Copyright (C) 2021 Intel Corporation
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

!  Content:
!      Intel(R) oneAPI Math Kernel Library (oneMKL)
!      FORTRAN OpenMP offload example for DGEMM_BATCH
!*******************************************************************************

include "mkl_omp_offload.f90"
include "common_blas.f90"

program dgemm_batch_example
#if defined(MKL_ILP64)
use onemkl_blas_omp_offload_ilp64
#else
use onemkl_blas_omp_offload_lp64
#endif
use common_blas
use, intrinsic :: ISO_C_BINDING

character*1 :: ta(2), tb(2)
double precision :: alpha(2), beta(2)
integer :: passed
integer :: m(2), n(2), k(2)
double precision,allocatable,target :: a(:,:,:), b(:,:,:), c(:,:,:)
double precision,allocatable :: c_ref(:,:,:)
integer :: lda(2), cola(2), ldb(2), colb(2)
integer :: ldc(2), colc(2)
integer(KIND=C_SIZE_T),allocatable :: a_array(:), b_array(:), c_array(:), c_ref_array(:)
integer(KIND=C_SIZE_T),allocatable :: a_array_dev(:), b_array_dev(:), c_array_dev(:)
double precision,pointer :: tmp_a(:,:), tmp_b(:,:), tmp_c(:,:)
integer :: group_size(2), group_count = 2, total_batch_size = 0, i
integer :: max_lda = 0, max_ldb = 0, max_ldc = 0, max_cola = 0, max_colb = 0, max_colc = 0

do i = 1, group_count
  m(i) = i + 10
  k(i) = i + 5
  n(i) = i + 8
  ta(i) = 'N'
  tb(i) = 'N'
  alpha(i) = 1.0
  beta(i) = 1.0
  if (ta(i).eq.'N') then
    lda(i) = m(i)
    cola(i) = k(i)
  else
    lda(i) = k(i)
    cola(i) = m(i)
  end if

  if (tb(i).eq.'N') then
    ldb(i) = k(i)
    colb(i) = n(i)
  else
    ldb(i) = n(i)
    colb(i) = k(i)
  end if

  ldc(i) = m(i)
  colc(i) = n(i)

  group_size(i) = 4 + i
  total_batch_size = total_batch_size + group_size(i)

  if (max_lda.lt.lda(i)) max_lda = lda(i)
  if (max_ldb.lt.ldb(i)) max_ldb = ldb(i)
  if (max_ldc.lt.ldc(i)) max_ldc = ldc(i)

  if (max_cola.lt.cola(i)) max_cola = cola(i)
  if (max_colb.lt.colb(i)) max_colb = colb(i)
  if (max_colc.lt.colc(i)) max_colc = colc(i)
end do

allocate(a(max_lda,max_cola,total_batch_size))
allocate(b(max_ldb,max_colb,total_batch_size))
allocate(c(max_ldc,max_colc,total_batch_size))
allocate(c_ref(max_ldc,max_colc,total_batch_size))
allocate(a_array(total_batch_size))
allocate(b_array(total_batch_size))
allocate(c_array(total_batch_size))
allocate(c_ref_array(total_batch_size))
allocate(a_array_dev(total_batch_size))
allocate(b_array_dev(total_batch_size))
allocate(c_array_dev(total_batch_size))

if ((.not.allocated(a)) .or. (.not.allocated(b)) .or. (.not.allocated(c)) .or. (.not.allocated(c_ref))) then
  print *, "Cannot allocate matrices"
  goto 998
end if

if ((.not.allocated(a_array)) .or. (.not.allocated(b_array)) .or. (.not.allocated(c_array)) .or. (.not.allocated(c_ref_array))) then
  print *, "Cannot allocate array of pointers"
  goto 998
end if

if ((.not.allocated(a_array_dev)) .or. (.not.allocated(b_array_dev)) .or. (.not.allocated(c_array_dev))) then
  print *, "Cannot allocate array of device pointers"
  goto 998
end if

call dinit_batch_matrix('N', max_lda, max_cola, max_lda, a, total_batch_size)
call dinit_batch_matrix('N', max_ldb, max_colb, max_ldb, b, total_batch_size)
call dinit_batch_matrix('N', max_ldc, max_colc, max_ldc, c, total_batch_size)
call dcopy_batch_matrix('N', max_ldc, max_colc, max_ldc, c, c_ref, total_batch_size)

do i = 1, total_batch_size
  a_array(i) = LOC(a(1,1,i))
  b_array(i) = LOC(b(1,1,i))
  c_array(i) = LOC(c(1,1,i))
  c_ref_array(i) = LOC(c_ref(1,1,i))
end do

call dgemm_batch(ta, tb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_ref_array, ldc, group_count, group_size)
! map each matrix to the device and store the device pointers into arrays 
do i = 1, total_batch_size
!$omp target enter data map(to:a(:,:,i),b(:,:,i),c(:,:,i))
  tmp_a => a(:,:,i)
  tmp_b => b(:,:,i)
  tmp_c => c(:,:,i)
!$omp target data use_device_addr(tmp_a,tmp_b,tmp_c)
  a_array_dev(i) = LOC(tmp_a)
  b_array_dev(i) = LOC(tmp_b)
  c_array_dev(i) = LOC(tmp_c)
!$omp end target data
end do

!$omp target data map(to:a_array_dev,b_array_dev) map(tofrom:c_array_dev)
!$omp dispatch
call dgemm_batch(ta, tb, m, n, k, alpha, a_array_dev, lda, b_array_dev, ldb, beta, c_array_dev, ldc, group_count, group_size)
!$omp end target data

do i = 1, total_batch_size
!$omp target exit data map(from:a(:,:,i),b(:,:,i),c(:,:,i))
end do

passed = dcheck_batch_matrix(max_ldc, max_colc, max_ldc, c, c_ref, total_batch_size)

deallocate(a);
deallocate(b);
deallocate(c);
deallocate(c_ref);
deallocate(a_array);
deallocate(a_array_dev);
deallocate(b_array);
deallocate(b_array_dev);
deallocate(c_array);
deallocate(c_ref_array);
deallocate(c_array_dev);

if (passed.ne.0) then
  goto 999
else
  print *, "PASSED"
end if

stop
998 print *, 'Error: cannot allocate memory'
999 stop 1
end program
