@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
|
|
|
8
8
|
|
|
9
9
|
const int nthreads = item_ct1.get_local_range(2);
|
|
10
10
|
const int nwarps = nthreads / WARP_SIZE;
|
|
11
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
12
11
|
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
|
13
12
|
|
|
14
13
|
for (int col = tid; col < ncols; col += block_size) {
|
|
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|
|
55
54
|
int end = start + group_size;
|
|
56
55
|
const int nthreads = item_ct1.get_local_range(2);
|
|
57
56
|
const int nwarps = nthreads / WARP_SIZE;
|
|
58
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
59
57
|
start += item_ct1.get_local_id(2);
|
|
60
58
|
int nreduce = nwarps / WARP_SIZE;
|
|
61
59
|
|
|
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
|
144
142
|
const int tid = item_ct1.get_local_id(2);
|
|
145
143
|
const int nthreads = item_ct1.get_local_range(2);
|
|
146
144
|
const int nwarps = nthreads / WARP_SIZE;
|
|
147
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
148
145
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
149
146
|
|
|
150
147
|
for (int col = tid; col < ncols; col += block_size) {
|
|
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
202
199
|
}
|
|
203
200
|
else {
|
|
204
201
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
202
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
205
203
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
206
204
|
/*
|
|
207
205
|
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -225,9 +223,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
225
223
|
}
|
|
226
224
|
|
|
227
225
|
static void group_norm_f32_sycl(const float* x, float* dst,
|
|
228
|
-
const int num_groups, const int group_size,
|
|
226
|
+
const int num_groups, const float eps, const int group_size,
|
|
229
227
|
const int ne_elements, queue_ptr stream, int device) {
|
|
230
|
-
static const float eps = 1e-6f;
|
|
231
228
|
if (group_size < 1024) {
|
|
232
229
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
233
230
|
stream->submit([&](sycl::handler& cgh) {
|
|
@@ -245,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
245
242
|
}
|
|
246
243
|
else {
|
|
247
244
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
245
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
248
246
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
249
247
|
/*
|
|
250
248
|
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -291,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
291
289
|
}
|
|
292
290
|
else {
|
|
293
291
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
292
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
294
293
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
295
294
|
/*
|
|
296
295
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|
|
343
342
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
344
343
|
|
|
345
344
|
int num_groups = dst->op_params[0];
|
|
345
|
+
|
|
346
|
+
float eps;
|
|
347
|
+
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
|
348
|
+
|
|
346
349
|
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
|
347
|
-
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
|
350
|
+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
|
348
351
|
|
|
349
352
|
(void)src1;
|
|
350
353
|
(void)dst;
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include <oneapi/mkl.hpp>
|
|
3
|
+
#include "outprod.hpp"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
7
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
11
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
12
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
13
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
14
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
15
|
+
|
|
16
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
17
|
+
|
|
18
|
+
// Get SYCL queue
|
|
19
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
20
|
+
|
|
21
|
+
// Dimension checks
|
|
22
|
+
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
|
23
|
+
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
|
24
|
+
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
|
25
|
+
|
|
26
|
+
// Get data pointers
|
|
27
|
+
const float* src0_d = (const float*)src0->data;
|
|
28
|
+
const float* src1_d = (const float*)src1->data;
|
|
29
|
+
float* dst_d = (float*)dst->data;
|
|
30
|
+
|
|
31
|
+
// GEMM parameters
|
|
32
|
+
const float alpha = 1.0f;
|
|
33
|
+
const float beta = 0.0f;
|
|
34
|
+
|
|
35
|
+
// Handle transposition of src1
|
|
36
|
+
const bool src1_T = ggml_is_transposed(src1);
|
|
37
|
+
const oneapi::mkl::transpose src1_op =
|
|
38
|
+
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
|
39
|
+
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
|
40
|
+
|
|
41
|
+
try {
|
|
42
|
+
// Perform matrix multiplication using oneMKL GEMM
|
|
43
|
+
oneapi::mkl::blas::column_major::gemm(*stream,
|
|
44
|
+
oneapi::mkl::transpose::nontrans, src1_op,
|
|
45
|
+
ne0, ne1, ne01,
|
|
46
|
+
alpha,
|
|
47
|
+
src0_d, ne00,
|
|
48
|
+
src1_d, ldb,
|
|
49
|
+
beta,
|
|
50
|
+
dst_d, ne0);
|
|
51
|
+
}
|
|
52
|
+
catch (sycl::exception const& exc) {
|
|
53
|
+
std::cerr << exc.what() << std::endl;
|
|
54
|
+
GGML_ASSERT(false);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_OUTPROD_HPP
|
|
2
|
+
#define GGML_SYCL_OUTPROD_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
7
|
+
const ggml_tensor* src1, ggml_tensor* dst);
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
#endif // GGML_SYCL_OUTPROD_HPP
|
|
11
|
+
|
|
@@ -25,6 +25,11 @@
|
|
|
25
25
|
#define SYCL_RELU_BLOCK_SIZE 256
|
|
26
26
|
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
|
27
27
|
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
|
28
|
+
#define SYCL_EXP_BLOCK_SIZE 256
|
|
29
|
+
#define SYCL_NEG_BLOCK_SIZE 256
|
|
30
|
+
#define SYCL_SIGMOID_BLOCK_SIZE 256
|
|
31
|
+
#define SYCL_SQRT_BLOCK_SIZE 256
|
|
32
|
+
#define SYCL_SIN_BLOCK_SIZE 256
|
|
28
33
|
#define SYCL_SQR_BLOCK_SIZE 256
|
|
29
34
|
#define SYCL_CPY_BLOCK_SIZE 32
|
|
30
35
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
|
@@ -41,6 +46,9 @@
|
|
|
41
46
|
#define SYCL_ACC_BLOCK_SIZE 256
|
|
42
47
|
#define SYCL_IM2COL_BLOCK_SIZE 256
|
|
43
48
|
#define SYCL_POOL2D_BLOCK_SIZE 256
|
|
49
|
+
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
|
50
|
+
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
|
51
|
+
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
|
44
52
|
|
|
45
53
|
// dmmv = dequantize_mul_mat_vec
|
|
46
54
|
#ifndef GGML_SYCL_DMMV_X
|
|
@@ -226,7 +226,7 @@ void ggml_sycl_op_rope(
|
|
|
226
226
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
227
227
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
228
228
|
|
|
229
|
-
const bool is_neox = mode &
|
|
229
|
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
230
230
|
|
|
231
231
|
const int32_t * pos = (const int32_t *) src1_dd;
|
|
232
232
|
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#include "tsembd.hpp"
|
|
14
|
+
|
|
15
|
+
static void timestep_embedding_f32(
|
|
16
|
+
const float * timesteps, float * dst, const int nb1,
|
|
17
|
+
const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
|
|
18
|
+
// item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
|
|
19
|
+
// item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
|
|
20
|
+
int i = item_ct1.get_group(1);
|
|
21
|
+
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
|
22
|
+
float * embed_data = (float *)((char *)dst + i*nb1);
|
|
23
|
+
|
|
24
|
+
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
|
25
|
+
embed_data[dim] = 0.f;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
int half = dim / 2;
|
|
29
|
+
if (j >= half) {
|
|
30
|
+
return;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
float timestep = timesteps[i];
|
|
34
|
+
float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
|
|
35
|
+
float arg = timestep * freq;
|
|
36
|
+
embed_data[j] = sycl::cos(arg);
|
|
37
|
+
embed_data[j + half] = sycl::sin(arg);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
static void timestep_embedding_f32_sycl(
|
|
41
|
+
const float * x, float * dst, const int ne00, const int nb1,
|
|
42
|
+
const int dim, const int max_period, const queue_ptr& stream) {
|
|
43
|
+
// As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
|
|
44
|
+
int half_ceil = dim / 2;
|
|
45
|
+
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
|
46
|
+
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
|
47
|
+
sycl::range<3> gridDim(1, ne00, num_blocks);
|
|
48
|
+
stream->parallel_for(
|
|
49
|
+
sycl::nd_range<3>(
|
|
50
|
+
gridDim * block_dims, block_dims),
|
|
51
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
52
|
+
timestep_embedding_f32(
|
|
53
|
+
x, dst, nb1, dim, max_period, item_ct1
|
|
54
|
+
);
|
|
55
|
+
});
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
59
|
+
const ggml_tensor *src1, ggml_tensor * dst) {
|
|
60
|
+
const float * src0_d = (const float *)src0->data;
|
|
61
|
+
float * dst_d = (float *)dst->data;
|
|
62
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
63
|
+
|
|
64
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
65
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
66
|
+
|
|
67
|
+
const int dim = dst->op_params[0];
|
|
68
|
+
const int max_period = dst->op_params[1];
|
|
69
|
+
|
|
70
|
+
timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
|
|
71
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#ifndef GGML_SYCL_TSEMBD_HPP
|
|
14
|
+
#define GGML_SYCL_TSEMBD_HPP
|
|
15
|
+
|
|
16
|
+
#include "common.hpp"
|
|
17
|
+
|
|
18
|
+
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
19
|
+
const ggml_tensor *src1, ggml_tensor * dst);
|
|
20
|
+
|
|
21
|
+
#endif // GGML_SYCL_TSEMBD_HPP
|
|
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
|
|
968
968
|
grid1[0] ^ signs[0], signs[0], std::minus<>());
|
|
969
969
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
970
970
|
grid2[0] ^ signs[1], signs[1], std::minus<>());
|
|
971
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
972
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
971
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
972
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
973
973
|
q8 += 8;
|
|
974
974
|
aux32 >>= 7;
|
|
975
975
|
}
|
|
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|
|
1009
1009
|
grid1[0] ^ signs0, signs0, std::minus<>());
|
|
1010
1010
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
1011
1011
|
grid2[0] ^ signs1, signs1, std::minus<>());
|
|
1012
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
1013
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
1012
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
1013
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
1014
1014
|
q8 += 8;
|
|
1015
1015
|
}
|
|
1016
1016
|
const float d =
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include "wkv6.hpp"
|
|
3
|
+
|
|
4
|
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
5
|
+
|
|
6
|
+
// Helper function for the main kernel
|
|
7
|
+
static void rwkv_wkv_f32_kernel(
|
|
8
|
+
const int B, const int T, const int C, const int H,
|
|
9
|
+
const float* k, const float* v, const float* r,
|
|
10
|
+
const float* tf, const float* td, const float* s,
|
|
11
|
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
12
|
+
|
|
13
|
+
const int tid = item_ct1.get_local_id(2);
|
|
14
|
+
const int bid = item_ct1.get_group(2);
|
|
15
|
+
|
|
16
|
+
const int head_size = WKV_BLOCK_SIZE;
|
|
17
|
+
const int batch_i = bid / H;
|
|
18
|
+
const int head_i = bid % H;
|
|
19
|
+
const int state_size = C * head_size;
|
|
20
|
+
const int n_seq_tokens = T / B;
|
|
21
|
+
|
|
22
|
+
// Set up shared memory pointers
|
|
23
|
+
float* _k = shared_mem;
|
|
24
|
+
float* _r = _k + head_size;
|
|
25
|
+
float* _tf = _r + head_size;
|
|
26
|
+
float* _td = _tf + head_size;
|
|
27
|
+
|
|
28
|
+
// Local state array
|
|
29
|
+
float state[WKV_BLOCK_SIZE];
|
|
30
|
+
|
|
31
|
+
// Load initial state
|
|
32
|
+
#pragma unroll
|
|
33
|
+
for (int i = 0; i < head_size; i++) {
|
|
34
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Sync threads before shared memory operations
|
|
38
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
39
|
+
|
|
40
|
+
// Load time-mixing parameters
|
|
41
|
+
_tf[tid] = tf[head_i * head_size + tid];
|
|
42
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
43
|
+
|
|
44
|
+
// Main sequence processing loop
|
|
45
|
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
46
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
47
|
+
t += C) {
|
|
48
|
+
|
|
49
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
50
|
+
|
|
51
|
+
// Load current timestep data to shared memory
|
|
52
|
+
_k[tid] = k[t];
|
|
53
|
+
_r[tid] = r[t];
|
|
54
|
+
_td[tid] = td[t];
|
|
55
|
+
|
|
56
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
57
|
+
|
|
58
|
+
const float _v = v[t];
|
|
59
|
+
float y = 0;
|
|
60
|
+
|
|
61
|
+
// Process in chunks of 4 for better vectorization
|
|
62
|
+
sycl::float4 k4, r4, tf4, td4, s4, kv4;
|
|
63
|
+
#pragma unroll
|
|
64
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
65
|
+
// Load data in vec4 chunks
|
|
66
|
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
67
|
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
68
|
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
69
|
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
70
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
71
|
+
|
|
72
|
+
// Compute key-value product
|
|
73
|
+
sycl::float4 kv4 = k4 * _v;
|
|
74
|
+
|
|
75
|
+
// Accumulate weighted sum
|
|
76
|
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
77
|
+
|
|
78
|
+
// Update state
|
|
79
|
+
s4 = s4 * td4 + kv4;
|
|
80
|
+
|
|
81
|
+
// Store updated state
|
|
82
|
+
state[j] = s4.x();
|
|
83
|
+
state[j+1] = s4.y();
|
|
84
|
+
state[j+2] = s4.z();
|
|
85
|
+
state[j+3] = s4.w();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
dst[t] = y;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// Save final state
|
|
92
|
+
#pragma unroll
|
|
93
|
+
for (int i = 0; i < head_size; i++) {
|
|
94
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
99
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
100
|
+
|
|
101
|
+
const float* k_d = (const float*)dst->src[0]->data;
|
|
102
|
+
const float* v_d = (const float*)dst->src[1]->data;
|
|
103
|
+
const float* r_d = (const float*)dst->src[2]->data;
|
|
104
|
+
const float* tf_d = (const float*)dst->src[3]->data;
|
|
105
|
+
const float* td_d = (const float*)dst->src[4]->data;
|
|
106
|
+
const float* s_d = (const float*)dst->src[5]->data;
|
|
107
|
+
float* dst_d = (float*)dst->data;
|
|
108
|
+
|
|
109
|
+
const int64_t B = dst->src[5]->ne[1];
|
|
110
|
+
const int64_t T = dst->src[0]->ne[3];
|
|
111
|
+
const int64_t C = dst->ne[0];
|
|
112
|
+
const int64_t H = dst->src[0]->ne[2];
|
|
113
|
+
|
|
114
|
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
115
|
+
GGML_ASSERT(C % H == 0);
|
|
116
|
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
117
|
+
|
|
118
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
119
|
+
|
|
120
|
+
// Calculate execution configuration
|
|
121
|
+
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
|
122
|
+
sycl::range<3> block_dims(1, 1, C / H);
|
|
123
|
+
sycl::range<3> grid_dims(1, 1, B * H);
|
|
124
|
+
|
|
125
|
+
// Submit kernel
|
|
126
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
127
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
128
|
+
|
|
129
|
+
cgh.parallel_for(
|
|
130
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
131
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
132
|
+
rwkv_wkv_f32_kernel(
|
|
133
|
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
134
|
+
item_ct1, shared_mem_acc.get_pointer()
|
|
135
|
+
);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
138
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#include "ggml-threading.h"
|
|
2
|
+
#include <mutex>
|
|
3
|
+
|
|
4
|
+
std::mutex ggml_critical_section_mutex;
|
|
5
|
+
|
|
6
|
+
void ggml_critical_section_start() {
|
|
7
|
+
ggml_critical_section_mutex.lock();
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
void ggml_critical_section_end(void) {
|
|
11
|
+
ggml_critical_section_mutex.unlock();
|
|
12
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
|
2
|
+
|
|
3
|
+
if (Vulkan_FOUND)
|
|
4
|
+
message(STATUS "Vulkan found")
|
|
5
|
+
|
|
6
|
+
add_library(ggml-vulkan
|
|
7
|
+
ggml-vulkan.cpp
|
|
8
|
+
../../include/ggml-vulkan.h
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
target_link_libraries(ggml-vulkan PRIVATE ggml-base Vulkan::Vulkan)
|
|
12
|
+
target_include_directories(ggml-vulkan PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
|
|
13
|
+
|
|
14
|
+
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
|
|
15
|
+
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
|
|
16
|
+
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
17
|
+
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
|
|
18
|
+
endif()
|
|
19
|
+
|
|
20
|
+
if (GGML_VULKAN_CHECK_RESULTS)
|
|
21
|
+
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
|
22
|
+
endif()
|
|
23
|
+
|
|
24
|
+
if (GGML_VULKAN_DEBUG)
|
|
25
|
+
add_compile_definitions(GGML_VULKAN_DEBUG)
|
|
26
|
+
endif()
|
|
27
|
+
|
|
28
|
+
if (GGML_VULKAN_MEMORY_DEBUG)
|
|
29
|
+
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
|
|
30
|
+
endif()
|
|
31
|
+
|
|
32
|
+
if (GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
33
|
+
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
34
|
+
endif()
|
|
35
|
+
|
|
36
|
+
if (GGML_VULKAN_PERF)
|
|
37
|
+
add_compile_definitions(GGML_VULKAN_PERF)
|
|
38
|
+
endif()
|
|
39
|
+
|
|
40
|
+
if (GGML_VULKAN_VALIDATE)
|
|
41
|
+
add_compile_definitions(GGML_VULKAN_VALIDATE)
|
|
42
|
+
endif()
|
|
43
|
+
|
|
44
|
+
if (GGML_VULKAN_RUN_TESTS)
|
|
45
|
+
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
|
46
|
+
endif()
|
|
47
|
+
|
|
48
|
+
add_subdirectory(vulkan-shaders)
|
|
49
|
+
|
|
50
|
+
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
|
|
51
|
+
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
|
52
|
+
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
|
53
|
+
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
|
54
|
+
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
|
55
|
+
|
|
56
|
+
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
|
57
|
+
|
|
58
|
+
add_custom_command(
|
|
59
|
+
OUTPUT ${_ggml_vk_header}
|
|
60
|
+
${_ggml_vk_source}
|
|
61
|
+
|
|
62
|
+
COMMAND ${_ggml_vk_genshaders_cmd}
|
|
63
|
+
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
|
64
|
+
--input-dir ${_ggml_vk_input_dir}
|
|
65
|
+
--output-dir ${_ggml_vk_output_dir}
|
|
66
|
+
--target-hpp ${_ggml_vk_header}
|
|
67
|
+
--target-cpp ${_ggml_vk_source}
|
|
68
|
+
--no-clean
|
|
69
|
+
|
|
70
|
+
DEPENDS ${_ggml_vk_shader_deps}
|
|
71
|
+
COMMENT "Generate vulkan shaders"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
|
|
75
|
+
|
|
76
|
+
else()
|
|
77
|
+
message(WARNING "Vulkan not found")
|
|
78
|
+
endif()
|