@fugood/llama.node 1.3.6 → 1.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,29 @@
24
24
 
25
25
  #define UNUSED GGML_UNUSED
26
26
 
27
+ static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
28
+ int16x8_t * out_mins,
29
+ int8_t * out_scales) {
30
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
31
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
32
+ constexpr uint32_t kmask3 = 0x03030303;
33
+ constexpr uint8_t scales_size = 12;
34
+
35
+ uint32_t sm[3];
36
+ memcpy(sm, scales_in, scales_size);
37
+
38
+ const uint32_t mins_0_3 = sm[1] & kmask1;
39
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
40
+ const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
41
+
42
+ *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
43
+
44
+ uint32_t scales_u32[2];
45
+ scales_u32[0] = sm[0] & kmask1;
46
+ scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
47
+ memcpy(out_scales, scales_u32, 8);
48
+ }
49
+
27
50
  void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
28
51
  assert(QK8_0 == 32);
29
52
  assert(k % QK8_0 == 0);
@@ -474,6 +497,162 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
474
497
  ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
475
498
  }
476
499
 
500
+ void ggml_gemv_q4_K_8x8_q8_K(int n,
501
+ float * GGML_RESTRICT s,
502
+ size_t bs,
503
+ const void * GGML_RESTRICT vx,
504
+ const void * GGML_RESTRICT vy,
505
+ int nr,
506
+ int nc) {
507
+ constexpr int qk = QK_K;
508
+ const int nb = n / qk;
509
+
510
+ constexpr int ncols_interleaved = 8;
511
+ constexpr int blocklen = 8;
512
+
513
+ assert(n % qk == 0);
514
+ assert(nr % 4 == 0);
515
+ assert(nc % ncols_interleaved == 0);
516
+
517
+ UNUSED(nb);
518
+ UNUSED(ncols_interleaved);
519
+ UNUSED(blocklen);
520
+
521
+ #if defined(__aarch64__) && defined(__ARM_NEON)
522
+ constexpr int col_pairs = ncols_interleaved / 2;
523
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
524
+
525
+ // 1x8 tile = 2 x 4
526
+ float32x4_t acc_f32[ncols_interleaved / 4];
527
+
528
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
529
+
530
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
531
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
532
+
533
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
534
+ acc_f32[i] = vdupq_n_f32(0);
535
+ }
536
+
537
+ for (int b = 0; b < nb; b++) {
538
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
539
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
540
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
541
+ float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
542
+ float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
543
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
544
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
545
+ float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
546
+ float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
547
+
548
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
549
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
550
+ // 2 sb each iteration
551
+ int32x4_t acc_lo[col_pairs];
552
+ int32x4_t acc_hi[col_pairs];
553
+
554
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
555
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
556
+ int16_t bsums_arr[8];
557
+ vst1q_s16(bsums_arr, bsums);
558
+ for (int sb = 0; sb < QK_K / 64; sb++) {
559
+ for (int i = 0; i < col_pairs; i++) {
560
+ acc_lo[i] = vdupq_n_s32(0);
561
+ acc_hi[i] = vdupq_n_s32(0);
562
+ }
563
+ // Need scales for the low and high nibbles
564
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
565
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
566
+ int16x8_t q4sb_scales[2];
567
+ for (int i = 0; i < 2; i++) {
568
+ int8_t aux_q4sb[8];
569
+ const int offset = sb * 24 + i * 12;
570
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
571
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
572
+ }
573
+
574
+ const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
575
+
576
+ // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
577
+ // but still need the qs to use the low and hi bits from q4
578
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
579
+ int8x16_t q8_qs[8];
580
+ for (int i = 0; i < 8; i++) {
581
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
582
+ }
583
+
584
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
585
+ for (int cp = 0; cp < col_pairs; cp++) {
586
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
587
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
588
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
589
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
590
+
591
+ acc_lo[cp] =
592
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
593
+ acc_lo[cp] =
594
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
595
+ acc_lo[cp] =
596
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
597
+ acc_lo[cp] =
598
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
599
+
600
+ acc_hi[cp] =
601
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
602
+ acc_hi[cp] =
603
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
604
+ acc_hi[cp] =
605
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
606
+ acc_hi[cp] =
607
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
608
+ }
609
+
610
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
611
+ // p = 0 -> 0123 p2 -> 4567
612
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
613
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
614
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
615
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
616
+
617
+ // 0123 or 4567
618
+ // TODO: Single superblock mul at the end of the superblock
619
+ float32x4_t sumf_0 =
620
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
621
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
622
+
623
+ float32x4_t sumf_1 =
624
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
625
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
626
+ }
627
+
628
+ // Multiply Acc bsum + mins
629
+ // Each pair of subblocks share the same bsums
630
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
631
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
632
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
633
+
634
+ // cols 0-3 bias
635
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
636
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
637
+
638
+ // cols 4-7 bias
639
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
640
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
641
+ } // for sb
642
+
643
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
644
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
645
+ } // for b
646
+
647
+ int base = x * ncols_interleaved;
648
+ vst1q_f32(s + base, acc_f32[0]);
649
+ vst1q_f32(s + base + 4, acc_f32[1]);
650
+ } // for x
651
+ return;
652
+ #endif // defined(__aarch64__) && defined(__ARM_NEON)
653
+ ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
654
+ }
655
+
477
656
  void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
478
657
  const int qk = QK8_0;
479
658
  const int nb = n / qk;
@@ -1889,3 +2068,212 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
1889
2068
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1890
2069
  ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1891
2070
  }
2071
+
2072
+ void ggml_gemm_q4_K_8x8_q8_K(int n,
2073
+ float * GGML_RESTRICT s,
2074
+ size_t bs,
2075
+ const void * GGML_RESTRICT vx,
2076
+ const void * GGML_RESTRICT vy,
2077
+ int nr,
2078
+ int nc) {
2079
+ constexpr int qk = QK_K;
2080
+ const int nb = n / qk;
2081
+
2082
+ constexpr int ncols_interleaved = 8;
2083
+ constexpr int blocklen = 8;
2084
+
2085
+ assert(n % qk == 0);
2086
+ assert(nr % 4 == 0);
2087
+ assert(nc % ncols_interleaved == 0);
2088
+
2089
+ UNUSED(nb);
2090
+ UNUSED(ncols_interleaved);
2091
+ UNUSED(blocklen);
2092
+
2093
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2094
+ constexpr int q8_k_blocklen = 4;
2095
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2096
+
2097
+ // 8 accumulators: 2 row pairs × 4 col pairs
2098
+ float32x4_t acc_f32[blocklen];
2099
+
2100
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
2101
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2102
+
2103
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2104
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2105
+
2106
+ for (int i = 0; i < blocklen; i++) {
2107
+ acc_f32[i] = vdupq_n_f32(0);
2108
+ }
2109
+
2110
+ for (int b = 0; b < nb; b++) {
2111
+ // bsums pairs belongs to the same q8_k subblock
2112
+ const int16x8_t bsums[4]{
2113
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2114
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2115
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2116
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2117
+ };
2118
+ int16_t bsums_arr[4][8];
2119
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2120
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2121
+ }
2122
+
2123
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
2124
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2125
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
2126
+ for (int i = 0; i < 8; i++) {
2127
+ acc[i] = vdupq_n_s32(0);
2128
+ bias_acc[i] = vdupq_n_s32(0);
2129
+ }
2130
+
2131
+ for (int sb = 0; sb < QK_K / 64; sb++) {
2132
+ // Need scales for the low and high nibbles
2133
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2134
+ int8_t q4sb_scales[2][8];
2135
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
2136
+ for (int i = 0; i < 2; i++) {
2137
+ const int offset = sb * 24 + i * 12;
2138
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
2139
+ }
2140
+
2141
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
2142
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
2143
+
2144
+ int8x16_t q8_qs_01[8];
2145
+ int8x16_t q8_qs_23[8];
2146
+
2147
+ // Load 32-byte per row pair, 1 subblock each time
2148
+ for (int i = 0; i < 8; i++) {
2149
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
2150
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
2151
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
2152
+ }
2153
+
2154
+ const int8x16_t q8s[2][8] = {
2155
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
2156
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
2157
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
2158
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
2159
+ };
2160
+
2161
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
2162
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
2163
+ for (int i = 0; i < 4; i++) {
2164
+ sb_acc[i] = vdupq_n_s32(0);
2165
+ }
2166
+
2167
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
2168
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
2169
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
2170
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
2171
+ const int8x16_t q4_nibbles[2][4] = {
2172
+ {
2173
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
2174
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
2175
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
2176
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
2177
+ },
2178
+ {
2179
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
2180
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
2181
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
2182
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
2183
+ }
2184
+ };
2185
+
2186
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
2187
+ // for each of the internal 32 qs subblock (blk)
2188
+ for (int rp = 0; rp < 2; rp++) {
2189
+ for (int blk = 0; blk < 2; blk++) {
2190
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
2191
+ const int8x16_t * q4 = q4_nibbles[blk];
2192
+ int32x4_t acc = sb_acc[2 * rp + blk];
2193
+ // mul add for each qs in the same subblock
2194
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
2195
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
2196
+ }
2197
+ sb_acc[2 * rp + blk] = acc;
2198
+ }
2199
+ }
2200
+
2201
+ // Scales[i] corresponds to column i
2202
+ const int scale_offset = cp * 2;
2203
+ for (int blk = 0; blk < 2; blk++) {
2204
+ const int32x4_t block_scale = {
2205
+ (int32_t) q4sb_scales[blk][scale_offset],
2206
+ (int32_t) q4sb_scales[blk][scale_offset],
2207
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2208
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2209
+ };
2210
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
2211
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
2212
+ }
2213
+ }
2214
+
2215
+ // Multiply Acc bsum + mins
2216
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2217
+ // Each pair of subblocks share the same bsums
2218
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2219
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
2220
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2221
+
2222
+ bias_acc[2 * q8_row] =
2223
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2224
+ bias_acc[2 * q8_row] =
2225
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2226
+ bias_acc[2 * q8_row + 1] =
2227
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2228
+ bias_acc[2 * q8_row + 1] =
2229
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2230
+ }
2231
+ } // for sb
2232
+
2233
+ // Reorder of i8mm output with bias and output layout
2234
+ for (int i = 0; i < 8; i++) {
2235
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
2236
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
2237
+ }
2238
+ int32x4_t reorder_acc[8] = {
2239
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
2240
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
2241
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
2242
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
2243
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
2244
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
2245
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
2246
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
2247
+ };
2248
+
2249
+ for (int i = 0; i < q8_k_blocklen; i++) {
2250
+ for (int j = 0; j < 2; j++) {
2251
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2252
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
2253
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
2254
+
2255
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
2256
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
2257
+
2258
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
2259
+ acc_f32[2 * i + j] =
2260
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
2261
+ }
2262
+ }
2263
+ } // for b
2264
+
2265
+ // With the previous reorder, the tile is already in the correct memory layout.
2266
+ for (int i = 0; i < q8_k_blocklen; i++) {
2267
+ int row = y * q8_k_blocklen + i;
2268
+ for (int j = 0; j < 2; j++) {
2269
+ int col = x * ncols_interleaved + j * 4;
2270
+ int offset = row * bs + col;
2271
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2272
+ }
2273
+ }
2274
+ } // for x
2275
+ } // for y
2276
+ return;
2277
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2278
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2279
+ }
@@ -0,0 +1,35 @@
1
+ #include "ggml-backend-impl.h"
2
+
3
+ #if defined(__riscv) && __riscv_xlen == 64
4
+ #include <sys/auxv.h>
5
+
6
+ //https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24
7
+ #ifndef COMPAT_HWCAP_ISA_V
8
+ #define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
9
+ #endif
10
+
11
+ struct riscv64_features {
12
+ bool has_rvv = false;
13
+
14
+ riscv64_features() {
15
+ uint32_t hwcap = getauxval(AT_HWCAP);
16
+
17
+ has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V);
18
+ }
19
+ };
20
+
21
+ static int ggml_backend_cpu_riscv64_score() {
22
+ int score = 1;
23
+ riscv64_features rf;
24
+
25
+ #ifdef GGML_USE_RVV
26
+ if (!rf.has_rvv) { return 0; }
27
+ score += 1 << 1;
28
+ #endif
29
+
30
+ return score;
31
+ }
32
+
33
+ GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
34
+
35
+ #endif // __riscv && __riscv_xlen == 64
@@ -51,10 +51,8 @@
51
51
  #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
52
52
  // repack.cpp
53
53
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
54
- #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
55
54
  #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
56
55
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
57
- #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
58
56
  #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
59
57
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
60
58
  #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1927
1927
  {
1928
1928
  ggml_compute_forward_argsort(params, tensor);
1929
1929
  } break;
1930
+ case GGML_OP_TOP_K:
1931
+ {
1932
+ ggml_compute_forward_top_k(params, tensor);
1933
+ } break;
1930
1934
  case GGML_OP_LEAKY_RELU:
1931
1935
  {
1932
1936
  ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2311
2315
  case GGML_OP_ARANGE:
2312
2316
  case GGML_OP_TIMESTEP_EMBEDDING:
2313
2317
  case GGML_OP_ARGSORT:
2318
+ case GGML_OP_TOP_K:
2314
2319
  case GGML_OP_FLASH_ATTN_EXT:
2315
2320
  case GGML_OP_FLASH_ATTN_BACK:
2316
2321
  case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
2834
2839
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
2835
2840
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
2836
2841
  } break;
2842
+ case GGML_OP_TOP_K:
2843
+ {
2844
+ cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
2845
+ } break;
2837
2846
  case GGML_OP_FLASH_ATTN_EXT:
2838
2847
  {
2839
2848
  const int64_t ne10 = node->src[1]->ne[0]; // DK
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
7794
7794
  // ggml_compute_forward_argsort
7795
7795
 
7796
7796
  template<enum ggml_sort_order order>
7797
- struct argsort_cmp {
7797
+ struct cmp_argsort {
7798
7798
  const float * data;
7799
7799
  bool operator()(int32_t a, int32_t b) const {
7800
7800
  if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
7833
7833
 
7834
7834
  switch (order) {
7835
7835
  case GGML_SORT_ORDER_ASC:
7836
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7836
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7837
7837
  break;
7838
7838
 
7839
7839
  case GGML_SORT_ORDER_DESC:
7840
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7840
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7841
7841
  break;
7842
7842
 
7843
7843
  default:
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
7864
7864
  }
7865
7865
  }
7866
7866
 
7867
+ // ggml_compute_forward_top_k
7868
+
7869
+ struct cmp_top_k {
7870
+ const float * data;
7871
+ bool operator()(int32_t a, int32_t b) const {
7872
+ return data[a] > data[b];
7873
+ }
7874
+ };
7875
+
7876
+ static void ggml_compute_forward_top_k_f32(
7877
+ const ggml_compute_params * params,
7878
+ ggml_tensor * dst) {
7879
+
7880
+ const ggml_tensor * src0 = dst->src[0];
7881
+
7882
+ GGML_TENSOR_UNARY_OP_LOCALS
7883
+
7884
+ GGML_ASSERT(nb0 == sizeof(float));
7885
+
7886
+ const int ith = params->ith;
7887
+ const int nth = params->nth;
7888
+
7889
+ const int64_t nr = ggml_nrows(src0);
7890
+
7891
+ const int top_k = ne0;
7892
+
7893
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7894
+
7895
+ for (int64_t i = ith; i < nr; i += nth) {
7896
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7897
+
7898
+ for (int64_t j = 0; j < ne00; j++) {
7899
+ tmp[j] = j;
7900
+ }
7901
+
7902
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7903
+
7904
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7905
+
7906
+ std::copy(tmp, tmp + top_k, dst_data);
7907
+
7908
+ // emphasize that the order is not important
7909
+ if (top_k > 1) {
7910
+ std::swap(dst_data[0], dst_data[1]);
7911
+ }
7912
+ }
7913
+ }
7914
+
7915
+ void ggml_compute_forward_top_k(
7916
+ const ggml_compute_params * params,
7917
+ ggml_tensor * dst) {
7918
+
7919
+ const ggml_tensor * src0 = dst->src[0];
7920
+
7921
+ switch (src0->type) {
7922
+ case GGML_TYPE_F32:
7923
+ {
7924
+ ggml_compute_forward_top_k_f32(params, dst);
7925
+ } break;
7926
+ default:
7927
+ {
7928
+ GGML_ABORT("fatal error");
7929
+ }
7930
+ }
7931
+ }
7932
+
7867
7933
  // ggml_compute_forward_flash_attn_ext
7868
7934
 
7869
7935
  static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
81
81
  void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
82
82
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
83
83
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
+ void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
85
  void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
85
86
  void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
86
87
  void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -1731,12 +1731,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1731
1731
  nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1732
1732
  }
1733
1733
 
1734
- if (nth == 1 || nchunk0 < nth || disable_chunking) {
1734
+ int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1735
+ // Only increase nchunk0 to nth if it won't make chunks too small
1736
+ if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
1735
1737
  nchunk0 = nth;
1738
+ dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1736
1739
  }
1737
1740
 
1738
- const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739
-
1740
1741
  // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1741
1742
  // This prevents creating too many tiny chunks that could overlap after alignment
1742
1743
  const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
@@ -1961,6 +1962,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1961
1962
  return &q4_K_8x8_q8_K;
1962
1963
  }
1963
1964
  }
1965
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
1966
+ if (cur->ne[1] % 8 == 0) {
1967
+ return &q4_K_8x8_q8_K;
1968
+ }
1969
+ }
1964
1970
  } else if (cur->type == GGML_TYPE_Q2_K) {
1965
1971
  if (ggml_cpu_has_avx512()) {
1966
1972
  if (cur->ne[1] % 8 == 0) {