@fugood/llama.node 1.4.11 → 1.4.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/package.json +15 -15
  2. package/scripts/llama.cpp.patch +31 -31
  3. package/src/llama.cpp/common/arg.cpp +128 -59
  4. package/src/llama.cpp/common/arg.h +1 -0
  5. package/src/llama.cpp/common/chat-parser.cpp +11 -0
  6. package/src/llama.cpp/common/chat.cpp +36 -7
  7. package/src/llama.cpp/common/chat.h +1 -0
  8. package/src/llama.cpp/common/common.cpp +42 -23
  9. package/src/llama.cpp/common/common.h +11 -1
  10. package/src/llama.cpp/common/llguidance.cpp +10 -6
  11. package/src/llama.cpp/common/regex-partial.cpp +13 -13
  12. package/src/llama.cpp/common/sampling.cpp +58 -14
  13. package/src/llama.cpp/common/sampling.h +3 -1
  14. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  15. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -1
  16. package/src/llama.cpp/ggml/src/CMakeLists.txt +23 -9
  17. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
  18. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  19. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +86 -25
  20. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +15 -8
  21. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +768 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +0 -4
  23. package/src/llama.cpp/include/llama.h +100 -12
  24. package/src/llama.cpp/src/CMakeLists.txt +4 -0
  25. package/src/llama.cpp/src/llama-adapter.cpp +12 -3
  26. package/src/llama.cpp/src/llama-adapter.h +7 -1
  27. package/src/llama.cpp/src/llama-arch.cpp +78 -0
  28. package/src/llama.cpp/src/llama-arch.h +8 -0
  29. package/src/llama.cpp/src/llama-chat.cpp +11 -0
  30. package/src/llama.cpp/src/llama-chat.h +1 -0
  31. package/src/llama.cpp/src/llama-context.cpp +637 -49
  32. package/src/llama.cpp/src/llama-context.h +43 -1
  33. package/src/llama.cpp/src/llama-grammar.cpp +40 -13
  34. package/src/llama.cpp/src/llama-grammar.h +2 -0
  35. package/src/llama.cpp/src/llama-graph.cpp +173 -5
  36. package/src/llama.cpp/src/llama-graph.h +71 -6
  37. package/src/llama.cpp/src/llama-hparams.cpp +4 -0
  38. package/src/llama.cpp/src/llama-hparams.h +12 -5
  39. package/src/llama.cpp/src/llama-kv-cache.h +1 -1
  40. package/src/llama.cpp/src/llama-mmap.cpp +11 -4
  41. package/src/llama.cpp/src/llama-model-loader.cpp +23 -0
  42. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  43. package/src/llama.cpp/src/llama-model-saver.cpp +3 -0
  44. package/src/llama.cpp/src/llama-model.cpp +337 -26
  45. package/src/llama.cpp/src/llama-model.h +13 -2
  46. package/src/llama.cpp/src/llama-sampling.cpp +1259 -186
  47. package/src/llama.cpp/src/llama-sampling.h +19 -7
  48. package/src/llama.cpp/src/llama-vocab.cpp +101 -33
  49. package/src/llama.cpp/src/llama-vocab.h +2 -0
  50. package/src/llama.cpp/src/llama.cpp +87 -64
  51. package/src/llama.cpp/src/models/afmoe.cpp +9 -5
  52. package/src/llama.cpp/src/models/bert.cpp +4 -2
  53. package/src/llama.cpp/src/models/cogvlm.cpp +5 -3
  54. package/src/llama.cpp/src/models/cohere2-iswa.cpp +3 -0
  55. package/src/llama.cpp/src/models/deepseek2.cpp +1 -1
  56. package/src/llama.cpp/src/models/gemma-embedding.cpp +2 -6
  57. package/src/llama.cpp/src/models/gemma2-iswa.cpp +5 -2
  58. package/src/llama.cpp/src/models/gemma3.cpp +3 -4
  59. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +4 -7
  60. package/src/llama.cpp/src/models/llama-iswa.cpp +6 -2
  61. package/src/llama.cpp/src/models/llama.cpp +19 -6
  62. package/src/llama.cpp/src/models/maincoder.cpp +117 -0
  63. package/src/llama.cpp/src/models/mimo2-iswa.cpp +123 -0
  64. package/src/llama.cpp/src/models/models.h +18 -0
  65. package/src/llama.cpp/src/models/modern-bert.cpp +116 -0
  66. package/src/llama.cpp/src/models/openai-moe-iswa.cpp +5 -2
  67. package/src/llama.cpp/src/models/plamo3.cpp +128 -0
  68. package/src/llama.cpp/src/models/smallthinker.cpp +11 -5
  69. package/src/llama.cpp/src/unicode.cpp +23 -14
@@ -69,6 +69,10 @@
69
69
  #define VECTOR_REGISTERS 16
70
70
  #endif
71
71
 
72
+ #if defined(__riscv_v_intrinsic)
73
+ #define LMUL 4
74
+ #endif
75
+
72
76
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
77
 
74
78
  namespace {
@@ -175,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
175
179
  }
176
180
  #endif
177
181
 
182
+ #if defined(__riscv_zvfh)
183
+ template <>
184
+ inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
185
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
186
+ }
187
+ inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
188
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
189
+ }
190
+ inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
191
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
192
+ }
193
+ inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
194
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
195
+ }
196
+ inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
197
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
198
+ }
199
+ inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
200
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
201
+ }
202
+ inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
203
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
204
+ }
205
+ inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
206
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
207
+ }
208
+ #endif
209
+
210
+ #if defined(__riscv_zvfbfwma)
211
+ inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
212
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
213
+ }
214
+ inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
215
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
216
+ }
217
+ inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
218
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
219
+ }
220
+ #endif
221
+
178
222
  ////////////////////////////////////////////////////////////////////////////////////////////////////
179
223
  // VECTORIZED HORIZONTAL SUM
180
224
 
@@ -227,6 +271,25 @@ inline float hsum(__m512 x) {
227
271
  }
228
272
  #endif // __AVX512F__
229
273
 
274
+ #if defined(__riscv_zvfh)
275
+ inline float hsum(vfloat32m1_t x) {
276
+ return __riscv_vfmv_f_s_f32m1_f32(
277
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
278
+ }
279
+ inline float hsum(vfloat32m2_t x) {
280
+ return __riscv_vfmv_f_s_f32m1_f32(
281
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
282
+ }
283
+ inline float hsum(vfloat32m4_t x) {
284
+ return __riscv_vfmv_f_s_f32m1_f32(
285
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
286
+ }
287
+ inline float hsum(vfloat32m8_t x) {
288
+ return __riscv_vfmv_f_s_f32m1_f32(
289
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
290
+ }
291
+ #endif
292
+
230
293
  ////////////////////////////////////////////////////////////////////////////////////////////////////
231
294
  // VECTORIZED MEMORY LOADING
232
295
 
@@ -315,6 +378,88 @@ template <> inline __m256bh load(const float *p) {
315
378
  }
316
379
  #endif
317
380
 
381
+ #if defined(__riscv_zvfh)
382
+ template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
383
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
384
+ }
385
+ template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
386
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
387
+ }
388
+ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
389
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
390
+ }
391
+ template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
392
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
393
+ }
394
+ template <> inline vfloat32m1_t load(const float *p) {
395
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
396
+ }
397
+ template <> inline vfloat32m2_t load(const float *p) {
398
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
399
+ }
400
+ template <> inline vfloat32m4_t load(const float *p) {
401
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
402
+ }
403
+ template <> inline vfloat32m8_t load(const float *p) {
404
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
405
+ }
406
+ #endif
407
+
408
+ #if defined(__riscv_zvfbfwma)
409
+ template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
410
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
411
+ }
412
+ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
413
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
414
+ }
415
+ template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
416
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
417
+ }
418
+ #endif
419
+
420
+ #if defined(__riscv_zvfh)
421
+ template <typename T> T set_zero();
422
+
423
+ template <> inline vfloat16mf2_t set_zero() {
424
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
425
+ }
426
+ template <> inline vfloat16m1_t set_zero() {
427
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
428
+ }
429
+ template <> inline vfloat16m2_t set_zero() {
430
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
431
+ }
432
+ template <> inline vfloat16m4_t set_zero() {
433
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
434
+ }
435
+ template <> inline vfloat32m1_t set_zero() {
436
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
437
+ }
438
+ template <> inline vfloat32m2_t set_zero() {
439
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
440
+ }
441
+ template <> inline vfloat32m4_t set_zero() {
442
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
443
+ }
444
+ template <> inline vfloat32m8_t set_zero() {
445
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
446
+ }
447
+ #endif
448
+
449
+ #if defined(__riscv_v_intrinsic)
450
+ template <typename T> size_t vlmax() {
451
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
452
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
453
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
454
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
455
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
456
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
457
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
458
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
459
+ return 0;
460
+ }
461
+ #endif
462
+
318
463
  ////////////////////////////////////////////////////////////////////////////////////////////////////
319
464
  // FLOATING POINT MATRIX MULTIPLICATION
320
465
 
@@ -488,6 +633,573 @@ class tinyBLAS {
488
633
  const int64_t ldc;
489
634
  };
490
635
 
636
+ #if defined(__riscv_v_intrinsic)
637
+ template <typename D, typename V, typename TA, typename TB, typename TC>
638
+ class tinyBLAS_RVV {
639
+ public:
640
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
641
+ const TA *A, int64_t lda,
642
+ const TB *B, int64_t ldb,
643
+ TC *C, int64_t ldc)
644
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
645
+ }
646
+
647
+ bool matmul(int64_t m, int64_t n) {
648
+ if (k % vlmax<V>() != 0) {
649
+ return false;
650
+ }
651
+
652
+ #if LMUL == 1
653
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
654
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
655
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
656
+ return true;
657
+ }
658
+ if (m % 8 == 0 ) {
659
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
660
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
661
+ return true;
662
+ }
663
+ if (m % 4 == 0) {
664
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
665
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
666
+ return true;
667
+ }
668
+ #elif LMUL == 2
669
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
670
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
671
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
672
+ return true;
673
+ }
674
+ if (m % 8 == 0 ) {
675
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
676
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
677
+ return true;
678
+ }
679
+ if (m % 4 == 0) {
680
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
681
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
682
+ return true;
683
+ }
684
+ #else // LMUL = 4
685
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
686
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
687
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
688
+ return true;
689
+ }
690
+ if (m % 8 == 0 ) {
691
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
692
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
693
+ return true;
694
+ }
695
+ if (m % 4 == 0) {
696
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
697
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
698
+ return true;
699
+ }
700
+ #endif
701
+ return false;
702
+ }
703
+
704
+ private:
705
+ template<int RM, int RN, int BM>
706
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
707
+ if (SIZE_N == RN) {
708
+ return gemm<RM, RN, BM>(m, n, BN);
709
+ }
710
+ if constexpr (RN > 1) {
711
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
712
+ } else {
713
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
714
+ GGML_ASSERT(false); // we have miss something.
715
+ }
716
+ }
717
+
718
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
719
+ size_t vl = vlmax<V>();
720
+ D Cv00 = set_zero<D>();
721
+ D Cv01 = set_zero<D>();
722
+ D Cv02 = set_zero<D>();
723
+ D Cv03 = set_zero<D>();
724
+ D Cv10 = set_zero<D>();
725
+ D Cv11 = set_zero<D>();
726
+ D Cv12 = set_zero<D>();
727
+ D Cv13 = set_zero<D>();
728
+ D Cv20 = set_zero<D>();
729
+ D Cv21 = set_zero<D>();
730
+ D Cv22 = set_zero<D>();
731
+ D Cv23 = set_zero<D>();
732
+ D Cv30 = set_zero<D>();
733
+ D Cv31 = set_zero<D>();
734
+ D Cv32 = set_zero<D>();
735
+ D Cv33 = set_zero<D>();
736
+ D Cv40 = set_zero<D>();
737
+ D Cv41 = set_zero<D>();
738
+ D Cv42 = set_zero<D>();
739
+ D Cv43 = set_zero<D>();
740
+ D Cv50 = set_zero<D>();
741
+ D Cv51 = set_zero<D>();
742
+ D Cv52 = set_zero<D>();
743
+ D Cv53 = set_zero<D>();
744
+
745
+ for (int64_t l = 0; l < k; l += vl) {
746
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
747
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
748
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
749
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
750
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
751
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
752
+
753
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
754
+ Cv00 = madd(Av0, Bv0, Cv00);
755
+ Cv10 = madd(Av0, Bv1, Cv10);
756
+ Cv20 = madd(Av0, Bv2, Cv20);
757
+ Cv30 = madd(Av0, Bv3, Cv30);
758
+ Cv40 = madd(Av0, Bv4, Cv40);
759
+ Cv50 = madd(Av0, Bv5, Cv50);
760
+
761
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
762
+ Cv01 = madd(Av1, Bv0, Cv01);
763
+ Cv11 = madd(Av1, Bv1, Cv11);
764
+ Cv21 = madd(Av1, Bv2, Cv21);
765
+ Cv31 = madd(Av1, Bv3, Cv31);
766
+ Cv41 = madd(Av1, Bv4, Cv41);
767
+ Cv51 = madd(Av1, Bv5, Cv51);
768
+
769
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
770
+ Cv02 = madd(Av2, Bv0, Cv02);
771
+ Cv12 = madd(Av2, Bv1, Cv12);
772
+ Cv22 = madd(Av2, Bv2, Cv22);
773
+ Cv32 = madd(Av2, Bv3, Cv32);
774
+ Cv42 = madd(Av2, Bv4, Cv42);
775
+ Cv52 = madd(Av2, Bv5, Cv52);
776
+
777
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
778
+ Cv03 = madd(Av3, Bv0, Cv03);
779
+ Cv13 = madd(Av3, Bv1, Cv13);
780
+ Cv23 = madd(Av3, Bv2, Cv23);
781
+ Cv33 = madd(Av3, Bv3, Cv33);
782
+ Cv43 = madd(Av3, Bv4, Cv43);
783
+ Cv53 = madd(Av3, Bv5, Cv53);
784
+ }
785
+
786
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
787
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
788
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
789
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
790
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
791
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
792
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
793
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
794
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
795
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
796
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
797
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
798
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
799
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
800
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
801
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
802
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
803
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
804
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
805
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
806
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
807
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
808
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
809
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
810
+ }
811
+
812
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
813
+ size_t vl = vlmax<V>();
814
+ D Cv00 = set_zero<D>();
815
+ D Cv01 = set_zero<D>();
816
+ D Cv02 = set_zero<D>();
817
+ D Cv03 = set_zero<D>();
818
+ D Cv10 = set_zero<D>();
819
+ D Cv11 = set_zero<D>();
820
+ D Cv12 = set_zero<D>();
821
+ D Cv13 = set_zero<D>();
822
+ D Cv20 = set_zero<D>();
823
+ D Cv21 = set_zero<D>();
824
+ D Cv22 = set_zero<D>();
825
+ D Cv23 = set_zero<D>();
826
+ D Cv30 = set_zero<D>();
827
+ D Cv31 = set_zero<D>();
828
+ D Cv32 = set_zero<D>();
829
+ D Cv33 = set_zero<D>();
830
+ D Cv40 = set_zero<D>();
831
+ D Cv41 = set_zero<D>();
832
+ D Cv42 = set_zero<D>();
833
+ D Cv43 = set_zero<D>();
834
+
835
+ for (int64_t l = 0; l < k; l += vl) {
836
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
837
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
838
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
839
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
840
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
841
+
842
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
843
+ Cv00 = madd(Av0, Bv0, Cv00);
844
+ Cv10 = madd(Av0, Bv1, Cv10);
845
+ Cv20 = madd(Av0, Bv2, Cv20);
846
+ Cv30 = madd(Av0, Bv3, Cv30);
847
+ Cv40 = madd(Av0, Bv4, Cv40);
848
+
849
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
850
+ Cv01 = madd(Av1, Bv0, Cv01);
851
+ Cv11 = madd(Av1, Bv1, Cv11);
852
+ Cv21 = madd(Av1, Bv2, Cv21);
853
+ Cv31 = madd(Av1, Bv3, Cv31);
854
+ Cv41 = madd(Av1, Bv4, Cv41);
855
+
856
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
857
+ Cv02 = madd(Av2, Bv0, Cv02);
858
+ Cv12 = madd(Av2, Bv1, Cv12);
859
+ Cv22 = madd(Av2, Bv2, Cv22);
860
+ Cv32 = madd(Av2, Bv3, Cv32);
861
+ Cv42 = madd(Av2, Bv4, Cv42);
862
+
863
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
864
+ Cv03 = madd(Av3, Bv0, Cv03);
865
+ Cv13 = madd(Av3, Bv1, Cv13);
866
+ Cv23 = madd(Av3, Bv2, Cv23);
867
+ Cv33 = madd(Av3, Bv3, Cv33);
868
+ Cv43 = madd(Av3, Bv4, Cv43);
869
+ }
870
+
871
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
872
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
873
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
874
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
875
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
876
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
877
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
878
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
879
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
880
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
881
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
882
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
883
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
884
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
885
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
886
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
887
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
888
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
889
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
890
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
891
+ }
892
+
893
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
894
+ size_t vl = vlmax<V>();
895
+ D Cv00 = set_zero<D>();
896
+ D Cv01 = set_zero<D>();
897
+ D Cv02 = set_zero<D>();
898
+ D Cv03 = set_zero<D>();
899
+ D Cv10 = set_zero<D>();
900
+ D Cv11 = set_zero<D>();
901
+ D Cv12 = set_zero<D>();
902
+ D Cv13 = set_zero<D>();
903
+ D Cv20 = set_zero<D>();
904
+ D Cv21 = set_zero<D>();
905
+ D Cv22 = set_zero<D>();
906
+ D Cv23 = set_zero<D>();
907
+ D Cv30 = set_zero<D>();
908
+ D Cv31 = set_zero<D>();
909
+ D Cv32 = set_zero<D>();
910
+ D Cv33 = set_zero<D>();
911
+
912
+ for (int64_t l = 0; l < k; l += vl) {
913
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
914
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
915
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
916
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
917
+
918
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
919
+ Cv00 = madd(Av0, Bv0, Cv00);
920
+ Cv01 = madd(Av1, Bv0, Cv01);
921
+ Cv02 = madd(Av2, Bv0, Cv02);
922
+ Cv03 = madd(Av3, Bv0, Cv03);
923
+
924
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
925
+ Cv10 = madd(Av0, Bv1, Cv10);
926
+ Cv11 = madd(Av1, Bv1, Cv11);
927
+ Cv12 = madd(Av2, Bv1, Cv12);
928
+ Cv13 = madd(Av3, Bv1, Cv13);
929
+
930
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
931
+ Cv20 = madd(Av0, Bv2, Cv20);
932
+ Cv21 = madd(Av1, Bv2, Cv21);
933
+ Cv22 = madd(Av2, Bv2, Cv22);
934
+ Cv23 = madd(Av3, Bv2, Cv23);
935
+
936
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
937
+ Cv30 = madd(Av0, Bv3, Cv30);
938
+ Cv31 = madd(Av1, Bv3, Cv31);
939
+ Cv32 = madd(Av2, Bv3, Cv32);
940
+ Cv33 = madd(Av3, Bv3, Cv33);
941
+ }
942
+
943
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
944
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
945
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
946
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
947
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
948
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
949
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
950
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
951
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
952
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
953
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
954
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
955
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
956
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
957
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
958
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
959
+ }
960
+
961
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
962
+ size_t vl = vlmax<V>();
963
+ D Cv00 = set_zero<D>();
964
+ D Cv01 = set_zero<D>();
965
+ D Cv02 = set_zero<D>();
966
+ D Cv03 = set_zero<D>();
967
+ D Cv10 = set_zero<D>();
968
+ D Cv11 = set_zero<D>();
969
+ D Cv12 = set_zero<D>();
970
+ D Cv13 = set_zero<D>();
971
+ D Cv20 = set_zero<D>();
972
+ D Cv21 = set_zero<D>();
973
+ D Cv22 = set_zero<D>();
974
+ D Cv23 = set_zero<D>();
975
+
976
+ for (int64_t l = 0; l < k; l += vl) {
977
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
978
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
979
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
980
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
981
+
982
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
983
+ Cv00 = madd(Av0, Bv0, Cv00);
984
+ Cv01 = madd(Av1, Bv0, Cv01);
985
+ Cv02 = madd(Av2, Bv0, Cv02);
986
+ Cv03 = madd(Av3, Bv0, Cv03);
987
+
988
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
989
+ Cv10 = madd(Av0, Bv1, Cv10);
990
+ Cv11 = madd(Av1, Bv1, Cv11);
991
+ Cv12 = madd(Av2, Bv1, Cv12);
992
+ Cv13 = madd(Av3, Bv1, Cv13);
993
+
994
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
995
+ Cv20 = madd(Av0, Bv2, Cv20);
996
+ Cv21 = madd(Av1, Bv2, Cv21);
997
+ Cv22 = madd(Av2, Bv2, Cv22);
998
+ Cv23 = madd(Av3, Bv2, Cv23);
999
+ }
1000
+
1001
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1002
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1003
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1004
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1005
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1006
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1007
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1008
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1009
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1010
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1011
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1012
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1013
+ }
1014
+
1015
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1016
+ size_t vl = vlmax<V>();
1017
+ D Cv00 = set_zero<D>();
1018
+ D Cv01 = set_zero<D>();
1019
+ D Cv02 = set_zero<D>();
1020
+ D Cv03 = set_zero<D>();
1021
+ D Cv10 = set_zero<D>();
1022
+ D Cv11 = set_zero<D>();
1023
+ D Cv12 = set_zero<D>();
1024
+ D Cv13 = set_zero<D>();
1025
+
1026
+ for (int64_t l = 0; l < k; l += vl) {
1027
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1028
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1029
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1030
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1031
+
1032
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1033
+ Cv00 = madd(Av0, Bv0, Cv00);
1034
+ Cv01 = madd(Av1, Bv0, Cv01);
1035
+ Cv02 = madd(Av2, Bv0, Cv02);
1036
+ Cv03 = madd(Av3, Bv0, Cv03);
1037
+
1038
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1039
+ Cv10 = madd(Av0, Bv1, Cv10);
1040
+ Cv11 = madd(Av1, Bv1, Cv11);
1041
+ Cv12 = madd(Av2, Bv1, Cv12);
1042
+ Cv13 = madd(Av3, Bv1, Cv13);
1043
+ }
1044
+
1045
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1046
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1047
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1048
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1049
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1050
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1051
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1052
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1053
+ }
1054
+
1055
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1056
+ size_t vl = vlmax<V>();
1057
+ D Cv00 = set_zero<D>();
1058
+ D Cv01 = set_zero<D>();
1059
+ D Cv02 = set_zero<D>();
1060
+ D Cv03 = set_zero<D>();
1061
+
1062
+ for (int64_t l = 0; l < k; l += vl) {
1063
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1064
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1065
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1066
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1067
+
1068
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1069
+ Cv00 = madd(Av0, Bv0, Cv00);
1070
+ Cv01 = madd(Av1, Bv0, Cv01);
1071
+ Cv02 = madd(Av2, Bv0, Cv02);
1072
+ Cv03 = madd(Av3, Bv0, Cv03);
1073
+ }
1074
+
1075
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1076
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1077
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1078
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1079
+ }
1080
+
1081
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1082
+ size_t vl = vlmax<V>();
1083
+ D Cv00 = set_zero<D>();
1084
+ D Cv01 = set_zero<D>();
1085
+ D Cv10 = set_zero<D>();
1086
+ D Cv11 = set_zero<D>();
1087
+
1088
+ for (int64_t l = 0; l < k; l += vl) {
1089
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1090
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1091
+
1092
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1093
+ Cv00 = madd(Av0, Bv0, Cv00);
1094
+ Cv01 = madd(Av1, Bv0, Cv01);
1095
+
1096
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1097
+ Cv10 = madd(Av0, Bv1, Cv10);
1098
+ Cv11 = madd(Av1, Bv1, Cv11);
1099
+ }
1100
+
1101
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1102
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1103
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1104
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1105
+ }
1106
+
1107
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1108
+ size_t vl = vlmax<V>();
1109
+ D Cv00 = set_zero<D>();
1110
+ D Cv01 = set_zero<D>();
1111
+
1112
+ for (int64_t l = 0; l < k; l += vl) {
1113
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1114
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1115
+
1116
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1117
+ Cv00 = madd(Av0, Bv0, Cv00);
1118
+ Cv01 = madd(Av1, Bv0, Cv01);
1119
+ }
1120
+
1121
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1122
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1123
+ }
1124
+
1125
+ template <int RM, int RN>
1126
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
1127
+ if constexpr (RM == 4) {
1128
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1129
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1130
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1131
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1132
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1133
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1134
+ } else if constexpr (RM == 2) {
1135
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1136
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1137
+ }
1138
+ }
1139
+
1140
+ template <int RM, int RN, int BM>
1141
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1142
+ GGML_ASSERT(m % (RM * BM) == 0);
1143
+ const int64_t ytiles = m / (RM * BM);
1144
+ const int64_t xtiles = (n + RN -1) / RN;
1145
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1146
+
1147
+ // "round" bloc_size to "nearest" BN
1148
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1149
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1150
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1151
+ const int64_t nb_job = ytiles * NB_BN;
1152
+
1153
+ if (params->ith == 0) {
1154
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1155
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1156
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
1157
+ }
1158
+
1159
+ ggml_barrier(params->threadpool);
1160
+
1161
+ int64_t job = params->ith;
1162
+ while (job < nb_job) {
1163
+ const int64_t ii = (job % ytiles) * RM * BM;
1164
+ const int64_t jb = job / ytiles;
1165
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
1166
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1167
+
1168
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1169
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1170
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1171
+
1172
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1173
+ int64_t jj = jj0;
1174
+ for (; jj < jj1; jj += RN) {
1175
+ gemm_bloc<RM, RN>(ii + bi, jj);
1176
+ }
1177
+ if constexpr (RN > 1) {
1178
+ for (; jj < jj2; jj += RN - 1) {
1179
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
1180
+ }
1181
+ }
1182
+ GGML_ASSERT(jj == jj2);
1183
+ }
1184
+
1185
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
1186
+ }
1187
+
1188
+ ggml_barrier(params->threadpool);
1189
+ return;
1190
+ }
1191
+
1192
+ const ggml_compute_params * params;
1193
+ const TA *const A;
1194
+ const TB *const B;
1195
+ TC *const C;
1196
+ const int64_t k;
1197
+ const int64_t lda;
1198
+ const int64_t ldb;
1199
+ const int64_t ldc;
1200
+ };
1201
+ #endif
1202
+
491
1203
  //////////////////////////////////////////////////////////////////////////////////////////
492
1204
  // QUANT ZERO MATRIX MULTIPLICATION
493
1205
 
@@ -2657,6 +3369,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2657
3369
  params->ith, params->nth};
2658
3370
  tb.matmul(m, n);
2659
3371
  return true;
3372
+ #elif defined(__riscv_zvfh)
3373
+ #if LMUL == 1
3374
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3375
+ k, (const float *)A, lda,
3376
+ (const float *)B, ldb,
3377
+ (float *)C, ldc};
3378
+ #elif LMUL == 2
3379
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3380
+ k, (const float *)A, lda,
3381
+ (const float *)B, ldb,
3382
+ (float *)C, ldc};
3383
+ #else // LMUL = 4
3384
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3385
+ k, (const float *)A, lda,
3386
+ (const float *)B, ldb,
3387
+ (float *)C, ldc};
3388
+ #endif
3389
+ return tb.matmul(m, n);
2660
3390
  #else
2661
3391
  return false;
2662
3392
  #endif
@@ -2699,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2699
3429
  tb.matmul(m, n);
2700
3430
  return true;
2701
3431
  }
3432
+ #elif defined(__riscv_zvfbfwma)
3433
+ #if LMUL == 1
3434
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3435
+ k, (const ggml_bf16_t *)A, lda,
3436
+ (const ggml_bf16_t *)B, ldb,
3437
+ (float *)C, ldc};
3438
+ #elif LMUL == 2
3439
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3440
+ k, (const ggml_bf16_t *)A, lda,
3441
+ (const ggml_bf16_t *)B, ldb,
3442
+ (float *)C, ldc};
3443
+ #else // LMUL = 4
3444
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3445
+ k, (const ggml_bf16_t *)A, lda,
3446
+ (const ggml_bf16_t *)B, ldb,
3447
+ (float *)C, ldc};
3448
+ #endif
3449
+ return tb.matmul(m, n);
2702
3450
  #endif
2703
3451
  return false;
2704
3452
  }
@@ -2748,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2748
3496
  (float *)C, ldc};
2749
3497
  return tb.matmul(m, n);
2750
3498
  }
3499
+ #elif defined(__riscv_zvfh)
3500
+ if (Btype == GGML_TYPE_F16) {
3501
+ #if LMUL == 1
3502
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3503
+ k, (const ggml_fp16_t *)A, lda,
3504
+ (const ggml_fp16_t *)B, ldb,
3505
+ (float *)C, ldc};
3506
+ #elif LMUL == 2
3507
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3508
+ k, (const ggml_fp16_t *)A, lda,
3509
+ (const ggml_fp16_t *)B, ldb,
3510
+ (float *)C, ldc};
3511
+ #else // LMUL = 4
3512
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3513
+ k, (const ggml_fp16_t *)A, lda,
3514
+ (const ggml_fp16_t *)B, ldb,
3515
+ (float *)C, ldc};
3516
+ #endif
3517
+ return tb.matmul(m, n);
3518
+ }
2751
3519
  #endif
2752
3520
  return false;
2753
3521
  }