@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
|
-
typedef void (*dequantize_kernel_t)(const void * vx, const
|
|
18
|
+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
|
19
19
|
|
|
20
|
-
static __dpct_inline__ void dequantize_q4_0(const void *vx, const
|
|
20
|
+
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
|
21
21
|
const int iqs, dfloat2 &v) {
|
|
22
22
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
|
23
23
|
|
|
@@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
|
|
|
40
40
|
#endif // GGML_SYCL_F16
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
-
static __dpct_inline__ void dequantize_q4_1(const void *vx, const
|
|
43
|
+
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
|
|
44
44
|
const int iqs, dfloat2 &v) {
|
|
45
45
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
46
46
|
|
|
@@ -55,16 +55,16 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
|
|
|
55
55
|
#ifdef GGML_SYCL_F16
|
|
56
56
|
// v = v * {d, d};
|
|
57
57
|
// v = v + {m, m};
|
|
58
|
-
v.s0() = (v.s0()
|
|
59
|
-
v.s1() = (v.s1()
|
|
58
|
+
v.s0() = sycl::fma(v.s0(), d, m);
|
|
59
|
+
v.s1() = sycl::fma(v.s1(), d, m);
|
|
60
60
|
|
|
61
61
|
#else
|
|
62
|
-
v.x() = (v.x()
|
|
63
|
-
v.y() = (v.y()
|
|
62
|
+
v.x() = sycl::fma(v.x(), d, m);
|
|
63
|
+
v.y() = sycl::fma(v.y(), d, m);
|
|
64
64
|
#endif // GGML_SYCL_F16
|
|
65
65
|
}
|
|
66
66
|
|
|
67
|
-
static __dpct_inline__ void dequantize_q5_0(const void *vx, const
|
|
67
|
+
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
|
|
68
68
|
const int iqs, dfloat2 &v) {
|
|
69
69
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
|
70
70
|
|
|
@@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
|
|
|
91
91
|
#endif // GGML_SYCL_F16
|
|
92
92
|
}
|
|
93
93
|
|
|
94
|
-
static __dpct_inline__ void dequantize_q5_1(const void *vx, const
|
|
94
|
+
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
|
|
95
95
|
const int iqs, dfloat2 &v) {
|
|
96
96
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
|
97
97
|
|
|
@@ -110,15 +110,15 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
|
|
|
110
110
|
#ifdef GGML_SYCL_F16
|
|
111
111
|
// v = v * {d, d};
|
|
112
112
|
// v = v + {m, m};
|
|
113
|
-
v.s0() = (v.s0()
|
|
114
|
-
v.s1() = (v.s1()
|
|
113
|
+
v.s0() = sycl::fma(v.s0(), d, m);
|
|
114
|
+
v.s1() = sycl::fma(v.s1(), d, m);
|
|
115
115
|
#else
|
|
116
|
-
v.x() = (v.x()
|
|
117
|
-
v.y() = (v.y()
|
|
116
|
+
v.x() = sycl::fma(v.x(), d, m);
|
|
117
|
+
v.y() = sycl::fma(v.y(), d, m);
|
|
118
118
|
#endif // GGML_SYCL_F16
|
|
119
119
|
}
|
|
120
120
|
|
|
121
|
-
static __dpct_inline__ void dequantize_q8_0(const void *vx, const
|
|
121
|
+
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
|
|
122
122
|
const int iqs, dfloat2 &v) {
|
|
123
123
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
|
124
124
|
|
|
@@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
|
|
|
138
138
|
}
|
|
139
139
|
|
|
140
140
|
template<typename dst_t>
|
|
141
|
-
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
|
141
|
+
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
142
142
|
const sycl::nd_item<3> &item_ct1) {
|
|
143
143
|
|
|
144
|
-
const
|
|
144
|
+
const int64_t i = item_ct1.get_group(2);
|
|
145
145
|
|
|
146
146
|
// assume 32 threads
|
|
147
|
-
const
|
|
148
|
-
const
|
|
149
|
-
const
|
|
150
|
-
const
|
|
147
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
148
|
+
const int64_t il = tid/8;
|
|
149
|
+
const int64_t ir = tid%8;
|
|
150
|
+
const int64_t ib = 8*i + ir;
|
|
151
151
|
if (ib >= nb32) {
|
|
152
152
|
return;
|
|
153
153
|
}
|
|
@@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
|
|
|
168
168
|
}
|
|
169
169
|
|
|
170
170
|
template<typename dst_t>
|
|
171
|
-
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
|
171
|
+
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
172
172
|
const sycl::nd_item<3> &item_ct1) {
|
|
173
173
|
|
|
174
|
-
const
|
|
174
|
+
const int64_t i = item_ct1.get_group(2);
|
|
175
175
|
|
|
176
176
|
// assume 32 threads
|
|
177
|
-
const
|
|
178
|
-
const
|
|
179
|
-
const
|
|
180
|
-
const
|
|
177
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
178
|
+
const int64_t il = tid/8;
|
|
179
|
+
const int64_t ir = tid%8;
|
|
180
|
+
const int64_t ib = 8*i + ir;
|
|
181
181
|
if (ib >= nb32) {
|
|
182
182
|
return;
|
|
183
183
|
}
|
|
@@ -203,14 +203,14 @@ template<typename dst_t>
|
|
|
203
203
|
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
|
204
204
|
const sycl::nd_item<3> &item_ct1) {
|
|
205
205
|
|
|
206
|
-
const
|
|
206
|
+
const int64_t i = item_ct1.get_group(2);
|
|
207
207
|
const block_q2_K * x = (const block_q2_K *) vx;
|
|
208
208
|
|
|
209
|
-
const
|
|
209
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
210
210
|
#if QK_K == 256
|
|
211
|
-
const
|
|
212
|
-
const
|
|
213
|
-
const
|
|
211
|
+
const int64_t n = tid/32;
|
|
212
|
+
const int64_t l = tid - 32*n;
|
|
213
|
+
const int64_t is = 8*n + l/16;
|
|
214
214
|
|
|
215
215
|
const uint8_t q = x[i].qs[32*n + l];
|
|
216
216
|
dst_t * y = yy + i*QK_K + 128*n;
|
|
@@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
222
222
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
|
223
223
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
|
224
224
|
#else
|
|
225
|
-
const
|
|
226
|
-
const
|
|
225
|
+
const int64_t is = tid/16; // 0 or 1
|
|
226
|
+
const int64_t il = tid%16; // 0...15
|
|
227
227
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
|
228
228
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
|
229
229
|
|
|
@@ -239,19 +239,19 @@ template<typename dst_t>
|
|
|
239
239
|
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
|
240
240
|
const sycl::nd_item<3> &item_ct1) {
|
|
241
241
|
|
|
242
|
-
const
|
|
242
|
+
const int64_t i = item_ct1.get_group(2);
|
|
243
243
|
const block_q3_K * x = (const block_q3_K *) vx;
|
|
244
244
|
|
|
245
245
|
#if QK_K == 256
|
|
246
|
-
const
|
|
247
|
-
const
|
|
248
|
-
const
|
|
249
|
-
const
|
|
250
|
-
const
|
|
251
|
-
const
|
|
246
|
+
const int64_t r = item_ct1.get_local_id(2) / 4;
|
|
247
|
+
const int64_t tid = r/2;
|
|
248
|
+
const int64_t is0 = r%2;
|
|
249
|
+
const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
|
250
|
+
const int64_t n = tid / 4;
|
|
251
|
+
const int64_t j = tid - 4*n;
|
|
252
252
|
|
|
253
253
|
uint8_t m = 1 << (4*n + j);
|
|
254
|
-
|
|
254
|
+
int64_t is = 8*n + 2*j + is0;
|
|
255
255
|
int shift = 2*j;
|
|
256
256
|
|
|
257
257
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
|
@@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
267
267
|
|
|
268
268
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
|
269
269
|
#else
|
|
270
|
-
const
|
|
271
|
-
const
|
|
272
|
-
const
|
|
273
|
-
const
|
|
274
|
-
const
|
|
270
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
271
|
+
const int64_t is = tid/16; // 0 or 1
|
|
272
|
+
const int64_t il = tid%16; // 0...15
|
|
273
|
+
const int64_t im = il/8; // 0...1
|
|
274
|
+
const int64_t in = il%8; // 0...7
|
|
275
275
|
|
|
276
276
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
|
277
277
|
|
|
@@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
307
307
|
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
|
308
308
|
const block_q4_K * x = (const block_q4_K *) vx;
|
|
309
309
|
|
|
310
|
-
const
|
|
310
|
+
const int64_t i = item_ct1.get_group(2);
|
|
311
311
|
|
|
312
312
|
#if QK_K == 256
|
|
313
313
|
// assume 32 threads
|
|
314
|
-
const
|
|
315
|
-
const
|
|
316
|
-
const
|
|
317
|
-
const
|
|
318
|
-
const
|
|
314
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
315
|
+
const int64_t il = tid/8;
|
|
316
|
+
const int64_t ir = tid%8;
|
|
317
|
+
const int64_t is = 2*il;
|
|
318
|
+
const int64_t n = 4;
|
|
319
319
|
|
|
320
320
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
|
321
321
|
|
|
@@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
341
341
|
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
|
342
342
|
}
|
|
343
343
|
#else
|
|
344
|
-
const
|
|
344
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
345
345
|
const uint8_t * q = x[i].qs;
|
|
346
346
|
dst_t * y = yy + i*QK_K;
|
|
347
347
|
const float d = (float)x[i].dm[0];
|
|
@@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
356
356
|
const sycl::nd_item<3> &item_ct1) {
|
|
357
357
|
const block_q5_K * x = (const block_q5_K *) vx;
|
|
358
358
|
|
|
359
|
-
const
|
|
359
|
+
const int64_t i = item_ct1.get_group(2);
|
|
360
360
|
|
|
361
361
|
#if QK_K == 256
|
|
362
362
|
// assume 64 threads - this is very slightly better than the one below
|
|
363
|
-
const
|
|
364
|
-
const
|
|
365
|
-
const
|
|
366
|
-
const
|
|
363
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
364
|
+
const int64_t il = tid/16; // il is in 0...3
|
|
365
|
+
const int64_t ir = tid%16; // ir is in 0...15
|
|
366
|
+
const int64_t is = 2*il; // is is in 0...6
|
|
367
367
|
|
|
368
368
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
|
369
369
|
|
|
@@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
386
386
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
|
387
387
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
|
388
388
|
#else
|
|
389
|
-
const
|
|
389
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
390
390
|
const uint8_t q = x[i].qs[tid];
|
|
391
|
-
const
|
|
392
|
-
const
|
|
393
|
-
const
|
|
391
|
+
const int64_t im = tid/8; // 0...3
|
|
392
|
+
const int64_t in = tid%8; // 0...7
|
|
393
|
+
const int64_t is = tid/16; // 0 or 1
|
|
394
394
|
const uint8_t h = x[i].qh[in] >> im;
|
|
395
395
|
const float d = x[i].d;
|
|
396
396
|
dst_t * y = yy + i*QK_K + tid;
|
|
@@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
404
404
|
const sycl::nd_item<3> &item_ct1) {
|
|
405
405
|
const block_q6_K * x = (const block_q6_K *) vx;
|
|
406
406
|
|
|
407
|
-
const
|
|
407
|
+
const int64_t i = item_ct1.get_group(2);
|
|
408
408
|
#if QK_K == 256
|
|
409
409
|
|
|
410
410
|
// assume 64 threads - this is very slightly better than the one below
|
|
411
|
-
const
|
|
412
|
-
const
|
|
413
|
-
const
|
|
414
|
-
const
|
|
411
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
412
|
+
const int64_t ip = tid/32; // ip is 0 or 1
|
|
413
|
+
const int64_t il = tid - 32*ip; // 0...32
|
|
414
|
+
const int64_t is = 8*ip + il/16;
|
|
415
415
|
|
|
416
416
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
|
417
417
|
|
|
@@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|
|
428
428
|
#else
|
|
429
429
|
|
|
430
430
|
// assume 32 threads
|
|
431
|
-
const
|
|
432
|
-
const
|
|
433
|
-
const
|
|
431
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
432
|
+
const int64_t ip = tid/16; // 0 or 1
|
|
433
|
+
const int64_t il = tid - 16*ip; // 0...15
|
|
434
434
|
|
|
435
435
|
dst_t * y = yy + i*QK_K + 16*ip + il;
|
|
436
436
|
|
|
@@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
|
|
|
452
452
|
const uint8_t *ksigns_iq2xs_ptr,
|
|
453
453
|
const uint8_t *kmask_iq2xs_ptr) {
|
|
454
454
|
|
|
455
|
-
const
|
|
455
|
+
const int64_t i = item_ct1.get_group(2);
|
|
456
456
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
|
457
457
|
|
|
458
|
-
const
|
|
458
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
459
459
|
#if QK_K == 256
|
|
460
|
-
const
|
|
461
|
-
const
|
|
460
|
+
const int64_t il = tid/8; // 0...3
|
|
461
|
+
const int64_t ib = tid%8; // 0...7
|
|
462
462
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
463
463
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
|
464
464
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
|
@@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
|
|
|
480
480
|
const uint8_t *ksigns_iq2xs,
|
|
481
481
|
const uint8_t *kmask_iq2xs) {
|
|
482
482
|
|
|
483
|
-
const
|
|
483
|
+
const int64_t i = item_ct1.get_group(2);
|
|
484
484
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
|
485
485
|
|
|
486
|
-
const
|
|
486
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
487
487
|
#if QK_K == 256
|
|
488
|
-
const
|
|
489
|
-
const
|
|
488
|
+
const int64_t il = tid/8; // 0...3
|
|
489
|
+
const int64_t ib = tid%8; // 0...7
|
|
490
490
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
491
491
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
|
492
492
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
|
@@ -504,13 +504,13 @@ __dpct_inline__ static void
|
|
|
504
504
|
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
505
505
|
const sycl::nd_item<3> &item_ct1) {
|
|
506
506
|
|
|
507
|
-
const
|
|
507
|
+
const int64_t i = item_ct1.get_group(2);
|
|
508
508
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
|
509
509
|
|
|
510
|
-
const
|
|
510
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
511
511
|
#if QK_K == 256
|
|
512
|
-
const
|
|
513
|
-
const
|
|
512
|
+
const int64_t il = tid/8; // 0...3
|
|
513
|
+
const int64_t ib = tid%8; // 0...7
|
|
514
514
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
515
515
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
|
516
516
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
|
@@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
|
|
|
532
532
|
const uint8_t *ksigns_iq2xs,
|
|
533
533
|
const uint8_t *kmask_iq2xs) {
|
|
534
534
|
|
|
535
|
-
const
|
|
535
|
+
const int64_t i = item_ct1.get_group(2);
|
|
536
536
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
|
537
537
|
|
|
538
|
-
const
|
|
538
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
539
539
|
#if QK_K == 256
|
|
540
|
-
const
|
|
541
|
-
const
|
|
540
|
+
const int64_t il = tid/8; // 0...3
|
|
541
|
+
const int64_t ib = tid%8; // 0...7
|
|
542
542
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
543
543
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
|
544
544
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
|
@@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
|
563
563
|
const sycl::nd_item<3> &item_ct1,
|
|
564
564
|
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
|
565
565
|
|
|
566
|
-
const
|
|
566
|
+
const int64_t i = item_ct1.get_group(2);
|
|
567
567
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
|
568
568
|
|
|
569
|
-
const
|
|
569
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
570
570
|
#if QK_K == 256
|
|
571
|
-
const
|
|
572
|
-
const
|
|
571
|
+
const int64_t il = tid/8; // 0...3
|
|
572
|
+
const int64_t ib = tid%8; // 0...7
|
|
573
573
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
574
574
|
const uint8_t * qs = x[i].qs + 8*ib;
|
|
575
575
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
|
@@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
|
593
593
|
const sycl::nd_item<3> &item_ct1,
|
|
594
594
|
const uint32_t *iq1s_grid_gpu) {
|
|
595
595
|
|
|
596
|
-
const
|
|
596
|
+
const int64_t i = item_ct1.get_group(2);
|
|
597
597
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
|
598
598
|
|
|
599
|
-
const
|
|
599
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
600
600
|
#if QK_K == 256
|
|
601
|
-
const
|
|
602
|
-
const
|
|
601
|
+
const int64_t il = tid/8; // 0...3
|
|
602
|
+
const int64_t ib = tid%8; // 0...7
|
|
603
603
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
604
604
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
|
605
605
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
|
@@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
|
623
623
|
const sycl::nd_item<3> &item_ct1,
|
|
624
624
|
const uint32_t *iq1s_grid_gpu) {
|
|
625
625
|
|
|
626
|
-
const
|
|
626
|
+
const int64_t i = item_ct1.get_group(2);
|
|
627
627
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
|
628
628
|
|
|
629
|
-
const
|
|
629
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
630
630
|
#if QK_K == 256
|
|
631
|
-
const
|
|
632
|
-
const
|
|
631
|
+
const int64_t il = tid/8; // 0...3
|
|
632
|
+
const int64_t ib = tid%8; // 0...7
|
|
633
633
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
|
634
634
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
|
635
635
|
iq1m_scale_t scale;
|
|
@@ -656,12 +656,12 @@ __dpct_inline__ static void
|
|
|
656
656
|
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
657
657
|
const sycl::nd_item<3> &item_ct1) {
|
|
658
658
|
|
|
659
|
-
const
|
|
659
|
+
const int64_t i = item_ct1.get_group(2);
|
|
660
660
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
|
661
661
|
|
|
662
|
-
const
|
|
663
|
-
const
|
|
664
|
-
const
|
|
662
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
663
|
+
const int64_t il = tid/8; // 0...3
|
|
664
|
+
const int64_t ib = tid%8; // 0...7
|
|
665
665
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
|
666
666
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
|
667
667
|
const float d = (float)x[ib].d;
|
|
@@ -678,12 +678,12 @@ template <typename dst_t>
|
|
|
678
678
|
__dpct_inline__ static void
|
|
679
679
|
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|
680
680
|
const sycl::nd_item<3> &item_ct1) {
|
|
681
|
-
const
|
|
681
|
+
const int64_t i = item_ct1.get_group(2);
|
|
682
682
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
|
683
683
|
|
|
684
|
-
const
|
|
685
|
-
const
|
|
686
|
-
const
|
|
684
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
685
|
+
const int64_t il = tid/8; // 0...3
|
|
686
|
+
const int64_t ib = tid%8; // 0...7
|
|
687
687
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
|
688
688
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
|
689
689
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
#include "presets.hpp"
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
static void convert_f16(const void * vx, const
|
|
7
|
+
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
|
8
8
|
const sycl::half *x = (const sycl::half *)vx;
|
|
9
9
|
|
|
10
10
|
// automatic half -> float type cast if dfloat == float
|
|
@@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
|
|
|
12
12
|
v.y() = x[ib + iqs + 1];
|
|
13
13
|
}
|
|
14
14
|
|
|
15
|
-
static void convert_f32(const void * vx, const
|
|
15
|
+
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
|
16
16
|
const float * x = (const float *) vx;
|
|
17
17
|
|
|
18
18
|
// automatic half -> float type cast if dfloat == float
|
|
@@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|
|
76
76
|
}
|
|
77
77
|
|
|
78
78
|
// sum up partial sums and write back result
|
|
79
|
-
|
|
80
|
-
for (int mask =
|
|
79
|
+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
|
80
|
+
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
|
81
81
|
tmp +=
|
|
82
82
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
|
83
83
|
}
|
|
@@ -874,7 +874,7 @@ namespace dpct
|
|
|
874
874
|
inline std::string get_preferred_gpu_platform_name() {
|
|
875
875
|
std::string result;
|
|
876
876
|
|
|
877
|
-
std::string filter = "
|
|
877
|
+
std::string filter = "";
|
|
878
878
|
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
|
|
879
879
|
if (env) {
|
|
880
880
|
if (std::strstr(env, "level_zero")) {
|
|
@@ -892,11 +892,24 @@ namespace dpct
|
|
|
892
892
|
else {
|
|
893
893
|
throw std::runtime_error("invalid device filter: " + std::string(env));
|
|
894
894
|
}
|
|
895
|
+
} else {
|
|
896
|
+
auto default_device = sycl::device(sycl::default_selector_v);
|
|
897
|
+
auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
|
|
898
|
+
|
|
899
|
+
if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
|
|
900
|
+
filter = "level-zero";
|
|
901
|
+
}
|
|
902
|
+
else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
|
|
903
|
+
filter = "cuda";
|
|
904
|
+
}
|
|
905
|
+
else if (std::strstr(default_platform_name.c_str(), "HIP")) {
|
|
906
|
+
filter = "hip";
|
|
907
|
+
}
|
|
895
908
|
}
|
|
896
909
|
|
|
897
|
-
auto
|
|
910
|
+
auto platform_list = sycl::platform::get_platforms();
|
|
898
911
|
|
|
899
|
-
for (const auto& platform :
|
|
912
|
+
for (const auto& platform : platform_list) {
|
|
900
913
|
auto devices = platform.get_devices();
|
|
901
914
|
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
|
|
902
915
|
return d.is_gpu();
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#ifndef GGML_SYCL_GEMM_HPP
|
|
14
|
+
#define GGML_SYCL_GEMM_HPP
|
|
15
|
+
|
|
16
|
+
#include <fstream>
|
|
17
|
+
#include <iostream>
|
|
18
|
+
|
|
19
|
+
#include "ggml-sycl.h"
|
|
20
|
+
|
|
21
|
+
#if GGML_SYCL_DNNL
|
|
22
|
+
|
|
23
|
+
#include "dnnl.hpp"
|
|
24
|
+
#include "dnnl_sycl.hpp"
|
|
25
|
+
|
|
26
|
+
class DnnlGemmWrapper {
|
|
27
|
+
public:
|
|
28
|
+
using dt = dnnl::memory::data_type;
|
|
29
|
+
using tag = dnnl::memory::format_tag;
|
|
30
|
+
|
|
31
|
+
template<typename T>
|
|
32
|
+
static constexpr dt to_dt() {
|
|
33
|
+
if constexpr (std::is_same_v<T, float>) return dt::f32;
|
|
34
|
+
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
|
|
35
|
+
else static_assert(0);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
|
39
|
+
bool b_trans, int m, int n, int k,
|
|
40
|
+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
|
41
|
+
{
|
|
42
|
+
// Get the device associated with the queue
|
|
43
|
+
sycl::device dev = q.get_device();
|
|
44
|
+
// Get the context associated with the queue
|
|
45
|
+
sycl::context ctx = q.get_context();
|
|
46
|
+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
|
47
|
+
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
|
48
|
+
dnnl::memory::dims a_dims = { m, k };
|
|
49
|
+
dnnl::memory::dims b_dims = { k, n };
|
|
50
|
+
dnnl::memory::dims c_dims = { m, n };
|
|
51
|
+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
|
52
|
+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
|
53
|
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
54
|
+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
|
55
|
+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
|
56
|
+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
|
57
|
+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
|
58
|
+
|
|
59
|
+
// Create the primitive.
|
|
60
|
+
auto matmul_prim = dnnl::matmul(matmul_pd);
|
|
61
|
+
// Primitive arguments.
|
|
62
|
+
std::unordered_map<int, dnnl::memory> matmul_args;
|
|
63
|
+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
|
64
|
+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
65
|
+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
|
66
|
+
|
|
67
|
+
matmul_prim.execute(stream, matmul_args);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
|
72
|
+
bool b_trans, int m, int n, int k,
|
|
73
|
+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
|
74
|
+
{
|
|
75
|
+
auto const eng = stream.get_engine();
|
|
76
|
+
dnnl::memory::dims a_dims = { m, k };
|
|
77
|
+
dnnl::memory::dims b_dims = { k, n };
|
|
78
|
+
dnnl::memory::dims c_dims = { m, n };
|
|
79
|
+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
|
80
|
+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
|
81
|
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
82
|
+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
|
83
|
+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
|
84
|
+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
|
85
|
+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
|
86
|
+
|
|
87
|
+
// Create the primitive.
|
|
88
|
+
auto matmul_prim = dnnl::matmul(matmul_pd);
|
|
89
|
+
// Primitive arguments.
|
|
90
|
+
std::unordered_map<int, dnnl::memory> matmul_args;
|
|
91
|
+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
|
92
|
+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
93
|
+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
|
94
|
+
|
|
95
|
+
matmul_prim.execute(stream, matmul_args);
|
|
96
|
+
}
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
#endif
|
|
100
|
+
|
|
101
|
+
#endif // GGML_SYCL_GEMM_HPP
|