@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -1,7 +1,7 @@
1
1
  #include "common.hpp"
2
2
  #include "element_wise.hpp"
3
3
 
4
- void acc_f32(const float * x, const float * y, float * dst, const int ne,
4
+ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
5
5
  const int ne10, const int ne11, const int ne12,
6
6
  const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
7
7
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
20
20
  }
21
21
  }
22
22
 
23
- void gelu_f32(const float * x, float * dst, const int k,
23
+ static void gelu_f32(const float * x, float * dst, const int k,
24
24
  const sycl::nd_item<3> &item_ct1) {
25
25
  const float GELU_COEF_A = 0.044715f;
26
26
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
37
37
  sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
38
38
  }
39
39
 
40
- void silu_f32(const float * x, float * dst, const int k,
40
+ static void silu_f32(const float * x, float * dst, const int k,
41
41
  const sycl::nd_item<3> &item_ct1) {
42
42
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
43
43
  item_ct1.get_local_id(2);
@@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
48
48
  dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
49
49
  }
50
50
 
51
- void gelu_quick_f32(const float *x, float *dst, int k,
51
+ static void gelu_quick_f32(const float *x, float *dst, int k,
52
52
  const sycl::nd_item<3> &item_ct1) {
53
53
  const float GELU_QUICK_COEF = -1.702f;
54
54
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
59
59
  dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
60
60
  }
61
61
 
62
- void tanh_f32(const float *x, float *dst, int k,
62
+ static void tanh_f32(const float *x, float *dst, int k,
63
63
  const sycl::nd_item<3> &item_ct1) {
64
64
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
65
65
  item_ct1.get_local_id(2);
@@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
69
69
  dst[i] = sycl::tanh((float)(x[i]));
70
70
  }
71
71
 
72
- void relu_f32(const float * x, float * dst, const int k,
72
+ static void relu_f32(const float * x, float * dst, const int k,
73
73
  const sycl::nd_item<3> &item_ct1) {
74
74
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
75
75
  item_ct1.get_local_id(2);
@@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
80
80
  dst[i] = sycl::fmax((float)(x[i]), (float)0);
81
81
  }
82
82
 
83
- void sigmoid_f32(const float * x, float * dst, const int k,
83
+ static void sigmoid_f32(const float * x, float * dst, const int k,
84
84
  const sycl::nd_item<3> &item_ct1) {
85
85
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
86
86
  item_ct1.get_local_id(2);
@@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
91
91
  dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
92
92
  }
93
93
 
94
- void sqrt_f32(const float * x, float * dst, const int k,
94
+ static void sqrt_f32(const float * x, float * dst, const int k,
95
95
  const sycl::nd_item<3> &item_ct1) {
96
96
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
97
97
  item_ct1.get_local_id(2);
@@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
102
102
  dst[i] = sycl::sqrt(x[i]);
103
103
  }
104
104
 
105
- void sin_f32(const float * x, float * dst, const int k,
105
+ static void sin_f32(const float * x, float * dst, const int k,
106
106
  const sycl::nd_item<3> &item_ct1) {
107
107
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
108
108
  item_ct1.get_local_id(2);
@@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
113
113
  dst[i] = sycl::sin(x[i]);
114
114
  }
115
115
 
116
- void cos_f32(const float * x, float * dst, const int k,
116
+ static void cos_f32(const float * x, float * dst, const int k,
117
117
  const sycl::nd_item<3> &item_ct1) {
118
118
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
119
119
  item_ct1.get_local_id(2);
@@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
124
124
  dst[i] = sycl::cos(x[i]);
125
125
  }
126
126
 
127
- void hardsigmoid_f32(const float * x, float * dst, const int k,
127
+ static void hardsigmoid_f32(const float * x, float * dst, const int k,
128
128
  const sycl::nd_item<3> &item_ct1) {
129
129
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
130
130
  item_ct1.get_local_id(2);
@@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
135
135
  dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
136
136
  }
137
137
 
138
- void hardswish_f32(const float * x, float * dst, const int k,
138
+ static void hardswish_f32(const float * x, float * dst, const int k,
139
139
  const sycl::nd_item<3> &item_ct1) {
140
140
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
141
141
  item_ct1.get_local_id(2);
@@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
146
146
  dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
147
147
  }
148
148
 
149
- void exp_f32(const float * x, float * dst, const int k,
149
+ static void exp_f32(const float * x, float * dst, const int k,
150
150
  const sycl::nd_item<3> &item_ct1) {
151
151
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
152
152
  item_ct1.get_local_id(2);
@@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
157
157
  dst[i] = sycl::exp(x[i]);
158
158
  }
159
159
 
160
- void log_f32(const float * x, float * dst, const int k,
160
+ static void log_f32(const float * x, float * dst, const int k,
161
161
  const sycl::nd_item<3> &item_ct1) {
162
162
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
163
163
  item_ct1.get_local_id(2);
@@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
173
173
  }
174
174
  }
175
175
 
176
- void neg_f32(const float * x, float * dst, const int k,
176
+ static void neg_f32(const float * x, float * dst, const int k,
177
177
  const sycl::nd_item<3> &item_ct1) {
178
178
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
179
179
  item_ct1.get_local_id(2);
@@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
184
184
  dst[i] = -x[i];
185
185
  }
186
186
 
187
- void step_f32(const float * x, float * dst, const int k,
187
+ static void step_f32(const float * x, float * dst, const int k,
188
188
  const sycl::nd_item<3> &item_ct1) {
189
189
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
190
190
  item_ct1.get_local_id(2);
@@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
195
195
  dst[i] = x[i] > 0.0f;
196
196
  }
197
197
 
198
- void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
198
+ static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
199
199
  const sycl::nd_item<3> &item_ct1) {
200
200
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
201
201
  item_ct1.get_local_id(2);
@@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
206
206
  sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
207
207
  }
208
208
 
209
- void sqr_f32(const float * x, float * dst, const int k,
209
+ static void sqr_f32(const float * x, float * dst, const int k,
210
210
  const sycl::nd_item<3> &item_ct1) {
211
211
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
212
212
  item_ct1.get_local_id(2);
@@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
217
217
  dst[i] = x[i] * x[i];
218
218
  }
219
219
 
220
- void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
220
+ static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
221
221
  const int nb02, const int nb03, const int ne10, const int ne11,
222
222
  const int ne12, const int ne13, const float sf0, const float sf1,
223
223
  const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
@@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
240
240
  dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
241
241
  }
242
242
 
243
- void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
243
+ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
244
244
  const sycl::nd_item<3> &item_ct1) {
245
245
  int nidx = item_ct1.get_local_id(2) +
246
246
  item_ct1.get_group(2) * item_ct1.get_local_range(2);
@@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
262
262
 
263
263
 
264
264
 
265
- void acc_f32_sycl(const float *x, const float *y, float *dst,
265
+ static void acc_f32_sycl(const float *x, const float *y, float *dst,
266
266
  const int n_elements, const int ne10, const int ne11,
267
267
  const int ne12, const int nb1, const int nb2,
268
268
  const int offset, queue_ptr stream) {
@@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
277
277
  });
278
278
  }
279
279
 
280
- void gelu_f32_sycl(const float *x, float *dst, const int k,
280
+ static void gelu_f32_sycl(const float *x, float *dst, const int k,
281
281
  queue_ptr stream) {
282
282
  const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
283
283
  stream->parallel_for(
@@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
289
289
  });
290
290
  }
291
291
 
292
- void silu_f32_sycl(const float *x, float *dst, const int k,
292
+ static void silu_f32_sycl(const float *x, float *dst, const int k,
293
293
  queue_ptr stream) {
294
294
  const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
295
295
  stream->parallel_for(
@@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
301
301
  });
302
302
  }
303
303
 
304
- void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
304
+ static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
305
305
  queue_ptr stream) {
306
306
  const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
307
307
  stream->parallel_for(
@@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
313
313
  });
314
314
  }
315
315
 
316
- void tanh_f32_sycl(const float *x, float *dst, const int k,
316
+ static void tanh_f32_sycl(const float *x, float *dst, const int k,
317
317
  queue_ptr stream) {
318
318
  const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
319
319
  stream->parallel_for(
@@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
325
325
  });
326
326
  }
327
327
 
328
- void relu_f32_sycl(const float *x, float *dst, const int k,
328
+ static void relu_f32_sycl(const float *x, float *dst, const int k,
329
329
  queue_ptr stream) {
330
330
  const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
331
331
  stream->parallel_for(
@@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
337
337
  });
338
338
  }
339
339
 
340
- void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
340
+ static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
341
341
  queue_ptr stream) {
342
342
  const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
343
343
  stream->parallel_for(
@@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
349
349
  });
350
350
  }
351
351
 
352
- void hardswish_f32_sycl(const float *x, float *dst, const int k,
352
+ static void hardswish_f32_sycl(const float *x, float *dst, const int k,
353
353
  queue_ptr stream) {
354
354
  const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
355
355
  stream->parallel_for(
@@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
361
361
  });
362
362
  }
363
363
 
364
- void exp_f32_sycl(const float *x, float *dst, const int k,
364
+ static void exp_f32_sycl(const float *x, float *dst, const int k,
365
365
  queue_ptr stream) {
366
366
  const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
367
367
  stream->parallel_for(
@@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
373
373
  });
374
374
  }
375
375
 
376
- void log_f32_sycl(const float *x, float *dst, const int k,
376
+ static void log_f32_sycl(const float *x, float *dst, const int k,
377
377
  queue_ptr stream) {
378
378
  const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
379
379
  stream->parallel_for(
@@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
385
385
  });
386
386
  }
387
387
 
388
- void neg_f32_sycl(const float *x, float *dst, const int k,
388
+ static void neg_f32_sycl(const float *x, float *dst, const int k,
389
389
  queue_ptr stream) {
390
390
  const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
391
391
  stream->parallel_for(
@@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
397
397
  });
398
398
  }
399
399
 
400
- void step_f32_sycl(const float *x, float *dst, const int k,
400
+ static void step_f32_sycl(const float *x, float *dst, const int k,
401
401
  queue_ptr stream) {
402
402
  const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
403
403
  stream->parallel_for(
@@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
409
409
  });
410
410
  }
411
411
 
412
- void sigmoid_f32_sycl(const float *x, float *dst, const int k,
412
+ static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
413
413
  queue_ptr stream) {
414
414
  const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
415
415
  stream->parallel_for(
@@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
421
421
  });
422
422
  }
423
423
 
424
- void sqrt_f32_sycl(const float *x, float *dst, const int k,
424
+ static void sqrt_f32_sycl(const float *x, float *dst, const int k,
425
425
  queue_ptr stream) {
426
426
  const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
427
427
  stream->parallel_for(
@@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
433
433
  });
434
434
  }
435
435
 
436
- void sin_f32_sycl(const float *x, float *dst, const int k,
436
+ static void sin_f32_sycl(const float *x, float *dst, const int k,
437
437
  queue_ptr stream) {
438
438
  const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
439
439
  stream->parallel_for(
@@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
445
445
  });
446
446
  }
447
447
 
448
- void cos_f32_sycl(const float *x, float *dst, const int k,
448
+ static void cos_f32_sycl(const float *x, float *dst, const int k,
449
449
  queue_ptr stream) {
450
450
  const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
451
451
  stream->parallel_for(
@@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
457
457
  });
458
458
  }
459
459
 
460
- void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
460
+ static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
461
461
  const float negative_slope,
462
462
  queue_ptr stream) {
463
463
  const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
@@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
470
470
  });
471
471
  }
472
472
 
473
- void sqr_f32_sycl(const float *x, float *dst, const int k,
473
+ static void sqr_f32_sycl(const float *x, float *dst, const int k,
474
474
  queue_ptr stream) {
475
475
  const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
476
476
  stream->parallel_for(
@@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
482
482
  });
483
483
  }
484
484
 
485
- void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
485
+ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
486
486
  const int nb02, const int nb03, const int ne10, const int ne11,
487
487
  const int ne12, const int ne13, const float sf0, const float sf1,
488
488
  const float sf2, const float sf3, queue_ptr stream) {
@@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
496
496
  });
497
497
  }
498
498
 
499
- void pad_f32_sycl(const float *x, float *dst, const int ne00,
499
+ static void pad_f32_sycl(const float *x, float *dst, const int ne00,
500
500
  const int ne01, const int ne02, const int ne0,
501
501
  const int ne1, const int ne2, queue_ptr stream) {
502
502
  int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
@@ -13,9 +13,6 @@
13
13
  #ifndef GGML_SYCL_GEMM_HPP
14
14
  #define GGML_SYCL_GEMM_HPP
15
15
 
16
- #include <fstream>
17
- #include <iostream>
18
-
19
16
  #include "ggml-sycl.h"
20
17
 
21
18
  #if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ public:
35
32
  else static_assert(0);
36
33
  }
37
34
 
38
- static inline void row_gemm(sycl::queue& q, bool a_trans,
39
- bool b_trans, int m, int n, int k,
40
- const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41
- {
42
- // Get the device associated with the queue
43
- sycl::device dev = q.get_device();
44
- // Get the context associated with the queue
45
- sycl::context ctx = q.get_context();
46
- const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47
- const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
35
+ static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37
+ auto stream = ctx.stream_dnnl(q);
38
+ auto eng = ctx.engine_dnnl(q);
48
39
  dnnl::memory::dims a_dims = { m, k };
49
40
  dnnl::memory::dims b_dims = { k, n };
50
41
  dnnl::memory::dims c_dims = { m, n };
51
42
  const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
52
43
  const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54
- auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
55
- auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
56
- auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57
- auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
44
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
58
45
 
59
- // Create the primitive.
60
- auto matmul_prim = dnnl::matmul(matmul_pd);
61
- // Primitive arguments.
62
- std::unordered_map<int, dnnl::memory> matmul_args;
63
- matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64
- matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65
- matmul_args.insert({ DNNL_ARG_DST, c_mem });
46
+ dnnl::primitive_attr primitive_attr;
47
+ primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
66
48
 
67
- matmul_prim.execute(stream, matmul_args);
68
- }
69
-
70
-
71
- static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72
- bool b_trans, int m, int n, int k,
73
- const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74
- {
75
- auto const eng = stream.get_engine();
76
- dnnl::memory::dims a_dims = { m, k };
77
- dnnl::memory::dims b_dims = { k, n };
78
- dnnl::memory::dims c_dims = { m, n };
79
- const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80
- const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
82
49
  auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
83
50
  auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
84
- auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
51
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
85
52
  auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
86
53
 
87
- // Create the primitive.
54
+ auto scratchpad_md = matmul_pd.scratchpad_desc();
55
+ auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
88
56
  auto matmul_prim = dnnl::matmul(matmul_pd);
89
- // Primitive arguments.
57
+
90
58
  std::unordered_map<int, dnnl::memory> matmul_args;
91
59
  matmul_args.insert({ DNNL_ARG_SRC, a_mem });
92
60
  matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
93
61
  matmul_args.insert({ DNNL_ARG_DST, c_mem });
62
+ matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
94
63
 
95
64
  matmul_prim.execute(stream, matmul_args);
96
65
  }
@@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
207
207
  const size_t nrows = ne01;
208
208
  const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
209
209
  stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
210
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
210
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
211
211
  k_get_rows_reorder<qk, qr, dq_reorder>(
212
212
  src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
213
213
  s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
@@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
302
302
  // TODO: k-quants
303
303
  GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
304
304
  GGML_ABORT("fatal error");
305
- break;
306
305
  }
307
306
  }
308
307