@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -12,115 +12,125 @@
12
12
 
13
13
  #include "im2col.hpp"
14
14
 
15
+ #include <sycl/sycl.hpp>
16
+ #include <type_traits> // For std::is_same_v
17
+
18
+ #include "ggml.h"
19
+
15
20
  template <typename T>
16
- static void im2col_kernel(
17
- const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
18
- int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
19
- int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
20
- const sycl::nd_item<3> &item_ct1) {
21
+ static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW,
22
+ int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
23
+ int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) {
21
24
  const int64_t work_group_size = item_ct1.get_local_range(2);
22
- const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
25
+ const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2));
23
26
 
24
27
  // make each work-item deal with more elements since sycl global range can not exceed max int
25
- for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
26
-
28
+ for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
27
29
  const int64_t ksize = OW * (KH > 1 ? KW : 1);
28
- const int64_t kx = i / ksize;
29
- const int64_t kd = kx * ksize;
30
- const int64_t ky = (i - kd) / OW;
31
- const int64_t ix = i % OW;
32
-
33
- const int64_t oh = item_ct1.get_group(1);
34
- const int64_t batch = item_ct1.get_group(0) / IC;
35
- const int64_t ic = item_ct1.get_group(0) % IC;
36
-
37
- const int64_t iiw = ix * s0 + kx * d0 - p0;
38
- const int64_t iih = oh * s1 + ky * d1 - p1;
39
-
40
- const int64_t offset_dst =
41
- ((batch * OH + oh) * OW + ix) * CHW +
42
- (ic * (KW * KH) + ky * KW + kx);
43
-
44
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
45
- dst[offset_dst] =
46
- sycl::vec<float, 1>(0.0f)
47
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
48
- } else {
49
- const int64_t offset_src = ic * offset_delta + batch * batch_offset;
50
- dst[offset_dst] =
51
- sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
52
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
30
+ const int64_t kx = i / ksize;
31
+ const int64_t kd = kx * ksize;
32
+ const int64_t ky = (i - kd) / OW;
33
+ const int64_t ix = i % OW;
34
+
35
+ const int64_t oh = item_ct1.get_group(1);
36
+ const int64_t batch = item_ct1.get_group(0) / IC;
37
+ const int64_t ic = item_ct1.get_group(0) % IC;
38
+
39
+ const int64_t iiw = (ix * s0) + (kx * d0) - p0;
40
+ const int64_t iih = (oh * s1) + (ky * d1) - p1;
41
+
42
+ const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx);
43
+
44
+ const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset);
45
+ const int64_t offset_src = offset_src_base + (iih * IW) + iiw;
46
+
47
+ const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW);
48
+ const float src_val = out_of_bounds ? 0.0f : x[offset_src];
49
+
50
+ if constexpr (std::is_same_v<T, sycl::half>) {
51
+ dst[offset_dst] = sycl::half(src_val);
52
+ } else if constexpr (std::is_same_v<T, float>) {
53
+ dst[offset_dst] = src_val;
53
54
  }
54
55
  }
55
56
  }
56
57
 
57
58
  template <typename T>
58
- static void im2col_sycl(
59
- const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
60
- int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
61
- int s0, int s1, int p0, int p1, int d0, int d1,
62
- queue_ptr stream) {
59
+ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
60
+ int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
61
+ int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
63
62
  const int64_t parallel_elements = OW * KW * KH;
64
- const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
63
+ const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
65
64
 
66
65
  // decrease global range when it exceeds the max int
67
66
  int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
67
+
68
68
  sycl::range<3> block_nums(batch * IC, OH, num_blocks);
69
69
  sycl::range<3> local_range(1, 1, local_size);
70
70
 
71
- {
72
- dpct::has_capability_or_fail(stream->get_device(),
73
- {sycl::aspect::fp16});
74
-
75
- stream->parallel_for(
76
- sycl::nd_range<3>(block_nums * local_range, local_range),
77
- [=](sycl::nd_item<3> item_ct1) {
78
- im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
79
- parallel_elements, (IC * KH * KW), s0, s1, p0,
80
- p1, d0, d1, item_ct1);
81
- });
71
+ const int64_t CHW = IC * KH * KW;
72
+
73
+ stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
74
+ im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
75
+ p0, p1, d0, d1, item_ct1);
76
+ });
77
+ }
78
+
79
+ static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH,
80
+ int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset,
81
+ int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
82
+ if (!stream->get_device().has(sycl::aspect::fp16)) {
83
+ throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported),
84
+ "Device does not support half precision (fp16) operations!");
82
85
  }
86
+ im2col_sycl_internal<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0,
87
+ p1, d0, d1, stream);
83
88
  }
84
89
 
85
- void ggml_sycl_op_im2col(
86
- ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
87
- ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
88
- const queue_ptr &main_stream) {
90
+ static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
91
+ int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0,
92
+ int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
93
+ im2col_sycl_internal<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1,
94
+ d0, d1, stream);
95
+ }
96
+
97
+ void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
98
+ const ggml_tensor * src0 = dst->src[0];
99
+ const ggml_tensor * src1 = dst->src[1];
89
100
 
90
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
91
101
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
92
102
  GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
93
103
 
94
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
95
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
96
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
97
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
98
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
99
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
104
+ const int32_t s0 = ((const int32_t *) (dst->op_params))[0];
105
+ const int32_t s1 = ((const int32_t *) (dst->op_params))[1];
106
+ const int32_t p0 = ((const int32_t *) (dst->op_params))[2];
107
+ const int32_t p1 = ((const int32_t *) (dst->op_params))[3];
108
+ const int32_t d0 = ((const int32_t *) (dst->op_params))[4];
109
+ const int32_t d1 = ((const int32_t *) (dst->op_params))[5];
100
110
 
101
- const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
111
+ const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1;
102
112
 
103
113
  const int64_t IC = src1->ne[is_2D ? 2 : 1];
104
114
  const int64_t IH = is_2D ? src1->ne[1] : 1;
105
- const int64_t IW = src1->ne[0];
115
+ const int64_t IW = src1->ne[0];
106
116
 
107
117
  const int64_t KH = is_2D ? src0->ne[1] : 1;
108
- const int64_t KW = src0->ne[0];
118
+ const int64_t KW = src0->ne[0];
109
119
 
110
120
  const int64_t OH = is_2D ? dst->ne[2] : 1;
111
- const int64_t OW = dst->ne[1];
121
+ const int64_t OW = dst->ne[1];
112
122
 
113
- const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
114
- const int64_t batch = src1->ne[3];
115
- const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
123
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float);
124
+ const int64_t batch = src1->ne[is_2D ? 3 : 2];
125
+ const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float);
126
+
127
+ queue_ptr stream = ctx.stream();
116
128
 
117
129
  if (dst->type == GGML_TYPE_F16) {
118
- im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
130
+ im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
131
+ batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
119
132
  } else {
120
- im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
133
+ im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
134
+ batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
121
135
  }
122
-
123
- GGML_UNUSED(src0);
124
- GGML_UNUSED(src0_dd);
125
- GGML_UNUSED(ctx);
126
136
  }
@@ -16,8 +16,6 @@
16
16
  #include "common.hpp"
17
17
 
18
18
  void ggml_sycl_op_im2col(
19
- ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
20
- ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
21
- const queue_ptr &main_stream);
19
+ ggml_backend_sycl_context & ctx, ggml_tensor *dst);
22
20
 
23
21
  #endif // GGML_SYCL_IM2COL_HPP
@@ -367,7 +367,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
367
367
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
368
368
  block_dims),
369
369
  [=](sycl::nd_item<3> item_ct1)
370
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
370
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
371
371
  l2_norm_f32(x, dst, ncols, eps, item_ct1,
372
372
  nullptr, WARP_SIZE);
373
373
  });
@@ -389,7 +389,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
389
389
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
390
390
  block_dims),
391
391
  [=](sycl::nd_item<3> item_ct1)
392
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
392
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
393
393
  l2_norm_f32(x, dst, ncols, eps, item_ct1,
394
394
  get_pointer(s_sum_acc_ct1), work_group_size);
395
395
  });
@@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
397
397
  }
398
398
  }
399
399
 
400
- void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
401
- ggml_tensor* dst, const float* src0_dd,
402
- const float* src1_dd, float* dst_dd,
403
- const queue_ptr& main_stream) {
400
+ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
404
401
 
405
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
402
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
406
403
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
407
404
 
408
- const int64_t ne00 = src0->ne[0];
409
- const int64_t nrows = ggml_nrows(src0);
405
+ const int64_t ne00 = dst->src[0]->ne[0];
406
+ const int64_t nrows = ggml_nrows(dst->src[0]);
407
+ dpct::queue_ptr main_stream = ctx.stream();
408
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
409
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
410
+ float * dst_dd = static_cast<float *>(dst->data);
410
411
 
411
412
  float eps;
412
413
  memcpy(&eps, dst->op_params, sizeof(float));
413
414
 
414
415
  norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
415
-
416
- (void)src1;
417
- (void)dst;
418
- (void)src1_dd;
419
416
  }
420
417
 
421
- void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
422
- const ggml_tensor* src1, ggml_tensor* dst,
423
- const float* src0_dd, const float* src1_dd,
424
- float* dst_dd,
425
- const queue_ptr& main_stream) {
418
+ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
426
419
 
427
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
420
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
428
421
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
429
422
 
430
423
  int num_groups = dst->op_params[0];
424
+ dpct::queue_ptr main_stream = ctx.stream();
425
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
426
+
427
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
428
+ float * dst_dd = static_cast<float *>(dst->data);
431
429
 
432
430
  float eps;
433
431
  memcpy(&eps, dst->op_params + 1, sizeof(float));
434
432
 
435
- int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
436
- group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
437
-
438
- (void)src1;
439
- (void)dst;
440
- (void)src1_dd;
441
- GGML_UNUSED(ctx);
433
+ int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
434
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
442
435
  }
443
436
 
444
- void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
445
- const ggml_tensor* src1, ggml_tensor* dst,
446
- const float* src0_dd, const float* src1_dd,
447
- float* dst_dd,
448
- const queue_ptr& main_stream) {
437
+ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
449
438
 
450
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
439
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
451
440
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
452
441
 
453
- const int64_t ne00 = src0->ne[0];
454
- const int64_t nrows = ggml_nrows(src0);
442
+ const int64_t ne00 = dst->src[0]->ne[0];
443
+ const int64_t nrows = ggml_nrows(dst->src[0]);
444
+ dpct::queue_ptr main_stream = ctx.stream();
445
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
446
+
447
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
448
+ float * dst_dd = static_cast<float *>(dst->data);
455
449
 
456
450
  float eps;
457
451
  memcpy(&eps, dst->op_params, sizeof(float));
458
452
 
459
453
  rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
460
-
461
- (void)src1;
462
- (void)dst;
463
- (void)src1_dd;
464
454
  }
465
455
 
466
- void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
467
- const ggml_tensor* src1, ggml_tensor* dst,
468
- const float* src0_dd, const float* src1_dd,
469
- float* dst_dd,
470
- const queue_ptr& main_stream) {
456
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
471
457
 
472
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
458
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
473
459
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
474
460
 
475
- const int64_t ne00 = src0->ne[0];
476
- const int64_t nrows = ggml_nrows(src0);
461
+ dpct::queue_ptr main_stream = ctx.stream();
462
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
463
+
464
+ const int64_t ne00 = dst->src[0]->ne[0];
465
+ const int64_t nrows = ggml_nrows(dst->src[0]);
466
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
467
+ float * dst_dd = static_cast<float *>(dst->data);
477
468
 
478
469
  float eps;
479
470
  memcpy(&eps, dst->op_params, sizeof(float));
480
471
 
481
472
  l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
482
473
 
483
- (void)src1;
484
- (void)dst;
485
- (void)src1_dd;
486
474
  }
@@ -15,27 +15,12 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
19
- ggml_tensor* dst, const float* src0_dd,
20
- const float* src1_dd, float* dst_dd,
21
- const queue_ptr& main_stream);
22
-
23
- void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
24
- const ggml_tensor* src1, ggml_tensor* dst,
25
- const float* src0_dd, const float* src1_dd,
26
- float* dst_dd,
27
- const queue_ptr& main_stream);
28
-
29
- void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
30
- const ggml_tensor* src1, ggml_tensor* dst,
31
- const float* src0_dd, const float* src1_dd,
32
- float* dst_dd,
33
- const queue_ptr& main_stream);
34
-
35
- void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
36
- const ggml_tensor* src1, ggml_tensor* dst,
37
- const float* src0_dd, const float* src1_dd,
38
- float* dst_dd,
39
- const queue_ptr& main_stream);
18
+ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
19
+
20
+ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
21
+
22
+ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
23
+
24
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
40
25
 
41
26
  #endif // GGML_SYCL_NORM_HPP
@@ -1,8 +1,5 @@
1
- #include <sycl/sycl.hpp>
2
- #include <oneapi/mkl.hpp>
3
1
  #include "outprod.hpp"
4
2
 
5
-
6
3
  void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
7
4
  const ggml_tensor *src0 = dst->src[0];
8
5
  const ggml_tensor *src1 = dst->src[1];
@@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
34
31
 
35
32
  // Handle transposition of src1
36
33
  const bool src1_T = ggml_is_transposed(src1);
37
- const oneapi::mkl::transpose src1_op =
38
- src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
34
+ const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
39
35
  const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
40
36
 
41
37
  try {
42
- // Perform matrix multiplication using oneMKL GEMM
43
- #ifdef GGML_SYCL_NVIDIA
44
- oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45
- oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46
- ne00, src1_d, ldb, beta, dst_d, ne0);
47
- #else
48
- oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
49
- src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
50
- #endif
38
+ // Perform matrix multiplication using oneMath GEMM
39
+ oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
40
+ ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
51
41
  }
52
42
  catch (sycl::exception const& exc) {
53
43
  std::cerr << exc.what() << std::endl;