@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -19,6 +19,10 @@
19
19
  #include "dpct/helper.hpp"
20
20
  #include "ggml-sycl.h"
21
21
  #include "presets.hpp"
22
+ #if GGML_SYCL_DNNL
23
+ #include "dnnl.hpp"
24
+ #include "dnnl_sycl.hpp"
25
+ #endif
22
26
 
23
27
  #define GGML_COMMON_DECL_SYCL
24
28
  #define GGML_COMMON_IMPL_SYCL
@@ -276,6 +280,52 @@ struct ggml_backend_sycl_context {
276
280
  return stream(device, 0);
277
281
  }
278
282
 
283
+ #if GGML_SYCL_DNNL
284
+ dnnl::engine make_engine(sycl::queue* q) {
285
+ // Get the device associated with the queue
286
+ sycl::device dev = q->get_device();
287
+ // Get the context associated with the queue
288
+ sycl::context ctx = q->get_context();
289
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
290
+ return eng;
291
+ }
292
+
293
+ std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
294
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
295
+ dnnl::stream stream_dnnl(int device, int _stream) {
296
+ auto q = stream(device, _stream);
297
+ return stream_dnnl(q);
298
+ }
299
+ dnnl::engine engine_dnnl(sycl::queue* qptr) {
300
+ auto it = engine_map.find(qptr);
301
+ if (it == engine_map.end()) {
302
+ auto eng = make_engine(qptr);
303
+ engine_map[qptr] = eng;
304
+ return eng;
305
+ }
306
+ else
307
+ {
308
+ return it->second;
309
+ }
310
+ }
311
+ dnnl::stream stream_dnnl(sycl::queue* qptr) {
312
+ auto it = stream_map.find(qptr);
313
+ if (it == stream_map.end()) {
314
+ auto eng = engine_dnnl(qptr);
315
+ auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
316
+ stream_map[qptr] = stream;
317
+ return stream;
318
+ }
319
+ else
320
+ {
321
+ return it->second;
322
+ }
323
+ }
324
+ dnnl::stream stream_dnnl() {
325
+ return stream_dnnl(device, 0);
326
+ }
327
+ #endif
328
+
279
329
  // pool
280
330
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
281
331
 
@@ -352,4 +402,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
352
402
  return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
353
403
  }
354
404
 
405
+ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
+
355
407
  #endif // GGML_SYCL_COMMON_HPP
@@ -0,0 +1,99 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "conv.hpp"
14
+
15
+ static void conv_transpose_1d_kernel(
16
+ const int s0, const int output_size,
17
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
18
+ const int src1_ne0, const int dst_ne0,
19
+ const float * src0, const float * src1, float * dst,
20
+ const sycl::nd_item<3> &item_ct1) {
21
+ int global_index = item_ct1.get_local_id(2) +
22
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
23
+ if (global_index >= output_size) {
24
+ return;
25
+ }
26
+
27
+ int out_index = global_index / dst_ne0;
28
+
29
+ float accumulator = 0;
30
+
31
+ for (int c = 0; c < src0_ne2; c++) {
32
+ int idx = global_index % dst_ne0;
33
+
34
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
35
+ int input_offset = src1_ne0 * c;
36
+
37
+ for (int i = 0; i < src1_ne0; i++) {
38
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
39
+ continue;
40
+ }
41
+ int weight_idx = idx - i*s0;
42
+
43
+ float kernel_weight = src0[kernel_offset + weight_idx];
44
+ float input_value = src1[input_offset+i];
45
+
46
+ accumulator += kernel_weight * input_value;
47
+ }
48
+ }
49
+ dst[global_index] = accumulator;
50
+ }
51
+
52
+ static void conv_transpose_1d_f32_f32_sycl(
53
+ const int s0, const int output_size,
54
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
55
+ const int src1_ne0, const int dst_ne0,
56
+ const float *src0, const float *src1, float *dst,
57
+ const queue_ptr& stream) {
58
+
59
+ const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60
+ const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61
+ const sycl::range<3> block_nums(1, 1, num_blocks);
62
+ stream->parallel_for(
63
+ sycl::nd_range<3>(
64
+ block_nums * block_dims, block_dims),
65
+ [=](sycl::nd_item<3> item_ct1) {
66
+ conv_transpose_1d_kernel(
67
+ s0, output_size,
68
+ src0_ne0, src0_ne1, src0_ne2,
69
+ src1_ne0, dst_ne0,
70
+ src0, src1, dst, item_ct1);
71
+ });
72
+ }
73
+
74
+ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
75
+ const ggml_tensor *src1, ggml_tensor *dst) {
76
+ const float * src0_d = (const float *)src0->data;
77
+ const float * src1_d = (const float *)src1->data;
78
+
79
+ float * dst_d = (float *)dst->data;
80
+ dpct::queue_ptr stream = ctx.stream();
81
+
82
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
83
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
84
+
85
+ GGML_ASSERT(ggml_is_contiguous(src0));
86
+ GGML_ASSERT(ggml_is_contiguous(src1));
87
+
88
+ const int32_t * opts = (const int32_t *)dst->op_params;
89
+
90
+ const int s0 = opts[0];
91
+
92
+ const int64_t output_size = ggml_nelements(dst);
93
+
94
+ conv_transpose_1d_f32_f32_sycl(s0, output_size,
95
+ src0->ne[0], src0->ne[1], src0->ne[2],
96
+ src1->ne[0], dst->ne[0],
97
+ src0_d, src1_d, dst_d, stream);
98
+ }
99
+
@@ -0,0 +1,21 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_CONV_HPP
14
+ #define GGML_SYCL_CONV_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
+ const ggml_tensor *src1, ggml_tensor *dst);
20
+
21
+ #endif // GGML_SYCL_CONV_HPP
@@ -3,19 +3,19 @@
3
3
  #include "presets.hpp"
4
4
 
5
5
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
6
- static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
6
+ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
7
7
  const sycl::nd_item<3> &item_ct1) {
8
- const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8
+ const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
9
  item_ct1.get_local_id(2));
10
10
 
11
11
  if (i >= k) {
12
12
  return;
13
13
  }
14
14
 
15
- const int ib = i/qk; // block index
16
- const int iqs = (i%qk)/qr; // quant index
17
- const int iybs = i - i%qk; // y block start index
18
- const int y_offset = qr == 1 ? 1 : qk/2;
15
+ const int64_t ib = i/qk; // block index
16
+ const int64_t iqs = (i%qk)/qr; // quant index
17
+ const int64_t iybs = i - i%qk; // y block start index
18
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
19
19
 
20
20
  // dequantize
21
21
  dfloat2 v;
@@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
27
27
 
28
28
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
29
29
  static void dequantize_block_sycl(const void *__restrict__ vx,
30
- dst_t *__restrict__ y, const int k,
30
+ dst_t *__restrict__ y, const int64_t k,
31
31
  dpct::queue_ptr stream) {
32
- const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
32
+ const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
33
33
  {
34
34
  dpct::has_capability_or_fail(stream->get_device(),
35
35
  {sycl::aspect::fp16});
@@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
45
45
  }
46
46
 
47
47
  template <typename dst_t>
48
- static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
48
+ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
49
49
  dpct::queue_ptr stream) {
50
- const int nb = k / QK_K;
50
+ const int64_t nb = k / QK_K;
51
51
  #if QK_K == 256
52
52
  {
53
53
  dpct::has_capability_or_fail(stream->get_device(),
@@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
77
77
  }
78
78
 
79
79
  template <typename dst_t>
80
- static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
80
+ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
81
81
  dpct::queue_ptr stream) {
82
- const int nb = k / QK_K;
82
+ const int64_t nb = k / QK_K;
83
83
  #if QK_K == 256
84
84
  {
85
85
  dpct::has_capability_or_fail(stream->get_device(),
@@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
108
108
  }
109
109
 
110
110
  template <typename dst_t>
111
- static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
111
+ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
112
112
  dpct::queue_ptr stream) {
113
- const int nb32 = k / 32;
114
- const int nb = (k + 255) / 256;
113
+ const int64_t nb32 = k / 32;
114
+ const int64_t nb = (k + 255) / 256;
115
115
  {
116
116
  dpct::has_capability_or_fail(stream->get_device(),
117
117
  {sycl::aspect::fp16});
@@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
126
126
  }
127
127
 
128
128
  template <typename dst_t>
129
- static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
129
+ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
130
130
  dpct::queue_ptr stream) {
131
- const int nb32 = k / 32;
132
- const int nb = (k + 255) / 256;
131
+ const int64_t nb32 = k / 32;
132
+ const int64_t nb = (k + 255) / 256;
133
133
  {
134
134
  dpct::has_capability_or_fail(stream->get_device(),
135
135
  {sycl::aspect::fp16});
@@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
145
145
 
146
146
 
147
147
  template <typename dst_t>
148
- static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
148
+ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
149
149
  dpct::queue_ptr stream) {
150
- const int nb = k / QK_K;
150
+ const int64_t nb = k / QK_K;
151
151
  {
152
152
  dpct::has_capability_or_fail(stream->get_device(),
153
153
  {sycl::aspect::fp16});
@@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
165
165
  }
166
166
 
167
167
  template <typename dst_t>
168
- static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
168
+ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
169
169
  dpct::queue_ptr stream) {
170
- const int nb = k / QK_K;
170
+ const int64_t nb = k / QK_K;
171
171
  #if QK_K == 256
172
172
  {
173
173
  dpct::has_capability_or_fail(stream->get_device(),
@@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
197
197
  }
198
198
 
199
199
  template <typename dst_t>
200
- static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
200
+ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
201
201
  dpct::queue_ptr stream) {
202
- const int nb = k / QK_K;
202
+ const int64_t nb = k / QK_K;
203
203
  #if QK_K == 256
204
204
  {
205
205
  dpct::has_capability_or_fail(stream->get_device(),
@@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
229
229
  }
230
230
 
231
231
  template <typename dst_t>
232
- static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
232
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
233
233
  dpct::queue_ptr stream) {
234
- const int nb = k / QK_K;
234
+ const int64_t nb = k / QK_K;
235
235
  {
236
236
  dpct::has_capability_or_fail(stream->get_device(),
237
237
  {sycl::aspect::fp16});
@@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
250
250
  }
251
251
 
252
252
  template <typename dst_t>
253
- static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
253
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
254
254
  dpct::queue_ptr stream) {
255
- const int nb = k / QK_K;
255
+ const int64_t nb = k / QK_K;
256
256
  {
257
257
  dpct::has_capability_or_fail(stream->get_device(),
258
258
  {sycl::aspect::fp16});
@@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
271
271
  }
272
272
 
273
273
  template <typename dst_t>
274
- static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
274
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
275
275
  dpct::queue_ptr stream) {
276
- const int nb = k / QK_K;
276
+ const int64_t nb = k / QK_K;
277
277
  {
278
278
  dpct::has_capability_or_fail(stream->get_device(),
279
279
  {sycl::aspect::fp16});
@@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
292
292
  }
293
293
 
294
294
  template <typename dst_t>
295
- static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
295
+ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
296
296
  dpct::queue_ptr stream) {
297
- const int nb = k / QK_K;
297
+ const int64_t nb = k / QK_K;
298
298
  {
299
299
  dpct::has_capability_or_fail(stream->get_device(),
300
300
  {sycl::aspect::fp16});
@@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
313
313
  }
314
314
 
315
315
  template <typename dst_t>
316
- static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
316
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
317
317
  dpct::queue_ptr stream) {
318
- const int nb = k / QK_K;
318
+ const int64_t nb = k / QK_K;
319
319
  {
320
320
  dpct::has_capability_or_fail(stream->get_device(),
321
321
  {sycl::aspect::fp16});
@@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
333
333
 
334
334
 
335
335
  template <typename dst_t>
336
- static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
336
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
337
337
  dpct::queue_ptr stream) {
338
- const int nb = k / QK_K;
338
+ const int64_t nb = k / QK_K;
339
339
  {
340
340
  dpct::has_capability_or_fail(stream->get_device(),
341
341
  {sycl::aspect::fp16});
@@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
354
354
  }
355
355
 
356
356
  template <typename dst_t>
357
- static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
357
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
358
358
  dpct::queue_ptr stream) {
359
- const int nb = k / QK_K;
359
+ const int64_t nb = k / QK_K;
360
360
  {
361
361
  dpct::has_capability_or_fail(stream->get_device(),
362
362
  {sycl::aspect::fp16});
@@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
374
374
  }
375
375
 
376
376
  template <typename dst_t>
377
- static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
377
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
378
378
  dpct::queue_ptr stream) {
379
- const int nb = (k + QK_K - 1) / QK_K;
379
+ const int64_t nb = (k + QK_K - 1) / QK_K;
380
380
  #if QK_K == 64
381
381
  dequantize_row_iq4_nl_sycl(vx, y, k, stream);
382
382
  #else
@@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
398
398
  }
399
399
 
400
400
  template <typename dst_t>
401
- static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
401
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
402
402
  dpct::queue_ptr stream) {
403
- const int nb = (k + QK_K - 1) / QK_K;
403
+ const int64_t nb = (k + QK_K - 1) / QK_K;
404
404
  {
405
405
  dpct::has_capability_or_fail(stream->get_device(),
406
406
  {sycl::aspect::fp16});
@@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
418
418
  }
419
419
 
420
420
  template <typename src_t, typename dst_t>
421
- static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
421
+ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
422
422
  const sycl::nd_item<3> &item_ct1) {
423
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
424
- item_ct1.get_local_id(2);
425
-
426
- if (i >= k) {
427
- return;
428
- }
423
+ const int64_t work_group_size = item_ct1.get_local_range(2);
424
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
429
425
 
426
+ // make each work-item deal with more elements since sycl global range can not exceed max int
430
427
  const src_t * x = (src_t *) vx;
431
-
432
- y[i] = x[i];
428
+ for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
429
+ y[i] = x[i];
430
+ }
433
431
  }
434
432
 
435
433
  template <typename src_t, typename dst_t>
436
434
  static void convert_unary_sycl(const void *__restrict__ vx,
437
- dst_t *__restrict__ y, const int k,
435
+ dst_t *__restrict__ y, const int64_t k,
438
436
  dpct::queue_ptr stream) {
439
- const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
437
+ const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
438
+
439
+ // decrease global range when it exceeds the max int
440
+ int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
441
+ sycl::range<3> block_nums(1, 1, num_blocks);
442
+ sycl::range<3> local_range(1, 1, local_size);
440
443
  {
441
444
  dpct::has_capability_or_fail(stream->get_device(),
442
445
  {sycl::aspect::fp16});
443
446
 
444
447
  stream->parallel_for(
445
- sycl::nd_range<3>(
446
- sycl::range<3>(1, 1, num_blocks) *
447
- sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
448
- sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
448
+ sycl::nd_range<3>(block_nums * local_range, local_range),
449
449
  [=](sycl::nd_item<3> item_ct1) {
450
450
  convert_unary<src_t>(vx, y, k, item_ct1);
451
451
  });
@@ -17,7 +17,7 @@
17
17
 
18
18
  template <typename T>
19
19
  using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
20
- int k, dpct::queue_ptr stream);
20
+ int64_t k, dpct::queue_ptr stream);
21
21
  typedef to_t_sycl_t<float> to_fp32_sycl_t;
22
22
  typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
23
23