@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -1,9 +1,15 @@
1
1
  #include "rope.hpp"
2
+ #include "ggml-sycl/common.hpp"
3
+ #include "ggml.h"
2
4
 
3
5
  struct rope_corr_dims {
4
6
  float v[2];
5
7
  };
6
8
 
9
+ struct mrope_sections {
10
+ int v[4];
11
+ };
12
+
7
13
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
8
14
  const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
9
15
  return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
@@ -28,23 +34,21 @@ static void rope_yarn(
28
34
  *sin_theta = sycl::sin(theta) * mscale;
29
35
  }
30
36
 
31
- template<typename T, bool has_ff>
32
- static void rope_norm(
33
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
35
- const sycl::nd_item<3> &item_ct1) {
36
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
37
- item_ct1.get_local_id(1));
37
+ template <typename T, bool has_ff>
38
+ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
+ const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
+ const sycl::nd_item<3> & item_ct1) {
42
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
38
43
 
39
44
  if (i0 >= ne0) {
40
45
  return;
41
46
  }
42
47
 
43
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
44
- item_ct1.get_local_id(2);
48
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
45
49
 
46
50
  if (i0 >= n_dims) {
47
- const int i = row*ne0 + i0;
51
+ const int i = row * ne0 + i0;
48
52
 
49
53
  dst[i + 0] = x[i + 0];
50
54
  dst[i + 1] = x[i + 1];
@@ -52,42 +56,43 @@ static void rope_norm(
52
56
  return;
53
57
  }
54
58
 
55
- const int i = row*ne0 + i0;
56
- const int i2 = row/p_delta_rows;
59
+ const int row0 = row % ne1;
60
+ const int channel0 = row / ne1;
57
61
 
58
- const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
62
+ const int i = row * ne0 + i0;
63
+ const int i2 = channel0 * s2 + row0 * s1 + i0;
59
64
 
60
- const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
65
+ const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
66
+
67
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
61
68
 
62
69
  float cos_theta;
63
70
  float sin_theta;
64
71
 
65
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
66
73
 
67
- const float x0 = x[i + 0];
68
- const float x1 = x[i + 1];
74
+ const float x0 = x[i2 + 0];
75
+ const float x1 = x[i2 + 1];
69
76
 
70
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
71
- dst[i + 1] = x0*sin_theta + x1*cos_theta;
77
+ dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
78
+ dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
72
79
  }
73
80
 
74
- template<typename T, bool has_ff>
75
- static void rope_neox(
76
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
77
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
78
- const sycl::nd_item<3> &item_ct1) {
79
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
80
- item_ct1.get_local_id(1));
81
+ template <typename T, bool has_ff>
82
+ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
83
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
84
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
85
+ const sycl::nd_item<3> & item_ct1) {
86
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
81
87
 
82
88
  if (i0 >= ne0) {
83
89
  return;
84
90
  }
85
91
 
86
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
87
- item_ct1.get_local_id(2);
92
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
88
93
 
89
94
  if (i0 >= n_dims) {
90
- const int i = row*ne0 + i0;
95
+ const int i = row * ne0 + i0;
91
96
 
92
97
  dst[i + 0] = x[i + 0];
93
98
  dst[i + 1] = x[i + 1];
@@ -95,38 +100,83 @@ static void rope_neox(
95
100
  return;
96
101
  }
97
102
 
98
- const int i = row*ne0 + i0/2;
99
- const int i2 = row/p_delta_rows;
103
+ const int row0 = row % ne1;
104
+ const int channel0 = row / ne1;
105
+
106
+ const int i = row * ne0 + i0 / 2;
107
+ const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
100
108
 
101
- const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
109
+ const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
102
110
 
103
- const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
111
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
104
112
 
105
113
  float cos_theta;
106
114
  float sin_theta;
107
115
 
108
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
116
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
117
+
118
+ const float x0 = x[i2 + 0];
119
+ const float x1 = x[i2 + n_dims / 2];
120
+
121
+ dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
122
+ dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
123
+ }
124
+
125
+ template <typename T, bool has_ff>
126
+ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
127
+ const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
128
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
129
+ const float theta_scale, const float * freq_factors, const mrope_sections sections,
130
+ const sycl::nd_item<3> & item_ct1) {
131
+ // get index pos
132
+ const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
133
+ if (i0 >= ne0) {
134
+ return;
135
+ }
136
+ const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
137
+ const int row_x = row_dst % ne1;
138
+ const int channel_x = row_dst / ne1;
139
+ const int idst = (row_dst * ne0) + (i0 / 2);
140
+ const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
141
+
142
+ const int sect_dims = sections.v[0] + sections.v[1];
143
+ const int sector = (i0 / 2) % sect_dims;
144
+
145
+ float theta_base = 0.0f;
146
+ if (sector < sections.v[0]) {
147
+ const int p = sector;
148
+ theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
149
+ } else {
150
+ // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
151
+ const int p = sector - sections.v[0];
152
+ theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
153
+ }
109
154
 
110
- const float x0 = x[i + 0];
111
- const float x1 = x[i + n_dims/2];
155
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
156
+ float cos_theta;
157
+ float sin_theta;
158
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
159
+ const float x0 = x[ix + 0];
160
+ const float x1 = x[ix + n_dims];
112
161
 
113
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
114
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
162
+ // store results in dst
163
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
164
+ dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
115
165
  }
116
166
 
117
167
  template <typename T>
118
- static void rope_norm_sycl(
119
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
120
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
168
+ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
169
+ const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
170
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
171
+ const float * freq_factors, queue_ptr stream) {
121
172
  GGML_ASSERT(ne0 % 2 == 0);
122
173
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
123
- const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
174
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
124
175
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
125
176
 
126
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
177
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
127
178
 
128
- dpct::has_capability_or_fail(stream->get_device(),
129
- {sycl::aspect::fp16});
179
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
130
180
 
131
181
  if (freq_factors == nullptr) {
132
182
  /*
@@ -134,82 +184,102 @@ static void rope_norm_sycl(
134
184
  the limit. To get the device limit, query
135
185
  info::device::max_work_group_size. Adjust the work-group size if needed.
136
186
  */
137
- stream->parallel_for(
138
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
139
- [=](sycl::nd_item<3> item_ct1) {
140
- rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
141
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
142
- item_ct1);
143
- });
187
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
188
+ rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189
+ theta_scale, freq_factors, item_ct1);
190
+ });
144
191
  } else {
145
192
  /*
146
193
  DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
147
194
  the limit. To get the device limit, query
148
195
  info::device::max_work_group_size. Adjust the work-group size if needed.
149
196
  */
150
- stream->parallel_for(
151
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
152
- [=](sycl::nd_item<3> item_ct1) {
153
- rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
154
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
155
- item_ct1);
156
- });
197
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
198
+ rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199
+ theta_scale, freq_factors, item_ct1);
200
+ });
157
201
  }
158
202
  }
159
203
 
160
204
  template <typename T>
161
- static void rope_neox_sycl(
162
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
163
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
205
+ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
206
+ const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
207
+ const float freq_base, const float ext_factor, const float attn_factor,
208
+ const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
164
209
  GGML_ASSERT(ne0 % 2 == 0);
165
210
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
166
- const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
211
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
167
212
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
168
213
 
169
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
214
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
170
215
 
171
- dpct::has_capability_or_fail(stream->get_device(),
172
- {sycl::aspect::fp16});
216
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
173
217
 
174
218
  if (freq_factors == nullptr) {
175
- stream->parallel_for(
176
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
177
- [=](sycl::nd_item<3> item_ct1) {
178
- rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
179
- p_delta_rows, ext_factor, attn_factor,
180
- corr_dims, theta_scale, freq_factors,
181
- item_ct1);
182
- });
219
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
220
+ rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221
+ theta_scale, freq_factors, item_ct1);
222
+ });
183
223
  } else {
184
- stream->parallel_for(
185
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
186
- [=](sycl::nd_item<3> item_ct1) {
187
- rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
188
- p_delta_rows, ext_factor, attn_factor,
189
- corr_dims, theta_scale, freq_factors,
190
- item_ct1);
191
- });
224
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
225
+ rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226
+ theta_scale, freq_factors, item_ct1);
227
+ });
192
228
  }
193
229
  }
194
230
 
195
- void ggml_sycl_op_rope(
196
- ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
197
- const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
198
- const ggml_tensor * src2 = dst->src[2];
231
+ // rope vision
232
+ template <typename T>
233
+ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
234
+ const size_t s2, const int n_dims, const int nr, const int32_t * pos,
235
+ const float freq_scale, const float freq_base, const float ext_factor,
236
+ const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
237
+ const mrope_sections sections, queue_ptr stream) {
238
+ GGML_ASSERT(ne0 % 2 == 0);
239
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
240
+ const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
241
+ const sycl::range<3> grid_dims(1, n_blocks_y, nr);
242
+ const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
243
+
244
+ const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
245
+ // Add FP16 capability check if T could be sycl::half
246
+ if constexpr (std::is_same_v<T, sycl::half>) {
247
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
248
+ }
249
+ // launch kernel
250
+ if (freq_factors == nullptr) {
251
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
252
+ rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
253
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
254
+ });
255
+ } else {
256
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
257
+ rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
258
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
259
+ });
260
+ }
261
+ }
262
+
263
+ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
199
264
 
200
- GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
265
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
201
266
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202
- GGML_ASSERT(src0->type == dst->type);
267
+ GGML_ASSERT(dst->src[0]->type == dst->type);
268
+ const int64_t ne00 = dst->src[0]->ne[0]; // head dims
269
+ const int64_t ne01 = dst->src[0]->ne[1]; // num heads
270
+ const int64_t ne02 = dst->src[0]->ne[2]; // num heads
271
+ const int64_t nr = ggml_nrows(dst->src[0]);
272
+
273
+ const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
274
+ const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
203
275
 
204
- const int64_t ne00 = src0->ne[0];
205
- const int64_t ne01 = src0->ne[1];
206
- const int64_t nr = ggml_nrows(src0);
207
276
 
208
277
  //const int n_past = ((int32_t *) dst->op_params)[0];
209
278
  const int n_dims = ((int32_t *) dst->op_params)[1];
210
279
  const int mode = ((int32_t *) dst->op_params)[2];
211
280
  //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
281
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
282
+ mrope_sections sections;
213
283
 
214
284
  // RoPE alteration for extended context
215
285
  float freq_base;
@@ -225,52 +295,68 @@ void ggml_sycl_op_rope(
225
295
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226
296
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
297
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
298
+ memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
228
299
 
229
300
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
301
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
230
302
 
231
- const int32_t * pos = (const int32_t *) src1_dd;
303
+ const int32_t * pos = (const int32_t *) dst->src[1]->data;
232
304
 
233
305
  const float * freq_factors = nullptr;
234
- if (src2 != nullptr) {
235
- freq_factors = (const float *) src2->data;
306
+ if (dst->src[2] != nullptr) {
307
+ freq_factors = (const float *) dst->src[2]->data;
236
308
  }
237
309
 
238
310
  rope_corr_dims corr_dims;
239
311
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
312
 
313
+ dpct::queue_ptr main_stream = ctx.stream();
314
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
315
+
241
316
  // compute
242
317
  if (is_neox) {
243
- if (src0->type == GGML_TYPE_F32) {
244
- rope_neox_sycl(
245
- (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
- attn_factor, corr_dims, freq_factors, main_stream
247
- );
248
- } else if (src0->type == GGML_TYPE_F16) {
249
- rope_neox_sycl(
250
- (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
- attn_factor, corr_dims, freq_factors, main_stream
252
- );
318
+ GGML_SYCL_DEBUG("%s: neox path\n", __func__);
319
+ if (dst->src[0]->type == GGML_TYPE_F32) {
320
+ rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
321
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
322
+ } else if (dst->src[0]->type == GGML_TYPE_F16) {
323
+ rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
324
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325
+ main_stream);
253
326
  } else {
254
327
  GGML_ABORT("fatal error");
255
328
  }
329
+ } else if (is_vision) {
330
+ GGML_SYCL_DEBUG("%s: vision path\n", __func__);
331
+ if (dst->src[0]->type == GGML_TYPE_F16) {
332
+ rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
333
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334
+ freq_factors, sections, main_stream);
335
+ } else if (dst->src[0]->type == GGML_TYPE_F32) {
336
+ rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
337
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338
+ main_stream);
339
+ } else {
340
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
341
+ }
256
342
  } else {
257
- if (src0->type == GGML_TYPE_F32) {
258
- rope_norm_sycl(
259
- (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
- attn_factor, corr_dims, freq_factors, main_stream
261
- );
262
- } else if (src0->type == GGML_TYPE_F16) {
263
- rope_norm_sycl(
264
- (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
- attn_factor, corr_dims, freq_factors, main_stream
266
- );
343
+ GGML_SYCL_DEBUG("%s: norm path\n", __func__);
344
+ if (dst->src[0]->type == GGML_TYPE_F32) {
345
+ rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
346
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
347
+ } else if (dst->src[0]->type == GGML_TYPE_F16) {
348
+ rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
349
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350
+ main_stream);
267
351
  } else {
268
352
  GGML_ABORT("fatal error");
269
353
  }
270
354
  }
355
+ }
271
356
 
272
- GGML_UNUSED(src1);
273
- GGML_UNUSED(dst);
274
- GGML_UNUSED(src1_dd);
275
- GGML_UNUSED(ctx);
357
+ void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358
+ GGML_SYCL_DEBUG("call %s\n", __func__);
359
+ ggml_sycl_op_rope(ctx, dst);
360
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
276
361
  }
362
+
@@ -15,8 +15,6 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- void ggml_sycl_op_rope(
19
- ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
20
- const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
18
+ void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
21
19
 
22
20
  #endif // GGML_SYCL_ROPE_HPP
@@ -32,8 +32,10 @@ if (Vulkan_FOUND)
32
32
 
33
33
  if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
34
34
  message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
35
+ set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF)
35
36
  else()
36
37
  message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
38
+ set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON)
37
39
  add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
38
40
  endif()
39
41
 
@@ -46,11 +48,29 @@ if (Vulkan_FOUND)
46
48
 
47
49
  if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
48
50
  message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
51
+ set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF)
49
52
  else()
50
53
  message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
54
+ set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON)
51
55
  add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
52
56
  endif()
53
57
 
58
+ # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
59
+ # If it's not, there will be an error to stderr.
60
+ # If it's supported, set a define to indicate that we should compile those shaders
61
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
62
+ OUTPUT_VARIABLE glslc_output
63
+ ERROR_VARIABLE glslc_error)
64
+
65
+ if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
66
+ message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
67
+ set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF)
68
+ else()
69
+ message(STATUS "GL_EXT_integer_dot_product supported by glslc")
70
+ set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON)
71
+ add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
72
+ endif()
73
+
54
74
  target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
55
75
  target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
56
76
 
@@ -119,6 +139,9 @@ if (Vulkan_FOUND)
119
139
  SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
120
140
  CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
121
141
  -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
142
+ -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
143
+ -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
144
+ -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
122
145
  BUILD_COMMAND ${CMAKE_COMMAND} --build .
123
146
  INSTALL_COMMAND ${CMAKE_COMMAND} --install .
124
147
  INSTALL_DIR ${CMAKE_BINARY_DIR}