whisper.rn 0.3.0-rc.2 → 0.3.0-rc.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -512,7 +512,7 @@ static inline int hsum_i32_4(const __m128i a) {
512
512
  return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
513
513
  }
514
514
 
515
- #if __AVX2__ || __AVX512F__
515
+ #if defined(__AVX2__) || defined(__AVX512F__)
516
516
  // spread 32 bits to 32 bytes { 0x00, 0xFF }
517
517
  static inline __m256i bytes_from_bits_32(const uint8_t * x) {
518
518
  uint32_t x32;
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
543
543
  return _mm256_cvtepi32_ps(summed_pairs);
544
544
  }
545
545
 
546
- // multiply int8_t, add results pairwise twice and return as float vector
547
- static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
548
- // Get absolute values of x vectors
549
- const __m256i ax = _mm256_sign_epi8(x, x);
550
- // Sign the values of the y vectors
551
- const __m256i sy = _mm256_sign_epi8(y, x);
546
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
552
547
  #if __AVXVNNI__
553
548
  const __m256i zero = _mm256_setzero_si256();
554
549
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560
555
  #endif
561
556
  }
562
557
 
558
+ // multiply int8_t, add results pairwise twice and return as float vector
559
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560
+ #if __AVXVNNIINT8__
561
+ const __m256i zero = _mm256_setzero_si256();
562
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
563
+ return _mm256_cvtepi32_ps(summed_pairs);
564
+ #else
565
+ // Get absolute values of x vectors
566
+ const __m256i ax = _mm256_sign_epi8(x, x);
567
+ // Sign the values of the y vectors
568
+ const __m256i sy = _mm256_sign_epi8(y, x);
569
+ return mul_sum_us8_pairs_float(ax, sy);
570
+ #endif
571
+ }
572
+
563
573
  static inline __m128i packNibbles( __m256i bytes )
564
574
  {
565
575
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
619
629
  return _mm256_cvtepi32_ps(summed_pairs);
620
630
  }
621
631
 
632
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
633
+ const __m128i axl = _mm256_castsi256_si128(ax);
634
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
635
+ const __m128i syl = _mm256_castsi256_si128(sy);
636
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
637
+ // Perform multiplication and create 16-bit values
638
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
639
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
640
+ return sum_i16_pairs_float(doth, dotl);
641
+ }
642
+
622
643
  // multiply int8_t, add results pairwise twice and return as float vector
623
644
  static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
624
645
  const __m128i xl = _mm256_castsi256_si128(x);
@@ -667,7 +688,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
667
688
  #endif // __AVX__ || __AVX2__ || __AVX512F__
668
689
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
669
690
 
670
- #if __ARM_NEON
691
+ #if defined(__ARM_NEON)
671
692
 
672
693
  #if !defined(__aarch64__)
673
694
 
@@ -719,19 +740,19 @@ inline static float vaddvq_f32(float32x4_t v) {
719
740
  return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
720
741
  }
721
742
 
722
- float vminvq_f32(float32x4_t v) {
743
+ inline static float vminvq_f32(float32x4_t v) {
723
744
  return
724
745
  MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
725
746
  MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
726
747
  }
727
748
 
728
- float vmaxvq_f32(float32x4_t v) {
749
+ inline static float vmaxvq_f32(float32x4_t v) {
729
750
  return
730
751
  MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
731
752
  MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
732
753
  }
733
754
 
734
- int32x4_t vcvtnq_s32_f32(float32x4_t v) {
755
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
735
756
  int32x4_t res;
736
757
 
737
758
  res[0] = roundf(vgetq_lane_f32(v, 0));
@@ -745,21 +766,20 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
745
766
  #endif
746
767
  #endif
747
768
 
748
-
749
769
  #define QK4_0 32
750
770
  typedef struct {
751
- float d; // delta
771
+ ggml_fp16_t d; // delta
752
772
  uint8_t qs[QK4_0 / 2]; // nibbles / quants
753
773
  } block_q4_0;
754
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
774
+ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
755
775
 
756
776
  #define QK4_1 32
757
777
  typedef struct {
758
- float d; // delta
759
- float m; // min
778
+ ggml_fp16_t d; // delta
779
+ ggml_fp16_t m; // min
760
780
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
761
781
  } block_q4_1;
762
- static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
782
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
763
783
 
764
784
  #define QK5_0 32
765
785
  typedef struct {
@@ -780,16 +800,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
780
800
 
781
801
  #define QK8_0 32
782
802
  typedef struct {
783
- float d; // delta
784
- int8_t qs[QK8_0]; // quants
803
+ ggml_fp16_t d; // delta
804
+ int8_t qs[QK8_0]; // quants
785
805
  } block_q8_0;
786
- static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
806
+ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
787
807
 
788
808
  #define QK8_1 32
789
809
  typedef struct {
790
- float d; // delta
791
- float s; // d * sum(qs[i])
792
- int8_t qs[QK8_1]; // quants
810
+ float d; // delta
811
+ float s; // d * sum(qs[i])
812
+ int8_t qs[QK8_1]; // quants
793
813
  } block_q8_1;
794
814
  static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
795
815
 
@@ -816,7 +836,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
816
836
  const float d = max / -8;
817
837
  const float id = d ? 1.0f/d : 0.0f;
818
838
 
819
- y[i].d = d;
839
+ y[i].d = GGML_FP32_TO_FP16(d);
820
840
 
821
841
  for (int j = 0; j < qk/2; ++j) {
822
842
  const float x0 = x[i*qk + 0 + j]*id;
@@ -856,8 +876,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
856
876
  const float d = (max - min) / ((1 << 4) - 1);
857
877
  const float id = d ? 1.0f/d : 0.0f;
858
878
 
859
- y[i].d = d;
860
- y[i].m = min;
879
+ y[i].d = GGML_FP32_TO_FP16(d);
880
+ y[i].m = GGML_FP32_TO_FP16(min);
861
881
 
862
882
  for (int j = 0; j < qk/2; ++j) {
863
883
  const float x0 = (x[i*qk + 0 + j] - min)*id;
@@ -988,7 +1008,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
988
1008
  const float d = amax / ((1 << 7) - 1);
989
1009
  const float id = d ? 1.0f/d : 0.0f;
990
1010
 
991
- y[i].d = d;
1011
+ y[i].d = GGML_FP32_TO_FP16(d);
992
1012
 
993
1013
  for (int j = 0; j < QK8_0; ++j) {
994
1014
  const float x0 = x[i*QK8_0 + j]*id;
@@ -1023,7 +1043,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1023
1043
  const float d = amax / ((1 << 7) - 1);
1024
1044
  const float id = d ? 1.0f/d : 0.0f;
1025
1045
 
1026
- y[i].d = d;
1046
+ y[i].d = GGML_FP32_TO_FP16(d);
1027
1047
 
1028
1048
  for (int j = 0; j < 8; j++) {
1029
1049
  const float32x4_t v = vmulq_n_f32(srcv[j], id);
@@ -1035,6 +1055,39 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1035
1055
  y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
1036
1056
  }
1037
1057
  }
1058
+ #elif defined(__wasm_simd128__)
1059
+ for (int i = 0; i < nb; i++) {
1060
+ v128_t srcv [8];
1061
+ v128_t asrcv[8];
1062
+ v128_t amaxv[8];
1063
+
1064
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1065
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1066
+
1067
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1068
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1069
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1070
+
1071
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1072
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
1073
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1074
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
1075
+
1076
+ const float d = amax / ((1 << 7) - 1);
1077
+ const float id = d ? 1.0f/d : 0.0f;
1078
+
1079
+ y[i].d = GGML_FP32_TO_FP16(d);
1080
+
1081
+ for (int j = 0; j < 8; j++) {
1082
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1083
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1084
+
1085
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1086
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1087
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1088
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1089
+ }
1090
+ }
1038
1091
  #elif defined(__AVX2__) || defined(__AVX__)
1039
1092
  for (int i = 0; i < nb; i++) {
1040
1093
  // Load elements into 4 AVX vectors
@@ -1058,7 +1111,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1058
1111
 
1059
1112
  // Quantize these floats
1060
1113
  const float d = maxScalar / 127.f;
1061
- y[i].d = d;
1114
+ y[i].d = GGML_FP32_TO_FP16(d);
1062
1115
  const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1063
1116
  const __m256 mul = _mm256_set1_ps( id );
1064
1117
 
@@ -1157,7 +1210,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
1157
1210
  sum += y[i].qs[QK8_1/2 + j];
1158
1211
  }
1159
1212
 
1160
- y[i].s = d * sum;
1213
+ y[i].s = sum*d;
1161
1214
  }
1162
1215
  }
1163
1216
 
@@ -1203,6 +1256,48 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
1203
1256
 
1204
1257
  y[i].s = d * vaddvq_s32(accv);
1205
1258
  }
1259
+ #elif defined(__wasm_simd128__)
1260
+ for (int i = 0; i < nb; i++) {
1261
+ v128_t srcv [8];
1262
+ v128_t asrcv[8];
1263
+ v128_t amaxv[8];
1264
+
1265
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1266
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1267
+
1268
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1269
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1270
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1271
+
1272
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1273
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
1274
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1275
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
1276
+
1277
+ const float d = amax / ((1 << 7) - 1);
1278
+ const float id = d ? 1.0f/d : 0.0f;
1279
+
1280
+ y[i].d = d;
1281
+
1282
+ v128_t accv = wasm_i32x4_splat(0);
1283
+
1284
+ for (int j = 0; j < 8; j++) {
1285
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1286
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1287
+
1288
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1289
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1290
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1291
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1292
+
1293
+ accv = wasm_i32x4_add(accv, vi);
1294
+ }
1295
+
1296
+ y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
1297
+ wasm_i32x4_extract_lane(accv, 1) +
1298
+ wasm_i32x4_extract_lane(accv, 2) +
1299
+ wasm_i32x4_extract_lane(accv, 3));
1300
+ }
1206
1301
  #elif defined(__AVX2__) || defined(__AVX__)
1207
1302
  for (int i = 0; i < nb; i++) {
1208
1303
  // Load elements into 4 AVX vectors
@@ -1309,7 +1404,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
1309
1404
  const int nb = k / qk;
1310
1405
 
1311
1406
  for (int i = 0; i < nb; i++) {
1312
- const float d = x[i].d;
1407
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1313
1408
 
1314
1409
  for (int j = 0; j < qk/2; ++j) {
1315
1410
  const int x0 = (x[i].qs[j] & 0x0F) - 8;
@@ -1329,8 +1424,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
1329
1424
  const int nb = k / qk;
1330
1425
 
1331
1426
  for (int i = 0; i < nb; i++) {
1332
- const float d = x[i].d;
1333
- const float m = x[i].m;
1427
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1428
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1334
1429
 
1335
1430
  for (int j = 0; j < qk/2; ++j) {
1336
1431
  const int x0 = (x[i].qs[j] & 0x0F);
@@ -1405,7 +1500,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
1405
1500
  const block_q8_0 * restrict x = vx;
1406
1501
 
1407
1502
  for (int i = 0; i < nb; i++) {
1408
- const float d = x[i].d;
1503
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1409
1504
 
1410
1505
  for (int j = 0; j < qk; ++j) {
1411
1506
  y[i*qk + j] = x[i].qs[j]*d;
@@ -1669,8 +1764,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1669
1764
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
1670
1765
  float tmp[8];
1671
1766
 
1672
- for (int i = 0; i < 8; i++)
1767
+ for (int i = 0; i < 8; i++) {
1673
1768
  tmp[i] = GGML_FP16_TO_FP32(x[i]);
1769
+ }
1674
1770
 
1675
1771
  return _mm256_loadu_ps(tmp);
1676
1772
  }
@@ -2090,8 +2186,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2090
2186
  const block_q8_0 * restrict y0 = &y[i + 0];
2091
2187
  const block_q8_0 * restrict y1 = &y[i + 1];
2092
2188
 
2093
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2094
- const int8x16_t s8b = vdupq_n_s8(0x8);
2189
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2190
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2095
2191
 
2096
2192
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2097
2193
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2119,8 +2215,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2119
2215
  const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2120
2216
  const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2121
2217
 
2122
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2123
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2218
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2219
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2124
2220
  #else
2125
2221
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2126
2222
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
@@ -2137,8 +2233,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2137
2233
  const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2138
2234
  const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2139
2235
 
2140
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2141
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2236
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2237
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2142
2238
  #endif
2143
2239
  }
2144
2240
 
@@ -2150,7 +2246,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2150
2246
  // Main loop
2151
2247
  for (int i = 0; i < nb; ++i) {
2152
2248
  /* Compute combined scale for the block */
2153
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2249
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2154
2250
 
2155
2251
  __m256i bx = bytes_from_nibbles_32(x[i].qs);
2156
2252
 
@@ -2174,7 +2270,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2174
2270
  // Main loop
2175
2271
  for (int i = 0; i < nb; ++i) {
2176
2272
  // Compute combined scale for the block
2177
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2273
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2178
2274
 
2179
2275
  const __m128i lowMask = _mm_set1_epi8(0xF);
2180
2276
  const __m128i off = _mm_set1_epi8(8);
@@ -2216,7 +2312,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2216
2312
  _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2217
2313
 
2218
2314
  // Compute combined scale for the block 0 and 1
2219
- const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
2315
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2220
2316
 
2221
2317
  const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2222
2318
 
@@ -2234,7 +2330,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2234
2330
  _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2235
2331
 
2236
2332
  // Compute combined scale for the block 2 and 3
2237
- const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
2333
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2238
2334
 
2239
2335
  const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2240
2336
 
@@ -2267,7 +2363,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2267
2363
  _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2268
2364
 
2269
2365
  // Compute combined scale for the block 0 and 1
2270
- const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
2366
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2271
2367
 
2272
2368
  const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2273
2369
 
@@ -2285,7 +2381,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2285
2381
  _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2286
2382
 
2287
2383
  // Compute combined scale for the block 2 and 3
2288
- const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
2384
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2289
2385
 
2290
2386
  const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2291
2387
 
@@ -2333,7 +2429,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2333
2429
  sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2334
2430
  }
2335
2431
 
2336
- sumf += (x[i].d*y[i].d)*sumi;
2432
+ sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2337
2433
  }
2338
2434
 
2339
2435
  *s = sumf;
@@ -2363,7 +2459,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2363
2459
  const block_q8_1 * restrict y0 = &y[i + 0];
2364
2460
  const block_q8_1 * restrict y1 = &y[i + 1];
2365
2461
 
2366
- summs += x0->m * y0->s + x1->m * y1->s;
2462
+ summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2367
2463
 
2368
2464
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
2369
2465
 
@@ -2387,8 +2483,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2387
2483
  const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2388
2484
  const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2389
2485
 
2390
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2391
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2486
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2487
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
2392
2488
  #else
2393
2489
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2394
2490
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2405,8 +2501,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2405
2501
  const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2406
2502
  const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2407
2503
 
2408
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2409
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2504
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2505
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2410
2506
  #endif
2411
2507
  }
2412
2508
 
@@ -2419,13 +2515,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2419
2515
 
2420
2516
  // Main loop
2421
2517
  for (int i = 0; i < nb; ++i) {
2422
- const float * d0 = &x[i].d;
2423
- const float * d1 = &y[i].d;
2518
+ const float d0 = GGML_FP16_TO_FP32(x[i].d);
2519
+ const float d1 = y[i].d;
2424
2520
 
2425
- summs += x[i].m * y[i].s;
2521
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2426
2522
 
2427
- const __m256 d0v = _mm256_broadcast_ss( d0 );
2428
- const __m256 d1v = _mm256_broadcast_ss( d1 );
2523
+ const __m256 d0v = _mm256_set1_ps( d0 );
2524
+ const __m256 d1v = _mm256_set1_ps( d1 );
2429
2525
 
2430
2526
  // Compute combined scales
2431
2527
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
@@ -2434,7 +2530,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2434
2530
  const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2435
2531
  const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2436
2532
 
2437
- const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2533
+ const __m256 xy = mul_sum_us8_pairs_float(bx, by);
2438
2534
 
2439
2535
  // Accumulate d0*d1*x*y
2440
2536
  #if defined(__AVX2__)
@@ -2459,7 +2555,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2459
2555
  sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2460
2556
  }
2461
2557
 
2462
- sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
2558
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2463
2559
  }
2464
2560
 
2465
2561
  *s = sumf;
@@ -2535,16 +2631,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2535
2631
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2536
2632
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2537
2633
 
2538
- const float x0d = GGML_FP16_TO_FP32(x0->d);
2539
- const float x1d = GGML_FP16_TO_FP32(x1->d);
2540
-
2541
2634
  #if defined(__ARM_FEATURE_DOTPROD)
2542
2635
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2543
2636
  vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2544
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
2637
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2545
2638
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2546
2639
  vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2547
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
2640
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2548
2641
  #else
2549
2642
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2550
2643
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2561,8 +2654,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2561
2654
  const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2562
2655
  const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2563
2656
 
2564
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2565
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
2657
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2658
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2566
2659
  #endif
2567
2660
  }
2568
2661
 
@@ -2579,7 +2672,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2579
2672
  const block_q8_0 * restrict y0 = &y[i];
2580
2673
 
2581
2674
  const v128_t m4b = wasm_i8x16_splat(0x0F);
2582
- const v128_t s16b = wasm_i8x16_splat(0x10);
2583
2675
 
2584
2676
  // extract the 5th bit
2585
2677
  memcpy(&qh, x0->qh, sizeof(qh));
@@ -2617,15 +2709,14 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2617
2709
  const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
2618
2710
  const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2619
2711
 
2620
- const float x0d = GGML_FP16_TO_FP32(x0->d);
2621
-
2622
2712
  // dot product
2623
2713
  sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
2624
2714
  wasm_i32x4_add(
2625
2715
  wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
2626
2716
  wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
2627
2717
  wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
2628
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
2718
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
2719
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
2629
2720
  }
2630
2721
 
2631
2722
  *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -2637,7 +2728,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2637
2728
  // Main loop
2638
2729
  for (int i = 0; i < nb; i++) {
2639
2730
  /* Compute combined scale for the block */
2640
- const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2731
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
2641
2732
 
2642
2733
  __m256i bx = bytes_from_nibbles_32(x[i].qs);
2643
2734
  __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2661,7 +2752,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2661
2752
  // Main loop
2662
2753
  for (int i = 0; i < nb; i++) {
2663
2754
  /* Compute combined scale for the block */
2664
- const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2755
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
2665
2756
 
2666
2757
  __m256i bx = bytes_from_nibbles_32(x[i].qs);
2667
2758
  const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2704,7 +2795,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2704
2795
  sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
2705
2796
  }
2706
2797
 
2707
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
2798
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
2708
2799
  }
2709
2800
 
2710
2801
  *s = sumf;
@@ -2786,16 +2877,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2786
2877
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2787
2878
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2788
2879
 
2789
- const float x0d = GGML_FP16_TO_FP32(x0->d);
2790
- const float x1d = GGML_FP16_TO_FP32(x1->d);
2791
-
2792
2880
  #if defined(__ARM_FEATURE_DOTPROD)
2793
2881
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2794
2882
  vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2795
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
2883
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
2796
2884
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2797
2885
  vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2798
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
2886
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
2799
2887
  #else
2800
2888
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2801
2889
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2812,8 +2900,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2812
2900
  const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2813
2901
  const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2814
2902
 
2815
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2816
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
2903
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2904
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2817
2905
  #endif
2818
2906
  }
2819
2907
 
@@ -2852,8 +2940,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2852
2940
  const v128_t v0l = wasm_v128_and (v0, m4b);
2853
2941
  const v128_t v0h = wasm_u8x16_shr(v0, 4);
2854
2942
 
2855
- static bool x = true;
2856
-
2857
2943
  // add high bit
2858
2944
  const v128_t v0lf = wasm_v128_or(v0l, qhl);
2859
2945
  const v128_t v0hf = wasm_v128_or(v0h, qhh);
@@ -2873,15 +2959,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2873
2959
  const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
2874
2960
  const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2875
2961
 
2876
- const float x0d = GGML_FP16_TO_FP32(x0->d);
2877
-
2878
2962
  // dot product
2879
- sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
2880
- wasm_i32x4_add(
2963
+ sumv = wasm_f32x4_add(sumv,
2964
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
2881
2965
  wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
2882
2966
  wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
2883
2967
  wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
2884
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
2968
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
2969
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
2885
2970
  }
2886
2971
 
2887
2972
  *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -2903,10 +2988,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2903
2988
  bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
2904
2989
  bx = _mm256_or_si256(bx, bxhi);
2905
2990
 
2906
- const __m256 dy = _mm256_broadcast_ss(&y[i].d);
2991
+ const __m256 dy = _mm256_set1_ps(y[i].d);
2907
2992
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2908
2993
 
2909
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
2994
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
2910
2995
 
2911
2996
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2912
2997
  }
@@ -2937,10 +3022,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2937
3022
  bxh = _mm_or_si128(bxh, bxhih);
2938
3023
  bx = _mm256_set_m128i(bxh, bxl);
2939
3024
 
2940
- const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3025
+ const __m256 dy = _mm256_set1_ps(y[i].d);
2941
3026
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2942
3027
 
2943
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
3028
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
2944
3029
 
2945
3030
  acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
2946
3031
  }
@@ -3007,11 +3092,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
3007
3092
  #if defined(__ARM_FEATURE_DOTPROD)
3008
3093
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3009
3094
  vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3010
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
3095
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3011
3096
 
3012
3097
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3013
3098
  vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3014
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
3099
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3015
3100
 
3016
3101
  #else
3017
3102
  const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
@@ -3029,8 +3114,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
3029
3114
  const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3030
3115
  const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3031
3116
 
3032
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3033
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
3117
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3118
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3034
3119
  #endif
3035
3120
  }
3036
3121
 
@@ -3042,7 +3127,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
3042
3127
  // Main loop
3043
3128
  for (int i = 0; i < nb; ++i) {
3044
3129
  // Compute combined scale for the block
3045
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
3130
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3046
3131
  __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3047
3132
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3048
3133
 
@@ -3068,7 +3153,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
3068
3153
  sumi += x[i].qs[j]*y[i].qs[j];
3069
3154
  }
3070
3155
 
3071
- sumf += (x[i].d*y[i].d)*sumi;
3156
+ sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
3072
3157
  }
3073
3158
 
3074
3159
  *s = sumf;
@@ -3457,6 +3542,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3457
3542
  "ROPE",
3458
3543
  "ROPE_BACK",
3459
3544
  "ALIBI",
3545
+ "CLAMP",
3460
3546
  "CONV_1D_1S",
3461
3547
  "CONV_1D_2S",
3462
3548
 
@@ -3467,7 +3553,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3467
3553
  "MAP_BINARY",
3468
3554
  };
3469
3555
 
3470
- static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
3556
+ static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3557
+
3471
3558
 
3472
3559
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3473
3560
  "none",
@@ -3517,6 +3604,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3517
3604
  "rope(x)",
3518
3605
  "rope_back(x)",
3519
3606
  "alibi(x)",
3607
+ "clamp(x)",
3520
3608
  "conv_1d_1s(x)",
3521
3609
  "conv_1d_2s(x)",
3522
3610
 
@@ -3527,7 +3615,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3527
3615
  "f(x,y)",
3528
3616
  };
3529
3617
 
3530
- static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
3618
+ static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3531
3619
 
3532
3620
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3533
3621
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3761,6 +3849,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
3761
3849
  (t1->ne[3]%t0->ne[3] == 0);
3762
3850
  }
3763
3851
 
3852
+ static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3853
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3854
+
3855
+ return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
3856
+ }
3857
+
3764
3858
  static inline int ggml_up32(int n) {
3765
3859
  return (n + 31) & ~31;
3766
3860
  }
@@ -4643,11 +4737,15 @@ struct ggml_tensor * ggml_mul_impl(
4643
4737
  struct ggml_tensor * a,
4644
4738
  struct ggml_tensor * b,
4645
4739
  bool inplace) {
4646
- GGML_ASSERT(ggml_are_same_shape(a, b));
4740
+ // TODO: support less-strict constraint
4741
+ // GGML_ASSERT(ggml_can_repeat(b, a));
4742
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
4647
4743
 
4648
4744
  bool is_node = false;
4649
4745
 
4650
4746
  if (!inplace && (a->grad || b->grad)) {
4747
+ // TODO: support backward pass for broadcasting
4748
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4651
4749
  is_node = true;
4652
4750
  }
4653
4751
 
@@ -6189,7 +6287,8 @@ struct ggml_tensor * ggml_alibi(
6189
6287
  struct ggml_context * ctx,
6190
6288
  struct ggml_tensor * a,
6191
6289
  int n_past,
6192
- int n_head) {
6290
+ int n_head,
6291
+ float bias_max) {
6193
6292
  GGML_ASSERT(n_past >= 0);
6194
6293
  bool is_node = false;
6195
6294
 
@@ -6208,6 +6307,8 @@ struct ggml_tensor * ggml_alibi(
6208
6307
 
6209
6308
  ((int32_t *) b->data)[0] = n_past;
6210
6309
  ((int32_t *) b->data)[1] = n_head;
6310
+ GGML_ASSERT(sizeof(float) == sizeof(int32_t));
6311
+ (((float *) b->data)[2]) = bias_max;
6211
6312
 
6212
6313
  ggml_scratch_load(ctx);
6213
6314
 
@@ -6219,6 +6320,40 @@ struct ggml_tensor * ggml_alibi(
6219
6320
  return result;
6220
6321
  }
6221
6322
 
6323
+ // ggml_clamp
6324
+
6325
+ struct ggml_tensor * ggml_clamp(
6326
+ struct ggml_context * ctx,
6327
+ struct ggml_tensor * a,
6328
+ float min,
6329
+ float max) {
6330
+ bool is_node = false;
6331
+
6332
+ if (a->grad) {
6333
+ GGML_ASSERT(false); // TODO: implement backward
6334
+ is_node = true;
6335
+ }
6336
+
6337
+ // TODO: when implement backward, fix this:
6338
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6339
+
6340
+ ggml_scratch_save(ctx);
6341
+
6342
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6343
+
6344
+ ((float *) b->data)[0] = min;
6345
+ ((float *) b->data)[1] = max;
6346
+
6347
+ ggml_scratch_load(ctx);
6348
+
6349
+ result->op = GGML_OP_CLAMP;
6350
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6351
+ result->src0 = a;
6352
+ result->src1 = b;
6353
+
6354
+ return result;
6355
+ }
6356
+
6222
6357
  // ggml_conv_1d_1s
6223
6358
 
6224
6359
  struct ggml_tensor * ggml_conv_1d_1s(
@@ -7945,7 +8080,7 @@ static void ggml_compute_forward_mul_f32(
7945
8080
  const struct ggml_tensor * src0,
7946
8081
  const struct ggml_tensor * src1,
7947
8082
  struct ggml_tensor * dst) {
7948
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8083
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
7949
8084
 
7950
8085
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
7951
8086
  return;
@@ -7953,10 +8088,25 @@ static void ggml_compute_forward_mul_f32(
7953
8088
  const int ith = params->ith;
7954
8089
  const int nth = params->nth;
7955
8090
 
7956
- const int nr = ggml_nrows(src0);
7957
- const int64_t ne0 = src0->ne[0];
7958
- const int64_t ne1 = src0->ne[1];
7959
- const int64_t ne2 = src0->ne[2];
8091
+ #ifdef GGML_USE_CUBLAS
8092
+ if (src1->backend == GGML_BACKEND_CUDA) {
8093
+ if (ith == 0) {
8094
+ ggml_cuda_mul(src0, src1, dst);
8095
+ }
8096
+ return;
8097
+ }
8098
+ #endif
8099
+
8100
+ const int64_t nr = ggml_nrows(src0);
8101
+
8102
+ const int64_t ne00 = src0->ne[0];
8103
+ const int64_t ne01 = src0->ne[1];
8104
+ const int64_t ne02 = src0->ne[2];
8105
+
8106
+ const int64_t ne10 = src1->ne[0];
8107
+ const int64_t ne11 = src1->ne[1];
8108
+ const int64_t ne12 = src1->ne[2];
8109
+ const int64_t ne13 = src1->ne[3];
7960
8110
 
7961
8111
  const size_t nb00 = src0->nb[0];
7962
8112
  const size_t nb01 = src0->nb[1];
@@ -7975,44 +8125,51 @@ static void ggml_compute_forward_mul_f32(
7975
8125
 
7976
8126
  GGML_ASSERT( nb0 == sizeof(float));
7977
8127
  GGML_ASSERT(nb00 == sizeof(float));
8128
+ GGML_ASSERT(ne00 == ne10);
7978
8129
 
7979
8130
  if (nb10 == sizeof(float)) {
7980
- for (int ir = ith; ir < nr; ir += nth) {
7981
- // src0, src1 and dst are same shape => same indices
7982
- const int i3 = ir/(ne2*ne1);
7983
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7984
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8131
+ for (int64_t ir = ith; ir < nr; ir += nth) {
8132
+ // src0 and dst are same shape => same indices
8133
+ const int64_t i03 = ir/(ne02*ne01);
8134
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8135
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7985
8136
 
8137
+ const int64_t i13 = i03 % ne13;
8138
+ const int64_t i12 = i02 % ne12;
8139
+ const int64_t i11 = i01 % ne11;
8140
+
8141
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8142
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8143
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7986
8144
 
7987
8145
  #ifdef GGML_USE_ACCELERATE
7988
8146
  UNUSED(ggml_vec_mul_f32);
7989
8147
 
7990
- vDSP_vmul(
7991
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
7992
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
7993
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
7994
- ne0);
8148
+ vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
7995
8149
  #else
7996
- ggml_vec_mul_f32(ne0,
7997
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
7998
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
7999
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8150
+ ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8000
8151
  #endif
8001
8152
  // }
8002
8153
  // }
8003
8154
  }
8004
8155
  } else {
8005
8156
  // src1 is not contiguous
8006
- for (int ir = ith; ir < nr; ir += nth) {
8007
- // src0, src1 and dst are same shape => same indices
8008
- const int i3 = ir/(ne2*ne1);
8009
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8010
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8157
+ for (int64_t ir = ith; ir < nr; ir += nth) {
8158
+ // src0 and dst are same shape => same indices
8159
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8160
+ const int64_t i03 = ir/(ne02*ne01);
8161
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8162
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8011
8163
 
8012
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8013
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8014
- for (int i0 = 0; i0 < ne0; i0++) {
8015
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8164
+ const int64_t i13 = i03 % ne13;
8165
+ const int64_t i12 = i02 % ne12;
8166
+ const int64_t i11 = i01 % ne11;
8167
+
8168
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8169
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8170
+
8171
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
8172
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
8016
8173
 
8017
8174
  dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
8018
8175
  }
@@ -10501,34 +10658,29 @@ static void ggml_compute_forward_diag_mask_f32(
10501
10658
  assert(src1->type == GGML_TYPE_I32);
10502
10659
  assert(ggml_nelements(src1) == 2);
10503
10660
 
10661
+ const int ith = params->ith;
10662
+ const int nth = params->nth;
10663
+
10504
10664
  const int n_past = ((int32_t *) src1->data)[0];
10505
10665
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10506
10666
 
10507
- if (params->type == GGML_TASK_INIT) {
10508
- // TODO: this hack is not good, need a better way to handle this
10509
- if (!inplace) {
10510
- // use the init task to copy src -> dst
10511
- struct ggml_compute_params params_cpy = *params;
10512
-
10513
- params_cpy.ith = 0;
10514
- params_cpy.nth = 1;
10515
- params_cpy.type = GGML_TASK_COMPUTE;
10516
-
10517
- ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
10518
- }
10667
+ assert(n_past >= 0);
10519
10668
 
10520
- return;
10669
+ if (!inplace && (params->type == GGML_TASK_INIT)) {
10670
+ // memcpy needs to be synchronized across threads to avoid race conditions.
10671
+ // => do it in INIT phase
10672
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
10673
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
10674
+ memcpy(
10675
+ ((char *) dst->data),
10676
+ ((char *) src0->data),
10677
+ ggml_nbytes(dst));
10521
10678
  }
10522
10679
 
10523
- if (params->type == GGML_TASK_FINALIZE) {
10680
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10524
10681
  return;
10525
10682
  }
10526
10683
 
10527
- const int ith = params->ith;
10528
- const int nth = params->nth;
10529
-
10530
- assert(n_past >= 0);
10531
-
10532
10684
  // TODO: handle transposed/permuted matrices
10533
10685
 
10534
10686
  const int n = ggml_nrows(src0);
@@ -10682,14 +10834,15 @@ static void ggml_compute_forward_alibi_f32(
10682
10834
  struct ggml_tensor * dst) {
10683
10835
  assert(params->ith == 0);
10684
10836
  assert(src1->type == GGML_TYPE_I32);
10685
- assert(ggml_nelements(src1) == 2);
10837
+ assert(ggml_nelements(src1) == 3);
10686
10838
 
10687
10839
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10688
10840
  return;
10689
10841
  }
10690
10842
 
10691
- const int n_past = ((int32_t *) src1->data)[0];
10692
- const int n_head = ((int32_t *) src1->data)[1];
10843
+ const int n_past = ((int32_t *) src1->data)[0];
10844
+ const int n_head = ((int32_t *) src1->data)[1];
10845
+ const float max_bias = ((float *) src1->data)[2];
10693
10846
 
10694
10847
  assert(n_past >= 0);
10695
10848
 
@@ -10712,8 +10865,8 @@ static void ggml_compute_forward_alibi_f32(
10712
10865
  // add alibi to src0 (KQ_scaled)
10713
10866
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
10714
10867
 
10715
- const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
10716
- const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
10868
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
10869
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
10717
10870
 
10718
10871
  for (int i = 0; i < ne0; i++) {
10719
10872
  for (int j = 0; j < ne1; j++) {
@@ -10731,13 +10884,13 @@ static void ggml_compute_forward_alibi_f32(
10731
10884
  m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
10732
10885
  }
10733
10886
 
10734
- pdst[0] = i * m_k + src[0];
10887
+ pdst[0] = (i-ne0+1) * m_k + src[0];
10888
+
10735
10889
  }
10736
10890
  }
10737
10891
  }
10738
10892
  }
10739
10893
 
10740
-
10741
10894
  static void ggml_compute_forward_alibi_f16(
10742
10895
  const struct ggml_compute_params * params,
10743
10896
  const struct ggml_tensor * src0,
@@ -10745,14 +10898,15 @@ static void ggml_compute_forward_alibi_f16(
10745
10898
  struct ggml_tensor * dst) {
10746
10899
  assert(params->ith == 0);
10747
10900
  assert(src1->type == GGML_TYPE_I32);
10748
- assert(ggml_nelements(src1) == 2);
10901
+ assert(ggml_nelements(src1) == 3);
10749
10902
 
10750
10903
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10751
10904
  return;
10752
10905
  }
10753
10906
 
10754
- const int n_past = ((int32_t *) src1->data)[0];
10755
- const int n_head = ((int32_t *) src1->data)[1];
10907
+ const int n_past = ((int32_t *) src1->data)[0];
10908
+ const int n_head = ((int32_t *) src1->data)[1];
10909
+ const float max_bias = ((float *) src1->data)[2];
10756
10910
 
10757
10911
  assert(n_past >= 0);
10758
10912
 
@@ -10775,8 +10929,8 @@ static void ggml_compute_forward_alibi_f16(
10775
10929
  // add alibi to src0 (KQ_scaled)
10776
10930
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
10777
10931
 
10778
- const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
10779
- const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
10932
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
10933
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
10780
10934
 
10781
10935
  for (int i = 0; i < ne0; i++) {
10782
10936
  for (int j = 0; j < ne1; j++) {
@@ -10795,7 +10949,7 @@ static void ggml_compute_forward_alibi_f16(
10795
10949
  }
10796
10950
 
10797
10951
  // we return F32
10798
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
10952
+ pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
10799
10953
  }
10800
10954
  }
10801
10955
  }
@@ -10831,6 +10985,77 @@ static void ggml_compute_forward_alibi(
10831
10985
  }
10832
10986
  }
10833
10987
 
10988
+
10989
+ // ggml_compute_forward_clamp
10990
+
10991
+ static void ggml_compute_forward_clamp_f32(
10992
+ const struct ggml_compute_params * params,
10993
+ const struct ggml_tensor * src0,
10994
+ const struct ggml_tensor * src1,
10995
+ struct ggml_tensor * dst) {
10996
+ assert(params->ith == 0);
10997
+ assert(src1->type == GGML_TYPE_I32);
10998
+ assert(ggml_nelements(src1) == 2);
10999
+
11000
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11001
+ return;
11002
+ }
11003
+
11004
+ const int min = ((float *) src1->data)[0];
11005
+ const int max = ((float *) src1->data)[1];
11006
+
11007
+ const int ith = params->ith;
11008
+ const int nth = params->nth;
11009
+
11010
+ const int n = ggml_nrows(src0);
11011
+ const int nc = src0->ne[0];
11012
+
11013
+ const size_t nb00 = src0->nb[0];
11014
+ const size_t nb01 = src0->nb[1];
11015
+
11016
+ const size_t nb0 = dst->nb[0];
11017
+ const size_t nb1 = dst->nb[1];
11018
+
11019
+ GGML_ASSERT( nb0 == sizeof(float));
11020
+ GGML_ASSERT(nb00 == sizeof(float));
11021
+
11022
+ for (int j = ith; j < n; j += nth) {
11023
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
11024
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
11025
+
11026
+ for (int i = 0; i < nc; i++) {
11027
+ dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
11028
+ }
11029
+ }
11030
+ }
11031
+
11032
+ static void ggml_compute_forward_clamp(
11033
+ const struct ggml_compute_params * params,
11034
+ const struct ggml_tensor * src0,
11035
+ const struct ggml_tensor * src1,
11036
+ struct ggml_tensor * dst) {
11037
+ switch (src0->type) {
11038
+ case GGML_TYPE_F32:
11039
+ {
11040
+ ggml_compute_forward_clamp_f32(params, src0, src1, dst);
11041
+ } break;
11042
+ case GGML_TYPE_F16:
11043
+ case GGML_TYPE_Q4_0:
11044
+ case GGML_TYPE_Q4_1:
11045
+ case GGML_TYPE_Q5_0:
11046
+ case GGML_TYPE_Q5_1:
11047
+ case GGML_TYPE_Q8_0:
11048
+ case GGML_TYPE_Q8_1:
11049
+ case GGML_TYPE_I8:
11050
+ case GGML_TYPE_I16:
11051
+ case GGML_TYPE_I32:
11052
+ case GGML_TYPE_COUNT:
11053
+ {
11054
+ GGML_ASSERT(false);
11055
+ } break;
11056
+ }
11057
+ }
11058
+
10834
11059
  // ggml_compute_forward_rope
10835
11060
 
10836
11061
  static void ggml_compute_forward_rope_f32(
@@ -12812,6 +13037,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12812
13037
  {
12813
13038
  ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
12814
13039
  } break;
13040
+ case GGML_OP_CLAMP:
13041
+ {
13042
+ ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
13043
+ } break;
12815
13044
  case GGML_OP_CONV_1D_1S:
12816
13045
  {
12817
13046
  ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -13119,6 +13348,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13119
13348
  {
13120
13349
  GGML_ASSERT(false); // TODO: not implemented
13121
13350
  } break;
13351
+ case GGML_OP_CLAMP:
13352
+ {
13353
+ GGML_ASSERT(false); // TODO: not implemented
13354
+ } break;
13122
13355
  case GGML_OP_SILU:
13123
13356
  {
13124
13357
  // necessary for llama
@@ -13998,6 +14231,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
13998
14231
  {
13999
14232
  node->n_tasks = 1; //TODO
14000
14233
  } break;
14234
+ case GGML_OP_CLAMP:
14235
+ {
14236
+ node->n_tasks = 1; //TODO
14237
+ } break;
14001
14238
  case GGML_OP_CONV_1D_1S:
14002
14239
  case GGML_OP_CONV_1D_2S:
14003
14240
  {