cui-llama.rn 1.1.4 → 1.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-quants.c CHANGED
@@ -230,6 +230,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
230
230
 
231
231
  return _mm_packus_epi16( bytes1, bytes2);
232
232
  }
233
+
234
+ static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
235
+ const __m128i ax = _mm_sign_epi8(x, x);
236
+ const __m128i sy = _mm_sign_epi8(y, x);
237
+ return _mm_maddubs_epi16(ax, sy);
238
+ }
233
239
  #endif
234
240
  #elif defined(__SSSE3__)
235
241
  // horizontally add 4x4 floats
@@ -4003,42 +4009,141 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
4003
4009
  float sumf = 0;
4004
4010
 
4005
4011
  #if defined(__ARM_FEATURE_SVE)
4006
- if (lm_ggml_sve_cnt_b == QK8_0) {
4007
- const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
4008
- const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
4009
-
4010
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
4011
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
4012
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
4013
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
4012
4014
 
4013
- for (; ib + 1 < nb; ib += 2) {
4014
- const block_q4_0 * restrict x0 = &x[ib + 0];
4015
- const block_q4_0 * restrict x1 = &x[ib + 1];
4016
- const block_q8_0 * restrict y0 = &y[ib + 0];
4017
- const block_q8_0 * restrict y1 = &y[ib + 1];
4015
+ const int vector_length = lm_ggml_sve_cnt_b*8;
4018
4016
 
4019
- // load x
4020
- const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4021
- const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4022
-
4023
- // 4-bit -> 8-bit
4024
- const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
4025
- const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
4026
-
4027
- // sub 8
4028
- const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4029
- const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4017
+ // VLA Implementation using switch case
4018
+ switch (vector_length) {
4019
+ case 128:
4020
+ {
4021
+ // predicate for activating higher lanes for 4 float32 elements
4022
+ const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
4023
+
4024
+ for (; ib + 1 < nb; ib += 2) {
4025
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4026
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4027
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4028
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4029
+
4030
+ // load x
4031
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4032
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4033
+
4034
+ // 4-bit -> 8-bit
4035
+ const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
4036
+ const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
4037
+ const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
4038
+ const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
4039
+
4040
+ // sub 8
4041
+ const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
4042
+ const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
4043
+ const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
4044
+ const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
4045
+
4046
+ // load y
4047
+ const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
4048
+ const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
4049
+ const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
4050
+ const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
4051
+
4052
+ // dot product
4053
+ sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4054
+ svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
4055
+ svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4056
+ sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4057
+ svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
4058
+ svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4059
+ }
4030
4060
 
4031
- // load y
4032
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4033
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4061
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4062
+ } break;
4063
+ case 256:
4064
+ {
4065
+ // predicate for activating higher lanes for 16 int8 elements
4066
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4067
+ // predicate for activating lower lanes for 16 int8 elements
4068
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
4069
+
4070
+ for (; ib + 1 < nb; ib += 2) {
4071
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4072
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4073
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4074
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4075
+
4076
+ // load x
4077
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4078
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4079
+
4080
+ // 4-bit -> 8-bit
4081
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4082
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4083
+
4084
+ // sub 8
4085
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4086
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4087
+
4088
+ // load y
4089
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4090
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4091
+
4092
+ // dot product
4093
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
4094
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4095
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
4096
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4097
+ }
4034
4098
 
4035
- // dot product
4036
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4037
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4038
- }
4099
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4100
+ } break;
4101
+ case 512:
4102
+ {
4103
+ // predicate for activating higher lanes for 32 int8 elements
4104
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
4105
+
4106
+ // predicate for activating higher lanes for 16 int8 elements
4107
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4108
+ // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
4109
+ const svbool_t pl16 = svnot_b_z(ph32, ph16);
4110
+
4111
+ for (; ib + 1 < nb; ib += 2) {
4112
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4113
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4114
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4115
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4116
+
4117
+ // load x
4118
+ const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
4119
+ const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
4120
+
4121
+ // 4-bit -> 8-bit
4122
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4123
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4124
+
4125
+ // sub 8
4126
+ const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
4127
+ const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
4128
+
4129
+ // load y
4130
+ const svint8_t qy0 = svld1_s8(ph32, y0->qs);
4131
+ const svint8_t qy1 = svld1_s8(ph32, y1->qs);
4132
+
4133
+ // dot product
4134
+ sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
4135
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4136
+ sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
4137
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4138
+ }
4039
4139
 
4040
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4140
+ sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
4141
+ } break;
4142
+ default:
4143
+ assert(false && "Unsupported vector length");
4144
+ break;
4041
4145
  }
4146
+
4042
4147
  #elif defined(__ARM_NEON)
4043
4148
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4044
4149
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -4107,37 +4212,37 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
4107
4212
 
4108
4213
  sumf = hsum_float_8(acc);
4109
4214
  #elif defined(__AVX__)
4110
- // Initialize accumulator with zeros
4111
- __m256 acc = _mm256_setzero_ps();
4112
-
4113
- // Main loop
4114
- for (; ib < nb; ++ib) {
4115
- // Compute combined scale for the block
4116
- const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) );
4117
-
4118
- const __m128i lowMask = _mm_set1_epi8(0xF);
4119
- const __m128i off = _mm_set1_epi8(8);
4120
-
4121
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
4122
-
4123
- __m128i bx_0 = _mm_and_si128(lowMask, tmp);
4124
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
4125
- bx_0 = _mm_sub_epi8(bx_0, off);
4126
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4127
-
4128
- bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
4129
- by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
4130
- bx_0 = _mm_sub_epi8(bx_0, off);
4131
- const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
4215
+ const __m128i mone = _mm_set1_epi16(1);
4132
4216
 
4133
- // Convert int32_t to float
4134
- __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
4217
+ __m256 accum1 = _mm256_setzero_ps();
4218
+ __m256 accum2 = _mm256_setzero_ps();
4219
+ for (; ib + 1 < nb; ib += 2) {
4220
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
4221
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
4222
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
4223
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
4224
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
4225
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
4135
4226
 
4136
- // Apply the scale, and accumulate
4137
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
4227
+ const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
4228
+ const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
4229
+ const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
4230
+ const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
4231
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
4232
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
4233
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
4234
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
4235
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
4236
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
4237
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
4238
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
4239
+ accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)),
4240
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
4241
+ accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)),
4242
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
4138
4243
  }
4139
4244
 
4140
- sumf = hsum_float_8(acc);
4245
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
4141
4246
  #elif defined(__SSSE3__)
4142
4247
  // set constants
4143
4248
  const __m128i lowMask = _mm_set1_epi8(0xF);
@@ -5488,29 +5593,124 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
5488
5593
  float sumf = 0;
5489
5594
 
5490
5595
  #if defined(__ARM_FEATURE_SVE)
5491
- if (lm_ggml_sve_cnt_b == QK8_0) {
5492
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
5493
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
5596
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5597
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5494
5598
 
5495
- for (; ib + 1 < nb; ib += 2) {
5496
- const block_q8_0 * restrict x0 = &x[ib + 0];
5497
- const block_q8_0 * restrict x1 = &x[ib + 1];
5498
- const block_q8_0 * restrict y0 = &y[ib + 0];
5499
- const block_q8_0 * restrict y1 = &y[ib + 1];
5599
+ const int vector_length = lm_ggml_sve_cnt_b*8;
5500
5600
 
5501
- // load x
5502
- const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5503
- const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5601
+ //VLA Implemenation for SVE
5602
+ switch (vector_length) {
5603
+ case 128:
5604
+ {
5605
+ // predicate for activating lanes for 16 Int8 elements
5606
+ const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
5607
+ const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
5608
+
5609
+ for (; ib + 1 < nb; ib += 2) {
5610
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5611
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5612
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5613
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5614
+
5615
+ // load x
5616
+ const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
5617
+ const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
5618
+ const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
5619
+ const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
5620
+
5621
+ // load y
5622
+ const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
5623
+ const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
5624
+ const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
5625
+ const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
5626
+
5627
+ sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5628
+ svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
5629
+ svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5630
+ sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5631
+ svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
5632
+ svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5633
+ }
5504
5634
 
5505
- // load y
5506
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5507
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5635
+ sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
5636
+ } break;
5637
+ case 256:
5638
+ {
5639
+ //printf("sve256");
5640
+ for (; ib + 1 < nb; ib += 2) {
5641
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5642
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5643
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5644
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5645
+
5646
+ // load x
5647
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5648
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5649
+
5650
+ // load y
5651
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5652
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5653
+
5654
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
5655
+ svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5656
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
5657
+ svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5658
+ }
5508
5659
 
5509
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5510
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5511
- }
5660
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5661
+ } break;
5662
+ case 512:
5663
+ {
5664
+ // predicate for activating high 256 bit
5665
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
5666
+ // predicate for activating low 256 bit
5667
+ const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
5512
5668
 
5513
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5669
+ // predicate for activating high lanes for 8 float32 elements
5670
+ const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
5671
+ // predicate for activating low lanes for 8 float32 elements
5672
+ const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
5673
+
5674
+ svfloat32_t sumv00 = svdup_n_f32(0.0f);
5675
+
5676
+ for (; ib + 1 < nb; ib += 2) {
5677
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5678
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5679
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5680
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5681
+
5682
+ //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
5683
+ // and add them to make one 64 element vector
5684
+ // load x
5685
+ const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
5686
+ svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
5687
+
5688
+ qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
5689
+
5690
+ // load y
5691
+ const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
5692
+ svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
5693
+
5694
+ qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
5695
+
5696
+ // scale creation
5697
+ const float32_t deq1 = LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d);
5698
+ const float32_t deq2 = LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d);
5699
+
5700
+ // duplicate deq1 in first half of vector and deq2 in second half of vector
5701
+ const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
5702
+
5703
+ const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
5704
+
5705
+ sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
5706
+ }
5707
+
5708
+ sumf = svaddv_f32(svptrue_b32(), sumv00);
5709
+ break;
5710
+ }
5711
+ default:
5712
+ assert(false && "Unsupported vector length");
5713
+ break;
5514
5714
  }
5515
5715
  #elif defined(__ARM_NEON)
5516
5716
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -11625,15 +11825,6 @@ void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const voi
11625
11825
  #endif
11626
11826
  }
11627
11827
 
11628
-
11629
- #if defined(__AVX__)
11630
- static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
11631
- const __m128i ax = _mm_sign_epi8(x, x);
11632
- const __m128i sy = _mm_sign_epi8(y, x);
11633
- return _mm_maddubs_epi16(ax, sy);
11634
- }
11635
- #endif
11636
-
11637
11828
  #if defined(__AVX2__)
11638
11829
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
11639
11830
  const __m256i ax = _mm256_sign_epi8(x, x);
package/cpp/ggml.c CHANGED
@@ -287,6 +287,7 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
287
287
  #define LM_GGML_DEBUG 0
288
288
  #define LM_GGML_GELU_FP16
289
289
  #define LM_GGML_GELU_QUICK_FP16
290
+ #define LM_GGML_N_TASKS_MAX (-1)
290
291
 
291
292
  #define LM_GGML_SOFT_MAX_UNROLL 4
292
293
  #define LM_GGML_VEC_DOT_UNROLL 2
@@ -1120,21 +1121,21 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) {
1120
1121
  #define LM_GGML_F32x4_ADD vaddq_f32
1121
1122
  #define LM_GGML_F32x4_MUL vmulq_f32
1122
1123
  #define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1123
- #define LM_GGML_F32x4_REDUCE(res, x) \
1124
- { \
1125
- int offset = LM_GGML_F32_ARR >> 1; \
1126
- for (int i = 0; i < offset; ++i) { \
1127
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1128
- } \
1129
- offset >>= 1; \
1130
- for (int i = 0; i < offset; ++i) { \
1131
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1132
- } \
1133
- offset >>= 1; \
1134
- for (int i = 0; i < offset; ++i) { \
1135
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1136
- } \
1137
- res = LM_GGML_F32x4_REDUCE_ONE(x[0]); \
1124
+ #define LM_GGML_F32x4_REDUCE(res, x) \
1125
+ { \
1126
+ int offset = LM_GGML_F32_ARR >> 1; \
1127
+ for (int i = 0; i < offset; ++i) { \
1128
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1129
+ } \
1130
+ offset >>= 1; \
1131
+ for (int i = 0; i < offset; ++i) { \
1132
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1133
+ } \
1134
+ offset >>= 1; \
1135
+ for (int i = 0; i < offset; ++i) { \
1136
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1137
+ } \
1138
+ (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \
1138
1139
  }
1139
1140
 
1140
1141
  #define LM_GGML_F32_VEC LM_GGML_F32x4
@@ -1161,30 +1162,30 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) {
1161
1162
  #define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
1162
1163
  #define LM_GGML_F16x8_ADD vaddq_f16
1163
1164
  #define LM_GGML_F16x8_MUL vmulq_f16
1164
- #define LM_GGML_F16x8_REDUCE(res, x) \
1165
- do { \
1166
- int offset = LM_GGML_F16_ARR >> 1; \
1167
- for (int i = 0; i < offset; ++i) { \
1168
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1169
- } \
1170
- offset >>= 1; \
1171
- for (int i = 0; i < offset; ++i) { \
1172
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1173
- } \
1174
- offset >>= 1; \
1175
- for (int i = 0; i < offset; ++i) { \
1176
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1177
- } \
1178
- const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1179
- const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1180
- res = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1165
+ #define LM_GGML_F16x8_REDUCE(res, x) \
1166
+ do { \
1167
+ int offset = LM_GGML_F16_ARR >> 1; \
1168
+ for (int i = 0; i < offset; ++i) { \
1169
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1170
+ } \
1171
+ offset >>= 1; \
1172
+ for (int i = 0; i < offset; ++i) { \
1173
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1174
+ } \
1175
+ offset >>= 1; \
1176
+ for (int i = 0; i < offset; ++i) { \
1177
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1178
+ } \
1179
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
1180
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
1181
+ (res) = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1181
1182
  } while (0)
1182
1183
 
1183
1184
  #define LM_GGML_F16_VEC LM_GGML_F16x8
1184
1185
  #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO
1185
1186
  #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1
1186
1187
  #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p)
1187
- #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), r[i])
1188
+ #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), (r)[i])
1188
1189
  #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA
1189
1190
  #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD
1190
1191
  #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL
@@ -1893,6 +1894,23 @@ static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) {
1893
1894
  #define LM_GGML_F16_ARR (LM_GGML_F16_STEP/LM_GGML_F16_EPR)
1894
1895
  #endif
1895
1896
 
1897
+ //
1898
+ // ggml object
1899
+ //
1900
+
1901
+ struct lm_ggml_object {
1902
+ size_t offs;
1903
+ size_t size;
1904
+
1905
+ struct lm_ggml_object * next;
1906
+
1907
+ enum lm_ggml_object_type type;
1908
+
1909
+ char padding[4];
1910
+ };
1911
+
1912
+ static const size_t LM_GGML_OBJECT_SIZE = sizeof(struct lm_ggml_object);
1913
+
1896
1914
  //
1897
1915
  // ggml context
1898
1916
  //
@@ -3381,7 +3399,7 @@ double lm_ggml_type_sizef(enum lm_ggml_type type) {
3381
3399
  }
3382
3400
 
3383
3401
  LM_GGML_CALL const char * lm_ggml_type_name(enum lm_ggml_type type) {
3384
- return type_traits[type].type_name;
3402
+ return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
3385
3403
  }
3386
3404
 
3387
3405
  LM_GGML_CALL bool lm_ggml_is_quantized(enum lm_ggml_type type) {
@@ -3847,7 +3865,7 @@ static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx,
3847
3865
 
3848
3866
  if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) {
3849
3867
  LM_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
3850
- __func__, cur_end + size_needed, ctx->mem_size);
3868
+ __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size);
3851
3869
  assert(false);
3852
3870
  return NULL;
3853
3871
  }
@@ -19161,6 +19179,34 @@ void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) {
19161
19179
  lm_ggml_hash_set_reset(&cgraph->visited_hash_set);
19162
19180
  }
19163
19181
 
19182
+ int lm_ggml_graph_size(struct lm_ggml_cgraph * cgraph) {
19183
+ return cgraph->size;
19184
+ }
19185
+
19186
+ struct lm_ggml_tensor * lm_ggml_graph_node(struct lm_ggml_cgraph * cgraph, int i) {
19187
+ if (i < 0) {
19188
+ LM_GGML_ASSERT(cgraph->n_nodes + i >= 0);
19189
+ return cgraph->nodes[cgraph->n_nodes + i];
19190
+ }
19191
+
19192
+ LM_GGML_ASSERT(i < cgraph->n_nodes);
19193
+ return cgraph->nodes[i];
19194
+ }
19195
+
19196
+ struct lm_ggml_tensor ** lm_ggml_graph_nodes(struct lm_ggml_cgraph * cgraph) {
19197
+ return cgraph->nodes;
19198
+ }
19199
+
19200
+ int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph) {
19201
+ return cgraph->n_nodes;
19202
+ }
19203
+
19204
+ void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) {
19205
+ LM_GGML_ASSERT(cgraph->size > cgraph->n_nodes);
19206
+ cgraph->nodes[cgraph->n_nodes] = tensor;
19207
+ cgraph->n_nodes++;
19208
+ }
19209
+
19164
19210
  // Android's libc implementation "bionic" does not support setting affinity
19165
19211
  #if defined(__gnu_linux__)
19166
19212
  static void set_numa_thread_affinity(int thread_n) {
@@ -23242,6 +23288,14 @@ int lm_ggml_cpu_has_arm_fma(void) {
23242
23288
  #endif
23243
23289
  }
23244
23290
 
23291
+ int lm_ggml_cpu_has_riscv_v(void) {
23292
+ #if defined(__riscv_v_intrinsic)
23293
+ return 1;
23294
+ #else
23295
+ return 0;
23296
+ #endif
23297
+ }
23298
+
23245
23299
  int lm_ggml_cpu_has_metal(void) {
23246
23300
  #if defined(LM_GGML_USE_METAL)
23247
23301
  return 1;