@fugood/llama.node 0.3.14 → 0.3.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +37 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +20 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +103 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8244 -173
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#include "common.hpp"
|
|
2
2
|
#include "element_wise.hpp"
|
|
3
3
|
|
|
4
|
-
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
|
4
|
+
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
|
5
5
|
const int ne10, const int ne11, const int ne12,
|
|
6
6
|
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
|
7
7
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
@@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
|
|
23
|
-
void gelu_f32(const float * x, float * dst, const int k,
|
|
23
|
+
static void gelu_f32(const float * x, float * dst, const int k,
|
|
24
24
|
const sycl::nd_item<3> &item_ct1) {
|
|
25
25
|
const float GELU_COEF_A = 0.044715f;
|
|
26
26
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
@@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
|
|
|
37
37
|
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
|
38
38
|
}
|
|
39
39
|
|
|
40
|
-
void silu_f32(const float * x, float * dst, const int k,
|
|
40
|
+
static void silu_f32(const float * x, float * dst, const int k,
|
|
41
41
|
const sycl::nd_item<3> &item_ct1) {
|
|
42
42
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
43
43
|
item_ct1.get_local_id(2);
|
|
@@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
|
|
|
48
48
|
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
-
void gelu_quick_f32(const float *x, float *dst, int k,
|
|
51
|
+
static void gelu_quick_f32(const float *x, float *dst, int k,
|
|
52
52
|
const sycl::nd_item<3> &item_ct1) {
|
|
53
53
|
const float GELU_QUICK_COEF = -1.702f;
|
|
54
54
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
@@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
|
|
|
59
59
|
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
|
60
60
|
}
|
|
61
61
|
|
|
62
|
-
void tanh_f32(const float *x, float *dst, int k,
|
|
62
|
+
static void tanh_f32(const float *x, float *dst, int k,
|
|
63
63
|
const sycl::nd_item<3> &item_ct1) {
|
|
64
64
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
65
65
|
item_ct1.get_local_id(2);
|
|
@@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
|
|
|
69
69
|
dst[i] = sycl::tanh((float)(x[i]));
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
-
void relu_f32(const float * x, float * dst, const int k,
|
|
72
|
+
static void relu_f32(const float * x, float * dst, const int k,
|
|
73
73
|
const sycl::nd_item<3> &item_ct1) {
|
|
74
74
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
75
75
|
item_ct1.get_local_id(2);
|
|
@@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
|
|
|
80
80
|
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
|
81
81
|
}
|
|
82
82
|
|
|
83
|
-
void sigmoid_f32(const float * x, float * dst, const int k,
|
|
83
|
+
static void sigmoid_f32(const float * x, float * dst, const int k,
|
|
84
84
|
const sycl::nd_item<3> &item_ct1) {
|
|
85
85
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
86
86
|
item_ct1.get_local_id(2);
|
|
@@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
|
|
|
91
91
|
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
|
92
92
|
}
|
|
93
93
|
|
|
94
|
-
void sqrt_f32(const float * x, float * dst, const int k,
|
|
94
|
+
static void sqrt_f32(const float * x, float * dst, const int k,
|
|
95
95
|
const sycl::nd_item<3> &item_ct1) {
|
|
96
96
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
97
97
|
item_ct1.get_local_id(2);
|
|
@@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
|
|
|
102
102
|
dst[i] = sycl::sqrt(x[i]);
|
|
103
103
|
}
|
|
104
104
|
|
|
105
|
-
void sin_f32(const float * x, float * dst, const int k,
|
|
105
|
+
static void sin_f32(const float * x, float * dst, const int k,
|
|
106
106
|
const sycl::nd_item<3> &item_ct1) {
|
|
107
107
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
108
108
|
item_ct1.get_local_id(2);
|
|
@@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
|
|
|
113
113
|
dst[i] = sycl::sin(x[i]);
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
-
void cos_f32(const float * x, float * dst, const int k,
|
|
116
|
+
static void cos_f32(const float * x, float * dst, const int k,
|
|
117
117
|
const sycl::nd_item<3> &item_ct1) {
|
|
118
118
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
119
119
|
item_ct1.get_local_id(2);
|
|
@@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
|
|
|
124
124
|
dst[i] = sycl::cos(x[i]);
|
|
125
125
|
}
|
|
126
126
|
|
|
127
|
-
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
|
127
|
+
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
|
128
128
|
const sycl::nd_item<3> &item_ct1) {
|
|
129
129
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
130
130
|
item_ct1.get_local_id(2);
|
|
@@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
|
|
|
135
135
|
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
136
136
|
}
|
|
137
137
|
|
|
138
|
-
void hardswish_f32(const float * x, float * dst, const int k,
|
|
138
|
+
static void hardswish_f32(const float * x, float * dst, const int k,
|
|
139
139
|
const sycl::nd_item<3> &item_ct1) {
|
|
140
140
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
141
141
|
item_ct1.get_local_id(2);
|
|
@@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
|
|
|
146
146
|
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
147
147
|
}
|
|
148
148
|
|
|
149
|
-
void exp_f32(const float * x, float * dst, const int k,
|
|
149
|
+
static void exp_f32(const float * x, float * dst, const int k,
|
|
150
150
|
const sycl::nd_item<3> &item_ct1) {
|
|
151
151
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
152
152
|
item_ct1.get_local_id(2);
|
|
@@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
|
|
|
157
157
|
dst[i] = sycl::exp(x[i]);
|
|
158
158
|
}
|
|
159
159
|
|
|
160
|
-
void log_f32(const float * x, float * dst, const int k,
|
|
160
|
+
static void log_f32(const float * x, float * dst, const int k,
|
|
161
161
|
const sycl::nd_item<3> &item_ct1) {
|
|
162
162
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
163
163
|
item_ct1.get_local_id(2);
|
|
@@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
|
|
|
173
173
|
}
|
|
174
174
|
}
|
|
175
175
|
|
|
176
|
-
void neg_f32(const float * x, float * dst, const int k,
|
|
176
|
+
static void neg_f32(const float * x, float * dst, const int k,
|
|
177
177
|
const sycl::nd_item<3> &item_ct1) {
|
|
178
178
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
179
179
|
item_ct1.get_local_id(2);
|
|
@@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
|
|
|
184
184
|
dst[i] = -x[i];
|
|
185
185
|
}
|
|
186
186
|
|
|
187
|
-
void step_f32(const float * x, float * dst, const int k,
|
|
187
|
+
static void step_f32(const float * x, float * dst, const int k,
|
|
188
188
|
const sycl::nd_item<3> &item_ct1) {
|
|
189
189
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
190
190
|
item_ct1.get_local_id(2);
|
|
@@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
|
|
|
195
195
|
dst[i] = x[i] > 0.0f;
|
|
196
196
|
}
|
|
197
197
|
|
|
198
|
-
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
|
198
|
+
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
|
199
199
|
const sycl::nd_item<3> &item_ct1) {
|
|
200
200
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
201
201
|
item_ct1.get_local_id(2);
|
|
@@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
|
|
|
206
206
|
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
|
207
207
|
}
|
|
208
208
|
|
|
209
|
-
void sqr_f32(const float * x, float * dst, const int k,
|
|
209
|
+
static void sqr_f32(const float * x, float * dst, const int k,
|
|
210
210
|
const sycl::nd_item<3> &item_ct1) {
|
|
211
211
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
212
212
|
item_ct1.get_local_id(2);
|
|
@@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
|
|
|
217
217
|
dst[i] = x[i] * x[i];
|
|
218
218
|
}
|
|
219
219
|
|
|
220
|
-
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
|
220
|
+
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
|
221
221
|
const int nb02, const int nb03, const int ne10, const int ne11,
|
|
222
222
|
const int ne12, const int ne13, const float sf0, const float sf1,
|
|
223
223
|
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
|
@@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
|
|
240
240
|
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
|
241
241
|
}
|
|
242
242
|
|
|
243
|
-
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
|
243
|
+
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
|
244
244
|
const sycl::nd_item<3> &item_ct1) {
|
|
245
245
|
int nidx = item_ct1.get_local_id(2) +
|
|
246
246
|
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
|
@@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
|
|
|
262
262
|
|
|
263
263
|
|
|
264
264
|
|
|
265
|
-
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|
265
|
+
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|
266
266
|
const int n_elements, const int ne10, const int ne11,
|
|
267
267
|
const int ne12, const int nb1, const int nb2,
|
|
268
268
|
const int offset, queue_ptr stream) {
|
|
@@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|
|
277
277
|
});
|
|
278
278
|
}
|
|
279
279
|
|
|
280
|
-
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
|
280
|
+
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
|
281
281
|
queue_ptr stream) {
|
|
282
282
|
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
|
283
283
|
stream->parallel_for(
|
|
@@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
|
|
|
289
289
|
});
|
|
290
290
|
}
|
|
291
291
|
|
|
292
|
-
void silu_f32_sycl(const float *x, float *dst, const int k,
|
|
292
|
+
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
|
293
293
|
queue_ptr stream) {
|
|
294
294
|
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
|
295
295
|
stream->parallel_for(
|
|
@@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
|
|
|
301
301
|
});
|
|
302
302
|
}
|
|
303
303
|
|
|
304
|
-
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
|
304
|
+
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
|
305
305
|
queue_ptr stream) {
|
|
306
306
|
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
|
307
307
|
stream->parallel_for(
|
|
@@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
|
|
313
313
|
});
|
|
314
314
|
}
|
|
315
315
|
|
|
316
|
-
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
|
316
|
+
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
|
317
317
|
queue_ptr stream) {
|
|
318
318
|
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
|
319
319
|
stream->parallel_for(
|
|
@@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
|
|
|
325
325
|
});
|
|
326
326
|
}
|
|
327
327
|
|
|
328
|
-
void relu_f32_sycl(const float *x, float *dst, const int k,
|
|
328
|
+
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
|
329
329
|
queue_ptr stream) {
|
|
330
330
|
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
|
331
331
|
stream->parallel_for(
|
|
@@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
|
|
|
337
337
|
});
|
|
338
338
|
}
|
|
339
339
|
|
|
340
|
-
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
340
|
+
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
341
341
|
queue_ptr stream) {
|
|
342
342
|
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
|
343
343
|
stream->parallel_for(
|
|
@@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
|
349
349
|
});
|
|
350
350
|
}
|
|
351
351
|
|
|
352
|
-
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
|
352
|
+
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
|
353
353
|
queue_ptr stream) {
|
|
354
354
|
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
|
355
355
|
stream->parallel_for(
|
|
@@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
|
|
361
361
|
});
|
|
362
362
|
}
|
|
363
363
|
|
|
364
|
-
void exp_f32_sycl(const float *x, float *dst, const int k,
|
|
364
|
+
static void exp_f32_sycl(const float *x, float *dst, const int k,
|
|
365
365
|
queue_ptr stream) {
|
|
366
366
|
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
|
367
367
|
stream->parallel_for(
|
|
@@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
|
|
|
373
373
|
});
|
|
374
374
|
}
|
|
375
375
|
|
|
376
|
-
void log_f32_sycl(const float *x, float *dst, const int k,
|
|
376
|
+
static void log_f32_sycl(const float *x, float *dst, const int k,
|
|
377
377
|
queue_ptr stream) {
|
|
378
378
|
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
|
379
379
|
stream->parallel_for(
|
|
@@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
|
|
|
385
385
|
});
|
|
386
386
|
}
|
|
387
387
|
|
|
388
|
-
void neg_f32_sycl(const float *x, float *dst, const int k,
|
|
388
|
+
static void neg_f32_sycl(const float *x, float *dst, const int k,
|
|
389
389
|
queue_ptr stream) {
|
|
390
390
|
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
|
391
391
|
stream->parallel_for(
|
|
@@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
|
|
|
397
397
|
});
|
|
398
398
|
}
|
|
399
399
|
|
|
400
|
-
void step_f32_sycl(const float *x, float *dst, const int k,
|
|
400
|
+
static void step_f32_sycl(const float *x, float *dst, const int k,
|
|
401
401
|
queue_ptr stream) {
|
|
402
402
|
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
|
403
403
|
stream->parallel_for(
|
|
@@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
|
|
|
409
409
|
});
|
|
410
410
|
}
|
|
411
411
|
|
|
412
|
-
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
412
|
+
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
413
413
|
queue_ptr stream) {
|
|
414
414
|
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
|
415
415
|
stream->parallel_for(
|
|
@@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
|
|
421
421
|
});
|
|
422
422
|
}
|
|
423
423
|
|
|
424
|
-
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
|
424
|
+
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
|
425
425
|
queue_ptr stream) {
|
|
426
426
|
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
|
427
427
|
stream->parallel_for(
|
|
@@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
|
|
433
433
|
});
|
|
434
434
|
}
|
|
435
435
|
|
|
436
|
-
void sin_f32_sycl(const float *x, float *dst, const int k,
|
|
436
|
+
static void sin_f32_sycl(const float *x, float *dst, const int k,
|
|
437
437
|
queue_ptr stream) {
|
|
438
438
|
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
|
439
439
|
stream->parallel_for(
|
|
@@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
|
|
|
445
445
|
});
|
|
446
446
|
}
|
|
447
447
|
|
|
448
|
-
void cos_f32_sycl(const float *x, float *dst, const int k,
|
|
448
|
+
static void cos_f32_sycl(const float *x, float *dst, const int k,
|
|
449
449
|
queue_ptr stream) {
|
|
450
450
|
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
|
451
451
|
stream->parallel_for(
|
|
@@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
|
|
|
457
457
|
});
|
|
458
458
|
}
|
|
459
459
|
|
|
460
|
-
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
|
460
|
+
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
|
461
461
|
const float negative_slope,
|
|
462
462
|
queue_ptr stream) {
|
|
463
463
|
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
|
@@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
|
|
470
470
|
});
|
|
471
471
|
}
|
|
472
472
|
|
|
473
|
-
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
|
473
|
+
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
|
474
474
|
queue_ptr stream) {
|
|
475
475
|
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
|
476
476
|
stream->parallel_for(
|
|
@@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
|
|
|
482
482
|
});
|
|
483
483
|
}
|
|
484
484
|
|
|
485
|
-
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
|
485
|
+
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
|
486
486
|
const int nb02, const int nb03, const int ne10, const int ne11,
|
|
487
487
|
const int ne12, const int ne13, const float sf0, const float sf1,
|
|
488
488
|
const float sf2, const float sf3, queue_ptr stream) {
|
|
@@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
|
|
|
496
496
|
});
|
|
497
497
|
}
|
|
498
498
|
|
|
499
|
-
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
|
499
|
+
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
|
500
500
|
const int ne01, const int ne02, const int ne0,
|
|
501
501
|
const int ne1, const int ne2, queue_ptr stream) {
|
|
502
502
|
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
|
@@ -13,9 +13,6 @@
|
|
|
13
13
|
#ifndef GGML_SYCL_GEMM_HPP
|
|
14
14
|
#define GGML_SYCL_GEMM_HPP
|
|
15
15
|
|
|
16
|
-
#include <fstream>
|
|
17
|
-
#include <iostream>
|
|
18
|
-
|
|
19
16
|
#include "ggml-sycl.h"
|
|
20
17
|
|
|
21
18
|
#if GGML_SYCL_DNNL
|
|
@@ -35,62 +32,34 @@ public:
|
|
|
35
32
|
else static_assert(0);
|
|
36
33
|
}
|
|
37
34
|
|
|
38
|
-
static inline void row_gemm(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
// Get the device associated with the queue
|
|
43
|
-
sycl::device dev = q.get_device();
|
|
44
|
-
// Get the context associated with the queue
|
|
45
|
-
sycl::context ctx = q.get_context();
|
|
46
|
-
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
|
47
|
-
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
|
35
|
+
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
|
36
|
+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
|
37
|
+
auto stream = ctx.stream_dnnl(q);
|
|
38
|
+
auto eng = ctx.engine_dnnl(q);
|
|
48
39
|
dnnl::memory::dims a_dims = { m, k };
|
|
49
40
|
dnnl::memory::dims b_dims = { k, n };
|
|
50
41
|
dnnl::memory::dims c_dims = { m, n };
|
|
51
42
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
|
52
43
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
|
53
|
-
const auto c_md
|
|
54
|
-
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
|
55
|
-
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
|
56
|
-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
|
57
|
-
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
|
44
|
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
58
45
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
// Primitive arguments.
|
|
62
|
-
std::unordered_map<int, dnnl::memory> matmul_args;
|
|
63
|
-
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
|
64
|
-
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
65
|
-
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
|
46
|
+
dnnl::primitive_attr primitive_attr;
|
|
47
|
+
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
|
66
48
|
|
|
67
|
-
matmul_prim.execute(stream, matmul_args);
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
|
72
|
-
bool b_trans, int m, int n, int k,
|
|
73
|
-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
|
74
|
-
{
|
|
75
|
-
auto const eng = stream.get_engine();
|
|
76
|
-
dnnl::memory::dims a_dims = { m, k };
|
|
77
|
-
dnnl::memory::dims b_dims = { k, n };
|
|
78
|
-
dnnl::memory::dims c_dims = { m, n };
|
|
79
|
-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
|
80
|
-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
|
81
|
-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
82
49
|
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
|
83
50
|
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
|
84
|
-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
|
51
|
+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
|
|
85
52
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
|
86
53
|
|
|
87
|
-
|
|
54
|
+
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
|
55
|
+
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
|
88
56
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
|
89
|
-
|
|
57
|
+
|
|
90
58
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
|
91
59
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
|
92
60
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
93
61
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
|
62
|
+
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
|
94
63
|
|
|
95
64
|
matmul_prim.execute(stream, matmul_args);
|
|
96
65
|
}
|
|
@@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
|
|
|
207
207
|
const size_t nrows = ne01;
|
|
208
208
|
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
|
209
209
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
210
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
210
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
|
211
211
|
k_get_rows_reorder<qk, qr, dq_reorder>(
|
|
212
212
|
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
|
213
213
|
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
|
@@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
|
|
302
302
|
// TODO: k-quants
|
|
303
303
|
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
|
304
304
|
GGML_ABORT("fatal error");
|
|
305
|
-
break;
|
|
306
305
|
}
|
|
307
306
|
}
|
|
308
307
|
|