@fugood/llama.node 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +29 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +17 -1
- package/src/LlamaContext.cpp +86 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -14,51 +14,51 @@
|
|
|
14
14
|
#include <vector>
|
|
15
15
|
|
|
16
16
|
struct ggml_opt_dataset {
|
|
17
|
-
struct ggml_context * ctx;
|
|
18
|
-
ggml_backend_buffer_t buf;
|
|
19
|
-
struct ggml_tensor * data;
|
|
20
|
-
struct ggml_tensor * labels;
|
|
17
|
+
struct ggml_context * ctx = nullptr;
|
|
18
|
+
ggml_backend_buffer_t buf = nullptr;
|
|
19
|
+
struct ggml_tensor * data = nullptr;
|
|
20
|
+
struct ggml_tensor * labels = nullptr;
|
|
21
21
|
|
|
22
|
-
int64_t ndata;
|
|
23
|
-
int64_t ndata_shard;
|
|
24
|
-
size_t nbs_data;
|
|
25
|
-
size_t nbs_labels;
|
|
22
|
+
int64_t ndata = -1;
|
|
23
|
+
int64_t ndata_shard = -1;
|
|
24
|
+
size_t nbs_data = -1;
|
|
25
|
+
size_t nbs_labels = -1;
|
|
26
26
|
|
|
27
27
|
std::vector<int64_t> permutation;
|
|
28
28
|
};
|
|
29
29
|
|
|
30
30
|
struct ggml_opt_context {
|
|
31
|
-
ggml_backend_sched_t backend_sched;
|
|
32
|
-
ggml_cgraph * allocated_graph;
|
|
33
|
-
ggml_cgraph * allocated_graph_copy;
|
|
34
|
-
struct ggml_context * ctx_static;
|
|
35
|
-
struct ggml_context * ctx_static_cpu;
|
|
36
|
-
struct ggml_context * ctx_compute;
|
|
37
|
-
struct ggml_context * ctx_copy;
|
|
38
|
-
ggml_backend_buffer_t buf_static;
|
|
39
|
-
ggml_backend_buffer_t buf_static_cpu;
|
|
31
|
+
ggml_backend_sched_t backend_sched = nullptr;
|
|
32
|
+
ggml_cgraph * allocated_graph = nullptr;
|
|
33
|
+
ggml_cgraph * allocated_graph_copy = nullptr;
|
|
34
|
+
struct ggml_context * ctx_static = nullptr;
|
|
35
|
+
struct ggml_context * ctx_static_cpu = nullptr;
|
|
36
|
+
struct ggml_context * ctx_compute = nullptr;
|
|
37
|
+
struct ggml_context * ctx_copy = nullptr;
|
|
38
|
+
ggml_backend_buffer_t buf_static = nullptr;
|
|
39
|
+
ggml_backend_buffer_t buf_static_cpu = nullptr;
|
|
40
40
|
std::mt19937 rng;
|
|
41
41
|
|
|
42
|
-
struct ggml_tensor * inputs;
|
|
43
|
-
struct ggml_tensor * outputs;
|
|
44
|
-
struct ggml_tensor * labels;
|
|
42
|
+
struct ggml_tensor * inputs = nullptr;
|
|
43
|
+
struct ggml_tensor * outputs = nullptr;
|
|
44
|
+
struct ggml_tensor * labels = nullptr;
|
|
45
45
|
|
|
46
|
-
struct ggml_tensor * loss;
|
|
47
|
-
struct ggml_tensor * pred;
|
|
48
|
-
struct ggml_tensor * ncorrect;
|
|
46
|
+
struct ggml_tensor * loss = nullptr;
|
|
47
|
+
struct ggml_tensor * pred = nullptr;
|
|
48
|
+
struct ggml_tensor * ncorrect = nullptr;
|
|
49
49
|
|
|
50
|
-
struct ggml_cgraph * gf;
|
|
51
|
-
struct ggml_cgraph * gb_grad;
|
|
52
|
-
struct ggml_cgraph * gb_opt;
|
|
50
|
+
struct ggml_cgraph * gf = nullptr;
|
|
51
|
+
struct ggml_cgraph * gb_grad = nullptr;
|
|
52
|
+
struct ggml_cgraph * gb_opt = nullptr;
|
|
53
53
|
|
|
54
|
-
int64_t iter;
|
|
55
|
-
int32_t opt_period;
|
|
56
|
-
int32_t opt_i;
|
|
57
|
-
bool loss_per_datapoint;
|
|
54
|
+
int64_t iter = 1;
|
|
55
|
+
int32_t opt_period = 1;
|
|
56
|
+
int32_t opt_i = 0;
|
|
57
|
+
bool loss_per_datapoint = false;
|
|
58
58
|
|
|
59
|
-
ggml_opt_get_optimizer_params get_opt_pars;
|
|
60
|
-
void * get_opt_pars_ud;
|
|
61
|
-
struct ggml_tensor * adamw_params;
|
|
59
|
+
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
|
60
|
+
void * get_opt_pars_ud = nullptr;
|
|
61
|
+
struct ggml_tensor * adamw_params = nullptr;
|
|
62
62
|
};
|
|
63
63
|
|
|
64
64
|
struct ggml_opt_result {
|
|
@@ -67,8 +67,8 @@ struct ggml_opt_result {
|
|
|
67
67
|
std::vector<int32_t> pred;
|
|
68
68
|
int64_t ncorrect = 0;
|
|
69
69
|
|
|
70
|
-
|
|
71
|
-
|
|
70
|
+
int64_t opt_period = -1;
|
|
71
|
+
bool loss_per_datapoint = false;
|
|
72
72
|
};
|
|
73
73
|
|
|
74
74
|
// ====== Dataset ======
|
|
@@ -188,11 +188,11 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
|
|
188
188
|
}
|
|
189
189
|
|
|
190
190
|
struct ggml_opt_params ggml_opt_default_params(
|
|
191
|
-
ggml_backend_sched_t
|
|
192
|
-
struct ggml_context
|
|
193
|
-
struct ggml_tensor
|
|
194
|
-
struct ggml_tensor
|
|
195
|
-
enum ggml_opt_loss_type
|
|
191
|
+
ggml_backend_sched_t backend_sched,
|
|
192
|
+
struct ggml_context * ctx_compute,
|
|
193
|
+
struct ggml_tensor * inputs,
|
|
194
|
+
struct ggml_tensor * outputs,
|
|
195
|
+
enum ggml_opt_loss_type loss_type) {
|
|
196
196
|
return {
|
|
197
197
|
/*backend_sched =*/ backend_sched,
|
|
198
198
|
/*ctx_compute =*/ ctx_compute,
|
|
@@ -237,25 +237,33 @@ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_
|
|
|
237
237
|
return new_tensor;
|
|
238
238
|
}
|
|
239
239
|
|
|
240
|
-
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph *
|
|
240
|
+
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
|
|
241
241
|
std::map<ggml_tensor *, ggml_tensor *> tensor_map;
|
|
242
242
|
|
|
243
|
-
ggml_cgraph *
|
|
243
|
+
ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
|
|
244
244
|
|
|
245
|
-
for (int i = 0; i <
|
|
246
|
-
ggml_build_forward_expand(
|
|
245
|
+
for (int i = 0; i < src->n_leafs; i++) {
|
|
246
|
+
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
|
|
247
247
|
}
|
|
248
|
-
|
|
249
|
-
|
|
248
|
+
GGML_ASSERT(dst->n_leafs == src->n_leafs);
|
|
249
|
+
for (int i = 0; i < src->n_nodes; i++) {
|
|
250
|
+
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
|
|
250
251
|
}
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
const size_t
|
|
254
|
-
|
|
255
|
-
|
|
252
|
+
GGML_ASSERT(dst->n_nodes == src->n_nodes);
|
|
253
|
+
for (int i = 0; i < src->n_nodes; ++i) {
|
|
254
|
+
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
|
|
255
|
+
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
|
|
256
|
+
|
|
257
|
+
GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
|
|
258
|
+
GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
|
|
259
|
+
GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
|
|
260
|
+
GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
|
|
261
|
+
|
|
262
|
+
dst->grads[igrad_dst] = src->grads[igrad_src];
|
|
263
|
+
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
|
|
256
264
|
}
|
|
257
265
|
|
|
258
|
-
return
|
|
266
|
+
return dst;
|
|
259
267
|
}
|
|
260
268
|
|
|
261
269
|
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
|
|
@@ -284,18 +292,13 @@ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph
|
|
|
284
292
|
|
|
285
293
|
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
286
294
|
ggml_opt_context_t result = new struct ggml_opt_context;
|
|
287
|
-
result->backend_sched
|
|
288
|
-
result->
|
|
289
|
-
result->
|
|
290
|
-
result->
|
|
291
|
-
result->
|
|
292
|
-
result->
|
|
293
|
-
result->
|
|
294
|
-
result->iter = 1;
|
|
295
|
-
result->opt_period = params.opt_period;
|
|
296
|
-
result->opt_i = 0;
|
|
297
|
-
result->get_opt_pars = params.get_opt_pars;
|
|
298
|
-
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
295
|
+
result->backend_sched = params.backend_sched;
|
|
296
|
+
result->ctx_compute = params.ctx_compute;
|
|
297
|
+
result->inputs = params.inputs;
|
|
298
|
+
result->outputs = params.outputs;
|
|
299
|
+
result->opt_period = params.opt_period;
|
|
300
|
+
result->get_opt_pars = params.get_opt_pars;
|
|
301
|
+
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
299
302
|
|
|
300
303
|
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
|
|
301
304
|
GGML_ASSERT(result->opt_period >= 1);
|
|
@@ -348,7 +351,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
|
348
351
|
|
|
349
352
|
switch (params.loss_type) {
|
|
350
353
|
case GGML_OPT_LOSS_TYPE_MEAN: {
|
|
351
|
-
result->labels = nullptr;
|
|
352
354
|
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
|
353
355
|
ggml_set_name(result->loss, "loss_sum");
|
|
354
356
|
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
|
@@ -358,7 +360,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
|
358
360
|
break;
|
|
359
361
|
}
|
|
360
362
|
case GGML_OPT_LOSS_TYPE_SUM: {
|
|
361
|
-
result->labels = nullptr;
|
|
362
363
|
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
|
363
364
|
ggml_set_name(result->loss, "loss_sum");
|
|
364
365
|
result->loss_per_datapoint = false;
|
|
@@ -413,14 +414,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
|
413
414
|
}
|
|
414
415
|
|
|
415
416
|
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
|
416
|
-
result->gb_grad = nullptr;
|
|
417
|
-
result->gb_opt = nullptr;
|
|
418
|
-
|
|
419
417
|
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
|
420
|
-
result->buf_static_cpu = nullptr;
|
|
421
|
-
|
|
422
|
-
ggml_opt_alloc_graph(result, result->gf);
|
|
423
|
-
|
|
424
418
|
return result;
|
|
425
419
|
}
|
|
426
420
|
|
|
@@ -429,14 +423,8 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
|
429
423
|
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
|
|
430
424
|
|
|
431
425
|
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
|
432
|
-
result->gb_opt = nullptr;
|
|
433
|
-
|
|
434
426
|
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
|
435
|
-
result->buf_static_cpu = nullptr;
|
|
436
|
-
|
|
437
|
-
ggml_opt_alloc_graph(result, result->gb_grad);
|
|
438
427
|
ggml_graph_reset(result->gb_grad);
|
|
439
|
-
|
|
440
428
|
return result;
|
|
441
429
|
}
|
|
442
430
|
|
|
@@ -466,7 +454,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
|
466
454
|
|
|
467
455
|
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
|
|
468
456
|
|
|
469
|
-
ggml_opt_alloc_graph(result, result->gb_opt);
|
|
470
457
|
ggml_graph_reset(result->gb_opt);
|
|
471
458
|
|
|
472
459
|
return result;
|
|
@@ -5220,15 +5220,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|
|
5220
5220
|
{
|
|
5221
5221
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
|
5222
5222
|
} break;
|
|
5223
|
-
case GGML_TYPE_Q4_0_4_4:
|
|
5224
|
-
case GGML_TYPE_Q4_0_4_8:
|
|
5225
|
-
{
|
|
5226
|
-
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
|
|
5227
|
-
} break;
|
|
5228
|
-
case GGML_TYPE_Q4_0_8_8:
|
|
5229
|
-
{
|
|
5230
|
-
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
|
|
5231
|
-
} break;
|
|
5232
5223
|
|
|
5233
5224
|
case GGML_TYPE_I8:
|
|
5234
5225
|
case GGML_TYPE_I16:
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
message(STATUS "Using RPC backend")
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
target_link_libraries(ggml-rpc PRIVATE ggml-base)
|
|
7
|
-
target_include_directories(ggml-rpc PRIVATE . ..)
|
|
3
|
+
ggml_add_backend_library(ggml-rpc
|
|
4
|
+
ggml-rpc.cpp
|
|
5
|
+
)
|
|
8
6
|
|
|
9
7
|
if (WIN32)
|
|
10
8
|
target_link_libraries(ggml-rpc PRIVATE ws2_32)
|
|
@@ -1369,8 +1369,9 @@ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
|
|
1369
1369
|
|
|
1370
1370
|
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
1371
1371
|
static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
|
1372
|
-
/* .
|
|
1373
|
-
/* .
|
|
1372
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
1373
|
+
/* .iface = */ ggml_backend_rpc_reg_i,
|
|
1374
|
+
/* .context = */ NULL,
|
|
1374
1375
|
};
|
|
1375
1376
|
|
|
1376
1377
|
return &ggml_backend_rpc_reg;
|
|
@@ -1401,3 +1402,5 @@ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
|
|
1401
1402
|
|
|
1402
1403
|
return dev;
|
|
1403
1404
|
}
|
|
1405
|
+
|
|
1406
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|
|
@@ -16,12 +16,10 @@ endif()
|
|
|
16
16
|
message(STATUS "SYCL found")
|
|
17
17
|
#todo: AOT
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
target_link_libraries(ggml-sycl PRIVATE ggml-base)
|
|
24
|
-
target_include_directories(ggml-sycl PRIVATE . ..)
|
|
19
|
+
ggml_add_backend_library(ggml-sycl
|
|
20
|
+
ggml-sycl.cpp
|
|
21
|
+
../../include/ggml-sycl.h
|
|
22
|
+
)
|
|
25
23
|
|
|
26
24
|
if (GGML_SYCL_F16)
|
|
27
25
|
if (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
@@ -70,12 +68,17 @@ else()
|
|
|
70
68
|
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
|
71
69
|
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
|
72
70
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
|
73
|
-
|
|
71
|
+
add_compile_definitions(GGML_SYCL_NVIDIA)
|
|
72
|
+
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
|
|
74
73
|
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
75
|
-
if (
|
|
76
|
-
message(ERROR "Can't enable SYCL hip backend,
|
|
74
|
+
if (NOT GGML_SYCL_DEVICE_ARCH)
|
|
75
|
+
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
|
77
76
|
endif()
|
|
78
|
-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa
|
|
77
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa")
|
|
79
78
|
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
|
|
80
79
|
endif()
|
|
80
|
+
|
|
81
|
+
if (GGML_SYCL_DEVICE_ARCH)
|
|
82
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}")
|
|
83
|
+
endif()
|
|
81
84
|
endif()
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
//
|
|
12
12
|
|
|
13
13
|
#include "common.hpp"
|
|
14
|
+
#include "ggml-impl.h"
|
|
14
15
|
|
|
15
16
|
int get_current_device_id() {
|
|
16
17
|
return dpct::dev_mgr::instance().current_device_id();
|
|
@@ -28,11 +29,7 @@ void* ggml_sycl_host_malloc(size_t size) try {
|
|
|
28
29
|
|
|
29
30
|
if (err != 0) {
|
|
30
31
|
// clear the error
|
|
31
|
-
|
|
32
|
-
stderr,
|
|
33
|
-
"WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
|
34
|
-
size / 1024.0 / 1024.0,
|
|
35
|
-
"syclGetErrorString is not supported");
|
|
32
|
+
GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");
|
|
36
33
|
return nullptr;
|
|
37
34
|
}
|
|
38
35
|
|
|
@@ -66,18 +63,12 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
|
|
|
66
63
|
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
67
64
|
const ggml_tensor *src1, ggml_tensor *dst,
|
|
68
65
|
const ggml_sycl_op_flatten_t op) try {
|
|
69
|
-
const int64_t nrows0 = ggml_nrows(src0);
|
|
70
66
|
|
|
71
67
|
const bool use_src1 = src1 != nullptr;
|
|
72
|
-
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
|
73
68
|
|
|
74
69
|
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
|
75
70
|
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
|
76
71
|
|
|
77
|
-
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
|
78
|
-
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
|
79
|
-
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
|
80
|
-
|
|
81
72
|
// dd = data device
|
|
82
73
|
float * src0_ddf = (float *) src0->data;
|
|
83
74
|
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
|
@@ -47,7 +47,7 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst,
|
|
|
47
47
|
// operation
|
|
48
48
|
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
|
49
49
|
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
|
50
|
-
if (item_ct1.get_group(1) < ne01) { // src0
|
|
50
|
+
if (item_ct1.get_group(1) < (size_t) ne01) { // src0
|
|
51
51
|
int offset_src =
|
|
52
52
|
nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
|
|
53
53
|
dst[offset_dst] = x[offset_src];
|
|
@@ -70,7 +70,7 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst,
|
|
|
70
70
|
// operation
|
|
71
71
|
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
|
72
72
|
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
|
73
|
-
if (item_ct1.get_group(0) < ne02) { // src0
|
|
73
|
+
if (item_ct1.get_group(0) < (size_t) ne02) { // src0
|
|
74
74
|
int offset_src = nidx + item_ct1.get_group(1) * ne0 +
|
|
75
75
|
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
|
76
76
|
dst[offset_dst] = x[offset_src];
|
|
@@ -424,7 +424,7 @@ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y,
|
|
|
424
424
|
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
|
425
425
|
|
|
426
426
|
// make each work-item deal with more elements since sycl global range can not exceed max int
|
|
427
|
-
const src_t * x = (src_t *) vx;
|
|
427
|
+
const src_t * x = (const src_t *) vx;
|
|
428
428
|
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
|
429
429
|
y[i] = x[i];
|
|
430
430
|
}
|
|
@@ -1015,9 +1015,9 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
1015
1015
|
break;
|
|
1016
1016
|
}
|
|
1017
1017
|
|
|
1018
|
-
(
|
|
1019
|
-
(
|
|
1020
|
-
(
|
|
1021
|
-
(
|
|
1022
|
-
(
|
|
1018
|
+
GGML_UNUSED(src1);
|
|
1019
|
+
GGML_UNUSED(dst);
|
|
1020
|
+
GGML_UNUSED(src1_ddq_i);
|
|
1021
|
+
GGML_UNUSED(src1_ncols);
|
|
1022
|
+
GGML_UNUSED(src1_padded_row_size);
|
|
1023
1023
|
}
|
|
@@ -1237,7 +1237,7 @@ namespace dpct
|
|
|
1237
1237
|
|
|
1238
1238
|
std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
|
|
1239
1239
|
{
|
|
1240
|
-
auto it = m_map.upper_bound((byte_t
|
|
1240
|
+
auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
|
|
1241
1241
|
if (it == m_map.end())
|
|
1242
1242
|
{
|
|
1243
1243
|
// Not a virtual pointer.
|
|
@@ -1689,9 +1689,14 @@ namespace dpct
|
|
|
1689
1689
|
auto data_a = get_memory<const Ta>(a);
|
|
1690
1690
|
auto data_b = get_memory<const Tb>(b);
|
|
1691
1691
|
auto data_c = get_memory<Tc>(c);
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1692
|
+
#ifdef GGML_SYCL_NVIDIA
|
|
1693
|
+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
|
1694
|
+
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1695
|
+
beta_value, data_c, ldc);
|
|
1696
|
+
#else
|
|
1697
|
+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1698
|
+
beta_value, data_c, ldc);
|
|
1699
|
+
#endif
|
|
1695
1700
|
}
|
|
1696
1701
|
|
|
1697
1702
|
template <typename VecT, class BinaryOperation, class = void>
|
|
@@ -1754,14 +1759,22 @@ namespace dpct
|
|
|
1754
1759
|
matrix_info->ld_info[2] = ldc;
|
|
1755
1760
|
matrix_info->groupsize_info = batch_size;
|
|
1756
1761
|
|
|
1762
|
+
#ifdef GGML_SYCL_NVIDIA
|
|
1763
|
+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1764
|
+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
|
1765
|
+
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
|
1766
|
+
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
|
1767
|
+
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
1768
|
+
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
|
1769
|
+
&(matrix_info->groupsize_info));
|
|
1770
|
+
#else
|
|
1757
1771
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1758
|
-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
1759
|
-
matrix_info->size_info, matrix_info->size_info +
|
|
1760
|
-
|
|
1761
|
-
reinterpret_cast<
|
|
1762
|
-
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
1763
|
-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
1772
|
+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
1773
|
+
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
|
1774
|
+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
1775
|
+
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
1764
1776
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1777
|
+
#endif
|
|
1765
1778
|
|
|
1766
1779
|
q.submit([&](sycl::handler &cgh)
|
|
1767
1780
|
{
|
|
@@ -1783,10 +1796,16 @@ namespace dpct
|
|
|
1783
1796
|
auto data_a = get_memory<const Ta>(a);
|
|
1784
1797
|
auto data_b = get_memory<const Tb>(b);
|
|
1785
1798
|
auto data_c = get_memory<Tc>(c);
|
|
1799
|
+
#ifdef GGML_SYCL_NVIDIA
|
|
1786
1800
|
oneapi::mkl::blas::column_major::gemm_batch(
|
|
1787
|
-
q, a_trans, b_trans, m, n, k,
|
|
1788
|
-
stride_a, data_b, ldb, stride_b, beta_value,
|
|
1789
|
-
|
|
1801
|
+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
|
1802
|
+
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
|
1803
|
+
batch_size);
|
|
1804
|
+
#else
|
|
1805
|
+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
1806
|
+
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
|
1807
|
+
stride_c, batch_size);
|
|
1808
|
+
#endif
|
|
1790
1809
|
}
|
|
1791
1810
|
|
|
1792
1811
|
} // namespace detail
|