@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -15,9 +15,9 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
18
+ typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
19
19
 
20
- static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
20
+ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
21
21
  const int iqs, dfloat2 &v) {
22
22
  const block_q4_0 * x = (const block_q4_0 *) vx;
23
23
 
@@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
40
40
  #endif // GGML_SYCL_F16
41
41
  }
42
42
 
43
- static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
43
+ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
44
44
  const int iqs, dfloat2 &v) {
45
45
  const block_q4_1 * x = (const block_q4_1 *) vx;
46
46
 
@@ -55,16 +55,16 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
55
55
  #ifdef GGML_SYCL_F16
56
56
  // v = v * {d, d};
57
57
  // v = v + {m, m};
58
- v.s0() = (v.s0() * d) + m;
59
- v.s1() = (v.s1() * d) + m;
58
+ v.s0() = sycl::fma(v.s0(), d, m);
59
+ v.s1() = sycl::fma(v.s1(), d, m);
60
60
 
61
61
  #else
62
- v.x() = (v.x() * d) + m;
63
- v.y() = (v.y() * d) + m;
62
+ v.x() = sycl::fma(v.x(), d, m);
63
+ v.y() = sycl::fma(v.y(), d, m);
64
64
  #endif // GGML_SYCL_F16
65
65
  }
66
66
 
67
- static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
67
+ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
68
68
  const int iqs, dfloat2 &v) {
69
69
  const block_q5_0 * x = (const block_q5_0 *) vx;
70
70
 
@@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
91
91
  #endif // GGML_SYCL_F16
92
92
  }
93
93
 
94
- static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
94
+ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
95
95
  const int iqs, dfloat2 &v) {
96
96
  const block_q5_1 * x = (const block_q5_1 *) vx;
97
97
 
@@ -110,15 +110,15 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
110
110
  #ifdef GGML_SYCL_F16
111
111
  // v = v * {d, d};
112
112
  // v = v + {m, m};
113
- v.s0() = (v.s0() * d) + m;
114
- v.s1() = (v.s1() * d) + m;
113
+ v.s0() = sycl::fma(v.s0(), d, m);
114
+ v.s1() = sycl::fma(v.s1(), d, m);
115
115
  #else
116
- v.x() = (v.x() * d) + m;
117
- v.y() = (v.y() * d) + m;
116
+ v.x() = sycl::fma(v.x(), d, m);
117
+ v.y() = sycl::fma(v.y(), d, m);
118
118
  #endif // GGML_SYCL_F16
119
119
  }
120
120
 
121
- static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
121
+ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
122
122
  const int iqs, dfloat2 &v) {
123
123
  const block_q8_0 * x = (const block_q8_0 *) vx;
124
124
 
@@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
138
138
  }
139
139
 
140
140
  template<typename dst_t>
141
- static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
141
+ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
142
142
  const sycl::nd_item<3> &item_ct1) {
143
143
 
144
- const int i = item_ct1.get_group(2);
144
+ const int64_t i = item_ct1.get_group(2);
145
145
 
146
146
  // assume 32 threads
147
- const int tid = item_ct1.get_local_id(2);
148
- const int il = tid/8;
149
- const int ir = tid%8;
150
- const int ib = 8*i + ir;
147
+ const int64_t tid = item_ct1.get_local_id(2);
148
+ const int64_t il = tid/8;
149
+ const int64_t ir = tid%8;
150
+ const int64_t ib = 8*i + ir;
151
151
  if (ib >= nb32) {
152
152
  return;
153
153
  }
@@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
168
168
  }
169
169
 
170
170
  template<typename dst_t>
171
- static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
171
+ static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
172
172
  const sycl::nd_item<3> &item_ct1) {
173
173
 
174
- const int i = item_ct1.get_group(2);
174
+ const int64_t i = item_ct1.get_group(2);
175
175
 
176
176
  // assume 32 threads
177
- const int tid = item_ct1.get_local_id(2);
178
- const int il = tid/8;
179
- const int ir = tid%8;
180
- const int ib = 8*i + ir;
177
+ const int64_t tid = item_ct1.get_local_id(2);
178
+ const int64_t il = tid/8;
179
+ const int64_t ir = tid%8;
180
+ const int64_t ib = 8*i + ir;
181
181
  if (ib >= nb32) {
182
182
  return;
183
183
  }
@@ -203,14 +203,14 @@ template<typename dst_t>
203
203
  static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
204
204
  const sycl::nd_item<3> &item_ct1) {
205
205
 
206
- const int i = item_ct1.get_group(2);
206
+ const int64_t i = item_ct1.get_group(2);
207
207
  const block_q2_K * x = (const block_q2_K *) vx;
208
208
 
209
- const int tid = item_ct1.get_local_id(2);
209
+ const int64_t tid = item_ct1.get_local_id(2);
210
210
  #if QK_K == 256
211
- const int n = tid/32;
212
- const int l = tid - 32*n;
213
- const int is = 8*n + l/16;
211
+ const int64_t n = tid/32;
212
+ const int64_t l = tid - 32*n;
213
+ const int64_t is = 8*n + l/16;
214
214
 
215
215
  const uint8_t q = x[i].qs[32*n + l];
216
216
  dst_t * y = yy + i*QK_K + 128*n;
@@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
222
222
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
223
223
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
224
224
  #else
225
- const int is = tid/16; // 0 or 1
226
- const int il = tid%16; // 0...15
225
+ const int64_t is = tid/16; // 0 or 1
226
+ const int64_t il = tid%16; // 0...15
227
227
  const uint8_t q = x[i].qs[il] >> (2*is);
228
228
  dst_t * y = yy + i*QK_K + 16*is + il;
229
229
 
@@ -239,19 +239,19 @@ template<typename dst_t>
239
239
  static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
240
240
  const sycl::nd_item<3> &item_ct1) {
241
241
 
242
- const int i = item_ct1.get_group(2);
242
+ const int64_t i = item_ct1.get_group(2);
243
243
  const block_q3_K * x = (const block_q3_K *) vx;
244
244
 
245
245
  #if QK_K == 256
246
- const int r = item_ct1.get_local_id(2) / 4;
247
- const int tid = r/2;
248
- const int is0 = r%2;
249
- const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
250
- const int n = tid / 4;
251
- const int j = tid - 4*n;
246
+ const int64_t r = item_ct1.get_local_id(2) / 4;
247
+ const int64_t tid = r/2;
248
+ const int64_t is0 = r%2;
249
+ const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
250
+ const int64_t n = tid / 4;
251
+ const int64_t j = tid - 4*n;
252
252
 
253
253
  uint8_t m = 1 << (4*n + j);
254
- int is = 8*n + 2*j + is0;
254
+ int64_t is = 8*n + 2*j + is0;
255
255
  int shift = 2*j;
256
256
 
257
257
  int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
267
267
 
268
268
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
269
269
  #else
270
- const int tid = item_ct1.get_local_id(2);
271
- const int is = tid/16; // 0 or 1
272
- const int il = tid%16; // 0...15
273
- const int im = il/8; // 0...1
274
- const int in = il%8; // 0...7
270
+ const int64_t tid = item_ct1.get_local_id(2);
271
+ const int64_t is = tid/16; // 0 or 1
272
+ const int64_t il = tid%16; // 0...15
273
+ const int64_t im = il/8; // 0...1
274
+ const int64_t in = il%8; // 0...7
275
275
 
276
276
  dst_t * y = yy + i*QK_K + 16*is + il;
277
277
 
@@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
307
307
  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
308
308
  const block_q4_K * x = (const block_q4_K *) vx;
309
309
 
310
- const int i = item_ct1.get_group(2);
310
+ const int64_t i = item_ct1.get_group(2);
311
311
 
312
312
  #if QK_K == 256
313
313
  // assume 32 threads
314
- const int tid = item_ct1.get_local_id(2);
315
- const int il = tid/8;
316
- const int ir = tid%8;
317
- const int is = 2*il;
318
- const int n = 4;
314
+ const int64_t tid = item_ct1.get_local_id(2);
315
+ const int64_t il = tid/8;
316
+ const int64_t ir = tid%8;
317
+ const int64_t is = 2*il;
318
+ const int64_t n = 4;
319
319
 
320
320
  dst_t * y = yy + i*QK_K + 64*il + n*ir;
321
321
 
@@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
341
341
  y[l +32] = d2 * (q_vec[l] >> 4) - m2;
342
342
  }
343
343
  #else
344
- const int tid = item_ct1.get_local_id(2);
344
+ const int64_t tid = item_ct1.get_local_id(2);
345
345
  const uint8_t * q = x[i].qs;
346
346
  dst_t * y = yy + i*QK_K;
347
347
  const float d = (float)x[i].dm[0];
@@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
356
356
  const sycl::nd_item<3> &item_ct1) {
357
357
  const block_q5_K * x = (const block_q5_K *) vx;
358
358
 
359
- const int i = item_ct1.get_group(2);
359
+ const int64_t i = item_ct1.get_group(2);
360
360
 
361
361
  #if QK_K == 256
362
362
  // assume 64 threads - this is very slightly better than the one below
363
- const int tid = item_ct1.get_local_id(2);
364
- const int il = tid/16; // il is in 0...3
365
- const int ir = tid%16; // ir is in 0...15
366
- const int is = 2*il; // is is in 0...6
363
+ const int64_t tid = item_ct1.get_local_id(2);
364
+ const int64_t il = tid/16; // il is in 0...3
365
+ const int64_t ir = tid%16; // ir is in 0...15
366
+ const int64_t is = 2*il; // is is in 0...6
367
367
 
368
368
  dst_t * y = yy + i*QK_K + 64*il + 2*ir;
369
369
 
@@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
386
386
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
387
387
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
388
388
  #else
389
- const int tid = item_ct1.get_local_id(2);
389
+ const int64_t tid = item_ct1.get_local_id(2);
390
390
  const uint8_t q = x[i].qs[tid];
391
- const int im = tid/8; // 0...3
392
- const int in = tid%8; // 0...7
393
- const int is = tid/16; // 0 or 1
391
+ const int64_t im = tid/8; // 0...3
392
+ const int64_t in = tid%8; // 0...7
393
+ const int64_t is = tid/16; // 0 or 1
394
394
  const uint8_t h = x[i].qh[in] >> im;
395
395
  const float d = x[i].d;
396
396
  dst_t * y = yy + i*QK_K + tid;
@@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
404
404
  const sycl::nd_item<3> &item_ct1) {
405
405
  const block_q6_K * x = (const block_q6_K *) vx;
406
406
 
407
- const int i = item_ct1.get_group(2);
407
+ const int64_t i = item_ct1.get_group(2);
408
408
  #if QK_K == 256
409
409
 
410
410
  // assume 64 threads - this is very slightly better than the one below
411
- const int tid = item_ct1.get_local_id(2);
412
- const int ip = tid/32; // ip is 0 or 1
413
- const int il = tid - 32*ip; // 0...32
414
- const int is = 8*ip + il/16;
411
+ const int64_t tid = item_ct1.get_local_id(2);
412
+ const int64_t ip = tid/32; // ip is 0 or 1
413
+ const int64_t il = tid - 32*ip; // 0...32
414
+ const int64_t is = 8*ip + il/16;
415
415
 
416
416
  dst_t * y = yy + i*QK_K + 128*ip + il;
417
417
 
@@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
428
428
  #else
429
429
 
430
430
  // assume 32 threads
431
- const int tid = item_ct1.get_local_id(2);
432
- const int ip = tid/16; // 0 or 1
433
- const int il = tid - 16*ip; // 0...15
431
+ const int64_t tid = item_ct1.get_local_id(2);
432
+ const int64_t ip = tid/16; // 0 or 1
433
+ const int64_t il = tid - 16*ip; // 0...15
434
434
 
435
435
  dst_t * y = yy + i*QK_K + 16*ip + il;
436
436
 
@@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
452
452
  const uint8_t *ksigns_iq2xs_ptr,
453
453
  const uint8_t *kmask_iq2xs_ptr) {
454
454
 
455
- const int i = item_ct1.get_group(2);
455
+ const int64_t i = item_ct1.get_group(2);
456
456
  const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
457
457
 
458
- const int tid = item_ct1.get_local_id(2);
458
+ const int64_t tid = item_ct1.get_local_id(2);
459
459
  #if QK_K == 256
460
- const int il = tid/8; // 0...3
461
- const int ib = tid%8; // 0...7
460
+ const int64_t il = tid/8; // 0...3
461
+ const int64_t ib = tid%8; // 0...7
462
462
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
463
463
  const uint16_t * q2 = x[i].qs + 4*ib;
464
464
  const uint8_t * aux8 = (const uint8_t *)q2;
@@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
480
480
  const uint8_t *ksigns_iq2xs,
481
481
  const uint8_t *kmask_iq2xs) {
482
482
 
483
- const int i = item_ct1.get_group(2);
483
+ const int64_t i = item_ct1.get_group(2);
484
484
  const block_iq2_xs * x = (const block_iq2_xs *) vx;
485
485
 
486
- const int tid = item_ct1.get_local_id(2);
486
+ const int64_t tid = item_ct1.get_local_id(2);
487
487
  #if QK_K == 256
488
- const int il = tid/8; // 0...3
489
- const int ib = tid%8; // 0...7
488
+ const int64_t il = tid/8; // 0...3
489
+ const int64_t ib = tid%8; // 0...7
490
490
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
491
491
  const uint16_t * q2 = x[i].qs + 4*ib;
492
492
  const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@@ -504,13 +504,13 @@ __dpct_inline__ static void
504
504
  dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
505
505
  const sycl::nd_item<3> &item_ct1) {
506
506
 
507
- const int i = item_ct1.get_group(2);
507
+ const int64_t i = item_ct1.get_group(2);
508
508
  const block_iq2_s * x = (const block_iq2_s *) vx;
509
509
 
510
- const int tid = item_ct1.get_local_id(2);
510
+ const int64_t tid = item_ct1.get_local_id(2);
511
511
  #if QK_K == 256
512
- const int il = tid/8; // 0...3
513
- const int ib = tid%8; // 0...7
512
+ const int64_t il = tid/8; // 0...3
513
+ const int64_t ib = tid%8; // 0...7
514
514
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
515
515
  const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
516
516
  const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
532
532
  const uint8_t *ksigns_iq2xs,
533
533
  const uint8_t *kmask_iq2xs) {
534
534
 
535
- const int i = item_ct1.get_group(2);
535
+ const int64_t i = item_ct1.get_group(2);
536
536
  const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
537
537
 
538
- const int tid = item_ct1.get_local_id(2);
538
+ const int64_t tid = item_ct1.get_local_id(2);
539
539
  #if QK_K == 256
540
- const int il = tid/8; // 0...3
541
- const int ib = tid%8; // 0...7
540
+ const int64_t il = tid/8; // 0...3
541
+ const int64_t ib = tid%8; // 0...7
542
542
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
543
543
  const uint8_t * q3 = x[i].qs + 8*ib;
544
544
  const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
563
563
  const sycl::nd_item<3> &item_ct1,
564
564
  const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
565
565
 
566
- const int i = item_ct1.get_group(2);
566
+ const int64_t i = item_ct1.get_group(2);
567
567
  const block_iq3_s * x = (const block_iq3_s *) vx;
568
568
 
569
- const int tid = item_ct1.get_local_id(2);
569
+ const int64_t tid = item_ct1.get_local_id(2);
570
570
  #if QK_K == 256
571
- const int il = tid/8; // 0...3
572
- const int ib = tid%8; // 0...7
571
+ const int64_t il = tid/8; // 0...3
572
+ const int64_t ib = tid%8; // 0...7
573
573
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
574
574
  const uint8_t * qs = x[i].qs + 8*ib;
575
575
  const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
593
593
  const sycl::nd_item<3> &item_ct1,
594
594
  const uint32_t *iq1s_grid_gpu) {
595
595
 
596
- const int i = item_ct1.get_group(2);
596
+ const int64_t i = item_ct1.get_group(2);
597
597
  const block_iq1_s * x = (const block_iq1_s *) vx;
598
598
 
599
- const int tid = item_ct1.get_local_id(2);
599
+ const int64_t tid = item_ct1.get_local_id(2);
600
600
  #if QK_K == 256
601
- const int il = tid/8; // 0...3
602
- const int ib = tid%8; // 0...7
601
+ const int64_t il = tid/8; // 0...3
602
+ const int64_t ib = tid%8; // 0...7
603
603
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
604
604
  const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
605
605
  const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
623
623
  const sycl::nd_item<3> &item_ct1,
624
624
  const uint32_t *iq1s_grid_gpu) {
625
625
 
626
- const int i = item_ct1.get_group(2);
626
+ const int64_t i = item_ct1.get_group(2);
627
627
  const block_iq1_m * x = (const block_iq1_m *) vx;
628
628
 
629
- const int tid = item_ct1.get_local_id(2);
629
+ const int64_t tid = item_ct1.get_local_id(2);
630
630
  #if QK_K == 256
631
- const int il = tid/8; // 0...3
632
- const int ib = tid%8; // 0...7
631
+ const int64_t il = tid/8; // 0...3
632
+ const int64_t ib = tid%8; // 0...7
633
633
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
634
634
  const uint16_t * sc = (const uint16_t *)x[i].scales;
635
635
  iq1m_scale_t scale;
@@ -656,12 +656,12 @@ __dpct_inline__ static void
656
656
  dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
657
657
  const sycl::nd_item<3> &item_ct1) {
658
658
 
659
- const int i = item_ct1.get_group(2);
659
+ const int64_t i = item_ct1.get_group(2);
660
660
  const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
661
661
 
662
- const int tid = item_ct1.get_local_id(2);
663
- const int il = tid/8; // 0...3
664
- const int ib = tid%8; // 0...7
662
+ const int64_t tid = item_ct1.get_local_id(2);
663
+ const int64_t il = tid/8; // 0...3
664
+ const int64_t ib = tid%8; // 0...7
665
665
  dst_t * y = yy + i*QK_K + 32*ib + 4*il;
666
666
  const uint8_t * q4 = x[ib].qs + 4*il;
667
667
  const float d = (float)x[ib].d;
@@ -678,12 +678,12 @@ template <typename dst_t>
678
678
  __dpct_inline__ static void
679
679
  dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
680
680
  const sycl::nd_item<3> &item_ct1) {
681
- const int i = item_ct1.get_group(2);
681
+ const int64_t i = item_ct1.get_group(2);
682
682
  const block_iq4_xs * x = (const block_iq4_xs *)vx;
683
683
 
684
- const int tid = item_ct1.get_local_id(2);
685
- const int il = tid/8; // 0...3
686
- const int ib = tid%8; // 0...7
684
+ const int64_t tid = item_ct1.get_local_id(2);
685
+ const int64_t il = tid/8; // 0...3
686
+ const int64_t ib = tid%8; // 0...7
687
687
  dst_t * y = yy + i*QK_K + 32*ib + 4*il;
688
688
  const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
689
689
  const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
@@ -4,7 +4,7 @@
4
4
  #include "presets.hpp"
5
5
 
6
6
 
7
- static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
7
+ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
8
8
  const sycl::half *x = (const sycl::half *)vx;
9
9
 
10
10
  // automatic half -> float type cast if dfloat == float
@@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
12
12
  v.y() = x[ib + iqs + 1];
13
13
  }
14
14
 
15
- static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
15
+ static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
16
16
  const float * x = (const float *) vx;
17
17
 
18
18
  // automatic half -> float type cast if dfloat == float
@@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
76
76
  }
77
77
 
78
78
  // sum up partial sums and write back result
79
- #pragma unroll
80
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
79
+ const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
80
+ for (int mask = mask_start; mask > 0; mask >>= 1) {
81
81
  tmp +=
82
82
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
83
83
  }
@@ -874,7 +874,7 @@ namespace dpct
874
874
  inline std::string get_preferred_gpu_platform_name() {
875
875
  std::string result;
876
876
 
877
- std::string filter = "level-zero";
877
+ std::string filter = "";
878
878
  char* env = getenv("ONEAPI_DEVICE_SELECTOR");
879
879
  if (env) {
880
880
  if (std::strstr(env, "level_zero")) {
@@ -892,11 +892,24 @@ namespace dpct
892
892
  else {
893
893
  throw std::runtime_error("invalid device filter: " + std::string(env));
894
894
  }
895
+ } else {
896
+ auto default_device = sycl::device(sycl::default_selector_v);
897
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
898
+
899
+ if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
900
+ filter = "level-zero";
901
+ }
902
+ else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
903
+ filter = "cuda";
904
+ }
905
+ else if (std::strstr(default_platform_name.c_str(), "HIP")) {
906
+ filter = "hip";
907
+ }
895
908
  }
896
909
 
897
- auto plaform_list = sycl::platform::get_platforms();
910
+ auto platform_list = sycl::platform::get_platforms();
898
911
 
899
- for (const auto& platform : plaform_list) {
912
+ for (const auto& platform : platform_list) {
900
913
  auto devices = platform.get_devices();
901
914
  auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
902
915
  return d.is_gpu();
@@ -0,0 +1,101 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_GEMM_HPP
14
+ #define GGML_SYCL_GEMM_HPP
15
+
16
+ #include <fstream>
17
+ #include <iostream>
18
+
19
+ #include "ggml-sycl.h"
20
+
21
+ #if GGML_SYCL_DNNL
22
+
23
+ #include "dnnl.hpp"
24
+ #include "dnnl_sycl.hpp"
25
+
26
+ class DnnlGemmWrapper {
27
+ public:
28
+ using dt = dnnl::memory::data_type;
29
+ using tag = dnnl::memory::format_tag;
30
+
31
+ template<typename T>
32
+ static constexpr dt to_dt() {
33
+ if constexpr (std::is_same_v<T, float>) return dt::f32;
34
+ else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
35
+ else static_assert(0);
36
+ }
37
+
38
+ static inline void row_gemm(sycl::queue& q, bool a_trans,
39
+ bool b_trans, int m, int n, int k,
40
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41
+ {
42
+ // Get the device associated with the queue
43
+ sycl::device dev = q.get_device();
44
+ // Get the context associated with the queue
45
+ sycl::context ctx = q.get_context();
46
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47
+ const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
48
+ dnnl::memory::dims a_dims = { m, k };
49
+ dnnl::memory::dims b_dims = { k, n };
50
+ dnnl::memory::dims c_dims = { m, n };
51
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
52
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
55
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
56
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
58
+
59
+ // Create the primitive.
60
+ auto matmul_prim = dnnl::matmul(matmul_pd);
61
+ // Primitive arguments.
62
+ std::unordered_map<int, dnnl::memory> matmul_args;
63
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
66
+
67
+ matmul_prim.execute(stream, matmul_args);
68
+ }
69
+
70
+
71
+ static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72
+ bool b_trans, int m, int n, int k,
73
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74
+ {
75
+ auto const eng = stream.get_engine();
76
+ dnnl::memory::dims a_dims = { m, k };
77
+ dnnl::memory::dims b_dims = { k, n };
78
+ dnnl::memory::dims c_dims = { m, n };
79
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
82
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
83
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
84
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
85
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
86
+
87
+ // Create the primitive.
88
+ auto matmul_prim = dnnl::matmul(matmul_pd);
89
+ // Primitive arguments.
90
+ std::unordered_map<int, dnnl::memory> matmul_args;
91
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
92
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
93
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
94
+
95
+ matmul_prim.execute(stream, matmul_args);
96
+ }
97
+ };
98
+
99
+ #endif
100
+
101
+ #endif // GGML_SYCL_GEMM_HPP