@fugood/llama.node 0.3.14 → 0.3.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +37 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +20 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +103 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8244 -173
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
8158
8158
|
|
|
8159
8159
|
const int nb = n / QK_K;
|
|
8160
8160
|
|
|
8161
|
-
#ifdef
|
|
8161
|
+
#ifdef __ARM_FEATURE_SVE
|
|
8162
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
8163
|
+
float sum = 0;
|
|
8164
|
+
svuint8_t m4b = svdup_n_u8(0xf);
|
|
8165
|
+
svint32_t vzero = svdup_n_s32(0);
|
|
8166
|
+
svuint8_t mone = svdup_n_u8(0x30);
|
|
8167
|
+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
|
8168
|
+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
|
8169
|
+
|
|
8170
|
+
for (int i = 0; i < nb; ++i) {
|
|
8171
|
+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
|
8172
|
+
|
|
8173
|
+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
|
8174
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
8175
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
8176
|
+
|
|
8177
|
+
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
|
8178
|
+
|
|
8179
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
|
8180
|
+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
|
8181
|
+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
|
8182
|
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
|
8183
|
+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
|
8184
|
+
const svint64_t prod = svdup_n_s64(0);
|
|
8185
|
+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
|
8186
|
+
svdot_s64(prod, q8sums_2, q6scales_2)));
|
|
8187
|
+
int32_t isum = 0;
|
|
8188
|
+
|
|
8189
|
+
switch (vector_length) {
|
|
8190
|
+
case 128:
|
|
8191
|
+
{
|
|
8192
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
|
8193
|
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
|
8194
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
8195
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
8196
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
|
8197
|
+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
|
8198
|
+
qh += 32;
|
|
8199
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
|
8200
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
|
8201
|
+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
|
8202
|
+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
|
8203
|
+
q6 += 64;
|
|
8204
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
|
8205
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
|
8206
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
|
8207
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
|
8208
|
+
q8 += 64;
|
|
8209
|
+
|
|
8210
|
+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
|
8211
|
+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
|
8212
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
|
8213
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
|
8214
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
|
8215
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
|
8216
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
|
8217
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
|
8218
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
|
8219
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
|
8220
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
|
8221
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
|
8222
|
+
|
|
8223
|
+
scale += 4;
|
|
8224
|
+
q8bytes_1 = svld1_s8(pg8_16, q8);
|
|
8225
|
+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
|
8226
|
+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
|
8227
|
+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
|
8228
|
+
q8 += 64;
|
|
8229
|
+
|
|
8230
|
+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
|
8231
|
+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
|
8232
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
|
8233
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
|
8234
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
|
8235
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
|
8236
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
|
8237
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
|
8238
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
|
8239
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
|
8240
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
|
8241
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
|
8242
|
+
scale += 4;
|
|
8243
|
+
}
|
|
8244
|
+
isum += svaddv_s32(pg32_4, isum_tmp);
|
|
8245
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
|
8246
|
+
}
|
|
8247
|
+
break;
|
|
8248
|
+
case 256:
|
|
8249
|
+
case 512:
|
|
8250
|
+
{
|
|
8251
|
+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
|
8252
|
+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
|
8253
|
+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
|
8254
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
8255
|
+
for (int j = 0; j < QK_K/128; j++) {
|
|
8256
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
|
8257
|
+
qh += 32;
|
|
8258
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
|
8259
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
|
8260
|
+
q6 += 64;
|
|
8261
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
|
8262
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
|
8263
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
|
8264
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
|
8265
|
+
q8 += 128;
|
|
8266
|
+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
|
8267
|
+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
|
8268
|
+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
|
8269
|
+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
|
8270
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
|
8271
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
|
8272
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
|
8273
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
|
8274
|
+
|
|
8275
|
+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
|
8276
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
|
8277
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
|
8278
|
+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
|
8279
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
|
8280
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
|
8281
|
+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
|
8282
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
|
8283
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
|
8284
|
+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
|
8285
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
|
8286
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
|
8287
|
+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
|
8288
|
+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
|
8289
|
+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
|
8290
|
+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
|
8291
|
+
|
|
8292
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
|
8293
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
|
8294
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
|
8295
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
|
8296
|
+
scale += 8;
|
|
8297
|
+
}
|
|
8298
|
+
isum += svaddv_s32(pg32_8, isum_tmp);
|
|
8299
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
|
8300
|
+
}
|
|
8301
|
+
break;
|
|
8302
|
+
default:
|
|
8303
|
+
assert(false && "Unsupported vector length");
|
|
8304
|
+
break;
|
|
8305
|
+
}
|
|
8306
|
+
}
|
|
8307
|
+
|
|
8308
|
+
*s = sum;
|
|
8309
|
+
|
|
8310
|
+
#elif __ARM_NEON
|
|
8162
8311
|
float sum = 0;
|
|
8163
8312
|
|
|
8164
8313
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
|
@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
|
|
|
3110
3110
|
const int ith = params->ith; // thread index
|
|
3111
3111
|
const int nth = params->nth; // number of threads
|
|
3112
3112
|
|
|
3113
|
-
// parallelize by
|
|
3114
|
-
const int
|
|
3115
|
-
const int dr = (
|
|
3116
|
-
const int
|
|
3117
|
-
const int
|
|
3113
|
+
// parallelize by blocks
|
|
3114
|
+
const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
|
|
3115
|
+
const int dr = (nk + nth - 1) / nth;
|
|
3116
|
+
const int k0 = dr * ith;
|
|
3117
|
+
const int k1 = MIN(k0 + dr, nk);
|
|
3118
3118
|
|
|
3119
|
-
if (
|
|
3119
|
+
if (k0 < k1) {
|
|
3120
3120
|
memcpy(
|
|
3121
|
-
((char *) dst->data +
|
|
3122
|
-
((char *) src0->data +
|
|
3123
|
-
(
|
|
3121
|
+
((char *) dst->data + k0*nb0),
|
|
3122
|
+
((char *) src0->data + k0*nb0),
|
|
3123
|
+
(k1 - k0) * nb0);
|
|
3124
3124
|
}
|
|
3125
3125
|
}
|
|
3126
3126
|
|
|
@@ -4055,7 +4055,6 @@ static void ggml_compute_forward_dup_f32(
|
|
|
4055
4055
|
static void ggml_compute_forward_dup_bytes(
|
|
4056
4056
|
const struct ggml_compute_params * params,
|
|
4057
4057
|
struct ggml_tensor * dst) {
|
|
4058
|
-
|
|
4059
4058
|
const struct ggml_tensor * src0 = dst->src[0];
|
|
4060
4059
|
|
|
4061
4060
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
@@ -4069,10 +4068,10 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
4069
4068
|
}
|
|
4070
4069
|
|
|
4071
4070
|
const size_t type_size = ggml_type_size(src0->type);
|
|
4071
|
+
|
|
4072
4072
|
const int ith = params->ith; // thread index
|
|
4073
4073
|
const int nth = params->nth; // number of threads
|
|
4074
4074
|
|
|
4075
|
-
|
|
4076
4075
|
// parallelize by rows
|
|
4077
4076
|
const int nr = ne01;
|
|
4078
4077
|
// number of rows per thread
|
|
@@ -4082,10 +4081,10 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
4082
4081
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
4083
4082
|
|
|
4084
4083
|
if (src0->type == dst->type &&
|
|
4085
|
-
|
|
4084
|
+
ggml_are_same_shape(src0, dst) &&
|
|
4086
4085
|
nb00 == type_size && nb0 == type_size) {
|
|
4087
4086
|
// copy by rows
|
|
4088
|
-
const size_t rs = ne00
|
|
4087
|
+
const size_t rs = ggml_row_size(src0->type, ne00);
|
|
4089
4088
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
4090
4089
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
4091
4090
|
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
@@ -4140,17 +4139,20 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
4140
4139
|
}
|
|
4141
4140
|
|
|
4142
4141
|
// dst counters
|
|
4143
|
-
|
|
4144
|
-
int64_t i10 = 0;
|
|
4142
|
+
int64_t k10 = 0;
|
|
4145
4143
|
int64_t i11 = 0;
|
|
4146
4144
|
int64_t i12 = 0;
|
|
4147
4145
|
int64_t i13 = 0;
|
|
4148
4146
|
|
|
4147
|
+
// number of blocks in a row
|
|
4148
|
+
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
|
|
4149
|
+
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
|
|
4150
|
+
|
|
4149
4151
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
4150
4152
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
4151
|
-
|
|
4152
|
-
while (
|
|
4153
|
-
|
|
4153
|
+
k10 += nk00 * ir0;
|
|
4154
|
+
while (k10 >= nk0) {
|
|
4155
|
+
k10 -= nk0;
|
|
4154
4156
|
if (++i11 == ne1) {
|
|
4155
4157
|
i11 = 0;
|
|
4156
4158
|
if (++i12 == ne2) {
|
|
@@ -4162,14 +4164,14 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
4162
4164
|
}
|
|
4163
4165
|
}
|
|
4164
4166
|
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
4165
|
-
for (int64_t
|
|
4166
|
-
const char * src0_ptr = ((char *) src0->data +
|
|
4167
|
-
char * dst_ptr = ((char *) dst->data +
|
|
4167
|
+
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
|
4168
|
+
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
4169
|
+
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
4168
4170
|
|
|
4169
4171
|
memcpy(dst_ptr, src0_ptr, type_size);
|
|
4170
4172
|
|
|
4171
|
-
if (++
|
|
4172
|
-
|
|
4173
|
+
if (++k10 == nk0) {
|
|
4174
|
+
k10 = 0;
|
|
4173
4175
|
if (++i11 == ne1) {
|
|
4174
4176
|
i11 = 0;
|
|
4175
4177
|
if (++i12 == ne2) {
|
|
@@ -4182,9 +4184,9 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
4182
4184
|
}
|
|
4183
4185
|
}
|
|
4184
4186
|
}
|
|
4185
|
-
|
|
4186
|
-
while (
|
|
4187
|
-
|
|
4187
|
+
k10 += nk00 * (ne01 - ir1);
|
|
4188
|
+
while (k10 >= nk0) {
|
|
4189
|
+
k10 -= nk0;
|
|
4188
4190
|
if (++i11 == ne1) {
|
|
4189
4191
|
i11 = 0;
|
|
4190
4192
|
if (++i12 == ne2) {
|
|
@@ -8548,6 +8550,69 @@ static void ggml_compute_forward_group_norm(
|
|
|
8548
8550
|
}
|
|
8549
8551
|
}
|
|
8550
8552
|
|
|
8553
|
+
// ggml_compute_forward_l2_norm
|
|
8554
|
+
|
|
8555
|
+
static void ggml_compute_forward_l2_norm_f32(
|
|
8556
|
+
const struct ggml_compute_params * params,
|
|
8557
|
+
struct ggml_tensor * dst) {
|
|
8558
|
+
|
|
8559
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
8560
|
+
|
|
8561
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
8562
|
+
|
|
8563
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8564
|
+
|
|
8565
|
+
const int ith = params->ith;
|
|
8566
|
+
const int nth = params->nth;
|
|
8567
|
+
|
|
8568
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
8569
|
+
|
|
8570
|
+
float eps;
|
|
8571
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
8572
|
+
|
|
8573
|
+
GGML_ASSERT(eps >= 0.0f);
|
|
8574
|
+
|
|
8575
|
+
// TODO: optimize
|
|
8576
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
8577
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
8578
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
8579
|
+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
8580
|
+
|
|
8581
|
+
ggml_float sum = 0.0;
|
|
8582
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
8583
|
+
sum += (ggml_float)(x[i00] * x[i00]);
|
|
8584
|
+
}
|
|
8585
|
+
|
|
8586
|
+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
8587
|
+
|
|
8588
|
+
memcpy(y, x, ne00 * sizeof(float));
|
|
8589
|
+
|
|
8590
|
+
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
|
8591
|
+
|
|
8592
|
+
ggml_vec_scale_f32(ne00, y, scale);
|
|
8593
|
+
}
|
|
8594
|
+
}
|
|
8595
|
+
}
|
|
8596
|
+
}
|
|
8597
|
+
|
|
8598
|
+
static void ggml_compute_forward_l2_norm(
|
|
8599
|
+
const struct ggml_compute_params * params,
|
|
8600
|
+
struct ggml_tensor * dst) {
|
|
8601
|
+
|
|
8602
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
8603
|
+
|
|
8604
|
+
switch (src0->type) {
|
|
8605
|
+
case GGML_TYPE_F32:
|
|
8606
|
+
{
|
|
8607
|
+
ggml_compute_forward_l2_norm_f32(params, dst);
|
|
8608
|
+
} break;
|
|
8609
|
+
default:
|
|
8610
|
+
{
|
|
8611
|
+
GGML_ABORT("fatal error");
|
|
8612
|
+
}
|
|
8613
|
+
}
|
|
8614
|
+
}
|
|
8615
|
+
|
|
8551
8616
|
// ggml_compute_forward_mul_mat
|
|
8552
8617
|
|
|
8553
8618
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
@@ -13604,6 +13669,184 @@ static void ggml_compute_forward_gla(
|
|
|
13604
13669
|
}
|
|
13605
13670
|
}
|
|
13606
13671
|
|
|
13672
|
+
// ggml_compute_forward_rwkv_wkv7
|
|
13673
|
+
|
|
13674
|
+
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
13675
|
+
const struct ggml_compute_params * params,
|
|
13676
|
+
struct ggml_tensor * dst) {
|
|
13677
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
13678
|
+
const int64_t C = dst->ne[0];
|
|
13679
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
13680
|
+
const int64_t n_seqs = dst->src[6]->ne[1];
|
|
13681
|
+
const int64_t head_size = C / HEADS;
|
|
13682
|
+
|
|
13683
|
+
float * dst_data = (float *) dst->data;
|
|
13684
|
+
float * state = ((float *) dst->data) + C * T;
|
|
13685
|
+
|
|
13686
|
+
const int ith = params->ith;
|
|
13687
|
+
const int nth = params->nth;
|
|
13688
|
+
|
|
13689
|
+
if (ith >= HEADS) {
|
|
13690
|
+
return;
|
|
13691
|
+
}
|
|
13692
|
+
|
|
13693
|
+
const int h_start = (HEADS * ith) / nth;
|
|
13694
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
13695
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
13696
|
+
|
|
13697
|
+
float * r = (float *) dst->src[0]->data;
|
|
13698
|
+
float * w = (float *) dst->src[1]->data;
|
|
13699
|
+
float * k = (float *) dst->src[2]->data;
|
|
13700
|
+
float * v = (float *) dst->src[3]->data;
|
|
13701
|
+
float * a = (float *) dst->src[4]->data;
|
|
13702
|
+
float * b = (float *) dst->src[5]->data;
|
|
13703
|
+
|
|
13704
|
+
int64_t t_stride = HEADS * head_size; // Same to C
|
|
13705
|
+
|
|
13706
|
+
int64_t h_stride = C / HEADS;
|
|
13707
|
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
|
13708
|
+
int64_t h_stride_2d = head_size * head_size;
|
|
13709
|
+
|
|
13710
|
+
#if defined(GGML_SIMD)
|
|
13711
|
+
for (int64_t t = 0; t < T; t++) {
|
|
13712
|
+
int64_t t_offset = t * t_stride;
|
|
13713
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
13714
|
+
float * state_cur = state + state_offset;
|
|
13715
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
13716
|
+
|
|
13717
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
13718
|
+
int64_t h_offset = h * h_stride;
|
|
13719
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
13720
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
13721
|
+
|
|
13722
|
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
|
13723
|
+
int64_t t_h_i_offset = t_h_offset + ii;
|
|
13724
|
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
|
13725
|
+
|
|
13726
|
+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
|
13727
|
+
|
|
13728
|
+
float sa = 0;
|
|
13729
|
+
{
|
|
13730
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
13731
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
13732
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
13733
|
+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
|
13734
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
13735
|
+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
|
13736
|
+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
|
13737
|
+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
|
13738
|
+
}
|
|
13739
|
+
}
|
|
13740
|
+
GGML_F32_VEC_REDUCE(sa, sum);
|
|
13741
|
+
}
|
|
13742
|
+
|
|
13743
|
+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
|
13744
|
+
|
|
13745
|
+
int64_t j = 0;
|
|
13746
|
+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
13747
|
+
for (; j < head_size; j += GGML_F32_STEP) {
|
|
13748
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
13749
|
+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
|
13750
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
|
13751
|
+
|
|
13752
|
+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
|
13753
|
+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
|
13754
|
+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
|
13755
|
+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
|
13756
|
+
|
|
13757
|
+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
|
13758
|
+
|
|
13759
|
+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
13760
|
+
// kv + s * decay + sa * b
|
|
13761
|
+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
|
13762
|
+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
|
13763
|
+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
|
13764
|
+
|
|
13765
|
+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
|
13766
|
+
}
|
|
13767
|
+
}
|
|
13768
|
+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
13769
|
+
|
|
13770
|
+
// There shouldn't be left-overs though.
|
|
13771
|
+
for (; j < head_size; j++) {
|
|
13772
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
13773
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
13774
|
+
|
|
13775
|
+
float r_val = r[t_h_j_offset];
|
|
13776
|
+
float w_val = w[t_h_j_offset];
|
|
13777
|
+
float k_val = k[t_h_j_offset];
|
|
13778
|
+
float b_val = b[t_h_j_offset];
|
|
13779
|
+
float kv_val = v[t_h_i_offset] * k_val;
|
|
13780
|
+
|
|
13781
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
13782
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
13783
|
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
13784
|
+
}
|
|
13785
|
+
}
|
|
13786
|
+
}
|
|
13787
|
+
}
|
|
13788
|
+
#else
|
|
13789
|
+
for (int64_t t = 0; t < T; t++) {
|
|
13790
|
+
int64_t t_offset = t * t_stride;
|
|
13791
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
13792
|
+
float * state_cur = state + state_offset;
|
|
13793
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
13794
|
+
|
|
13795
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
13796
|
+
int64_t h_offset = h * h_stride;
|
|
13797
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
13798
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
13799
|
+
|
|
13800
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
13801
|
+
int64_t t_h_i_offset = t_h_offset + i;
|
|
13802
|
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
13803
|
+
|
|
13804
|
+
float v_val = v[t_h_i_offset];
|
|
13805
|
+
|
|
13806
|
+
float sa = 0, result = 0;
|
|
13807
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
13808
|
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
|
13809
|
+
}
|
|
13810
|
+
|
|
13811
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
13812
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
13813
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
13814
|
+
|
|
13815
|
+
float r_val = r[t_h_j_offset];
|
|
13816
|
+
float w_val = w[t_h_j_offset];
|
|
13817
|
+
float k_val = k[t_h_j_offset];
|
|
13818
|
+
float b_val = b[t_h_j_offset];
|
|
13819
|
+
float kv_val = v_val * k_val;
|
|
13820
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
13821
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
13822
|
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
|
13823
|
+
}
|
|
13824
|
+
dst_data[t_h_i_offset] = result;
|
|
13825
|
+
}
|
|
13826
|
+
}
|
|
13827
|
+
}
|
|
13828
|
+
#endif
|
|
13829
|
+
}
|
|
13830
|
+
|
|
13831
|
+
|
|
13832
|
+
static void ggml_compute_forward_rwkv_wkv7(
|
|
13833
|
+
const struct ggml_compute_params * params,
|
|
13834
|
+
struct ggml_tensor * dst) {
|
|
13835
|
+
|
|
13836
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
13837
|
+
|
|
13838
|
+
switch (src0->type) {
|
|
13839
|
+
case GGML_TYPE_F32:
|
|
13840
|
+
{
|
|
13841
|
+
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
|
13842
|
+
} break;
|
|
13843
|
+
default:
|
|
13844
|
+
{
|
|
13845
|
+
GGML_ABORT("fatal error");
|
|
13846
|
+
}
|
|
13847
|
+
}
|
|
13848
|
+
}
|
|
13849
|
+
|
|
13607
13850
|
// ggml_compute_forward_map_unary
|
|
13608
13851
|
|
|
13609
13852
|
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -14067,7 +14310,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
14067
14310
|
}
|
|
14068
14311
|
|
|
14069
14312
|
// extra_buffer op?
|
|
14070
|
-
if (ggml_cpu_extra_compute_forward(params, tensor))
|
|
14313
|
+
if (ggml_cpu_extra_compute_forward(params, tensor)) {
|
|
14314
|
+
return;
|
|
14315
|
+
}
|
|
14071
14316
|
|
|
14072
14317
|
switch (tensor->op) {
|
|
14073
14318
|
case GGML_OP_DUP:
|
|
@@ -14170,6 +14415,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
14170
14415
|
{
|
|
14171
14416
|
ggml_compute_forward_group_norm(params, tensor);
|
|
14172
14417
|
} break;
|
|
14418
|
+
case GGML_OP_L2_NORM:
|
|
14419
|
+
{
|
|
14420
|
+
ggml_compute_forward_l2_norm(params, tensor);
|
|
14421
|
+
} break;
|
|
14173
14422
|
case GGML_OP_MUL_MAT:
|
|
14174
14423
|
{
|
|
14175
14424
|
ggml_compute_forward_mul_mat(params, tensor);
|
|
@@ -14357,6 +14606,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
14357
14606
|
{
|
|
14358
14607
|
ggml_compute_forward_gla(params, tensor);
|
|
14359
14608
|
} break;
|
|
14609
|
+
case GGML_OP_RWKV_WKV7:
|
|
14610
|
+
{
|
|
14611
|
+
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
|
14612
|
+
} break;
|
|
14360
14613
|
case GGML_OP_MAP_UNARY:
|
|
14361
14614
|
{
|
|
14362
14615
|
ggml_unary_op_f32_t fun;
|
|
@@ -14582,6 +14835,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
14582
14835
|
case GGML_OP_NORM:
|
|
14583
14836
|
case GGML_OP_RMS_NORM:
|
|
14584
14837
|
case GGML_OP_RMS_NORM_BACK:
|
|
14838
|
+
case GGML_OP_L2_NORM:
|
|
14585
14839
|
case GGML_OP_GROUP_NORM:
|
|
14586
14840
|
case GGML_OP_CONCAT:
|
|
14587
14841
|
case GGML_OP_MUL_MAT:
|
|
@@ -14648,14 +14902,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
14648
14902
|
case GGML_OP_FLASH_ATTN_BACK:
|
|
14649
14903
|
case GGML_OP_SSM_CONV:
|
|
14650
14904
|
case GGML_OP_SSM_SCAN:
|
|
14905
|
+
case GGML_OP_RWKV_WKV6:
|
|
14906
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
|
14907
|
+
case GGML_OP_RWKV_WKV7:
|
|
14651
14908
|
{
|
|
14652
14909
|
n_tasks = n_threads;
|
|
14653
14910
|
} break;
|
|
14654
14911
|
case GGML_OP_WIN_PART:
|
|
14655
14912
|
case GGML_OP_WIN_UNPART:
|
|
14656
14913
|
case GGML_OP_GET_REL_POS:
|
|
14657
|
-
case GGML_OP_RWKV_WKV6:
|
|
14658
|
-
case GGML_OP_GATED_LINEAR_ATTN:
|
|
14659
14914
|
case GGML_OP_MAP_UNARY:
|
|
14660
14915
|
case GGML_OP_MAP_BINARY:
|
|
14661
14916
|
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -112,7 +112,7 @@
|
|
|
112
112
|
#define cudaGraphExecDestroy hipGraphExecDestroy
|
|
113
113
|
#define cudaGraphLaunch hipGraphLaunch
|
|
114
114
|
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
|
115
|
-
#define
|
|
115
|
+
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
|
116
116
|
#define cudaGraphNodeType hipGraphNodeType
|
|
117
117
|
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
|
118
118
|
#define cudaGraphInstantiate hipGraphInstantiate
|
|
@@ -129,6 +129,7 @@
|
|
|
129
129
|
#define cudaGraph_t hipGraph_t
|
|
130
130
|
#define cudaStream_t hipStream_t
|
|
131
131
|
#define cudaSuccess hipSuccess
|
|
132
|
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
|
132
133
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
|
133
134
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
134
135
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
|
@@ -119,7 +119,7 @@
|
|
|
119
119
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
|
120
120
|
#define cudaGraphExec_t musaGraphExec_t
|
|
121
121
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
|
122
|
-
#define
|
|
122
|
+
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
|
123
123
|
#define cudaGraphGetNodes musaGraphGetNodes
|
|
124
124
|
#define cudaGraphInstantiate musaGraphInstantiate
|
|
125
125
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
@@ -132,6 +132,8 @@
|
|
|
132
132
|
#define cudaGraph_t musaGraph_t
|
|
133
133
|
#define cudaKernelNodeParams musaKernelNodeParams
|
|
134
134
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
135
|
+
#define cudaStreamBeginCapture musaStreamBeginCapture
|
|
135
136
|
#define cudaStreamEndCapture musaStreamEndCapture
|
|
137
|
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
|
136
138
|
|
|
137
139
|
typedef mt_bfloat16 nv_bfloat16;
|
|
@@ -285,6 +285,13 @@ typedef struct {
|
|
|
285
285
|
float eps;
|
|
286
286
|
} ggml_metal_kargs_rms_norm;
|
|
287
287
|
|
|
288
|
+
typedef struct {
|
|
289
|
+
int32_t ne00;
|
|
290
|
+
int32_t ne00_4;
|
|
291
|
+
uint64_t nb01;
|
|
292
|
+
float eps;
|
|
293
|
+
} ggml_metal_kargs_l2_norm;
|
|
294
|
+
|
|
288
295
|
typedef struct {
|
|
289
296
|
int64_t ne00;
|
|
290
297
|
int64_t ne01;
|
|
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
|
|
67
67
|
add_compile_definitions(GGML_USE_MUSA)
|
|
68
68
|
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
|
69
69
|
|
|
70
|
-
if (GGML_CUDA_GRAPHS)
|
|
71
|
-
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
|
72
|
-
endif()
|
|
73
|
-
|
|
74
70
|
if (GGML_CUDA_FORCE_MMQ)
|
|
75
71
|
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
|
76
72
|
endif()
|