@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -3,19 +3,19 @@
3
3
  #include "presets.hpp"
4
4
 
5
5
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
6
- static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
6
+ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
7
7
  const sycl::nd_item<3> &item_ct1) {
8
- const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8
+ const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
9
  item_ct1.get_local_id(2));
10
10
 
11
11
  if (i >= k) {
12
12
  return;
13
13
  }
14
14
 
15
- const int ib = i/qk; // block index
16
- const int iqs = (i%qk)/qr; // quant index
17
- const int iybs = i - i%qk; // y block start index
18
- const int y_offset = qr == 1 ? 1 : qk/2;
15
+ const int64_t ib = i/qk; // block index
16
+ const int64_t iqs = (i%qk)/qr; // quant index
17
+ const int64_t iybs = i - i%qk; // y block start index
18
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
19
19
 
20
20
  // dequantize
21
21
  dfloat2 v;
@@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
27
27
 
28
28
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
29
29
  static void dequantize_block_sycl(const void *__restrict__ vx,
30
- dst_t *__restrict__ y, const int k,
30
+ dst_t *__restrict__ y, const int64_t k,
31
31
  dpct::queue_ptr stream) {
32
- const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
32
+ const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
33
33
  {
34
34
  dpct::has_capability_or_fail(stream->get_device(),
35
35
  {sycl::aspect::fp16});
@@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
45
45
  }
46
46
 
47
47
  template <typename dst_t>
48
- static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
48
+ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
49
49
  dpct::queue_ptr stream) {
50
- const int nb = k / QK_K;
50
+ const int64_t nb = k / QK_K;
51
51
  #if QK_K == 256
52
52
  {
53
53
  dpct::has_capability_or_fail(stream->get_device(),
@@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
77
77
  }
78
78
 
79
79
  template <typename dst_t>
80
- static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
80
+ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
81
81
  dpct::queue_ptr stream) {
82
- const int nb = k / QK_K;
82
+ const int64_t nb = k / QK_K;
83
83
  #if QK_K == 256
84
84
  {
85
85
  dpct::has_capability_or_fail(stream->get_device(),
@@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
108
108
  }
109
109
 
110
110
  template <typename dst_t>
111
- static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
111
+ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
112
112
  dpct::queue_ptr stream) {
113
- const int nb32 = k / 32;
114
- const int nb = (k + 255) / 256;
113
+ const int64_t nb32 = k / 32;
114
+ const int64_t nb = (k + 255) / 256;
115
115
  {
116
116
  dpct::has_capability_or_fail(stream->get_device(),
117
117
  {sycl::aspect::fp16});
@@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
126
126
  }
127
127
 
128
128
  template <typename dst_t>
129
- static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
129
+ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
130
130
  dpct::queue_ptr stream) {
131
- const int nb32 = k / 32;
132
- const int nb = (k + 255) / 256;
131
+ const int64_t nb32 = k / 32;
132
+ const int64_t nb = (k + 255) / 256;
133
133
  {
134
134
  dpct::has_capability_or_fail(stream->get_device(),
135
135
  {sycl::aspect::fp16});
@@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
145
145
 
146
146
 
147
147
  template <typename dst_t>
148
- static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
148
+ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
149
149
  dpct::queue_ptr stream) {
150
- const int nb = k / QK_K;
150
+ const int64_t nb = k / QK_K;
151
151
  {
152
152
  dpct::has_capability_or_fail(stream->get_device(),
153
153
  {sycl::aspect::fp16});
@@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
165
165
  }
166
166
 
167
167
  template <typename dst_t>
168
- static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
168
+ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
169
169
  dpct::queue_ptr stream) {
170
- const int nb = k / QK_K;
170
+ const int64_t nb = k / QK_K;
171
171
  #if QK_K == 256
172
172
  {
173
173
  dpct::has_capability_or_fail(stream->get_device(),
@@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
197
197
  }
198
198
 
199
199
  template <typename dst_t>
200
- static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
200
+ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
201
201
  dpct::queue_ptr stream) {
202
- const int nb = k / QK_K;
202
+ const int64_t nb = k / QK_K;
203
203
  #if QK_K == 256
204
204
  {
205
205
  dpct::has_capability_or_fail(stream->get_device(),
@@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
229
229
  }
230
230
 
231
231
  template <typename dst_t>
232
- static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
232
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
233
233
  dpct::queue_ptr stream) {
234
- const int nb = k / QK_K;
234
+ const int64_t nb = k / QK_K;
235
235
  {
236
236
  dpct::has_capability_or_fail(stream->get_device(),
237
237
  {sycl::aspect::fp16});
@@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
250
250
  }
251
251
 
252
252
  template <typename dst_t>
253
- static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
253
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
254
254
  dpct::queue_ptr stream) {
255
- const int nb = k / QK_K;
255
+ const int64_t nb = k / QK_K;
256
256
  {
257
257
  dpct::has_capability_or_fail(stream->get_device(),
258
258
  {sycl::aspect::fp16});
@@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
271
271
  }
272
272
 
273
273
  template <typename dst_t>
274
- static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
274
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
275
275
  dpct::queue_ptr stream) {
276
- const int nb = k / QK_K;
276
+ const int64_t nb = k / QK_K;
277
277
  {
278
278
  dpct::has_capability_or_fail(stream->get_device(),
279
279
  {sycl::aspect::fp16});
@@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
292
292
  }
293
293
 
294
294
  template <typename dst_t>
295
- static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
295
+ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
296
296
  dpct::queue_ptr stream) {
297
- const int nb = k / QK_K;
297
+ const int64_t nb = k / QK_K;
298
298
  {
299
299
  dpct::has_capability_or_fail(stream->get_device(),
300
300
  {sycl::aspect::fp16});
@@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
313
313
  }
314
314
 
315
315
  template <typename dst_t>
316
- static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
316
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
317
317
  dpct::queue_ptr stream) {
318
- const int nb = k / QK_K;
318
+ const int64_t nb = k / QK_K;
319
319
  {
320
320
  dpct::has_capability_or_fail(stream->get_device(),
321
321
  {sycl::aspect::fp16});
@@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
333
333
 
334
334
 
335
335
  template <typename dst_t>
336
- static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
336
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
337
337
  dpct::queue_ptr stream) {
338
- const int nb = k / QK_K;
338
+ const int64_t nb = k / QK_K;
339
339
  {
340
340
  dpct::has_capability_or_fail(stream->get_device(),
341
341
  {sycl::aspect::fp16});
@@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
354
354
  }
355
355
 
356
356
  template <typename dst_t>
357
- static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
357
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
358
358
  dpct::queue_ptr stream) {
359
- const int nb = k / QK_K;
359
+ const int64_t nb = k / QK_K;
360
360
  {
361
361
  dpct::has_capability_or_fail(stream->get_device(),
362
362
  {sycl::aspect::fp16});
@@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
374
374
  }
375
375
 
376
376
  template <typename dst_t>
377
- static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
377
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
378
378
  dpct::queue_ptr stream) {
379
- const int nb = (k + QK_K - 1) / QK_K;
379
+ const int64_t nb = (k + QK_K - 1) / QK_K;
380
380
  #if QK_K == 64
381
381
  dequantize_row_iq4_nl_sycl(vx, y, k, stream);
382
382
  #else
@@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
398
398
  }
399
399
 
400
400
  template <typename dst_t>
401
- static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
401
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
402
402
  dpct::queue_ptr stream) {
403
- const int nb = (k + QK_K - 1) / QK_K;
403
+ const int64_t nb = (k + QK_K - 1) / QK_K;
404
404
  {
405
405
  dpct::has_capability_or_fail(stream->get_device(),
406
406
  {sycl::aspect::fp16});
@@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
418
418
  }
419
419
 
420
420
  template <typename src_t, typename dst_t>
421
- static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
421
+ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
422
422
  const sycl::nd_item<3> &item_ct1) {
423
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
424
- item_ct1.get_local_id(2);
425
-
426
- if (i >= k) {
427
- return;
428
- }
423
+ const int64_t work_group_size = item_ct1.get_local_range(2);
424
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
429
425
 
426
+ // make each work-item deal with more elements since sycl global range can not exceed max int
430
427
  const src_t * x = (src_t *) vx;
431
-
432
- y[i] = x[i];
428
+ for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
429
+ y[i] = x[i];
430
+ }
433
431
  }
434
432
 
435
433
  template <typename src_t, typename dst_t>
436
434
  static void convert_unary_sycl(const void *__restrict__ vx,
437
- dst_t *__restrict__ y, const int k,
435
+ dst_t *__restrict__ y, const int64_t k,
438
436
  dpct::queue_ptr stream) {
439
- const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
437
+ const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
438
+
439
+ // decrease global range when it exceeds the max int
440
+ int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
441
+ sycl::range<3> block_nums(1, 1, num_blocks);
442
+ sycl::range<3> local_range(1, 1, local_size);
440
443
  {
441
444
  dpct::has_capability_or_fail(stream->get_device(),
442
445
  {sycl::aspect::fp16});
443
446
 
444
447
  stream->parallel_for(
445
- sycl::nd_range<3>(
446
- sycl::range<3>(1, 1, num_blocks) *
447
- sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
448
- sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
448
+ sycl::nd_range<3>(block_nums * local_range, local_range),
449
449
  [=](sycl::nd_item<3> item_ct1) {
450
450
  convert_unary<src_t>(vx, y, k, item_ct1);
451
451
  });
@@ -17,7 +17,7 @@
17
17
 
18
18
  template <typename T>
19
19
  using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
20
- int k, dpct::queue_ptr stream);
20
+ int64_t k, dpct::queue_ptr stream);
21
21
  typedef to_t_sycl_t<float> to_fp32_sycl_t;
22
22
  typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
23
23