@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -2,29 +2,28 @@
2
2
  #define GGML_SYCL_ELEMENTWISE_HPP
3
3
 
4
4
  #include "common.hpp"
5
+ #include "ggml.h"
6
+ #include <limits.h>
5
7
 
6
- static __dpct_inline__ float op_repeat(const float a, const float b) {
7
- return b;
8
- GGML_UNUSED(a);
8
+ template <typename T>
9
+ T neg_infinity() {
10
+ return -std::numeric_limits<T>::infinity();
9
11
  }
10
12
 
11
- static __dpct_inline__ float op_add(const float a, const float b) {
12
- return a + b;
13
+ template<typename T>
14
+ struct typed_data {
15
+ const T * src;
16
+ T * dst;
17
+ };
18
+
19
+ template<typename T>
20
+ typed_data<T> cast_data(ggml_tensor * dst) {
21
+ return {
22
+ /* .src = */ static_cast<const T *>(dst->src[0]->data),
23
+ /* .dst = */ static_cast<T *>(dst->data)
24
+ };
13
25
  }
14
26
 
15
- static __dpct_inline__ float op_sub(const float a, const float b) {
16
- return a - b;
17
- }
18
-
19
- static __dpct_inline__ float op_mul(const float a, const float b) {
20
- return a * b;
21
- }
22
-
23
- static __dpct_inline__ float op_div(const float a, const float b) {
24
- return a / b;
25
- }
26
-
27
-
28
27
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
29
28
 
30
29
  void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -65,12 +64,12 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
65
64
 
66
65
  void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
67
66
 
68
- void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
67
+ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
69
68
 
70
- void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
69
+ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
71
70
 
72
- void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
73
-
74
- void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
71
+ void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
75
72
 
73
+ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
76
74
  #endif // GGML_SYCL_ELEMENTWISE_HPP
75
+
@@ -13,9 +13,6 @@
13
13
  #ifndef GGML_SYCL_GEMM_HPP
14
14
  #define GGML_SYCL_GEMM_HPP
15
15
 
16
- #include <fstream>
17
- #include <iostream>
18
-
19
16
  #include "ggml-sycl.h"
20
17
 
21
18
  #if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ public:
35
32
  else static_assert(0);
36
33
  }
37
34
 
38
- static inline void row_gemm(sycl::queue& q, bool a_trans,
39
- bool b_trans, int m, int n, int k,
40
- const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41
- {
42
- // Get the device associated with the queue
43
- sycl::device dev = q.get_device();
44
- // Get the context associated with the queue
45
- sycl::context ctx = q.get_context();
46
- const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47
- const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
35
+ static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37
+ auto stream = ctx.stream_dnnl(q);
38
+ auto eng = ctx.engine_dnnl(q);
48
39
  dnnl::memory::dims a_dims = { m, k };
49
40
  dnnl::memory::dims b_dims = { k, n };
50
41
  dnnl::memory::dims c_dims = { m, n };
51
42
  const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
52
43
  const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54
- auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
55
- auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
56
- auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57
- auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
44
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
58
45
 
59
- // Create the primitive.
60
- auto matmul_prim = dnnl::matmul(matmul_pd);
61
- // Primitive arguments.
62
- std::unordered_map<int, dnnl::memory> matmul_args;
63
- matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64
- matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65
- matmul_args.insert({ DNNL_ARG_DST, c_mem });
46
+ dnnl::primitive_attr primitive_attr;
47
+ primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
66
48
 
67
- matmul_prim.execute(stream, matmul_args);
68
- }
69
-
70
-
71
- static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72
- bool b_trans, int m, int n, int k,
73
- const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74
- {
75
- auto const eng = stream.get_engine();
76
- dnnl::memory::dims a_dims = { m, k };
77
- dnnl::memory::dims b_dims = { k, n };
78
- dnnl::memory::dims c_dims = { m, n };
79
- const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80
- const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
82
49
  auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
83
50
  auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
84
- auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
51
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
85
52
  auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
86
53
 
87
- // Create the primitive.
54
+ auto scratchpad_md = matmul_pd.scratchpad_desc();
55
+ auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
88
56
  auto matmul_prim = dnnl::matmul(matmul_pd);
89
- // Primitive arguments.
57
+
90
58
  std::unordered_map<int, dnnl::memory> matmul_args;
91
59
  matmul_args.insert({ DNNL_ARG_SRC, a_mem });
92
60
  matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
93
61
  matmul_args.insert({ DNNL_ARG_DST, c_mem });
62
+ matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
94
63
 
95
64
  matmul_prim.execute(stream, matmul_args);
96
65
  }
@@ -257,50 +257,54 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
257
257
  GGML_UNUSED(ctx);
258
258
  }
259
259
 
260
- void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
261
- const ggml_tensor *src1, ggml_tensor *dst,
262
- const float *src0_d, const float *src1_d,
263
- float *dst_d, const queue_ptr &stream) {
260
+ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
264
261
 
265
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
262
+ GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
266
263
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
267
264
 
268
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
269
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
265
+ GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
266
+ GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
270
267
  GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
271
268
 
272
- const int32_t * src1_i32 = (const int32_t *) src1_d;
273
-
274
- switch (src0->type) {
269
+ const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data;
270
+ /* TODO: Refactor and remove duplicates */
271
+ switch (dst->src[0]->type) {
275
272
  case GGML_TYPE_F16:
276
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
277
- src1_i32, dst_d, stream);
273
+ get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
274
+ src1_i32, (float *)dst->data, ctx.stream());
278
275
  break;
279
276
  case GGML_TYPE_F32:
280
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
277
+ get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
278
+ src1_i32, (float *)dst->data, ctx.stream());
281
279
  break;
282
280
  case GGML_TYPE_Q4_0:
283
281
  if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
284
- get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
282
+ get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
283
+ src1_i32, (float *)dst->data, ctx.stream());
285
284
  } else {
286
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
285
+ get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
286
+ src1_i32, (float *)dst->data, ctx.stream());
287
287
  }
288
288
  break;
289
289
  case GGML_TYPE_Q4_1:
290
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
290
+ get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
291
+ src1_i32, (float *)dst->data, ctx.stream());
291
292
  break;
292
293
  case GGML_TYPE_Q5_0:
293
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
294
+ get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
295
+ src1_i32, (float *)dst->data, ctx.stream());
294
296
  break;
295
297
  case GGML_TYPE_Q5_1:
296
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
298
+ get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
299
+ src1_i32, (float *)dst->data, ctx.stream());
297
300
  break;
298
301
  case GGML_TYPE_Q8_0:
299
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
302
+ get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
303
+ src1_i32, (float *)dst->data, ctx.stream());
300
304
  break;
301
305
  default:
302
306
  // TODO: k-quants
303
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
307
+ GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
304
308
  GGML_ABORT("fatal error");
305
309
  }
306
310
  }
@@ -15,9 +15,6 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
- const ggml_tensor *src1, ggml_tensor *dst,
20
- const float *src0_d, const float *src1_d,
21
- float *dst_d, const queue_ptr &stream);
18
+ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
22
19
 
23
20
  #endif // GGML_SYCL_GETROWS_HPP