@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8158
8158
 
8159
8159
  const int nb = n / QK_K;
8160
8160
 
8161
- #ifdef __ARM_NEON
8161
+ #ifdef __ARM_FEATURE_SVE
8162
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163
+ float sum = 0;
8164
+ svuint8_t m4b = svdup_n_u8(0xf);
8165
+ svint32_t vzero = svdup_n_s32(0);
8166
+ svuint8_t mone = svdup_n_u8(0x30);
8167
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169
+
8170
+ for (int i = 0; i < nb; ++i) {
8171
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172
+
8173
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176
+
8177
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
8178
+
8179
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184
+ const svint64_t prod = svdup_n_s64(0);
8185
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186
+ svdot_s64(prod, q8sums_2, q6scales_2)));
8187
+ int32_t isum = 0;
8188
+
8189
+ switch (vector_length) {
8190
+ case 128:
8191
+ {
8192
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194
+ svint32_t isum_tmp = svdup_n_s32(0);
8195
+ for (int j = 0; j < QK_K/128; ++j) {
8196
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198
+ qh += 32;
8199
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203
+ q6 += 64;
8204
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208
+ q8 += 64;
8209
+
8210
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222
+
8223
+ scale += 4;
8224
+ q8bytes_1 = svld1_s8(pg8_16, q8);
8225
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228
+ q8 += 64;
8229
+
8230
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242
+ scale += 4;
8243
+ }
8244
+ isum += svaddv_s32(pg32_4, isum_tmp);
8245
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246
+ }
8247
+ break;
8248
+ case 256:
8249
+ case 512:
8250
+ {
8251
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254
+ svint32_t isum_tmp = svdup_n_s32(0);
8255
+ for (int j = 0; j < QK_K/128; j++) {
8256
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257
+ qh += 32;
8258
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260
+ q6 += 64;
8261
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265
+ q8 += 128;
8266
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274
+
8275
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291
+
8292
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296
+ scale += 8;
8297
+ }
8298
+ isum += svaddv_s32(pg32_8, isum_tmp);
8299
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300
+ }
8301
+ break;
8302
+ default:
8303
+ assert(false && "Unsupported vector length");
8304
+ break;
8305
+ }
8306
+ }
8307
+
8308
+ *s = sum;
8309
+
8310
+ #elif __ARM_NEON
8162
8311
  float sum = 0;
8163
8312
 
8164
8313
  const uint8x16_t m4b = vdupq_n_u8(0xF);
@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
3110
3110
  const int ith = params->ith; // thread index
3111
3111
  const int nth = params->nth; // number of threads
3112
3112
 
3113
- // parallelize by elements
3114
- const int ne = ggml_nelements(dst);
3115
- const int dr = (ne + nth - 1) / nth;
3116
- const int ie0 = dr * ith;
3117
- const int ie1 = MIN(ie0 + dr, ne);
3113
+ // parallelize by blocks
3114
+ const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
3115
+ const int dr = (nk + nth - 1) / nth;
3116
+ const int k0 = dr * ith;
3117
+ const int k1 = MIN(k0 + dr, nk);
3118
3118
 
3119
- if (ie0 < ie1) {
3119
+ if (k0 < k1) {
3120
3120
  memcpy(
3121
- ((char *) dst->data + ie0*nb0),
3122
- ((char *) src0->data + ie0*nb0),
3123
- (ie1 - ie0) * nb0);
3121
+ ((char *) dst->data + k0*nb0),
3122
+ ((char *) src0->data + k0*nb0),
3123
+ (k1 - k0) * nb0);
3124
3124
  }
3125
3125
  }
3126
3126
 
@@ -4055,7 +4055,6 @@ static void ggml_compute_forward_dup_f32(
4055
4055
  static void ggml_compute_forward_dup_bytes(
4056
4056
  const struct ggml_compute_params * params,
4057
4057
  struct ggml_tensor * dst) {
4058
-
4059
4058
  const struct ggml_tensor * src0 = dst->src[0];
4060
4059
 
4061
4060
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
@@ -4069,10 +4068,10 @@ static void ggml_compute_forward_dup_bytes(
4069
4068
  }
4070
4069
 
4071
4070
  const size_t type_size = ggml_type_size(src0->type);
4071
+
4072
4072
  const int ith = params->ith; // thread index
4073
4073
  const int nth = params->nth; // number of threads
4074
4074
 
4075
-
4076
4075
  // parallelize by rows
4077
4076
  const int nr = ne01;
4078
4077
  // number of rows per thread
@@ -4082,10 +4081,10 @@ static void ggml_compute_forward_dup_bytes(
4082
4081
  const int ir1 = MIN(ir0 + dr, nr);
4083
4082
 
4084
4083
  if (src0->type == dst->type &&
4085
- ne00 == ne0 &&
4084
+ ggml_are_same_shape(src0, dst) &&
4086
4085
  nb00 == type_size && nb0 == type_size) {
4087
4086
  // copy by rows
4088
- const size_t rs = ne00 * type_size;
4087
+ const size_t rs = ggml_row_size(src0->type, ne00);
4089
4088
  for (int64_t i03 = 0; i03 < ne03; i03++) {
4090
4089
  for (int64_t i02 = 0; i02 < ne02; i02++) {
4091
4090
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -4140,17 +4139,20 @@ static void ggml_compute_forward_dup_bytes(
4140
4139
  }
4141
4140
 
4142
4141
  // dst counters
4143
-
4144
- int64_t i10 = 0;
4142
+ int64_t k10 = 0;
4145
4143
  int64_t i11 = 0;
4146
4144
  int64_t i12 = 0;
4147
4145
  int64_t i13 = 0;
4148
4146
 
4147
+ // number of blocks in a row
4148
+ const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
4149
+ const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
4150
+
4149
4151
  for (int64_t i03 = 0; i03 < ne03; i03++) {
4150
4152
  for (int64_t i02 = 0; i02 < ne02; i02++) {
4151
- i10 += ne00 * ir0;
4152
- while (i10 >= ne0) {
4153
- i10 -= ne0;
4153
+ k10 += nk00 * ir0;
4154
+ while (k10 >= nk0) {
4155
+ k10 -= nk0;
4154
4156
  if (++i11 == ne1) {
4155
4157
  i11 = 0;
4156
4158
  if (++i12 == ne2) {
@@ -4162,14 +4164,14 @@ static void ggml_compute_forward_dup_bytes(
4162
4164
  }
4163
4165
  }
4164
4166
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
4165
- for (int64_t i00 = 0; i00 < ne00; i00++) {
4166
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4167
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4167
+ for (int64_t k00 = 0; k00 < nk00; k00++) {
4168
+ const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4169
+ char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4168
4170
 
4169
4171
  memcpy(dst_ptr, src0_ptr, type_size);
4170
4172
 
4171
- if (++i10 == ne0) {
4172
- i10 = 0;
4173
+ if (++k10 == nk0) {
4174
+ k10 = 0;
4173
4175
  if (++i11 == ne1) {
4174
4176
  i11 = 0;
4175
4177
  if (++i12 == ne2) {
@@ -4182,9 +4184,9 @@ static void ggml_compute_forward_dup_bytes(
4182
4184
  }
4183
4185
  }
4184
4186
  }
4185
- i10 += ne00 * (ne01 - ir1);
4186
- while (i10 >= ne0) {
4187
- i10 -= ne0;
4187
+ k10 += nk00 * (ne01 - ir1);
4188
+ while (k10 >= nk0) {
4189
+ k10 -= nk0;
4188
4190
  if (++i11 == ne1) {
4189
4191
  i11 = 0;
4190
4192
  if (++i12 == ne2) {
@@ -8548,6 +8550,69 @@ static void ggml_compute_forward_group_norm(
8548
8550
  }
8549
8551
  }
8550
8552
 
8553
+ // ggml_compute_forward_l2_norm
8554
+
8555
+ static void ggml_compute_forward_l2_norm_f32(
8556
+ const struct ggml_compute_params * params,
8557
+ struct ggml_tensor * dst) {
8558
+
8559
+ const struct ggml_tensor * src0 = dst->src[0];
8560
+
8561
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8562
+
8563
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
8564
+
8565
+ const int ith = params->ith;
8566
+ const int nth = params->nth;
8567
+
8568
+ GGML_TENSOR_UNARY_OP_LOCALS
8569
+
8570
+ float eps;
8571
+ memcpy(&eps, dst->op_params, sizeof(float));
8572
+
8573
+ GGML_ASSERT(eps >= 0.0f);
8574
+
8575
+ // TODO: optimize
8576
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8577
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8578
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
8579
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8580
+
8581
+ ggml_float sum = 0.0;
8582
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8583
+ sum += (ggml_float)(x[i00] * x[i00]);
8584
+ }
8585
+
8586
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
8587
+
8588
+ memcpy(y, x, ne00 * sizeof(float));
8589
+
8590
+ const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
8591
+
8592
+ ggml_vec_scale_f32(ne00, y, scale);
8593
+ }
8594
+ }
8595
+ }
8596
+ }
8597
+
8598
+ static void ggml_compute_forward_l2_norm(
8599
+ const struct ggml_compute_params * params,
8600
+ struct ggml_tensor * dst) {
8601
+
8602
+ const struct ggml_tensor * src0 = dst->src[0];
8603
+
8604
+ switch (src0->type) {
8605
+ case GGML_TYPE_F32:
8606
+ {
8607
+ ggml_compute_forward_l2_norm_f32(params, dst);
8608
+ } break;
8609
+ default:
8610
+ {
8611
+ GGML_ABORT("fatal error");
8612
+ }
8613
+ }
8614
+ }
8615
+
8551
8616
  // ggml_compute_forward_mul_mat
8552
8617
 
8553
8618
  static void ggml_compute_forward_mul_mat_one_chunk(
@@ -13604,6 +13669,184 @@ static void ggml_compute_forward_gla(
13604
13669
  }
13605
13670
  }
13606
13671
 
13672
+ // ggml_compute_forward_rwkv_wkv7
13673
+
13674
+ static void ggml_compute_forward_rwkv_wkv7_f32(
13675
+ const struct ggml_compute_params * params,
13676
+ struct ggml_tensor * dst) {
13677
+ const int64_t T = dst->src[1]->ne[2];
13678
+ const int64_t C = dst->ne[0];
13679
+ const int64_t HEADS = dst->src[1]->ne[1];
13680
+ const int64_t n_seqs = dst->src[6]->ne[1];
13681
+ const int64_t head_size = C / HEADS;
13682
+
13683
+ float * dst_data = (float *) dst->data;
13684
+ float * state = ((float *) dst->data) + C * T;
13685
+
13686
+ const int ith = params->ith;
13687
+ const int nth = params->nth;
13688
+
13689
+ if (ith >= HEADS) {
13690
+ return;
13691
+ }
13692
+
13693
+ const int h_start = (HEADS * ith) / nth;
13694
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
13695
+ (HEADS * (ith + 1)) / nth : HEADS;
13696
+
13697
+ float * r = (float *) dst->src[0]->data;
13698
+ float * w = (float *) dst->src[1]->data;
13699
+ float * k = (float *) dst->src[2]->data;
13700
+ float * v = (float *) dst->src[3]->data;
13701
+ float * a = (float *) dst->src[4]->data;
13702
+ float * b = (float *) dst->src[5]->data;
13703
+
13704
+ int64_t t_stride = HEADS * head_size; // Same to C
13705
+
13706
+ int64_t h_stride = C / HEADS;
13707
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
13708
+ int64_t h_stride_2d = head_size * head_size;
13709
+
13710
+ #if defined(GGML_SIMD)
13711
+ for (int64_t t = 0; t < T; t++) {
13712
+ int64_t t_offset = t * t_stride;
13713
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13714
+ float * state_cur = state + state_offset;
13715
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13716
+
13717
+ for (int64_t h = h_start; h < h_end; h++) {
13718
+ int64_t h_offset = h * h_stride;
13719
+ int64_t t_h_offset = t_offset + h_offset;
13720
+ int64_t h_2d_offset = h * h_stride_2d;
13721
+
13722
+ for (int64_t ii = 0; ii < head_size; ii++) {
13723
+ int64_t t_h_i_offset = t_h_offset + ii;
13724
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
13725
+
13726
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
13727
+
13728
+ float sa = 0;
13729
+ {
13730
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13731
+ GGML_F32_VEC ax[GGML_F32_ARR];
13732
+ GGML_F32_VEC ay[GGML_F32_ARR];
13733
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
13734
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13735
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
13736
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
13737
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
13738
+ }
13739
+ }
13740
+ GGML_F32_VEC_REDUCE(sa, sum);
13741
+ }
13742
+
13743
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
13744
+
13745
+ int64_t j = 0;
13746
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13747
+ for (; j < head_size; j += GGML_F32_STEP) {
13748
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13749
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
13750
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
13751
+
13752
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
13753
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
13754
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
13755
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
13756
+
13757
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
13758
+
13759
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
13760
+ // kv + s * decay + sa * b
13761
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
13762
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
13763
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
13764
+
13765
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
13766
+ }
13767
+ }
13768
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
13769
+
13770
+ // There shouldn't be left-overs though.
13771
+ for (; j < head_size; j++) {
13772
+ int64_t t_h_j_offset = t_h_offset + j;
13773
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13774
+
13775
+ float r_val = r[t_h_j_offset];
13776
+ float w_val = w[t_h_j_offset];
13777
+ float k_val = k[t_h_j_offset];
13778
+ float b_val = b[t_h_j_offset];
13779
+ float kv_val = v[t_h_i_offset] * k_val;
13780
+
13781
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13782
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13783
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
13784
+ }
13785
+ }
13786
+ }
13787
+ }
13788
+ #else
13789
+ for (int64_t t = 0; t < T; t++) {
13790
+ int64_t t_offset = t * t_stride;
13791
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
13792
+ float * state_cur = state + state_offset;
13793
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13794
+
13795
+ for (int64_t h = h_start; h < h_end; h++) {
13796
+ int64_t h_offset = h * h_stride;
13797
+ int64_t t_h_offset = t_offset + h_offset;
13798
+ int64_t h_2d_offset = h * h_stride_2d;
13799
+
13800
+ for (int64_t i = 0; i < head_size; i++) {
13801
+ int64_t t_h_i_offset = t_h_offset + i;
13802
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
13803
+
13804
+ float v_val = v[t_h_i_offset];
13805
+
13806
+ float sa = 0, result = 0;
13807
+ for (int64_t j = 0; j < head_size; j++) {
13808
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
13809
+ }
13810
+
13811
+ for (int64_t j = 0; j < head_size; j++) {
13812
+ int64_t t_h_j_offset = t_h_offset + j;
13813
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13814
+
13815
+ float r_val = r[t_h_j_offset];
13816
+ float w_val = w[t_h_j_offset];
13817
+ float k_val = k[t_h_j_offset];
13818
+ float b_val = b[t_h_j_offset];
13819
+ float kv_val = v_val * k_val;
13820
+ float prev_state_val = state_prev[h_2d_i_j_offset];
13821
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13822
+ result += state_cur[h_2d_i_j_offset] * r_val;
13823
+ }
13824
+ dst_data[t_h_i_offset] = result;
13825
+ }
13826
+ }
13827
+ }
13828
+ #endif
13829
+ }
13830
+
13831
+
13832
+ static void ggml_compute_forward_rwkv_wkv7(
13833
+ const struct ggml_compute_params * params,
13834
+ struct ggml_tensor * dst) {
13835
+
13836
+ const struct ggml_tensor * src0 = dst->src[0];
13837
+
13838
+ switch (src0->type) {
13839
+ case GGML_TYPE_F32:
13840
+ {
13841
+ ggml_compute_forward_rwkv_wkv7_f32(params, dst);
13842
+ } break;
13843
+ default:
13844
+ {
13845
+ GGML_ABORT("fatal error");
13846
+ }
13847
+ }
13848
+ }
13849
+
13607
13850
  // ggml_compute_forward_map_unary
13608
13851
 
13609
13852
  static void ggml_compute_forward_map_unary_f32(
@@ -14067,7 +14310,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14067
14310
  }
14068
14311
 
14069
14312
  // extra_buffer op?
14070
- if (ggml_cpu_extra_compute_forward(params, tensor)) return;
14313
+ if (ggml_cpu_extra_compute_forward(params, tensor)) {
14314
+ return;
14315
+ }
14071
14316
 
14072
14317
  switch (tensor->op) {
14073
14318
  case GGML_OP_DUP:
@@ -14170,6 +14415,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14170
14415
  {
14171
14416
  ggml_compute_forward_group_norm(params, tensor);
14172
14417
  } break;
14418
+ case GGML_OP_L2_NORM:
14419
+ {
14420
+ ggml_compute_forward_l2_norm(params, tensor);
14421
+ } break;
14173
14422
  case GGML_OP_MUL_MAT:
14174
14423
  {
14175
14424
  ggml_compute_forward_mul_mat(params, tensor);
@@ -14357,6 +14606,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14357
14606
  {
14358
14607
  ggml_compute_forward_gla(params, tensor);
14359
14608
  } break;
14609
+ case GGML_OP_RWKV_WKV7:
14610
+ {
14611
+ ggml_compute_forward_rwkv_wkv7(params, tensor);
14612
+ } break;
14360
14613
  case GGML_OP_MAP_UNARY:
14361
14614
  {
14362
14615
  ggml_unary_op_f32_t fun;
@@ -14582,6 +14835,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14582
14835
  case GGML_OP_NORM:
14583
14836
  case GGML_OP_RMS_NORM:
14584
14837
  case GGML_OP_RMS_NORM_BACK:
14838
+ case GGML_OP_L2_NORM:
14585
14839
  case GGML_OP_GROUP_NORM:
14586
14840
  case GGML_OP_CONCAT:
14587
14841
  case GGML_OP_MUL_MAT:
@@ -14648,14 +14902,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
14648
14902
  case GGML_OP_FLASH_ATTN_BACK:
14649
14903
  case GGML_OP_SSM_CONV:
14650
14904
  case GGML_OP_SSM_SCAN:
14905
+ case GGML_OP_RWKV_WKV6:
14906
+ case GGML_OP_GATED_LINEAR_ATTN:
14907
+ case GGML_OP_RWKV_WKV7:
14651
14908
  {
14652
14909
  n_tasks = n_threads;
14653
14910
  } break;
14654
14911
  case GGML_OP_WIN_PART:
14655
14912
  case GGML_OP_WIN_UNPART:
14656
14913
  case GGML_OP_GET_REL_POS:
14657
- case GGML_OP_RWKV_WKV6:
14658
- case GGML_OP_GATED_LINEAR_ATTN:
14659
14914
  case GGML_OP_MAP_UNARY:
14660
14915
  case GGML_OP_MAP_BINARY:
14661
14916
  case GGML_OP_MAP_CUSTOM1_F32:
@@ -112,7 +112,7 @@
112
112
  #define cudaGraphExecDestroy hipGraphExecDestroy
113
113
  #define cudaGraphLaunch hipGraphLaunch
114
114
  #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
- #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
115
+ #define cudaGraphExecUpdateResult hipGraphExecUpdateResult
116
116
  #define cudaGraphNodeType hipGraphNodeType
117
117
  #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
118
  #define cudaGraphInstantiate hipGraphInstantiate
@@ -129,6 +129,7 @@
129
129
  #define cudaGraph_t hipGraph_t
130
130
  #define cudaStream_t hipStream_t
131
131
  #define cudaSuccess hipSuccess
132
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132
133
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
133
134
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134
135
  #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
@@ -119,7 +119,7 @@
119
119
  #define cudaGraphExecDestroy musaGraphExecDestroy
120
120
  #define cudaGraphExec_t musaGraphExec_t
121
121
  #define cudaGraphExecUpdate musaGraphExecUpdate
122
- #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122
+ #define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123
123
  #define cudaGraphGetNodes musaGraphGetNodes
124
124
  #define cudaGraphInstantiate musaGraphInstantiate
125
125
  #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,8 @@
132
132
  #define cudaGraph_t musaGraph_t
133
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
+ #define cudaStreamBeginCapture musaStreamBeginCapture
135
136
  #define cudaStreamEndCapture musaStreamEndCapture
137
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
136
138
 
137
139
  typedef mt_bfloat16 nv_bfloat16;
@@ -285,6 +285,13 @@ typedef struct {
285
285
  float eps;
286
286
  } ggml_metal_kargs_rms_norm;
287
287
 
288
+ typedef struct {
289
+ int32_t ne00;
290
+ int32_t ne00_4;
291
+ uint64_t nb01;
292
+ float eps;
293
+ } ggml_metal_kargs_l2_norm;
294
+
288
295
  typedef struct {
289
296
  int64_t ne00;
290
297
  int64_t ne01;
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
67
67
  add_compile_definitions(GGML_USE_MUSA)
68
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
69
69
 
70
- if (GGML_CUDA_GRAPHS)
71
- add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72
- endif()
73
-
74
70
  if (GGML_CUDA_FORCE_MMQ)
75
71
  add_compile_definitions(GGML_CUDA_FORCE_MMQ)
76
72
  endif()