@fugood/llama.node 1.3.7 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/lib/binding.js +18 -1
  2. package/lib/binding.ts +19 -1
  3. package/lib/index.js +3 -3
  4. package/lib/index.ts +1 -1
  5. package/package.json +15 -15
  6. package/scripts/llama.cpp.patch +7 -7
  7. package/src/LlamaCompletionWorker.cpp +2 -2
  8. package/src/llama.cpp/common/arg.cpp +27 -2
  9. package/src/llama.cpp/common/chat-parser.cpp +968 -0
  10. package/src/llama.cpp/common/chat.cpp +0 -952
  11. package/src/llama.cpp/common/common.cpp +55 -0
  12. package/src/llama.cpp/common/common.h +18 -0
  13. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -2
  14. package/src/llama.cpp/ggml/CMakeLists.txt +6 -4
  15. package/src/llama.cpp/ggml/include/ggml-rpc.h +1 -1
  16. package/src/llama.cpp/ggml/include/ggml.h +12 -4
  17. package/src/llama.cpp/ggml/src/CMakeLists.txt +26 -4
  18. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +29 -15
  19. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +721 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  21. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +22 -2
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +9 -0
  23. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +71 -4
  24. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  25. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +243 -4
  26. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +6 -0
  27. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +84 -85
  28. package/src/llama.cpp/include/llama.h +18 -0
  29. package/src/llama.cpp/src/CMakeLists.txt +2 -0
  30. package/src/llama.cpp/src/llama-arch.cpp +95 -16
  31. package/src/llama.cpp/src/llama-arch.h +15 -0
  32. package/src/llama.cpp/src/llama-context.cpp +7 -3
  33. package/src/llama.cpp/src/llama-graph.cpp +3 -3
  34. package/src/llama.cpp/src/llama-hparams.h +1 -1
  35. package/src/llama.cpp/src/llama-model.cpp +141 -6
  36. package/src/llama.cpp/src/llama-model.h +4 -0
  37. package/src/llama.cpp/src/llama-quant.cpp +13 -5
  38. package/src/llama.cpp/src/models/lfm2.cpp +5 -3
  39. package/src/llama.cpp/src/models/models.h +55 -1
  40. package/src/llama.cpp/src/models/qwen3next.cpp +1042 -0
  41. package/src/llama.cpp/src/models/rnd1.cpp +126 -0
@@ -24,6 +24,29 @@
24
24
 
25
25
  #define UNUSED GGML_UNUSED
26
26
 
27
+ static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
28
+ int16x8_t * out_mins,
29
+ int8_t * out_scales) {
30
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
31
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
32
+ constexpr uint32_t kmask3 = 0x03030303;
33
+ constexpr uint8_t scales_size = 12;
34
+
35
+ uint32_t sm[3];
36
+ memcpy(sm, scales_in, scales_size);
37
+
38
+ const uint32_t mins_0_3 = sm[1] & kmask1;
39
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
40
+ const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
41
+
42
+ *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
43
+
44
+ uint32_t scales_u32[2];
45
+ scales_u32[0] = sm[0] & kmask1;
46
+ scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
47
+ memcpy(out_scales, scales_u32, 8);
48
+ }
49
+
27
50
  void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
28
51
  assert(QK8_0 == 32);
29
52
  assert(k % QK8_0 == 0);
@@ -474,6 +497,295 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
474
497
  ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
475
498
  }
476
499
 
500
+ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
501
+ constexpr int qk = QK_K;
502
+ const int nb = n / qk;
503
+
504
+ constexpr int ncols_interleaved = 8;
505
+ constexpr int blocklen = 8;
506
+
507
+ assert(n % qk == 0);
508
+ assert(nr % 4 == 0);
509
+ assert(nc % ncols_interleaved == 0);
510
+
511
+ UNUSED(nb);
512
+ UNUSED(ncols_interleaved);
513
+ UNUSED(blocklen);
514
+
515
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
516
+ constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
517
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
518
+
519
+ // 1x8 tile = 2 x 4
520
+ float32x4_t acc_f32[col_groups];
521
+
522
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
523
+
524
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
525
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
526
+
527
+ for (int i = 0; i < col_groups; i++) {
528
+ acc_f32[i] = vdupq_n_f32(0);
529
+ }
530
+
531
+ for (int b = 0; b < nb; b++) {
532
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
533
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
534
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
535
+ float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
536
+ float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
537
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
538
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
539
+ float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
540
+ float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
541
+
542
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
543
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
544
+ int32x4_t acc_lo[col_groups];
545
+ int32x4_t acc_hi[col_groups];
546
+
547
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
548
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
549
+ int16_t bsums_arr[8];
550
+ vst1q_s16(bsums_arr, bsums);
551
+ for (int sb = 0; sb < QK_K / 64; sb++) {
552
+ for (int i = 0; i < col_groups; i++) {
553
+ acc_lo[i] = vdupq_n_s32(0);
554
+ acc_hi[i] = vdupq_n_s32(0);
555
+ }
556
+ // Need scales for the low and high nibbles
557
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
558
+ int16x8_t q4sb_mins[2];
559
+ int16x8_t q4sb_scales[2];
560
+ for (int i = 0; i < 2; i++) {
561
+ int8_t aux_q4sb[8];
562
+ const int offset = sb * 24 + i * 12;
563
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
564
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
565
+ }
566
+
567
+ int8x16_t q8_qs[64 / 16];
568
+ for (int i = 0; i < 64 / 16; i++) {
569
+ q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
570
+ }
571
+
572
+ for (int c = 0; c < col_groups; c++) {
573
+ uint8x16_t q4_cols[8];
574
+ for (int i = 0; i < 8; i++) {
575
+ q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
576
+ }
577
+
578
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
579
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
580
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
581
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
582
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
583
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
584
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
585
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
586
+
587
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
588
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
589
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
590
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
591
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
592
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
593
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
594
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
595
+ }
596
+
597
+ // Scales
598
+ // row c0123 blk0 and blk1
599
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
600
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
601
+ const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
602
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
603
+ acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
604
+ // row c4567 blk0 and blk1
605
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
606
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
607
+ const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
608
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
609
+ acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
610
+
611
+ // Bias Correction
612
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
613
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
614
+
615
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
616
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
617
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
618
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
619
+ } // for sb
620
+
621
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
622
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
623
+ } // for b
624
+
625
+ int base = x * ncols_interleaved;
626
+ vst1q_f32(s + base, acc_f32[0]);
627
+ vst1q_f32(s + base + 4, acc_f32[1]);
628
+ } // for x
629
+ return;
630
+ #endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
631
+ ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
632
+ }
633
+
634
+ void ggml_gemv_q4_K_8x8_q8_K(int n,
635
+ float * GGML_RESTRICT s,
636
+ size_t bs,
637
+ const void * GGML_RESTRICT vx,
638
+ const void * GGML_RESTRICT vy,
639
+ int nr,
640
+ int nc) {
641
+ constexpr int qk = QK_K;
642
+ const int nb = n / qk;
643
+
644
+ constexpr int ncols_interleaved = 8;
645
+ constexpr int blocklen = 8;
646
+
647
+ assert(n % qk == 0);
648
+ assert(nr % 4 == 0);
649
+ assert(nc % ncols_interleaved == 0);
650
+
651
+ UNUSED(nb);
652
+ UNUSED(ncols_interleaved);
653
+ UNUSED(blocklen);
654
+
655
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
656
+ constexpr int col_pairs = ncols_interleaved / 2;
657
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
658
+
659
+ // 1x8 tile = 2 x 4
660
+ float32x4_t acc_f32[ncols_interleaved / 4];
661
+
662
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
663
+
664
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
665
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
666
+
667
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
668
+ acc_f32[i] = vdupq_n_f32(0);
669
+ }
670
+
671
+ for (int b = 0; b < nb; b++) {
672
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
673
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
674
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
675
+ float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
676
+ float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
677
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
678
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
679
+ float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
680
+ float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
681
+
682
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
683
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
684
+ // 2 sb each iteration
685
+ int32x4_t acc_lo[col_pairs];
686
+ int32x4_t acc_hi[col_pairs];
687
+
688
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
689
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
690
+ int16_t bsums_arr[8];
691
+ vst1q_s16(bsums_arr, bsums);
692
+ for (int sb = 0; sb < QK_K / 64; sb++) {
693
+ for (int i = 0; i < col_pairs; i++) {
694
+ acc_lo[i] = vdupq_n_s32(0);
695
+ acc_hi[i] = vdupq_n_s32(0);
696
+ }
697
+ // Need scales for the low and high nibbles
698
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
699
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
700
+ int16x8_t q4sb_scales[2];
701
+ for (int i = 0; i < 2; i++) {
702
+ int8_t aux_q4sb[8];
703
+ const int offset = sb * 24 + i * 12;
704
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
705
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
706
+ }
707
+
708
+ const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
709
+
710
+ // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
711
+ // but still need the qs to use the low and hi bits from q4
712
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
713
+ int8x16_t q8_qs[8];
714
+ for (int i = 0; i < 8; i++) {
715
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
716
+ }
717
+
718
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
719
+ for (int cp = 0; cp < col_pairs; cp++) {
720
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
721
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
722
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
723
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
724
+
725
+ acc_lo[cp] =
726
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
727
+ acc_lo[cp] =
728
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
729
+ acc_lo[cp] =
730
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
731
+ acc_lo[cp] =
732
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
733
+
734
+ acc_hi[cp] =
735
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
736
+ acc_hi[cp] =
737
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
738
+ acc_hi[cp] =
739
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
740
+ acc_hi[cp] =
741
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
742
+ }
743
+
744
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
745
+ // p = 0 -> 0123 p2 -> 4567
746
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
747
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
748
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
749
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
750
+
751
+ // 0123 or 4567
752
+ float32x4_t sumf_0 =
753
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
754
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
755
+
756
+ float32x4_t sumf_1 =
757
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
758
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
759
+ }
760
+
761
+ // Multiply Acc bsum + mins
762
+ // Each pair of subblocks share the same bsums
763
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
764
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
765
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
766
+
767
+ // cols 0-3 bias
768
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
769
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
770
+
771
+ // cols 4-7 bias
772
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
773
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
774
+ } // for sb
775
+
776
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
777
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
778
+ } // for b
779
+
780
+ int base = x * ncols_interleaved;
781
+ vst1q_f32(s + base, acc_f32[0]);
782
+ vst1q_f32(s + base + 4, acc_f32[1]);
783
+ } // for x
784
+ return;
785
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
786
+ ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
787
+ }
788
+
477
789
  void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
478
790
  const int qk = QK8_0;
479
791
  const int nb = n / qk;
@@ -1889,3 +2201,412 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
1889
2201
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1890
2202
  ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1891
2203
  }
2204
+
2205
+ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2206
+ constexpr int qk = QK_K;
2207
+ const int nb = n / qk;
2208
+
2209
+ constexpr int ncols_interleaved = 8;
2210
+ constexpr int blocklen = 4;
2211
+
2212
+ assert(n % qk == 0);
2213
+ assert(nr % 4 == 0);
2214
+ assert(nc % ncols_interleaved == 0);
2215
+
2216
+ UNUSED(nb);
2217
+ UNUSED(ncols_interleaved);
2218
+ UNUSED(blocklen);
2219
+
2220
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2221
+ constexpr int q8_k_blocklen = 4;
2222
+ constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
2223
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2224
+
2225
+ // 8 accumulators: 2 row pairs × 4 col pairs
2226
+ float32x4_t acc_f32[acc_size];
2227
+
2228
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
2229
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2230
+
2231
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2232
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2233
+
2234
+ for (int i = 0; i < acc_size; i++) {
2235
+ acc_f32[i] = vdupq_n_f32(0);
2236
+ }
2237
+
2238
+ for (int b = 0; b < nb; b++) {
2239
+ // d4 0 1 2 3, 4 5 6 7
2240
+ float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
2241
+ float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
2242
+ // d8 0 1 2 3
2243
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
2244
+ // mins
2245
+ float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
2246
+ float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
2247
+
2248
+ // Precomputation of scales and mins
2249
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
2250
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
2251
+ float32x4_t sbd_min_0123[q8_k_blocklen];
2252
+ float32x4_t sbd_min_4567[q8_k_blocklen];
2253
+
2254
+ sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
2255
+ sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
2256
+ sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
2257
+ sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
2258
+
2259
+ sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
2260
+ sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
2261
+ sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
2262
+ sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
2263
+
2264
+ sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
2265
+ sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
2266
+ sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
2267
+ sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
2268
+
2269
+ sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
2270
+ sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
2271
+ sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
2272
+ sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
2273
+
2274
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
2275
+ const int16x8_t bsums[q8_k_blocklen] = {
2276
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2277
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2278
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2279
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2280
+ };
2281
+ int16_t bsums_arr[QK_K / 64][8];
2282
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2283
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2284
+ }
2285
+
2286
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
2287
+ int32x4_t bias_acc[acc_size];
2288
+ for (int i = 0; i < acc_size; i++) {
2289
+ bias_acc[i] = vdupq_n_s32(0);
2290
+ }
2291
+
2292
+ for (int sb = 0; sb < QK_K / 64; sb++) {
2293
+ // Int accumulators for qs vecdot (4 row x 2 col quartets)
2294
+ int32x4_t acc_lo[acc_size];
2295
+ int32x4_t acc_hi[acc_size];
2296
+ for (int i = 0; i < acc_size; i++) {
2297
+ acc_lo[i] = vdupq_n_s32(0);
2298
+ acc_hi[i] = vdupq_n_s32(0);
2299
+ }
2300
+ // Need scales for the low and high nibbles
2301
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2302
+ int16x8_t q4sb_scales[2];
2303
+ int16x8_t q4sb_mins[2];
2304
+ for (int i = 0; i < 2; i++) {
2305
+ int8_t aux_q4sb[8];
2306
+ const int offset = sb * 24 + i * 12;
2307
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
2308
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
2309
+ }
2310
+
2311
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
2312
+ for (int k = 0; k < reads_per_sb; k++) {
2313
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
2314
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
2315
+
2316
+ // 0..3 & 32..35
2317
+ const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
2318
+ const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
2319
+
2320
+ const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
2321
+ const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
2322
+
2323
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
2324
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
2325
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
2326
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
2327
+
2328
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
2329
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
2330
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
2331
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
2332
+
2333
+ const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
2334
+ const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
2335
+
2336
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
2337
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
2338
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
2339
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
2340
+
2341
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
2342
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
2343
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
2344
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
2345
+ }
2346
+
2347
+ // Scale and bias application
2348
+ // acc is stored interleaved to match output layout
2349
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
2350
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
2351
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
2352
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
2353
+ for (int row = 0; row < q8_k_blocklen; row++) {
2354
+ // Bias correction
2355
+ // row c0123 blk0 and blk1
2356
+ const float32x4_t sumf_0123 =
2357
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
2358
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
2359
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
2360
+
2361
+ // row c4567 blk0 and blk1
2362
+ const float32x4_t sumf_4567 =
2363
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
2364
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
2365
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
2366
+
2367
+ // Bias
2368
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
2369
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
2370
+
2371
+ // row c0123 blk0 and blk1
2372
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2373
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2374
+
2375
+ // row c4567 blk0 and blk1
2376
+ bias_acc[2 * row + 1] =
2377
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2378
+ bias_acc[2 * row + 1] =
2379
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2380
+ }
2381
+ } // for sb
2382
+
2383
+ for (int row = 0; row < q8_k_blocklen; row++) {
2384
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
2385
+ acc_f32[2 * row + 1] =
2386
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
2387
+ }
2388
+ } // for b
2389
+
2390
+ for (int i = 0; i < q8_k_blocklen; i++) {
2391
+ int row = y * q8_k_blocklen + i;
2392
+ for (int j = 0; j < 2; j++) {
2393
+ int col = x * ncols_interleaved + j * 4;
2394
+ int offset = row * bs + col;
2395
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2396
+ }
2397
+ }
2398
+ } // for x
2399
+ } // for y
2400
+ return;
2401
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2402
+ ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2403
+ }
2404
+
2405
+ void ggml_gemm_q4_K_8x8_q8_K(int n,
2406
+ float * GGML_RESTRICT s,
2407
+ size_t bs,
2408
+ const void * GGML_RESTRICT vx,
2409
+ const void * GGML_RESTRICT vy,
2410
+ int nr,
2411
+ int nc) {
2412
+ constexpr int qk = QK_K;
2413
+ const int nb = n / qk;
2414
+
2415
+ constexpr int ncols_interleaved = 8;
2416
+ constexpr int blocklen = 8;
2417
+
2418
+ assert(n % qk == 0);
2419
+ assert(nr % 4 == 0);
2420
+ assert(nc % ncols_interleaved == 0);
2421
+
2422
+ UNUSED(nb);
2423
+ UNUSED(ncols_interleaved);
2424
+ UNUSED(blocklen);
2425
+
2426
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2427
+ constexpr int q8_k_blocklen = 4;
2428
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2429
+
2430
+ // 8 accumulators: 2 row pairs × 4 col pairs
2431
+ float32x4_t acc_f32[blocklen];
2432
+
2433
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
2434
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2435
+
2436
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2437
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2438
+
2439
+ for (int i = 0; i < blocklen; i++) {
2440
+ acc_f32[i] = vdupq_n_f32(0);
2441
+ }
2442
+
2443
+ for (int b = 0; b < nb; b++) {
2444
+ // bsums pairs belongs to the same q8_k subblock
2445
+ const int16x8_t bsums[4]{
2446
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2447
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2448
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2449
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2450
+ };
2451
+ int16_t bsums_arr[4][8];
2452
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2453
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2454
+ }
2455
+
2456
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
2457
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2458
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
2459
+ for (int i = 0; i < 8; i++) {
2460
+ acc[i] = vdupq_n_s32(0);
2461
+ bias_acc[i] = vdupq_n_s32(0);
2462
+ }
2463
+
2464
+ for (int sb = 0; sb < QK_K / 64; sb++) {
2465
+ // Need scales for the low and high nibbles
2466
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2467
+ int8_t q4sb_scales[2][8];
2468
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
2469
+ for (int i = 0; i < 2; i++) {
2470
+ const int offset = sb * 24 + i * 12;
2471
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
2472
+ }
2473
+
2474
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
2475
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
2476
+
2477
+ int8x16_t q8_qs_01[8];
2478
+ int8x16_t q8_qs_23[8];
2479
+
2480
+ // Load 32-byte per row pair, 1 subblock each time
2481
+ for (int i = 0; i < 8; i++) {
2482
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
2483
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
2484
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
2485
+ }
2486
+
2487
+ const int8x16_t q8s[2][8] = {
2488
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
2489
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
2490
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
2491
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
2492
+ };
2493
+
2494
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
2495
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
2496
+ for (int i = 0; i < 4; i++) {
2497
+ sb_acc[i] = vdupq_n_s32(0);
2498
+ }
2499
+
2500
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
2501
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
2502
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
2503
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
2504
+ const int8x16_t q4_nibbles[2][4] = {
2505
+ {
2506
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
2507
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
2508
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
2509
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
2510
+ },
2511
+ {
2512
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
2513
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
2514
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
2515
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
2516
+ }
2517
+ };
2518
+
2519
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
2520
+ // for each of the internal 32 qs subblock (blk)
2521
+ for (int rp = 0; rp < 2; rp++) {
2522
+ for (int blk = 0; blk < 2; blk++) {
2523
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
2524
+ const int8x16_t * q4 = q4_nibbles[blk];
2525
+ int32x4_t acc = sb_acc[2 * rp + blk];
2526
+ // mul add for each qs in the same subblock
2527
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
2528
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
2529
+ }
2530
+ sb_acc[2 * rp + blk] = acc;
2531
+ }
2532
+ }
2533
+
2534
+ // Scales[i] corresponds to column i
2535
+ const int scale_offset = cp * 2;
2536
+ for (int blk = 0; blk < 2; blk++) {
2537
+ const int32x4_t block_scale = {
2538
+ (int32_t) q4sb_scales[blk][scale_offset],
2539
+ (int32_t) q4sb_scales[blk][scale_offset],
2540
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2541
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2542
+ };
2543
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
2544
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
2545
+ }
2546
+ }
2547
+
2548
+ // Multiply Acc bsum + mins
2549
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2550
+ // Each pair of subblocks share the same bsums
2551
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2552
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
2553
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2554
+
2555
+ bias_acc[2 * q8_row] =
2556
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2557
+ bias_acc[2 * q8_row] =
2558
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2559
+ bias_acc[2 * q8_row + 1] =
2560
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2561
+ bias_acc[2 * q8_row + 1] =
2562
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2563
+ }
2564
+ } // for sb
2565
+
2566
+ // Reorder of i8mm output with bias and output layout
2567
+ for (int i = 0; i < 8; i++) {
2568
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
2569
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
2570
+ }
2571
+ int32x4_t reorder_acc[8] = {
2572
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
2573
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
2574
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
2575
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
2576
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
2577
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
2578
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
2579
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
2580
+ };
2581
+
2582
+ for (int i = 0; i < q8_k_blocklen; i++) {
2583
+ for (int j = 0; j < 2; j++) {
2584
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2585
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
2586
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
2587
+
2588
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
2589
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
2590
+
2591
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
2592
+ acc_f32[2 * i + j] =
2593
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
2594
+ }
2595
+ }
2596
+ } // for b
2597
+
2598
+ // With the previous reorder, the tile is already in the correct memory layout.
2599
+ for (int i = 0; i < q8_k_blocklen; i++) {
2600
+ int row = y * q8_k_blocklen + i;
2601
+ for (int j = 0; j < 2; j++) {
2602
+ int col = x * ncols_interleaved + j * 4;
2603
+ int offset = row * bs + col;
2604
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2605
+ }
2606
+ }
2607
+ } // for x
2608
+ } // for y
2609
+ return;
2610
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2611
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2612
+ }