@fugood/llama.node 0.3.13 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +89 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/CMakeLists.txt +9 -1
  25. package/src/llama.cpp/cmake/common.cmake +2 -0
  26. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  27. package/src/llama.cpp/common/arg.cpp +132 -13
  28. package/src/llama.cpp/common/chat.cpp +960 -266
  29. package/src/llama.cpp/common/chat.h +135 -0
  30. package/src/llama.cpp/common/common.cpp +33 -174
  31. package/src/llama.cpp/common/common.h +27 -67
  32. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  33. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  34. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  35. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  36. package/src/llama.cpp/common/sampling.cpp +45 -7
  37. package/src/llama.cpp/common/speculative.cpp +10 -9
  38. package/src/llama.cpp/common/speculative.h +1 -1
  39. package/src/llama.cpp/docs/build.md +45 -7
  40. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  41. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
  42. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
  43. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  44. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  45. package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
  46. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  48. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  50. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  51. package/src/llama.cpp/examples/llava/clip.h +19 -3
  52. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  53. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  54. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  55. package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
  56. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  57. package/src/llama.cpp/examples/main/main.cpp +79 -34
  58. package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
  59. package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
  60. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  61. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  62. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  63. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +196 -108
  67. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  68. package/src/llama.cpp/examples/server/server.cpp +113 -101
  69. package/src/llama.cpp/examples/server/utils.hpp +94 -105
  70. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  74. package/src/llama.cpp/examples/tts/tts.cpp +263 -151
  75. package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
  76. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  77. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  79. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  80. package/src/llama.cpp/ggml/include/ggml.h +29 -1
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
  82. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  83. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  84. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  85. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  87. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
  88. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  89. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
  90. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  91. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  102. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  103. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  104. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  105. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  106. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  107. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
  108. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
  109. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  110. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
  111. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  112. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  113. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
  117. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  118. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  124. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
  125. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
  127. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  128. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
  129. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  130. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
  132. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  134. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  135. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
  139. package/src/llama.cpp/ggml/src/ggml.c +93 -5
  140. package/src/llama.cpp/include/llama.h +105 -27
  141. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  142. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  143. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  144. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  145. package/src/llama.cpp/requirements.txt +1 -0
  146. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  147. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  148. package/src/llama.cpp/src/llama-adapter.h +11 -9
  149. package/src/llama.cpp/src/llama-arch.cpp +123 -16
  150. package/src/llama.cpp/src/llama-arch.h +19 -0
  151. package/src/llama.cpp/src/llama-batch.h +2 -2
  152. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  153. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  154. package/src/llama.cpp/src/llama-context.h +214 -77
  155. package/src/llama.cpp/src/llama-cparams.h +1 -0
  156. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  157. package/src/llama.cpp/src/llama-grammar.h +12 -3
  158. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  159. package/src/llama.cpp/src/llama-graph.h +574 -0
  160. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  161. package/src/llama.cpp/src/llama-hparams.h +9 -0
  162. package/src/llama.cpp/src/llama-io.cpp +15 -0
  163. package/src/llama.cpp/src/llama-io.h +35 -0
  164. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  165. package/src/llama.cpp/src/llama-kv-cache.h +178 -109
  166. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  167. package/src/llama.cpp/src/llama-memory.h +21 -0
  168. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  169. package/src/llama.cpp/src/llama-model.cpp +8230 -122
  170. package/src/llama.cpp/src/llama-model.h +34 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  172. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  173. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  174. package/src/llama.cpp/src/llama.cpp +51 -9837
  175. package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
  176. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  177. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  178. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  179. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  180. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  181. package/src/llama.cpp/common/chat.hpp +0 -55
  182. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  183. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
  184. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -0,0 +1,11 @@
1
+ #ifndef GGML_SYCL_CPY_HPP
2
+ #define GGML_SYCL_CPY_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
7
+
8
+ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
9
+ void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
10
+
11
+ #endif // GGML_SYCL_CPY_HPP
@@ -16,6 +16,8 @@
16
16
  #include "common.hpp"
17
17
 
18
18
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
19
+ typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
20
+ const int iqs, dfloat2 &v);
19
21
 
20
22
  static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
21
23
  const int iqs, dfloat2 &v) {
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
40
42
  #endif // GGML_SYCL_F16
41
43
  }
42
44
 
45
+ static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
46
+ const int iqs, dfloat2 &v) {
47
+ // const block_q4_0 * x = (const block_q4_0 *) vx;
48
+
49
+ const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
50
+
51
+ const int vui = *((const uint8_t *)qs+iqs);
52
+
53
+ v.x() = vui & 0xF;
54
+ v.y() = vui >> 4;
55
+
56
+ #ifdef GGML_SYCL_F16
57
+ // v = v - {8.0f, 8.0f};
58
+ // v = v * {d, d};
59
+ v.s0() = (v.s0() - 8.0f) * d;
60
+ v.s1() = (v.s1() - 8.0f) * d;
61
+
62
+ #else
63
+ v.x() = (v.x() - 8.0f) * d;
64
+ v.y() = (v.y() - 8.0f) * d;
65
+ #endif // GGML_SYCL_F16
66
+ }
67
+
43
68
  static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
44
69
  const int iqs, dfloat2 &v) {
45
70
  const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
167
192
  }
168
193
  }
169
194
 
195
+ template<typename dst_t>
196
+ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
197
+ const sycl::nd_item<3> &item_ct1) {
198
+
199
+ const int64_t i = item_ct1.get_group(2);
200
+ auto k=nb32;
201
+ // assume 32 threads
202
+ const int64_t tid = item_ct1.get_local_id(2);
203
+ const int lane_ib = i * WARP_SIZE + tid;
204
+
205
+ if (lane_ib >= k / QK4_0) {
206
+ return;
207
+ }
208
+
209
+ dst_t * y_ptr = yy + lane_ib * QK4_0;
210
+
211
+ auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
212
+ auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
213
+
214
+ const float d = float(*s_ptr);
215
+
216
+ #pragma unroll
217
+ for (int l = 0; l < QK4_0 / 2; ++l) {
218
+ int vq = qs[l];
219
+ y_ptr[l + 0] = d * ((vq & 0xF) - 8);
220
+ y_ptr[l + 16] = d * ((vq >> 4) - 8);
221
+ }
222
+
223
+ }
224
+
170
225
  template<typename dst_t>
171
226
  static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
172
227
  const sycl::nd_item<3> &item_ct1) {
@@ -3,7 +3,6 @@
3
3
  #include "dequantize.hpp"
4
4
  #include "presets.hpp"
5
5
 
6
-
7
6
  static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
8
7
  const sycl::half *x = (const sycl::half *)vx;
9
8
 
@@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
91
90
  }
92
91
  }
93
92
 
93
+ template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
94
+ static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
95
+ const sycl::nd_item<3> &item_ct1) {
96
+ // qk = quantized weights per x block
97
+ // qr = number of quantized weights per data value in x block
98
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
99
+ item_ct1.get_local_id(1);
100
+
101
+ if (row >= nrows) {
102
+ return;
103
+ }
104
+
105
+ const int tid = item_ct1.get_local_id(2);
106
+
107
+
108
+ const int ncols_left = ncols % (QK4_0*WARP_SIZE);
109
+ const int ncols_align = ncols - ncols_left;
110
+ const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
111
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
112
+ const int y_offset = qr == 1 ? 1 : qk/2;
113
+
114
+ // partial sum for each thread
115
+ #ifdef GGML_SYCL_F16
116
+ sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
117
+ #else
118
+ float tmp = 0.0f;
119
+ #endif // GGML_SYCL_F16
120
+ const char *d_ptr = (const char*)vx+ncols*nrows/2;
121
+ int i=0;
122
+ for (i = 0; i < ncols_align; i += iter_stride) {
123
+ const int col = i + vals_per_iter*tid;
124
+ const int ib = (row*ncols + col)/qk; // x block index
125
+ const int iqs = (col%qk)/qr; // x quant index
126
+ const int iybs = col - col%qk; // y block start index
127
+
128
+ // processing >2 values per i iter is faster for fast GPUs
129
+ #pragma unroll
130
+ for (int j = 0; j < vals_per_iter; j += 2) {
131
+ // process 2 vals per j iter
132
+
133
+ // dequantize
134
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
135
+ dfloat2 v;
136
+ dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
137
+
138
+ // matrix multiplication
139
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
140
+ #ifdef GGML_SYCL_F16
141
+ dfloat2 t1{y[iybs + iqs + j / qr + 0],
142
+ y[iybs + iqs + j / qr + y_offset]};
143
+
144
+ tmp += v * t1;
145
+ #else
146
+ tmp += v.x() * y[iybs + iqs + j / qr + 0];
147
+ tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
148
+ #endif // GGML_SYCL_F16
149
+ }
150
+ }
151
+
152
+ for (; i < ncols; i += iter_stride) {
153
+ if (tid>=ncols_left/QK4_0) continue;
154
+ const int col = i + vals_per_iter*tid;
155
+ const int ib = (row*ncols + col)/qk; // x block index
156
+ const int iqs = (col%qk)/qr; // x quant index
157
+ const int iybs = col - col%qk; // y block start index
158
+
159
+ // processing >2 values per i iter is faster for fast GPUs
160
+ #pragma unroll
161
+ for (int j = 0; j < vals_per_iter; j += 2) {
162
+ // process 2 vals per j iter
163
+
164
+ // dequantize
165
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
166
+ dfloat2 v;
167
+ dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
168
+
169
+ // matrix multiplication
170
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
171
+ #ifdef GGML_SYCL_F16
172
+ dfloat2 t1{y[iybs + iqs + j / qr + 0],
173
+ y[iybs + iqs + j / qr + y_offset]};
174
+
175
+ tmp += v * t1;
176
+ #else
177
+ tmp += v.x() * y[iybs + iqs + j / qr + 0];
178
+ tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
179
+ #endif // GGML_SYCL_F16
180
+ }
181
+ }
182
+
183
+ // sum up partial sums and write back result
184
+ const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
185
+ for (int mask = mask_start; mask > 0; mask >>= 1) {
186
+ tmp +=
187
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
188
+ }
189
+
190
+ if (tid == 0) {
191
+ #ifdef GGML_SYCL_F16
192
+ dst[row] = tmp.x() + tmp.y();
193
+ #else
194
+ dst[row] = tmp;
195
+ #endif // GGML_SYCL_F16
196
+ }
197
+ }
198
+
94
199
  static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
95
200
  float *dst, const int ncols,
96
201
  const int nrows,
@@ -105,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
105
210
 
106
211
  stream->parallel_for(
107
212
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
108
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
213
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
109
214
  dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
110
215
  nrows, item_ct1);
111
216
  });
@@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
759
864
  }
760
865
  }
761
866
 
867
+ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
868
+ float *dst, const int ncols,
869
+ const int nrows,
870
+ dpct::queue_ptr stream) {
871
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
872
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
873
+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
874
+ const sycl::range<3> block_nums(1, 1, block_num_y);
875
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
876
+ {
877
+ dpct::has_capability_or_fail(stream->get_device(),
878
+ {sycl::aspect::fp16});
879
+
880
+ stream->parallel_for(
881
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
882
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
883
+ dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
884
+ vx, y, dst, ncols, nrows, item_ct1);
885
+ });
886
+ }
887
+ }
888
+
762
889
 
763
890
  static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
764
891
  float *dst, const int ncols,
@@ -775,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
775
902
 
776
903
  stream->parallel_for(
777
904
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
778
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
779
906
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
780
907
  vx, y, dst, ncols, nrows, item_ct1);
781
908
  });
@@ -796,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
796
923
 
797
924
  stream->parallel_for(
798
925
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
799
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
926
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
800
927
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
801
928
  vx, y, dst, ncols, nrows, item_ct1);
802
929
  });
@@ -817,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
817
944
 
818
945
  stream->parallel_for(
819
946
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
820
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
947
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
821
948
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
822
949
  vx, y, dst, ncols, nrows, item_ct1);
823
950
  });
@@ -838,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
838
965
 
839
966
  stream->parallel_for(
840
967
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
841
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
968
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
842
969
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
843
970
  vx, y, dst, ncols, nrows, item_ct1);
844
971
  });
@@ -859,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
859
986
 
860
987
  stream->parallel_for(
861
988
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
862
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
989
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
863
990
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
864
991
  vx, y, dst, ncols, nrows, item_ct1);
865
992
  });
@@ -877,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
877
1004
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
878
1005
  stream->parallel_for(
879
1006
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
880
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1007
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
881
1008
  dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
882
1009
  });
883
1010
  }
@@ -893,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
893
1020
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
894
1021
  stream->parallel_for(
895
1022
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
896
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1023
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
897
1024
  dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
898
1025
  });
899
1026
  }
@@ -909,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
909
1036
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
910
1037
  stream->parallel_for(
911
1038
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1039
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
913
1040
  dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
914
1041
  });
915
1042
  }
@@ -922,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
922
1049
  const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
923
1050
  stream->parallel_for(
924
1051
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
925
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1052
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
926
1053
  dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
927
1054
  });
928
1055
  }
@@ -938,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
938
1065
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
939
1066
  stream->parallel_for(
940
1067
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
941
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1068
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
942
1069
  dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
943
1070
  });
944
1071
  }
@@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
953
1080
 
954
1081
  const int64_t ne00 = src0->ne[0];
955
1082
  const int64_t row_diff = row_high - row_low;
956
-
957
1083
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
958
1084
  // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
959
1085
  #ifdef GGML_SYCL_F16
@@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
967
1093
 
968
1094
  if (src1_convert_f16) {
969
1095
  src1_dfloat = src1_dfloat_a.alloc(ne00);
970
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
1096
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
971
1097
  GGML_ASSERT(to_fp16_sycl != nullptr);
972
1098
  to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
973
1099
  }
@@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
977
1103
 
978
1104
  switch (src0->type) {
979
1105
  case GGML_TYPE_Q4_0:
980
- dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1106
+ if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
1107
+ ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
1108
+ dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1109
+ } else {
1110
+ dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1111
+ }
981
1112
  break;
982
1113
  case GGML_TYPE_Q4_1:
983
1114
  dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
@@ -1012,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1012
1143
  default:
1013
1144
  printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1014
1145
  GGML_ABORT("fatal error");
1015
- break;
1016
1146
  }
1017
1147
 
1018
1148
  GGML_UNUSED(src1);
@@ -1020,4 +1150,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1020
1150
  GGML_UNUSED(src1_ddq_i);
1021
1151
  GGML_UNUSED(src1_ncols);
1022
1152
  GGML_UNUSED(src1_padded_row_size);
1153
+ GGML_UNUSED(ctx);
1023
1154
  }