llama_cpp 0.3.8 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
77
77
  }
78
78
  return 1/iscale;
79
79
  }
80
+ bool return_early = false;
81
+ if (rmse_type < 0) {
82
+ rmse_type = -rmse_type;
83
+ return_early = true;
84
+ }
80
85
  int weight_type = rmse_type%2;
81
86
  float sumlx = 0;
82
87
  float suml2 = 0;
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
89
94
  suml2 += w*l*l;
90
95
  }
91
96
  float scale = sumlx/suml2;
97
+ if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
92
98
  float best = scale * sumlx;
93
- for (int itry = 0; itry < 3; ++itry) {
94
- iscale = 1/scale;
95
- float slx = 0;
96
- float sl2 = 0;
97
- bool changed = false;
98
- for (int i = 0; i < n; ++i) {
99
- int l = nearest_int(iscale * x[i]);
100
- l = MAX(-nmax, MIN(nmax-1, l));
101
- if (l + nmax != L[i]) { changed = true; }
102
- float w = weight_type == 1 ? x[i] * x[i] : 1.f;
103
- slx += w*x[i]*l;
104
- sl2 += w*l*l;
105
- }
106
- if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
107
- for (int i = 0; i < n; ++i) {
108
- int l = nearest_int(iscale * x[i]);
109
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
110
- }
111
- sumlx = slx; suml2 = sl2;
112
- scale = sumlx/suml2;
113
- best = scale * sumlx;
114
- }
115
- for (int itry = 0; itry < 5; ++itry) {
116
- int n_changed = 0;
117
- for (int i = 0; i < n; ++i) {
118
- float w = weight_type == 1 ? x[i]*x[i] : 1;
119
- int l = L[i] - nmax;
120
- float slx = sumlx - w*x[i]*l;
121
- if (slx > 0) {
122
- float sl2 = suml2 - w*l*l;
123
- int new_l = nearest_int(x[i] * sl2 / slx);
124
- new_l = MAX(-nmax, MIN(nmax-1, new_l));
125
- if (new_l != l) {
126
- slx += w*x[i]*new_l;
127
- sl2 += w*new_l*new_l;
128
- if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
129
- L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
130
- scale = sumlx / suml2; best = scale * sumlx;
131
- ++n_changed;
132
- }
133
- }
134
- }
135
- }
136
- if (!n_changed) { break; }
137
- }
138
- if (rmse_type < 3) {
139
- return scale;
140
- }
141
- for (int is = -4; is <= 4; ++is) {
99
+ for (int is = -9; is <= 9; ++is) {
142
100
  if (is == 0) {
143
101
  continue;
144
102
  }
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
221
179
  return 1/iscale;
222
180
  }
223
181
 
224
- static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
182
+ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
183
+ int ntry, float alpha) {
225
184
  float min = x[0];
226
185
  float max = x[0];
186
+ float sum_x = 0;
187
+ float sum_x2 = 0;
227
188
  for (int i = 1; i < n; ++i) {
228
189
  if (x[i] < min) min = x[i];
229
190
  if (x[i] > max) max = x[i];
191
+ sum_x += x[i];
192
+ sum_x2 += x[i]*x[i];
230
193
  }
231
194
  if (max == min) {
232
195
  for (int i = 0; i < n; ++i) L[i] = 0;
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
254
217
  for (int i = 0; i < n; ++i) {
255
218
  sum += x[i] - scale*L[i];
256
219
  }
257
- min = sum/n;
220
+ min = alpha*min + (1 - alpha)*sum/n;
258
221
  if (min > 0) min = 0;
259
222
  iscale = 1/scale;
260
223
  if (!did_change) break;
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
263
226
  return scale;
264
227
  }
265
228
 
229
+ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
230
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
231
+ float rmin, float rdelta, int nstep, bool use_mad) {
232
+ float min = x[0];
233
+ float max = x[0];
234
+ float sum_w = weights[0];
235
+ float sum_x = sum_w * x[0];
236
+ for (int i = 1; i < n; ++i) {
237
+ if (x[i] < min) min = x[i];
238
+ if (x[i] > max) max = x[i];
239
+ float w = weights[i];
240
+ sum_w += w;
241
+ sum_x += w * x[i];
242
+ }
243
+ if (min > 0) min = 0;
244
+ if (max == min) {
245
+ for (int i = 0; i < n; ++i) L[i] = 0;
246
+ *the_min = -min;
247
+ return 0.f;
248
+ }
249
+ float iscale = nmax/(max - min);
250
+ float scale = 1/iscale;
251
+ float best_mad = 0;
252
+ for (int i = 0; i < n; ++i) {
253
+ int l = nearest_int(iscale*(x[i] - min));
254
+ L[i] = MAX(0, MIN(nmax, l));
255
+ float diff = scale * L[i] + min - x[i];
256
+ diff = use_mad ? fabsf(diff) : diff * diff;
257
+ float w = weights[i];
258
+ best_mad += w * diff;
259
+ }
260
+ if (nstep < 1) {
261
+ *the_min = -min;
262
+ return scale;
263
+ }
264
+ for (int is = 0; is <= nstep; ++is) {
265
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
266
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
267
+ for (int i = 0; i < n; ++i) {
268
+ int l = nearest_int(iscale*(x[i] - min));
269
+ l = MAX(0, MIN(nmax, l));
270
+ Laux[i] = l;
271
+ float w = weights[i];
272
+ sum_l += w*l;
273
+ sum_l2 += w*l*l;
274
+ sum_xl += w*l*x[i];
275
+ }
276
+ float D = sum_w * sum_l2 - sum_l * sum_l;
277
+ if (D > 0) {
278
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
279
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
280
+ if (this_min > 0) {
281
+ this_min = 0;
282
+ this_scale = sum_xl / sum_l2;
283
+ }
284
+ float mad = 0;
285
+ for (int i = 0; i < n; ++i) {
286
+ float diff = this_scale * Laux[i] + this_min - x[i];
287
+ diff = use_mad ? fabsf(diff) : diff * diff;
288
+ float w = weights[i];
289
+ mad += w * diff;
290
+ }
291
+ if (mad < best_mad) {
292
+ for (int i = 0; i < n; ++i) {
293
+ L[i] = Laux[i];
294
+ }
295
+ best_mad = mad;
296
+ scale = this_scale;
297
+ min = this_min;
298
+ }
299
+ }
300
+ }
301
+ *the_min = -min;
302
+ return scale;
303
+ }
304
+
266
305
  #if QK_K == 256
267
306
  static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
268
307
  if (j < 4) {
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
281
320
  const int nb = k / QK_K;
282
321
 
283
322
  uint8_t L[QK_K];
323
+ uint8_t Laux[16];
324
+ float weights[16];
284
325
  float mins[QK_K/16];
285
326
  float scales[QK_K/16];
286
327
 
@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
291
332
  float max_scale = 0; // as we are deducting the min, scales are always positive
292
333
  float max_min = 0;
293
334
  for (int j = 0; j < QK_K/16; ++j) {
294
- scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
335
+ for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
336
+ scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
295
337
  float scale = scales[j];
296
338
  if (scale > max_scale) {
297
339
  max_scale = scale;
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
637
679
  const int nb = k / QK_K;
638
680
 
639
681
  uint8_t L[QK_K];
682
+ uint8_t Laux[32];
683
+ float weights[32];
640
684
  float mins[QK_K/32];
641
685
  float scales[QK_K/32];
642
686
 
@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
645
689
  float max_scale = 0; // as we are deducting the min, scales are always positive
646
690
  float max_min = 0;
647
691
  for (int j = 0; j < QK_K/32; ++j) {
648
- scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
692
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
693
+ float sum_x2 = 0;
694
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
695
+ float av_x = sqrtf(sum_x2/32);
696
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
697
+ scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
649
698
  float scale = scales[j];
650
699
  if (scale > max_scale) {
651
700
  max_scale = scale;
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
798
847
  uint8_t L[QK_K];
799
848
  float mins[QK_K/32];
800
849
  float scales[QK_K/32];
850
+ float weights[32];
851
+ uint8_t Laux[32];
801
852
  #else
802
853
  int8_t L[QK_K];
803
854
  float scales[QK_K/16];
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
810
861
  float max_scale = 0; // as we are deducting the min, scales are always positive
811
862
  float max_min = 0;
812
863
  for (int j = 0; j < QK_K/32; ++j) {
813
- scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
864
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
865
+ float sum_x2 = 0;
866
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
867
+ float av_x = sqrtf(sum_x2/32);
868
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
869
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
814
870
  float scale = scales[j];
815
871
  if (scale > max_scale) {
816
872
  max_scale = scale;
@@ -2638,13 +2694,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2638
2694
  const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2639
2695
  __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
2640
2696
  p16l = _mm256_madd_epi16(scale_l, p16l);
2641
- sumi = _mm256_add_epi32(sumi, p16l);
2642
2697
 
2643
2698
  const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2644
2699
  __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
2645
2700
  p16h = _mm256_madd_epi16(scale_h, p16h);
2646
- sumi = _mm256_add_epi32(sumi, p16h);
2701
+ const __m256i sumj = _mm256_add_epi32(p16l, p16h);
2647
2702
 
2703
+ sumi = _mm256_add_epi32(sumi, sumj);
2648
2704
  }
2649
2705
 
2650
2706
  __m256 vd = _mm256_set1_ps(d);