whisper.rn 0.5.4 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
- package/android/src/main/jni.cpp +13 -0
- package/cpp/ggml-alloc.c +78 -26
- package/cpp/ggml-alloc.h +9 -0
- package/cpp/ggml-backend-impl.h +1 -1
- package/cpp/ggml-backend-reg.cpp +19 -3
- package/cpp/ggml-backend.cpp +72 -20
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
- package/cpp/ggml-cpu/arch-fallback.h +50 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +139 -58
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +170 -18
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/repack.cpp +531 -5
- package/cpp/ggml-cpu/repack.h +14 -0
- package/cpp/ggml-cpu/simd-mappings.h +16 -18
- package/cpp/ggml-cpu/vec.cpp +41 -1
- package/cpp/ggml-cpu/vec.h +241 -138
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +0 -4
- package/cpp/ggml-metal/ggml-metal-context.m +26 -16
- package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
- package/cpp/ggml-metal/ggml-metal-device.h +87 -65
- package/cpp/ggml-metal/ggml-metal-device.m +263 -104
- package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
- package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
- package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
- package/cpp/ggml-metal/ggml-metal.cpp +6 -5
- package/cpp/ggml-metal/ggml-metal.metal +404 -34
- package/cpp/ggml.c +110 -31
- package/cpp/ggml.h +51 -12
- package/cpp/jsi/RNWhisperJSI.cpp +1 -0
- package/cpp/whisper.cpp +16 -3
- package/ios/CMakeLists.txt +21 -1
- package/ios/RNWhisperContext.mm +5 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/jest-mock.js +2 -0
- package/lib/commonjs/jest-mock.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/jest-mock.js +2 -0
- package/lib/module/jest-mock.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/jest-mock.ts +2 -0
- package/src/version.json +1 -1
|
@@ -24,6 +24,31 @@
|
|
|
24
24
|
|
|
25
25
|
#define UNUSED WSP_GGML_UNUSED
|
|
26
26
|
|
|
27
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
|
|
28
|
+
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
|
29
|
+
int16x8_t * out_mins,
|
|
30
|
+
int8_t * out_scales) {
|
|
31
|
+
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
|
32
|
+
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
|
33
|
+
constexpr uint32_t kmask3 = 0x03030303;
|
|
34
|
+
constexpr uint8_t scales_size = 12;
|
|
35
|
+
|
|
36
|
+
uint32_t sm[3];
|
|
37
|
+
memcpy(sm, scales_in, scales_size);
|
|
38
|
+
|
|
39
|
+
const uint32_t mins_0_3 = sm[1] & kmask1;
|
|
40
|
+
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
|
41
|
+
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
|
|
42
|
+
|
|
43
|
+
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
|
|
44
|
+
|
|
45
|
+
uint32_t scales_u32[2];
|
|
46
|
+
scales_u32[0] = sm[0] & kmask1;
|
|
47
|
+
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
|
48
|
+
memcpy(out_scales, scales_u32, 8);
|
|
49
|
+
}
|
|
50
|
+
#endif
|
|
51
|
+
|
|
27
52
|
void wsp_ggml_wsp_quantize_mat_q8_0_4x4(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k) {
|
|
28
53
|
assert(QK8_0 == 32);
|
|
29
54
|
assert(k % QK8_0 == 0);
|
|
@@ -474,6 +499,420 @@ void wsp_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs
|
|
|
474
499
|
wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
475
500
|
}
|
|
476
501
|
|
|
502
|
+
void wsp_ggml_gemv_q4_K_8x4_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
503
|
+
constexpr int qk = QK_K;
|
|
504
|
+
const int nb = n / qk;
|
|
505
|
+
|
|
506
|
+
constexpr int ncols_interleaved = 8;
|
|
507
|
+
constexpr int blocklen = 8;
|
|
508
|
+
|
|
509
|
+
assert(n % qk == 0);
|
|
510
|
+
assert(nc % ncols_interleaved == 0);
|
|
511
|
+
|
|
512
|
+
UNUSED(nb);
|
|
513
|
+
UNUSED(ncols_interleaved);
|
|
514
|
+
UNUSED(blocklen);
|
|
515
|
+
|
|
516
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
517
|
+
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
|
518
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
519
|
+
|
|
520
|
+
// 1x8 tile = 2 x 4
|
|
521
|
+
float32x4_t acc_f32[col_groups];
|
|
522
|
+
|
|
523
|
+
const block_q8_K * WSP_GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
524
|
+
|
|
525
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
526
|
+
const block_q4_Kx8 * WSP_GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
527
|
+
|
|
528
|
+
for (int i = 0; i < col_groups; i++) {
|
|
529
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
for (int b = 0; b < nb; b++) {
|
|
533
|
+
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
|
534
|
+
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
535
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
536
|
+
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
|
|
537
|
+
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
|
|
538
|
+
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
|
539
|
+
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
|
540
|
+
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
|
|
541
|
+
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
|
|
542
|
+
|
|
543
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
|
544
|
+
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
545
|
+
int32x4_t acc_lo[col_groups];
|
|
546
|
+
int32x4_t acc_hi[col_groups];
|
|
547
|
+
|
|
548
|
+
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
549
|
+
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
550
|
+
int16_t bsums_arr[8];
|
|
551
|
+
vst1q_s16(bsums_arr, bsums);
|
|
552
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
553
|
+
for (int i = 0; i < col_groups; i++) {
|
|
554
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
555
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
556
|
+
}
|
|
557
|
+
// Need scales for the low and high nibbles
|
|
558
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
559
|
+
int16x8_t q4sb_mins[2];
|
|
560
|
+
int16x8_t q4sb_scales[2];
|
|
561
|
+
for (int i = 0; i < 2; i++) {
|
|
562
|
+
int8_t aux_q4sb[8];
|
|
563
|
+
const int offset = sb * 24 + i * 12;
|
|
564
|
+
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
565
|
+
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
int8x16_t q8_qs[64 / 16];
|
|
569
|
+
for (int i = 0; i < 64 / 16; i++) {
|
|
570
|
+
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
for (int c = 0; c < col_groups; c++) {
|
|
574
|
+
uint8x16_t q4_cols[8];
|
|
575
|
+
for (int i = 0; i < 8; i++) {
|
|
576
|
+
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
|
|
580
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
|
|
581
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
|
|
582
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
|
|
583
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
|
|
584
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
|
|
585
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
|
|
586
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
|
|
587
|
+
|
|
588
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
|
|
589
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
|
|
590
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
|
|
591
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
|
|
592
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
|
|
593
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
|
|
594
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
|
|
595
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
// Scales
|
|
599
|
+
// row c0123 blk0 and blk1
|
|
600
|
+
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
|
601
|
+
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
|
602
|
+
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
|
603
|
+
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
|
604
|
+
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
|
605
|
+
// row c4567 blk0 and blk1
|
|
606
|
+
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
|
607
|
+
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
|
608
|
+
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
|
609
|
+
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
|
610
|
+
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
|
611
|
+
|
|
612
|
+
// Bias Correction
|
|
613
|
+
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
614
|
+
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
615
|
+
|
|
616
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
617
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
618
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
619
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
620
|
+
} // for sb
|
|
621
|
+
|
|
622
|
+
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
|
623
|
+
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
|
624
|
+
} // for b
|
|
625
|
+
|
|
626
|
+
int base = x * ncols_interleaved;
|
|
627
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
628
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
629
|
+
} // for x
|
|
630
|
+
return;
|
|
631
|
+
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
632
|
+
wsp_ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
void wsp_ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
636
|
+
float * WSP_GGML_RESTRICT s,
|
|
637
|
+
size_t bs,
|
|
638
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
639
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
640
|
+
int nr,
|
|
641
|
+
int nc) {
|
|
642
|
+
constexpr int qk = QK_K;
|
|
643
|
+
const int nb = n / qk;
|
|
644
|
+
|
|
645
|
+
constexpr int ncols_interleaved = 8;
|
|
646
|
+
constexpr int blocklen = 8;
|
|
647
|
+
|
|
648
|
+
assert(n % qk == 0);
|
|
649
|
+
assert(nc % ncols_interleaved == 0);
|
|
650
|
+
|
|
651
|
+
UNUSED(nb);
|
|
652
|
+
UNUSED(ncols_interleaved);
|
|
653
|
+
UNUSED(blocklen);
|
|
654
|
+
|
|
655
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
656
|
+
constexpr int col_pairs = ncols_interleaved / 2;
|
|
657
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
658
|
+
|
|
659
|
+
// 1x8 tile = 2 x 4
|
|
660
|
+
float32x4_t acc_f32[ncols_interleaved / 4];
|
|
661
|
+
|
|
662
|
+
const block_q8_K * WSP_GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
663
|
+
|
|
664
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
665
|
+
const block_q4_Kx8 * WSP_GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
666
|
+
|
|
667
|
+
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
|
668
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
for (int b = 0; b < nb; b++) {
|
|
672
|
+
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
|
673
|
+
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
674
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
675
|
+
float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
|
|
676
|
+
float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
|
|
677
|
+
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
|
678
|
+
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
|
679
|
+
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
|
|
680
|
+
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
|
|
681
|
+
|
|
682
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
|
683
|
+
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
684
|
+
// 2 sb each iteration
|
|
685
|
+
int32x4_t acc_lo[col_pairs];
|
|
686
|
+
int32x4_t acc_hi[col_pairs];
|
|
687
|
+
|
|
688
|
+
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
689
|
+
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
690
|
+
int16_t bsums_arr[8];
|
|
691
|
+
vst1q_s16(bsums_arr, bsums);
|
|
692
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
693
|
+
for (int i = 0; i < col_pairs; i++) {
|
|
694
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
695
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
696
|
+
}
|
|
697
|
+
// Need scales for the low and high nibbles
|
|
698
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
699
|
+
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
|
700
|
+
int16x8_t q4sb_scales[2];
|
|
701
|
+
for (int i = 0; i < 2; i++) {
|
|
702
|
+
int8_t aux_q4sb[8];
|
|
703
|
+
const int offset = sb * 24 + i * 12;
|
|
704
|
+
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
705
|
+
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
|
|
709
|
+
|
|
710
|
+
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
|
|
711
|
+
// but still need the qs to use the low and hi bits from q4
|
|
712
|
+
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
|
713
|
+
int8x16_t q8_qs[8];
|
|
714
|
+
for (int i = 0; i < 8; i++) {
|
|
715
|
+
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
719
|
+
for (int cp = 0; cp < col_pairs; cp++) {
|
|
720
|
+
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
|
|
721
|
+
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
|
|
722
|
+
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
|
|
723
|
+
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
|
|
724
|
+
|
|
725
|
+
acc_lo[cp] =
|
|
726
|
+
wsp_ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
|
|
727
|
+
acc_lo[cp] =
|
|
728
|
+
wsp_ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
|
|
729
|
+
acc_lo[cp] =
|
|
730
|
+
wsp_ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
|
|
731
|
+
acc_lo[cp] =
|
|
732
|
+
wsp_ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
|
|
733
|
+
|
|
734
|
+
acc_hi[cp] =
|
|
735
|
+
wsp_ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
|
|
736
|
+
acc_hi[cp] =
|
|
737
|
+
wsp_ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
|
|
738
|
+
acc_hi[cp] =
|
|
739
|
+
wsp_ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
|
|
740
|
+
acc_hi[cp] =
|
|
741
|
+
wsp_ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
|
745
|
+
// p = 0 -> 0123 p2 -> 4567
|
|
746
|
+
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
|
747
|
+
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
|
|
748
|
+
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
|
|
749
|
+
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
|
750
|
+
|
|
751
|
+
// 0123 or 4567
|
|
752
|
+
float32x4_t sumf_0 =
|
|
753
|
+
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
|
754
|
+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
|
755
|
+
|
|
756
|
+
float32x4_t sumf_1 =
|
|
757
|
+
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
|
758
|
+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
// Multiply Acc bsum + mins
|
|
762
|
+
// Each pair of subblocks share the same bsums
|
|
763
|
+
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
|
764
|
+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
765
|
+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
766
|
+
|
|
767
|
+
// cols 0-3 bias
|
|
768
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
769
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
770
|
+
|
|
771
|
+
// cols 4-7 bias
|
|
772
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
773
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
774
|
+
} // for sb
|
|
775
|
+
|
|
776
|
+
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
|
|
777
|
+
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
|
|
778
|
+
} // for b
|
|
779
|
+
|
|
780
|
+
int base = x * ncols_interleaved;
|
|
781
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
782
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
783
|
+
} // for x
|
|
784
|
+
return;
|
|
785
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
786
|
+
wsp_ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
void wsp_ggml_gemv_q8_0_4x4_q8_0(int n,
|
|
790
|
+
float * WSP_GGML_RESTRICT s,
|
|
791
|
+
size_t bs,
|
|
792
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
793
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
794
|
+
int nr,
|
|
795
|
+
int nc) {
|
|
796
|
+
const int qk = QK8_0;
|
|
797
|
+
const int nb = n / qk;
|
|
798
|
+
const int ncols_interleaved = 4;
|
|
799
|
+
const int blocklen = 4;
|
|
800
|
+
|
|
801
|
+
assert(n % qk == 0);
|
|
802
|
+
assert(nc % ncols_interleaved == 0);
|
|
803
|
+
|
|
804
|
+
UNUSED(nb);
|
|
805
|
+
UNUSED(ncols_interleaved);
|
|
806
|
+
UNUSED(blocklen);
|
|
807
|
+
|
|
808
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
809
|
+
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
810
|
+
|
|
811
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
812
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
813
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
814
|
+
for (int b = 0; b < nb; b++) {
|
|
815
|
+
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
816
|
+
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
817
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
818
|
+
|
|
819
|
+
int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
|
|
820
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
821
|
+
|
|
822
|
+
int32x4_t ret = vdupq_n_s32(0);
|
|
823
|
+
|
|
824
|
+
ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
|
|
825
|
+
ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
|
|
826
|
+
ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
|
|
827
|
+
ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
|
|
828
|
+
|
|
829
|
+
ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
|
|
830
|
+
ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
|
|
831
|
+
ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
|
|
832
|
+
ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
|
|
833
|
+
|
|
834
|
+
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
835
|
+
a_ptr++;
|
|
836
|
+
b_ptr++;
|
|
837
|
+
}
|
|
838
|
+
vst1q_f32(s, acc);
|
|
839
|
+
s += ncols_interleaved;
|
|
840
|
+
}
|
|
841
|
+
return;
|
|
842
|
+
|
|
843
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
844
|
+
wsp_ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
void wsp_ggml_gemv_q8_0_4x8_q8_0(int n,
|
|
848
|
+
float * WSP_GGML_RESTRICT s,
|
|
849
|
+
size_t bs,
|
|
850
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
851
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
852
|
+
int nr,
|
|
853
|
+
int nc) {
|
|
854
|
+
const int qk = QK8_0;
|
|
855
|
+
const int nb = n / qk;
|
|
856
|
+
const int ncols_interleaved = 4;
|
|
857
|
+
const int blocklen = 8;
|
|
858
|
+
|
|
859
|
+
assert(n % qk == 0);
|
|
860
|
+
assert(nc % ncols_interleaved == 0);
|
|
861
|
+
|
|
862
|
+
UNUSED(nb);
|
|
863
|
+
UNUSED(ncols_interleaved);
|
|
864
|
+
UNUSED(blocklen);
|
|
865
|
+
|
|
866
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
867
|
+
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
868
|
+
|
|
869
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
870
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
871
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
872
|
+
|
|
873
|
+
for (int b = 0; b < nb; b++) {
|
|
874
|
+
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
875
|
+
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
876
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
877
|
+
|
|
878
|
+
int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
|
|
879
|
+
int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
|
|
880
|
+
int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
|
|
881
|
+
int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
|
|
882
|
+
int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
|
|
883
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
884
|
+
|
|
885
|
+
int32x4_t ret0 = vdupq_n_s32(0);
|
|
886
|
+
int32x4_t ret1 = vdupq_n_s32(0);
|
|
887
|
+
|
|
888
|
+
// 0..7
|
|
889
|
+
ret0 = vdotq_s32(ret0, b_low.val[0], a0);
|
|
890
|
+
ret1 = vdotq_s32(ret1, b_low.val[1], a0);
|
|
891
|
+
// 8..15
|
|
892
|
+
ret0 = vdotq_s32(ret0, b_low.val[2], a1);
|
|
893
|
+
ret1 = vdotq_s32(ret1, b_low.val[3], a1);
|
|
894
|
+
// 16..23
|
|
895
|
+
ret0 = vdotq_s32(ret0, b_high.val[0], a2);
|
|
896
|
+
ret1 = vdotq_s32(ret1, b_high.val[1], a2);
|
|
897
|
+
// 24..31
|
|
898
|
+
ret0 = vdotq_s32(ret0, b_high.val[2], a3);
|
|
899
|
+
ret1 = vdotq_s32(ret1, b_high.val[3], a3);
|
|
900
|
+
|
|
901
|
+
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
|
902
|
+
|
|
903
|
+
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
904
|
+
a_ptr++;
|
|
905
|
+
b_ptr++;
|
|
906
|
+
}
|
|
907
|
+
vst1q_f32(s, acc);
|
|
908
|
+
s += ncols_interleaved;
|
|
909
|
+
}
|
|
910
|
+
return;
|
|
911
|
+
|
|
912
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
913
|
+
wsp_ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
914
|
+
}
|
|
915
|
+
|
|
477
916
|
void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
478
917
|
const int qk = QK8_0;
|
|
479
918
|
const int nb = n / qk;
|
|
@@ -1889,3 +2328,568 @@ void wsp_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs
|
|
|
1889
2328
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
1890
2329
|
wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1891
2330
|
}
|
|
2331
|
+
|
|
2332
|
+
void wsp_ggml_gemm_q4_K_8x4_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
2333
|
+
constexpr int qk = QK_K;
|
|
2334
|
+
const int nb = n / qk;
|
|
2335
|
+
|
|
2336
|
+
constexpr int ncols_interleaved = 8;
|
|
2337
|
+
constexpr int blocklen = 4;
|
|
2338
|
+
|
|
2339
|
+
assert(n % qk == 0);
|
|
2340
|
+
assert(nr % 4 == 0);
|
|
2341
|
+
assert(nc % ncols_interleaved == 0);
|
|
2342
|
+
|
|
2343
|
+
UNUSED(nb);
|
|
2344
|
+
UNUSED(ncols_interleaved);
|
|
2345
|
+
UNUSED(blocklen);
|
|
2346
|
+
|
|
2347
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
2348
|
+
constexpr int q8_k_blocklen = 4;
|
|
2349
|
+
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
|
2350
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
2351
|
+
|
|
2352
|
+
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
2353
|
+
float32x4_t acc_f32[acc_size];
|
|
2354
|
+
|
|
2355
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
2356
|
+
const block_q8_Kx4 * WSP_GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
2357
|
+
|
|
2358
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2359
|
+
const block_q4_Kx8 * WSP_GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
2360
|
+
|
|
2361
|
+
for (int i = 0; i < acc_size; i++) {
|
|
2362
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
2363
|
+
}
|
|
2364
|
+
|
|
2365
|
+
for (int b = 0; b < nb; b++) {
|
|
2366
|
+
// d4 0 1 2 3, 4 5 6 7
|
|
2367
|
+
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
|
|
2368
|
+
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
|
2369
|
+
// d8 0 1 2 3
|
|
2370
|
+
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
2371
|
+
// mins
|
|
2372
|
+
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
|
2373
|
+
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
|
2374
|
+
|
|
2375
|
+
// Precomputation of scales and mins
|
|
2376
|
+
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
2377
|
+
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
2378
|
+
float32x4_t sbd_min_0123[q8_k_blocklen];
|
|
2379
|
+
float32x4_t sbd_min_4567[q8_k_blocklen];
|
|
2380
|
+
|
|
2381
|
+
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
|
2382
|
+
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
|
2383
|
+
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
|
2384
|
+
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
|
2385
|
+
|
|
2386
|
+
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
|
2387
|
+
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
|
2388
|
+
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
|
2389
|
+
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
|
2390
|
+
|
|
2391
|
+
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
|
2392
|
+
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
|
2393
|
+
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
|
2394
|
+
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
|
2395
|
+
|
|
2396
|
+
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
|
2397
|
+
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
|
2398
|
+
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
|
2399
|
+
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
|
2400
|
+
|
|
2401
|
+
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
|
2402
|
+
const int16x8_t bsums[q8_k_blocklen] = {
|
|
2403
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
2404
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
2405
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
2406
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
2407
|
+
};
|
|
2408
|
+
int16_t bsums_arr[QK_K / 64][8];
|
|
2409
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
2410
|
+
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
2411
|
+
}
|
|
2412
|
+
|
|
2413
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
|
2414
|
+
int32x4_t bias_acc[acc_size];
|
|
2415
|
+
for (int i = 0; i < acc_size; i++) {
|
|
2416
|
+
bias_acc[i] = vdupq_n_s32(0);
|
|
2417
|
+
}
|
|
2418
|
+
|
|
2419
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
2420
|
+
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
|
2421
|
+
int32x4_t acc_lo[acc_size];
|
|
2422
|
+
int32x4_t acc_hi[acc_size];
|
|
2423
|
+
for (int i = 0; i < acc_size; i++) {
|
|
2424
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
2425
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
2426
|
+
}
|
|
2427
|
+
// Need scales for the low and high nibbles
|
|
2428
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
2429
|
+
int16x8_t q4sb_scales[2];
|
|
2430
|
+
int16x8_t q4sb_mins[2];
|
|
2431
|
+
for (int i = 0; i < 2; i++) {
|
|
2432
|
+
int8_t aux_q4sb[8];
|
|
2433
|
+
const int offset = sb * 24 + i * 12;
|
|
2434
|
+
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
2435
|
+
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
2436
|
+
}
|
|
2437
|
+
|
|
2438
|
+
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
|
2439
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
2440
|
+
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
|
2441
|
+
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
|
2442
|
+
|
|
2443
|
+
// 0..3 & 32..35
|
|
2444
|
+
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
|
|
2445
|
+
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
|
2446
|
+
|
|
2447
|
+
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
|
|
2448
|
+
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
|
|
2449
|
+
|
|
2450
|
+
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
|
2451
|
+
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
|
2452
|
+
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
|
2453
|
+
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
|
2454
|
+
|
|
2455
|
+
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
|
2456
|
+
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
|
2457
|
+
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
|
2458
|
+
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
|
2459
|
+
|
|
2460
|
+
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
|
|
2461
|
+
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
|
|
2462
|
+
|
|
2463
|
+
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
|
2464
|
+
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
|
2465
|
+
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
|
2466
|
+
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
|
2467
|
+
|
|
2468
|
+
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
|
2469
|
+
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
|
2470
|
+
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
|
2471
|
+
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
|
2472
|
+
}
|
|
2473
|
+
|
|
2474
|
+
// Scale and bias application
|
|
2475
|
+
// acc is stored interleaved to match output layout
|
|
2476
|
+
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
|
2477
|
+
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
|
2478
|
+
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
|
2479
|
+
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
|
2480
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
2481
|
+
// Bias correction
|
|
2482
|
+
// row c0123 blk0 and blk1
|
|
2483
|
+
const float32x4_t sumf_0123 =
|
|
2484
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
|
2485
|
+
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
|
2486
|
+
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
|
2487
|
+
|
|
2488
|
+
// row c4567 blk0 and blk1
|
|
2489
|
+
const float32x4_t sumf_4567 =
|
|
2490
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
|
2491
|
+
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
|
2492
|
+
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
|
2493
|
+
|
|
2494
|
+
// Bias
|
|
2495
|
+
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
|
2496
|
+
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
|
2497
|
+
|
|
2498
|
+
// row c0123 blk0 and blk1
|
|
2499
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
2500
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
2501
|
+
|
|
2502
|
+
// row c4567 blk0 and blk1
|
|
2503
|
+
bias_acc[2 * row + 1] =
|
|
2504
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
2505
|
+
bias_acc[2 * row + 1] =
|
|
2506
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
2507
|
+
}
|
|
2508
|
+
} // for sb
|
|
2509
|
+
|
|
2510
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
2511
|
+
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
|
2512
|
+
acc_f32[2 * row + 1] =
|
|
2513
|
+
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
|
2514
|
+
}
|
|
2515
|
+
} // for b
|
|
2516
|
+
|
|
2517
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
2518
|
+
int row = y * q8_k_blocklen + i;
|
|
2519
|
+
for (int j = 0; j < 2; j++) {
|
|
2520
|
+
int col = x * ncols_interleaved + j * 4;
|
|
2521
|
+
int offset = row * bs + col;
|
|
2522
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
2523
|
+
}
|
|
2524
|
+
}
|
|
2525
|
+
} // for x
|
|
2526
|
+
} // for y
|
|
2527
|
+
return;
|
|
2528
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
2529
|
+
wsp_ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
2530
|
+
}
|
|
2531
|
+
|
|
2532
|
+
void wsp_ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
2533
|
+
float * WSP_GGML_RESTRICT s,
|
|
2534
|
+
size_t bs,
|
|
2535
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
2536
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
2537
|
+
int nr,
|
|
2538
|
+
int nc) {
|
|
2539
|
+
constexpr int qk = QK_K;
|
|
2540
|
+
const int nb = n / qk;
|
|
2541
|
+
|
|
2542
|
+
constexpr int ncols_interleaved = 8;
|
|
2543
|
+
constexpr int blocklen = 8;
|
|
2544
|
+
|
|
2545
|
+
assert(n % qk == 0);
|
|
2546
|
+
assert(nr % 4 == 0);
|
|
2547
|
+
assert(nc % ncols_interleaved == 0);
|
|
2548
|
+
|
|
2549
|
+
UNUSED(nb);
|
|
2550
|
+
UNUSED(ncols_interleaved);
|
|
2551
|
+
UNUSED(blocklen);
|
|
2552
|
+
|
|
2553
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2554
|
+
constexpr int q8_k_blocklen = 4;
|
|
2555
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
2556
|
+
|
|
2557
|
+
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
2558
|
+
float32x4_t acc_f32[blocklen];
|
|
2559
|
+
|
|
2560
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
2561
|
+
const block_q8_Kx4 * WSP_GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
2562
|
+
|
|
2563
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2564
|
+
const block_q4_Kx8 * WSP_GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
2565
|
+
|
|
2566
|
+
for (int i = 0; i < blocklen; i++) {
|
|
2567
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
2568
|
+
}
|
|
2569
|
+
|
|
2570
|
+
for (int b = 0; b < nb; b++) {
|
|
2571
|
+
// bsums pairs belongs to the same q8_k subblock
|
|
2572
|
+
const int16x8_t bsums[4]{
|
|
2573
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
2574
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
2575
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
2576
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
2577
|
+
};
|
|
2578
|
+
int16_t bsums_arr[4][8];
|
|
2579
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
2580
|
+
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
2581
|
+
}
|
|
2582
|
+
|
|
2583
|
+
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
|
2584
|
+
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
|
2585
|
+
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
|
2586
|
+
for (int i = 0; i < 8; i++) {
|
|
2587
|
+
acc[i] = vdupq_n_s32(0);
|
|
2588
|
+
bias_acc[i] = vdupq_n_s32(0);
|
|
2589
|
+
}
|
|
2590
|
+
|
|
2591
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
2592
|
+
// Need scales for the low and high nibbles
|
|
2593
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
2594
|
+
int8_t q4sb_scales[2][8];
|
|
2595
|
+
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
|
2596
|
+
for (int i = 0; i < 2; i++) {
|
|
2597
|
+
const int offset = sb * 24 + i * 12;
|
|
2598
|
+
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
|
2599
|
+
}
|
|
2600
|
+
|
|
2601
|
+
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
|
2602
|
+
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
|
2603
|
+
|
|
2604
|
+
int8x16_t q8_qs_01[8];
|
|
2605
|
+
int8x16_t q8_qs_23[8];
|
|
2606
|
+
|
|
2607
|
+
// Load 32-byte per row pair, 1 subblock each time
|
|
2608
|
+
for (int i = 0; i < 8; i++) {
|
|
2609
|
+
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
|
2610
|
+
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
|
2611
|
+
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
const int8x16_t q8s[2][8] = {
|
|
2615
|
+
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
|
2616
|
+
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
|
|
2617
|
+
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
|
2618
|
+
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
|
|
2619
|
+
};
|
|
2620
|
+
|
|
2621
|
+
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
2622
|
+
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
|
2623
|
+
for (int i = 0; i < 4; i++) {
|
|
2624
|
+
sb_acc[i] = vdupq_n_s32(0);
|
|
2625
|
+
}
|
|
2626
|
+
|
|
2627
|
+
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
|
2628
|
+
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
|
2629
|
+
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
|
2630
|
+
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
|
2631
|
+
const int8x16_t q4_nibbles[2][4] = {
|
|
2632
|
+
{
|
|
2633
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
|
|
2634
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
|
|
2635
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
|
|
2636
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
|
|
2637
|
+
},
|
|
2638
|
+
{
|
|
2639
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
|
|
2640
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
|
|
2641
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
|
|
2642
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
|
|
2643
|
+
}
|
|
2644
|
+
};
|
|
2645
|
+
|
|
2646
|
+
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
|
|
2647
|
+
// for each of the internal 32 qs subblock (blk)
|
|
2648
|
+
for (int rp = 0; rp < 2; rp++) {
|
|
2649
|
+
for (int blk = 0; blk < 2; blk++) {
|
|
2650
|
+
const int8x16_t * q8 = &q8s[rp][4 * blk];
|
|
2651
|
+
const int8x16_t * q4 = q4_nibbles[blk];
|
|
2652
|
+
int32x4_t acc = sb_acc[2 * rp + blk];
|
|
2653
|
+
// mul add for each qs in the same subblock
|
|
2654
|
+
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
|
|
2655
|
+
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
|
|
2656
|
+
}
|
|
2657
|
+
sb_acc[2 * rp + blk] = acc;
|
|
2658
|
+
}
|
|
2659
|
+
}
|
|
2660
|
+
|
|
2661
|
+
// Scales[i] corresponds to column i
|
|
2662
|
+
const int scale_offset = cp * 2;
|
|
2663
|
+
for (int blk = 0; blk < 2; blk++) {
|
|
2664
|
+
const int32x4_t block_scale = {
|
|
2665
|
+
(int32_t) q4sb_scales[blk][scale_offset],
|
|
2666
|
+
(int32_t) q4sb_scales[blk][scale_offset],
|
|
2667
|
+
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
|
2668
|
+
(int32_t) q4sb_scales[blk][scale_offset + 1],
|
|
2669
|
+
};
|
|
2670
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
|
|
2671
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
|
|
2672
|
+
}
|
|
2673
|
+
}
|
|
2674
|
+
|
|
2675
|
+
// Multiply Acc bsum + mins
|
|
2676
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
2677
|
+
// Each pair of subblocks share the same bsums
|
|
2678
|
+
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
|
2679
|
+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
|
2680
|
+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
|
2681
|
+
|
|
2682
|
+
bias_acc[2 * q8_row] =
|
|
2683
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
2684
|
+
bias_acc[2 * q8_row] =
|
|
2685
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
2686
|
+
bias_acc[2 * q8_row + 1] =
|
|
2687
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
2688
|
+
bias_acc[2 * q8_row + 1] =
|
|
2689
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
2690
|
+
}
|
|
2691
|
+
} // for sb
|
|
2692
|
+
|
|
2693
|
+
// Reorder of i8mm output with bias and output layout
|
|
2694
|
+
for (int i = 0; i < 8; i++) {
|
|
2695
|
+
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
|
2696
|
+
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
|
2697
|
+
}
|
|
2698
|
+
int32x4_t reorder_acc[8] = {
|
|
2699
|
+
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
|
2700
|
+
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
|
2701
|
+
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
|
2702
|
+
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
|
2703
|
+
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
|
2704
|
+
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
|
2705
|
+
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
|
2706
|
+
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
|
2707
|
+
};
|
|
2708
|
+
|
|
2709
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
2710
|
+
for (int j = 0; j < 2; j++) {
|
|
2711
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
|
2712
|
+
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
|
|
2713
|
+
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
|
|
2714
|
+
|
|
2715
|
+
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
|
|
2716
|
+
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
|
|
2717
|
+
|
|
2718
|
+
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
|
2719
|
+
acc_f32[2 * i + j] =
|
|
2720
|
+
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
|
2721
|
+
}
|
|
2722
|
+
}
|
|
2723
|
+
} // for b
|
|
2724
|
+
|
|
2725
|
+
// With the previous reorder, the tile is already in the correct memory layout.
|
|
2726
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
2727
|
+
int row = y * q8_k_blocklen + i;
|
|
2728
|
+
for (int j = 0; j < 2; j++) {
|
|
2729
|
+
int col = x * ncols_interleaved + j * 4;
|
|
2730
|
+
int offset = row * bs + col;
|
|
2731
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
2732
|
+
}
|
|
2733
|
+
}
|
|
2734
|
+
} // for x
|
|
2735
|
+
} // for y
|
|
2736
|
+
return;
|
|
2737
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2738
|
+
wsp_ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
2739
|
+
}
|
|
2740
|
+
|
|
2741
|
+
|
|
2742
|
+
void wsp_ggml_gemm_q8_0_4x4_q8_0(int n,
|
|
2743
|
+
float * WSP_GGML_RESTRICT s,
|
|
2744
|
+
size_t bs,
|
|
2745
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
2746
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
2747
|
+
int nr,
|
|
2748
|
+
int nc) {
|
|
2749
|
+
const int qk = QK8_0;
|
|
2750
|
+
const int nb = n / qk;
|
|
2751
|
+
const int ncols_interleaved = 4;
|
|
2752
|
+
const int blocklen = 4;
|
|
2753
|
+
|
|
2754
|
+
assert(n % qk == 0);
|
|
2755
|
+
assert(nr % 4 == 0);
|
|
2756
|
+
assert(nc % ncols_interleaved == 0);
|
|
2757
|
+
|
|
2758
|
+
UNUSED(nb);
|
|
2759
|
+
UNUSED(ncols_interleaved);
|
|
2760
|
+
UNUSED(blocklen);
|
|
2761
|
+
|
|
2762
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
2763
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
2764
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
2765
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2766
|
+
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
|
2767
|
+
|
|
2768
|
+
float32x4_t sumf[4];
|
|
2769
|
+
for (int m = 0; m < 4; m++) {
|
|
2770
|
+
sumf[m] = vdupq_n_f32(0);
|
|
2771
|
+
}
|
|
2772
|
+
|
|
2773
|
+
for (int l = 0; l < nb; l++) {
|
|
2774
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
|
|
2775
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
|
|
2776
|
+
|
|
2777
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
2778
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
2779
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
2780
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
2781
|
+
|
|
2782
|
+
for (int k_group = 0; k_group < 8; k_group += 4) {
|
|
2783
|
+
int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
|
|
2784
|
+
int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
|
|
2785
|
+
|
|
2786
|
+
for (int k = 0; k < 4; k++) {
|
|
2787
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
|
|
2788
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
|
|
2789
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
|
|
2790
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
|
|
2791
|
+
}
|
|
2792
|
+
}
|
|
2793
|
+
|
|
2794
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
2795
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
2796
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
2797
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
2798
|
+
}
|
|
2799
|
+
|
|
2800
|
+
for (int m = 0; m < 4; m++) {
|
|
2801
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
2802
|
+
}
|
|
2803
|
+
}
|
|
2804
|
+
}
|
|
2805
|
+
return;
|
|
2806
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
2807
|
+
wsp_ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
2808
|
+
}
|
|
2809
|
+
|
|
2810
|
+
void wsp_ggml_gemm_q8_0_4x8_q8_0(int n,
|
|
2811
|
+
float * WSP_GGML_RESTRICT s,
|
|
2812
|
+
size_t bs,
|
|
2813
|
+
const void * WSP_GGML_RESTRICT vx,
|
|
2814
|
+
const void * WSP_GGML_RESTRICT vy,
|
|
2815
|
+
int nr,
|
|
2816
|
+
int nc) {
|
|
2817
|
+
const int qk = QK8_0;
|
|
2818
|
+
const int nb = n / qk;
|
|
2819
|
+
const int ncols_interleaved = 4;
|
|
2820
|
+
const int blocklen = 8;
|
|
2821
|
+
|
|
2822
|
+
assert(n % qk == 0);
|
|
2823
|
+
assert(nr % 4 == 0);
|
|
2824
|
+
assert(nc % ncols_interleaved == 0);
|
|
2825
|
+
|
|
2826
|
+
UNUSED(nb);
|
|
2827
|
+
UNUSED(ncols_interleaved);
|
|
2828
|
+
UNUSED(blocklen);
|
|
2829
|
+
|
|
2830
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2831
|
+
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
|
2832
|
+
|
|
2833
|
+
for (int y = 0; y < nr; y += 4) {
|
|
2834
|
+
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
|
|
2835
|
+
|
|
2836
|
+
for (int x = 0; x < nc; x += ncols_interleaved) {
|
|
2837
|
+
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
|
|
2838
|
+
const block_q8_0x4 * a_ptr = a_ptr_base;
|
|
2839
|
+
|
|
2840
|
+
float32x4_t acc_f32[4];
|
|
2841
|
+
for (int i = 0; i < 4; i++) {
|
|
2842
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
2843
|
+
}
|
|
2844
|
+
|
|
2845
|
+
for (int b = 0; b < nb; b++) {
|
|
2846
|
+
int32x4_t acc[4];
|
|
2847
|
+
for (int i = 0; i < 4; i++) {
|
|
2848
|
+
acc[i] = vdupq_n_s32(0);
|
|
2849
|
+
}
|
|
2850
|
+
|
|
2851
|
+
// Process 4 chunks of 8 positions each
|
|
2852
|
+
for (int chunk = 0; chunk < 4; chunk++) {
|
|
2853
|
+
int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
|
|
2854
|
+
int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
|
|
2855
|
+
int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
|
|
2856
|
+
int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
|
|
2857
|
+
|
|
2858
|
+
acc[0] = vmmlaq_s32(acc[0], a01, b01);
|
|
2859
|
+
acc[1] = vmmlaq_s32(acc[1], a01, b23);
|
|
2860
|
+
acc[2] = vmmlaq_s32(acc[2], a23, b01);
|
|
2861
|
+
acc[3] = vmmlaq_s32(acc[3], a23, b23);
|
|
2862
|
+
}
|
|
2863
|
+
|
|
2864
|
+
// Reorder outputs from 2×2 tiles to row-major
|
|
2865
|
+
// acc[0] = [r0c0, r0c1, r1c0, r1c1]
|
|
2866
|
+
// acc[1] = [r0c2, r0c3, r1c2, r1c3]
|
|
2867
|
+
// acc[2] = [r2c0, r2c1, r3c0, r3c1]
|
|
2868
|
+
// acc[3] = [r2c2, r2c3, r3c2, r3c3]
|
|
2869
|
+
int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
|
|
2870
|
+
int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
|
|
2871
|
+
int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
|
|
2872
|
+
int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
|
|
2873
|
+
|
|
2874
|
+
// Scales
|
|
2875
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
|
|
2876
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
|
|
2877
|
+
|
|
2878
|
+
acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
|
|
2879
|
+
acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
|
|
2880
|
+
acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
|
|
2881
|
+
acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
|
|
2882
|
+
|
|
2883
|
+
a_ptr++;
|
|
2884
|
+
b_ptr++;
|
|
2885
|
+
}
|
|
2886
|
+
|
|
2887
|
+
for (int row = 0; row < 4; row++) {
|
|
2888
|
+
vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
|
|
2889
|
+
}
|
|
2890
|
+
}
|
|
2891
|
+
}
|
|
2892
|
+
return;
|
|
2893
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2894
|
+
wsp_ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
2895
|
+
}
|