@fugood/llama.node 0.3.14 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +32 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +12 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
  58. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  60. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  61. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  62. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  63. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
  64. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  65. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
  66. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  67. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  68. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  69. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  70. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
  71. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  72. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  73. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  74. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  75. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  78. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
  79. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
  81. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  82. package/src/llama.cpp/include/llama.h +86 -22
  83. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  84. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  85. package/src/llama.cpp/src/llama-adapter.h +11 -9
  86. package/src/llama.cpp/src/llama-arch.cpp +102 -16
  87. package/src/llama.cpp/src/llama-arch.h +18 -0
  88. package/src/llama.cpp/src/llama-batch.h +2 -2
  89. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  90. package/src/llama.cpp/src/llama-context.h +214 -77
  91. package/src/llama.cpp/src/llama-cparams.h +1 -0
  92. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  93. package/src/llama.cpp/src/llama-graph.h +574 -0
  94. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  95. package/src/llama.cpp/src/llama-hparams.h +9 -0
  96. package/src/llama.cpp/src/llama-io.cpp +15 -0
  97. package/src/llama.cpp/src/llama-io.h +35 -0
  98. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  99. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  100. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  101. package/src/llama.cpp/src/llama-memory.h +21 -0
  102. package/src/llama.cpp/src/llama-model.cpp +8207 -163
  103. package/src/llama.cpp/src/llama-model.h +34 -1
  104. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  105. package/src/llama.cpp/src/llama.cpp +51 -9984
  106. package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
  107. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  108. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2790
2790
  (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2791
2791
  output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
2792
2792
  output_ne_offset);
2793
+ int64_t antiquantGroupSize = 0;
2794
+ if (src0->ne[0] > QK8_0) {
2795
+ antiquantGroupSize = QK8_0;
2796
+ }
2793
2797
 
2794
2798
  ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2795
2799
  acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
2796
- nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
2800
+ nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
2797
2801
  &workspaceSize, &executor));
2798
2802
  if (workspaceAddr == nullptr) {
2799
2803
  workspaceAddr = workspace_allocator.alloc(workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2833
2837
 
2834
2838
  ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2835
2839
  acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836
- nullptr, nullptr, nullptr, nullptr, QK8_0,
2840
+ nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
2837
2841
  acl_output_tensor, &workspaceSize, &executor));
2838
2842
  ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2839
2843
  workspaceAddr, workspaceSize, executor, ctx.stream()));
@@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1689
1689
  case GGML_OP_MUL_MAT: {
1690
1690
  switch (op->src[0]->type) {
1691
1691
  case GGML_TYPE_Q8_0:
1692
- // Current groupsize should not be greater than k-1 in
1693
- // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
1694
- if (op->src[0]->ne[0] <= QK8_0) {
1695
- return false;
1696
- }
1697
1692
  case GGML_TYPE_F16:
1698
1693
  case GGML_TYPE_F32:
1699
1694
  case GGML_TYPE_Q4_0:
@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
287
287
  endif()
288
288
  endif()
289
289
  endif()
290
- elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
290
+ elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
291
291
  message(STATUS "PowerPC detected")
292
- execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M)
293
- if (${POWER_M} MATCHES "POWER10")
294
- list(APPEND ARCH_FLAGS -mcpu=power10)
295
- elseif (${POWER_M} MATCHES "POWER9")
296
- list(APPEND ARCH_FLAGS -mcpu=power9)
292
+ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
293
+ file(READ "/proc/cpuinfo" POWER10_M)
294
+ elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
295
+ execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
296
+ endif()
297
+
298
+ string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
299
+ string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
300
+
301
+ if (EXTRACTED_NUMBER GREATER_EQUAL 10)
302
+ list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
303
+ elseif (EXTRACTED_NUMBER EQUAL 9)
304
+ list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
297
305
  elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
298
306
  list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
299
307
  else()
300
- list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native)
308
+ list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
301
309
  endif()
302
310
  elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
303
311
  message(STATUS "loongarch64 detected")
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8158
8158
 
8159
8159
  const int nb = n / QK_K;
8160
8160
 
8161
- #ifdef __ARM_NEON
8161
+ #ifdef __ARM_FEATURE_SVE
8162
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163
+ float sum = 0;
8164
+ svuint8_t m4b = svdup_n_u8(0xf);
8165
+ svint32_t vzero = svdup_n_s32(0);
8166
+ svuint8_t mone = svdup_n_u8(0x30);
8167
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169
+
8170
+ for (int i = 0; i < nb; ++i) {
8171
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172
+
8173
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176
+
8177
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
8178
+
8179
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184
+ const svint64_t prod = svdup_n_s64(0);
8185
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186
+ svdot_s64(prod, q8sums_2, q6scales_2)));
8187
+ int32_t isum = 0;
8188
+
8189
+ switch (vector_length) {
8190
+ case 128:
8191
+ {
8192
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194
+ svint32_t isum_tmp = svdup_n_s32(0);
8195
+ for (int j = 0; j < QK_K/128; ++j) {
8196
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198
+ qh += 32;
8199
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203
+ q6 += 64;
8204
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208
+ q8 += 64;
8209
+
8210
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222
+
8223
+ scale += 4;
8224
+ q8bytes_1 = svld1_s8(pg8_16, q8);
8225
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228
+ q8 += 64;
8229
+
8230
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242
+ scale += 4;
8243
+ }
8244
+ isum += svaddv_s32(pg32_4, isum_tmp);
8245
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246
+ }
8247
+ break;
8248
+ case 256:
8249
+ case 512:
8250
+ {
8251
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254
+ svint32_t isum_tmp = svdup_n_s32(0);
8255
+ for (int j = 0; j < QK_K/128; j++) {
8256
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257
+ qh += 32;
8258
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260
+ q6 += 64;
8261
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265
+ q8 += 128;
8266
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274
+
8275
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291
+
8292
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296
+ scale += 8;
8297
+ }
8298
+ isum += svaddv_s32(pg32_8, isum_tmp);
8299
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300
+ }
8301
+ break;
8302
+ default:
8303
+ assert(false && "Unsupported vector length");
8304
+ break;
8305
+ }
8306
+ }
8307
+
8308
+ *s = sum;
8309
+
8310
+ #elif __ARM_NEON
8162
8311
  float sum = 0;
8163
8312
 
8164
8313
  const uint8x16_t m4b = vdupq_n_u8(0xF);
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
8548
8548
  }
8549
8549
  }
8550
8550
 
8551
+ // ggml_compute_forward_l2_norm
8552
+
8553
+ static void ggml_compute_forward_l2_norm_f32(
8554
+ const struct ggml_compute_params * params,
8555
+ struct ggml_tensor * dst) {
8556
+
8557
+ const struct ggml_tensor * src0 = dst->src[0];
8558
+
8559
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8560
+
8561
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
8562
+
8563
+ const int ith = params->ith;
8564
+ const int nth = params->nth;
8565
+
8566
+ GGML_TENSOR_UNARY_OP_LOCALS
8567
+
8568
+ float eps;
8569
+ memcpy(&eps, dst->op_params, sizeof(float));
8570
+
8571
+ GGML_ASSERT(eps >= 0.0f);
8572
+
8573
+ // TODO: optimize
8574
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8575
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8576
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
8577
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8578
+
8579
+ ggml_float sum = 0.0;
8580
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8581
+ sum += (ggml_float)(x[i00] * x[i00]);
8582
+ }
8583
+
8584
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
8585
+
8586
+ memcpy(y, x, ne00 * sizeof(float));
8587
+
8588
+ const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
8589
+
8590
+ ggml_vec_scale_f32(ne00, y, scale);
8591
+ }
8592
+ }
8593
+ }
8594
+ }
8595
+
8596
+ static void ggml_compute_forward_l2_norm(
8597
+ const struct ggml_compute_params * params,
8598
+ struct ggml_tensor * dst) {
8599
+
8600
+ const struct ggml_tensor * src0 = dst->src[0];
8601
+
8602
+ switch (src0->type) {
8603
+ case GGML_TYPE_F32:
8604
+ {
8605
+ ggml_compute_forward_l2_norm_f32(params, dst);
8606
+ } break;
8607
+ default:
8608
+ {
8609
+ GGML_ABORT("fatal error");
8610
+ }
8611
+ }
8612
+ }
8613
+
8551
8614
  // ggml_compute_forward_mul_mat
8552
8615
 
8553
8616
  static void ggml_compute_forward_mul_mat_one_chunk(
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
13604
13667
  }
13605
13668
  }
13606
13669
 
13670
+ // ggml_compute_forward_rwkv_wkv7
13671
+
13672
+ static void ggml_compute_forward_rwkv_wkv7_f32(
13673
+ const struct ggml_compute_params * params,
13674
+ struct ggml_tensor * dst) {
13675
+ const int64_t T = dst->src[1]->ne[2];
13676
+ const int64_t C = dst->ne[0];
13677
+ const int64_t HEADS = dst->src[1]->ne[1];
13678
+ const int64_t n_seqs = dst->src[6]->ne[1];
13679
+ const int64_t head_size = C / HEADS;
13680
+
13681
+ float * dst_data = (float *) dst->data;
13682
+ float * state = ((float *) dst->data) + C * T;
13683
+
13684
+ const int ith = params->ith;
13685
+ const int nth = params->nth;
13686
+
13687
+ if (ith >= HEADS) {
13688
+ return;
13689
+ }
13690
+
13691
+ const int h_start = (HEADS * ith) / nth;
13692
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
13693
+ (HEADS * (ith + 1)) / nth : HEADS;
13694
+
13695
+ float * r = (float *) dst->src[0]->data;
13696
+ float * w = (float *) dst->src[1]->data;
13697
+ float * k = (float *) dst->src[2]->data;
13698
+ float * v = (float *) dst->src[3]->data;
13699
+ float * a = (float *) dst->src[4]->data;
13700
+ float * b = (float *) dst->src[5]->data;
13701
+
13702
+ int64_t t_stride = HEADS * head_size; // Same to C
13703
+
13704
+ int64_t h_stride = C / HEADS;
13705
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
13706
+ int64_t h_stride_2d = head_size * head_size;
13707
+
13708
+ #if defined(GGML_SIMD)
13709
+ for (int64_t t = 0; t < T; t++) {
13710
+ int64_t t_offset = t * t_stride;
13711
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13712
+ float * state_cur = state + state_offset;
13713
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13714
+
13715
+ for (int64_t h = h_start; h < h_end; h++) {
13716
+ int64_t h_offset = h * h_stride;
13717
+ int64_t t_h_offset = t_offset + h_offset;
13718
+ int64_t h_2d_offset = h * h_stride_2d;
13719
+
13720
+ for (int64_t ii = 0; ii < head_size; ii++) {
13721
+ int64_t t_h_i_offset = t_h_offset + ii;
13722
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
13723
+
13724
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
13725
+
13726
+ float sa = 0;
13727
+ {
13728
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13729
+ GGML_F32_VEC ax[GGML_F32_ARR];
13730
+ GGML_F32_VEC ay[GGML_F32_ARR];
13731
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
13732
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13733
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
13734
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
13735
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
13736
+ }
13737
+ }
13738
+ GGML_F32_VEC_REDUCE(sa, sum);
13739
+ }
13740
+
13741
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
13742
+
13743
+ int64_t j = 0;
13744
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13745
+ for (; j < head_size; j += GGML_F32_STEP) {
13746
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13747
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
13748
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
13749
+
13750
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
13751
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
13752
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
13753
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
13754
+
13755
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
13756
+
13757
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
13758
+ // kv + s * decay + sa * b
13759
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
13760
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
13761
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
13762
+
13763
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
13764
+ }
13765
+ }
13766
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
13767
+
13768
+ // There shouldn't be left-overs though.
13769
+ for (; j < head_size; j++) {
13770
+ int64_t t_h_j_offset = t_h_offset + j;
13771
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13772
+
13773
+ float r_val = r[t_h_j_offset];
13774
+ float w_val = w[t_h_j_offset];
13775
+ float k_val = k[t_h_j_offset];
13776
+ float b_val = b[t_h_j_offset];
13777
+ float kv_val = v[t_h_i_offset] * k_val;
13778
+
13779
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13780
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13781
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
13782
+ }
13783
+ }
13784
+ }
13785
+ }
13786
+ #else
13787
+ for (int64_t t = 0; t < T; t++) {
13788
+ int64_t t_offset = t * t_stride;
13789
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13790
+ float * state_cur = state + state_offset;
13791
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13792
+
13793
+ for (int64_t h = h_start; h < h_end; h++) {
13794
+ int64_t h_offset = h * h_stride;
13795
+ int64_t t_h_offset = t_offset + h_offset;
13796
+ int64_t h_2d_offset = h * h_stride_2d;
13797
+
13798
+ for (int64_t i = 0; i < head_size; i++) {
13799
+ int64_t t_h_i_offset = t_h_offset + i;
13800
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
13801
+
13802
+ float v_val = v[t_h_i_offset];
13803
+
13804
+ float sa = 0, result = 0;
13805
+ for (int64_t j = 0; j < head_size; j++) {
13806
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
13807
+ }
13808
+
13809
+ for (int64_t j = 0; j < head_size; j++) {
13810
+ int64_t t_h_j_offset = t_h_offset + j;
13811
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13812
+
13813
+ float r_val = r[t_h_j_offset];
13814
+ float w_val = w[t_h_j_offset];
13815
+ float k_val = k[t_h_j_offset];
13816
+ float b_val = b[t_h_j_offset];
13817
+ float kv_val = v_val * k_val;
13818
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13819
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13820
+ result += state_cur[h_2d_i_j_offset] * r_val;
13821
+ }
13822
+ dst_data[t_h_i_offset] = result;
13823
+ }
13824
+ }
13825
+ }
13826
+ #endif
13827
+ }
13828
+
13829
+
13830
+ static void ggml_compute_forward_rwkv_wkv7(
13831
+ const struct ggml_compute_params * params,
13832
+ struct ggml_tensor * dst) {
13833
+
13834
+ const struct ggml_tensor * src0 = dst->src[0];
13835
+
13836
+ switch (src0->type) {
13837
+ case GGML_TYPE_F32:
13838
+ {
13839
+ ggml_compute_forward_rwkv_wkv7_f32(params, dst);
13840
+ } break;
13841
+ default:
13842
+ {
13843
+ GGML_ABORT("fatal error");
13844
+ }
13845
+ }
13846
+ }
13847
+
13607
13848
  // ggml_compute_forward_map_unary
13608
13849
 
13609
13850
  static void ggml_compute_forward_map_unary_f32(
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14170
14411
  {
14171
14412
  ggml_compute_forward_group_norm(params, tensor);
14172
14413
  } break;
14414
+ case GGML_OP_L2_NORM:
14415
+ {
14416
+ ggml_compute_forward_l2_norm(params, tensor);
14417
+ } break;
14173
14418
  case GGML_OP_MUL_MAT:
14174
14419
  {
14175
14420
  ggml_compute_forward_mul_mat(params, tensor);
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14357
14602
  {
14358
14603
  ggml_compute_forward_gla(params, tensor);
14359
14604
  } break;
14605
+ case GGML_OP_RWKV_WKV7:
14606
+ {
14607
+ ggml_compute_forward_rwkv_wkv7(params, tensor);
14608
+ } break;
14360
14609
  case GGML_OP_MAP_UNARY:
14361
14610
  {
14362
14611
  ggml_unary_op_f32_t fun;
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14582
14831
  case GGML_OP_NORM:
14583
14832
  case GGML_OP_RMS_NORM:
14584
14833
  case GGML_OP_RMS_NORM_BACK:
14834
+ case GGML_OP_L2_NORM:
14585
14835
  case GGML_OP_GROUP_NORM:
14586
14836
  case GGML_OP_CONCAT:
14587
14837
  case GGML_OP_MUL_MAT:
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14648
14898
  case GGML_OP_FLASH_ATTN_BACK:
14649
14899
  case GGML_OP_SSM_CONV:
14650
14900
  case GGML_OP_SSM_SCAN:
14901
+ case GGML_OP_RWKV_WKV6:
14902
+ case GGML_OP_GATED_LINEAR_ATTN:
14903
+ case GGML_OP_RWKV_WKV7:
14651
14904
  {
14652
14905
  n_tasks = n_threads;
14653
14906
  } break;
14654
14907
  case GGML_OP_WIN_PART:
14655
14908
  case GGML_OP_WIN_UNPART:
14656
14909
  case GGML_OP_GET_REL_POS:
14657
- case GGML_OP_RWKV_WKV6:
14658
- case GGML_OP_GATED_LINEAR_ATTN:
14659
14910
  case GGML_OP_MAP_UNARY:
14660
14911
  case GGML_OP_MAP_BINARY:
14661
14912
  case GGML_OP_MAP_CUSTOM1_F32:
@@ -112,7 +112,7 @@
112
112
  #define cudaGraphExecDestroy hipGraphExecDestroy
113
113
  #define cudaGraphLaunch hipGraphLaunch
114
114
  #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
- #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
115
+ #define cudaGraphExecUpdateResult hipGraphExecUpdateResult
116
116
  #define cudaGraphNodeType hipGraphNodeType
117
117
  #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
118
  #define cudaGraphInstantiate hipGraphInstantiate
@@ -129,6 +129,7 @@
129
129
  #define cudaGraph_t hipGraph_t
130
130
  #define cudaStream_t hipStream_t
131
131
  #define cudaSuccess hipSuccess
132
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132
133
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
133
134
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134
135
  #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
@@ -119,7 +119,7 @@
119
119
  #define cudaGraphExecDestroy musaGraphExecDestroy
120
120
  #define cudaGraphExec_t musaGraphExec_t
121
121
  #define cudaGraphExecUpdate musaGraphExecUpdate
122
- #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122
+ #define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123
123
  #define cudaGraphGetNodes musaGraphGetNodes
124
124
  #define cudaGraphInstantiate musaGraphInstantiate
125
125
  #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,8 @@
132
132
  #define cudaGraph_t musaGraph_t
133
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
+ #define cudaStreamBeginCapture musaStreamBeginCapture
135
136
  #define cudaStreamEndCapture musaStreamEndCapture
137
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
136
138
 
137
139
  typedef mt_bfloat16 nv_bfloat16;
@@ -285,6 +285,13 @@ typedef struct {
285
285
  float eps;
286
286
  } ggml_metal_kargs_rms_norm;
287
287
 
288
+ typedef struct {
289
+ int32_t ne00;
290
+ int32_t ne00_4;
291
+ uint64_t nb01;
292
+ float eps;
293
+ } ggml_metal_kargs_l2_norm;
294
+
288
295
  typedef struct {
289
296
  int64_t ne00;
290
297
  int64_t ne01;
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
67
67
  add_compile_definitions(GGML_USE_MUSA)
68
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
69
69
 
70
- if (GGML_CUDA_GRAPHS)
71
- add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72
- endif()
73
-
74
70
  if (GGML_CUDA_FORCE_MMQ)
75
71
  add_compile_definitions(GGML_CUDA_FORCE_MMQ)
76
72
  endif()