whisper.rn 0.5.0-rc.9 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/cpp/ggml-alloc.c +1 -15
  2. package/cpp/ggml-backend-reg.cpp +17 -8
  3. package/cpp/ggml-backend.cpp +15 -22
  4. package/cpp/ggml-common.h +17 -0
  5. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  6. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  7. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  8. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  9. package/cpp/ggml-cpu/arch-fallback.h +34 -0
  10. package/cpp/ggml-cpu/ggml-cpu.c +22 -1
  11. package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
  12. package/cpp/ggml-cpu/ops.cpp +870 -211
  13. package/cpp/ggml-cpu/ops.h +3 -8
  14. package/cpp/ggml-cpu/quants.c +35 -0
  15. package/cpp/ggml-cpu/quants.h +8 -0
  16. package/cpp/ggml-cpu/repack.cpp +458 -47
  17. package/cpp/ggml-cpu/repack.h +22 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +1 -1
  19. package/cpp/ggml-cpu/traits.cpp +2 -2
  20. package/cpp/ggml-cpu/traits.h +1 -1
  21. package/cpp/ggml-cpu/vec.cpp +12 -9
  22. package/cpp/ggml-cpu/vec.h +107 -13
  23. package/cpp/ggml-impl.h +77 -0
  24. package/cpp/ggml-metal-impl.h +51 -12
  25. package/cpp/ggml-metal.m +610 -115
  26. package/cpp/ggml-opt.cpp +97 -41
  27. package/cpp/ggml-opt.h +25 -6
  28. package/cpp/ggml-quants.c +110 -16
  29. package/cpp/ggml-quants.h +6 -0
  30. package/cpp/ggml-whisper-sim.metallib +0 -0
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +314 -88
  33. package/cpp/ggml.h +137 -11
  34. package/cpp/gguf.cpp +8 -1
  35. package/cpp/jsi/RNWhisperJSI.cpp +23 -6
  36. package/cpp/whisper.cpp +15 -6
  37. package/ios/RNWhisper.mm +6 -6
  38. package/ios/RNWhisperContext.mm +2 -0
  39. package/ios/RNWhisperVadContext.mm +2 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  72. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  73. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  74. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  75. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  76. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  77. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  78. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  79. package/package.json +1 -1
  80. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  81. package/src/realtime-transcription/types.ts +6 -0
@@ -8,6 +8,7 @@
8
8
  #include "vec.h"
9
9
 
10
10
  #include <float.h>
11
+ #include <algorithm>
11
12
 
12
13
  // wsp_ggml_compute_forward_dup
13
14
 
@@ -1283,6 +1284,7 @@ void wsp_ggml_compute_forward_add(
1283
1284
  case WSP_GGML_TYPE_Q5_0:
1284
1285
  case WSP_GGML_TYPE_Q5_1:
1285
1286
  case WSP_GGML_TYPE_Q8_0:
1287
+ case WSP_GGML_TYPE_MXFP4:
1286
1288
  case WSP_GGML_TYPE_Q2_K:
1287
1289
  case WSP_GGML_TYPE_Q3_K:
1288
1290
  case WSP_GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void wsp_ggml_compute_forward_add(
1309
1311
  }
1310
1312
  }
1311
1313
 
1314
+ // wsp_ggml_compute_forward_add_id
1315
+
1316
+ static void wsp_ggml_compute_forward_add_id_f32(
1317
+ const wsp_ggml_compute_params * params,
1318
+ wsp_ggml_tensor * dst) {
1319
+
1320
+ const wsp_ggml_tensor * src0 = dst->src[0];
1321
+ const wsp_ggml_tensor * src1 = dst->src[1];
1322
+ const wsp_ggml_tensor * src2 = dst->src[2];
1323
+
1324
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
1325
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
1326
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1327
+ WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_I32);
1328
+
1329
+ WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
1330
+ WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
1331
+
1332
+ const int ith = params->ith;
1333
+ const int nth = params->nth;
1334
+
1335
+ const int nr = wsp_ggml_nrows(src0);
1336
+
1337
+ WSP_GGML_TENSOR_TERNARY_OP_LOCALS
1338
+
1339
+ WSP_GGML_ASSERT( nb0 == sizeof(float));
1340
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
1341
+
1342
+ // rows per thread
1343
+ const int dr = (nr + nth - 1)/nth;
1344
+
1345
+ // row range for this thread
1346
+ const int ir0 = dr*ith;
1347
+ const int ir1 = MIN(ir0 + dr, nr);
1348
+
1349
+ for (int ir = ir0; ir < ir1; ++ir) {
1350
+ // src0 indices
1351
+ const int i3 = ir/(ne2*ne1);
1352
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
1353
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1354
+
1355
+ // src1 indices
1356
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
1357
+
1358
+ WSP_GGML_ASSERT(i11 >= 0 && i11 < ne11);
1359
+
1360
+ wsp_ggml_vec_add_f32(ne0,
1361
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
1362
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
1363
+ (float *) ((char *) src1->data + i11*nb11));
1364
+ }
1365
+ }
1366
+
1367
+ void wsp_ggml_compute_forward_add_id(
1368
+ const wsp_ggml_compute_params * params,
1369
+ wsp_ggml_tensor * dst) {
1370
+
1371
+ const wsp_ggml_tensor * src0 = dst->src[0];
1372
+
1373
+ switch (src0->type) {
1374
+ case WSP_GGML_TYPE_F32:
1375
+ {
1376
+ wsp_ggml_compute_forward_add_id_f32(params, dst);
1377
+ } break;
1378
+ default:
1379
+ {
1380
+ WSP_GGML_ABORT("unsupported type for wsp_ggml_compute_forward_add_id: %s", wsp_ggml_type_name(src0->type));
1381
+ }
1382
+ }
1383
+ }
1384
+
1312
1385
  // wsp_ggml_compute_forward_add1
1313
1386
 
1314
1387
  static void wsp_ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void wsp_ggml_compute_forward_add1(
1660
1733
  case WSP_GGML_TYPE_Q5_1:
1661
1734
  case WSP_GGML_TYPE_Q8_0:
1662
1735
  case WSP_GGML_TYPE_Q8_1:
1736
+ case WSP_GGML_TYPE_MXFP4:
1663
1737
  case WSP_GGML_TYPE_Q2_K:
1664
1738
  case WSP_GGML_TYPE_Q3_K:
1665
1739
  case WSP_GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void wsp_ggml_compute_forward_acc(
1787
1861
  case WSP_GGML_TYPE_Q5_1:
1788
1862
  case WSP_GGML_TYPE_Q8_0:
1789
1863
  case WSP_GGML_TYPE_Q8_1:
1864
+ case WSP_GGML_TYPE_MXFP4:
1790
1865
  case WSP_GGML_TYPE_Q2_K:
1791
1866
  case WSP_GGML_TYPE_Q3_K:
1792
1867
  case WSP_GGML_TYPE_Q4_K:
@@ -3185,11 +3260,356 @@ void wsp_ggml_compute_forward_silu_back(
3185
3260
  }
3186
3261
  }
3187
3262
 
3188
- // wsp_ggml_compute_forward_reglu
3189
-
3190
- static void wsp_ggml_compute_forward_reglu_f32(
3191
- const wsp_ggml_compute_params * params,
3192
- wsp_ggml_tensor * dst) {
3263
+ // wsp_ggml_compute_forward_reglu
3264
+
3265
+ static void wsp_ggml_compute_forward_reglu_f32(
3266
+ const wsp_ggml_compute_params * params,
3267
+ wsp_ggml_tensor * dst) {
3268
+
3269
+ const wsp_ggml_tensor * src0 = dst->src[0];
3270
+ const wsp_ggml_tensor * src1 = dst->src[1];
3271
+ char * src0_d = (char *) src0->data;
3272
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3273
+ const size_t src0_o = src0->nb[1];
3274
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3275
+
3276
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3277
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3278
+
3279
+ if (src1) {
3280
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3281
+ WSP_GGML_ASSERT(src0->type == src1->type);
3282
+ }
3283
+
3284
+ const int ith = params->ith;
3285
+ const int nth = params->nth;
3286
+
3287
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3288
+ const int nr = wsp_ggml_nrows(src0);
3289
+
3290
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3291
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3292
+
3293
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3294
+
3295
+ // rows per thread
3296
+ const int dr = (nr + nth - 1)/nth;
3297
+
3298
+ // row range for this thread
3299
+ const int ir0 = dr*ith;
3300
+ const int ir1 = MIN(ir0 + dr, nr);
3301
+
3302
+ for (int i1 = ir0; i1 < ir1; i1++) {
3303
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3304
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3305
+
3306
+ if (!src1) {
3307
+ src0_p += swapped ? nc : 0;
3308
+ src1_p += swapped ? 0 : nc;
3309
+ }
3310
+
3311
+ wsp_ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3312
+
3313
+ #ifndef NDEBUG
3314
+ for (int k = 0; k < nc; k++) {
3315
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3316
+ WSP_GGML_UNUSED(x);
3317
+ assert(!isnan(x));
3318
+ assert(!isinf(x));
3319
+ }
3320
+ #endif
3321
+ }
3322
+ }
3323
+
3324
+ static void wsp_ggml_compute_forward_reglu_f16(
3325
+ const wsp_ggml_compute_params * params,
3326
+ wsp_ggml_tensor * dst) {
3327
+
3328
+ const wsp_ggml_tensor * src0 = dst->src[0];
3329
+ const wsp_ggml_tensor * src1 = dst->src[1];
3330
+ char * src0_d = (char *) src0->data;
3331
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3332
+ const size_t src0_o = src0->nb[1];
3333
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3334
+
3335
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3336
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3337
+
3338
+ if (src1) {
3339
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3340
+ WSP_GGML_ASSERT(src0->type == src1->type);
3341
+ }
3342
+
3343
+ const int ith = params->ith;
3344
+ const int nth = params->nth;
3345
+
3346
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3347
+ const int nr = wsp_ggml_nrows(src0);
3348
+
3349
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3350
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3351
+
3352
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3353
+
3354
+ // rows per thread
3355
+ const int dr = (nr + nth - 1)/nth;
3356
+
3357
+ // row range for this thread
3358
+ const int ir0 = dr*ith;
3359
+ const int ir1 = MIN(ir0 + dr, nr);
3360
+
3361
+ for (int i1 = ir0; i1 < ir1; i1++) {
3362
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3363
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3364
+
3365
+ if (!src1) {
3366
+ src0_p += swapped ? nc : 0;
3367
+ src1_p += swapped ? 0 : nc;
3368
+ }
3369
+
3370
+ wsp_ggml_vec_reglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3371
+
3372
+ #ifndef NDEBUG
3373
+ for (int k = 0; k < nc; k++) {
3374
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3375
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3376
+ WSP_GGML_UNUSED(v);
3377
+ assert(!isnan(v));
3378
+ assert(!isinf(v));
3379
+ }
3380
+ #endif
3381
+ }
3382
+ }
3383
+
3384
+ static void wsp_ggml_compute_forward_reglu(
3385
+ const wsp_ggml_compute_params * params,
3386
+ wsp_ggml_tensor * dst) {
3387
+
3388
+ const wsp_ggml_tensor * src0 = dst->src[0];
3389
+
3390
+ switch (src0->type) {
3391
+ case WSP_GGML_TYPE_F32:
3392
+ {
3393
+ wsp_ggml_compute_forward_reglu_f32(params, dst);
3394
+ } break;
3395
+ case WSP_GGML_TYPE_F16:
3396
+ {
3397
+ wsp_ggml_compute_forward_reglu_f16(params, dst);
3398
+ } break;
3399
+ default:
3400
+ {
3401
+ WSP_GGML_ABORT("fatal error");
3402
+ }
3403
+ }
3404
+ }
3405
+
3406
+ // wsp_ggml_compute_forward_geglu
3407
+
3408
+ static void wsp_ggml_compute_forward_geglu_f32(
3409
+ const wsp_ggml_compute_params * params,
3410
+ wsp_ggml_tensor * dst) {
3411
+
3412
+ const wsp_ggml_tensor * src0 = dst->src[0];
3413
+ const wsp_ggml_tensor * src1 = dst->src[1];
3414
+ char * src0_d = (char *) src0->data;
3415
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3416
+ const size_t src0_o = src0->nb[1];
3417
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3418
+
3419
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3420
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3421
+
3422
+ if (src1) {
3423
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3424
+ WSP_GGML_ASSERT(src0->type == src1->type);
3425
+ }
3426
+
3427
+ const int ith = params->ith;
3428
+ const int nth = params->nth;
3429
+
3430
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3431
+ const int nr = wsp_ggml_nrows(src0);
3432
+
3433
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3434
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3435
+
3436
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3437
+
3438
+ // rows per thread
3439
+ const int dr = (nr + nth - 1)/nth;
3440
+
3441
+ // row range for this thread
3442
+ const int ir0 = dr*ith;
3443
+ const int ir1 = MIN(ir0 + dr, nr);
3444
+
3445
+ for (int i1 = ir0; i1 < ir1; i1++) {
3446
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3447
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3448
+
3449
+ if (!src1) {
3450
+ src0_p += swapped ? nc : 0;
3451
+ src1_p += swapped ? 0 : nc;
3452
+ }
3453
+
3454
+ wsp_ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3455
+
3456
+ #ifndef NDEBUG
3457
+ for (int k = 0; k < nc; k++) {
3458
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3459
+ WSP_GGML_UNUSED(x);
3460
+ assert(!isnan(x));
3461
+ assert(!isinf(x));
3462
+ }
3463
+ #endif
3464
+ }
3465
+ }
3466
+
3467
+ static void wsp_ggml_compute_forward_geglu_f16(
3468
+ const wsp_ggml_compute_params * params,
3469
+ wsp_ggml_tensor * dst) {
3470
+
3471
+ const wsp_ggml_tensor * src0 = dst->src[0];
3472
+ const wsp_ggml_tensor * src1 = dst->src[1];
3473
+ char * src0_d = (char *) src0->data;
3474
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3475
+ const size_t src0_o = src0->nb[1];
3476
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3477
+
3478
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3479
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3480
+
3481
+ if (src1) {
3482
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3483
+ WSP_GGML_ASSERT(src0->type == src1->type);
3484
+ }
3485
+
3486
+ const int ith = params->ith;
3487
+ const int nth = params->nth;
3488
+
3489
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3490
+ const int nr = wsp_ggml_nrows(src0);
3491
+
3492
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3493
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3494
+
3495
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3496
+
3497
+ // rows per thread
3498
+ const int dr = (nr + nth - 1)/nth;
3499
+
3500
+ // row range for this thread
3501
+ const int ir0 = dr*ith;
3502
+ const int ir1 = MIN(ir0 + dr, nr);
3503
+
3504
+ for (int i1 = ir0; i1 < ir1; i1++) {
3505
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3506
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3507
+
3508
+ if (!src1) {
3509
+ src0_p += swapped ? nc : 0;
3510
+ src1_p += swapped ? 0 : nc;
3511
+ }
3512
+
3513
+ wsp_ggml_vec_geglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3514
+
3515
+ #ifndef NDEBUG
3516
+ for (int k = 0; k < nc; k++) {
3517
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3518
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3519
+ WSP_GGML_UNUSED(v);
3520
+ assert(!isnan(v));
3521
+ assert(!isinf(v));
3522
+ }
3523
+ #endif
3524
+ }
3525
+ }
3526
+
3527
+ static void wsp_ggml_compute_forward_geglu(
3528
+ const wsp_ggml_compute_params * params,
3529
+ wsp_ggml_tensor * dst) {
3530
+
3531
+ const wsp_ggml_tensor * src0 = dst->src[0];
3532
+
3533
+ switch (src0->type) {
3534
+ case WSP_GGML_TYPE_F32:
3535
+ {
3536
+ wsp_ggml_compute_forward_geglu_f32(params, dst);
3537
+ } break;
3538
+ case WSP_GGML_TYPE_F16:
3539
+ {
3540
+ wsp_ggml_compute_forward_geglu_f16(params, dst);
3541
+ } break;
3542
+ default:
3543
+ {
3544
+ WSP_GGML_ABORT("fatal error");
3545
+ }
3546
+ }
3547
+ }
3548
+
3549
+ // wsp_ggml_compute_forward_swiglu
3550
+
3551
+ static void wsp_ggml_compute_forward_swiglu_f32(
3552
+ const wsp_ggml_compute_params * params,
3553
+ wsp_ggml_tensor * dst) {
3554
+
3555
+ const wsp_ggml_tensor * src0 = dst->src[0];
3556
+ const wsp_ggml_tensor * src1 = dst->src[1];
3557
+ char * src0_d = (char *) src0->data;
3558
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3559
+ const size_t src0_o = src0->nb[1];
3560
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3561
+
3562
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3563
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3564
+
3565
+ if (src1) {
3566
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3567
+ WSP_GGML_ASSERT(src0->type == src1->type);
3568
+ }
3569
+
3570
+ const int ith = params->ith;
3571
+ const int nth = params->nth;
3572
+
3573
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3574
+ const int nr = wsp_ggml_nrows(src0);
3575
+
3576
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3577
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3578
+
3579
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3580
+
3581
+ // rows per thread
3582
+ const int dr = (nr + nth - 1)/nth;
3583
+
3584
+ // row range for this thread
3585
+ const int ir0 = dr*ith;
3586
+ const int ir1 = MIN(ir0 + dr, nr);
3587
+
3588
+ for (int i1 = ir0; i1 < ir1; i1++) {
3589
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3590
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3591
+
3592
+ if (!src1) {
3593
+ src0_p += swapped ? nc : 0;
3594
+ src1_p += swapped ? 0 : nc;
3595
+ }
3596
+
3597
+ wsp_ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3598
+
3599
+ #ifndef NDEBUG
3600
+ for (int k = 0; k < nc; k++) {
3601
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3602
+ WSP_GGML_UNUSED(x);
3603
+ assert(!isnan(x));
3604
+ assert(!isinf(x));
3605
+ }
3606
+ #endif
3607
+ }
3608
+ }
3609
+
3610
+ static void wsp_ggml_compute_forward_swiglu_f16(
3611
+ const wsp_ggml_compute_params * params,
3612
+ wsp_ggml_tensor * dst) {
3193
3613
 
3194
3614
  const wsp_ggml_tensor * src0 = dst->src[0];
3195
3615
  const wsp_ggml_tensor * src1 = dst->src[1];
@@ -3225,30 +3645,55 @@ static void wsp_ggml_compute_forward_reglu_f32(
3225
3645
  const int ir1 = MIN(ir0 + dr, nr);
3226
3646
 
3227
3647
  for (int i1 = ir0; i1 < ir1; i1++) {
3228
- float * src0_p = (float *) (src0_d + i1*src0_o);
3229
- float * src1_p = (float *) (src1_d + i1*src1_o);
3648
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3649
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3230
3650
 
3231
3651
  if (!src1) {
3232
3652
  src0_p += swapped ? nc : 0;
3233
3653
  src1_p += swapped ? 0 : nc;
3234
3654
  }
3235
3655
 
3236
- wsp_ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3656
+ wsp_ggml_vec_swiglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
3657
 
3238
3658
  #ifndef NDEBUG
3239
3659
  for (int k = 0; k < nc; k++) {
3240
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3241
- WSP_GGML_UNUSED(x);
3242
- assert(!isnan(x));
3243
- assert(!isinf(x));
3660
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3661
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3662
+ WSP_GGML_UNUSED(v);
3663
+ assert(!isnan(v));
3664
+ assert(!isinf(v));
3244
3665
  }
3245
3666
  #endif
3246
3667
  }
3247
3668
  }
3248
3669
 
3249
- static void wsp_ggml_compute_forward_reglu_f16(
3250
- const wsp_ggml_compute_params * params,
3251
- wsp_ggml_tensor * dst) {
3670
+ static void wsp_ggml_compute_forward_swiglu(
3671
+ const wsp_ggml_compute_params * params,
3672
+ wsp_ggml_tensor * dst) {
3673
+
3674
+ const wsp_ggml_tensor * src0 = dst->src[0];
3675
+
3676
+ switch (src0->type) {
3677
+ case WSP_GGML_TYPE_F32:
3678
+ {
3679
+ wsp_ggml_compute_forward_swiglu_f32(params, dst);
3680
+ } break;
3681
+ case WSP_GGML_TYPE_F16:
3682
+ {
3683
+ wsp_ggml_compute_forward_swiglu_f16(params, dst);
3684
+ } break;
3685
+ default:
3686
+ {
3687
+ WSP_GGML_ABORT("fatal error");
3688
+ }
3689
+ }
3690
+ }
3691
+
3692
+ // wsp_ggml_compute_forward_swiglu_oai
3693
+
3694
+ static void wsp_ggml_compute_forward_swiglu_oai_f32(
3695
+ const wsp_ggml_compute_params * params,
3696
+ wsp_ggml_tensor * dst) {
3252
3697
 
3253
3698
  const wsp_ggml_tensor * src0 = dst->src[0];
3254
3699
  const wsp_ggml_tensor * src1 = dst->src[1];
@@ -3275,6 +3720,8 @@ static void wsp_ggml_compute_forward_reglu_f16(
3275
3720
  WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3276
3721
 
3277
3722
  const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3723
+ const float alpha = wsp_ggml_get_op_params_f32(dst, 2);
3724
+ const float limit = wsp_ggml_get_op_params_f32(dst, 3);
3278
3725
 
3279
3726
  // rows per thread
3280
3727
  const int dr = (nr + nth - 1)/nth;
@@ -3284,29 +3731,34 @@ static void wsp_ggml_compute_forward_reglu_f16(
3284
3731
  const int ir1 = MIN(ir0 + dr, nr);
3285
3732
 
3286
3733
  for (int i1 = ir0; i1 < ir1; i1++) {
3287
- wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3288
- wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3734
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3735
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3736
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3289
3737
 
3290
3738
  if (!src1) {
3291
3739
  src0_p += swapped ? nc : 0;
3292
3740
  src1_p += swapped ? 0 : nc;
3293
3741
  }
3294
3742
 
3295
- wsp_ggml_vec_reglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3743
+ for (int k = 0; k < nc; k++) {
3744
+ const float x = std::min(src0_p[k], limit);
3745
+ const float y = std::clamp(src1_p[k], -limit, limit);
3746
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3747
+ dst_p[k] = out_glu * (y + 1.f);
3748
+ }
3296
3749
 
3297
3750
  #ifndef NDEBUG
3298
3751
  for (int k = 0; k < nc; k++) {
3299
- const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3300
- const float v = WSP_GGML_FP16_TO_FP32(x);
3301
- WSP_GGML_UNUSED(v);
3302
- assert(!isnan(v));
3303
- assert(!isinf(v));
3752
+ const float x = dst_p[k];
3753
+ WSP_GGML_UNUSED(x);
3754
+ assert(!isnan(x));
3755
+ assert(!isinf(x));
3304
3756
  }
3305
3757
  #endif
3306
3758
  }
3307
3759
  }
3308
3760
 
3309
- static void wsp_ggml_compute_forward_reglu(
3761
+ static void wsp_ggml_compute_forward_swiglu_oai(
3310
3762
  const wsp_ggml_compute_params * params,
3311
3763
  wsp_ggml_tensor * dst) {
3312
3764
 
@@ -3315,11 +3767,7 @@ static void wsp_ggml_compute_forward_reglu(
3315
3767
  switch (src0->type) {
3316
3768
  case WSP_GGML_TYPE_F32:
3317
3769
  {
3318
- wsp_ggml_compute_forward_reglu_f32(params, dst);
3319
- } break;
3320
- case WSP_GGML_TYPE_F16:
3321
- {
3322
- wsp_ggml_compute_forward_reglu_f16(params, dst);
3770
+ wsp_ggml_compute_forward_swiglu_oai_f32(params, dst);
3323
3771
  } break;
3324
3772
  default:
3325
3773
  {
@@ -3328,9 +3776,9 @@ static void wsp_ggml_compute_forward_reglu(
3328
3776
  }
3329
3777
  }
3330
3778
 
3331
- // wsp_ggml_compute_forward_geglu
3779
+ // wsp_ggml_compute_forward_geglu_erf
3332
3780
 
3333
- static void wsp_ggml_compute_forward_geglu_f32(
3781
+ static void wsp_ggml_compute_forward_geglu_erf_f32(
3334
3782
  const wsp_ggml_compute_params * params,
3335
3783
  wsp_ggml_tensor * dst) {
3336
3784
 
@@ -3376,7 +3824,7 @@ static void wsp_ggml_compute_forward_geglu_f32(
3376
3824
  src1_p += swapped ? 0 : nc;
3377
3825
  }
3378
3826
 
3379
- wsp_ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3827
+ wsp_ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
3828
 
3381
3829
  #ifndef NDEBUG
3382
3830
  for (int k = 0; k < nc; k++) {
@@ -3389,7 +3837,7 @@ static void wsp_ggml_compute_forward_geglu_f32(
3389
3837
  }
3390
3838
  }
3391
3839
 
3392
- static void wsp_ggml_compute_forward_geglu_f16(
3840
+ static void wsp_ggml_compute_forward_geglu_erf_f16(
3393
3841
  const wsp_ggml_compute_params * params,
3394
3842
  wsp_ggml_tensor * dst) {
3395
3843
 
@@ -3435,7 +3883,7 @@ static void wsp_ggml_compute_forward_geglu_f16(
3435
3883
  src1_p += swapped ? 0 : nc;
3436
3884
  }
3437
3885
 
3438
- wsp_ggml_vec_geglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3886
+ wsp_ggml_vec_geglu_erf_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
3887
 
3440
3888
  #ifndef NDEBUG
3441
3889
  for (int k = 0; k < nc; k++) {
@@ -3449,7 +3897,7 @@ static void wsp_ggml_compute_forward_geglu_f16(
3449
3897
  }
3450
3898
  }
3451
3899
 
3452
- static void wsp_ggml_compute_forward_geglu(
3900
+ static void wsp_ggml_compute_forward_geglu_erf(
3453
3901
  const wsp_ggml_compute_params * params,
3454
3902
  wsp_ggml_tensor * dst) {
3455
3903
 
@@ -3458,11 +3906,11 @@ static void wsp_ggml_compute_forward_geglu(
3458
3906
  switch (src0->type) {
3459
3907
  case WSP_GGML_TYPE_F32:
3460
3908
  {
3461
- wsp_ggml_compute_forward_geglu_f32(params, dst);
3909
+ wsp_ggml_compute_forward_geglu_erf_f32(params, dst);
3462
3910
  } break;
3463
3911
  case WSP_GGML_TYPE_F16:
3464
3912
  {
3465
- wsp_ggml_compute_forward_geglu_f16(params, dst);
3913
+ wsp_ggml_compute_forward_geglu_erf_f16(params, dst);
3466
3914
  } break;
3467
3915
  default:
3468
3916
  {
@@ -3471,9 +3919,9 @@ static void wsp_ggml_compute_forward_geglu(
3471
3919
  }
3472
3920
  }
3473
3921
 
3474
- // wsp_ggml_compute_forward_swiglu
3922
+ // wsp_ggml_compute_forward_geglu_quick
3475
3923
 
3476
- static void wsp_ggml_compute_forward_swiglu_f32(
3924
+ static void wsp_ggml_compute_forward_geglu_quick_f32(
3477
3925
  const wsp_ggml_compute_params * params,
3478
3926
  wsp_ggml_tensor * dst) {
3479
3927
 
@@ -3519,7 +3967,7 @@ static void wsp_ggml_compute_forward_swiglu_f32(
3519
3967
  src1_p += swapped ? 0 : nc;
3520
3968
  }
3521
3969
 
3522
- wsp_ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3970
+ wsp_ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
3971
 
3524
3972
  #ifndef NDEBUG
3525
3973
  for (int k = 0; k < nc; k++) {
@@ -3532,7 +3980,7 @@ static void wsp_ggml_compute_forward_swiglu_f32(
3532
3980
  }
3533
3981
  }
3534
3982
 
3535
- static void wsp_ggml_compute_forward_swiglu_f16(
3983
+ static void wsp_ggml_compute_forward_geglu_quick_f16(
3536
3984
  const wsp_ggml_compute_params * params,
3537
3985
  wsp_ggml_tensor * dst) {
3538
3986
 
@@ -3578,7 +4026,7 @@ static void wsp_ggml_compute_forward_swiglu_f16(
3578
4026
  src1_p += swapped ? 0 : nc;
3579
4027
  }
3580
4028
 
3581
- wsp_ggml_vec_swiglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
4029
+ wsp_ggml_vec_geglu_quick_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
4030
 
3583
4031
  #ifndef NDEBUG
3584
4032
  for (int k = 0; k < nc; k++) {
@@ -3592,7 +4040,7 @@ static void wsp_ggml_compute_forward_swiglu_f16(
3592
4040
  }
3593
4041
  }
3594
4042
 
3595
- static void wsp_ggml_compute_forward_swiglu(
4043
+ static void wsp_ggml_compute_forward_geglu_quick(
3596
4044
  const wsp_ggml_compute_params * params,
3597
4045
  wsp_ggml_tensor * dst) {
3598
4046
 
@@ -3601,11 +4049,11 @@ static void wsp_ggml_compute_forward_swiglu(
3601
4049
  switch (src0->type) {
3602
4050
  case WSP_GGML_TYPE_F32:
3603
4051
  {
3604
- wsp_ggml_compute_forward_swiglu_f32(params, dst);
4052
+ wsp_ggml_compute_forward_geglu_quick_f32(params, dst);
3605
4053
  } break;
3606
4054
  case WSP_GGML_TYPE_F16:
3607
4055
  {
3608
- wsp_ggml_compute_forward_swiglu_f16(params, dst);
4056
+ wsp_ggml_compute_forward_geglu_quick_f16(params, dst);
3609
4057
  } break;
3610
4058
  default:
3611
4059
  {
@@ -3729,6 +4177,9 @@ static void wsp_ggml_compute_forward_rms_norm_f32(
3729
4177
 
3730
4178
  const float scale = 1.0f/sqrtf(mean + eps);
3731
4179
 
4180
+ // if you hit this, likely you got an inf somewhere earlier
4181
+ assert(scale > 0.0f);
4182
+
3732
4183
  wsp_ggml_vec_scale_f32(ne00, y, scale);
3733
4184
  }
3734
4185
  }
@@ -4310,6 +4761,7 @@ void wsp_ggml_compute_forward_out_prod(
4310
4761
  case WSP_GGML_TYPE_Q5_0:
4311
4762
  case WSP_GGML_TYPE_Q5_1:
4312
4763
  case WSP_GGML_TYPE_Q8_0:
4764
+ case WSP_GGML_TYPE_MXFP4:
4313
4765
  case WSP_GGML_TYPE_Q2_K:
4314
4766
  case WSP_GGML_TYPE_Q3_K:
4315
4767
  case WSP_GGML_TYPE_Q4_K:
@@ -4357,9 +4809,11 @@ static void wsp_ggml_compute_forward_scale_f32(
4357
4809
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
4358
4810
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
4359
4811
 
4360
- // scale factor
4361
- float v;
4362
- memcpy(&v, dst->op_params, sizeof(float));
4812
+ float s; // scale factor
4813
+ float b; // bias
4814
+
4815
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4816
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4363
4817
 
4364
4818
  const int ith = params->ith;
4365
4819
  const int nth = params->nth;
@@ -4378,12 +4832,22 @@ static void wsp_ggml_compute_forward_scale_f32(
4378
4832
 
4379
4833
  const size_t nb1 = dst->nb[1];
4380
4834
 
4381
- for (int i1 = ir0; i1 < ir1; i1++) {
4382
- if (dst->data != src0->data) {
4383
- // src0 is same shape as dst => same indices
4384
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4835
+ if (b == 0.0f) {
4836
+ for (int i1 = ir0; i1 < ir1; i1++) {
4837
+ if (dst->data != src0->data) {
4838
+ // src0 is same shape as dst => same indices
4839
+ // TODO: add x parameter to wsp_ggml_vec_scale_f32 and remove this memcpy
4840
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4841
+ }
4842
+ wsp_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4843
+ }
4844
+ } else {
4845
+ for (int i1 = ir0; i1 < ir1; i1++) {
4846
+ wsp_ggml_vec_mad1_f32(nc,
4847
+ (float *) ((char *) dst->data + i1*nb1),
4848
+ (float *) ((char *) src0->data + i1*nb1),
4849
+ s, b);
4385
4850
  }
4386
- wsp_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
4387
4851
  }
4388
4852
  }
4389
4853
 
@@ -4572,6 +5036,7 @@ void wsp_ggml_compute_forward_set(
4572
5036
  case WSP_GGML_TYPE_Q5_1:
4573
5037
  case WSP_GGML_TYPE_Q8_0:
4574
5038
  case WSP_GGML_TYPE_Q8_1:
5039
+ case WSP_GGML_TYPE_MXFP4:
4575
5040
  case WSP_GGML_TYPE_Q2_K:
4576
5041
  case WSP_GGML_TYPE_Q3_K:
4577
5042
  case WSP_GGML_TYPE_Q4_K:
@@ -4833,6 +5298,7 @@ void wsp_ggml_compute_forward_get_rows(
4833
5298
  case WSP_GGML_TYPE_Q5_1:
4834
5299
  case WSP_GGML_TYPE_Q8_0:
4835
5300
  case WSP_GGML_TYPE_Q8_1:
5301
+ case WSP_GGML_TYPE_MXFP4:
4836
5302
  case WSP_GGML_TYPE_Q2_K:
4837
5303
  case WSP_GGML_TYPE_Q3_K:
4838
5304
  case WSP_GGML_TYPE_Q4_K:
@@ -5222,6 +5688,7 @@ static void wsp_ggml_compute_forward_soft_max_f32(
5222
5688
 
5223
5689
  const wsp_ggml_tensor * src0 = dst->src[0];
5224
5690
  const wsp_ggml_tensor * src1 = dst->src[1];
5691
+ const wsp_ggml_tensor * src2 = dst->src[2];
5225
5692
 
5226
5693
  assert(wsp_ggml_is_contiguous(dst));
5227
5694
  assert(wsp_ggml_are_same_shape(src0, dst));
@@ -5232,14 +5699,17 @@ static void wsp_ggml_compute_forward_soft_max_f32(
5232
5699
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5233
5700
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5234
5701
 
5235
- // TODO: handle transposed/permuted matrices
5236
-
5237
5702
  const int ith = params->ith;
5238
5703
  const int nth = params->nth;
5239
5704
 
5240
5705
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
5241
5706
 
5242
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5707
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5708
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5709
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5710
+
5711
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5712
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5243
5713
 
5244
5714
  // TODO: is this supposed to be ceil instead of floor?
5245
5715
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5719,78 @@ static void wsp_ggml_compute_forward_soft_max_f32(
5249
5719
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5250
5720
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5251
5721
 
5252
- const int nc = src0->ne[0];
5253
- const int nr = wsp_ggml_nrows(src0);
5254
-
5255
- // rows per thread
5256
- const int dr = (nr + nth - 1)/nth;
5257
-
5258
- // row range for this thread
5259
- const int ir0 = dr*ith;
5260
- const int ir1 = MIN(ir0 + dr, nr);
5261
-
5262
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5722
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5263
5723
 
5264
5724
  const bool use_f16 = (src1 && src1->type == WSP_GGML_TYPE_F16);
5265
5725
 
5266
- for (int i1 = ir0; i1 < ir1; i1++) {
5267
- // ALiBi
5268
- const uint32_t h = (i1/ne01)%ne02; // head
5269
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5270
-
5271
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5272
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5726
+ // sinks
5727
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5273
5728
 
5274
- // broadcast the mask across rows
5275
- wsp_ggml_fp16_t * mp_f16 = src1 ? (wsp_ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5277
-
5278
- wsp_ggml_vec_cpy_f32 (nc, wp, sp);
5279
- wsp_ggml_vec_scale_f32(nc, wp, scale);
5280
- if (mp_f32) {
5281
- if (use_f16) {
5282
- for (int i = 0; i < nc; ++i) {
5283
- wp[i] += slope*WSP_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5284
- }
5285
- } else {
5286
- for (int i = 0; i < nc; ++i) {
5287
- wp[i] += slope*mp_f32[i];
5729
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5730
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5731
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5732
+ const int64_t i11 = i01;
5733
+ const int64_t i12 = i02%ne12;
5734
+ const int64_t i13 = i03%ne13;
5735
+
5736
+ // ALiBi
5737
+ const uint32_t h = i02; // head
5738
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5739
+
5740
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5741
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5742
+
5743
+ // broadcast the mask across rows
5744
+ wsp_ggml_fp16_t * mp_f16 = src1 ? (wsp_ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5745
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5746
+
5747
+ wsp_ggml_vec_cpy_f32 (ne00, wp, sp);
5748
+ wsp_ggml_vec_scale_f32(ne00, wp, scale);
5749
+ if (mp_f32) {
5750
+ if (use_f16) {
5751
+ for (int i = 0; i < ne00; ++i) {
5752
+ wp[i] += slope*WSP_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5753
+ }
5754
+ } else {
5755
+ for (int i = 0; i < ne00; ++i) {
5756
+ wp[i] += slope*mp_f32[i];
5757
+ }
5758
+ }
5288
5759
  }
5289
- }
5290
- }
5291
5760
 
5292
5761
  #ifndef NDEBUG
5293
- for (int i = 0; i < nc; ++i) {
5294
- //printf("p[%d] = %f\n", i, p[i]);
5295
- assert(!isnan(wp[i]));
5296
- }
5762
+ for (int i = 0; i < ne00; ++i) {
5763
+ //printf("p[%d] = %f\n", i, p[i]);
5764
+ assert(!isnan(wp[i]));
5765
+ }
5297
5766
  #endif
5298
5767
 
5299
- float max = -INFINITY;
5300
- wsp_ggml_vec_max_f32(nc, &max, wp);
5768
+ float max = -INFINITY;
5769
+ wsp_ggml_vec_max_f32(ne00, &max, wp);
5301
5770
 
5302
- wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(nc, dp, wp, max);
5303
- assert(sum > 0.0);
5771
+ // if we have sinks, make a correction as if they were included in the softmax
5772
+ if (sk) {
5773
+ max = MAX(max, sk[i02]);
5774
+ }
5775
+
5776
+ wsp_ggml_float sum = wsp_ggml_vec_soft_max_f32(ne00, dp, wp, max);
5777
+ assert(sum > 0.0);
5778
+
5779
+ if (sk) {
5780
+ sum += (wsp_ggml_float) expf(sk[i02] - max);
5781
+ }
5304
5782
 
5305
- sum = 1.0/sum;
5306
- wsp_ggml_vec_scale_f32(nc, dp, sum);
5783
+ sum = 1.0/sum;
5784
+ wsp_ggml_vec_scale_f32(ne00, dp, sum);
5307
5785
 
5308
5786
  #ifndef NDEBUG
5309
- for (int i = 0; i < nc; ++i) {
5310
- assert(!isnan(dp[i]));
5311
- assert(!isinf(dp[i]));
5312
- }
5787
+ for (int i = 0; i < ne00; ++i) {
5788
+ assert(!isnan(dp[i]));
5789
+ assert(!isinf(dp[i]));
5790
+ }
5313
5791
  #endif
5792
+ }
5793
+ }
5314
5794
  }
5315
5795
  }
5316
5796
 
@@ -5534,6 +6014,7 @@ void wsp_ggml_compute_forward_clamp(
5534
6014
  case WSP_GGML_TYPE_Q5_1:
5535
6015
  case WSP_GGML_TYPE_Q8_0:
5536
6016
  case WSP_GGML_TYPE_Q8_1:
6017
+ case WSP_GGML_TYPE_MXFP4:
5537
6018
  case WSP_GGML_TYPE_Q2_K:
5538
6019
  case WSP_GGML_TYPE_Q3_K:
5539
6020
  case WSP_GGML_TYPE_Q4_K:
@@ -7687,12 +8168,14 @@ void wsp_ggml_compute_forward_argsort(
7687
8168
 
7688
8169
  static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7689
8170
  const wsp_ggml_compute_params * params,
7690
- const wsp_ggml_tensor * q,
7691
- const wsp_ggml_tensor * k,
7692
- const wsp_ggml_tensor * v,
7693
- const wsp_ggml_tensor * mask,
7694
8171
  wsp_ggml_tensor * dst) {
7695
8172
 
8173
+ const wsp_ggml_tensor * q = dst->src[0];
8174
+ const wsp_ggml_tensor * k = dst->src[1];
8175
+ const wsp_ggml_tensor * v = dst->src[2];
8176
+ const wsp_ggml_tensor * mask = dst->src[3];
8177
+ const wsp_ggml_tensor * sinks = dst->src[4];
8178
+
7696
8179
  WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7697
8180
  WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
7698
8181
  WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -7766,7 +8249,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7766
8249
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7767
8250
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7768
8251
 
7769
- wsp_ggml_type const k_vec_dot_type = wsp_ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8252
+ wsp_ggml_type const k_vec_dot_type = wsp_ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7770
8253
  wsp_ggml_from_float_t const q_to_vec_dot = wsp_ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7771
8254
  wsp_ggml_vec_dot_t const kq_vec_dot = wsp_ggml_get_type_traits_cpu(k->type)->vec_dot;
7772
8255
  wsp_ggml_to_float_t const v_to_float = wsp_ggml_get_type_traits(v->type)->to_float;
@@ -7798,7 +8281,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7798
8281
  memset(VKQ32, 0, DV*sizeof(float));
7799
8282
  }
7800
8283
 
7801
- const wsp_ggml_fp16_t * mp = mask ? (wsp_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8284
+ const wsp_ggml_fp16_t * mp = mask ? (wsp_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7802
8285
 
7803
8286
  // k indices
7804
8287
  const int ik3 = iq3 / rk3;
@@ -7887,6 +8370,23 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7887
8370
  }
7888
8371
  }
7889
8372
 
8373
+ // sinks
8374
+ if (sinks) {
8375
+ const float s = ((float *)((char *) sinks->data))[h];
8376
+
8377
+ float ms = 1.0f;
8378
+ float vs = 1.0f;
8379
+
8380
+ if (s > M) {
8381
+ ms = expf(M - s);
8382
+ wsp_ggml_vec_scale_f32(DV, VKQ32, ms);
8383
+ } else {
8384
+ vs = expf(s - M);
8385
+ }
8386
+
8387
+ S = S*ms + vs;
8388
+ }
8389
+
7890
8390
  // V /= S
7891
8391
  const float S_inv = 1.0f/S;
7892
8392
  wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -7906,17 +8406,13 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7906
8406
 
7907
8407
  void wsp_ggml_compute_forward_flash_attn_ext(
7908
8408
  const wsp_ggml_compute_params * params,
7909
- const wsp_ggml_tensor * q,
7910
- const wsp_ggml_tensor * k,
7911
- const wsp_ggml_tensor * v,
7912
- const wsp_ggml_tensor * mask,
7913
8409
  wsp_ggml_tensor * dst) {
7914
8410
  switch (dst->op_params[3]) {
7915
8411
  case WSP_GGML_PREC_DEFAULT:
7916
8412
  case WSP_GGML_PREC_F32:
7917
8413
  {
7918
8414
  // uses F32 accumulators
7919
- wsp_ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8415
+ wsp_ggml_compute_forward_flash_attn_ext_f16(params, dst);
7920
8416
  } break;
7921
8417
  default:
7922
8418
  {
@@ -8336,120 +8832,210 @@ void wsp_ggml_compute_forward_ssm_conv(
8336
8832
  static void wsp_ggml_compute_forward_ssm_scan_f32(
8337
8833
  const wsp_ggml_compute_params * params,
8338
8834
  wsp_ggml_tensor * dst) {
8339
- const wsp_ggml_tensor * src0 = dst->src[0]; // s
8340
- const wsp_ggml_tensor * src1 = dst->src[1]; // x
8341
- const wsp_ggml_tensor * src2 = dst->src[2]; // dt
8342
- const wsp_ggml_tensor * src3 = dst->src[3]; // A
8343
- const wsp_ggml_tensor * src4 = dst->src[4]; // B
8344
- const wsp_ggml_tensor * src5 = dst->src[5]; // C
8835
+ const wsp_ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8836
+ const wsp_ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8837
+ const wsp_ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8838
+ const wsp_ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8839
+ const wsp_ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8840
+ const wsp_ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8841
+ const wsp_ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8345
8842
 
8346
8843
  const int ith = params->ith;
8347
8844
  const int nth = params->nth;
8348
8845
 
8349
- const int64_t nc = src0->ne[0]; // d_state
8350
- const int64_t nr = src0->ne[1]; // d_inner
8351
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8352
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8846
+ const int64_t nc = src0->ne[0]; // d_state
8847
+ const int64_t nr = src0->ne[1]; // dim
8848
+ const int64_t nh = src1->ne[1]; // n_head
8849
+ const int64_t ng = src4->ne[1];
8850
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8851
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8353
8852
 
8354
- WSP_GGML_ASSERT(wsp_ggml_nelements(src1) + wsp_ggml_nelements(src0) == wsp_ggml_nelements(dst));
8853
+ // can't use wsp_ggml_nbytes because src1 is not necessarily contiguous
8854
+ const int64_t s_off = wsp_ggml_nelements(src1) * wsp_ggml_element_size(src1);
8855
+
8856
+ WSP_GGML_ASSERT(wsp_ggml_nelements(src1) + nc*nr*nh*ns == wsp_ggml_nelements(dst));
8355
8857
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
8356
8858
  WSP_GGML_ASSERT(src1->nb[0] == sizeof(float));
8357
8859
  WSP_GGML_ASSERT(src2->nb[0] == sizeof(float));
8358
8860
  WSP_GGML_ASSERT(src3->nb[0] == sizeof(float));
8359
8861
  WSP_GGML_ASSERT(src4->nb[0] == sizeof(float));
8360
8862
  WSP_GGML_ASSERT(src5->nb[0] == sizeof(float));
8361
- // required for the dot product between s and C
8362
- WSP_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8363
- // required for per-sequence offsets for states
8364
- WSP_GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8365
- // required to get correct offset for state destination (i.e. src1->nb[3])
8366
- WSP_GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8863
+ WSP_GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8864
+ // allows optimizing the modulo since n_group should be a power of 2
8865
+ WSP_GGML_ASSERT((ng & -ng) == ng);
8367
8866
 
8368
- // rows per thread
8369
- const int dr = (nr + nth - 1)/nth;
8867
+ // heads per thread
8868
+ const int dh = (nh + nth - 1)/nth;
8370
8869
 
8371
- // row range for this thread
8372
- const int ir0 = dr*ith;
8373
- const int ir1 = MIN(ir0 + dr, nr);
8374
- const int ir = ir1 - ir0;
8870
+ // head range for this thread
8871
+ const int ih0 = dh*ith;
8872
+ const int ih1 = MIN(ih0 + dh, nh);
8873
+
8874
+ const int32_t * ids = (const int32_t *) src6->data;
8875
+
8876
+ for (int i3 = 0; i3 < ns; ++i3) {
8877
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8878
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8879
+
8880
+ for (int i2 = 0; i2 < nt; ++i2) {
8881
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8882
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8883
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8884
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8885
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8886
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8887
+
8888
+ if (src3->ne[0] == 1) {
8889
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8890
+
8891
+ // n_head
8892
+ for (int h = ih0; h < ih1; ++h) {
8893
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8894
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8895
+ const float dA = expf(dt_soft_plus * A[h]);
8896
+
8897
+ // dim
8898
+ for (int i1 = 0; i1 < nr; ++i1) {
8899
+ const int ii = i1 + h*nr;
8900
+ const float x_dt = x[ii] * dt_soft_plus;
8901
+ float sumf = 0.0f;
8902
+ #if defined(WSP_GGML_SIMD)
8903
+ #if defined(__ARM_FEATURE_SVE)
8904
+ const int wsp_ggml_f32_epr = svcntw();
8905
+ const int wsp_ggml_f32_step = 1 * wsp_ggml_f32_epr;
8906
+
8907
+ const int np = (nc & ~(wsp_ggml_f32_step - 1));
8908
+
8909
+ WSP_GGML_F32_VEC sum = WSP_GGML_F32_VEC_ZERO;
8910
+
8911
+ WSP_GGML_F32_VEC adA = WSP_GGML_F32_VEC_SET1(dA);
8912
+ WSP_GGML_F32_VEC axdt = WSP_GGML_F32_VEC_SET1(x_dt);
8913
+
8914
+ for (int i = 0; i < np; i += wsp_ggml_f32_step) {
8915
+ // TODO: maybe unroll more?
8916
+ for (int j = 0; j < 1; j++) {
8917
+ WSP_GGML_F32_VEC t0 = WSP_GGML_F32_VEC_LOAD(s0 + i + j*wsp_ggml_f32_epr + ii*nc);
8918
+ WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8919
+ WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8920
+
8921
+ t0 = WSP_GGML_F32_VEC_MUL(t0, adA);
8922
+ t1 = WSP_GGML_F32_VEC_MUL(t1, axdt);
8923
+
8924
+ t0 = WSP_GGML_F32_VEC_ADD(t0, t1);
8925
+
8926
+ sum = WSP_GGML_F32_VEC_FMA(sum, t0, t2);
8927
+
8928
+ WSP_GGML_F32_VEC_STORE(s + i + j*wsp_ggml_f32_epr + ii*nc, t0);
8929
+ }
8930
+ }
8931
+
8932
+ sumf = WSP_GGML_F32xt_REDUCE_ONE(sum);
8933
+ #else
8934
+ const int np = (nc & ~(WSP_GGML_F32_STEP - 1));
8935
+
8936
+ WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO };
8937
+
8938
+ WSP_GGML_F32_VEC adA = WSP_GGML_F32_VEC_SET1(dA);
8939
+ WSP_GGML_F32_VEC axdt = WSP_GGML_F32_VEC_SET1(x_dt);
8940
+
8941
+ WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR];
8942
+ WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
8943
+ WSP_GGML_F32_VEC az[WSP_GGML_F32_ARR];
8944
+
8945
+ for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
8946
+ for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
8947
+ ax[j] = WSP_GGML_F32_VEC_LOAD(s0 + i + j*WSP_GGML_F32_EPR + ii*nc);
8948
+ ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8949
+ az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8950
+
8951
+ ax[j] = WSP_GGML_F32_VEC_MUL(ax[j], adA);
8952
+ ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], axdt);
8953
+
8954
+ ax[j] = WSP_GGML_F32_VEC_ADD(ax[j], ay[j]);
8955
+
8956
+ sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8957
+
8958
+ WSP_GGML_F32_VEC_STORE(s + i + j*WSP_GGML_F32_EPR + ii*nc, ax[j]);
8959
+ }
8960
+ }
8375
8961
 
8376
- #ifdef __ARM_FEATURE_SVE
8377
- for (int i3 = 0; i3 < n_s; ++i3) {
8378
- for (int i2 = 0; i2 < n_t; ++i2) {
8379
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8380
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8381
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8382
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8383
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8384
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8385
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8386
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8387
-
8388
- // use the output as the source for the next token-wise iterations
8389
- if (i2 > 0) { s0 = s; }
8390
-
8391
- // d_inner
8392
- for (int i1 = 0; i1 < ir; ++i1) {
8393
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8394
- float x_dt = x[i1] * dt_soft_plus;
8395
- svfloat32_t vx_dt = WSP_GGML_F32_VEC_SET1(x_dt);
8396
- svfloat32_t vdt_soft_plus = WSP_GGML_F32_VEC_SET1(dt_soft_plus);
8397
- svfloat32_t r1_vector = WSP_GGML_F32_VEC_ZERO;
8398
-
8399
- for (int64_t k = 0; k < nc; k += svcntw()) {
8400
- svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8401
- svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k]);
8402
- svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k]);
8403
- svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8404
-
8405
- svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8406
- t1 = exp_ps_sve(svptrue_b32(), t1);
8407
- svfloat32_t t2 = WSP_GGML_F32_VEC_MUL(vx_dt, vB);
8408
-
8409
- vs0 = WSP_GGML_F32_VEC_FMA(vs0, t1, t2);
8410
- r1_vector = WSP_GGML_F32_VEC_ADD(WSP_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8411
-
8412
- WSP_GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8962
+ // reduce sum0..sum3 to sum0
8963
+ WSP_GGML_F32_VEC_REDUCE(sumf, sum);
8964
+ #endif
8965
+ #else
8966
+ const int np = 0;
8967
+ #endif
8968
+ // d_state
8969
+ for (int i0 = np; i0 < nc; ++i0) {
8970
+ const int i = i0 + ii*nc;
8971
+ const int ig = i0 + (h & (ng - 1))*nc;
8972
+ // state = prev_state * dA + dB * x
8973
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8974
+ // y = rowwise_dotprod(state, C)
8975
+ sumf += state * C[ig];
8976
+ s[i] = state;
8977
+ }
8978
+ y[ii] = sumf;
8413
8979
  }
8414
- y[i1] = WSP_GGML_F32xt_REDUCE_ONE(r1_vector);
8415
8980
  }
8416
- }
8417
- }
8418
- #else
8419
- for (int i3 = 0; i3 < n_s; ++i3) {
8420
- for (int i2 = 0; i2 < n_t; ++i2) {
8421
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8422
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8423
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8424
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8425
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8426
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8427
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8428
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8429
-
8430
- // use the output as the source for the next token-wise iterations
8431
- if (i2 > 0) { s0 = s; }
8432
-
8433
- // d_inner
8434
- for (int i1 = 0; i1 < ir; ++i1) {
8435
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8436
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8437
- float x_dt = x[i1] * dt_soft_plus;
8438
- float sumf = 0.0f;
8439
- // d_state
8440
- for (int i0 = 0; i0 < nc; ++i0) {
8441
- int i = i0 + i1*nc;
8442
- // state = prev_state * dA + dB * x
8443
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8444
- // y = rowwise_dotprod(state, C)
8445
- sumf += state * C[i0];
8446
- s[i] = state;
8981
+ } else {
8982
+ // Mamba-1 has an element-wise decay factor for the states
8983
+
8984
+ // n_head
8985
+ for (int h = ih0; h < ih1; ++h) {
8986
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8987
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8988
+
8989
+ // dim
8990
+ for (int i1 = 0; i1 < nr; ++i1) {
8991
+ const int ii = i1 + h*nr;
8992
+ const float x_dt = x[ii] * dt_soft_plus;
8993
+ #if defined(__ARM_FEATURE_SVE)
8994
+ svfloat32_t vx_dt = WSP_GGML_F32_VEC_SET1(x_dt);
8995
+ svfloat32_t vdt_soft_plus = WSP_GGML_F32_VEC_SET1(dt_soft_plus);
8996
+ svfloat32_t r1_vector = WSP_GGML_F32_VEC_ZERO;
8997
+
8998
+ // d_state
8999
+ // TODO: what happens when (d_state % svcntw()) != 0?
9000
+ for (int64_t k = 0; k < nc; k += svcntw()) {
9001
+ svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[h*nc + k]);
9002
+ svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
9003
+ svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
9004
+ svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9005
+
9006
+ svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
9007
+ t1 = exp_ps_sve(svptrue_b32(), t1);
9008
+ svfloat32_t t2 = WSP_GGML_F32_VEC_MUL(vx_dt, vB);
9009
+
9010
+ vs0 = WSP_GGML_F32_VEC_FMA(t2, vs0, t1);
9011
+ r1_vector = WSP_GGML_F32_VEC_ADD(WSP_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
9012
+
9013
+ WSP_GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
9014
+ }
9015
+ y[ii] = WSP_GGML_F32xt_REDUCE_ONE(r1_vector);
9016
+ #else
9017
+ float sumf = 0.0f;
9018
+ // NOTE: can't really use WSP_GGML_SIMD here because d_state is usually 16
9019
+ // and also because expf is used within the loop.
9020
+ // d_state
9021
+ for (int i0 = 0; i0 < nc; ++i0) {
9022
+ const int i = i0 + ii*nc;
9023
+ const int ig = i0 + (h & (ng - 1))*nc;
9024
+ // state = prev_state * dA + dB * x
9025
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9026
+ // y = rowwise_dotprod(state, C)
9027
+ sumf += state * C[ig];
9028
+ s[i] = state;
9029
+ }
9030
+ y[ii] = sumf;
9031
+ #endif
8447
9032
  }
8448
- y[i1] = sumf;
8449
9033
  }
8450
9034
  }
9035
+ // use the output as the source when it's not the first token-wise iteration
9036
+ s0 = s;
8451
9037
  }
8452
- #endif
9038
+ }
8453
9039
  }
8454
9040
 
8455
9041
  void wsp_ggml_compute_forward_ssm_scan(
@@ -8688,6 +9274,18 @@ void wsp_ggml_compute_forward_glu(
8688
9274
  {
8689
9275
  wsp_ggml_compute_forward_swiglu(params, dst);
8690
9276
  } break;
9277
+ case WSP_GGML_GLU_OP_SWIGLU_OAI:
9278
+ {
9279
+ wsp_ggml_compute_forward_swiglu_oai(params, dst);
9280
+ } break;
9281
+ case WSP_GGML_GLU_OP_GEGLU_ERF:
9282
+ {
9283
+ wsp_ggml_compute_forward_geglu_erf(params, dst);
9284
+ } break;
9285
+ case WSP_GGML_GLU_OP_GEGLU_QUICK:
9286
+ {
9287
+ wsp_ggml_compute_forward_geglu_quick(params, dst);
9288
+ } break;
8691
9289
  default:
8692
9290
  {
8693
9291
  WSP_GGML_ABORT("fatal error");
@@ -9732,6 +10330,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
9732
10330
  const int ir1 = MIN(ir0 + dr, nr);
9733
10331
 
9734
10332
  const float * adamw_params_ptr = wsp_ggml_get_data_f32(adamw_params);
10333
+
9735
10334
  const float alpha = adamw_params_ptr[0];
9736
10335
  const float beta1 = adamw_params_ptr[1];
9737
10336
  const float beta2 = adamw_params_ptr[2];
@@ -9739,7 +10338,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
9739
10338
  const float wd = adamw_params_ptr[4];
9740
10339
  const float beta1h = adamw_params_ptr[5];
9741
10340
  const float beta2h = adamw_params_ptr[6];
9742
-
10341
+ const float keep = 1.f - alpha * wd;
9743
10342
  for (int ir = ir0; ir < ir1; ++ir) {
9744
10343
  const int64_t i03 = ir/(ne02*ne01);
9745
10344
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -9762,7 +10361,7 @@ static void wsp_ggml_compute_forward_opt_step_adamw_f32(
9762
10361
  // The weight decay is applied independently of the Adam momenta m and v.
9763
10362
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
9764
10363
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
9765
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10364
+ w[i00] = w[i00] * keep - alpha * mh / vh;
9766
10365
  }
9767
10366
  }
9768
10367
  }
@@ -9784,3 +10383,63 @@ void wsp_ggml_compute_forward_opt_step_adamw(
9784
10383
  }
9785
10384
  }
9786
10385
  }
10386
+
10387
+ static void wsp_ggml_compute_forward_opt_step_sgd_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
10388
+ const wsp_ggml_tensor * src0 = dst->src[0];
10389
+ const wsp_ggml_tensor * src0_grad = dst->src[1];
10390
+ const wsp_ggml_tensor * sgd_params = dst->src[2];
10391
+
10392
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src0_grad));
10393
+ WSP_GGML_ASSERT(wsp_ggml_nelements(sgd_params) == 2);
10394
+
10395
+ const int ith = params->ith;
10396
+ const int nth = params->nth;
10397
+
10398
+ const int nr = wsp_ggml_nrows(src0);
10399
+
10400
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
10401
+ WSP_GGML_ASSERT(nb00 == sizeof(float));
10402
+
10403
+ // rows per thread
10404
+ const int dr = (nr + nth - 1) / nth;
10405
+
10406
+ // row range for this thread
10407
+ const int ir0 = dr * ith;
10408
+ const int ir1 = MIN(ir0 + dr, nr);
10409
+
10410
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10411
+ const float * sgd_params_ptr = wsp_ggml_get_data_f32(sgd_params);
10412
+ const float alpha = sgd_params_ptr[0];
10413
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10414
+
10415
+ for (int ir = ir0; ir < ir1; ++ir) {
10416
+ const int64_t i03 = ir / (ne02 * ne01);
10417
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10418
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10419
+
10420
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10421
+
10422
+ float * w = (float *) ((char *) src0->data + offset); // weight
10423
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10424
+
10425
+ for (int i00 = 0; i00 < ne00; ++i00) {
10426
+ w[i00] = w[i00] * keep - alpha * g[i00];
10427
+ }
10428
+ }
10429
+ }
10430
+
10431
+ void wsp_ggml_compute_forward_opt_step_sgd(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
10432
+ const wsp_ggml_tensor * src0 = dst->src[0];
10433
+
10434
+ switch (src0->type) {
10435
+ case WSP_GGML_TYPE_F32:
10436
+ {
10437
+ wsp_ggml_compute_forward_opt_step_sgd_f32(params, dst);
10438
+ }
10439
+ break;
10440
+ default:
10441
+ {
10442
+ WSP_GGML_ABORT("fatal error - sgd is F32 only");
10443
+ }
10444
+ }
10445
+ }