@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -3,19 +3,19 @@
|
|
|
3
3
|
#include "presets.hpp"
|
|
4
4
|
|
|
5
5
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
6
|
-
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const
|
|
6
|
+
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
|
7
7
|
const sycl::nd_item<3> &item_ct1) {
|
|
8
|
-
const
|
|
8
|
+
const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
9
9
|
item_ct1.get_local_id(2));
|
|
10
10
|
|
|
11
11
|
if (i >= k) {
|
|
12
12
|
return;
|
|
13
13
|
}
|
|
14
14
|
|
|
15
|
-
const
|
|
16
|
-
const
|
|
17
|
-
const
|
|
18
|
-
const
|
|
15
|
+
const int64_t ib = i/qk; // block index
|
|
16
|
+
const int64_t iqs = (i%qk)/qr; // quant index
|
|
17
|
+
const int64_t iybs = i - i%qk; // y block start index
|
|
18
|
+
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
|
19
19
|
|
|
20
20
|
// dequantize
|
|
21
21
|
dfloat2 v;
|
|
@@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
|
|
|
27
27
|
|
|
28
28
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
29
29
|
static void dequantize_block_sycl(const void *__restrict__ vx,
|
|
30
|
-
dst_t *__restrict__ y, const
|
|
30
|
+
dst_t *__restrict__ y, const int64_t k,
|
|
31
31
|
dpct::queue_ptr stream) {
|
|
32
|
-
const
|
|
32
|
+
const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
|
33
33
|
{
|
|
34
34
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
35
35
|
{sycl::aspect::fp16});
|
|
@@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
|
|
45
45
|
}
|
|
46
46
|
|
|
47
47
|
template <typename dst_t>
|
|
48
|
-
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const
|
|
48
|
+
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
49
49
|
dpct::queue_ptr stream) {
|
|
50
|
-
const
|
|
50
|
+
const int64_t nb = k / QK_K;
|
|
51
51
|
#if QK_K == 256
|
|
52
52
|
{
|
|
53
53
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
@@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
|
|
77
77
|
}
|
|
78
78
|
|
|
79
79
|
template <typename dst_t>
|
|
80
|
-
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const
|
|
80
|
+
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
81
81
|
dpct::queue_ptr stream) {
|
|
82
|
-
const
|
|
82
|
+
const int64_t nb = k / QK_K;
|
|
83
83
|
#if QK_K == 256
|
|
84
84
|
{
|
|
85
85
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
@@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
|
|
108
108
|
}
|
|
109
109
|
|
|
110
110
|
template <typename dst_t>
|
|
111
|
-
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const
|
|
111
|
+
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
112
112
|
dpct::queue_ptr stream) {
|
|
113
|
-
const
|
|
114
|
-
const
|
|
113
|
+
const int64_t nb32 = k / 32;
|
|
114
|
+
const int64_t nb = (k + 255) / 256;
|
|
115
115
|
{
|
|
116
116
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
117
117
|
{sycl::aspect::fp16});
|
|
@@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
|
|
126
126
|
}
|
|
127
127
|
|
|
128
128
|
template <typename dst_t>
|
|
129
|
-
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const
|
|
129
|
+
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
130
130
|
dpct::queue_ptr stream) {
|
|
131
|
-
const
|
|
132
|
-
const
|
|
131
|
+
const int64_t nb32 = k / 32;
|
|
132
|
+
const int64_t nb = (k + 255) / 256;
|
|
133
133
|
{
|
|
134
134
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
135
135
|
{sycl::aspect::fp16});
|
|
@@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
template <typename dst_t>
|
|
148
|
-
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const
|
|
148
|
+
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
149
149
|
dpct::queue_ptr stream) {
|
|
150
|
-
const
|
|
150
|
+
const int64_t nb = k / QK_K;
|
|
151
151
|
{
|
|
152
152
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
153
153
|
{sycl::aspect::fp16});
|
|
@@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
|
|
165
165
|
}
|
|
166
166
|
|
|
167
167
|
template <typename dst_t>
|
|
168
|
-
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const
|
|
168
|
+
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
169
169
|
dpct::queue_ptr stream) {
|
|
170
|
-
const
|
|
170
|
+
const int64_t nb = k / QK_K;
|
|
171
171
|
#if QK_K == 256
|
|
172
172
|
{
|
|
173
173
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
@@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
|
|
197
197
|
}
|
|
198
198
|
|
|
199
199
|
template <typename dst_t>
|
|
200
|
-
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const
|
|
200
|
+
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
201
201
|
dpct::queue_ptr stream) {
|
|
202
|
-
const
|
|
202
|
+
const int64_t nb = k / QK_K;
|
|
203
203
|
#if QK_K == 256
|
|
204
204
|
{
|
|
205
205
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
@@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
|
|
229
229
|
}
|
|
230
230
|
|
|
231
231
|
template <typename dst_t>
|
|
232
|
-
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const
|
|
232
|
+
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
233
233
|
dpct::queue_ptr stream) {
|
|
234
|
-
const
|
|
234
|
+
const int64_t nb = k / QK_K;
|
|
235
235
|
{
|
|
236
236
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
237
237
|
{sycl::aspect::fp16});
|
|
@@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
|
|
250
250
|
}
|
|
251
251
|
|
|
252
252
|
template <typename dst_t>
|
|
253
|
-
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const
|
|
253
|
+
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
254
254
|
dpct::queue_ptr stream) {
|
|
255
|
-
const
|
|
255
|
+
const int64_t nb = k / QK_K;
|
|
256
256
|
{
|
|
257
257
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
258
258
|
{sycl::aspect::fp16});
|
|
@@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
|
|
271
271
|
}
|
|
272
272
|
|
|
273
273
|
template <typename dst_t>
|
|
274
|
-
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const
|
|
274
|
+
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
275
275
|
dpct::queue_ptr stream) {
|
|
276
|
-
const
|
|
276
|
+
const int64_t nb = k / QK_K;
|
|
277
277
|
{
|
|
278
278
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
279
279
|
{sycl::aspect::fp16});
|
|
@@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
|
|
292
292
|
}
|
|
293
293
|
|
|
294
294
|
template <typename dst_t>
|
|
295
|
-
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const
|
|
295
|
+
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
296
296
|
dpct::queue_ptr stream) {
|
|
297
|
-
const
|
|
297
|
+
const int64_t nb = k / QK_K;
|
|
298
298
|
{
|
|
299
299
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
300
300
|
{sycl::aspect::fp16});
|
|
@@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
|
|
313
313
|
}
|
|
314
314
|
|
|
315
315
|
template <typename dst_t>
|
|
316
|
-
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const
|
|
316
|
+
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
317
317
|
dpct::queue_ptr stream) {
|
|
318
|
-
const
|
|
318
|
+
const int64_t nb = k / QK_K;
|
|
319
319
|
{
|
|
320
320
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
321
321
|
{sycl::aspect::fp16});
|
|
@@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
|
|
333
333
|
|
|
334
334
|
|
|
335
335
|
template <typename dst_t>
|
|
336
|
-
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const
|
|
336
|
+
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
337
337
|
dpct::queue_ptr stream) {
|
|
338
|
-
const
|
|
338
|
+
const int64_t nb = k / QK_K;
|
|
339
339
|
{
|
|
340
340
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
341
341
|
{sycl::aspect::fp16});
|
|
@@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
|
|
354
354
|
}
|
|
355
355
|
|
|
356
356
|
template <typename dst_t>
|
|
357
|
-
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const
|
|
357
|
+
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
358
358
|
dpct::queue_ptr stream) {
|
|
359
|
-
const
|
|
359
|
+
const int64_t nb = k / QK_K;
|
|
360
360
|
{
|
|
361
361
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
362
362
|
{sycl::aspect::fp16});
|
|
@@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
|
|
374
374
|
}
|
|
375
375
|
|
|
376
376
|
template <typename dst_t>
|
|
377
|
-
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const
|
|
377
|
+
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
378
378
|
dpct::queue_ptr stream) {
|
|
379
|
-
const
|
|
379
|
+
const int64_t nb = (k + QK_K - 1) / QK_K;
|
|
380
380
|
#if QK_K == 64
|
|
381
381
|
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
|
382
382
|
#else
|
|
@@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
|
|
398
398
|
}
|
|
399
399
|
|
|
400
400
|
template <typename dst_t>
|
|
401
|
-
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const
|
|
401
|
+
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
402
402
|
dpct::queue_ptr stream) {
|
|
403
|
-
const
|
|
403
|
+
const int64_t nb = (k + QK_K - 1) / QK_K;
|
|
404
404
|
{
|
|
405
405
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
406
406
|
{sycl::aspect::fp16});
|
|
@@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
|
|
418
418
|
}
|
|
419
419
|
|
|
420
420
|
template <typename src_t, typename dst_t>
|
|
421
|
-
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const
|
|
421
|
+
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
|
422
422
|
const sycl::nd_item<3> &item_ct1) {
|
|
423
|
-
const
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
if (i >= k) {
|
|
427
|
-
return;
|
|
428
|
-
}
|
|
423
|
+
const int64_t work_group_size = item_ct1.get_local_range(2);
|
|
424
|
+
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
|
429
425
|
|
|
426
|
+
// make each work-item deal with more elements since sycl global range can not exceed max int
|
|
430
427
|
const src_t * x = (src_t *) vx;
|
|
431
|
-
|
|
432
|
-
|
|
428
|
+
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
|
429
|
+
y[i] = x[i];
|
|
430
|
+
}
|
|
433
431
|
}
|
|
434
432
|
|
|
435
433
|
template <typename src_t, typename dst_t>
|
|
436
434
|
static void convert_unary_sycl(const void *__restrict__ vx,
|
|
437
|
-
dst_t *__restrict__ y, const
|
|
435
|
+
dst_t *__restrict__ y, const int64_t k,
|
|
438
436
|
dpct::queue_ptr stream) {
|
|
439
|
-
const
|
|
437
|
+
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
|
438
|
+
|
|
439
|
+
// decrease global range when it exceeds the max int
|
|
440
|
+
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
|
|
441
|
+
sycl::range<3> block_nums(1, 1, num_blocks);
|
|
442
|
+
sycl::range<3> local_range(1, 1, local_size);
|
|
440
443
|
{
|
|
441
444
|
dpct::has_capability_or_fail(stream->get_device(),
|
|
442
445
|
{sycl::aspect::fp16});
|
|
443
446
|
|
|
444
447
|
stream->parallel_for(
|
|
445
|
-
sycl::nd_range<3>(
|
|
446
|
-
sycl::range<3>(1, 1, num_blocks) *
|
|
447
|
-
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
|
448
|
-
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
|
448
|
+
sycl::nd_range<3>(block_nums * local_range, local_range),
|
|
449
449
|
[=](sycl::nd_item<3> item_ct1) {
|
|
450
450
|
convert_unary<src_t>(vx, y, k, item_ct1);
|
|
451
451
|
});
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
template <typename T>
|
|
19
19
|
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
|
20
|
-
|
|
20
|
+
int64_t k, dpct::queue_ptr stream);
|
|
21
21
|
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
|
22
22
|
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
|
23
23
|
|