whisper.rn 0.4.0-rc.10 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-cpu.c CHANGED
@@ -3,11 +3,10 @@
3
3
 
4
4
  #include "ggml-backend-impl.h"
5
5
  #include "ggml-backend.h"
6
- #include "ggml-cpu-aarch64.h"
6
+ #include "ggml-cpu-traits.h"
7
7
  #include "ggml-cpu-impl.h"
8
8
  #include "ggml-cpu.h"
9
9
  #include "ggml-impl.h"
10
- #include "ggml-quants.h"
11
10
  #include "ggml-cpu-quants.h"
12
11
  #include "ggml-threading.h"
13
12
  #include "ggml.h"
@@ -109,10 +108,12 @@ static wsp_ggml_fp16_t wsp_ggml_table_gelu_quick_f16[1 << 16];
109
108
  #if defined(__ARM_ARCH)
110
109
  struct wsp_ggml_arm_arch_features_type {
111
110
  int has_neon;
111
+ int has_dotprod;
112
112
  int has_i8mm;
113
113
  int has_sve;
114
114
  int sve_cnt;
115
- } wsp_ggml_arm_arch_features = {-1, -1, -1, 0};
115
+ int has_sme;
116
+ } wsp_ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
116
117
  #endif
117
118
 
118
119
 
@@ -124,8 +125,7 @@ struct wsp_ggml_arm_arch_features_type {
124
125
  #endif
125
126
  #include <windows.h>
126
127
 
127
-
128
- #if !defined(__clang__)
128
+ #if defined(_MSC_VER) && !defined(__clang__)
129
129
  #define WSP_GGML_CACHE_ALIGN __declspec(align(WSP_GGML_CACHE_LINE))
130
130
 
131
131
  typedef volatile LONG atomic_int;
@@ -222,10 +222,6 @@ typedef void * thread_ret_t;
222
222
 
223
223
  typedef pthread_t wsp_ggml_thread_t;
224
224
 
225
- #ifdef WSP_GGML_USE_CPU_HBM
226
- #include <hbwmalloc.h>
227
- #endif
228
-
229
225
  #if defined(__APPLE__)
230
226
  #include <unistd.h>
231
227
  #include <mach/mach.h>
@@ -241,6 +237,8 @@ typedef pthread_t wsp_ggml_thread_t;
241
237
  #else
242
238
  #if defined(__POWER9_VECTOR__)
243
239
  #define CACHE_LINE_SIZE 128
240
+ #elif defined(__VXE__) || defined(__VXE2__)
241
+ #define CACHE_LINE_SIZE 256
244
242
  #else
245
243
  #define CACHE_LINE_SIZE 64
246
244
  #endif
@@ -299,7 +297,6 @@ static const struct wsp_ggml_type_traits_cpu type_traits_cpu[WSP_GGML_TYPE_COUNT
299
297
  },
300
298
  [WSP_GGML_TYPE_Q8_0] = {
301
299
  .from_float = wsp_quantize_row_q8_0,
302
- .from_float_to_mat = wsp_quantize_mat_q8_0,
303
300
  .vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
304
301
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
305
302
  #if defined (__ARM_FEATURE_MATMUL_INT8)
@@ -407,33 +404,6 @@ static const struct wsp_ggml_type_traits_cpu type_traits_cpu[WSP_GGML_TYPE_COUNT
407
404
  .vec_dot_type = WSP_GGML_TYPE_BF16,
408
405
  .nrows = 1,
409
406
  },
410
- [WSP_GGML_TYPE_Q4_0_4_4] = {
411
- .from_float = NULL,
412
- .vec_dot = NULL,
413
- .vec_dot_type = WSP_GGML_TYPE_Q8_0,
414
- .nrows = 1,
415
- .ncols = 4,
416
- .gemv = wsp_ggml_gemv_q4_0_4x4_q8_0,
417
- .gemm = wsp_ggml_gemm_q4_0_4x4_q8_0,
418
- },
419
- [WSP_GGML_TYPE_Q4_0_4_8] = {
420
- .from_float = NULL,
421
- .vec_dot = NULL,
422
- .vec_dot_type = WSP_GGML_TYPE_Q8_0,
423
- .nrows = 1,
424
- .ncols = 4,
425
- .gemv = wsp_ggml_gemv_q4_0_4x8_q8_0,
426
- .gemm = wsp_ggml_gemm_q4_0_4x8_q8_0,
427
- },
428
- [WSP_GGML_TYPE_Q4_0_8_8] = {
429
- .from_float = NULL,
430
- .vec_dot = NULL,
431
- .vec_dot_type = WSP_GGML_TYPE_Q8_0,
432
- .nrows = 1,
433
- .ncols = 8,
434
- .gemv = wsp_ggml_gemv_q4_0_8x8_q8_0,
435
- .gemm = wsp_ggml_gemm_q4_0_8x8_q8_0,
436
- },
437
407
  [WSP_GGML_TYPE_TQ1_0] = {
438
408
  .from_float = wsp_quantize_row_tq1_0,
439
409
  .vec_dot = wsp_ggml_vec_dot_tq1_0_q8_K,
@@ -485,21 +455,21 @@ const struct wsp_ggml_type_traits_cpu * wsp_ggml_get_type_traits_cpu(enum wsp_gg
485
455
  #define WSP_GGML_F32x4_ADD vaddq_f32
486
456
  #define WSP_GGML_F32x4_MUL vmulq_f32
487
457
  #define WSP_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
488
- #define WSP_GGML_F32x4_REDUCE(res, x) \
489
- { \
490
- int offset = WSP_GGML_F32_ARR >> 1; \
491
- for (int i = 0; i < offset; ++i) { \
492
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
493
- } \
494
- offset >>= 1; \
495
- for (int i = 0; i < offset; ++i) { \
496
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
497
- } \
498
- offset >>= 1; \
499
- for (int i = 0; i < offset; ++i) { \
500
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
501
- } \
502
- (res) = WSP_GGML_F32x4_REDUCE_ONE((x)[0]); \
458
+ #define WSP_GGML_F32x4_REDUCE(res, x) \
459
+ { \
460
+ int offset = WSP_GGML_F32_ARR >> 1; \
461
+ for (int i = 0; i < offset; ++i) { \
462
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
463
+ } \
464
+ offset >>= 1; \
465
+ for (int i = 0; i < offset; ++i) { \
466
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
467
+ } \
468
+ offset >>= 1; \
469
+ for (int i = 0; i < offset; ++i) { \
470
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
471
+ } \
472
+ (res) = (wsp_ggml_float) WSP_GGML_F32x4_REDUCE_ONE((x)[0]); \
503
473
  }
504
474
 
505
475
  #define WSP_GGML_F32_VEC WSP_GGML_F32x4
@@ -614,7 +584,7 @@ do { \
614
584
  for (int i = 0; i < offset; ++i) { \
615
585
  x[i] = _mm512_add_ps(x[i], x[offset+i]); \
616
586
  } \
617
- res = _mm512_reduce_add_ps(x[0]); \
587
+ res = (wsp_ggml_float) _mm512_reduce_add_ps(x[0]); \
618
588
  } while (0)
619
589
 
620
590
  // TODO: is this optimal ?
@@ -664,7 +634,7 @@ do { \
664
634
  for (int i = 0; i < offset; ++i) { \
665
635
  x[i] = _mm512_add_ps(x[i], x[offset+i]); \
666
636
  } \
667
- res = _mm512_reduce_add_ps(x[0]); \
637
+ res = (wsp_ggml_float) _mm512_reduce_add_ps(x[0]); \
668
638
  } while (0)
669
639
 
670
640
  #define WSP_GGML_F16_VEC WSP_GGML_F32Cx16
@@ -675,8 +645,8 @@ do { \
675
645
  #define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx16_FMA
676
646
  #define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx16_ADD
677
647
  #define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx16_MUL
678
- #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx16_REDUCE
679
648
 
649
+ #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx16_REDUCE
680
650
  #elif defined(__AVX__)
681
651
 
682
652
  #define WSP_GGML_SIMD
@@ -745,7 +715,7 @@ do { \
745
715
  #define WSP_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
746
716
  #define WSP_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
747
717
  #else
748
- static inline __m256 __avx_f32cx8_load(wsp_ggml_fp16_t *x) {
718
+ static inline __m256 __avx_f32cx8_load(const wsp_ggml_fp16_t * x) {
749
719
  float tmp[8];
750
720
 
751
721
  for (int i = 0; i < 8; i++) {
@@ -1017,7 +987,7 @@ inline static void __wasm_f16x4_store(wsp_ggml_fp16_t * p, v128_t x) {
1017
987
  #define WSP_GGML_F16_STEP 32
1018
988
  #define WSP_GGML_F16_EPR 4
1019
989
 
1020
- static inline __m128 __sse_f16x4_load(wsp_ggml_fp16_t *x) {
990
+ static inline __m128 __sse_f16x4_load(const wsp_ggml_fp16_t * x) {
1021
991
  float tmp[4];
1022
992
 
1023
993
  tmp[0] = WSP_GGML_FP16_TO_FP32(x[0]);
@@ -1028,7 +998,7 @@ static inline __m128 __sse_f16x4_load(wsp_ggml_fp16_t *x) {
1028
998
  return _mm_loadu_ps(tmp);
1029
999
  }
1030
1000
 
1031
- static inline void __sse_f16x4_store(wsp_ggml_fp16_t *x, __m128 y) {
1001
+ static inline void __sse_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
1032
1002
  float arr[4];
1033
1003
 
1034
1004
  _mm_storeu_ps(arr, y);
@@ -1109,29 +1079,23 @@ do { \
1109
1079
  #define WSP_GGML_F16_STEP 32
1110
1080
  #define WSP_GGML_F16_EPR 8
1111
1081
 
1112
- // F16 arithmetic is not supported by AVX, so we use F32 instead
1082
+ // F16 arithmetic is not supported by LASX, so we use F32 instead
1113
1083
 
1114
1084
  #define WSP_GGML_F32Cx8 __m256
1115
1085
  #define WSP_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1116
1086
  #define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1117
1087
 
1118
1088
  static inline __m256 __lasx_f32cx8_load(const wsp_ggml_fp16_t * x) {
1119
- float tmp[8];
1120
-
1121
- for (int i = 0; i < 8; i++) {
1122
- tmp[i] = WSP_GGML_FP16_TO_FP32(x[i]);
1123
- }
1124
-
1125
- return (__m256)__lasx_xvld(tmp, 0);
1089
+ __m256i a;
1090
+ memcpy(&a, x, sizeof(wsp_ggml_fp16_t) * 8);
1091
+ a = __lasx_xvpermi_d(a, 0 | (1 << 4));
1092
+ return __lasx_xvfcvtl_s_h(a);
1126
1093
  }
1127
- static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
1128
- float arr[8];
1129
-
1130
- __lasx_xvst(y, arr, 0);
1131
1094
 
1132
- for (int i = 0; i < 8; i++) {
1133
- x[i] = WSP_GGML_FP32_TO_FP16(arr[i]);
1134
- }
1095
+ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
1096
+ __m256i a = __lasx_xvfcvt_h_s(y, y);
1097
+ a = __lasx_xvpermi_d(a, 0 | (2 << 2));
1098
+ memcpy(x, &a, sizeof(wsp_ggml_fp16_t) * 8);
1135
1099
  }
1136
1100
  #define WSP_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1137
1101
  #define WSP_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1168,28 +1132,28 @@ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
1168
1132
  #define WSP_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1169
1133
  #define WSP_GGML_F32x4_ADD __lsx_vfadd_s
1170
1134
  #define WSP_GGML_F32x4_MUL __lsx_vfmul_s
1171
- #define WSP_GGML_F32x4_REDUCE(res, x) \
1172
- { \
1173
- int offset = WSP_GGML_F32_ARR >> 1; \
1174
- for (int i = 0; i < offset; ++i) { \
1175
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1176
- } \
1177
- offset >>= 1; \
1178
- for (int i = 0; i < offset; ++i) { \
1179
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1180
- } \
1181
- offset >>= 1; \
1182
- for (int i = 0; i < offset; ++i) { \
1183
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1184
- } \
1185
- __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1186
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1187
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1188
- const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1189
- tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1190
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1191
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1192
- res = (wsp_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1135
+ #define WSP_GGML_F32x4_REDUCE(res, x) \
1136
+ { \
1137
+ int offset = WSP_GGML_F32_ARR >> 1; \
1138
+ for (int i = 0; i < offset; ++i) { \
1139
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1140
+ } \
1141
+ offset >>= 1; \
1142
+ for (int i = 0; i < offset; ++i) { \
1143
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1144
+ } \
1145
+ offset >>= 1; \
1146
+ for (int i = 0; i < offset; ++i) { \
1147
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1148
+ } \
1149
+ __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
1150
+ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
1151
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1152
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1153
+ tmp = __lsx_vsrli_d((__m128i) t0, 32); \
1154
+ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
1155
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1156
+ res = (wsp_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1193
1157
  }
1194
1158
 
1195
1159
  #define WSP_GGML_F32_VEC WSP_GGML_F32x4
@@ -1249,6 +1213,87 @@ static inline void __lsx_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
1249
1213
  #define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL
1250
1214
  #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE
1251
1215
 
1216
+ #elif defined(__VXE__) || defined(__VXE2__)
1217
+
1218
+ #define WSP_GGML_SIMD
1219
+
1220
+ // F32 s390x
1221
+
1222
+ #define WSP_GGML_F32_STEP 32
1223
+ #define WSP_GGML_F32_EPR 4
1224
+
1225
+ #define WSP_GGML_F32x4 __vector float
1226
+ #define WSP_GGML_F32x4_ZERO vec_splats(0.0f)
1227
+ #define WSP_GGML_F32x4_SET1 vec_splats
1228
+ #define WSP_GGML_F32x4_LOAD(p) vec_xl(0, p)
1229
+ #define WSP_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
1230
+ #define WSP_GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
1231
+ #define WSP_GGML_F32x4_ADD vec_add
1232
+ #define WSP_GGML_F32x4_MUL vec_mul
1233
+ #define WSP_GGML_F32x4_REDUCE(res, x) \
1234
+ { \
1235
+ int offset = WSP_GGML_F32_ARR >> 1; \
1236
+ for (int i = 0; i < offset; ++i) { \
1237
+ x[i] = vec_add(x[i], x[offset + i]); \
1238
+ } \
1239
+ offset >>= 1; \
1240
+ for (int i = 0; i < offset; ++i) { \
1241
+ x[i] = vec_add(x[i], x[offset + i]); \
1242
+ } \
1243
+ offset >>= 1; \
1244
+ for (int i = 0; i < offset; ++i) { \
1245
+ x[i] = vec_add(x[i], x[offset + i]); \
1246
+ } \
1247
+ res = vec_extract(x[0], 0) + \
1248
+ vec_extract(x[0], 1) + \
1249
+ vec_extract(x[0], 2) + \
1250
+ vec_extract(x[0], 3); \
1251
+ }
1252
+
1253
+ #define WSP_GGML_F32_VEC WSP_GGML_F32x4
1254
+ #define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32x4_ZERO
1255
+ #define WSP_GGML_F32_VEC_SET1 WSP_GGML_F32x4_SET1
1256
+ #define WSP_GGML_F32_VEC_LOAD WSP_GGML_F32x4_LOAD
1257
+ #define WSP_GGML_F32_VEC_STORE WSP_GGML_F32x4_STORE
1258
+ #define WSP_GGML_F32_VEC_FMA WSP_GGML_F32x4_FMA
1259
+ #define WSP_GGML_F32_VEC_ADD WSP_GGML_F32x4_ADD
1260
+ #define WSP_GGML_F32_VEC_MUL WSP_GGML_F32x4_MUL
1261
+ #define WSP_GGML_F32_VEC_REDUCE WSP_GGML_F32x4_REDUCE
1262
+
1263
+ // F16 s390x
1264
+ #define WSP_GGML_F16_STEP WSP_GGML_F32_STEP
1265
+ #define WSP_GGML_F16_EPR WSP_GGML_F32_EPR
1266
+
1267
+ static inline __vector float __lzs_f16cx4_load(const wsp_ggml_fp16_t * x) {
1268
+ float tmp[4];
1269
+
1270
+ for (int i = 0; i < 4; i++) {
1271
+ tmp[i] = WSP_GGML_FP16_TO_FP32(x[i]);
1272
+ }
1273
+
1274
+ return vec_xl(0, tmp);
1275
+ }
1276
+
1277
+ static inline void __lzs_f16cx4_store(wsp_ggml_fp16_t * x, __vector float y) {
1278
+ float arr[4];
1279
+
1280
+ vec_xst(y, 0, arr);
1281
+
1282
+ for (int i = 0; i < 4; i++) {
1283
+ x[i] = WSP_GGML_FP32_TO_FP16(arr[i]);
1284
+ }
1285
+ }
1286
+
1287
+ #define WSP_GGML_F16_VEC WSP_GGML_F32x4
1288
+ #define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32x4_ZERO
1289
+ #define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32x4_SET1
1290
+ #define WSP_GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
1291
+ #define WSP_GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
1292
+ #define WSP_GGML_F16_VEC_FMA WSP_GGML_F32x4_FMA
1293
+ #define WSP_GGML_F16_VEC_ADD WSP_GGML_F32x4_ADD
1294
+ #define WSP_GGML_F16_VEC_MUL WSP_GGML_F32x4_MUL
1295
+ #define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32x4_REDUCE
1296
+
1252
1297
  #endif
1253
1298
 
1254
1299
  // WSP_GGML_F32_ARR / WSP_GGML_F16_ARR
@@ -1328,12 +1373,12 @@ struct wsp_ggml_threadpool {
1328
1373
  atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1329
1374
  atomic_int WSP_GGML_CACHE_ALIGN n_barrier;
1330
1375
  atomic_int WSP_GGML_CACHE_ALIGN n_barrier_passed;
1331
- atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1376
+ atomic_int WSP_GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1332
1377
 
1333
1378
  // these are atomic as an annotation for thread-sanitizer
1334
1379
  atomic_bool stop; // Used for stopping the threadpool altogether
1335
1380
  atomic_bool pause; // Used for pausing the threadpool or individual threads
1336
- atomic_bool abort; // Used for aborting processing of a graph
1381
+ atomic_int abort; // Used for aborting processing of a graph
1337
1382
 
1338
1383
  struct wsp_ggml_compute_state * workers; // per thread state
1339
1384
  int n_threads_max; // number of threads in the pool
@@ -1357,41 +1402,48 @@ struct wsp_ggml_compute_state {
1357
1402
  int ith;
1358
1403
  };
1359
1404
 
1360
- struct wsp_ggml_compute_params {
1361
- // ith = thread index, nth = number of threads
1362
- int ith, nth;
1363
-
1364
- // work buffer for all threads
1365
- size_t wsize;
1366
- void * wdata;
1367
-
1368
- struct wsp_ggml_threadpool * threadpool;
1369
- };
1370
-
1371
1405
  //
1372
1406
  // fundamental operations
1373
1407
  //
1374
1408
 
1375
1409
  inline static void wsp_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1376
-
1377
1410
  inline static void wsp_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1378
1411
 
1379
- inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1412
+ inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1413
+ inline static void wsp_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1380
1414
 
1381
1415
  inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1382
-
1383
1416
  inline static void wsp_ggml_vec_set_bf16(const int n, wsp_ggml_bf16_t * x, const wsp_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1384
-
1385
1417
  inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1418
+ inline static void wsp_ggml_vec_add_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
1419
+ for (int i = 0; i < n; ++i) {
1420
+ z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) + WSP_GGML_FP16_TO_FP32(y[i]));
1421
+ }
1422
+ }
1386
1423
  inline static void wsp_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1387
1424
  inline static void wsp_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
1388
1425
  inline static void wsp_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
1389
1426
  inline static void wsp_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
1427
+ inline static void wsp_ggml_vec_sub_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
1428
+ for (int i = 0; i < n; ++i) {
1429
+ z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) - WSP_GGML_FP16_TO_FP32(y[i]));
1430
+ }
1431
+ }
1390
1432
  inline static void wsp_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
1391
1433
  inline static void wsp_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1392
1434
  inline static void wsp_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
1393
1435
  inline static void wsp_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1436
+ inline static void wsp_ggml_vec_mul_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
1437
+ for (int i = 0; i < n; ++i) {
1438
+ z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) * WSP_GGML_FP16_TO_FP32(y[i]));
1439
+ }
1440
+ }
1394
1441
  inline static void wsp_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
1442
+ inline static void wsp_ggml_vec_div_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
1443
+ for (int i = 0; i < n; ++i) {
1444
+ z[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(x[i]) / WSP_GGML_FP16_TO_FP32(y[i]));
1445
+ }
1446
+ }
1395
1447
 
1396
1448
  static void wsp_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
1397
1449
  assert(nrc == 1);
@@ -1868,7 +1920,7 @@ inline static float wsp_ggml_silu_f32(float x) {
1868
1920
 
1869
1921
  #if __FINITE_MATH_ONLY__
1870
1922
  #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
1871
- #error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
1923
+ #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
1872
1924
  #endif
1873
1925
 
1874
1926
  #if defined(__ARM_NEON) && defined(__aarch64__)
@@ -2276,7 +2328,7 @@ struct wsp_ggml_state {
2276
2328
 
2277
2329
  static struct wsp_ggml_state g_state = {0};
2278
2330
 
2279
- static void wsp_ggml_barrier(struct wsp_ggml_threadpool * tp) {
2331
+ void wsp_ggml_barrier(struct wsp_ggml_threadpool * tp) {
2280
2332
  int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
2281
2333
  if (n_threads == 1) {
2282
2334
  return;
@@ -2430,7 +2482,11 @@ bool wsp_ggml_is_numa(void) {
2430
2482
  #endif
2431
2483
 
2432
2484
  #if !defined(HWCAP2_I8MM)
2433
- #define HWCAP2_I8MM 0
2485
+ #define HWCAP2_I8MM (1 << 13)
2486
+ #endif
2487
+
2488
+ #if !defined(HWCAP2_SME)
2489
+ #define HWCAP2_SME (1 << 23)
2434
2490
  #endif
2435
2491
 
2436
2492
  static void wsp_ggml_init_arm_arch_features(void) {
@@ -2438,9 +2494,11 @@ static void wsp_ggml_init_arm_arch_features(void) {
2438
2494
  uint32_t hwcap = getauxval(AT_HWCAP);
2439
2495
  uint32_t hwcap2 = getauxval(AT_HWCAP2);
2440
2496
 
2441
- wsp_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2442
- wsp_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2443
- wsp_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2497
+ wsp_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2498
+ wsp_ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
2499
+ wsp_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2500
+ wsp_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2501
+ wsp_ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
2444
2502
 
2445
2503
  #if defined(__ARM_FEATURE_SVE)
2446
2504
  wsp_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2453,11 +2511,21 @@ static void wsp_ggml_init_arm_arch_features(void) {
2453
2511
  }
2454
2512
  wsp_ggml_arm_arch_features.has_neon = oldp;
2455
2513
 
2514
+ if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
2515
+ oldp = 0;
2516
+ }
2517
+ wsp_ggml_arm_arch_features.has_dotprod = oldp;
2518
+
2456
2519
  if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
2457
2520
  oldp = 0;
2458
2521
  }
2459
2522
  wsp_ggml_arm_arch_features.has_i8mm = oldp;
2460
2523
 
2524
+ if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
2525
+ oldp = 0;
2526
+ }
2527
+ wsp_ggml_arm_arch_features.has_sme = oldp;
2528
+
2461
2529
  wsp_ggml_arm_arch_features.has_sve = 0;
2462
2530
  wsp_ggml_arm_arch_features.sve_cnt = 0;
2463
2531
  #else
@@ -2481,6 +2549,12 @@ static void wsp_ggml_init_arm_arch_features(void) {
2481
2549
  wsp_ggml_arm_arch_features.has_sve = 0;
2482
2550
  wsp_ggml_arm_arch_features.sve_cnt = 0;
2483
2551
  #endif
2552
+
2553
+ #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
2554
+ wsp_ggml_arm_arch_features.has_sme = 1;
2555
+ #else
2556
+ wsp_ggml_arm_arch_features.has_sme = 0;
2557
+ #endif
2484
2558
  #endif
2485
2559
  }
2486
2560
  #endif
@@ -4005,6 +4079,57 @@ static void wsp_ggml_compute_forward_dup_bytes(
4005
4079
  }
4006
4080
  }
4007
4081
 
4082
+ static void wsp_ggml_compute_forward_dup_q(
4083
+ const struct wsp_ggml_compute_params * params,
4084
+ struct wsp_ggml_tensor * dst) {
4085
+
4086
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
4087
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
4088
+
4089
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
4090
+
4091
+ const enum wsp_ggml_type type = src0->type;
4092
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = wsp_ggml_get_type_traits(type)->to_float;
4093
+
4094
+ size_t qk = wsp_ggml_blck_size(type);
4095
+ const int64_t nr = wsp_ggml_nelements(src1) / qk;
4096
+
4097
+ // destination must be contiguous in the first dimension
4098
+ WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(dst->type));
4099
+ // must either have first dimension large enough to hold a row, or fully contiguous
4100
+ WSP_GGML_ASSERT((ne10 % qk) == 0 || wsp_ggml_is_contiguous(dst));
4101
+
4102
+ const int ith = params->ith;
4103
+ const int nth = params->nth;
4104
+
4105
+ const int dr = (nr + nth - 1)/nth;
4106
+
4107
+ // row range for this thread
4108
+ const int ir0 = dr*ith;
4109
+ const int ir1 = MIN(ir0 + dr, nr);
4110
+
4111
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
4112
+
4113
+ uint32_t i = ir * qk;
4114
+
4115
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
4116
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
4117
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
4118
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
4119
+ const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
4120
+
4121
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
4122
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
4123
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
4124
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
4125
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
4126
+
4127
+ wsp_dewsp_quantize_row_q(
4128
+ (const void *) ((char *) src0->data + x_offset),
4129
+ (float *) ((char *) dst->data + dst_offset), qk);
4130
+ }
4131
+ }
4132
+
4008
4133
  static void wsp_ggml_compute_forward_dup(
4009
4134
  const struct wsp_ggml_compute_params * params,
4010
4135
  struct wsp_ggml_tensor * dst) {
@@ -4031,6 +4156,10 @@ static void wsp_ggml_compute_forward_dup(
4031
4156
  } break;
4032
4157
  default:
4033
4158
  {
4159
+ if (wsp_ggml_is_quantized(src0->type) && dst->type == WSP_GGML_TYPE_F32) {
4160
+ wsp_ggml_compute_forward_dup_q(params, dst);
4161
+ break;
4162
+ }
4034
4163
  WSP_GGML_ABORT("fatal error");
4035
4164
  }
4036
4165
  }
@@ -4270,7 +4399,7 @@ static void wsp_ggml_compute_forward_add_f16_f16(
4270
4399
  const struct wsp_ggml_tensor * src0 = dst->src[0];
4271
4400
  const struct wsp_ggml_tensor * src1 = dst->src[1];
4272
4401
 
4273
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst));
4402
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
4274
4403
 
4275
4404
  const int ith = params->ith;
4276
4405
  const int nth = params->nth;
@@ -4295,17 +4424,22 @@ static void wsp_ggml_compute_forward_add_f16_f16(
4295
4424
 
4296
4425
  if (nb10 == sizeof(wsp_ggml_fp16_t)) {
4297
4426
  for (int ir = ir0; ir < ir1; ++ir) {
4298
- // src0, src1 and dst are same shape => same indices
4299
- const int i3 = ir/(ne2*ne1);
4300
- const int i2 = (ir - i3*ne2*ne1)/ne1;
4301
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
4427
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
4428
+ const int64_t i03 = ir/(ne02*ne01);
4429
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
4430
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
4431
+
4432
+ const int64_t i13 = i03 % ne13;
4433
+ const int64_t i12 = i02 % ne12;
4434
+ const int64_t i11 = i01 % ne11;
4435
+ const int64_t nr0 = ne00 / ne10;
4302
4436
 
4303
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
4304
- wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
4305
- wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
4437
+ wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
4438
+ wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
4439
+ wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
4306
4440
 
4307
- for (int i = 0; i < ne0; i++) {
4308
- dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + WSP_GGML_FP16_TO_FP32(src1_ptr[i]));
4441
+ for (int64_t r = 0; r < nr0; ++r) {
4442
+ wsp_ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
4309
4443
  }
4310
4444
  }
4311
4445
  }
@@ -4505,9 +4639,6 @@ static void wsp_ggml_compute_forward_add(
4505
4639
  case WSP_GGML_TYPE_IQ4_XS:
4506
4640
  case WSP_GGML_TYPE_IQ3_S:
4507
4641
  case WSP_GGML_TYPE_IQ2_S:
4508
- case WSP_GGML_TYPE_Q4_0_4_4:
4509
- case WSP_GGML_TYPE_Q4_0_4_8:
4510
- case WSP_GGML_TYPE_Q4_0_8_8:
4511
4642
  {
4512
4643
  wsp_ggml_compute_forward_add_q_f32(params, dst);
4513
4644
  } break;
@@ -4885,9 +5016,6 @@ static void wsp_ggml_compute_forward_add1(
4885
5016
  case WSP_GGML_TYPE_IQ4_XS:
4886
5017
  case WSP_GGML_TYPE_IQ3_S:
4887
5018
  case WSP_GGML_TYPE_IQ2_S:
4888
- case WSP_GGML_TYPE_Q4_0_4_4:
4889
- case WSP_GGML_TYPE_Q4_0_4_8:
4890
- case WSP_GGML_TYPE_Q4_0_8_8:
4891
5019
  {
4892
5020
  wsp_ggml_compute_forward_add1_q_f32(params, dst);
4893
5021
  } break;
@@ -5015,9 +5143,6 @@ static void wsp_ggml_compute_forward_acc(
5015
5143
  case WSP_GGML_TYPE_IQ4_XS:
5016
5144
  case WSP_GGML_TYPE_IQ3_S:
5017
5145
  case WSP_GGML_TYPE_IQ2_S:
5018
- case WSP_GGML_TYPE_Q4_0_4_4:
5019
- case WSP_GGML_TYPE_Q4_0_4_8:
5020
- case WSP_GGML_TYPE_Q4_0_8_8:
5021
5146
  default:
5022
5147
  {
5023
5148
  WSP_GGML_ABORT("fatal error");
@@ -5102,6 +5227,62 @@ static void wsp_ggml_compute_forward_sub_f32(
5102
5227
  }
5103
5228
  }
5104
5229
 
5230
+ static void wsp_ggml_compute_forward_sub_f16(
5231
+ const struct wsp_ggml_compute_params * params,
5232
+ struct wsp_ggml_tensor * dst) {
5233
+
5234
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
5235
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
5236
+
5237
+ assert(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
5238
+
5239
+ const int ith = params->ith;
5240
+ const int nth = params->nth;
5241
+
5242
+ const int nr = wsp_ggml_nrows(src0);
5243
+
5244
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
5245
+
5246
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
5247
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
5248
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
5249
+
5250
+ WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
5251
+ WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
5252
+
5253
+ // rows per thread
5254
+ const int dr = (nr + nth - 1)/nth;
5255
+
5256
+ // row range for this thread
5257
+ const int ir0 = dr*ith;
5258
+ const int ir1 = MIN(ir0 + dr, nr);
5259
+
5260
+ if (nb10 == sizeof(wsp_ggml_fp16_t)) {
5261
+ for (int ir = ir0; ir < ir1; ++ir) {
5262
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
5263
+ const int64_t i03 = ir/(ne02*ne01);
5264
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5265
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5266
+
5267
+ const int64_t i13 = i03 % ne13;
5268
+ const int64_t i12 = i02 % ne12;
5269
+ const int64_t i11 = i01 % ne11;
5270
+ const int64_t nr0 = ne00 / ne10;
5271
+
5272
+ wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5273
+ wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5274
+ wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5275
+
5276
+ for (int64_t r = 0; r < nr0; ++r) {
5277
+ wsp_ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5278
+ }
5279
+ }
5280
+ } else {
5281
+ // src1 is not contiguous
5282
+ WSP_GGML_ABORT("unimplemented error");
5283
+ }
5284
+ }
5285
+
5105
5286
  static void wsp_ggml_compute_forward_sub(
5106
5287
  const struct wsp_ggml_compute_params * params,
5107
5288
  struct wsp_ggml_tensor * dst) {
@@ -5113,6 +5294,10 @@ static void wsp_ggml_compute_forward_sub(
5113
5294
  {
5114
5295
  wsp_ggml_compute_forward_sub_f32(params, dst);
5115
5296
  } break;
5297
+ case WSP_GGML_TYPE_F16:
5298
+ {
5299
+ wsp_ggml_compute_forward_sub_f16(params, dst);
5300
+ } break;
5116
5301
  default:
5117
5302
  {
5118
5303
  WSP_GGML_ABORT("fatal error");
@@ -5193,32 +5378,9 @@ static void wsp_ggml_compute_forward_mul_f32(
5193
5378
  }
5194
5379
  }
5195
5380
 
5196
- static void wsp_ggml_compute_forward_mul(
5197
- const struct wsp_ggml_compute_params * params,
5198
- struct wsp_ggml_tensor * dst) {
5199
-
5200
- const struct wsp_ggml_tensor * src0 = dst->src[0];
5201
- const struct wsp_ggml_tensor * src1 = dst->src[1];
5202
-
5203
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32 && "only f32 src1 supported for now");
5204
-
5205
- switch (src0->type) {
5206
- case WSP_GGML_TYPE_F32:
5207
- {
5208
- wsp_ggml_compute_forward_mul_f32(params, dst);
5209
- } break;
5210
- default:
5211
- {
5212
- WSP_GGML_ABORT("fatal error");
5213
- }
5214
- }
5215
- }
5216
-
5217
- // wsp_ggml_compute_forward_div
5218
-
5219
- static void wsp_ggml_compute_forward_div_f32(
5220
- const struct wsp_ggml_compute_params * params,
5221
- struct wsp_ggml_tensor * dst) {
5381
+ static void wsp_ggml_compute_forward_mul_f16(
5382
+ const struct wsp_ggml_compute_params * params,
5383
+ struct wsp_ggml_tensor * dst) {
5222
5384
 
5223
5385
  const struct wsp_ggml_tensor * src0 = dst->src[0];
5224
5386
  const struct wsp_ggml_tensor * src1 = dst->src[1];
@@ -5232,8 +5394,84 @@ static void wsp_ggml_compute_forward_div_f32(
5232
5394
 
5233
5395
  WSP_GGML_TENSOR_BINARY_OP_LOCALS
5234
5396
 
5235
- WSP_GGML_ASSERT( nb0 == sizeof(float));
5236
- WSP_GGML_ASSERT(nb00 == sizeof(float));
5397
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
5398
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
5399
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
5400
+
5401
+ WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
5402
+ WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
5403
+
5404
+ if (nb10 == sizeof(wsp_ggml_fp16_t)) {
5405
+ for (int64_t ir = ith; ir < nr; ir += nth) {
5406
+ // src0 and dst are same shape => same indices
5407
+ const int64_t i03 = ir/(ne02*ne01);
5408
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5409
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5410
+
5411
+ const int64_t i13 = i03 % ne13;
5412
+ const int64_t i12 = i02 % ne12;
5413
+ const int64_t i11 = i01 % ne11;
5414
+ const int64_t nr0 = ne00 / ne10;
5415
+
5416
+ wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5417
+ wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5418
+ wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5419
+
5420
+ for (int64_t r = 0 ; r < nr0; ++r) {
5421
+ wsp_ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5422
+ }
5423
+ }
5424
+ } else {
5425
+ // src1 is not contiguous
5426
+ WSP_GGML_ABORT("unimplemented error");
5427
+ }
5428
+ }
5429
+
5430
+ static void wsp_ggml_compute_forward_mul(
5431
+ const struct wsp_ggml_compute_params * params,
5432
+ struct wsp_ggml_tensor * dst) {
5433
+
5434
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
5435
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
5436
+
5437
+ WSP_GGML_ASSERT((src1->type == WSP_GGML_TYPE_F32 || src1->type == WSP_GGML_TYPE_F16) && "only f32/f16 src1 supported for now");
5438
+
5439
+ switch (src0->type) {
5440
+ case WSP_GGML_TYPE_F32:
5441
+ {
5442
+ wsp_ggml_compute_forward_mul_f32(params, dst);
5443
+ } break;
5444
+ case WSP_GGML_TYPE_F16:
5445
+ {
5446
+ wsp_ggml_compute_forward_mul_f16(params, dst);
5447
+ } break;
5448
+ default:
5449
+ {
5450
+ WSP_GGML_ABORT("fatal error");
5451
+ }
5452
+ }
5453
+ }
5454
+
5455
+ // wsp_ggml_compute_forward_div
5456
+
5457
+ static void wsp_ggml_compute_forward_div_f32(
5458
+ const struct wsp_ggml_compute_params * params,
5459
+ struct wsp_ggml_tensor * dst) {
5460
+
5461
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
5462
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
5463
+
5464
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
5465
+
5466
+ const int ith = params->ith;
5467
+ const int nth = params->nth;
5468
+
5469
+ const int64_t nr = wsp_ggml_nrows(src0);
5470
+
5471
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
5472
+
5473
+ WSP_GGML_ASSERT( nb0 == sizeof(float));
5474
+ WSP_GGML_ASSERT(nb00 == sizeof(float));
5237
5475
 
5238
5476
  if (nb10 == sizeof(float)) {
5239
5477
  for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -5287,6 +5525,55 @@ static void wsp_ggml_compute_forward_div_f32(
5287
5525
  }
5288
5526
  }
5289
5527
 
5528
+ static void wsp_ggml_compute_forward_div_f16(
5529
+ const struct wsp_ggml_compute_params * params,
5530
+ struct wsp_ggml_tensor * dst) {
5531
+
5532
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
5533
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
5534
+
5535
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
5536
+
5537
+ const int ith = params->ith;
5538
+ const int nth = params->nth;
5539
+
5540
+ const int64_t nr = wsp_ggml_nrows(src0);
5541
+
5542
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
5543
+
5544
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
5545
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
5546
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16);
5547
+
5548
+ WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t));
5549
+ WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
5550
+
5551
+ if (nb10 == sizeof(wsp_ggml_fp16_t)) {
5552
+ for (int64_t ir = ith; ir < nr; ir += nth) {
5553
+ // src0 and dst are same shape => same indices
5554
+ const int64_t i03 = ir/(ne02*ne01);
5555
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
5556
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
5557
+
5558
+ const int64_t i13 = i03 % ne13;
5559
+ const int64_t i12 = i02 % ne12;
5560
+ const int64_t i11 = i01 % ne11;
5561
+ const int64_t nr0 = ne00 / ne10;
5562
+
5563
+ wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
5564
+ wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
5565
+ wsp_ggml_fp16_t * src1_ptr = (wsp_ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
5566
+
5567
+ for (int64_t r = 0; r < nr0; ++r) {
5568
+ wsp_ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
5569
+ }
5570
+ }
5571
+ } else {
5572
+ // src1 is not contiguous
5573
+ WSP_GGML_ABORT("unimplemented error");
5574
+ }
5575
+ }
5576
+
5290
5577
  static void wsp_ggml_compute_forward_div(
5291
5578
  const struct wsp_ggml_compute_params * params,
5292
5579
  struct wsp_ggml_tensor * dst) {
@@ -5298,6 +5585,10 @@ static void wsp_ggml_compute_forward_div(
5298
5585
  {
5299
5586
  wsp_ggml_compute_forward_div_f32(params, dst);
5300
5587
  } break;
5588
+ case WSP_GGML_TYPE_F16:
5589
+ {
5590
+ wsp_ggml_compute_forward_div_f16(params, dst);
5591
+ } break;
5301
5592
  default:
5302
5593
  {
5303
5594
  WSP_GGML_ABORT("fatal error");
@@ -6738,20 +7029,20 @@ static void wsp_ggml_compute_forward_silu_back_f32(
6738
7029
  const struct wsp_ggml_compute_params * params,
6739
7030
  struct wsp_ggml_tensor * dst) {
6740
7031
 
6741
- const struct wsp_ggml_tensor * src0 = dst->src[0];
6742
- const struct wsp_ggml_tensor * grad = dst->src[1];
7032
+ const struct wsp_ggml_tensor * grad = dst->src[0];
7033
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
6743
7034
 
6744
7035
  assert(wsp_ggml_is_contiguous_1(grad));
6745
- assert(wsp_ggml_is_contiguous_1(src0));
7036
+ assert(wsp_ggml_is_contiguous_1(src1));
6746
7037
  assert(wsp_ggml_is_contiguous_1(dst));
6747
- assert(wsp_ggml_are_same_shape(src0, dst));
6748
- assert(wsp_ggml_are_same_shape(src0, grad));
7038
+ assert(wsp_ggml_are_same_shape(src1, dst));
7039
+ assert(wsp_ggml_are_same_shape(src1, grad));
6749
7040
 
6750
7041
  const int ith = params->ith;
6751
7042
  const int nth = params->nth;
6752
7043
 
6753
- const int nc = src0->ne[0];
6754
- const int nr = wsp_ggml_nrows(src0);
7044
+ const int nc = src1->ne[0];
7045
+ const int nr = wsp_ggml_nrows(src1);
6755
7046
 
6756
7047
  // rows per thread
6757
7048
  const int dr = (nr + nth - 1)/nth;
@@ -6763,7 +7054,7 @@ static void wsp_ggml_compute_forward_silu_back_f32(
6763
7054
  for (int i1 = ir0; i1 < ir1; i1++) {
6764
7055
  wsp_ggml_vec_silu_backward_f32(nc,
6765
7056
  (float *) ((char *) dst->data + i1*( dst->nb[1])),
6766
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
7057
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
6767
7058
  (float *) ((char *) grad->data + i1*(grad->nb[1])));
6768
7059
 
6769
7060
  #ifndef NDEBUG
@@ -6942,7 +7233,7 @@ static void wsp_ggml_compute_forward_norm_f32(
6942
7233
  float eps;
6943
7234
  memcpy(&eps, dst->op_params, sizeof(float));
6944
7235
 
6945
- WSP_GGML_ASSERT(eps > 0.0f);
7236
+ WSP_GGML_ASSERT(eps >= 0.0f);
6946
7237
 
6947
7238
  // TODO: optimize
6948
7239
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7013,7 +7304,7 @@ static void wsp_ggml_compute_forward_rms_norm_f32(
7013
7304
  float eps;
7014
7305
  memcpy(&eps, dst->op_params, sizeof(float));
7015
7306
 
7016
- WSP_GGML_ASSERT(eps > 0.0f);
7307
+ WSP_GGML_ASSERT(eps >= 0.0f);
7017
7308
 
7018
7309
  // TODO: optimize
7019
7310
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7065,12 +7356,13 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
7065
7356
  const struct wsp_ggml_compute_params * params,
7066
7357
  struct wsp_ggml_tensor * dst) {
7067
7358
 
7068
- const struct wsp_ggml_tensor * src0 = dst->src[0];
7069
- const struct wsp_ggml_tensor * src1 = dst->src[1];
7359
+ const struct wsp_ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
7360
+ const struct wsp_ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
7070
7361
 
7071
7362
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst) && wsp_ggml_are_same_shape(src0, src1));
7072
7363
 
7073
7364
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
7365
+ WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
7074
7366
 
7075
7367
  const int ith = params->ith;
7076
7368
  const int nth = params->nth;
@@ -7089,8 +7381,8 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
7089
7381
  const int64_t i12 = i02;
7090
7382
  const int64_t i13 = i03;
7091
7383
 
7092
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7093
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7384
+ const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7385
+ const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7094
7386
 
7095
7387
  wsp_ggml_float sum_xx = 0.0;
7096
7388
  wsp_ggml_float sum_xdz = 0.0;
@@ -7113,9 +7405,9 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
7113
7405
  {
7114
7406
  // z = rms_norm(x)
7115
7407
  //
7116
- // rms_norm(src0) =
7408
+ // rms_norm(src1) =
7117
7409
  // scale(
7118
- // src0,
7410
+ // src1,
7119
7411
  // div(
7120
7412
  // 1,
7121
7413
  // sqrt(
@@ -7123,13 +7415,13 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
7123
7415
  // scale(
7124
7416
  // sum(
7125
7417
  // sqr(
7126
- // src0)),
7418
+ // src1)),
7127
7419
  // (1.0/N)),
7128
7420
  // eps))));
7129
7421
 
7130
7422
  // postorder:
7131
7423
  // ## op args grad
7132
- // 00 param src0 grad[#00]
7424
+ // 00 param src1 grad[#00]
7133
7425
  // 01 const 1
7134
7426
  // 02 sqr (#00) grad[#02]
7135
7427
  // 03 sum (#02) grad[#03]
@@ -7206,6 +7498,7 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32(
7206
7498
  // dx := scale(dx, rrms)
7207
7499
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7208
7500
 
7501
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
7209
7502
  wsp_ggml_vec_cpy_f32 (ne00, dx, x);
7210
7503
  // wsp_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7211
7504
  wsp_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7433,20 +7726,9 @@ static void wsp_ggml_compute_forward_mul_mat(
7433
7726
  const int ith = params->ith;
7434
7727
  const int nth = params->nth;
7435
7728
 
7436
- enum wsp_ggml_type type = src0->type;
7437
-
7438
- if (src0->buffer && wsp_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
7439
- type = (enum wsp_ggml_type)(intptr_t)src0->extra;
7440
- }
7441
-
7442
- enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7729
+ enum wsp_ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
7443
7730
  wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7444
- wsp_ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat;
7445
- int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
7446
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
7447
- int64_t const blck_size_interleave = wsp_ggml_get_type_traits(type)->blck_size_interleave;
7448
- wsp_ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
7449
- wsp_ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
7731
+ int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
7450
7732
 
7451
7733
  WSP_GGML_ASSERT(ne0 == ne01);
7452
7734
  WSP_GGML_ASSERT(ne1 == ne11);
@@ -7454,7 +7736,7 @@ static void wsp_ggml_compute_forward_mul_mat(
7454
7736
  WSP_GGML_ASSERT(ne3 == ne13);
7455
7737
 
7456
7738
  // we don't support permuted src0 or src1
7457
- WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
7739
+ WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
7458
7740
  WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
7459
7741
 
7460
7742
  // dst cannot be transposed or permuted
@@ -7466,6 +7748,7 @@ static void wsp_ggml_compute_forward_mul_mat(
7466
7748
  // nb01 >= nb00 - src0 is not transposed
7467
7749
  // compute by src0 rows
7468
7750
 
7751
+ // TODO: extract to "extra_op"
7469
7752
  #if WSP_GGML_USE_LLAMAFILE
7470
7753
  // broadcast factors
7471
7754
  const int64_t r2 = ne12 / ne02;
@@ -7476,15 +7759,15 @@ static void wsp_ggml_compute_forward_mul_mat(
7476
7759
  if (src1_cont) {
7477
7760
  for (int64_t i13 = 0; i13 < ne13; i13++)
7478
7761
  for (int64_t i12 = 0; i12 < ne12; i12++)
7479
- if (!llamafile_sgemm(ne01, ne11, ne00/wsp_ggml_blck_size(type),
7762
+ if (!llamafile_sgemm(params,
7763
+ ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
7480
7764
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7481
- nb01/wsp_ggml_type_size(type),
7765
+ nb01/wsp_ggml_type_size(src0->type),
7482
7766
  (const char *)src1->data + i12*nb12 + i13*nb13,
7483
7767
  nb11/wsp_ggml_type_size(src1->type),
7484
7768
  (char *)dst->data + i12*nb2 + i13*nb3,
7485
7769
  nb1/wsp_ggml_type_size(dst->type),
7486
- ith, nth,
7487
- type,
7770
+ src0->type,
7488
7771
  src1->type,
7489
7772
  dst->type))
7490
7773
  goto UseGgmlGemm1;
@@ -7496,6 +7779,7 @@ UseGgmlGemm1:;
7496
7779
  if (src1->type != vec_dot_type) {
7497
7780
  char * wdata = params->wdata;
7498
7781
 
7782
+ const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
7499
7783
  const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
7500
7784
  const size_t nbw2 = nbw1*ne11;
7501
7785
  const size_t nbw3 = nbw2*ne12;
@@ -7503,24 +7787,30 @@ UseGgmlGemm1:;
7503
7787
  assert(params->wsize >= ne13*nbw3);
7504
7788
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
7505
7789
 
7790
+ #if 0
7506
7791
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7507
7792
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7508
- int64_t i11_processed = 0;
7509
- if ((wsp_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
7510
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
7511
- from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7512
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7513
- 4, ne10, blck_size_interleave);
7514
- }
7515
- i11_processed = ne11 - ne11 % 4;
7516
- }
7517
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
7793
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
7518
7794
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7519
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7520
- ne10);
7795
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7796
+ ne10);
7797
+ }
7798
+ }
7799
+ }
7800
+ #else
7801
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7802
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7803
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7804
+ size_t bs = wsp_ggml_blck_size(vec_dot_type);
7805
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
7806
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7807
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7808
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7809
+ (ne10_block_end - ne10_block_start) * bs);
7521
7810
  }
7522
7811
  }
7523
7812
  }
7813
+ #endif
7524
7814
  }
7525
7815
 
7526
7816
  if (ith == 0) {
@@ -7537,15 +7827,15 @@ UseGgmlGemm1:;
7537
7827
 
7538
7828
  for (int64_t i13 = 0; i13 < ne13; i13++)
7539
7829
  for (int64_t i12 = 0; i12 < ne12; i12++)
7540
- if (!llamafile_sgemm(ne01, ne11, ne00/wsp_ggml_blck_size(type),
7830
+ if (!llamafile_sgemm(params,
7831
+ ne01, ne11, ne00/wsp_ggml_blck_size(src0->type),
7541
7832
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7542
- nb01/wsp_ggml_type_size(type),
7833
+ nb01/wsp_ggml_type_size(src0->type),
7543
7834
  (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
7544
7835
  row_size/wsp_ggml_type_size(vec_dot_type),
7545
7836
  (char *)dst->data + i12*nb2 + i13*nb3,
7546
7837
  nb1/wsp_ggml_type_size(dst->type),
7547
- ith, nth,
7548
- type,
7838
+ src0->type,
7549
7839
  vec_dot_type,
7550
7840
  dst->type))
7551
7841
  goto UseGgmlGemm2;
@@ -7560,14 +7850,6 @@ UseGgmlGemm2:;
7560
7850
  // This is the size of the rest of the dimensions of the result
7561
7851
  const int64_t nr1 = ne1 * ne2 * ne3;
7562
7852
 
7563
- // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
7564
- int64_t num_rows_per_vec_dot = vec_dot_num_rows;
7565
- // TODO: currently the mmla kernels support only even numbered rows/cols.
7566
- // this check can be removed once they are extended to support odd numbered rows/cols too
7567
- if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
7568
- num_rows_per_vec_dot = 1;
7569
- }
7570
-
7571
7853
  // Now select a reasonable chunk size.
7572
7854
  int chunk_size = 16;
7573
7855
 
@@ -7583,7 +7865,7 @@ UseGgmlGemm2:;
7583
7865
  int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
7584
7866
 
7585
7867
  // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
7586
- // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
7868
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915
7587
7869
  // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
7588
7870
  if (nchunk0 * nchunk1 < nth * 4 || wsp_ggml_is_numa()) {
7589
7871
  // distribute the thread work across the inner or outer loop based on which one is larger
@@ -7595,28 +7877,6 @@ UseGgmlGemm2:;
7595
7877
  const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
7596
7878
  const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
7597
7879
 
7598
- if ((wsp_ggml_n_dims(src0) == 2) && gemv) {
7599
- const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7600
- const size_t src1_col_stride = wsp_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? wsp_ggml_row_size(vec_dot_type, ne10) : nb11;
7601
- int64_t src0_start = (ith * ne01) / nth;
7602
- int64_t src0_end = ((ith + 1) * ne01) / nth;
7603
- src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
7604
- src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
7605
- if (src0_start >= src0_end) return;
7606
-
7607
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
7608
- if (gemm && (ne11 > 3)) {
7609
- gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
7610
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
7611
- }
7612
- for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
7613
- gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
7614
- (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
7615
- src0_end - src0_start);
7616
- }
7617
- return;
7618
- }
7619
-
7620
7880
  // The first chunk comes from our thread_id, the rest will get auto-assigned.
7621
7881
  int current_chunk = ith;
7622
7882
 
@@ -7630,7 +7890,15 @@ UseGgmlGemm2:;
7630
7890
  const int64_t ir1_start = dr1 * ith1;
7631
7891
  const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7632
7892
 
7633
- wsp_ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7893
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
7894
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
7895
+
7896
+ // these checks are needed to avoid crossing dim1 boundaries
7897
+ // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
7898
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
7899
+ num_rows_per_vec_dot = 1;
7900
+ }
7901
+ wsp_ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7634
7902
 
7635
7903
  if (nth >= nchunk0 * nchunk1) {
7636
7904
  break;
@@ -7642,6 +7910,84 @@ UseGgmlGemm2:;
7642
7910
 
7643
7911
  // wsp_ggml_compute_forward_mul_mat_id
7644
7912
 
7913
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
7914
+
7915
+ struct mmid_row_mapping {
7916
+ int32_t i1;
7917
+ int32_t i2;
7918
+ };
7919
+
7920
+ static void wsp_ggml_compute_forward_mul_mat_id_one_chunk(
7921
+ struct wsp_ggml_tensor * dst,
7922
+ const struct wsp_ggml_tensor * src0,
7923
+ const struct wsp_ggml_tensor * src1,
7924
+ const struct wsp_ggml_tensor * ids,
7925
+ const int64_t cur_a,
7926
+ const int64_t ir0_start,
7927
+ const int64_t ir0_end,
7928
+ const int64_t ir1_start,
7929
+ const int64_t ir1_end,
7930
+ const char * src0_cur,
7931
+ const struct mmid_row_mapping * matrix_rows,
7932
+ const size_t row_size,
7933
+ const bool src1_cont,
7934
+ const void * wdata) {
7935
+
7936
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
7937
+
7938
+ const enum wsp_ggml_type type = src0->type;
7939
+
7940
+ wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7941
+ enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7942
+
7943
+ const int64_t blck_0 = 16;
7944
+ const int64_t blck_1 = 16;
7945
+
7946
+ float tmp[16];
7947
+
7948
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
7949
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
7950
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
7951
+ const int64_t _i12 = ir1; // logical row index for this expert
7952
+
7953
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7954
+ const int id = row_mapping.i1; // selected expert index
7955
+
7956
+ const int64_t i11 = id % ne11;
7957
+ const int64_t i12 = row_mapping.i2; // row index in src1
7958
+
7959
+ const int64_t i1 = id; // selected expert index
7960
+ const int64_t i2 = i12; // row
7961
+
7962
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7963
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7964
+ // the original src1 data pointer, so we should index using the indices directly
7965
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
7966
+ const char * src1_col = (const char *) wdata +
7967
+ (src1_cont || src1->type != vec_dot_type
7968
+ ? (i11 + i12*ne11)*row_size
7969
+ : (i11*nb11 + i12*nb12));
7970
+
7971
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7972
+
7973
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
7974
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7975
+ }
7976
+
7977
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
7978
+ }
7979
+ }
7980
+ }
7981
+ }
7982
+
7983
+ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
7984
+
7985
+ void * ptr = *p;
7986
+ ptr = (void *) WSP_GGML_PAD((uintptr_t) ptr, align);
7987
+ *p = (void *) ((char *) ptr + size);
7988
+ return ptr;
7989
+ }
7990
+
7645
7991
  static void wsp_ggml_compute_forward_mul_mat_id(
7646
7992
  const struct wsp_ggml_compute_params * params,
7647
7993
  struct wsp_ggml_tensor * dst) {
@@ -7659,11 +8005,8 @@ static void wsp_ggml_compute_forward_mul_mat_id(
7659
8005
 
7660
8006
  const bool src1_cont = wsp_ggml_is_contiguous(src1);
7661
8007
 
7662
- wsp_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7663
8008
  enum wsp_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7664
8009
  wsp_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7665
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
7666
- wsp_ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
7667
8010
 
7668
8011
  // we don't support permuted src0 or src1
7669
8012
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
@@ -7679,21 +8022,27 @@ static void wsp_ggml_compute_forward_mul_mat_id(
7679
8022
  const int n_ids = ids->ne[0]; // n_expert_used
7680
8023
  const int n_as = ne02; // n_expert
7681
8024
 
7682
- char * wdata_src1_end = (src1->type == vec_dot_type) ?
7683
- (char *) params->wdata :
7684
- (char *) params->wdata + WSP_GGML_PAD(wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
8025
+ void * wdata_cur = params->wdata;
7685
8026
 
7686
- struct mmid_row_mapping {
7687
- int32_t i1;
7688
- int32_t i2;
7689
- };
8027
+ if (src1->type != vec_dot_type) {
8028
+ incr_ptr_aligned(&wdata_cur, wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t));
8029
+ }
8030
+
8031
+ int64_t * matrix_row_counts = // [n_as]
8032
+ incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
8033
+
8034
+ struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
8035
+ incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
7690
8036
 
7691
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
7692
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
8037
+ char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
8038
+ incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
8039
+
8040
+ WSP_GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
7693
8041
 
7694
8042
  if (src1->type != vec_dot_type) {
7695
8043
  char * wdata = params->wdata;
7696
8044
 
8045
+ const size_t nbw0 = wsp_ggml_type_size(vec_dot_type);
7697
8046
  const size_t nbw1 = wsp_ggml_row_size(vec_dot_type, ne10);
7698
8047
  const size_t nbw2 = nbw1*ne11;
7699
8048
  const size_t nbw3 = nbw2*ne12;
@@ -7701,19 +8050,32 @@ static void wsp_ggml_compute_forward_mul_mat_id(
7701
8050
  assert(params->wsize >= ne13*nbw3);
7702
8051
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
7703
8052
 
8053
+ #if 0
7704
8054
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7705
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
7706
- for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
8055
+ for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
8056
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7707
8057
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7708
8058
  (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7709
8059
  ne10);
7710
8060
  }
7711
8061
  }
7712
8062
  }
8063
+ #else
8064
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
8065
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
8066
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
8067
+ size_t bs = wsp_ggml_blck_size(vec_dot_type);
8068
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
8069
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
8070
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
8071
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
8072
+ (ne10_block_end - ne10_block_start) * bs);
8073
+ }
8074
+ }
8075
+ }
8076
+ #endif
7713
8077
  }
7714
8078
 
7715
- #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
7716
-
7717
8079
  if (ith == 0) {
7718
8080
  // initialize matrix_row_counts
7719
8081
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7731,9 +8093,14 @@ static void wsp_ggml_compute_forward_mul_mat_id(
7731
8093
  }
7732
8094
  }
7733
8095
 
8096
+ // reset current_chunk
8097
+ for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
8098
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
8099
+ *current_chunk_ctr = nth;
8100
+ }
8101
+
7734
8102
  wsp_ggml_barrier(params->threadpool);
7735
8103
 
7736
- // compute each matrix multiplication in sequence
7737
8104
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
7738
8105
  const int64_t cne1 = matrix_row_counts[cur_a];
7739
8106
 
@@ -7741,112 +8108,64 @@ static void wsp_ggml_compute_forward_mul_mat_id(
7741
8108
  continue;
7742
8109
  }
7743
8110
 
7744
- const char * src0_cur = (const char *) src0->data + cur_a*nb02;
7745
-
7746
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
8111
+ const char * src0_cur = (const char *) src0->data + cur_a * nb02;
8112
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7747
8113
  const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10);
7748
8114
 
7749
- const int64_t nr0 = ne01; // src0 rows
7750
- const int64_t nr1 = cne1; // src1 rows
7751
-
7752
- if (((wsp_ggml_n_dims(src0) - 1) == 2) && gemv) {
7753
- int64_t src0_cur_start = (ith * ne01) / nth;
7754
- int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
7755
- src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
7756
- src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
7757
- if (src0_cur_start >= src0_cur_end) return;
7758
-
7759
- for (int ir1 = 0; ir1 < nr1; ir1++) {
7760
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
7761
- const int id = row_mapping.i1; // selected expert index
7762
-
7763
- const int64_t i11 = id % ne11;
7764
- const int64_t i12 = row_mapping.i2; // row index in src1
7765
-
7766
- const int64_t i1 = id; // selected expert index
7767
- const int64_t i2 = i12; // row
7768
-
7769
- const char * src1_col = (const char *) wdata +
7770
- (src1_cont || src1->type != vec_dot_type
7771
- ? (i11 + i12 * ne11) * row_size
7772
- : (i11 * nb11 + i12 * nb12));
8115
+ const int64_t nr0 = ne01;
8116
+ const int64_t nr1 = cne1;
7773
8117
 
7774
- gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
7775
- (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
7776
- }
7777
- continue;
8118
+ int chunk_size = 16;
8119
+ if (nr0 == 1 || nr1 == 1) {
8120
+ chunk_size = 64;
7778
8121
  }
7779
8122
 
7780
- // distribute the thread work across the inner or outer loop based on which one is larger
7781
-
7782
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
7783
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
7784
-
7785
- const int64_t ith0 = ith % nth0;
7786
- const int64_t ith1 = ith / nth0;
7787
-
7788
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
7789
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
7790
-
7791
- const int64_t ir010 = dr0*ith0;
7792
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
7793
-
7794
- const int64_t ir110 = dr1*ith1;
7795
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
7796
-
7797
- // threads with no work simply yield (not sure if it helps)
7798
- //if (ir010 >= ir011 || ir110 >= ir111) {
7799
- // sched_yield();
7800
- // continue;
7801
- //}
7802
-
7803
- // block-tiling attempt
7804
- const int64_t blck_0 = 16;
7805
- const int64_t blck_1 = 16;
8123
+ #if defined(__aarch64__)
8124
+ // disable for ARM
8125
+ const bool disable_chunking = true;
8126
+ #else
8127
+ // disable for NUMA
8128
+ const bool disable_chunking = wsp_ggml_is_numa();
8129
+ #endif // defined(__aarch64__)
7806
8130
 
7807
- // attempt to reduce false-sharing (does not seem to make a difference)
7808
- float tmp[16];
8131
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
8132
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
7809
8133
 
7810
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
7811
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
7812
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
7813
- const int64_t _i12 = ir1; // logical row index for this expert
8134
+ if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
8135
+ nchunk0 = nr0 > nr1 ? nth : 1;
8136
+ nchunk1 = nr0 > nr1 ? 1 : nth;
8137
+ }
7814
8138
 
7815
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7816
- const int id = row_mapping.i1; // selected expert index
8139
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
8140
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
7817
8141
 
7818
- const int64_t i11 = id % ne11;
7819
- const int64_t i12 = row_mapping.i2; // row index in src1
8142
+ int current_chunk = ith;
7820
8143
 
7821
- const int64_t i1 = id; // selected expert index
7822
- const int64_t i2 = i12; // row
8144
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
7823
8145
 
7824
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7825
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7826
- // the original src1 data pointer, so we should index using the indices directly
7827
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
7828
- const char * src1_col = (const char *) wdata +
7829
- (src1_cont || src1->type != vec_dot_type
7830
- ? (i11 + i12*ne11)*row_size
7831
- : (i11*nb11 + i12*nb12));
8146
+ while (current_chunk < nchunk0 * nchunk1) {
8147
+ const int64_t ith0 = current_chunk % nchunk0;
8148
+ const int64_t ith1 = current_chunk / nchunk0;
7832
8149
 
7833
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
8150
+ const int64_t ir0_start = dr0 * ith0;
8151
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
7834
8152
 
7835
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7836
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
7837
- //}
8153
+ const int64_t ir1_start = dr1 * ith1;
8154
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7838
8155
 
7839
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7840
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7841
- }
8156
+ wsp_ggml_compute_forward_mul_mat_id_one_chunk(
8157
+ dst, src0, src1, ids, cur_a,
8158
+ ir0_start, ir0_end, ir1_start, ir1_end,
8159
+ src0_cur, matrix_rows, row_size, src1_cont, wdata
8160
+ );
7842
8161
 
7843
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
7844
- }
8162
+ if (nth >= nchunk0 * nchunk1) {
8163
+ break;
7845
8164
  }
8165
+
8166
+ current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
7846
8167
  }
7847
8168
  }
7848
-
7849
- #undef MMID_MATRIX_ROW
7850
8169
  }
7851
8170
 
7852
8171
  // wsp_ggml_compute_forward_out_prod
@@ -7867,12 +8186,13 @@ static void wsp_ggml_compute_forward_out_prod_f32(
7867
8186
  const int ith = params->ith;
7868
8187
  const int nth = params->nth;
7869
8188
 
7870
- WSP_GGML_ASSERT(ne0 == ne00);
7871
- WSP_GGML_ASSERT(ne1 == ne10);
7872
- WSP_GGML_ASSERT(ne2 == ne02);
7873
- WSP_GGML_ASSERT(ne02 == ne12);
7874
- WSP_GGML_ASSERT(ne3 == ne13);
7875
- WSP_GGML_ASSERT(ne03 == ne13);
8189
+ WSP_GGML_ASSERT(ne0 == ne00);
8190
+ WSP_GGML_ASSERT(ne1 == ne10);
8191
+ WSP_GGML_ASSERT(ne2 == ne12);
8192
+ WSP_GGML_ASSERT(ne3 == ne13);
8193
+
8194
+ WSP_GGML_ASSERT(ne2 % ne02 == 0);
8195
+ WSP_GGML_ASSERT(ne3 % ne03 == 0);
7876
8196
 
7877
8197
  // we don't support permuted src0 or src1
7878
8198
  WSP_GGML_ASSERT(nb00 == sizeof(float));
@@ -7914,6 +8234,10 @@ static void wsp_ggml_compute_forward_out_prod_f32(
7914
8234
  const int64_t blck_0 = MAX(WSP_GGML_VEC_MAD_UNROLL, 32);
7915
8235
  const int64_t blck_1 = 16;
7916
8236
 
8237
+ // dps == dst per src0, used for group query attention
8238
+ const int64_t dps2 = ne2 / ne02;
8239
+ const int64_t dps3 = ne3 / ne03;
8240
+
7917
8241
  for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
7918
8242
  const int64_t bir1 = MIN(bir + blck_1, ir1);
7919
8243
  for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7924,8 +8248,8 @@ static void wsp_ggml_compute_forward_out_prod_f32(
7924
8248
  const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
7925
8249
  const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
7926
8250
 
7927
- const int64_t i02 = i2;
7928
- const int64_t i03 = i3;
8251
+ const int64_t i02 = i2 / dps2;
8252
+ const int64_t i03 = i3 / dps3;
7929
8253
 
7930
8254
  //const int64_t i10 = i1;
7931
8255
  const int64_t i12 = i2;
@@ -7938,7 +8262,7 @@ static void wsp_ggml_compute_forward_out_prod_f32(
7938
8262
 
7939
8263
  float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
7940
8264
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7941
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
8265
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7942
8266
 
7943
8267
  wsp_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
7944
8268
  }
@@ -7947,7 +8271,7 @@ static void wsp_ggml_compute_forward_out_prod_f32(
7947
8271
 
7948
8272
  float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
7949
8273
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7950
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
8274
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7951
8275
 
7952
8276
  wsp_ggml_vec_mad_f32(ne0, d, s0, *s1);
7953
8277
  }
@@ -8084,9 +8408,6 @@ static void wsp_ggml_compute_forward_out_prod(
8084
8408
  case WSP_GGML_TYPE_IQ4_XS:
8085
8409
  case WSP_GGML_TYPE_IQ3_S:
8086
8410
  case WSP_GGML_TYPE_IQ2_S:
8087
- case WSP_GGML_TYPE_Q4_0_4_4:
8088
- case WSP_GGML_TYPE_Q4_0_4_8:
8089
- case WSP_GGML_TYPE_Q4_0_8_8:
8090
8411
  {
8091
8412
  wsp_ggml_compute_forward_out_prod_q_f32(params, dst);
8092
8413
  } break;
@@ -8239,6 +8560,77 @@ static void wsp_ggml_compute_forward_set_f32(
8239
8560
  }
8240
8561
  }
8241
8562
 
8563
+ static void wsp_ggml_compute_forward_set_i32(
8564
+ const struct wsp_ggml_compute_params * params,
8565
+ struct wsp_ggml_tensor * dst) {
8566
+
8567
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
8568
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
8569
+
8570
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
8571
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
8572
+
8573
+ // view src0 and dst with these strides and data offset inbytes during set
8574
+ // nb0 is implicitly element_size because src0 and dst are contiguous
8575
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
8576
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
8577
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
8578
+ size_t offset = ((int32_t *) dst->op_params)[3];
8579
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
8580
+
8581
+ if (!inplace) {
8582
+ if (params->ith == 0) {
8583
+ // memcpy needs to be synchronized across threads to avoid race conditions.
8584
+ // => do it in INIT phase
8585
+ memcpy(
8586
+ ((char *) dst->data),
8587
+ ((char *) src0->data),
8588
+ wsp_ggml_nbytes(dst));
8589
+ }
8590
+ wsp_ggml_barrier(params->threadpool);
8591
+ }
8592
+
8593
+ const int ith = params->ith;
8594
+ const int nth = params->nth;
8595
+
8596
+ const int nr = wsp_ggml_nrows(src1);
8597
+ const int nc = src1->ne[0];
8598
+
8599
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
8600
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
8601
+
8602
+ // src0 and dst as viewed during set
8603
+ const size_t nb0 = wsp_ggml_element_size(src0);
8604
+
8605
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
8606
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
8607
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
8608
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
8609
+
8610
+ WSP_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= wsp_ggml_nbytes(dst));
8611
+
8612
+ WSP_GGML_ASSERT(nb10 == sizeof(int32_t));
8613
+
8614
+ // rows per thread
8615
+ const int dr = (nr + nth - 1)/nth;
8616
+
8617
+ // row range for this thread
8618
+ const int ir0 = dr*ith;
8619
+ const int ir1 = MIN(ir0 + dr, nr);
8620
+
8621
+ for (int ir = ir0; ir < ir1; ++ir) {
8622
+ // src0 and dst are viewed with shape of src1 and offset
8623
+ // => same indices
8624
+ const int i3 = ir/(ne12*ne11);
8625
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
8626
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
8627
+
8628
+ wsp_ggml_vec_cpy_i32(nc,
8629
+ (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
8630
+ (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8631
+ }
8632
+ }
8633
+
8242
8634
  static void wsp_ggml_compute_forward_set(
8243
8635
  const struct wsp_ggml_compute_params * params,
8244
8636
  struct wsp_ggml_tensor * dst) {
@@ -8250,6 +8642,10 @@ static void wsp_ggml_compute_forward_set(
8250
8642
  {
8251
8643
  wsp_ggml_compute_forward_set_f32(params, dst);
8252
8644
  } break;
8645
+ case WSP_GGML_TYPE_I32:
8646
+ {
8647
+ wsp_ggml_compute_forward_set_i32(params, dst);
8648
+ } break;
8253
8649
  case WSP_GGML_TYPE_F16:
8254
8650
  case WSP_GGML_TYPE_BF16:
8255
8651
  case WSP_GGML_TYPE_Q4_0:
@@ -8274,9 +8670,6 @@ static void wsp_ggml_compute_forward_set(
8274
8670
  case WSP_GGML_TYPE_IQ4_XS:
8275
8671
  case WSP_GGML_TYPE_IQ3_S:
8276
8672
  case WSP_GGML_TYPE_IQ2_S:
8277
- case WSP_GGML_TYPE_Q4_0_4_4:
8278
- case WSP_GGML_TYPE_Q4_0_4_8:
8279
- case WSP_GGML_TYPE_Q4_0_8_8:
8280
8673
  default:
8281
8674
  {
8282
8675
  WSP_GGML_ABORT("fatal error");
@@ -8538,9 +8931,6 @@ static void wsp_ggml_compute_forward_get_rows(
8538
8931
  case WSP_GGML_TYPE_IQ4_XS:
8539
8932
  case WSP_GGML_TYPE_IQ3_S:
8540
8933
  case WSP_GGML_TYPE_IQ2_S:
8541
- case WSP_GGML_TYPE_Q4_0_4_4:
8542
- case WSP_GGML_TYPE_Q4_0_4_8:
8543
- case WSP_GGML_TYPE_Q4_0_8_8:
8544
8934
  {
8545
8935
  wsp_ggml_compute_forward_get_rows_q(params, dst);
8546
8936
  } break;
@@ -8957,9 +9347,9 @@ static void wsp_ggml_compute_forward_soft_max(
8957
9347
  }
8958
9348
 
8959
9349
 
8960
- // wsp_ggml_compute_forward_soft_max_back
9350
+ // wsp_ggml_compute_forward_soft_max_ext_back
8961
9351
 
8962
- static void wsp_ggml_compute_forward_soft_max_back_f32(
9352
+ static void wsp_ggml_compute_forward_soft_max_ext_back_f32(
8963
9353
  const struct wsp_ggml_compute_params * params,
8964
9354
  struct wsp_ggml_tensor * dst) {
8965
9355
 
@@ -8972,6 +9362,14 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
8972
9362
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
8973
9363
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src1, dst));
8974
9364
 
9365
+ float scale = 1.0f;
9366
+ float max_bias = 0.0f;
9367
+
9368
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
9369
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
9370
+
9371
+ WSP_GGML_ASSERT(max_bias == 0.0f);
9372
+
8975
9373
  // TODO: handle transposed/permuted matrices
8976
9374
 
8977
9375
  const int ith = params->ith;
@@ -9020,10 +9418,11 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
9020
9418
 
9021
9419
  // linear runtime, no additional memory
9022
9420
  float dot_y_dy = 0;
9023
- wsp_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
9024
- wsp_ggml_vec_cpy_f32 (nc, dx, dy);
9025
- wsp_ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
9026
- wsp_ggml_vec_mul_f32 (nc, dx, dx, y);
9421
+ wsp_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
9422
+ wsp_ggml_vec_cpy_f32 (nc, dx, dy);
9423
+ wsp_ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
9424
+ wsp_ggml_vec_mul_f32 (nc, dx, dx, y);
9425
+ wsp_ggml_vec_scale_f32(nc, dx, scale);
9027
9426
 
9028
9427
  #ifndef NDEBUG
9029
9428
  for (int i = 0; i < nc; ++i) {
@@ -9034,7 +9433,7 @@ static void wsp_ggml_compute_forward_soft_max_back_f32(
9034
9433
  }
9035
9434
  }
9036
9435
 
9037
- static void wsp_ggml_compute_forward_soft_max_back(
9436
+ static void wsp_ggml_compute_forward_soft_max_ext_back(
9038
9437
  const struct wsp_ggml_compute_params * params,
9039
9438
  struct wsp_ggml_tensor * dst) {
9040
9439
 
@@ -9043,7 +9442,7 @@ static void wsp_ggml_compute_forward_soft_max_back(
9043
9442
  switch (src0->type) {
9044
9443
  case WSP_GGML_TYPE_F32:
9045
9444
  {
9046
- wsp_ggml_compute_forward_soft_max_back_f32(params, dst);
9445
+ wsp_ggml_compute_forward_soft_max_ext_back_f32(params, dst);
9047
9446
  } break;
9048
9447
  default:
9049
9448
  {
@@ -9060,10 +9459,6 @@ static void wsp_ggml_compute_forward_clamp_f32(
9060
9459
 
9061
9460
  const struct wsp_ggml_tensor * src0 = dst->src[0];
9062
9461
 
9063
- if (params->ith != 0) {
9064
- return;
9065
- }
9066
-
9067
9462
  float min;
9068
9463
  float max;
9069
9464
  memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@@ -9130,9 +9525,6 @@ static void wsp_ggml_compute_forward_clamp(
9130
9525
  case WSP_GGML_TYPE_IQ3_S:
9131
9526
  case WSP_GGML_TYPE_IQ2_S:
9132
9527
  case WSP_GGML_TYPE_Q8_K:
9133
- case WSP_GGML_TYPE_Q4_0_4_4:
9134
- case WSP_GGML_TYPE_Q4_0_4_8:
9135
- case WSP_GGML_TYPE_Q4_0_8_8:
9136
9528
  case WSP_GGML_TYPE_I8:
9137
9529
  case WSP_GGML_TYPE_I16:
9138
9530
  case WSP_GGML_TYPE_I32:
@@ -9187,6 +9579,64 @@ static void wsp_ggml_rope_cache_init(
9187
9579
  }
9188
9580
  }
9189
9581
 
9582
+ static void wsp_ggml_mrope_cache_init(
9583
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
9584
+ float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
9585
+ float * cache, float sin_sign, float theta_scale) {
9586
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
9587
+ float theta_t = theta_base_t;
9588
+ float theta_h = theta_base_h;
9589
+ float theta_w = theta_base_w;
9590
+ float theta_e = theta_base_e; // extra position id for vision encoder
9591
+ int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
9592
+ int sec_w = sections[1] + sections[0];
9593
+ int sec_e = sections[2] + sec_w;
9594
+ WSP_GGML_ASSERT(sect_dims <= ne0);
9595
+
9596
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
9597
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
9598
+
9599
+ int sector = (i0 / 2) % sect_dims;
9600
+ if (indep_sects) {
9601
+ // compute theta independently for each dim sections
9602
+ // (i.e. reset corresponding theta when `i0` go from one section to another)
9603
+ if (sector == 0) {
9604
+ theta_t = theta_base_t;
9605
+ }
9606
+ else if (sector == sections[0]) {
9607
+ theta_h = theta_base_h;;
9608
+ }
9609
+ else if (sector == sec_w) {
9610
+ theta_w = theta_base_w;
9611
+ }
9612
+ else if (sector == sec_e) {
9613
+ theta_e = theta_base_e;
9614
+ }
9615
+ }
9616
+
9617
+ float theta = theta_t;
9618
+ if (sector >= sections[0] && sector < sec_w) {
9619
+ theta = theta_h;
9620
+ }
9621
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
9622
+ theta = theta_w;
9623
+ }
9624
+ else if (sector >= sec_w + sections[2]) {
9625
+ theta = theta_e;
9626
+ }
9627
+
9628
+ rope_yarn(
9629
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
9630
+ );
9631
+ cache[i0 + 1] *= sin_sign;
9632
+
9633
+ theta_t *= theta_scale;
9634
+ theta_w *= theta_scale;
9635
+ theta_h *= theta_scale;
9636
+ theta_e *= theta_scale;
9637
+ }
9638
+ }
9639
+
9190
9640
  static void wsp_ggml_compute_forward_rope_f32(
9191
9641
  const struct wsp_ggml_compute_params * params,
9192
9642
  struct wsp_ggml_tensor * dst,
@@ -9197,6 +9647,7 @@ static void wsp_ggml_compute_forward_rope_f32(
9197
9647
  const struct wsp_ggml_tensor * src2 = dst->src[2];
9198
9648
 
9199
9649
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
9650
+ int sections[4];
9200
9651
 
9201
9652
  //const int n_past = ((int32_t *) dst->op_params)[0];
9202
9653
  const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -9210,6 +9661,7 @@ static void wsp_ggml_compute_forward_rope_f32(
9210
9661
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
9211
9662
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
9212
9663
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
9664
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
9213
9665
 
9214
9666
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
9215
9667
 
@@ -9242,6 +9694,16 @@ static void wsp_ggml_compute_forward_rope_f32(
9242
9694
  wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
9243
9695
 
9244
9696
  const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
9697
+ const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
9698
+ const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
9699
+
9700
+ if (is_mrope) {
9701
+ WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
9702
+ }
9703
+
9704
+ if (is_vision) {
9705
+ WSP_GGML_ASSERT(n_dims == ne0/2);
9706
+ }
9245
9707
 
9246
9708
  const float * freq_factors = NULL;
9247
9709
  if (src2 != NULL) {
@@ -9257,18 +9719,63 @@ static void wsp_ggml_compute_forward_rope_f32(
9257
9719
 
9258
9720
  const int32_t * pos = (const int32_t *) src1->data;
9259
9721
 
9260
- for (int64_t i3 = 0; i3 < ne3; i3++) {
9261
- for (int64_t i2 = 0; i2 < ne2; i2++) {
9262
- const int64_t p = pos[i2];
9722
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
9723
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
9263
9724
 
9264
9725
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
9265
- wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9726
+ if (!is_mrope) {
9727
+ const int64_t p = pos[i2];
9728
+ wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9729
+ }
9730
+ else {
9731
+ const int64_t p_t = pos[i2];
9732
+ const int64_t p_h = pos[i2 + ne2];
9733
+ const int64_t p_w = pos[i2 + ne2 * 2];
9734
+ const int64_t p_e = pos[i2 + ne2 * 3];
9735
+ wsp_ggml_mrope_cache_init(
9736
+ p_t, p_h, p_w, p_e, sections, is_vision,
9737
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9738
+ }
9266
9739
 
9267
- for (int64_t i1 = 0; i1 < ne1; i1++) {
9740
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
9268
9741
  if (ir++ < ir0) continue;
9269
9742
  if (ir > ir1) break;
9270
9743
 
9271
- if (!is_neox) {
9744
+ if (is_neox || is_mrope) {
9745
+ if (is_vision){
9746
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9747
+ const int64_t ic = i0/2;
9748
+
9749
+ const float cos_theta = cache[i0 + 0];
9750
+ const float sin_theta = cache[i0 + 1];
9751
+
9752
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9753
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9754
+
9755
+ const float x0 = src[0];
9756
+ const float x1 = src[n_dims];
9757
+
9758
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9759
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
9760
+ }
9761
+ } else {
9762
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9763
+ const int64_t ic = i0/2;
9764
+
9765
+ const float cos_theta = cache[i0 + 0];
9766
+ const float sin_theta = cache[i0 + 1];
9767
+
9768
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9769
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9770
+
9771
+ const float x0 = src[0];
9772
+ const float x1 = src[n_dims/2];
9773
+
9774
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9775
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9776
+ }
9777
+ }
9778
+ } else {
9272
9779
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9273
9780
  const float cos_theta = cache[i0 + 0];
9274
9781
  const float sin_theta = cache[i0 + 1];
@@ -9282,8 +9789,10 @@ static void wsp_ggml_compute_forward_rope_f32(
9282
9789
  dst_data[0] = x0*cos_theta - x1*sin_theta;
9283
9790
  dst_data[1] = x0*sin_theta + x1*cos_theta;
9284
9791
  }
9285
- } else {
9286
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9792
+ }
9793
+
9794
+ if (is_vision) {
9795
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9287
9796
  const int64_t ic = i0/2;
9288
9797
 
9289
9798
  const float cos_theta = cache[i0 + 0];
@@ -9293,19 +9802,20 @@ static void wsp_ggml_compute_forward_rope_f32(
9293
9802
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9294
9803
 
9295
9804
  const float x0 = src[0];
9296
- const float x1 = src[n_dims/2];
9805
+ const float x1 = src[n_dims];
9297
9806
 
9298
- dst_data[0] = x0*cos_theta - x1*sin_theta;
9299
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9807
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9808
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
9300
9809
  }
9301
- }
9302
-
9303
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9304
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9305
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9810
+ } else {
9811
+ // fill the remain channels with data from src tensor
9812
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9813
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9814
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9306
9815
 
9307
- dst_data[0] = src[0];
9308
- dst_data[1] = src[1];
9816
+ dst_data[0] = src[0];
9817
+ dst_data[1] = src[1];
9818
+ }
9309
9819
  }
9310
9820
  }
9311
9821
  }
@@ -9323,6 +9833,7 @@ static void wsp_ggml_compute_forward_rope_f16(
9323
9833
  const struct wsp_ggml_tensor * src2 = dst->src[2];
9324
9834
 
9325
9835
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
9836
+ int sections[4];
9326
9837
 
9327
9838
  //const int n_past = ((int32_t *) dst->op_params)[0];
9328
9839
  const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -9335,6 +9846,8 @@ static void wsp_ggml_compute_forward_rope_f16(
9335
9846
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
9336
9847
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
9337
9848
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
9849
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
9850
+
9338
9851
 
9339
9852
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
9340
9853
 
@@ -9367,6 +9880,16 @@ static void wsp_ggml_compute_forward_rope_f16(
9367
9880
  wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
9368
9881
 
9369
9882
  const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
9883
+ const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
9884
+ const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
9885
+
9886
+ if (is_mrope) {
9887
+ WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
9888
+ }
9889
+
9890
+ if (is_vision) {
9891
+ WSP_GGML_ASSERT(n_dims == ne0/2);
9892
+ }
9370
9893
 
9371
9894
  const float * freq_factors = NULL;
9372
9895
  if (src2 != NULL) {
@@ -9384,16 +9907,61 @@ static void wsp_ggml_compute_forward_rope_f16(
9384
9907
 
9385
9908
  for (int64_t i3 = 0; i3 < ne3; i3++) {
9386
9909
  for (int64_t i2 = 0; i2 < ne2; i2++) {
9387
- const int64_t p = pos[i2];
9388
9910
 
9389
9911
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
9390
- wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9912
+ if (!is_mrope) {
9913
+ const int64_t p = pos[i2];
9914
+ wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9915
+ }
9916
+ else {
9917
+ const int64_t p_t = pos[i2];
9918
+ const int64_t p_h = pos[i2 + ne2];
9919
+ const int64_t p_w = pos[i2 + ne2 * 2];
9920
+ const int64_t p_e = pos[i2 + ne2 * 3];
9921
+ wsp_ggml_mrope_cache_init(
9922
+ p_t, p_h, p_w, p_e, sections, is_vision,
9923
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9924
+ }
9391
9925
 
9392
9926
  for (int64_t i1 = 0; i1 < ne1; i1++) {
9393
9927
  if (ir++ < ir0) continue;
9394
9928
  if (ir > ir1) break;
9395
9929
 
9396
- if (!is_neox) {
9930
+ if (is_neox || is_mrope) {
9931
+ if (is_vision) {
9932
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9933
+ const int64_t ic = i0/2;
9934
+
9935
+ const float cos_theta = cache[i0 + 0];
9936
+ const float sin_theta = cache[i0 + 1];
9937
+
9938
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9939
+ wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9940
+
9941
+ const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
9942
+ const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
9943
+
9944
+ dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9945
+ dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9946
+ }
9947
+ } else {
9948
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9949
+ const int64_t ic = i0/2;
9950
+
9951
+ const float cos_theta = cache[i0 + 0];
9952
+ const float sin_theta = cache[i0 + 1];
9953
+
9954
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9955
+ wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9956
+
9957
+ const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
9958
+ const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]);
9959
+
9960
+ dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9961
+ dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9962
+ }
9963
+ }
9964
+ } else {
9397
9965
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9398
9966
  const float cos_theta = cache[i0 + 0];
9399
9967
  const float sin_theta = cache[i0 + 1];
@@ -9407,8 +9975,10 @@ static void wsp_ggml_compute_forward_rope_f16(
9407
9975
  dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9408
9976
  dst_data[1] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9409
9977
  }
9410
- } else {
9411
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9978
+ }
9979
+
9980
+ if (is_vision) {
9981
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9412
9982
  const int64_t ic = i0/2;
9413
9983
 
9414
9984
  const float cos_theta = cache[i0 + 0];
@@ -9418,19 +9988,19 @@ static void wsp_ggml_compute_forward_rope_f16(
9418
9988
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9419
9989
 
9420
9990
  const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
9421
- const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]);
9991
+ const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
9422
9992
 
9423
- dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9424
- dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9993
+ dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9994
+ dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9425
9995
  }
9426
- }
9427
-
9428
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9429
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9430
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9996
+ } else {
9997
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9998
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9999
+ wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9431
10000
 
9432
- dst_data[0] = src[0];
9433
- dst_data[1] = src[1];
10001
+ dst_data[0] = src[0];
10002
+ dst_data[1] = src[1];
10003
+ }
9434
10004
  }
9435
10005
  }
9436
10006
  }
@@ -9861,9 +10431,10 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
9861
10431
  const struct wsp_ggml_compute_params * params,
9862
10432
  struct wsp_ggml_tensor * dst) {
9863
10433
 
9864
- const struct wsp_ggml_tensor * src0 = dst->src[0];
9865
- const struct wsp_ggml_tensor * src1 = dst->src[1];
10434
+ const struct wsp_ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
10435
+ const struct wsp_ggml_tensor * src1 = dst->src[1]; // convolution kernel
9866
10436
 
10437
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
9867
10438
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
9868
10439
  WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
9869
10440
 
@@ -9885,11 +10456,11 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
9885
10456
  const int64_t IH = is_2D ? ne1 : 1;
9886
10457
  const int64_t IW = ne0;
9887
10458
 
9888
- const int64_t KH = is_2D ? ne01 : 1;
9889
- const int64_t KW = ne00;
10459
+ const int64_t KH = is_2D ? ne11 : 1;
10460
+ const int64_t KW = ne10;
9890
10461
 
9891
- const int64_t OH = is_2D ? ne12 : 1;
9892
- const int64_t OW = ne11;
10462
+ const int64_t OH = is_2D ? ne02 : 1;
10463
+ const int64_t OW = ne01;
9893
10464
 
9894
10465
  int ofs0 = is_2D ? nb3 : nb2;
9895
10466
  int ofs1 = is_2D ? nb2 : nb1;
@@ -9935,9 +10506,9 @@ static void wsp_ggml_compute_forward_im2col_back_f32(
9935
10506
  continue;
9936
10507
  }
9937
10508
 
9938
- const float * const src_data = (const float *) src1->data
10509
+ const float * const grad_in = (const float *) src0->data
9939
10510
  + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
9940
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
10511
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
9941
10512
  }
9942
10513
  }
9943
10514
  float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -10429,6 +11000,40 @@ static void wsp_ggml_compute_forward_pad(
10429
11000
  }
10430
11001
  }
10431
11002
 
11003
+ // wsp_ggml_compute_forward_pad_reflect_1d
11004
+
11005
+ static void wsp_ggml_compute_forward_pad_reflect_1d(
11006
+ const struct wsp_ggml_compute_params * params,
11007
+ struct wsp_ggml_tensor * dst) {
11008
+
11009
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
11010
+
11011
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
11012
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11013
+
11014
+ const int ith = params->ith;
11015
+ const int nth = params->nth;
11016
+
11017
+ const int32_t * opts = (const int32_t *) dst->op_params;
11018
+ const int p0 = opts[0];
11019
+ const int p1 = opts[1];
11020
+
11021
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
11022
+
11023
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
11024
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
11025
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
11026
+ float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
11027
+ float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
11028
+
11029
+ wsp_ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
11030
+
11031
+ for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
11032
+ for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
11033
+ }
11034
+ }
11035
+ }
11036
+ }
10432
11037
 
10433
11038
  // wsp_ggml_compute_forward_arange
10434
11039
 
@@ -11645,9 +12250,9 @@ static void wsp_ggml_compute_forward_add_rel_pos(
11645
12250
  static void wsp_ggml_compute_forward_rwkv_wkv6_f32(
11646
12251
  const struct wsp_ggml_compute_params * params,
11647
12252
  struct wsp_ggml_tensor * dst) {
11648
- const int64_t T = dst->src[1]->ne[3];
12253
+ const int64_t T = dst->src[1]->ne[2];
11649
12254
  const int64_t C = dst->ne[0];
11650
- const int64_t HEADS = dst->src[1]->ne[2];
12255
+ const int64_t HEADS = dst->src[1]->ne[1];
11651
12256
  const int64_t n_seqs = dst->src[5]->ne[1];
11652
12257
  const int64_t head_size = C / HEADS;
11653
12258
 
@@ -11842,6 +12447,197 @@ static void wsp_ggml_compute_forward_rwkv_wkv6(
11842
12447
  }
11843
12448
  }
11844
12449
 
12450
+ // wsp_ggml_compute_forward_gla
12451
+
12452
+ static void wsp_ggml_compute_forward_gla_f32(
12453
+ const struct wsp_ggml_compute_params * params,
12454
+ struct wsp_ggml_tensor * dst) {
12455
+ const int64_t T = dst->src[1]->ne[2];
12456
+ const int64_t C = dst->ne[0];
12457
+ const int64_t HEADS = dst->src[1]->ne[1];
12458
+ const int64_t n_seqs = dst->src[4]->ne[1];
12459
+ const int64_t head_size = C / HEADS;
12460
+ const float scale = wsp_ggml_get_op_params_f32(dst, 0);
12461
+
12462
+ float * dst_data = (float *) dst->data;
12463
+ float * state = ((float *) dst->data) + C * T;
12464
+
12465
+ const int ith = params->ith;
12466
+ const int nth = params->nth;
12467
+
12468
+ if (ith >= HEADS) {
12469
+ return;
12470
+ }
12471
+
12472
+ const int h_start = (HEADS * ith) / nth;
12473
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
12474
+ (HEADS * (ith + 1)) / nth : HEADS;
12475
+
12476
+ float * k = (float *) dst->src[0]->data;
12477
+ float * v = (float *) dst->src[1]->data;
12478
+ float * q = (float *) dst->src[2]->data;
12479
+ float * g = (float *) dst->src[3]->data;
12480
+
12481
+ size_t t_stride = HEADS * head_size; // Same to C
12482
+
12483
+ size_t h_stride = C / HEADS;
12484
+ WSP_GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
12485
+ size_t h_stride_2d = head_size * head_size;
12486
+
12487
+ if (ith == 0) {
12488
+ memset(dst_data, 0, T * C * sizeof(float));
12489
+ }
12490
+ wsp_ggml_barrier(params->threadpool);
12491
+
12492
+
12493
+ #if defined(__AVX__) && !defined(__AVX512F__)
12494
+ #define WSP_GGML_F32X WSP_GGML_F32x8
12495
+ #define WSP_GGML_F32X_SET1 WSP_GGML_F32x8_SET1
12496
+ #define WSP_GGML_F32X_LOAD WSP_GGML_F32x8_LOAD
12497
+ #define WSP_GGML_F32X_STORE WSP_GGML_F32x8_STORE
12498
+ #define WSP_GGML_F32X_MUL WSP_GGML_F32x8_MUL
12499
+ #define WSP_GGML_F32X_FMA WSP_GGML_F32x8_FMA
12500
+ #define GLA_VECTOR_SIZE 8
12501
+ #elif defined(__AVX512F__)
12502
+ #define WSP_GGML_F32X WSP_GGML_F32x16
12503
+ #define WSP_GGML_F32X_SET1 WSP_GGML_F32x16_SET1
12504
+ #define WSP_GGML_F32X_LOAD WSP_GGML_F32x16_LOAD
12505
+ #define WSP_GGML_F32X_STORE WSP_GGML_F32x16_STORE
12506
+ #define WSP_GGML_F32X_MUL WSP_GGML_F32x16_MUL
12507
+ #define WSP_GGML_F32X_FMA WSP_GGML_F32x16_FMA
12508
+ #define GLA_VECTOR_SIZE 16
12509
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
12510
+ #define WSP_GGML_F32X WSP_GGML_F32x4
12511
+ #define WSP_GGML_F32X_SET1 WSP_GGML_F32x4_SET1
12512
+ #define WSP_GGML_F32X_LOAD WSP_GGML_F32x4_LOAD
12513
+ #define WSP_GGML_F32X_STORE WSP_GGML_F32x4_STORE
12514
+ #define WSP_GGML_F32X_MUL WSP_GGML_F32x4_MUL
12515
+ #define WSP_GGML_F32X_FMA WSP_GGML_F32x4_FMA
12516
+ #define GLA_VECTOR_SIZE 4
12517
+ #endif
12518
+
12519
+ #ifdef GLA_VECTOR_SIZE
12520
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
12521
+
12522
+ for (int64_t t = 0; t < T; t++) {
12523
+ size_t t_offset = t * t_stride;
12524
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12525
+ float * state_cur = state + state_offset;
12526
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12527
+
12528
+ for (int64_t h = h_start; h < h_end; h++) {
12529
+ size_t h_offset = h * h_stride;
12530
+ size_t t_h_offset = t_offset + h_offset;
12531
+ size_t h_2d_offset = h * h_stride_2d;
12532
+
12533
+ for (int64_t i = 0; i < head_size; i++) {
12534
+ size_t t_h_i_offset = t_h_offset + i;
12535
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12536
+
12537
+ float k_val = k[t_h_i_offset];
12538
+ float q_val = q[t_h_i_offset] * scale;
12539
+ float g_val = g[t_h_i_offset];
12540
+
12541
+ // Broadcast scalar values to vectors
12542
+ WSP_GGML_F32X k_vec = WSP_GGML_F32X_SET1(k_val);
12543
+ WSP_GGML_F32X q_vec = WSP_GGML_F32X_SET1(q_val);
12544
+ WSP_GGML_F32X g_vec = WSP_GGML_F32X_SET1(g_val);
12545
+
12546
+ for (int64_t j = 0; j < vec_count; j++) {
12547
+ size_t base_j = j * GLA_VECTOR_SIZE;
12548
+ size_t t_h_j_offset = t_h_offset + base_j;
12549
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
12550
+
12551
+ // Load x elements at once
12552
+ WSP_GGML_F32X v_vec = WSP_GGML_F32X_LOAD(&v[t_h_j_offset]);
12553
+ WSP_GGML_F32X prev_state_vec = WSP_GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
12554
+ WSP_GGML_F32X dst_vec = WSP_GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
12555
+
12556
+ // Compute kv = v * k
12557
+ WSP_GGML_F32X kv_vec = WSP_GGML_F32X_MUL(v_vec, k_vec);
12558
+
12559
+ // Compute temp = prev_state * g + kv
12560
+ WSP_GGML_F32X temp_vec = WSP_GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
12561
+
12562
+ // Update dst: dst += temp * q
12563
+ dst_vec = WSP_GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
12564
+ WSP_GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
12565
+
12566
+ // Update state
12567
+ WSP_GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
12568
+ }
12569
+
12570
+ // Handle remaining elements, this will not be used.
12571
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
12572
+ size_t t_h_j_offset = t_h_offset + j;
12573
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12574
+ float v_val = v[t_h_j_offset];
12575
+ float kv_val = v_val * k_val;
12576
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12577
+ float temp_val = kv_val + prev_state_val * g_val;
12578
+ dst_data[t_h_j_offset] += temp_val * q_val;
12579
+ state_cur[h_2d_i_j_offset] = temp_val;
12580
+ }
12581
+ }
12582
+ }
12583
+ }
12584
+
12585
+ #else
12586
+ for (int64_t t = 0; t < T; t++) {
12587
+ size_t t_offset = t * t_stride;
12588
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12589
+ float * state_cur = state + state_offset;
12590
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12591
+
12592
+ for (int64_t h = h_start; h < h_end; h++) {
12593
+ size_t h_offset = h * h_stride;
12594
+ size_t t_h_offset = t_offset + h_offset;
12595
+ size_t h_2d_offset = h * h_stride_2d;
12596
+
12597
+ for (int64_t i = 0; i < head_size; i++) {
12598
+ size_t t_h_i_offset = t_h_offset + i;
12599
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12600
+
12601
+ float k_val = k[t_h_i_offset];
12602
+ float q_val = q[t_h_i_offset] * scale;
12603
+ float g_val = g[t_h_i_offset];
12604
+
12605
+ for (int64_t j = 0; j < head_size; j++) {
12606
+ size_t t_h_j_offset = t_h_offset + j;
12607
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12608
+
12609
+ float v_val = v[t_h_j_offset];
12610
+ float kv_val = v_val * k_val;
12611
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12612
+ float temp_val = prev_state_val * g_val + kv_val;
12613
+ dst_data[t_h_j_offset] += temp_val * q_val;
12614
+ state_cur[h_2d_i_j_offset] = temp_val;
12615
+ }
12616
+ }
12617
+ }
12618
+ }
12619
+ #endif
12620
+ }
12621
+
12622
+
12623
+ static void wsp_ggml_compute_forward_gla(
12624
+ const struct wsp_ggml_compute_params * params,
12625
+ struct wsp_ggml_tensor * dst) {
12626
+
12627
+ const struct wsp_ggml_tensor * src0 = dst->src[0];
12628
+
12629
+ switch (src0->type) {
12630
+ case WSP_GGML_TYPE_F32:
12631
+ {
12632
+ wsp_ggml_compute_forward_gla_f32(params, dst);
12633
+ } break;
12634
+ default:
12635
+ {
12636
+ WSP_GGML_ABORT("fatal error");
12637
+ }
12638
+ }
12639
+ }
12640
+
11845
12641
  // wsp_ggml_compute_forward_map_unary
11846
12642
 
11847
12643
  static void wsp_ggml_compute_forward_map_unary_f32(
@@ -12135,22 +12931,22 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
12135
12931
  const struct wsp_ggml_compute_params * params,
12136
12932
  struct wsp_ggml_tensor * dst) {
12137
12933
 
12138
- const struct wsp_ggml_tensor * src0 = dst->src[0];
12139
- const struct wsp_ggml_tensor * src1 = dst->src[1];
12140
- const struct wsp_ggml_tensor * opt0 = dst->src[2];
12934
+ const struct wsp_ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
12935
+ const struct wsp_ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
12936
+ const struct wsp_ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
12141
12937
 
12142
12938
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
12143
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
12144
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
12145
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0));
12146
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst));
12939
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0f));
12940
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1f));
12941
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(grad));
12942
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0f, src1f) && wsp_ggml_are_same_shape(src0f, dst));
12147
12943
 
12148
12944
  const int64_t ith = params->ith;
12149
12945
  const int64_t nth = params->nth;
12150
12946
 
12151
12947
  // TODO: handle transposed/permuted matrices
12152
- const int64_t nc = src0->ne[0];
12153
- const int64_t nr = wsp_ggml_nrows(src0);
12948
+ const int64_t nc = src0f->ne[0];
12949
+ const int64_t nr = wsp_ggml_nrows(src0f);
12154
12950
 
12155
12951
  // rows per thread
12156
12952
  const int64_t dr = (nr + nth - 1)/nth;
@@ -12159,12 +12955,12 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
12159
12955
  const int64_t ir0 = dr*ith;
12160
12956
  const int64_t ir1 = MIN(ir0 + dr, nr);
12161
12957
 
12162
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
12958
+ const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
12163
12959
 
12164
12960
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
12165
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12166
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
12167
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
12961
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12962
+ const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
12963
+ const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
12168
12964
 
12169
12965
  #ifndef NDEBUG
12170
12966
  for (int64_t i = 0; i < nc; ++i) {
@@ -12177,11 +12973,11 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32(
12177
12973
  // soft_max
12178
12974
  float max = -INFINITY;
12179
12975
  wsp_ggml_vec_max_f32(nc, &max, s0);
12180
- wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(nc, ds0, s0, max);
12976
+ const wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(nc, ds0, s0, max);
12181
12977
  assert(sum > 0.0);
12182
12978
  wsp_ggml_vec_scale_f32(nc, ds0, 1.0/sum);
12183
12979
 
12184
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
12980
+ // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
12185
12981
  wsp_ggml_vec_sub_f32(nc, ds0, ds0, s1);
12186
12982
  wsp_ggml_vec_scale_f32(nc, ds0, d_by_nr);
12187
12983
 
@@ -12304,6 +13100,9 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
12304
13100
  return;
12305
13101
  }
12306
13102
 
13103
+ // extra_buffer op?
13104
+ if (wsp_ggml_cpu_extra_compute_forward(params, tensor)) return;
13105
+
12307
13106
  switch (tensor->op) {
12308
13107
  case WSP_GGML_OP_DUP:
12309
13108
  {
@@ -12475,7 +13274,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
12475
13274
  } break;
12476
13275
  case WSP_GGML_OP_SOFT_MAX_BACK:
12477
13276
  {
12478
- wsp_ggml_compute_forward_soft_max_back(params, tensor);
13277
+ wsp_ggml_compute_forward_soft_max_ext_back(params, tensor);
12479
13278
  } break;
12480
13279
  case WSP_GGML_OP_ROPE:
12481
13280
  {
@@ -12525,6 +13324,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
12525
13324
  {
12526
13325
  wsp_ggml_compute_forward_pad(params, tensor);
12527
13326
  } break;
13327
+ case WSP_GGML_OP_PAD_REFLECT_1D:
13328
+ {
13329
+ wsp_ggml_compute_forward_pad_reflect_1d(params, tensor);
13330
+ } break;
12528
13331
  case WSP_GGML_OP_ARANGE:
12529
13332
  {
12530
13333
  wsp_ggml_compute_forward_arange(params, tensor);
@@ -12584,6 +13387,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
12584
13387
  {
12585
13388
  wsp_ggml_compute_forward_rwkv_wkv6(params, tensor);
12586
13389
  } break;
13390
+ case WSP_GGML_OP_GATED_LINEAR_ATTN:
13391
+ {
13392
+ wsp_ggml_compute_forward_gla(params, tensor);
13393
+ } break;
12587
13394
  case WSP_GGML_OP_MAP_UNARY:
12588
13395
  {
12589
13396
  wsp_ggml_unary_op_f32_t fun;
@@ -12867,6 +13674,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
12867
13674
  } break;
12868
13675
  case WSP_GGML_OP_UPSCALE:
12869
13676
  case WSP_GGML_OP_PAD:
13677
+ case WSP_GGML_OP_PAD_REFLECT_1D:
12870
13678
  case WSP_GGML_OP_ARANGE:
12871
13679
  case WSP_GGML_OP_TIMESTEP_EMBEDDING:
12872
13680
  case WSP_GGML_OP_ARGSORT:
@@ -12881,6 +13689,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
12881
13689
  case WSP_GGML_OP_WIN_UNPART:
12882
13690
  case WSP_GGML_OP_GET_REL_POS:
12883
13691
  case WSP_GGML_OP_RWKV_WKV6:
13692
+ case WSP_GGML_OP_GATED_LINEAR_ATTN:
12884
13693
  case WSP_GGML_OP_MAP_UNARY:
12885
13694
  case WSP_GGML_OP_MAP_BINARY:
12886
13695
  case WSP_GGML_OP_MAP_CUSTOM1_F32:
@@ -12956,7 +13765,7 @@ static thread_ret_t wsp_ggml_graph_compute_secondary_thread(void* data);
12956
13765
  #include "windows.h"
12957
13766
 
12958
13767
  // TODO: support > 64 CPUs
12959
- bool wsp_ggml_thread_apply_affinity(bool * mask) {
13768
+ static bool wsp_ggml_thread_apply_affinity(bool * mask) {
12960
13769
  HANDLE h = GetCurrentThread();
12961
13770
  uint64_t bitmask = 0ULL;
12962
13771
 
@@ -13246,140 +14055,148 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(
13246
14055
 
13247
14056
  size_t cur = 0;
13248
14057
 
13249
- switch (node->op) {
13250
- case WSP_GGML_OP_CPY:
13251
- case WSP_GGML_OP_DUP:
13252
- {
13253
- if (wsp_ggml_is_quantized(node->type) ||
13254
- // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
13255
- (node->src[0]->type == WSP_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_BF16) ||
13256
- (node->src[0]->type == WSP_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_F16)) {
14058
+ if (!wsp_ggml_cpu_extra_work_size(n_threads, node, &cur)) {
14059
+
14060
+ switch (node->op) {
14061
+ case WSP_GGML_OP_CPY:
14062
+ case WSP_GGML_OP_DUP:
14063
+ {
14064
+ if (wsp_ggml_is_quantized(node->type) ||
14065
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
14066
+ (node->src[0]->type == WSP_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_BF16) ||
14067
+ (node->src[0]->type == WSP_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == WSP_GGML_TYPE_F16)) {
14068
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
14069
+ }
14070
+ } break;
14071
+ case WSP_GGML_OP_ADD:
14072
+ case WSP_GGML_OP_ADD1:
14073
+ {
14074
+ if (wsp_ggml_is_quantized(node->src[0]->type)) {
14075
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
14076
+ }
14077
+ } break;
14078
+ case WSP_GGML_OP_ACC:
14079
+ {
14080
+ if (wsp_ggml_is_quantized(node->src[0]->type)) {
14081
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
14082
+ }
14083
+ } break;
14084
+ case WSP_GGML_OP_COUNT_EQUAL:
14085
+ {
14086
+ cur = wsp_ggml_type_size(node->type)*n_tasks;
14087
+ } break;
14088
+ case WSP_GGML_OP_MUL_MAT:
14089
+ {
14090
+ const enum wsp_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
14091
+
14092
+ if (node->src[1]->type != vec_dot_type) {
14093
+ cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1]));
14094
+ }
14095
+ } break;
14096
+ case WSP_GGML_OP_MUL_MAT_ID:
14097
+ {
14098
+ cur = 0;
14099
+ const struct wsp_ggml_tensor * src0 = node->src[0];
14100
+ const struct wsp_ggml_tensor * src1 = node->src[1];
14101
+ const struct wsp_ggml_tensor * ids = node->src[2];
14102
+ const enum wsp_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
14103
+ const int n_as = src0->ne[2];
14104
+ // src1
14105
+ if (src1->type != vec_dot_type) {
14106
+ cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)) + sizeof(int64_t);
14107
+ }
14108
+ // matrix_row_counts
14109
+ cur += n_as * sizeof(int64_t) + sizeof(int64_t);
14110
+ // matrix_rows
14111
+ cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
14112
+ // atomic_current_chunk
14113
+ cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
14114
+ } break;
14115
+ case WSP_GGML_OP_OUT_PROD:
14116
+ {
14117
+ if (wsp_ggml_is_quantized(node->src[0]->type)) {
14118
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
14119
+ }
14120
+ } break;
14121
+ case WSP_GGML_OP_SOFT_MAX:
14122
+ case WSP_GGML_OP_ROPE:
14123
+ case WSP_GGML_OP_ROPE_BACK:
14124
+ {
13257
14125
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
13258
- }
13259
- } break;
13260
- case WSP_GGML_OP_ADD:
13261
- case WSP_GGML_OP_ADD1:
13262
- {
13263
- if (wsp_ggml_is_quantized(node->src[0]->type)) {
13264
- cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13265
- }
13266
- } break;
13267
- case WSP_GGML_OP_ACC:
13268
- {
13269
- if (wsp_ggml_is_quantized(node->src[0]->type)) {
13270
- cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
13271
- }
13272
- } break;
13273
- case WSP_GGML_OP_COUNT_EQUAL:
13274
- {
13275
- cur = wsp_ggml_type_size(node->type)*n_tasks;
13276
- } break;
13277
- case WSP_GGML_OP_MUL_MAT:
13278
- {
13279
- const enum wsp_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
14126
+ } break;
14127
+ case WSP_GGML_OP_CONV_TRANSPOSE_1D:
14128
+ {
14129
+ WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
14130
+ WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
14131
+ WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
14132
+
14133
+ const int64_t ne00 = node->src[0]->ne[0]; // K
14134
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
14135
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
14136
+ const int64_t ne10 = node->src[1]->ne[0]; // L
14137
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
14138
+
14139
+ if ((node->src[0]->type == WSP_GGML_TYPE_F16 ||
14140
+ node->src[0]->type == WSP_GGML_TYPE_BF16) &&
14141
+ node->src[1]->type == WSP_GGML_TYPE_F32) {
14142
+ cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02;
14143
+ cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11;
14144
+ } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
14145
+ node->src[1]->type == WSP_GGML_TYPE_F32) {
14146
+ cur += sizeof(float)*ne00*ne01*ne02;
14147
+ cur += sizeof(float)*ne10*ne11;
14148
+ } else {
14149
+ WSP_GGML_ABORT("fatal error");
14150
+ }
14151
+ } break;
14152
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
14153
+ {
14154
+ const int64_t ne00 = node->src[0]->ne[0]; // W
14155
+ const int64_t ne01 = node->src[0]->ne[1]; // H
14156
+ const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
14157
+ const int64_t ne03 = node->src[0]->ne[3]; // Channels In
13280
14158
 
13281
- if (node->src[1]->type != vec_dot_type) {
13282
- cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1]));
13283
- }
13284
- } break;
13285
- case WSP_GGML_OP_MUL_MAT_ID:
13286
- {
13287
- cur = 0;
13288
- const struct wsp_ggml_tensor * src0 = node->src[0];
13289
- const struct wsp_ggml_tensor * src1 = node->src[1];
13290
- const enum wsp_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13291
- if (src1->type != vec_dot_type) {
13292
- cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1));
13293
- }
13294
- const int n_as = src0->ne[2];
13295
- cur += WSP_GGML_PAD(cur, sizeof(int64_t)); // align
13296
- cur += n_as * sizeof(int64_t); // matrix_row_counts
13297
- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
13298
- } break;
13299
- case WSP_GGML_OP_OUT_PROD:
13300
- {
13301
- if (wsp_ggml_is_quantized(node->src[0]->type)) {
13302
- cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13303
- }
13304
- } break;
13305
- case WSP_GGML_OP_SOFT_MAX:
13306
- case WSP_GGML_OP_ROPE:
13307
- {
13308
- cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
13309
- } break;
13310
- case WSP_GGML_OP_CONV_TRANSPOSE_1D:
13311
- {
13312
- WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
13313
- WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
13314
- WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
13315
-
13316
- const int64_t ne00 = node->src[0]->ne[0]; // K
13317
- const int64_t ne01 = node->src[0]->ne[1]; // Cout
13318
- const int64_t ne02 = node->src[0]->ne[2]; // Cin
13319
-
13320
- const int64_t ne10 = node->src[1]->ne[0]; // L
13321
- const int64_t ne11 = node->src[1]->ne[1]; // Cin
13322
-
13323
- if ((node->src[0]->type == WSP_GGML_TYPE_F16 ||
13324
- node->src[0]->type == WSP_GGML_TYPE_BF16) &&
13325
- node->src[1]->type == WSP_GGML_TYPE_F32) {
13326
- cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02;
13327
- cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11;
13328
- } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
13329
- node->src[1]->type == WSP_GGML_TYPE_F32) {
13330
- cur += sizeof(float)*ne00*ne01*ne02;
13331
- cur += sizeof(float)*ne10*ne11;
13332
- } else {
13333
- WSP_GGML_ABORT("fatal error");
13334
- }
13335
- } break;
13336
- case WSP_GGML_OP_CONV_TRANSPOSE_2D:
13337
- {
13338
- const int64_t ne00 = node->src[0]->ne[0]; // W
13339
- const int64_t ne01 = node->src[0]->ne[1]; // H
13340
- const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
13341
- const int64_t ne03 = node->src[0]->ne[3]; // Channels In
13342
-
13343
- const int64_t ne10 = node->src[1]->ne[0]; // W
13344
- const int64_t ne11 = node->src[1]->ne[1]; // H
13345
- const int64_t ne12 = node->src[1]->ne[2]; // Channels In
13346
-
13347
- cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03;
13348
- cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12;
13349
- } break;
13350
- case WSP_GGML_OP_FLASH_ATTN_EXT:
13351
- {
13352
- const int64_t ne00 = node->src[0]->ne[0]; // D
14159
+ const int64_t ne10 = node->src[1]->ne[0]; // W
14160
+ const int64_t ne11 = node->src[1]->ne[1]; // H
14161
+ const int64_t ne12 = node->src[1]->ne[2]; // Channels In
13353
14162
 
13354
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
13355
- } break;
13356
- case WSP_GGML_OP_FLASH_ATTN_BACK:
13357
- {
13358
- const int64_t D = node->src[0]->ne[0];
13359
- const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
13360
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
13361
- if (node->src[1]->type == WSP_GGML_TYPE_F32) {
13362
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13363
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13364
- } else if (node->src[1]->type == WSP_GGML_TYPE_F16) {
13365
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13366
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13367
- } else if (node->src[1]->type == WSP_GGML_TYPE_BF16) {
13368
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13369
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13370
- }
13371
- } break;
14163
+ cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03;
14164
+ cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12;
14165
+ } break;
14166
+ case WSP_GGML_OP_FLASH_ATTN_EXT:
14167
+ {
14168
+ const int64_t ne00 = node->src[0]->ne[0]; // D
13372
14169
 
13373
- case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
13374
- {
13375
- cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
13376
- } break;
13377
- case WSP_GGML_OP_COUNT:
13378
- {
13379
- WSP_GGML_ABORT("fatal error");
13380
- }
13381
- default:
13382
- break;
14170
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
14171
+ } break;
14172
+ case WSP_GGML_OP_FLASH_ATTN_BACK:
14173
+ {
14174
+ const int64_t D = node->src[0]->ne[0];
14175
+ const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
14176
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
14177
+ if (node->src[1]->type == WSP_GGML_TYPE_F32) {
14178
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
14179
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
14180
+ } else if (node->src[1]->type == WSP_GGML_TYPE_F16) {
14181
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
14182
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
14183
+ } else if (node->src[1]->type == WSP_GGML_TYPE_BF16) {
14184
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
14185
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
14186
+ }
14187
+ } break;
14188
+
14189
+ case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
14190
+ {
14191
+ cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
14192
+ } break;
14193
+ case WSP_GGML_OP_COUNT:
14194
+ {
14195
+ WSP_GGML_ABORT("fatal error");
14196
+ }
14197
+ default:
14198
+ break;
14199
+ }
13383
14200
  }
13384
14201
 
13385
14202
  work_size = MAX(work_size, cur);
@@ -13414,20 +14231,24 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) {
13414
14231
  /*.threadpool=*/ tp,
13415
14232
  };
13416
14233
 
13417
- for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
14234
+ for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
13418
14235
  struct wsp_ggml_tensor * node = cgraph->nodes[node_n];
13419
14236
 
13420
14237
  wsp_ggml_compute_forward(&params, node);
13421
14238
 
13422
14239
  if (state->ith == 0 && cplan->abort_callback &&
13423
14240
  cplan->abort_callback(cplan->abort_callback_data)) {
13424
- tp->abort = true;
14241
+ atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
13425
14242
  tp->ec = WSP_GGML_STATUS_ABORTED;
13426
14243
  }
13427
14244
 
13428
- wsp_ggml_barrier(state->threadpool);
14245
+ if (node_n + 1 < cgraph->n_nodes) {
14246
+ wsp_ggml_barrier(state->threadpool);
14247
+ }
13429
14248
  }
13430
14249
 
14250
+ wsp_ggml_barrier(state->threadpool);
14251
+
13431
14252
  return 0;
13432
14253
  }
13433
14254
 
@@ -13578,29 +14399,6 @@ static void wsp_ggml_graph_compute_kickoff(struct wsp_ggml_threadpool * threadpo
13578
14399
 
13579
14400
  #endif // WSP_GGML_USE_OPENMP
13580
14401
 
13581
- void wsp_ggml_threadpool_params_init(struct wsp_ggml_threadpool_params * p, int n_threads) {
13582
- p->n_threads = n_threads;
13583
- p->prio = 0; // default priority (usually means normal or inherited)
13584
- p->poll = 50; // hybrid-polling enabled
13585
- p->strict_cpu = false; // no strict placement (all threads share same cpumask)
13586
- p->paused = false; // threads are ready to go
13587
- memset(p->cpumask, 0, WSP_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
13588
- }
13589
-
13590
- struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads) {
13591
- struct wsp_ggml_threadpool_params p;
13592
- wsp_ggml_threadpool_params_init(&p, n_threads);
13593
- return p;
13594
- }
13595
-
13596
- bool wsp_ggml_threadpool_params_match(const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1) {
13597
- if (p0->n_threads != p1->n_threads ) return false;
13598
- if (p0->prio != p1->prio ) return false;
13599
- if (p0->poll != p1->poll ) return false;
13600
- if (p0->strict_cpu != p1->strict_cpu ) return false;
13601
- return memcmp(p0->cpumask, p1->cpumask, WSP_GGML_MAX_N_THREADS) == 0;
13602
- }
13603
-
13604
14402
  static struct wsp_ggml_threadpool * wsp_ggml_threadpool_new_impl(
13605
14403
  struct wsp_ggml_threadpool_params * tpp,
13606
14404
  struct wsp_ggml_cgraph * cgraph,
@@ -13617,7 +14415,7 @@ static struct wsp_ggml_threadpool * wsp_ggml_threadpool_new_impl(
13617
14415
  threadpool->current_chunk = 0;
13618
14416
  threadpool->stop = false;
13619
14417
  threadpool->pause = tpp->paused;
13620
- threadpool->abort = false;
14418
+ threadpool->abort = -1;
13621
14419
  threadpool->workers = NULL;
13622
14420
  threadpool->n_threads_max = tpp->n_threads;
13623
14421
  threadpool->n_threads_cur = tpp->n_threads;
@@ -13696,7 +14494,7 @@ enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, str
13696
14494
  threadpool->cgraph = cgraph;
13697
14495
  threadpool->cplan = cplan;
13698
14496
  threadpool->current_chunk = 0;
13699
- threadpool->abort = false;
14497
+ threadpool->abort = -1;
13700
14498
  threadpool->ec = WSP_GGML_STATUS_SUCCESS;
13701
14499
  }
13702
14500
 
@@ -13895,16 +14693,32 @@ int wsp_ggml_cpu_has_vsx(void) {
13895
14693
  #endif
13896
14694
  }
13897
14695
 
14696
+ int wsp_ggml_cpu_has_vxe(void) {
14697
+ #if defined(__VXE__) || defined(__VXE2__)
14698
+ return 1;
14699
+ #else
14700
+ return 0;
14701
+ #endif
14702
+ }
14703
+
13898
14704
  int wsp_ggml_cpu_has_neon(void) {
13899
- #if defined(__ARM_ARCH)
14705
+ #if defined(__ARM_ARCH) && defined(__ARM_NEON)
13900
14706
  return wsp_ggml_arm_arch_features.has_neon;
13901
14707
  #else
13902
14708
  return 0;
13903
14709
  #endif
13904
14710
  }
13905
14711
 
14712
+ int wsp_ggml_cpu_has_dotprod(void) {
14713
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
14714
+ return wsp_ggml_arm_arch_features.has_dotprod;
14715
+ #else
14716
+ return 0;
14717
+ #endif
14718
+ }
14719
+
13906
14720
  int wsp_ggml_cpu_has_sve(void) {
13907
- #if defined(__ARM_ARCH)
14721
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
13908
14722
  return wsp_ggml_arm_arch_features.has_sve;
13909
14723
  #else
13910
14724
  return 0;
@@ -13912,7 +14726,7 @@ int wsp_ggml_cpu_has_sve(void) {
13912
14726
  }
13913
14727
 
13914
14728
  int wsp_ggml_cpu_has_matmul_int8(void) {
13915
- #if defined(__ARM_ARCH)
14729
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
13916
14730
  return wsp_ggml_arm_arch_features.has_i8mm;
13917
14731
  #else
13918
14732
  return 0;
@@ -13920,13 +14734,21 @@ int wsp_ggml_cpu_has_matmul_int8(void) {
13920
14734
  }
13921
14735
 
13922
14736
  int wsp_ggml_cpu_get_sve_cnt(void) {
13923
- #if defined(__ARM_ARCH)
14737
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
13924
14738
  return wsp_ggml_arm_arch_features.sve_cnt;
13925
14739
  #else
13926
14740
  return 0;
13927
14741
  #endif
13928
14742
  }
13929
14743
 
14744
+ int wsp_ggml_cpu_has_sme(void) {
14745
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
14746
+ return wsp_ggml_arm_arch_features.has_sme;
14747
+ #else
14748
+ return 0;
14749
+ #endif
14750
+ }
14751
+
13930
14752
  void wsp_ggml_cpu_init(void) {
13931
14753
  // needed to initialize f16 tables
13932
14754
  {