@fugood/llama.node 0.3.14 → 0.3.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +32 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +12 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +102 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8207 -163
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
|
|
|
2790
2790
|
(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
|
|
2791
2791
|
output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
|
|
2792
2792
|
output_ne_offset);
|
|
2793
|
+
int64_t antiquantGroupSize = 0;
|
|
2794
|
+
if (src0->ne[0] > QK8_0) {
|
|
2795
|
+
antiquantGroupSize = QK8_0;
|
|
2796
|
+
}
|
|
2793
2797
|
|
|
2794
2798
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
|
|
2795
2799
|
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
|
|
2796
|
-
nullptr, nullptr, nullptr,
|
|
2800
|
+
nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
|
|
2797
2801
|
&workspaceSize, &executor));
|
|
2798
2802
|
if (workspaceAddr == nullptr) {
|
|
2799
2803
|
workspaceAddr = workspace_allocator.alloc(workspaceSize);
|
|
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
|
|
|
2833
2837
|
|
|
2834
2838
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
|
|
2835
2839
|
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
|
|
2836
|
-
nullptr, nullptr, nullptr, nullptr,
|
|
2840
|
+
nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
|
|
2837
2841
|
acl_output_tensor, &workspaceSize, &executor));
|
|
2838
2842
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
|
|
2839
2843
|
workspaceAddr, workspaceSize, executor, ctx.stream()));
|
|
@@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
1689
1689
|
case GGML_OP_MUL_MAT: {
|
|
1690
1690
|
switch (op->src[0]->type) {
|
|
1691
1691
|
case GGML_TYPE_Q8_0:
|
|
1692
|
-
// Current groupsize should not be greater than k-1 in
|
|
1693
|
-
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
|
|
1694
|
-
if (op->src[0]->ne[0] <= QK8_0) {
|
|
1695
|
-
return false;
|
|
1696
|
-
}
|
|
1697
1692
|
case GGML_TYPE_F16:
|
|
1698
1693
|
case GGML_TYPE_F32:
|
|
1699
1694
|
case GGML_TYPE_Q4_0:
|
|
@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
287
287
|
endif()
|
|
288
288
|
endif()
|
|
289
289
|
endif()
|
|
290
|
-
elseif (${CMAKE_SYSTEM_PROCESSOR}
|
|
290
|
+
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
|
291
291
|
message(STATUS "PowerPC detected")
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
292
|
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
|
293
|
+
file(READ "/proc/cpuinfo" POWER10_M)
|
|
294
|
+
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
|
|
295
|
+
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
|
296
|
+
endif()
|
|
297
|
+
|
|
298
|
+
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
|
299
|
+
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
|
300
|
+
|
|
301
|
+
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
|
302
|
+
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
|
303
|
+
elseif (EXTRACTED_NUMBER EQUAL 9)
|
|
304
|
+
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
|
297
305
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
|
298
306
|
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
|
299
307
|
else()
|
|
300
|
-
list(APPEND ARCH_FLAGS -mcpu=
|
|
308
|
+
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
|
301
309
|
endif()
|
|
302
310
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
|
303
311
|
message(STATUS "loongarch64 detected")
|
|
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
8158
8158
|
|
|
8159
8159
|
const int nb = n / QK_K;
|
|
8160
8160
|
|
|
8161
|
-
#ifdef
|
|
8161
|
+
#ifdef __ARM_FEATURE_SVE
|
|
8162
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
8163
|
+
float sum = 0;
|
|
8164
|
+
svuint8_t m4b = svdup_n_u8(0xf);
|
|
8165
|
+
svint32_t vzero = svdup_n_s32(0);
|
|
8166
|
+
svuint8_t mone = svdup_n_u8(0x30);
|
|
8167
|
+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
|
8168
|
+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
|
8169
|
+
|
|
8170
|
+
for (int i = 0; i < nb; ++i) {
|
|
8171
|
+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
|
8172
|
+
|
|
8173
|
+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
|
8174
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
8175
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
8176
|
+
|
|
8177
|
+
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
|
8178
|
+
|
|
8179
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
|
8180
|
+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
|
8181
|
+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
|
8182
|
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
|
8183
|
+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
|
8184
|
+
const svint64_t prod = svdup_n_s64(0);
|
|
8185
|
+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
|
8186
|
+
svdot_s64(prod, q8sums_2, q6scales_2)));
|
|
8187
|
+
int32_t isum = 0;
|
|
8188
|
+
|
|
8189
|
+
switch (vector_length) {
|
|
8190
|
+
case 128:
|
|
8191
|
+
{
|
|
8192
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
|
8193
|
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
|
8194
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
8195
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
8196
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
|
8197
|
+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
|
8198
|
+
qh += 32;
|
|
8199
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
|
8200
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
|
8201
|
+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
|
8202
|
+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
|
8203
|
+
q6 += 64;
|
|
8204
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
|
8205
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
|
8206
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
|
8207
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
|
8208
|
+
q8 += 64;
|
|
8209
|
+
|
|
8210
|
+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
|
8211
|
+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
|
8212
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
|
8213
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
|
8214
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
|
8215
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
|
8216
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
|
8217
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
|
8218
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
|
8219
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
|
8220
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
|
8221
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
|
8222
|
+
|
|
8223
|
+
scale += 4;
|
|
8224
|
+
q8bytes_1 = svld1_s8(pg8_16, q8);
|
|
8225
|
+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
|
8226
|
+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
|
8227
|
+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
|
8228
|
+
q8 += 64;
|
|
8229
|
+
|
|
8230
|
+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
|
8231
|
+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
|
8232
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
|
8233
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
|
8234
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
|
8235
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
|
8236
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
|
8237
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
|
8238
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
|
8239
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
|
8240
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
|
8241
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
|
8242
|
+
scale += 4;
|
|
8243
|
+
}
|
|
8244
|
+
isum += svaddv_s32(pg32_4, isum_tmp);
|
|
8245
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
|
8246
|
+
}
|
|
8247
|
+
break;
|
|
8248
|
+
case 256:
|
|
8249
|
+
case 512:
|
|
8250
|
+
{
|
|
8251
|
+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
|
8252
|
+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
|
8253
|
+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
|
8254
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
8255
|
+
for (int j = 0; j < QK_K/128; j++) {
|
|
8256
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
|
8257
|
+
qh += 32;
|
|
8258
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
|
8259
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
|
8260
|
+
q6 += 64;
|
|
8261
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
|
8262
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
|
8263
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
|
8264
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
|
8265
|
+
q8 += 128;
|
|
8266
|
+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
|
8267
|
+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
|
8268
|
+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
|
8269
|
+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
|
8270
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
|
8271
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
|
8272
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
|
8273
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
|
8274
|
+
|
|
8275
|
+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
|
8276
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
|
8277
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
|
8278
|
+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
|
8279
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
|
8280
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
|
8281
|
+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
|
8282
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
|
8283
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
|
8284
|
+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
|
8285
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
|
8286
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
|
8287
|
+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
|
8288
|
+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
|
8289
|
+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
|
8290
|
+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
|
8291
|
+
|
|
8292
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
|
8293
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
|
8294
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
|
8295
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
|
8296
|
+
scale += 8;
|
|
8297
|
+
}
|
|
8298
|
+
isum += svaddv_s32(pg32_8, isum_tmp);
|
|
8299
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
|
8300
|
+
}
|
|
8301
|
+
break;
|
|
8302
|
+
default:
|
|
8303
|
+
assert(false && "Unsupported vector length");
|
|
8304
|
+
break;
|
|
8305
|
+
}
|
|
8306
|
+
}
|
|
8307
|
+
|
|
8308
|
+
*s = sum;
|
|
8309
|
+
|
|
8310
|
+
#elif __ARM_NEON
|
|
8162
8311
|
float sum = 0;
|
|
8163
8312
|
|
|
8164
8313
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
|
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
|
|
|
8548
8548
|
}
|
|
8549
8549
|
}
|
|
8550
8550
|
|
|
8551
|
+
// ggml_compute_forward_l2_norm
|
|
8552
|
+
|
|
8553
|
+
static void ggml_compute_forward_l2_norm_f32(
|
|
8554
|
+
const struct ggml_compute_params * params,
|
|
8555
|
+
struct ggml_tensor * dst) {
|
|
8556
|
+
|
|
8557
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
8558
|
+
|
|
8559
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
8560
|
+
|
|
8561
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8562
|
+
|
|
8563
|
+
const int ith = params->ith;
|
|
8564
|
+
const int nth = params->nth;
|
|
8565
|
+
|
|
8566
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
8567
|
+
|
|
8568
|
+
float eps;
|
|
8569
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
8570
|
+
|
|
8571
|
+
GGML_ASSERT(eps >= 0.0f);
|
|
8572
|
+
|
|
8573
|
+
// TODO: optimize
|
|
8574
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
8575
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
8576
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
8577
|
+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
8578
|
+
|
|
8579
|
+
ggml_float sum = 0.0;
|
|
8580
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
8581
|
+
sum += (ggml_float)(x[i00] * x[i00]);
|
|
8582
|
+
}
|
|
8583
|
+
|
|
8584
|
+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
8585
|
+
|
|
8586
|
+
memcpy(y, x, ne00 * sizeof(float));
|
|
8587
|
+
|
|
8588
|
+
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
|
8589
|
+
|
|
8590
|
+
ggml_vec_scale_f32(ne00, y, scale);
|
|
8591
|
+
}
|
|
8592
|
+
}
|
|
8593
|
+
}
|
|
8594
|
+
}
|
|
8595
|
+
|
|
8596
|
+
static void ggml_compute_forward_l2_norm(
|
|
8597
|
+
const struct ggml_compute_params * params,
|
|
8598
|
+
struct ggml_tensor * dst) {
|
|
8599
|
+
|
|
8600
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
8601
|
+
|
|
8602
|
+
switch (src0->type) {
|
|
8603
|
+
case GGML_TYPE_F32:
|
|
8604
|
+
{
|
|
8605
|
+
ggml_compute_forward_l2_norm_f32(params, dst);
|
|
8606
|
+
} break;
|
|
8607
|
+
default:
|
|
8608
|
+
{
|
|
8609
|
+
GGML_ABORT("fatal error");
|
|
8610
|
+
}
|
|
8611
|
+
}
|
|
8612
|
+
}
|
|
8613
|
+
|
|
8551
8614
|
// ggml_compute_forward_mul_mat
|
|
8552
8615
|
|
|
8553
8616
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
|
|
|
13604
13667
|
}
|
|
13605
13668
|
}
|
|
13606
13669
|
|
|
13670
|
+
// ggml_compute_forward_rwkv_wkv7
|
|
13671
|
+
|
|
13672
|
+
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
13673
|
+
const struct ggml_compute_params * params,
|
|
13674
|
+
struct ggml_tensor * dst) {
|
|
13675
|
+
const int64_t T = dst->src[1]->ne[2];
|
|
13676
|
+
const int64_t C = dst->ne[0];
|
|
13677
|
+
const int64_t HEADS = dst->src[1]->ne[1];
|
|
13678
|
+
const int64_t n_seqs = dst->src[6]->ne[1];
|
|
13679
|
+
const int64_t head_size = C / HEADS;
|
|
13680
|
+
|
|
13681
|
+
float * dst_data = (float *) dst->data;
|
|
13682
|
+
float * state = ((float *) dst->data) + C * T;
|
|
13683
|
+
|
|
13684
|
+
const int ith = params->ith;
|
|
13685
|
+
const int nth = params->nth;
|
|
13686
|
+
|
|
13687
|
+
if (ith >= HEADS) {
|
|
13688
|
+
return;
|
|
13689
|
+
}
|
|
13690
|
+
|
|
13691
|
+
const int h_start = (HEADS * ith) / nth;
|
|
13692
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
13693
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
13694
|
+
|
|
13695
|
+
float * r = (float *) dst->src[0]->data;
|
|
13696
|
+
float * w = (float *) dst->src[1]->data;
|
|
13697
|
+
float * k = (float *) dst->src[2]->data;
|
|
13698
|
+
float * v = (float *) dst->src[3]->data;
|
|
13699
|
+
float * a = (float *) dst->src[4]->data;
|
|
13700
|
+
float * b = (float *) dst->src[5]->data;
|
|
13701
|
+
|
|
13702
|
+
int64_t t_stride = HEADS * head_size; // Same to C
|
|
13703
|
+
|
|
13704
|
+
int64_t h_stride = C / HEADS;
|
|
13705
|
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
|
13706
|
+
int64_t h_stride_2d = head_size * head_size;
|
|
13707
|
+
|
|
13708
|
+
#if defined(GGML_SIMD)
|
|
13709
|
+
for (int64_t t = 0; t < T; t++) {
|
|
13710
|
+
int64_t t_offset = t * t_stride;
|
|
13711
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
13712
|
+
float * state_cur = state + state_offset;
|
|
13713
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
13714
|
+
|
|
13715
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
13716
|
+
int64_t h_offset = h * h_stride;
|
|
13717
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
13718
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
13719
|
+
|
|
13720
|
+
for (int64_t ii = 0; ii < head_size; ii++) {
|
|
13721
|
+
int64_t t_h_i_offset = t_h_offset + ii;
|
|
13722
|
+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
|
13723
|
+
|
|
13724
|
+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
|
13725
|
+
|
|
13726
|
+
float sa = 0;
|
|
13727
|
+
{
|
|
13728
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
13729
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
13730
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
13731
|
+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
|
13732
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
13733
|
+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
|
13734
|
+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
|
13735
|
+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
|
13736
|
+
}
|
|
13737
|
+
}
|
|
13738
|
+
GGML_F32_VEC_REDUCE(sa, sum);
|
|
13739
|
+
}
|
|
13740
|
+
|
|
13741
|
+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
|
13742
|
+
|
|
13743
|
+
int64_t j = 0;
|
|
13744
|
+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
13745
|
+
for (; j < head_size; j += GGML_F32_STEP) {
|
|
13746
|
+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
13747
|
+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
|
13748
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
|
13749
|
+
|
|
13750
|
+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
|
13751
|
+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
|
13752
|
+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
|
13753
|
+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
|
13754
|
+
|
|
13755
|
+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
|
13756
|
+
|
|
13757
|
+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
13758
|
+
// kv + s * decay + sa * b
|
|
13759
|
+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
|
13760
|
+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
|
13761
|
+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
|
13762
|
+
|
|
13763
|
+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
|
13764
|
+
}
|
|
13765
|
+
}
|
|
13766
|
+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
13767
|
+
|
|
13768
|
+
// There shouldn't be left-overs though.
|
|
13769
|
+
for (; j < head_size; j++) {
|
|
13770
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
13771
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
13772
|
+
|
|
13773
|
+
float r_val = r[t_h_j_offset];
|
|
13774
|
+
float w_val = w[t_h_j_offset];
|
|
13775
|
+
float k_val = k[t_h_j_offset];
|
|
13776
|
+
float b_val = b[t_h_j_offset];
|
|
13777
|
+
float kv_val = v[t_h_i_offset] * k_val;
|
|
13778
|
+
|
|
13779
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
13780
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
13781
|
+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
13782
|
+
}
|
|
13783
|
+
}
|
|
13784
|
+
}
|
|
13785
|
+
}
|
|
13786
|
+
#else
|
|
13787
|
+
for (int64_t t = 0; t < T; t++) {
|
|
13788
|
+
int64_t t_offset = t * t_stride;
|
|
13789
|
+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
13790
|
+
float * state_cur = state + state_offset;
|
|
13791
|
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
13792
|
+
|
|
13793
|
+
for (int64_t h = h_start; h < h_end; h++) {
|
|
13794
|
+
int64_t h_offset = h * h_stride;
|
|
13795
|
+
int64_t t_h_offset = t_offset + h_offset;
|
|
13796
|
+
int64_t h_2d_offset = h * h_stride_2d;
|
|
13797
|
+
|
|
13798
|
+
for (int64_t i = 0; i < head_size; i++) {
|
|
13799
|
+
int64_t t_h_i_offset = t_h_offset + i;
|
|
13800
|
+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
13801
|
+
|
|
13802
|
+
float v_val = v[t_h_i_offset];
|
|
13803
|
+
|
|
13804
|
+
float sa = 0, result = 0;
|
|
13805
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
13806
|
+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
|
13807
|
+
}
|
|
13808
|
+
|
|
13809
|
+
for (int64_t j = 0; j < head_size; j++) {
|
|
13810
|
+
int64_t t_h_j_offset = t_h_offset + j;
|
|
13811
|
+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
13812
|
+
|
|
13813
|
+
float r_val = r[t_h_j_offset];
|
|
13814
|
+
float w_val = w[t_h_j_offset];
|
|
13815
|
+
float k_val = k[t_h_j_offset];
|
|
13816
|
+
float b_val = b[t_h_j_offset];
|
|
13817
|
+
float kv_val = v_val * k_val;
|
|
13818
|
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
13819
|
+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
13820
|
+
result += state_cur[h_2d_i_j_offset] * r_val;
|
|
13821
|
+
}
|
|
13822
|
+
dst_data[t_h_i_offset] = result;
|
|
13823
|
+
}
|
|
13824
|
+
}
|
|
13825
|
+
}
|
|
13826
|
+
#endif
|
|
13827
|
+
}
|
|
13828
|
+
|
|
13829
|
+
|
|
13830
|
+
static void ggml_compute_forward_rwkv_wkv7(
|
|
13831
|
+
const struct ggml_compute_params * params,
|
|
13832
|
+
struct ggml_tensor * dst) {
|
|
13833
|
+
|
|
13834
|
+
const struct ggml_tensor * src0 = dst->src[0];
|
|
13835
|
+
|
|
13836
|
+
switch (src0->type) {
|
|
13837
|
+
case GGML_TYPE_F32:
|
|
13838
|
+
{
|
|
13839
|
+
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
|
13840
|
+
} break;
|
|
13841
|
+
default:
|
|
13842
|
+
{
|
|
13843
|
+
GGML_ABORT("fatal error");
|
|
13844
|
+
}
|
|
13845
|
+
}
|
|
13846
|
+
}
|
|
13847
|
+
|
|
13607
13848
|
// ggml_compute_forward_map_unary
|
|
13608
13849
|
|
|
13609
13850
|
static void ggml_compute_forward_map_unary_f32(
|
|
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
14170
14411
|
{
|
|
14171
14412
|
ggml_compute_forward_group_norm(params, tensor);
|
|
14172
14413
|
} break;
|
|
14414
|
+
case GGML_OP_L2_NORM:
|
|
14415
|
+
{
|
|
14416
|
+
ggml_compute_forward_l2_norm(params, tensor);
|
|
14417
|
+
} break;
|
|
14173
14418
|
case GGML_OP_MUL_MAT:
|
|
14174
14419
|
{
|
|
14175
14420
|
ggml_compute_forward_mul_mat(params, tensor);
|
|
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
14357
14602
|
{
|
|
14358
14603
|
ggml_compute_forward_gla(params, tensor);
|
|
14359
14604
|
} break;
|
|
14605
|
+
case GGML_OP_RWKV_WKV7:
|
|
14606
|
+
{
|
|
14607
|
+
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
|
14608
|
+
} break;
|
|
14360
14609
|
case GGML_OP_MAP_UNARY:
|
|
14361
14610
|
{
|
|
14362
14611
|
ggml_unary_op_f32_t fun;
|
|
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
14582
14831
|
case GGML_OP_NORM:
|
|
14583
14832
|
case GGML_OP_RMS_NORM:
|
|
14584
14833
|
case GGML_OP_RMS_NORM_BACK:
|
|
14834
|
+
case GGML_OP_L2_NORM:
|
|
14585
14835
|
case GGML_OP_GROUP_NORM:
|
|
14586
14836
|
case GGML_OP_CONCAT:
|
|
14587
14837
|
case GGML_OP_MUL_MAT:
|
|
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
14648
14898
|
case GGML_OP_FLASH_ATTN_BACK:
|
|
14649
14899
|
case GGML_OP_SSM_CONV:
|
|
14650
14900
|
case GGML_OP_SSM_SCAN:
|
|
14901
|
+
case GGML_OP_RWKV_WKV6:
|
|
14902
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
|
14903
|
+
case GGML_OP_RWKV_WKV7:
|
|
14651
14904
|
{
|
|
14652
14905
|
n_tasks = n_threads;
|
|
14653
14906
|
} break;
|
|
14654
14907
|
case GGML_OP_WIN_PART:
|
|
14655
14908
|
case GGML_OP_WIN_UNPART:
|
|
14656
14909
|
case GGML_OP_GET_REL_POS:
|
|
14657
|
-
case GGML_OP_RWKV_WKV6:
|
|
14658
|
-
case GGML_OP_GATED_LINEAR_ATTN:
|
|
14659
14910
|
case GGML_OP_MAP_UNARY:
|
|
14660
14911
|
case GGML_OP_MAP_BINARY:
|
|
14661
14912
|
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -112,7 +112,7 @@
|
|
|
112
112
|
#define cudaGraphExecDestroy hipGraphExecDestroy
|
|
113
113
|
#define cudaGraphLaunch hipGraphLaunch
|
|
114
114
|
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
|
115
|
-
#define
|
|
115
|
+
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
|
116
116
|
#define cudaGraphNodeType hipGraphNodeType
|
|
117
117
|
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
|
118
118
|
#define cudaGraphInstantiate hipGraphInstantiate
|
|
@@ -129,6 +129,7 @@
|
|
|
129
129
|
#define cudaGraph_t hipGraph_t
|
|
130
130
|
#define cudaStream_t hipStream_t
|
|
131
131
|
#define cudaSuccess hipSuccess
|
|
132
|
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
|
132
133
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
|
133
134
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
134
135
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
|
@@ -119,7 +119,7 @@
|
|
|
119
119
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
|
120
120
|
#define cudaGraphExec_t musaGraphExec_t
|
|
121
121
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
|
122
|
-
#define
|
|
122
|
+
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
|
123
123
|
#define cudaGraphGetNodes musaGraphGetNodes
|
|
124
124
|
#define cudaGraphInstantiate musaGraphInstantiate
|
|
125
125
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
@@ -132,6 +132,8 @@
|
|
|
132
132
|
#define cudaGraph_t musaGraph_t
|
|
133
133
|
#define cudaKernelNodeParams musaKernelNodeParams
|
|
134
134
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
135
|
+
#define cudaStreamBeginCapture musaStreamBeginCapture
|
|
135
136
|
#define cudaStreamEndCapture musaStreamEndCapture
|
|
137
|
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
|
136
138
|
|
|
137
139
|
typedef mt_bfloat16 nv_bfloat16;
|
|
@@ -285,6 +285,13 @@ typedef struct {
|
|
|
285
285
|
float eps;
|
|
286
286
|
} ggml_metal_kargs_rms_norm;
|
|
287
287
|
|
|
288
|
+
typedef struct {
|
|
289
|
+
int32_t ne00;
|
|
290
|
+
int32_t ne00_4;
|
|
291
|
+
uint64_t nb01;
|
|
292
|
+
float eps;
|
|
293
|
+
} ggml_metal_kargs_l2_norm;
|
|
294
|
+
|
|
288
295
|
typedef struct {
|
|
289
296
|
int64_t ne00;
|
|
290
297
|
int64_t ne01;
|
|
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
|
|
67
67
|
add_compile_definitions(GGML_USE_MUSA)
|
|
68
68
|
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
|
69
69
|
|
|
70
|
-
if (GGML_CUDA_GRAPHS)
|
|
71
|
-
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
|
72
|
-
endif()
|
|
73
|
-
|
|
74
70
|
if (GGML_CUDA_FORCE_MMQ)
|
|
75
71
|
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
|
76
72
|
endif()
|