mkl-devel-dpcpp 2024.0.0__py2.py3-none-manylinux1_x86_64.whl → 2024.2.0__py2.py3-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mkl-devel-dpcpp might be problematic. Click here for more details.

Files changed (81) hide show
  1. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas/buffer_decls.hpp +53 -15
  2. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas/usm_decls.hpp +186 -146
  3. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/dfti.hpp +3 -1
  4. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/lapack/concepts.hpp +55 -0
  5. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/lapack/exceptions.hpp +75 -0
  6. {mkl_devel_dpcpp-2024.0.0.data/data/include/oneapi/mkl → mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/lapack}/lapack.hpp +79 -315
  7. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/lapack/scratchpad.hpp +106 -0
  8. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/lapack.hpp +23 -0
  9. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/engines.hpp +20 -0
  10. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/functions.hpp +2 -0
  11. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/spblas/sparse_auxiliary.hpp +68 -0
  12. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/spblas/sparse_operations.hpp +383 -0
  13. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/spblas/sparse_structures.hpp +194 -0
  14. mkl_devel_dpcpp-2024.2.0.data/data/include/oneapi/mkl/spblas.hpp +32 -0
  15. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/stats.hpp +2 -2
  16. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/buffer.hpp +63 -1
  17. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/decls.hpp +2 -2
  18. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/decls.hpp +1 -1
  19. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/dispatch.hpp +1 -1
  20. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/ep.hpp +1 -1
  21. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/ha.hpp +1 -1
  22. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/la.hpp +1 -1
  23. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/rts.hpp +1 -1
  24. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/detail/scalar.hpp +1 -1
  25. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/device/vm.hpp +1 -1
  26. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/span.hpp +69 -1
  27. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm/usm.hpp +67 -1
  28. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/lib/libmkl_sycl.a +0 -0
  29. mkl_devel_dpcpp-2024.2.0.data/data/lib/libmkl_sycl.so +1 -0
  30. {mkl_devel_dpcpp-2024.0.0.dist-info → mkl_devel_dpcpp-2024.2.0.dist-info}/METADATA +3 -3
  31. mkl_devel_dpcpp-2024.2.0.dist-info/RECORD +79 -0
  32. mkl_devel_dpcpp-2024.0.0.data/data/include/oneapi/mkl/spblas.hpp +0 -963
  33. mkl_devel_dpcpp-2024.0.0.dist-info/RECORD +0 -71
  34. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/bfloat16.hpp +0 -0
  35. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas/buffer.hpp +0 -0
  36. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas/types.hpp +0 -0
  37. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas/usm.hpp +0 -0
  38. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/blas.hpp +0 -0
  39. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/exceptions.hpp +0 -0
  40. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/experimental/data_fitting/interpolate.hpp +0 -0
  41. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/experimental/data_fitting/spline_and_data_params.hpp +0 -0
  42. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/experimental/data_fitting/splines.hpp +0 -0
  43. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/experimental/data_fitting.hpp +0 -0
  44. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/export.hpp +0 -0
  45. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/detail/engine_base.hpp +0 -0
  46. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/bernoulli_impl.hpp +0 -0
  47. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/bits_impl.hpp +0 -0
  48. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/distribution_base.hpp +0 -0
  49. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/engine_base.hpp +0 -0
  50. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/engine_helpers_base.hpp +0 -0
  51. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp +0 -0
  52. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/gaussian_impl.hpp +0 -0
  53. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/lognormal_impl.hpp +0 -0
  54. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mcg31m1_helpers_impl.hpp +0 -0
  55. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp +0 -0
  56. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mcg59_helpers_impl.hpp +0 -0
  57. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp +0 -0
  58. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mrg32k3a_helpers_impl.hpp +0 -0
  59. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mrg32k3a_impl.hpp +0 -0
  60. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/mrg32k3a_skip_ahead_matrix.hpp +0 -0
  61. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/philox4x32x10_helpers_impl.hpp +0 -0
  62. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/philox4x32x10_impl.hpp +0 -0
  63. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/poisson_impl.hpp +0 -0
  64. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/types.hpp +0 -0
  65. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/uniform_bits_impl.hpp +0 -0
  66. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +0 -0
  67. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp +0 -0
  68. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/distributions.hpp +0 -0
  69. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/engine_helpers.hpp +0 -0
  70. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/engines.hpp +0 -0
  71. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/functions.hpp +0 -0
  72. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device/types.hpp +0 -0
  73. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/device.hpp +0 -0
  74. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng/distributions.hpp +0 -0
  75. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/rng.hpp +0 -0
  76. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/types.hpp +0 -0
  77. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl/vm.hpp +0 -0
  78. {mkl_devel_dpcpp-2024.0.0.data → mkl_devel_dpcpp-2024.2.0.data}/data/include/oneapi/mkl.hpp +0 -0
  79. {mkl_devel_dpcpp-2024.0.0.dist-info → mkl_devel_dpcpp-2024.2.0.dist-info}/LICENSE.txt +0 -0
  80. {mkl_devel_dpcpp-2024.0.0.dist-info → mkl_devel_dpcpp-2024.2.0.dist-info}/WHEEL +0 -0
  81. {mkl_devel_dpcpp-2024.0.0.dist-info → mkl_devel_dpcpp-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -123,6 +123,21 @@ ONEMKL_DECLARE_BUF_TRMM(std::complex<double>)
123
123
 
124
124
  #undef ONEMKL_DECLARE_BUF_TRMM
125
125
 
126
+ #define ONEMKL_DECLARE_BUF_TRMM_OOP(T) \
127
+ DLL_EXPORT void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, \
128
+ std::int64_t m, std::int64_t n, \
129
+ T alpha, sycl::buffer<T, 1> &a, std::int64_t lda, \
130
+ sycl::buffer<T, 1> &b, std::int64_t ldb, \
131
+ T beta, sycl::buffer<T, 1> &c, std::int64_t ldc, \
132
+ compute_mode mode = MKL_BLAS_COMPUTE_MODE);
133
+
134
+ ONEMKL_DECLARE_BUF_TRMM_OOP(float)
135
+ ONEMKL_DECLARE_BUF_TRMM_OOP(double)
136
+ ONEMKL_DECLARE_BUF_TRMM_OOP(std::complex<float>)
137
+ ONEMKL_DECLARE_BUF_TRMM_OOP(std::complex<double>)
138
+
139
+ #undef ONEMKL_DECLARE_BUF_TRMM_OOP
140
+
126
141
  #define ONEMKL_DECLARE_BUF_TRSM(T) \
127
142
  DLL_EXPORT void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, \
128
143
  std::int64_t m, std::int64_t n, \
@@ -137,6 +152,21 @@ ONEMKL_DECLARE_BUF_TRSM(std::complex<double>)
137
152
 
138
153
  #undef ONEMKL_DECLARE_BUF_TRSM
139
154
 
155
+ #define ONEMKL_DECLARE_BUF_TRSM_OOP(T) \
156
+ DLL_EXPORT void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, \
157
+ std::int64_t m, std::int64_t n, \
158
+ T alpha, sycl::buffer<T, 1> &a, std::int64_t lda, \
159
+ sycl::buffer<T, 1> &b, std::int64_t ldb, \
160
+ T beta, sycl::buffer<T, 1> &c, std::int64_t ldc, \
161
+ compute_mode mode = MKL_BLAS_COMPUTE_MODE);
162
+
163
+ ONEMKL_DECLARE_BUF_TRSM_OOP(float)
164
+ ONEMKL_DECLARE_BUF_TRSM_OOP(double)
165
+ ONEMKL_DECLARE_BUF_TRSM_OOP(std::complex<float>)
166
+ ONEMKL_DECLARE_BUF_TRSM_OOP(std::complex<double>)
167
+
168
+ #undef ONEMKL_DECLARE_BUF_TRSM_OOP
169
+
140
170
  // Level 2
141
171
 
142
172
  #define ONEMKL_DECLARE_BUF_DGMM(T) \
@@ -467,25 +497,33 @@ ONEMKL_DECLARE_BUF_DOTU(std::complex<double>)
467
497
 
468
498
  #undef ONEMKL_DECLARE_BUF_DOTU
469
499
 
470
- #define ONEMKL_DECLARE_BUF_IAMAX(T) \
471
- DLL_EXPORT void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer<T, 1> &x, std::int64_t incx, \
472
- sycl::buffer<std::int64_t, 1> &result, index_base base=index_base::zero);
500
+ #define ONEMKL_DECLARE_BUF_IAMAX(Tf, Ti) \
501
+ DLL_EXPORT void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer<Tf, 1> &x, std::int64_t incx, \
502
+ sycl::buffer<Ti, 1> &result, index_base base=index_base::zero);
473
503
 
474
- ONEMKL_DECLARE_BUF_IAMAX(float)
475
- ONEMKL_DECLARE_BUF_IAMAX(double)
476
- ONEMKL_DECLARE_BUF_IAMAX(std::complex<float>)
477
- ONEMKL_DECLARE_BUF_IAMAX(std::complex<double>)
504
+ ONEMKL_DECLARE_BUF_IAMAX(float, std::int64_t)
505
+ ONEMKL_DECLARE_BUF_IAMAX(float, std::int32_t)
506
+ ONEMKL_DECLARE_BUF_IAMAX(double, std::int64_t)
507
+ ONEMKL_DECLARE_BUF_IAMAX(double, std::int32_t)
508
+ ONEMKL_DECLARE_BUF_IAMAX(std::complex<float>, std::int64_t)
509
+ ONEMKL_DECLARE_BUF_IAMAX(std::complex<float>, std::int32_t)
510
+ ONEMKL_DECLARE_BUF_IAMAX(std::complex<double>, std::int64_t)
511
+ ONEMKL_DECLARE_BUF_IAMAX(std::complex<double>, std::int32_t)
478
512
 
479
513
  #undef ONEMKL_DECLARE_BUF_IAMAX
480
514
 
481
- #define ONEMKL_DECLARE_BUF_IAMIN(T) \
482
- DLL_EXPORT void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer<T, 1> &x, std::int64_t incx, \
483
- sycl::buffer<std::int64_t, 1> &result, index_base base=index_base::zero);
484
-
485
- ONEMKL_DECLARE_BUF_IAMIN(float)
486
- ONEMKL_DECLARE_BUF_IAMIN(double)
487
- ONEMKL_DECLARE_BUF_IAMIN(std::complex<float>)
488
- ONEMKL_DECLARE_BUF_IAMIN(std::complex<double>)
515
+ #define ONEMKL_DECLARE_BUF_IAMIN(Tf, Ti) \
516
+ DLL_EXPORT void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer<Tf, 1> &x, std::int64_t incx, \
517
+ sycl::buffer<Ti, 1> &result, index_base base=index_base::zero);
518
+
519
+ ONEMKL_DECLARE_BUF_IAMIN(float, std::int64_t)
520
+ ONEMKL_DECLARE_BUF_IAMIN(float, std::int32_t)
521
+ ONEMKL_DECLARE_BUF_IAMIN(double, std::int64_t)
522
+ ONEMKL_DECLARE_BUF_IAMIN(double, std::int32_t)
523
+ ONEMKL_DECLARE_BUF_IAMIN(std::complex<float>, std::int64_t)
524
+ ONEMKL_DECLARE_BUF_IAMIN(std::complex<float>, std::int32_t)
525
+ ONEMKL_DECLARE_BUF_IAMIN(std::complex<double>, std::int64_t)
526
+ ONEMKL_DECLARE_BUF_IAMIN(std::complex<double>, std::int32_t)
489
527
 
490
528
  #undef ONEMKL_DECLARE_BUF_IAMIN
491
529
 
@@ -188,6 +188,27 @@ ONEMKL_DECLARE_TRMM(std::complex<double>)
188
188
 
189
189
  #undef ONEMKL_DECLARE_TRMM
190
190
 
191
+ #define ONEMKL_DECLARE_TRMM_OOP(T) \
192
+ DLL_EXPORT sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \
193
+ diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer<T> alpha, const T *a, std::int64_t lda, \
194
+ const T *b, std::int64_t ldb, value_or_pointer<T> beta, T *c, std::int64_t ldc, \
195
+ compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
196
+ ONEMKL_INLINE_DECLARE sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \
197
+ diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer<T> alpha, const T *a, std::int64_t lda, \
198
+ const T *b, std::int64_t ldb, value_or_pointer<T> beta, T *c, std::int64_t ldc, \
199
+ const std::vector<sycl::event> &dependencies = {}) \
200
+ { \
201
+ return trmm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, \
202
+ beta, c, ldc, MKL_BLAS_COMPUTE_MODE, dependencies); \
203
+ }
204
+
205
+ ONEMKL_DECLARE_TRMM_OOP(float)
206
+ ONEMKL_DECLARE_TRMM_OOP(double)
207
+ ONEMKL_DECLARE_TRMM_OOP(std::complex<float>)
208
+ ONEMKL_DECLARE_TRMM_OOP(std::complex<double>)
209
+
210
+ #undef ONEMKL_DECLARE_TRMM_OOP
211
+
191
212
  #define ONEMKL_DECLARE_TRSM(T) \
192
213
  DLL_EXPORT sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, \
193
214
  std::int64_t m, std::int64_t n, \
@@ -210,6 +231,27 @@ ONEMKL_DECLARE_TRSM(std::complex<double>)
210
231
 
211
232
  #undef ONEMKL_DECLARE_TRSM
212
233
 
234
+ #define ONEMKL_DECLARE_TRSM_OOP(T) \
235
+ DLL_EXPORT sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \
236
+ diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer<T> alpha, const T *a, std::int64_t lda, \
237
+ const T *b, std::int64_t ldb, value_or_pointer<T> beta, T *c, std::int64_t ldc, \
238
+ compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
239
+ ONEMKL_INLINE_DECLARE sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \
240
+ diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer<T> alpha, const T *a, std::int64_t lda, \
241
+ const T *b, std::int64_t ldb, value_or_pointer<T> beta, T *c, std::int64_t ldc, \
242
+ const std::vector<sycl::event> &dependencies = {}) \
243
+ { \
244
+ return trsm(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, \
245
+ beta, c, ldc, MKL_BLAS_COMPUTE_MODE, dependencies); \
246
+ }
247
+
248
+ ONEMKL_DECLARE_TRSM_OOP(float)
249
+ ONEMKL_DECLARE_TRSM_OOP(double)
250
+ ONEMKL_DECLARE_TRSM_OOP(std::complex<float>)
251
+ ONEMKL_DECLARE_TRSM_OOP(std::complex<double>)
252
+
253
+ #undef ONEMKL_DECLARE_TRSM_OOP
254
+
213
255
  // Level 2
214
256
 
215
257
  #define ONEMKL_DECLARE_DGMM(T) \
@@ -543,39 +585,47 @@ ONEMKL_DECLARE_DOTU(std::complex<double>)
543
585
 
544
586
  #undef ONEMKL_DECLARE_DOTU
545
587
 
546
- #define ONEMKL_DECLARE_IAMAX(T) \
547
- DLL_EXPORT sycl::event iamax(sycl::queue &queue, std::int64_t n, const T *x, std::int64_t incx, \
548
- std::int64_t *result, index_base base, \
588
+ #define ONEMKL_DECLARE_IAMAX(Tf, Ti) \
589
+ DLL_EXPORT sycl::event iamax(sycl::queue &queue, std::int64_t n, const Tf *x, std::int64_t incx, \
590
+ Ti *result, index_base base, \
549
591
  const std::vector<sycl::event> &dependencies = {}); \
550
- ONEMKL_INLINE_DECLARE sycl::event iamax(sycl::queue &queue, std::int64_t n, const T *x, \
551
- std::int64_t incx, std::int64_t *result, \
592
+ ONEMKL_INLINE_DECLARE sycl::event iamax(sycl::queue &queue, std::int64_t n, const Tf *x, \
593
+ std::int64_t incx, Ti *result, \
552
594
  const std::vector<sycl::event> &dependencies = {}) \
553
595
  { \
554
596
  return iamax(queue, n, x, incx, result, index_base::zero, dependencies); \
555
597
  }
556
598
 
557
- ONEMKL_DECLARE_IAMAX(float)
558
- ONEMKL_DECLARE_IAMAX(double)
559
- ONEMKL_DECLARE_IAMAX(std::complex<float>)
560
- ONEMKL_DECLARE_IAMAX(std::complex<double>)
599
+ ONEMKL_DECLARE_IAMAX(float, std::int64_t)
600
+ ONEMKL_DECLARE_IAMAX(float, std::int32_t)
601
+ ONEMKL_DECLARE_IAMAX(double, std::int64_t)
602
+ ONEMKL_DECLARE_IAMAX(double, std::int32_t)
603
+ ONEMKL_DECLARE_IAMAX(std::complex<float>, std::int64_t)
604
+ ONEMKL_DECLARE_IAMAX(std::complex<float>, std::int32_t)
605
+ ONEMKL_DECLARE_IAMAX(std::complex<double>, std::int64_t)
606
+ ONEMKL_DECLARE_IAMAX(std::complex<double>, std::int32_t)
561
607
 
562
608
  #undef ONEMKL_DECLARE_IAMAX
563
609
 
564
- #define ONEMKL_DECLARE_IAMIN(T) \
565
- DLL_EXPORT sycl::event iamin(sycl::queue &queue, std::int64_t n, const T *x, std::int64_t incx, \
566
- std::int64_t *result, index_base base, \
610
+ #define ONEMKL_DECLARE_IAMIN(Tf, Ti) \
611
+ DLL_EXPORT sycl::event iamin(sycl::queue &queue, std::int64_t n, const Tf *x, std::int64_t incx, \
612
+ Ti *result, index_base base, \
567
613
  const std::vector<sycl::event> &dependencies = {}); \
568
- ONEMKL_INLINE_DECLARE sycl::event iamin(sycl::queue &queue, std::int64_t n, const T *x, \
569
- std::int64_t incx, std::int64_t *result, \
614
+ ONEMKL_INLINE_DECLARE sycl::event iamin(sycl::queue &queue, std::int64_t n, const Tf *x, \
615
+ std::int64_t incx, Ti *result, \
570
616
  const std::vector<sycl::event> &dependencies = {}) \
571
617
  { \
572
618
  return iamin(queue, n, x, incx, result, index_base::zero, dependencies); \
573
619
  }
574
620
 
575
- ONEMKL_DECLARE_IAMIN(float)
576
- ONEMKL_DECLARE_IAMIN(double)
577
- ONEMKL_DECLARE_IAMIN(std::complex<float>)
578
- ONEMKL_DECLARE_IAMIN(std::complex<double>)
621
+ ONEMKL_DECLARE_IAMIN(float, std::int64_t)
622
+ ONEMKL_DECLARE_IAMIN(float, std::int32_t)
623
+ ONEMKL_DECLARE_IAMIN(double, std::int64_t)
624
+ ONEMKL_DECLARE_IAMIN(double, std::int32_t)
625
+ ONEMKL_DECLARE_IAMIN(std::complex<float>, std::int64_t)
626
+ ONEMKL_DECLARE_IAMIN(std::complex<float>, std::int32_t)
627
+ ONEMKL_DECLARE_IAMIN(std::complex<double>, std::int64_t)
628
+ ONEMKL_DECLARE_IAMIN(std::complex<double>, std::int32_t)
579
629
 
580
630
  #undef ONEMKL_DECLARE_IAMIN
581
631
 
@@ -666,7 +716,7 @@ ONEMKL_DECLARE_ROT(std::complex<double>, double, std::complex<double>)
666
716
  #undef ONEMKL_DECLARE_ROT
667
717
 
668
718
  #define ONEMKL_DECLARE_ROT_EXPLICIT_SCALARS(T, Tc, Ts) \
669
- DLL_EXPORT sycl::event rot(sycl::queue &queue, std::int64_t n, T *x, std::int64_t incx, T *y, std::int64_t incy, Tc c, Ts s, const std::vector<sycl::event> &dependencies = {});
719
+ DLL_EXPORT sycl::event rot(sycl::queue &queue, std::int64_t n, T *x, std::int64_t incx, T *y, std::int64_t incy, value_or_pointer<Tc> c, Ts s, const std::vector<sycl::event> &dependencies = {});
670
720
 
671
721
  ONEMKL_DECLARE_ROT_EXPLICIT_SCALARS(std::complex<float>, float, float)
672
722
  ONEMKL_DECLARE_ROT_EXPLICIT_SCALARS(std::complex<float>, float, std::complex<float>)
@@ -731,7 +781,8 @@ ONEMKL_DECLARE_SWAP(std::complex<double>)
731
781
 
732
782
  #define ONEMKL_DECLARE_GEMM_BATCH(Ta, Tb, Tc, Ts) \
733
783
  ONEMKL_DECLARE_GEMM_BATCH_STRIDED(Ta, Tb, Tc, Ts) \
734
- ONEMKL_DECLARE_GEMM_BATCH_GROUP(Ta, Tb, Tc, Ts) \
784
+ ONEMKL_DECLARE_GEMM_BATCH_GROUP(Ta, Tb, Tc, Ts, std::int64_t) \
785
+ ONEMKL_DECLARE_GEMM_BATCH_GROUP(Ta, Tb, Tc, Ts, std::int32_t) \
735
786
  ONEMKL_DECLARE_GEMM_BATCH_SPAN(Ta, Tb, Tc, Ts)
736
787
 
737
788
  #define ONEMKL_DECLARE_GEMM_BATCH_STRIDED(Ta, Tb, Tc, Ts) \
@@ -753,20 +804,20 @@ ONEMKL_INLINE_DECLARE sycl::event gemm_batch(sycl::queue &queue, transpose trans
753
804
  return gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, MKL_BLAS_COMPUTE_MODE, dependencies); \
754
805
  }
755
806
 
756
- #define ONEMKL_DECLARE_GEMM_BATCH_GROUP(Ta, Tb, Tc, Ts) \
807
+ #define ONEMKL_DECLARE_GEMM_BATCH_GROUP(Ta, Tb, Tc, Ts, Ti) \
757
808
  DLL_EXPORT sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transpose *transb, \
758
- const std::int64_t *m, const std::int64_t *n, const std::int64_t *k, const Ts *alpha, \
759
- const Ta **a, const std::int64_t *lda, \
760
- const Tb **b, const std::int64_t *ldb, \
761
- const Ts *beta, Tc **c, const std::int64_t *ldc, \
762
- std::int64_t group_count, const std::int64_t *groupsize, \
809
+ const Ti *m, const Ti *n, const Ti *k, const Ts *alpha, \
810
+ const Ta **a, const Ti *lda, \
811
+ const Tb **b, const Ti *ldb, \
812
+ const Ts *beta, Tc **c, const Ti *ldc, \
813
+ std::int64_t group_count, const Ti *groupsize, \
763
814
  compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
764
815
  ONEMKL_INLINE_DECLARE sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transpose *transb, \
765
- const std::int64_t *m, const std::int64_t *n, const std::int64_t *k, const Ts *alpha, \
766
- const Ta **a, const std::int64_t *lda, \
767
- const Tb **b, const std::int64_t *ldb, \
768
- const Ts *beta, Tc **c, const std::int64_t *ldc, \
769
- std::int64_t group_count, const std::int64_t *groupsize, \
816
+ const Ti *m, const Ti *n, const Ti *k, const Ts *alpha, \
817
+ const Ta **a, const Ti *lda, \
818
+ const Tb **b, const Ti *ldb, \
819
+ const Ts *beta, Tc **c, const Ti *ldc, \
820
+ std::int64_t group_count, const Ti *groupsize, \
770
821
  const std::vector<sycl::event> &dependencies = {}) \
771
822
  { \
772
823
  return gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, groupsize, MKL_BLAS_COMPUTE_MODE, dependencies); \
@@ -815,7 +866,8 @@ ONEMKL_DECLARE_GEMM_BATCH(std::int8_t, std::int8_t, float, float)
815
866
 
816
867
  #define ONEMKL_DECLARE_SYRK_BATCH(T) \
817
868
  ONEMKL_DECLARE_SYRK_BATCH_STRIDED(T) \
818
- ONEMKL_DECLARE_SYRK_BATCH_GROUP(T)
869
+ ONEMKL_DECLARE_SYRK_BATCH_GROUP(T, std::int64_t) \
870
+ ONEMKL_DECLARE_SYRK_BATCH_GROUP(T, std::int32_t)
819
871
 
820
872
  #define ONEMKL_DECLARE_SYRK_BATCH_STRIDED(T) \
821
873
  DLL_EXPORT sycl::event syrk_batch(sycl::queue &queue, \
@@ -833,18 +885,18 @@ ONEMKL_INLINE_DECLARE sycl::event syrk_batch(sycl::queue &queue, \
833
885
  return syrk_batch(queue, upper_lower, trans, n, k, alpha, a, lda, stride_a, beta, c, ldc, stride_c, batch_size, MKL_BLAS_COMPUTE_MODE, dependencies); \
834
886
  }
835
887
 
836
- #define ONEMKL_DECLARE_SYRK_BATCH_GROUP(T) \
888
+ #define ONEMKL_DECLARE_SYRK_BATCH_GROUP(T, Ti) \
837
889
  DLL_EXPORT sycl::event syrk_batch(sycl::queue &queue, \
838
- const uplo *upper_lower, const transpose *trans, const std::int64_t *n, const std::int64_t *k, \
839
- const T *alpha, const T **a, const std::int64_t *lda, const T *beta, \
840
- T **c, const std::int64_t *ldc, \
841
- std::int64_t group_count, const std::int64_t *groupsize, \
842
- compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
890
+ const uplo *upper_lower, const transpose *trans, const Ti *n, const Ti *k, \
891
+ const T *alpha, const T **a, const Ti *lda, const T *beta, \
892
+ T **c, const Ti *ldc, \
893
+ std::int64_t group_count, const Ti *groupsize, \
894
+ compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
843
895
  ONEMKL_INLINE_DECLARE sycl::event syrk_batch(sycl::queue &queue, \
844
- const uplo *upper_lower, const transpose *trans, const std::int64_t *n, const std::int64_t *k, \
845
- const T *alpha, const T **a, const std::int64_t *lda, const T *beta, \
846
- T **c, const std::int64_t *ldc, \
847
- std::int64_t group_count, const std::int64_t *groupsize, \
896
+ const uplo *upper_lower, const transpose *trans, const Ti *n, const Ti *k, \
897
+ const T *alpha, const T **a, const Ti *lda, const T *beta, \
898
+ T **c, const Ti *ldc, \
899
+ std::int64_t group_count, const Ti *groupsize, \
848
900
  const std::vector<sycl::event> &dependencies = {}) \
849
901
  { \
850
902
  return syrk_batch(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, groupsize, MKL_BLAS_COMPUTE_MODE, dependencies); \
@@ -859,7 +911,8 @@ ONEMKL_DECLARE_SYRK_BATCH(std::complex<double>)
859
911
 
860
912
  #define ONEMKL_DECLARE_TRSM_BATCH(T) \
861
913
  ONEMKL_DECLARE_TRSM_BATCH_STRIDED(T) \
862
- ONEMKL_DECLARE_TRSM_BATCH_GROUP(T)
914
+ ONEMKL_DECLARE_TRSM_BATCH_GROUP(T, std::int64_t) \
915
+ ONEMKL_DECLARE_TRSM_BATCH_GROUP(T, std::int32_t)
863
916
 
864
917
  #define ONEMKL_DECLARE_TRSM_BATCH_STRIDED(T) \
865
918
  DLL_EXPORT sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, \
@@ -882,20 +935,20 @@ ONEMKL_INLINE_DECLARE sycl::event trsm_batch(sycl::queue &queue, side left_right
882
935
  return trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, MKL_BLAS_COMPUTE_MODE, dependencies); \
883
936
  }
884
937
 
885
- #define ONEMKL_DECLARE_TRSM_BATCH_GROUP(T) \
938
+ #define ONEMKL_DECLARE_TRSM_BATCH_GROUP(T, Ti) \
886
939
  DLL_EXPORT sycl::event trsm_batch(sycl::queue &queue, const side *left_right, const uplo *upper_lower, \
887
940
  const transpose *trans, const diag *unit_diag, \
888
- const std::int64_t *m, const std::int64_t *n, \
889
- const T *alpha, const T **a, const std::int64_t *lda, \
890
- T **b, const std::int64_t *ldb, \
891
- std::int64_t group_count, const std::int64_t *group_size, \
941
+ const Ti *m, const Ti *n, \
942
+ const T *alpha, const T **a, const Ti *lda, \
943
+ T **b, const Ti *ldb, \
944
+ std::int64_t group_count, const Ti *group_size, \
892
945
  compute_mode mode, const std::vector<sycl::event> &dependencies = {}); \
893
946
  ONEMKL_INLINE_DECLARE sycl::event trsm_batch(sycl::queue &queue, const side *left_right, const uplo *upper_lower, \
894
947
  const transpose *trans, const diag *unit_diag, \
895
- const std::int64_t *m, const std::int64_t *n, \
896
- const T *alpha, const T **a, const std::int64_t *lda, \
897
- T **b, const std::int64_t *ldb, \
898
- std::int64_t group_count, const std::int64_t *group_size, \
948
+ const Ti *m, const Ti *n, \
949
+ const T *alpha, const T **a, const Ti *lda, \
950
+ T **b, const Ti *ldb, \
951
+ std::int64_t group_count, const Ti *group_size, \
899
952
  const std::vector<sycl::event> &dependencies = {}) \
900
953
  { \
901
954
  return trsm_batch(queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, group_count, group_size, MKL_BLAS_COMPUTE_MODE, dependencies); \
@@ -910,7 +963,8 @@ ONEMKL_DECLARE_TRSM_BATCH(std::complex<double>)
910
963
 
911
964
  #define ONEMKL_DECLARE_DGMM_BATCH(T) \
912
965
  ONEMKL_DECLARE_DGMM_BATCH_STRIDED(T) \
913
- ONEMKL_DECLARE_DGMM_BATCH_GROUP(T)
966
+ ONEMKL_DECLARE_DGMM_BATCH_GROUP(T, std::int64_t) \
967
+ ONEMKL_DECLARE_DGMM_BATCH_GROUP(T, std::int32_t)
914
968
 
915
969
  #define ONEMKL_DECLARE_DGMM_BATCH_STRIDED(T) \
916
970
  DLL_EXPORT sycl::event dgmm_batch(sycl::queue &queue, side left_right, \
@@ -920,30 +974,28 @@ DLL_EXPORT sycl::event dgmm_batch(sycl::queue &queue, side left_right, \
920
974
  T *c, std::int64_t ldc, std::int64_t stridec, std::int64_t batch_size, \
921
975
  const std::vector<sycl::event> &dependencies = {});
922
976
 
923
- ONEMKL_DECLARE_DGMM_BATCH_STRIDED(float)
924
- ONEMKL_DECLARE_DGMM_BATCH_STRIDED(double)
925
- ONEMKL_DECLARE_DGMM_BATCH_STRIDED(std::complex<float>)
926
- ONEMKL_DECLARE_DGMM_BATCH_STRIDED(std::complex<double>)
927
-
928
- #define ONEMKL_DECLARE_DGMM_BATCH_GROUP(T) \
977
+ #define ONEMKL_DECLARE_DGMM_BATCH_GROUP(Tf, Ti) \
929
978
  DLL_EXPORT sycl::event dgmm_batch(sycl::queue &queue, const side *left_right, \
930
- const std::int64_t *m, const std::int64_t *n, \
931
- const T **a, const std::int64_t *lda, \
932
- const T **x, const std::int64_t *incx, \
933
- T **c, const std::int64_t *ldc, \
934
- std::int64_t group_count, const std::int64_t *group_size, \
979
+ const Ti *m, const Ti *n, \
980
+ const Tf **a, const Ti *lda, \
981
+ const Tf **x, const Ti *incx, \
982
+ Tf **c, const Ti *ldc, \
983
+ std::int64_t group_count, const Ti *group_size, \
935
984
  const std::vector<sycl::event> &dependencies = {});
936
985
 
937
- ONEMKL_DECLARE_DGMM_BATCH_GROUP(float)
938
- ONEMKL_DECLARE_DGMM_BATCH_GROUP(double)
939
- ONEMKL_DECLARE_DGMM_BATCH_GROUP(std::complex<float>)
940
- ONEMKL_DECLARE_DGMM_BATCH_GROUP(std::complex<double>)
986
+ ONEMKL_DECLARE_DGMM_BATCH(float)
987
+ ONEMKL_DECLARE_DGMM_BATCH(double)
988
+ ONEMKL_DECLARE_DGMM_BATCH(std::complex<float>)
989
+ ONEMKL_DECLARE_DGMM_BATCH(std::complex<double>)
941
990
 
991
+ #undef ONEMKL_DECLARE_DGMM_BATCH_STRIDED
992
+ #undef ONEMKL_DECLARE_DGMM_BATCH_GROUP
942
993
  #undef ONEMKL_DECLARE_DGMM_BATCH
943
994
 
944
995
  #define ONEMKL_DECLARE_GEMV_BATCH(T) \
945
996
  ONEMKL_DECLARE_GEMV_BATCH_STRIDED(T) \
946
- ONEMKL_DECLARE_GEMV_BATCH_GROUP(T)
997
+ ONEMKL_DECLARE_GEMV_BATCH_GROUP(T, std::int64_t) \
998
+ ONEMKL_DECLARE_GEMV_BATCH_GROUP(T, std::int32_t)
947
999
 
948
1000
  #define ONEMKL_DECLARE_GEMV_BATCH_STRIDED(T) \
949
1001
  DLL_EXPORT sycl::event gemv_batch(sycl::queue &queue, transpose trans, \
@@ -953,30 +1005,28 @@ DLL_EXPORT sycl::event gemv_batch(sycl::queue &queue, transpose trans, \
953
1005
  T *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, \
954
1006
  const std::vector<sycl::event> &dependencies = {});
955
1007
 
956
- ONEMKL_DECLARE_GEMV_BATCH_STRIDED(float)
957
- ONEMKL_DECLARE_GEMV_BATCH_STRIDED(double)
958
- ONEMKL_DECLARE_GEMV_BATCH_STRIDED(std::complex<float>)
959
- ONEMKL_DECLARE_GEMV_BATCH_STRIDED(std::complex<double>)
960
-
961
- #define ONEMKL_DECLARE_GEMV_BATCH_GROUP(T) \
1008
+ #define ONEMKL_DECLARE_GEMV_BATCH_GROUP(Tf, Ti) \
962
1009
  DLL_EXPORT sycl::event gemv_batch(sycl::queue &queue, const transpose *trans, \
963
- const std::int64_t *m, const std::int64_t *n, const T *alpha, \
964
- const T **a, const std::int64_t *lda, \
965
- const T **x, const std::int64_t *incx, const T *beta, \
966
- T **y, const std::int64_t *incy, \
967
- std::int64_t group_count, const std::int64_t *group_size, \
1010
+ const Ti *m, const Ti *n, const Tf *alpha, \
1011
+ const Tf **a, const Ti *lda, \
1012
+ const Tf **x, const Ti *incx, const Tf *beta, \
1013
+ Tf **y, const Ti *incy, \
1014
+ std::int64_t group_count, const Ti *group_size, \
968
1015
  const std::vector<sycl::event> &dependencies = {});
969
1016
 
970
- ONEMKL_DECLARE_GEMV_BATCH_GROUP(float)
971
- ONEMKL_DECLARE_GEMV_BATCH_GROUP(double)
972
- ONEMKL_DECLARE_GEMV_BATCH_GROUP(std::complex<float>)
973
- ONEMKL_DECLARE_GEMV_BATCH_GROUP(std::complex<double>)
1017
+ ONEMKL_DECLARE_GEMV_BATCH(float)
1018
+ ONEMKL_DECLARE_GEMV_BATCH(double)
1019
+ ONEMKL_DECLARE_GEMV_BATCH(std::complex<float>)
1020
+ ONEMKL_DECLARE_GEMV_BATCH(std::complex<double>)
974
1021
 
1022
+ #undef ONEMKL_DECLARE_GEMV_BATCH_STRIDED
1023
+ #undef ONEMKL_DECLARE_GEMV_BATCH_GROUP
975
1024
  #undef ONEMKL_DECLARE_GEMV_BATCH
976
1025
 
977
- #define ONEMKL_DECLARE_AXPY_BATCH(T) \
978
- ONEMKL_DECLARE_AXPY_BATCH_STRIDED(T) \
979
- ONEMKL_DECLARE_AXPY_BATCH_GROUP(T)
1026
+ #define ONEMKL_DECLARE_AXPY_BATCH(T) \
1027
+ ONEMKL_DECLARE_AXPY_BATCH_STRIDED(T) \
1028
+ ONEMKL_DECLARE_AXPY_BATCH_GROUP(T, std::int64_t) \
1029
+ ONEMKL_DECLARE_AXPY_BATCH_GROUP(T, std::int32_t) \
980
1030
 
981
1031
  #define ONEMKL_DECLARE_AXPY_BATCH_STRIDED(T) \
982
1032
  DLL_EXPORT sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, value_or_pointer<T> alpha, \
@@ -984,27 +1034,25 @@ DLL_EXPORT sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, value_or_p
984
1034
  T *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, \
985
1035
  const std::vector<sycl::event> &dependencies = {});
986
1036
 
987
- ONEMKL_DECLARE_AXPY_BATCH_STRIDED(float)
988
- ONEMKL_DECLARE_AXPY_BATCH_STRIDED(double)
989
- ONEMKL_DECLARE_AXPY_BATCH_STRIDED(std::complex<float>)
990
- ONEMKL_DECLARE_AXPY_BATCH_STRIDED(std::complex<double>)
991
-
992
- #define ONEMKL_DECLARE_AXPY_BATCH_GROUP(T) \
993
- DLL_EXPORT sycl::event axpy_batch(sycl::queue &queue, const std::int64_t *n, const T *alpha, const T **x, \
994
- const std::int64_t *incx, T **y, const std::int64_t *incy, std::int64_t group_count, \
995
- const std::int64_t *group_size, \
1037
+ #define ONEMKL_DECLARE_AXPY_BATCH_GROUP(Tf, Ti) \
1038
+ DLL_EXPORT sycl::event axpy_batch(sycl::queue &queue, const Ti *n, const Tf *alpha, const Tf **x, \
1039
+ const Ti *incx, Tf **y, const Ti *incy, std::int64_t group_count, \
1040
+ const Ti *group_size, \
996
1041
  const std::vector<sycl::event> &dependencies = {});
997
1042
 
998
- ONEMKL_DECLARE_AXPY_BATCH_GROUP(float)
999
- ONEMKL_DECLARE_AXPY_BATCH_GROUP(double)
1000
- ONEMKL_DECLARE_AXPY_BATCH_GROUP(std::complex<float>)
1001
- ONEMKL_DECLARE_AXPY_BATCH_GROUP(std::complex<double>)
1043
+ ONEMKL_DECLARE_AXPY_BATCH(float)
1044
+ ONEMKL_DECLARE_AXPY_BATCH(double)
1045
+ ONEMKL_DECLARE_AXPY_BATCH(std::complex<float>)
1046
+ ONEMKL_DECLARE_AXPY_BATCH(std::complex<double>)
1002
1047
 
1048
+ #undef ONEMKL_DECLARE_AXPY_BATCH_STRIDED
1049
+ #undef ONEMKL_DECLARE_AXPY_BATCH_GROUP
1003
1050
  #undef ONEMKL_DECLARE_AXPY_BATCH
1004
1051
 
1005
1052
  #define ONEMKL_DECLARE_COPY_BATCH(T) \
1006
1053
  ONEMKL_DECLARE_COPY_BATCH_STRIDED(T) \
1007
- ONEMKL_DECLARE_COPY_BATCH_GROUP(T)
1054
+ ONEMKL_DECLARE_COPY_BATCH_GROUP(T, std::int64_t) \
1055
+ ONEMKL_DECLARE_COPY_BATCH_GROUP(T, std::int32_t)
1008
1056
 
1009
1057
  #define ONEMKL_DECLARE_COPY_BATCH_STRIDED(T) \
1010
1058
  DLL_EXPORT sycl::event copy_batch(sycl::queue &queue, std::int64_t n, \
@@ -1013,23 +1061,19 @@ DLL_EXPORT sycl::event copy_batch(sycl::queue &queue, std::int64_t n, \
1013
1061
  std::int64_t batch_size, \
1014
1062
  const std::vector<sycl::event> &dependencies = {});
1015
1063
 
1016
- ONEMKL_DECLARE_COPY_BATCH_STRIDED(float)
1017
- ONEMKL_DECLARE_COPY_BATCH_STRIDED(double)
1018
- ONEMKL_DECLARE_COPY_BATCH_STRIDED(std::complex<float>)
1019
- ONEMKL_DECLARE_COPY_BATCH_STRIDED(std::complex<double>)
1064
+ #define ONEMKL_DECLARE_COPY_BATCH_GROUP(Tf, Ti) \
1065
+ DLL_EXPORT sycl::event copy_batch(sycl::queue &queue, const Ti *n, \
1066
+ const Tf **x, const Ti *incx, Tf **y, const Ti *incy, \
1067
+ std::int64_t group_count, const Ti *group_size, \
1068
+ const std::vector<sycl::event> &dependencies = {});
1020
1069
 
1021
- #define ONEMKL_DECLARE_COPY_BATCH_GROUP(T) \
1022
- DLL_EXPORT sycl::event copy_batch(sycl::queue &queue, const std::int64_t *n, \
1023
- const T **x, const std::int64_t *incx, \
1024
- T **y, const std::int64_t *incy, \
1025
- std::int64_t group_count, const std::int64_t *group_size, \
1026
- const std::vector<sycl::event> &dependencies = {});
1027
-
1028
- ONEMKL_DECLARE_COPY_BATCH_GROUP(float)
1029
- ONEMKL_DECLARE_COPY_BATCH_GROUP(double)
1030
- ONEMKL_DECLARE_COPY_BATCH_GROUP(std::complex<float>)
1031
- ONEMKL_DECLARE_COPY_BATCH_GROUP(std::complex<double>)
1070
+ ONEMKL_DECLARE_COPY_BATCH(float)
1071
+ ONEMKL_DECLARE_COPY_BATCH(double)
1072
+ ONEMKL_DECLARE_COPY_BATCH(std::complex<float>)
1073
+ ONEMKL_DECLARE_COPY_BATCH(std::complex<double>)
1032
1074
 
1075
+ #undef ONEMKL_DECLARE_COPY_BATCH_STRIDED
1076
+ #undef ONEMKL_DECLARE_COPY_BATCH_GROUP
1033
1077
  #undef ONEMKL_DECLARE_COPY_BATCH
1034
1078
 
1035
1079
  // BLAS like
@@ -1121,7 +1165,8 @@ ONEMKL_DECLARE_OMATADD(std::complex<double>)
1121
1165
 
1122
1166
  #define ONEMKL_DECLARE_IMATCOPY_BATCH(T) \
1123
1167
  ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(T) \
1124
- ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(T)
1168
+ ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(T, std::int64_t) \
1169
+ ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(T, std::int32_t)
1125
1170
 
1126
1171
  #define ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(T) \
1127
1172
  DLL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, \
@@ -1129,28 +1174,26 @@ DLL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, \
1129
1174
  std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, \
1130
1175
  const std::vector<sycl::event> &dependencies = {});
1131
1176
 
1132
- ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(float)
1133
- ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(double)
1134
- ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(std::complex<float>)
1135
- ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED(std::complex<double>)
1136
-
1137
- #define ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(T) \
1177
+ #define ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(T, Ti) \
1138
1178
  DLL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, const transpose *trans, \
1139
- const std::int64_t *m, const std::int64_t *n, const T *alpha, T **ab, \
1140
- const std::int64_t *lda, const std::int64_t *ldb, std::int64_t group_count, \
1141
- const std::int64_t *groupsize, \
1142
- const std::vector<sycl::event> &dependencies = {});
1179
+ const Ti *m, const Ti *n, const T *alpha, T **ab, \
1180
+ const Ti *lda, const Ti *ldb, std::int64_t group_count, \
1181
+ const Ti *groupsize, \
1182
+ const std::vector<sycl::event> &dependencies = {});
1143
1183
 
1144
- ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(float)
1145
- ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(double)
1146
- ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(std::complex<float>)
1147
- ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP(std::complex<double>)
1184
+ ONEMKL_DECLARE_IMATCOPY_BATCH(float)
1185
+ ONEMKL_DECLARE_IMATCOPY_BATCH(double)
1186
+ ONEMKL_DECLARE_IMATCOPY_BATCH(std::complex<float>)
1187
+ ONEMKL_DECLARE_IMATCOPY_BATCH(std::complex<double>)
1148
1188
 
1189
+ #undef ONEMKL_DECLARE_IMATCOPY_BATCH_GROUP
1190
+ #undef ONEMKL_DECLARE_IMATCOPY_BATCH_STRIDED
1149
1191
  #undef ONEMKL_DECLARE_IMATCOPY_BATCH
1150
1192
 
1151
1193
  #define ONEMKL_DECLARE_OMATCOPY_BATCH(T) \
1152
1194
  ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(T) \
1153
- ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(T)
1195
+ ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(T, std::int64_t) \
1196
+ ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(T, std::int32_t)
1154
1197
 
1155
1198
  #define ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(T) \
1156
1199
  DLL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, \
@@ -1160,23 +1203,20 @@ DLL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, \
1160
1203
  std::int64_t batch_size, \
1161
1204
  const std::vector<sycl::event> &dependencies = {});
1162
1205
 
1163
- ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(float)
1164
- ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(double)
1165
- ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(std::complex<float>)
1166
- ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED(std::complex<double>)
1167
-
1168
- #define ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(T) \
1206
+ #define ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(T, Ti) \
1169
1207
  DLL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, const transpose *trans, \
1170
- const std::int64_t *m, const std::int64_t *n, const T *alpha, const T **a, \
1171
- const std::int64_t *lda, T **b, const std::int64_t *ldb, std::int64_t group_count, \
1172
- const std::int64_t *groupsize, \
1173
- const std::vector<sycl::event> &dependencies = {});
1208
+ const Ti *m, const Ti *n, const T *alpha, const T **a, \
1209
+ const Ti *lda, T **b, const Ti *ldb, std::int64_t group_count, \
1210
+ const Ti *groupsize, \
1211
+ const std::vector<sycl::event> &dependencies = {});
1174
1212
 
1175
- ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(float)
1176
- ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(double)
1177
- ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(std::complex<float>)
1178
- ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP(std::complex<double>)
1213
+ ONEMKL_DECLARE_OMATCOPY_BATCH(float)
1214
+ ONEMKL_DECLARE_OMATCOPY_BATCH(double)
1215
+ ONEMKL_DECLARE_OMATCOPY_BATCH(std::complex<float>)
1216
+ ONEMKL_DECLARE_OMATCOPY_BATCH(std::complex<double>)
1179
1217
 
1218
+ #undef ONEMKL_DECLARE_OMATCOPY_BATCH_GROUP
1219
+ #undef ONEMKL_DECLARE_OMATCOPY_BATCH_STRIDED
1180
1220
  #undef ONEMKL_DECLARE_OMATCOPY_BATCH
1181
1221
 
1182
1222
  #define ONEMKL_DECLARE_OMATADD_BATCH(T) \
@@ -62,7 +62,9 @@ enum class config_param {
62
62
  THREAD_LIMIT = DFTI_THREAD_LIMIT,
63
63
  DESTROY_INPUT = DFTI_DESTROY_INPUT,
64
64
  WORKSPACE_ESTIMATE_BYTES,
65
- WORKSPACE_BYTES
65
+ WORKSPACE_BYTES,
66
+ FWD_STRIDES,
67
+ BWD_STRIDES
66
68
  };
67
69
 
68
70
  enum class config_value {