cui-llama.rn 1.4.6 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -2
- package/android/src/main/jni.cpp +52 -34
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-metal-impl.h +77 -3
- package/cpp/ggml-metal.m +794 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +29 -5
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +44 -53
- package/cpp/rn-llama.h +2 -12
- package/cpp/sampling.cpp +3 -0
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +5 -2
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.mm +40 -24
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +6 -4
- package/src/index.ts +3 -1
- package/cpp/chat-template.hpp +0 -529
- package/cpp/minja.hpp +0 -2915
package/cpp/ggml-cpu-quants.c
CHANGED
@@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
|
|
891
891
|
}
|
892
892
|
#elif defined(__riscv_v_intrinsic)
|
893
893
|
|
894
|
-
size_t vl =
|
894
|
+
size_t vl = QK8_0;
|
895
895
|
|
896
896
|
for (int i = 0; i < nb; i++) {
|
897
897
|
// load elements
|
898
|
-
|
898
|
+
vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
|
899
899
|
|
900
|
-
|
900
|
+
vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
|
901
901
|
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
|
902
|
-
vfloat32m1_t vmax =
|
902
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
|
903
903
|
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
904
904
|
|
905
905
|
const float d = amax / ((1 << 7) - 1);
|
@@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
|
|
907
907
|
|
908
908
|
y[i].d = LM_GGML_FP32_TO_FP16(d);
|
909
909
|
|
910
|
-
|
910
|
+
vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
|
911
911
|
|
912
912
|
// convert to integer
|
913
|
-
|
914
|
-
|
913
|
+
vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
|
914
|
+
vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
|
915
915
|
|
916
916
|
// store result
|
917
|
-
|
917
|
+
__riscv_vse8_v_i8m2(y[i].qs , vs, vl);
|
918
918
|
}
|
919
919
|
|
920
920
|
#elif defined(__POWER9_VECTOR__)
|
@@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
|
|
1229
1229
|
}
|
1230
1230
|
#elif defined(__riscv_v_intrinsic)
|
1231
1231
|
|
1232
|
-
size_t vl =
|
1232
|
+
size_t vl = QK8_1;
|
1233
1233
|
|
1234
1234
|
for (int i = 0; i < nb; i++) {
|
1235
1235
|
// load elements
|
1236
|
-
|
1236
|
+
vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
|
1237
1237
|
|
1238
|
-
|
1238
|
+
vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
|
1239
1239
|
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
|
1240
|
-
vfloat32m1_t vmax =
|
1240
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
|
1241
1241
|
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
1242
1242
|
|
1243
1243
|
const float d = amax / ((1 << 7) - 1);
|
@@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
|
|
1245
1245
|
|
1246
1246
|
y[i].d = LM_GGML_FP32_TO_FP16(d);
|
1247
1247
|
|
1248
|
-
|
1248
|
+
vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
|
1249
1249
|
|
1250
1250
|
// convert to integer
|
1251
|
-
|
1252
|
-
|
1251
|
+
vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
|
1252
|
+
vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
|
1253
1253
|
|
1254
1254
|
// store result
|
1255
|
-
|
1255
|
+
__riscv_vse8_v_i8m2(y[i].qs , vs, vl);
|
1256
1256
|
|
1257
1257
|
// compute sum for y[i].s
|
1258
1258
|
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
|
1259
|
-
vint16m1_t vwrs =
|
1259
|
+
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
|
1260
1260
|
|
1261
1261
|
// set y[i].s
|
1262
1262
|
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
|
@@ -2391,33 +2391,31 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
2391
2391
|
|
2392
2392
|
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
2393
2393
|
#elif defined(__riscv_v_intrinsic)
|
2394
|
-
size_t vl =
|
2394
|
+
size_t vl = qk / 2;
|
2395
2395
|
|
2396
2396
|
for (; ib < nb; ++ib) {
|
2397
2397
|
// load elements
|
2398
|
-
|
2398
|
+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
2399
2399
|
|
2400
|
-
|
2401
|
-
|
2400
|
+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
|
2401
|
+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
|
2402
2402
|
|
2403
2403
|
// mask and store lower part of x, and then upper part
|
2404
|
-
|
2405
|
-
|
2404
|
+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
|
2405
|
+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
|
2406
2406
|
|
2407
|
-
|
2408
|
-
|
2407
|
+
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
|
2408
|
+
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
|
2409
2409
|
|
2410
2410
|
// subtract offset
|
2411
|
-
|
2412
|
-
|
2411
|
+
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
|
2412
|
+
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
|
2413
2413
|
|
2414
|
-
|
2415
|
-
|
2414
|
+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
|
2415
|
+
vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
|
2416
2416
|
|
2417
2417
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2418
|
-
|
2419
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2420
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2418
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
|
2421
2419
|
|
2422
2420
|
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2423
2421
|
|
@@ -2783,29 +2781,27 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
2783
2781
|
|
2784
2782
|
sumf = hsum_float_8(acc) + summs;
|
2785
2783
|
#elif defined(__riscv_v_intrinsic)
|
2786
|
-
size_t vl =
|
2784
|
+
size_t vl = qk / 2;
|
2787
2785
|
|
2788
2786
|
for (; ib < nb; ++ib) {
|
2789
2787
|
// load elements
|
2790
|
-
|
2788
|
+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
2791
2789
|
|
2792
|
-
|
2793
|
-
|
2790
|
+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
|
2791
|
+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
|
2794
2792
|
|
2795
2793
|
// mask and store lower part of x, and then upper part
|
2796
|
-
|
2797
|
-
|
2794
|
+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
|
2795
|
+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
|
2798
2796
|
|
2799
|
-
|
2800
|
-
|
2797
|
+
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
|
2798
|
+
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
|
2801
2799
|
|
2802
|
-
|
2803
|
-
|
2800
|
+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
|
2801
|
+
vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
|
2804
2802
|
|
2805
2803
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
2806
|
-
|
2807
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
2808
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
2804
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
|
2809
2805
|
|
2810
2806
|
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
2811
2807
|
|
@@ -3132,65 +3128,33 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
3132
3128
|
|
3133
3129
|
sumf = hsum_float_8(acc);
|
3134
3130
|
#elif defined(__riscv_v_intrinsic)
|
3135
|
-
|
3136
|
-
|
3137
|
-
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3138
|
-
|
3139
|
-
// These temporary registers are for masking and shift operations
|
3140
|
-
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3141
|
-
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
3142
|
-
|
3143
|
-
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
|
3144
|
-
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3131
|
+
size_t vl;
|
3132
|
+
size_t vlenb = __riscv_vlenb();
|
3145
3133
|
|
3146
3134
|
for (; ib < nb; ++ib) {
|
3147
|
-
|
3148
|
-
|
3149
|
-
|
3150
|
-
|
3151
|
-
|
3152
|
-
|
3153
|
-
|
3154
|
-
|
3155
|
-
|
3156
|
-
|
3157
|
-
|
3158
|
-
|
3159
|
-
|
3160
|
-
|
3161
|
-
|
3162
|
-
|
3163
|
-
|
3164
|
-
|
3165
|
-
|
3166
|
-
|
3167
|
-
|
3168
|
-
|
3169
|
-
|
3170
|
-
|
3171
|
-
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3172
|
-
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3173
|
-
|
3174
|
-
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3175
|
-
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3176
|
-
|
3177
|
-
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3178
|
-
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3179
|
-
|
3180
|
-
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
|
3181
|
-
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
|
3182
|
-
|
3183
|
-
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3184
|
-
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3185
|
-
|
3186
|
-
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3187
|
-
|
3188
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3189
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3190
|
-
|
3191
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3192
|
-
|
3193
|
-
sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi;
|
3135
|
+
vl = qk / 2;
|
3136
|
+
vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
3137
|
+
vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
|
3138
|
+
vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
|
3139
|
+
vint8m2_t v0c;
|
3140
|
+
if (vlenb == 16) {
|
3141
|
+
v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
|
3142
|
+
} else {
|
3143
|
+
v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
|
3144
|
+
v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
|
3145
|
+
}
|
3146
|
+
|
3147
|
+
vl = qk;
|
3148
|
+
vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
|
3149
|
+
qh = __riscv_vmnand_mm_b4(qh, qh, vl);
|
3150
|
+
vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
|
3151
|
+
vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
3152
|
+
vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
|
3153
|
+
vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3154
|
+
vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
|
3155
|
+
int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
|
3156
|
+
|
3157
|
+
sumf += (LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi;
|
3194
3158
|
}
|
3195
3159
|
|
3196
3160
|
#elif defined(__POWER9_VECTOR__)
|
@@ -3503,60 +3467,30 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
3503
3467
|
|
3504
3468
|
sumf = hsum_float_8(acc) + summs;
|
3505
3469
|
#elif defined(__riscv_v_intrinsic)
|
3506
|
-
|
3507
|
-
|
3508
|
-
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
3509
|
-
|
3510
|
-
// temporary registers for shift operations
|
3511
|
-
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
3512
|
-
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
3470
|
+
size_t vl;
|
3471
|
+
size_t vlenb = __riscv_vlenb();
|
3513
3472
|
|
3514
3473
|
for (; ib < nb; ++ib) {
|
3515
|
-
|
3516
|
-
|
3517
|
-
|
3518
|
-
|
3519
|
-
|
3520
|
-
|
3521
|
-
|
3522
|
-
|
3523
|
-
|
3524
|
-
|
3525
|
-
|
3526
|
-
|
3527
|
-
|
3528
|
-
|
3529
|
-
|
3530
|
-
|
3531
|
-
|
3532
|
-
|
3533
|
-
|
3534
|
-
|
3535
|
-
|
3536
|
-
// load
|
3537
|
-
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
|
3538
|
-
|
3539
|
-
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
|
3540
|
-
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
|
3541
|
-
|
3542
|
-
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
3543
|
-
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
3544
|
-
|
3545
|
-
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
3546
|
-
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
3547
|
-
|
3548
|
-
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
3549
|
-
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
3550
|
-
|
3551
|
-
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
3552
|
-
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
3553
|
-
|
3554
|
-
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3555
|
-
|
3556
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
3557
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
3558
|
-
|
3559
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
3474
|
+
vl = qk / 2;
|
3475
|
+
vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
3476
|
+
vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
|
3477
|
+
vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
|
3478
|
+
vint8m2_t v0c;
|
3479
|
+
if (vlenb == 16) {
|
3480
|
+
v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
|
3481
|
+
} else {
|
3482
|
+
v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
|
3483
|
+
v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
|
3484
|
+
}
|
3485
|
+
|
3486
|
+
vl = qk;
|
3487
|
+
vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
|
3488
|
+
vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
|
3489
|
+
vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
3490
|
+
vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
|
3491
|
+
vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3492
|
+
vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
|
3493
|
+
int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
|
3560
3494
|
|
3561
3495
|
sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s);
|
3562
3496
|
}
|
@@ -3970,17 +3904,17 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
3970
3904
|
|
3971
3905
|
sumf = hsum_float_8(accum);
|
3972
3906
|
#elif defined(__riscv_v_intrinsic)
|
3973
|
-
size_t vl =
|
3907
|
+
size_t vl = qk;
|
3974
3908
|
|
3975
3909
|
for (; ib < nb; ++ib) {
|
3976
3910
|
// load elements
|
3977
|
-
|
3978
|
-
|
3911
|
+
vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
|
3912
|
+
vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
3979
3913
|
|
3980
|
-
|
3914
|
+
vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
|
3981
3915
|
|
3982
3916
|
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
3983
|
-
vint32m1_t v_sum =
|
3917
|
+
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
|
3984
3918
|
|
3985
3919
|
int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
|
3986
3920
|
|
@@ -5174,84 +5108,182 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
5174
5108
|
|
5175
5109
|
#elif defined __riscv_v_intrinsic
|
5176
5110
|
|
5111
|
+
const int vector_length = __riscv_vlenb() * 8;
|
5177
5112
|
float sumf = 0;
|
5178
|
-
uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
5179
|
-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
5180
|
-
|
5181
|
-
for (int i = 0; i < nb; ++i) {
|
5182
5113
|
|
5183
|
-
|
5184
|
-
|
5185
|
-
|
5114
|
+
uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
5115
|
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
|
5116
|
+
uint8_t atmp[16];
|
5186
5117
|
|
5187
|
-
|
5188
|
-
|
5118
|
+
switch (vector_length) {
|
5119
|
+
case 256:
|
5120
|
+
for (int i = 0; i < nb; ++i) {
|
5121
|
+
const uint8_t * q2 = x[i].qs;
|
5122
|
+
const int8_t * q8 = y[i].qs;
|
5123
|
+
const uint8_t * sc = x[i].scales;
|
5189
5124
|
|
5190
|
-
|
5125
|
+
const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
5126
|
+
const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
|
5191
5127
|
|
5192
|
-
|
5193
|
-
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
|
5128
|
+
size_t vl = 16;
|
5194
5129
|
|
5195
|
-
|
5130
|
+
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
|
5131
|
+
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
|
5196
5132
|
|
5197
|
-
|
5198
|
-
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
|
5199
|
-
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
5200
|
-
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
|
5201
|
-
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
5133
|
+
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
|
5202
5134
|
|
5203
|
-
|
5135
|
+
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
|
5136
|
+
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
|
5137
|
+
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
5138
|
+
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
|
5139
|
+
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
5204
5140
|
|
5205
|
-
|
5141
|
+
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
|
5206
5142
|
|
5207
|
-
|
5208
|
-
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
|
5143
|
+
vl = 32;
|
5209
5144
|
|
5210
|
-
|
5211
|
-
|
5145
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
5146
|
+
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
|
5212
5147
|
|
5213
|
-
|
5214
|
-
|
5215
|
-
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
|
5148
|
+
uint8_t is = 0;
|
5149
|
+
int isum = 0;
|
5216
5150
|
|
5217
|
-
|
5218
|
-
|
5219
|
-
|
5220
|
-
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
|
5151
|
+
for (int j = 0; j < QK_K / 128; ++j) {
|
5152
|
+
// load Q2
|
5153
|
+
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
|
5221
5154
|
|
5222
|
-
|
5223
|
-
|
5224
|
-
|
5225
|
-
|
5226
|
-
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
|
5155
|
+
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
|
5156
|
+
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
|
5157
|
+
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
|
5158
|
+
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
|
5227
5159
|
|
5228
|
-
|
5229
|
-
|
5230
|
-
|
5231
|
-
|
5160
|
+
// duplicate scale elements for product
|
5161
|
+
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
|
5162
|
+
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
|
5163
|
+
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
|
5164
|
+
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
|
5232
5165
|
|
5233
|
-
|
5234
|
-
|
5235
|
-
|
5236
|
-
|
5237
|
-
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
|
5166
|
+
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
|
5167
|
+
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
|
5168
|
+
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
|
5169
|
+
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
|
5238
5170
|
|
5239
|
-
|
5240
|
-
|
5241
|
-
|
5242
|
-
|
5171
|
+
// load Q8
|
5172
|
+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
|
5173
|
+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
|
5174
|
+
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
|
5175
|
+
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
|
5243
5176
|
|
5244
|
-
|
5245
|
-
|
5177
|
+
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
|
5178
|
+
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
|
5179
|
+
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
|
5180
|
+
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
|
5246
5181
|
|
5247
|
-
|
5182
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
|
5183
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
|
5248
5184
|
|
5249
|
-
|
5185
|
+
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
|
5250
5186
|
|
5251
|
-
|
5187
|
+
q2 += 32;
|
5188
|
+
q8 += 128;
|
5189
|
+
is = 8;
|
5190
|
+
}
|
5252
5191
|
|
5253
|
-
|
5192
|
+
sumf += dall * isum;
|
5193
|
+
}
|
5194
|
+
break;
|
5195
|
+
case 128:
|
5196
|
+
for (int i = 0; i < nb; ++i) {
|
5197
|
+
const uint8_t * q2 = x[i].qs;
|
5198
|
+
const int8_t * q8 = y[i].qs;
|
5199
|
+
const uint8_t * sc = x[i].scales;
|
5200
|
+
const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
5201
|
+
const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
|
5202
|
+
uint8_t *patmp = atmp;
|
5203
|
+
int vsums;
|
5204
|
+
int tmp;
|
5205
|
+
__asm__ __volatile__(
|
5206
|
+
"vsetivli zero, 16, e8, m1\n\t"
|
5207
|
+
"vmv.v.x v8, zero\n\t"
|
5208
|
+
"vle8.v v1, (%[sc])\n\t"
|
5209
|
+
"vand.vi v0, v1, 0xF\n\t"
|
5210
|
+
"vsrl.vi v1, v1, 4\n\t"
|
5211
|
+
"vse8.v v0, (%[scale])\n\t"
|
5212
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
5213
|
+
"vle16.v v2, (%[bsums])\n\t"
|
5214
|
+
"vzext.vf2 v0, v1\n\t"
|
5215
|
+
"vwmul.vv v4, v0, v2\n\t"
|
5216
|
+
"vsetivli zero, 16, e32, m4\n\t"
|
5217
|
+
"vredsum.vs v8, v4, v8\n\t"
|
5218
|
+
"vmv.x.s %[vsums], v8"
|
5219
|
+
: [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
|
5220
|
+
: [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
|
5221
|
+
: "memory"
|
5222
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
5223
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
5224
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
5225
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
5226
|
+
);
|
5227
|
+
sumf += dmin * vsums;
|
5228
|
+
int isum = 0;
|
5229
|
+
|
5230
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
5231
|
+
__asm__ __volatile__(
|
5232
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
5233
|
+
"vle8.v v0, (%[q2])\n\t"
|
5234
|
+
"vsrl.vi v2, v0, 2\n\t"
|
5235
|
+
"vsrl.vi v4, v0, 4\n\t"
|
5236
|
+
"vsrl.vi v6, v0, 6\n\t"
|
5237
|
+
"vand.vi v0, v0, 0x3\n\t"
|
5238
|
+
"vand.vi v2, v2, 0x3\n\t"
|
5239
|
+
"vand.vi v4, v4, 0x3\n\t"
|
5240
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
5241
|
+
"vle8.v v8, (%[q8])\n\t"
|
5242
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
5243
|
+
"vwmul.vv v16, v0, v8\n\t"
|
5244
|
+
"vwmul.vv v24, v4, v12\n\t"
|
5245
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
5246
|
+
"vmv.v.x v0, zero\n\t"
|
5247
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
5248
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
5249
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
5250
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
5251
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
5252
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
5253
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
5254
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
5255
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
5256
|
+
"vslideup.vi v10, v9, 1\n\t"
|
5257
|
+
"vslideup.vi v8, v7, 1\n\t"
|
5258
|
+
"vslideup.vi v11, v12, 1\n\t"
|
5259
|
+
"vslideup.vi v13, v14, 1\n\t"
|
5260
|
+
"vslideup.vi v10, v8, 2\n\t"
|
5261
|
+
"vslideup.vi v11, v13, 2\n\t"
|
5262
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
5263
|
+
"vle8.v v15, (%[scale])\n\t"
|
5264
|
+
"vzext.vf4 v12, v15\n\t"
|
5265
|
+
"vmul.vv v10, v10, v12\n\t"
|
5266
|
+
"vredsum.vs v0, v10, v0\n\t"
|
5267
|
+
"vmv.x.s %[tmp], v0\n\t"
|
5268
|
+
"add %[isum], %[isum], %[tmp]"
|
5269
|
+
: [tmp] "=&r" (tmp), [isum] "+&r" (isum)
|
5270
|
+
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
|
5271
|
+
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
|
5272
|
+
: "memory"
|
5273
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
5274
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
5275
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
5276
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
5277
|
+
);
|
5278
|
+
q2 += 32; q8 += 128; patmp += 8;
|
5279
|
+
}
|
5254
5280
|
|
5281
|
+
sumf += dall * isum;
|
5282
|
+
}
|
5283
|
+
break;
|
5284
|
+
default:
|
5285
|
+
assert(false && "Unsupported vector length");
|
5286
|
+
break;
|
5255
5287
|
}
|
5256
5288
|
|
5257
5289
|
*s = sumf;
|
@@ -6116,97 +6148,221 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
6116
6148
|
uint32_t aux[3];
|
6117
6149
|
uint32_t utmp[4];
|
6118
6150
|
|
6151
|
+
const int vector_length = __riscv_vlenb() * 8;
|
6119
6152
|
float sumf = 0;
|
6120
|
-
for (int i = 0; i < nb; ++i) {
|
6121
6153
|
|
6122
|
-
|
6123
|
-
|
6124
|
-
|
6154
|
+
switch (vector_length) {
|
6155
|
+
case 256:
|
6156
|
+
for (int i = 0; i < nb; ++i) {
|
6125
6157
|
|
6126
|
-
|
6127
|
-
|
6128
|
-
|
6129
|
-
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
6130
|
-
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
6158
|
+
const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
|
6159
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].hmask;
|
6160
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
6131
6161
|
|
6132
|
-
|
6133
|
-
|
6162
|
+
memcpy(aux, x[i].scales, 12);
|
6163
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
6164
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
6165
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
6166
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
6134
6167
|
|
6168
|
+
int8_t * scale = (int8_t *)utmp;
|
6169
|
+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
|
6135
6170
|
|
6136
|
-
size_t vl = 32;
|
6137
|
-
uint8_t m = 1;
|
6138
6171
|
|
6139
|
-
|
6140
|
-
|
6172
|
+
size_t vl = 32;
|
6173
|
+
uint8_t m = 1;
|
6141
6174
|
|
6142
|
-
|
6175
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
6176
|
+
vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
|
6143
6177
|
|
6144
|
-
|
6178
|
+
int sum_t = 0;
|
6145
6179
|
|
6146
|
-
|
6180
|
+
for (int j = 0; j < QK_K; j += 128) {
|
6147
6181
|
|
6148
|
-
|
6149
|
-
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
|
6182
|
+
vl = 32;
|
6150
6183
|
|
6151
|
-
|
6152
|
-
|
6153
|
-
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
|
6154
|
-
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
|
6184
|
+
// load Q3
|
6185
|
+
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
|
6155
6186
|
|
6156
|
-
|
6157
|
-
|
6158
|
-
|
6159
|
-
|
6160
|
-
m <<= 1;
|
6187
|
+
vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
|
6188
|
+
vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
|
6189
|
+
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
|
6190
|
+
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
|
6161
6191
|
|
6162
|
-
|
6163
|
-
|
6164
|
-
|
6165
|
-
|
6192
|
+
// compute mask for subtraction
|
6193
|
+
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
6194
|
+
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
|
6195
|
+
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
|
6196
|
+
m <<= 1;
|
6166
6197
|
|
6167
|
-
|
6168
|
-
|
6169
|
-
|
6170
|
-
|
6198
|
+
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
6199
|
+
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
|
6200
|
+
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
|
6201
|
+
m <<= 1;
|
6171
6202
|
|
6172
|
-
|
6173
|
-
|
6174
|
-
|
6175
|
-
|
6203
|
+
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
6204
|
+
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
|
6205
|
+
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
|
6206
|
+
m <<= 1;
|
6176
6207
|
|
6177
|
-
|
6178
|
-
|
6179
|
-
|
6180
|
-
|
6181
|
-
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
6208
|
+
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
6209
|
+
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
|
6210
|
+
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
|
6211
|
+
m <<= 1;
|
6182
6212
|
|
6183
|
-
|
6213
|
+
// load Q8 and take product with Q3
|
6214
|
+
vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
|
6215
|
+
vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
|
6216
|
+
vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
|
6217
|
+
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
6184
6218
|
|
6185
|
-
|
6186
|
-
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
6187
|
-
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
6188
|
-
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
6189
|
-
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
|
6190
|
-
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
|
6191
|
-
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
|
6192
|
-
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
|
6193
|
-
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
|
6219
|
+
vl = 16;
|
6194
6220
|
|
6195
|
-
|
6196
|
-
|
6197
|
-
|
6198
|
-
|
6221
|
+
// retrieve lane to multiply with scale
|
6222
|
+
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
6223
|
+
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
6224
|
+
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
6225
|
+
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
|
6226
|
+
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
|
6227
|
+
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
|
6228
|
+
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
|
6229
|
+
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
|
6199
6230
|
|
6200
|
-
|
6231
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
|
6232
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
|
6233
|
+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
|
6234
|
+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
|
6201
6235
|
|
6202
|
-
|
6236
|
+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
|
6203
6237
|
|
6204
|
-
|
6238
|
+
q3 += 32; q8 += 128; scale += 8;
|
6205
6239
|
|
6206
|
-
|
6240
|
+
}
|
6241
|
+
|
6242
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
6243
|
+
|
6244
|
+
sumf += d*sum_t;
|
6245
|
+
|
6246
|
+
}
|
6247
|
+
break;
|
6248
|
+
case 128:
|
6249
|
+
for (int i = 0; i < nb; ++i) {
|
6250
|
+
const uint8_t * restrict q3 = x[i].qs;
|
6251
|
+
const uint8_t * restrict qh = x[i].hmask;
|
6252
|
+
const int8_t * restrict q8 = y[i].qs;
|
6253
|
+
|
6254
|
+
int8_t * scale = (int8_t *)utmp;
|
6255
|
+
int tmp;
|
6256
|
+
__asm__ __volatile__(
|
6257
|
+
"vsetivli zero, 12, e8, m1\n\t"
|
6258
|
+
"vle8.v v0, (%[s6b])\n\t"
|
6259
|
+
"vmv1r.v v2, v0\n\t"
|
6260
|
+
"vsetivli zero, 2, e64, m1\n\t"
|
6261
|
+
"vmv.v.x v9, %[sh]\n\t"\
|
6262
|
+
"vslidedown.vi v1, v0, 1\n\t"
|
6263
|
+
"vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
|
6264
|
+
"vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
|
6265
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
6266
|
+
"vid.v v9\n\t"
|
6267
|
+
"vmv.x.s %[tmp], v1\n\t"
|
6268
|
+
"vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
|
6269
|
+
"vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
|
6270
|
+
"vsrl.vv v4, v1, v9\n\t"
|
6271
|
+
"vsrl.vv v2, v0, v8\n\t"
|
6272
|
+
"vand.vx v5, v4, %[kmask1]\n\t"
|
6273
|
+
"vand.vx v3, v2, %[kmask2]\n\t"
|
6274
|
+
"vsll.vi v6, v5, 4\n\t"
|
6275
|
+
"vor.vv v7, v6, v3\n\t"
|
6276
|
+
"vsetivli zero, 16, e8, m1\n\t"
|
6277
|
+
"vsub.vx v0, v7, %[c]\n\t"
|
6278
|
+
"vse8.v v0, (%[scale])"
|
6279
|
+
: [tmp] "=&r" (tmp)
|
6280
|
+
: [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
|
6281
|
+
, [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
|
6282
|
+
: "memory"
|
6283
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
6284
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
6285
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
6286
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
6287
|
+
);
|
6207
6288
|
|
6208
|
-
|
6289
|
+
uint8_t m = 1;
|
6290
|
+
int isum = 0;
|
6291
|
+
for (int j = 0; j < QK_K; j += 128) {
|
6292
|
+
__asm__ __volatile__(
|
6293
|
+
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
|
6294
|
+
"vle8.v v8, (%[q3])\n\t"
|
6295
|
+
"vsrl.vi v10, v8, 2\n\t"
|
6296
|
+
"vsrl.vi v12, v8, 4\n\t"
|
6297
|
+
"vsrl.vi v14, v8, 6\n\t"
|
6298
|
+
"vand.vi v8, v8, 3\n\t"
|
6299
|
+
"vand.vi v10, v10, 3\n\t"
|
6300
|
+
"vand.vi v12, v12, 3\n\t"
|
6301
|
+
"vle8.v v2, (%[qh])\n\t"
|
6302
|
+
"vand.vx v4, v2, %[m]\n\t"
|
6303
|
+
"slli %[m], %[m], 1\n\t"
|
6304
|
+
"vmseq.vx v0, v4, zero\n\t"
|
6305
|
+
"vadd.vi v8, v8, -4, v0.t\n\t"
|
6306
|
+
"vand.vx v4, v2, %[m]\n\t"
|
6307
|
+
"slli %[m], %[m], 1\n\t"
|
6308
|
+
"vmseq.vx v0, v4, zero\n\t"
|
6309
|
+
"vadd.vi v10, v10, -4, v0.t\n\t"
|
6310
|
+
"vand.vx v4, v2, %[m]\n\t"
|
6311
|
+
"slli %[m], %[m], 1\n\t"
|
6312
|
+
"vmseq.vx v0, v4, zero\n\t"
|
6313
|
+
"vadd.vi v12, v12, -4, v0.t\n\t"
|
6314
|
+
"vand.vx v4, v2, %[m]\n\t"
|
6315
|
+
"slli %[m], %[m], 1\n\t"
|
6316
|
+
"vmseq.vx v0, v4, zero\n\t"
|
6317
|
+
"vadd.vi v14, v14, -4, v0.t\n\t"
|
6318
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
6319
|
+
"vle8.v v0, (%[q8])\n\t"
|
6320
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
6321
|
+
"vwmul.vv v16, v0, v8\n\t"
|
6322
|
+
"vwmul.vv v24, v4, v12\n\t"
|
6323
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
6324
|
+
"vmv.v.x v0, zero\n\t"
|
6325
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
6326
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
6327
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
6328
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
6329
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
6330
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
6331
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
6332
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
6333
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
6334
|
+
"vslideup.vi v10, v9, 1\n\t"
|
6335
|
+
"vslideup.vi v8, v7, 1\n\t"
|
6336
|
+
"vslideup.vi v11, v12, 1\n\t"
|
6337
|
+
"vslideup.vi v13, v14, 1\n\t"
|
6338
|
+
"vslideup.vi v10, v8, 2\n\t"
|
6339
|
+
"vslideup.vi v11, v13, 2\n\t"
|
6340
|
+
"vsetivli zero, 8, e32, m2\n\t"\
|
6341
|
+
"vle8.v v15, (%[scale])\n\t"
|
6342
|
+
"vsext.vf4 v12, v15\n\t"
|
6343
|
+
"vmul.vv v10, v10, v12\n\t"
|
6344
|
+
"vredsum.vs v0, v10, v0\n\t"
|
6345
|
+
"vmv.x.s %[tmp], v0\n\t"
|
6346
|
+
"add %[isum], %[isum], %[tmp]"
|
6347
|
+
: [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
|
6348
|
+
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
|
6349
|
+
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
|
6350
|
+
: "memory"
|
6351
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
6352
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
6353
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
6354
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
6355
|
+
);
|
6356
|
+
q3 += 32; q8 += 128; scale += 8;
|
6357
|
+
}
|
6209
6358
|
|
6359
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
6360
|
+
sumf += d * isum;
|
6361
|
+
}
|
6362
|
+
break;
|
6363
|
+
default:
|
6364
|
+
assert(false && "Unsupported vector length");
|
6365
|
+
break;
|
6210
6366
|
}
|
6211
6367
|
|
6212
6368
|
*s = sumf;
|
@@ -6924,69 +7080,181 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
6924
7080
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
6925
7081
|
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
6926
7082
|
|
7083
|
+
const int vector_length = __riscv_vlenb() * 8;
|
6927
7084
|
float sumf = 0;
|
6928
7085
|
|
6929
|
-
|
7086
|
+
switch (vector_length) {
|
7087
|
+
case 256:
|
7088
|
+
for (int i = 0; i < nb; ++i) {
|
6930
7089
|
|
6931
|
-
|
7090
|
+
size_t vl = 8;
|
6932
7091
|
|
6933
|
-
|
6934
|
-
|
7092
|
+
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
7093
|
+
const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
|
6935
7094
|
|
6936
|
-
|
6937
|
-
|
6938
|
-
|
7095
|
+
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
|
7096
|
+
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
|
7097
|
+
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
|
6939
7098
|
|
6940
|
-
|
6941
|
-
|
6942
|
-
|
6943
|
-
|
6944
|
-
|
6945
|
-
|
7099
|
+
memcpy(utmp, x[i].scales, 12);
|
7100
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
7101
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
7102
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
7103
|
+
utmp[2] = uaux;
|
7104
|
+
utmp[0] &= kmask1;
|
6946
7105
|
|
6947
|
-
|
6948
|
-
|
6949
|
-
|
7106
|
+
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
|
7107
|
+
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
|
7108
|
+
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
|
6950
7109
|
|
6951
|
-
|
6952
|
-
|
7110
|
+
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
7111
|
+
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
|
6953
7112
|
|
6954
|
-
|
6955
|
-
|
7113
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
|
7114
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
6956
7115
|
|
6957
|
-
|
7116
|
+
vl = 32;
|
6958
7117
|
|
6959
|
-
|
6960
|
-
|
7118
|
+
int32_t sum_1 = 0;
|
7119
|
+
int32_t sum_2 = 0;
|
6961
7120
|
|
6962
|
-
|
7121
|
+
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
|
6963
7122
|
|
6964
|
-
|
6965
|
-
|
6966
|
-
|
7123
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
7124
|
+
// load Q4
|
7125
|
+
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
|
6967
7126
|
|
6968
|
-
|
6969
|
-
|
6970
|
-
|
6971
|
-
|
6972
|
-
|
7127
|
+
// load Q8 and multiply it with lower Q4 nibble
|
7128
|
+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
|
7129
|
+
vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
|
7130
|
+
vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
|
7131
|
+
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
|
6973
7132
|
|
6974
|
-
|
7133
|
+
sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
|
6975
7134
|
|
6976
|
-
|
6977
|
-
|
6978
|
-
|
6979
|
-
|
6980
|
-
|
7135
|
+
// load Q8 and multiply it with upper Q4 nibble
|
7136
|
+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
|
7137
|
+
vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
|
7138
|
+
vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
|
7139
|
+
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
|
6981
7140
|
|
6982
|
-
|
7141
|
+
sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
|
6983
7142
|
|
6984
|
-
|
7143
|
+
q4 += 32; q8 += 64;
|
6985
7144
|
|
6986
|
-
|
7145
|
+
}
|
7146
|
+
|
7147
|
+
sumf += d*(sum_1 + sum_2);
|
7148
|
+
|
7149
|
+
}
|
7150
|
+
break;
|
7151
|
+
case 128:
|
7152
|
+
for (int i = 0; i < nb; ++i) {
|
7153
|
+
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
7154
|
+
const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
|
7155
|
+
|
7156
|
+
int tmp, tmp2, sumi;
|
7157
|
+
__asm__ __volatile__(
|
7158
|
+
"vsetivli zero, 12, e8, m1\n\t"
|
7159
|
+
"vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
|
7160
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
7161
|
+
"vslidedown.vi v2, v1, 2\n\t"
|
7162
|
+
"vmv1r.v v3, v2\n\t"
|
7163
|
+
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
|
7164
|
+
"vsetivli zero, 2, e32, m1\n\t"
|
7165
|
+
"vmv.v.i v4, 4\n\t"
|
7166
|
+
"vand.vx v8, v1, %[kmask1]\n\t"
|
7167
|
+
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
|
7168
|
+
"vsrl.vi v6, v1, 6\n\t"
|
7169
|
+
"vsrl.vv v7, v2, v5\n\t"
|
7170
|
+
"vand.vx v0, v6, %[kmask3]\n\t"
|
7171
|
+
"vand.vx v2, v7, %[kmask2]\n\t"
|
7172
|
+
"vsll.vi v6, v0, 4\n\t"
|
7173
|
+
"li %[t2], 8\n\t"
|
7174
|
+
"addi %[t1], %[utmp], 4\n\t"
|
7175
|
+
"vor.vv v1, v6, v2\n\t"
|
7176
|
+
"vsse32.v v8, (%[utmp]), %[t2]\n\t"
|
7177
|
+
"vsse32.v v1, (%[t1]), %[t2]\n\t"
|
7178
|
+
"vsetivli zero, 8, e16, m1\n\t"
|
7179
|
+
"vle32.v v2, (%[bsums])\n\t"
|
7180
|
+
"vnsrl.wi v0, v2, 0\n\t"
|
7181
|
+
"vnsrl.wi v1, v2, 16\n\t"
|
7182
|
+
"vadd.vv v2, v0, v1\n\t"
|
7183
|
+
"vle8.v v3, (%[mins])\n\t"
|
7184
|
+
"vzext.vf2 v4, v3\n\t"
|
7185
|
+
"vwmul.vv v6, v4, v2\n\t"
|
7186
|
+
"vmv.v.x v0, zero\n\t"
|
7187
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
7188
|
+
"vredsum.vs v0, v6, v0\n\t"
|
7189
|
+
"vmv.x.s %[sumi], v0"
|
7190
|
+
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
|
7191
|
+
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
|
7192
|
+
, [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
|
7193
|
+
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
|
7194
|
+
: "memory"
|
7195
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
7196
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
7197
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
7198
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
7199
|
+
);
|
7200
|
+
sumf -= dmin * sumi;
|
7201
|
+
|
7202
|
+
const uint8_t * restrict q4 = x[i].qs;
|
7203
|
+
const int8_t * restrict q8 = y[i].qs;
|
7204
|
+
|
7205
|
+
sumi = 0;
|
7206
|
+
const uint8_t * scale = scales;
|
7207
|
+
|
7208
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
7209
|
+
int vl128 = 128, vl64 = 64, vl32 = 32;
|
7210
|
+
__asm__ __volatile__(
|
7211
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
7212
|
+
"vle8.v v8, (%[q8])\n\t"
|
7213
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
7214
|
+
"vle8.v v0, (%[q4])\n\t"
|
7215
|
+
"vsrl.vi v4, v0, 4\n\t"
|
7216
|
+
"vand.vi v0, v0, 0xF\n\t"
|
7217
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
7218
|
+
"vwmul.vv v28, v6, v14\n\t"
|
7219
|
+
"vwmul.vv v20, v4, v10\n\t"
|
7220
|
+
"vwmul.vv v24, v2, v12\n\t"
|
7221
|
+
"vwmul.vv v16, v0, v8\n\t"
|
7222
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
7223
|
+
"vle8.v v2, (%[scale])\n\t"
|
7224
|
+
"vmv.v.x v0, zero\n\t"
|
7225
|
+
"vzext.vf4 v1, v2\n\t"
|
7226
|
+
"vsetvli zero, %[vl32], e16, m4\n\t"
|
7227
|
+
"vwredsum.vs v6, v24, v0\n\t"
|
7228
|
+
"vwredsum.vs v7, v28, v0\n\t"
|
7229
|
+
"vwredsum.vs v4, v16, v0\n\t"
|
7230
|
+
"vwredsum.vs v5, v20, v0\n\t"
|
7231
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
7232
|
+
"vslideup.vi v6, v7, 1\n\t"
|
7233
|
+
"vslideup.vi v4, v5, 1\n\t"
|
7234
|
+
"vslideup.vi v4, v6, 2\n\t"
|
7235
|
+
"vmul.vv v8, v4, v1\n\t"
|
7236
|
+
"vredsum.vs v0, v8, v0\n\t"
|
7237
|
+
"vmv.x.s %[tmp], v0\n\t"
|
7238
|
+
"add %[sumi], %[sumi], %[tmp]"
|
7239
|
+
: [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
|
7240
|
+
: [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
|
7241
|
+
, [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
|
7242
|
+
: "memory"
|
7243
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
7244
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
7245
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
7246
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
7247
|
+
);
|
6987
7248
|
|
6988
|
-
|
7249
|
+
q4 += 64; q8 += 128; scale += 4;
|
7250
|
+
}
|
6989
7251
|
|
7252
|
+
sumf += d * sumi;
|
7253
|
+
}
|
7254
|
+
break;
|
7255
|
+
default:
|
7256
|
+
assert(false && "Unsupported vector length");
|
7257
|
+
break;
|
6990
7258
|
}
|
6991
7259
|
|
6992
7260
|
*s = sumf;
|
@@ -7722,9 +7990,9 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
7722
7990
|
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
7723
7991
|
const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
7724
7992
|
|
7725
|
-
|
7726
|
-
|
7727
|
-
|
7993
|
+
vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
|
7994
|
+
vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
|
7995
|
+
vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
|
7728
7996
|
|
7729
7997
|
memcpy(utmp, x[i].scales, 12);
|
7730
7998
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
@@ -7733,11 +8001,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
7733
8001
|
utmp[2] = uaux;
|
7734
8002
|
utmp[0] &= kmask1;
|
7735
8003
|
|
7736
|
-
|
7737
|
-
|
7738
|
-
|
8004
|
+
vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
|
8005
|
+
vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
8006
|
+
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
|
7739
8007
|
|
7740
|
-
vint32m1_t sumi =
|
8008
|
+
vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
7741
8009
|
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
|
7742
8010
|
|
7743
8011
|
vl = 32;
|
@@ -7746,43 +8014,42 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
7746
8014
|
|
7747
8015
|
uint8_t m = 1;
|
7748
8016
|
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
7749
|
-
|
8017
|
+
vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
|
7750
8018
|
|
7751
8019
|
for (int j = 0; j < QK_K/64; ++j) {
|
7752
8020
|
// load Q5 and Q8
|
7753
|
-
|
7754
|
-
|
7755
|
-
|
8021
|
+
vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
|
8022
|
+
vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
|
8023
|
+
vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
|
7756
8024
|
|
7757
8025
|
// compute mask for addition
|
7758
|
-
|
7759
|
-
|
7760
|
-
|
7761
|
-
|
8026
|
+
vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
|
8027
|
+
vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
|
8028
|
+
vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
|
8029
|
+
vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
|
7762
8030
|
m <<= 1;
|
7763
8031
|
|
7764
|
-
|
7765
|
-
|
7766
|
-
|
7767
|
-
|
8032
|
+
vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
|
8033
|
+
vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
|
8034
|
+
vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
|
8035
|
+
vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
|
7768
8036
|
m <<= 1;
|
7769
8037
|
|
7770
|
-
|
7771
|
-
|
8038
|
+
vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
|
8039
|
+
vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
|
7772
8040
|
|
7773
|
-
|
7774
|
-
|
8041
|
+
vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
|
8042
|
+
vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
|
7775
8043
|
|
7776
|
-
vint32m1_t vacc1 =
|
7777
|
-
vint32m1_t vacc2 =
|
8044
|
+
vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
|
8045
|
+
vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
|
7778
8046
|
|
7779
|
-
aux32 += __riscv_vmv_x_s_i32m1_i32(
|
8047
|
+
aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
|
7780
8048
|
q5 += 32; q8 += 64;
|
7781
8049
|
|
7782
8050
|
}
|
7783
8051
|
|
7784
|
-
|
7785
|
-
sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
|
8052
|
+
sums += aux32 * d;
|
7786
8053
|
|
7787
8054
|
}
|
7788
8055
|
|
@@ -8158,7 +8425,156 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
8158
8425
|
|
8159
8426
|
const int nb = n / QK_K;
|
8160
8427
|
|
8161
|
-
#ifdef
|
8428
|
+
#ifdef __ARM_FEATURE_SVE
|
8429
|
+
const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
|
8430
|
+
float sum = 0;
|
8431
|
+
svuint8_t m4b = svdup_n_u8(0xf);
|
8432
|
+
svint32_t vzero = svdup_n_s32(0);
|
8433
|
+
svuint8_t mone = svdup_n_u8(0x30);
|
8434
|
+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
8435
|
+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
8436
|
+
|
8437
|
+
for (int i = 0; i < nb; ++i) {
|
8438
|
+
const float d_all = LM_GGML_FP16_TO_FP32(x[i].d);
|
8439
|
+
|
8440
|
+
const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
|
8441
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
8442
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
8443
|
+
|
8444
|
+
const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
|
8445
|
+
|
8446
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
8447
|
+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
8448
|
+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
8449
|
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
8450
|
+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
8451
|
+
const svint64_t prod = svdup_n_s64(0);
|
8452
|
+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
8453
|
+
svdot_s64(prod, q8sums_2, q6scales_2)));
|
8454
|
+
int32_t isum = 0;
|
8455
|
+
|
8456
|
+
switch (vector_length) {
|
8457
|
+
case 128:
|
8458
|
+
{
|
8459
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
8460
|
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
8461
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
8462
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
8463
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
8464
|
+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
8465
|
+
qh += 32;
|
8466
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
8467
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
8468
|
+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
8469
|
+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
8470
|
+
q6 += 64;
|
8471
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
8472
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
8473
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
8474
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
8475
|
+
q8 += 64;
|
8476
|
+
|
8477
|
+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
8478
|
+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
8479
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
8480
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
8481
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
8482
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
8483
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
8484
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
8485
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
8486
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
8487
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
8488
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
8489
|
+
|
8490
|
+
scale += 4;
|
8491
|
+
q8bytes_1 = svld1_s8(pg8_16, q8);
|
8492
|
+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
8493
|
+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
8494
|
+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
8495
|
+
q8 += 64;
|
8496
|
+
|
8497
|
+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
8498
|
+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
8499
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
8500
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
8501
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
8502
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
8503
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
8504
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
8505
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
8506
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
8507
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
8508
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
8509
|
+
scale += 4;
|
8510
|
+
}
|
8511
|
+
isum += svaddv_s32(pg32_4, isum_tmp);
|
8512
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
8513
|
+
}
|
8514
|
+
break;
|
8515
|
+
case 256:
|
8516
|
+
case 512:
|
8517
|
+
{
|
8518
|
+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
8519
|
+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
8520
|
+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
8521
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
8522
|
+
for (int j = 0; j < QK_K/128; j++) {
|
8523
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
8524
|
+
qh += 32;
|
8525
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
8526
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
8527
|
+
q6 += 64;
|
8528
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
8529
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
8530
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
8531
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
8532
|
+
q8 += 128;
|
8533
|
+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
8534
|
+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
8535
|
+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
8536
|
+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
8537
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
8538
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
8539
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
8540
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
8541
|
+
|
8542
|
+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
8543
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
8544
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
8545
|
+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
8546
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
8547
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
8548
|
+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
8549
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
8550
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
8551
|
+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
8552
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
8553
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
8554
|
+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
8555
|
+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
8556
|
+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
8557
|
+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
8558
|
+
|
8559
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
8560
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
8561
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
8562
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
8563
|
+
scale += 8;
|
8564
|
+
}
|
8565
|
+
isum += svaddv_s32(pg32_8, isum_tmp);
|
8566
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
8567
|
+
}
|
8568
|
+
break;
|
8569
|
+
default:
|
8570
|
+
assert(false && "Unsupported vector length");
|
8571
|
+
break;
|
8572
|
+
}
|
8573
|
+
}
|
8574
|
+
|
8575
|
+
*s = sum;
|
8576
|
+
|
8577
|
+
#elif __ARM_NEON
|
8162
8578
|
float sum = 0;
|
8163
8579
|
|
8164
8580
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
@@ -8518,85 +8934,168 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
|
|
8518
8934
|
|
8519
8935
|
#elif defined __riscv_v_intrinsic
|
8520
8936
|
|
8937
|
+
const int vector_length = __riscv_vlenb() * 8;
|
8521
8938
|
float sumf = 0;
|
8522
|
-
for (int i = 0; i < nb; ++i) {
|
8523
8939
|
|
8524
|
-
|
8940
|
+
switch (vector_length) {
|
8941
|
+
case 256:
|
8942
|
+
for (int i = 0; i < nb; ++i) {
|
8525
8943
|
|
8526
|
-
|
8527
|
-
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
8528
|
-
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
8944
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8529
8945
|
|
8530
|
-
|
8946
|
+
const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
|
8947
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
8948
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
8531
8949
|
|
8532
|
-
|
8950
|
+
const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
|
8533
8951
|
|
8534
|
-
|
8952
|
+
size_t vl;
|
8535
8953
|
|
8536
|
-
|
8537
|
-
int is = 0;
|
8954
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
8538
8955
|
|
8539
|
-
|
8956
|
+
int sum_t = 0;
|
8957
|
+
int is = 0;
|
8540
8958
|
|
8541
|
-
|
8959
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
8542
8960
|
|
8543
|
-
|
8544
|
-
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
|
8961
|
+
vl = 32;
|
8545
8962
|
|
8546
|
-
|
8547
|
-
|
8548
|
-
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
|
8963
|
+
// load qh
|
8964
|
+
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
|
8549
8965
|
|
8550
|
-
|
8551
|
-
|
8552
|
-
|
8553
|
-
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
|
8966
|
+
// load Q6
|
8967
|
+
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
|
8968
|
+
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
|
8554
8969
|
|
8555
|
-
|
8556
|
-
|
8557
|
-
|
8558
|
-
|
8970
|
+
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
|
8971
|
+
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
|
8972
|
+
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
|
8973
|
+
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
|
8559
8974
|
|
8560
|
-
|
8561
|
-
|
8562
|
-
|
8563
|
-
|
8975
|
+
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
|
8976
|
+
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
|
8977
|
+
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
|
8978
|
+
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
|
8564
8979
|
|
8565
|
-
|
8566
|
-
|
8567
|
-
|
8568
|
-
|
8980
|
+
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
|
8981
|
+
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
|
8982
|
+
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
|
8983
|
+
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
|
8569
8984
|
|
8570
|
-
|
8571
|
-
|
8572
|
-
|
8573
|
-
|
8574
|
-
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
8985
|
+
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
|
8986
|
+
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
|
8987
|
+
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
|
8988
|
+
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
|
8575
8989
|
|
8576
|
-
|
8990
|
+
// load Q8 and take product
|
8991
|
+
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
|
8992
|
+
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
|
8993
|
+
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
|
8994
|
+
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
8577
8995
|
|
8578
|
-
|
8579
|
-
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
|
8580
|
-
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
|
8581
|
-
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
|
8582
|
-
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
|
8583
|
-
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
|
8584
|
-
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
|
8585
|
-
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
|
8996
|
+
vl = 16;
|
8586
8997
|
|
8587
|
-
|
8588
|
-
|
8589
|
-
|
8590
|
-
|
8998
|
+
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
|
8999
|
+
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
|
9000
|
+
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
|
9001
|
+
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
|
9002
|
+
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
|
9003
|
+
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
|
9004
|
+
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
|
9005
|
+
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
|
8591
9006
|
|
8592
|
-
|
9007
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
|
9008
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
|
9009
|
+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
|
9010
|
+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
|
8593
9011
|
|
8594
|
-
|
9012
|
+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
|
8595
9013
|
|
8596
|
-
|
9014
|
+
q6 += 64; qh += 32; q8 += 128; is=8;
|
8597
9015
|
|
8598
|
-
|
9016
|
+
}
|
8599
9017
|
|
9018
|
+
sumf += d * sum_t;
|
9019
|
+
|
9020
|
+
}
|
9021
|
+
break;
|
9022
|
+
case 128:
|
9023
|
+
for (int i = 0; i < nb; ++i) {
|
9024
|
+
|
9025
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
9026
|
+
|
9027
|
+
const uint8_t * restrict q6 = x[i].ql;
|
9028
|
+
const uint8_t * restrict qh = x[i].qh;
|
9029
|
+
const int8_t * restrict q8 = y[i].qs;
|
9030
|
+
|
9031
|
+
const int8_t * restrict scale = x[i].scales;
|
9032
|
+
|
9033
|
+
int sum_t = 0;
|
9034
|
+
int t0;
|
9035
|
+
|
9036
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
9037
|
+
__asm__ __volatile__(
|
9038
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
9039
|
+
"vle8.v v4, (%[qh])\n\t"
|
9040
|
+
"vsll.vi v0, v4, 4\n\t"
|
9041
|
+
"vsll.vi v2, v4, 2\n\t"
|
9042
|
+
"vsrl.vi v6, v4, 2\n\t"
|
9043
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
9044
|
+
"vle8.v v8, (%[q6])\n\t"
|
9045
|
+
"vsrl.vi v12, v8, 4\n\t"
|
9046
|
+
"vand.vi v8, v8, 0xF\n\t"
|
9047
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
9048
|
+
"vand.vx v0, v0, %[mask]\n\t"
|
9049
|
+
"vor.vv v8, v8, v0\n\t"
|
9050
|
+
"vle8.v v0, (%[q8])\n\t"
|
9051
|
+
"vsub.vx v8, v8, %[vl32]\n\t"
|
9052
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
9053
|
+
"vwmul.vv v16, v0, v8\n\t"
|
9054
|
+
"vwmul.vv v24, v4, v12\n\t"
|
9055
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
9056
|
+
"vmv.v.x v0, zero\n\t"
|
9057
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
9058
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
9059
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
9060
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
9061
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
9062
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
9063
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
9064
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
9065
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
9066
|
+
"vslideup.vi v10, v9, 1\n\t"
|
9067
|
+
"vslideup.vi v8, v7, 1\n\t"
|
9068
|
+
"vslideup.vi v11, v12, 1\n\t"
|
9069
|
+
"vslideup.vi v13, v14, 1\n\t"
|
9070
|
+
"vslideup.vi v10, v8, 2\n\t"
|
9071
|
+
"vslideup.vi v11, v13, 2\n\t"
|
9072
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
9073
|
+
"vle8.v v2, (%[scale])\n\t"
|
9074
|
+
"vsext.vf4 v4, v2\n\t"
|
9075
|
+
"vmul.vv v2, v4, v10\n\t"
|
9076
|
+
"vredsum.vs v0, v2, v0\n\t"
|
9077
|
+
"vmv.x.s %[t0], v0\n\t"
|
9078
|
+
"add %[sumi], %[sumi], %[t0]"
|
9079
|
+
: [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
|
9080
|
+
: [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
|
9081
|
+
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
|
9082
|
+
, [mask] "r" (0x30)
|
9083
|
+
: "memory"
|
9084
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
9085
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
9086
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
9087
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
9088
|
+
);
|
9089
|
+
q6 += 64; qh += 32; q8 += 128; scale += 8;
|
9090
|
+
}
|
9091
|
+
|
9092
|
+
sumf += d * sum_t;
|
9093
|
+
|
9094
|
+
}
|
9095
|
+
break;
|
9096
|
+
default:
|
9097
|
+
assert(false && "Unsupported vector length");
|
9098
|
+
break;
|
8600
9099
|
}
|
8601
9100
|
|
8602
9101
|
*s = sumf;
|