whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -7,8 +7,10 @@
7
7
  #include "unary-ops.h"
8
8
  #include "vec.h"
9
9
 
10
- #include <float.h>
10
+ #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cmath>
13
+ #include <functional>
12
14
 
13
15
  // wsp_ggml_compute_forward_dup
14
16
 
@@ -1394,6 +1396,56 @@ void wsp_ggml_compute_forward_sum(
1394
1396
  }
1395
1397
  }
1396
1398
 
1399
+ // wsp_ggml_compute_forward_cumsum
1400
+
1401
+ static void wsp_ggml_compute_forward_cumsum_f32(
1402
+ const wsp_ggml_compute_params * params,
1403
+ wsp_ggml_tensor * dst) {
1404
+
1405
+ const wsp_ggml_tensor * src0 = dst->src[0];
1406
+
1407
+ WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
+ WSP_GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
+
1410
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
1411
+
1412
+ WSP_GGML_ASSERT(ne0 == ne00);
1413
+ WSP_GGML_ASSERT(ne1 == ne01);
1414
+ WSP_GGML_ASSERT(ne2 == ne02);
1415
+ WSP_GGML_ASSERT(ne3 == ne03);
1416
+
1417
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1418
+
1419
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
+ const int64_t i03 = ir/(ne02*ne01);
1421
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
+
1424
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1426
+
1427
+ wsp_ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
+ }
1429
+ }
1430
+
1431
+ void wsp_ggml_compute_forward_cumsum(
1432
+ const wsp_ggml_compute_params * params,
1433
+ wsp_ggml_tensor * dst) {
1434
+
1435
+ const wsp_ggml_tensor * src0 = dst->src[0];
1436
+
1437
+ switch (src0->type) {
1438
+ case WSP_GGML_TYPE_F32:
1439
+ {
1440
+ wsp_ggml_compute_forward_cumsum_f32(params, dst);
1441
+ } break;
1442
+ default:
1443
+ {
1444
+ WSP_GGML_ABORT("fatal error");
1445
+ }
1446
+ }
1447
+ }
1448
+
1397
1449
  // wsp_ggml_compute_forward_sum_rows
1398
1450
 
1399
1451
  static void wsp_ggml_compute_forward_sum_rows_f32(
@@ -2140,6 +2192,83 @@ static void wsp_ggml_compute_forward_gelu(
2140
2192
  }
2141
2193
  }
2142
2194
 
2195
+ // wsp_ggml_compute_fill
2196
+
2197
+ static void wsp_ggml_compute_forward_fill_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2198
+ const float c = wsp_ggml_get_op_params_f32(dst, 0);
2199
+
2200
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2202
+
2203
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2204
+
2205
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
+ const int64_t i03 = ir/(ne2*ne1);
2207
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
+
2210
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
+
2212
+ wsp_ggml_vec_set_f32(ne0, dst_ptr, c);
2213
+ }
2214
+ }
2215
+
2216
+ void wsp_ggml_compute_forward_fill(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2217
+ wsp_ggml_compute_forward_fill_f32(params, dst);
2218
+ }
2219
+
2220
+ // wsp_ggml_compute_tri
2221
+
2222
+ static void wsp_ggml_compute_forward_tri_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2223
+ const wsp_ggml_tensor * src0 = dst->src[0];
2224
+
2225
+ const wsp_ggml_tri_type ttype = (wsp_ggml_tri_type) wsp_ggml_get_op_params_i32(dst, 0);
2226
+
2227
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2228
+
2229
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
2230
+
2231
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2232
+
2233
+ bool (*bipred)(int, int);
2234
+
2235
+ switch (ttype) {
2236
+ case WSP_GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2237
+ case WSP_GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
+ case WSP_GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2239
+ case WSP_GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
+ default: WSP_GGML_ABORT("invalid tri type");
2241
+ }
2242
+
2243
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
+ const int64_t i03 = ir/(ne02*ne01);
2245
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
+
2248
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2250
+
2251
+ for (int i0 = 0; i0 < ne0; ++i0) {
2252
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
+ }
2254
+ }
2255
+ }
2256
+
2257
+ void wsp_ggml_compute_forward_tri(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2258
+ const wsp_ggml_tensor * src0 = dst->src[0];
2259
+
2260
+ switch (src0->type) {
2261
+ case WSP_GGML_TYPE_F32:
2262
+ {
2263
+ wsp_ggml_compute_forward_tri_f32(params, dst);
2264
+ } break;
2265
+ default:
2266
+ {
2267
+ WSP_GGML_ABORT("fatal error");
2268
+ }
2269
+ }
2270
+ }
2271
+
2143
2272
  // wsp_ggml_compute_forward_gelu_erf
2144
2273
 
2145
2274
  static void wsp_ggml_compute_forward_gelu_erf_f32(
@@ -3467,31 +3596,27 @@ static void wsp_ggml_compute_forward_norm_f32(
3467
3596
 
3468
3597
  WSP_GGML_ASSERT(eps >= 0.0f);
3469
3598
 
3470
- // TODO: optimize
3471
3599
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3472
3600
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3473
3601
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3474
3602
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3475
3603
 
3476
- wsp_ggml_float sum = 0.0;
3477
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3478
- sum += (wsp_ggml_float)x[i00];
3479
- }
3480
-
3604
+ float sum = 0.0;
3605
+ wsp_ggml_vec_sum_f32(ne00, &sum, x);
3481
3606
  float mean = sum/ne00;
3482
3607
 
3483
3608
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3609
+ float variance = 0;
3484
3610
 
3485
- wsp_ggml_float sum2 = 0.0;
3486
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3487
- float v = x[i00] - mean;
3488
- y[i00] = v;
3489
- sum2 += (wsp_ggml_float)(v*v);
3490
- }
3611
+ #ifdef WSP_GGML_USE_ACCELERATE
3612
+ mean = -mean;
3613
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3614
+ vDSP_measqv(y, 1, &variance, ne00);
3615
+ #else
3616
+ variance = wsp_ggml_vec_cvar_f32(ne00, y, x, mean);
3617
+ #endif //WSP_GGML_USE_ACCELERATE
3491
3618
 
3492
- float variance = sum2/ne00;
3493
3619
  const float scale = 1.0f/sqrtf(variance + eps);
3494
-
3495
3620
  wsp_ggml_vec_scale_f32(ne00, y, scale);
3496
3621
  }
3497
3622
  }
@@ -4459,46 +4584,6 @@ void wsp_ggml_compute_forward_cont(
4459
4584
  wsp_ggml_compute_forward_dup(params, dst);
4460
4585
  }
4461
4586
 
4462
- // wsp_ggml_compute_forward_reshape
4463
-
4464
- void wsp_ggml_compute_forward_reshape(
4465
- const wsp_ggml_compute_params * params,
4466
- wsp_ggml_tensor * dst) {
4467
- // NOP
4468
- WSP_GGML_UNUSED(params);
4469
- WSP_GGML_UNUSED(dst);
4470
- }
4471
-
4472
- // wsp_ggml_compute_forward_view
4473
-
4474
- void wsp_ggml_compute_forward_view(
4475
- const wsp_ggml_compute_params * params,
4476
- wsp_ggml_tensor * dst) {
4477
- // NOP
4478
- WSP_GGML_UNUSED(params);
4479
- WSP_GGML_UNUSED(dst);
4480
- }
4481
-
4482
- // wsp_ggml_compute_forward_permute
4483
-
4484
- void wsp_ggml_compute_forward_permute(
4485
- const wsp_ggml_compute_params * params,
4486
- wsp_ggml_tensor * dst) {
4487
- // NOP
4488
- WSP_GGML_UNUSED(params);
4489
- WSP_GGML_UNUSED(dst);
4490
- }
4491
-
4492
- // wsp_ggml_compute_forward_transpose
4493
-
4494
- void wsp_ggml_compute_forward_transpose(
4495
- const wsp_ggml_compute_params * params,
4496
- wsp_ggml_tensor * dst) {
4497
- // NOP
4498
- WSP_GGML_UNUSED(params);
4499
- WSP_GGML_UNUSED(dst);
4500
- }
4501
-
4502
4587
  // wsp_ggml_compute_forward_get_rows
4503
4588
 
4504
4589
  static void wsp_ggml_compute_forward_get_rows_q(
@@ -5478,7 +5563,7 @@ static void wsp_ggml_rope_cache_init(
5478
5563
  }
5479
5564
 
5480
5565
  static void wsp_ggml_mrope_cache_init(
5481
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5566
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5482
5567
  float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5483
5568
  float * cache, float sin_sign, float theta_scale) {
5484
5569
  // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5513,14 +5598,26 @@ static void wsp_ggml_mrope_cache_init(
5513
5598
  }
5514
5599
 
5515
5600
  float theta = theta_t;
5516
- if (sector >= sections[0] && sector < sec_w) {
5517
- theta = theta_h;
5518
- }
5519
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
5520
- theta = theta_w;
5521
- }
5522
- else if (sector >= sec_w + sections[2]) {
5523
- theta = theta_e;
5601
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5602
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
+ theta = theta_h;
5604
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
+ theta = theta_w;
5606
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
+ theta = theta_t;
5608
+ } else {
5609
+ theta = theta_e;
5610
+ }
5611
+ } else {
5612
+ if (sector >= sections[0] && sector < sec_w) {
5613
+ theta = theta_h;
5614
+ }
5615
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
+ theta = theta_w;
5617
+ }
5618
+ else if (sector >= sec_w + sections[2]) {
5619
+ theta = theta_e;
5620
+ }
5524
5621
  }
5525
5622
 
5526
5623
  rope_yarn(
@@ -5535,193 +5632,28 @@ static void wsp_ggml_mrope_cache_init(
5535
5632
  }
5536
5633
  }
5537
5634
 
5538
- static void wsp_ggml_compute_forward_rope_f32(
5539
- const wsp_ggml_compute_params * params,
5540
- wsp_ggml_tensor * dst,
5541
- const bool forward) {
5542
-
5543
- const wsp_ggml_tensor * src0 = dst->src[0];
5544
- const wsp_ggml_tensor * src1 = dst->src[1];
5545
- const wsp_ggml_tensor * src2 = dst->src[2];
5546
-
5547
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5548
- int sections[4];
5549
-
5550
- //const int n_past = ((int32_t *) dst->op_params)[0];
5551
- const int n_dims = ((int32_t *) dst->op_params)[1];
5552
- const int mode = ((int32_t *) dst->op_params)[2];
5553
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
5554
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5555
5635
 
5556
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5557
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5558
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5559
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5560
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5561
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5562
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5563
-
5564
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
5565
-
5566
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5567
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5568
-
5569
- WSP_GGML_ASSERT(nb00 == sizeof(float));
5570
-
5571
- const int ith = params->ith;
5572
- const int nth = params->nth;
5573
-
5574
- const int nr = wsp_ggml_nrows(dst);
5575
-
5576
- WSP_GGML_ASSERT(n_dims <= ne0);
5577
- WSP_GGML_ASSERT(n_dims % 2 == 0);
5578
-
5579
- // rows per thread
5580
- const int dr = (nr + nth - 1)/nth;
5581
-
5582
- // row range for this thread
5583
- const int ir0 = dr*ith;
5584
- const int ir1 = MIN(ir0 + dr, nr);
5585
-
5586
- // row index used to determine which thread to use
5587
- int ir = 0;
5588
-
5589
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
5590
-
5591
- float corr_dims[2];
5592
- wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5593
-
5594
- const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
5595
- const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
5596
- const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
5597
-
5598
- if (is_mrope) {
5599
- WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5600
- }
5601
-
5602
- if (is_vision) {
5603
- WSP_GGML_ASSERT(n_dims == ne0/2);
5604
- }
5605
-
5606
- const float * freq_factors = NULL;
5607
- if (src2 != NULL) {
5608
- WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_F32);
5609
- WSP_GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5610
- freq_factors = (const float *) src2->data;
5611
- }
5612
-
5613
- // backward process uses inverse rotation by cos and sin.
5614
- // cos and sin build a rotation matrix, where the inverse is the transpose.
5615
- // this essentially just switches the sign of sin.
5616
- const float sin_sign = forward ? 1.0f : -1.0f;
5617
-
5618
- const int32_t * pos = (const int32_t *) src1->data;
5619
-
5620
- for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5621
- for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5622
-
5623
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5624
- if (!is_mrope) {
5625
- const int64_t p = pos[i2];
5626
- wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5627
- }
5628
- else {
5629
- const int64_t p_t = pos[i2];
5630
- const int64_t p_h = pos[i2 + ne2];
5631
- const int64_t p_w = pos[i2 + ne2 * 2];
5632
- const int64_t p_e = pos[i2 + ne2 * 3];
5633
- wsp_ggml_mrope_cache_init(
5634
- p_t, p_h, p_w, p_e, sections, is_vision,
5635
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5636
- }
5637
-
5638
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5639
- if (ir++ < ir0) continue;
5640
- if (ir > ir1) break;
5641
-
5642
- if (is_neox || is_mrope) {
5643
- if (is_vision){
5644
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5645
- const int64_t ic = i0/2;
5646
-
5647
- const float cos_theta = cache[i0 + 0];
5648
- const float sin_theta = cache[i0 + 1];
5649
-
5650
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5651
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5652
-
5653
- const float x0 = src[0];
5654
- const float x1 = src[n_dims];
5655
-
5656
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5657
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5658
- }
5659
- } else {
5660
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5661
- const int64_t ic = i0/2;
5636
+ template<typename T>
5637
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
+ const int64_t ic = i0/scale; // hack for WSP_GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5662
5640
 
5663
- const float cos_theta = cache[i0 + 0];
5664
- const float sin_theta = cache[i0 + 1];
5641
+ const float cos_theta = cache[i0 + 0];
5642
+ const float sin_theta = cache[i0 + 1];
5665
5643
 
5666
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5667
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5644
+ const T * const src = src_data + ic;
5645
+ T * dst = dst_data + ic;
5668
5646
 
5669
- const float x0 = src[0];
5670
- const float x1 = src[n_dims/2];
5647
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5671
5649
 
5672
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5673
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5674
- }
5675
- }
5676
- } else {
5677
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5678
- const float cos_theta = cache[i0 + 0];
5679
- const float sin_theta = cache[i0 + 1];
5680
-
5681
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5682
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5683
-
5684
- const float x0 = src[0];
5685
- const float x1 = src[1];
5686
-
5687
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5688
- dst_data[1] = x0*sin_theta + x1*cos_theta;
5689
- }
5690
- }
5691
-
5692
- if (is_vision) {
5693
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5694
- const int64_t ic = i0/2;
5695
-
5696
- const float cos_theta = cache[i0 + 0];
5697
- const float sin_theta = cache[i0 + 1];
5698
-
5699
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5700
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5701
-
5702
- const float x0 = src[0];
5703
- const float x1 = src[n_dims];
5704
-
5705
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5706
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5707
- }
5708
- } else {
5709
- // fill the remain channels with data from src tensor
5710
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5711
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5712
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5713
-
5714
- dst_data[0] = src[0];
5715
- dst_data[1] = src[1];
5716
- }
5717
- }
5718
- }
5719
- }
5720
- }
5650
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
+ }
5721
5653
  }
5722
5654
 
5723
- // TODO: deduplicate f16/f32 code
5724
- static void wsp_ggml_compute_forward_rope_f16(
5655
+ template<typename T> //float or wsp_ggml_fp16_t
5656
+ static void wsp_ggml_compute_forward_rope_flt(
5725
5657
  const wsp_ggml_compute_params * params,
5726
5658
  wsp_ggml_tensor * dst,
5727
5659
  const bool forward) {
@@ -5730,6 +5662,9 @@ static void wsp_ggml_compute_forward_rope_f16(
5730
5662
  const wsp_ggml_tensor * src1 = dst->src[1];
5731
5663
  const wsp_ggml_tensor * src2 = dst->src[2];
5732
5664
 
5665
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32 || src0->type == WSP_GGML_TYPE_F16);
5666
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32);
5667
+
5733
5668
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5734
5669
  int sections[4];
5735
5670
 
@@ -5738,6 +5673,7 @@ static void wsp_ggml_compute_forward_rope_f16(
5738
5673
  const int mode = ((int32_t *) dst->op_params)[2];
5739
5674
  //const int n_ctx = ((int32_t *) dst->op_params)[3];
5740
5675
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5676
+
5741
5677
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5742
5678
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5743
5679
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -5746,13 +5682,13 @@ static void wsp_ggml_compute_forward_rope_f16(
5746
5682
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5747
5683
  memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5748
5684
 
5749
-
5750
5685
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
5751
5686
 
5752
5687
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5753
5688
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5754
5689
 
5755
- WSP_GGML_ASSERT(nb0 == sizeof(wsp_ggml_fp16_t));
5690
+ WSP_GGML_ASSERT(nb0 == nb00);
5691
+ WSP_GGML_ASSERT(nb0 == sizeof(T));
5756
5692
 
5757
5693
  const int ith = params->ith;
5758
5694
  const int nth = params->nth;
@@ -5777,11 +5713,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5777
5713
  float corr_dims[2];
5778
5714
  wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5779
5715
 
5780
- const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
5781
- const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
5716
+ const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
+ const bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5782
5718
  const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
5783
5719
 
5784
- if (is_mrope) {
5720
+ if (mrope_used) {
5785
5721
  WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5786
5722
  }
5787
5723
 
@@ -5803,11 +5739,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5803
5739
 
5804
5740
  const int32_t * pos = (const int32_t *) src1->data;
5805
5741
 
5806
- for (int64_t i3 = 0; i3 < ne3; i3++) {
5807
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5742
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5808
5744
 
5809
5745
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5810
- if (!is_mrope) {
5746
+ if (!mrope_used) {
5811
5747
  const int64_t p = pos[i2];
5812
5748
  wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5813
5749
  }
@@ -5817,90 +5753,44 @@ static void wsp_ggml_compute_forward_rope_f16(
5817
5753
  const int64_t p_w = pos[i2 + ne2 * 2];
5818
5754
  const int64_t p_e = pos[i2 + ne2 * 3];
5819
5755
  wsp_ggml_mrope_cache_init(
5820
- p_t, p_h, p_w, p_e, sections, is_vision,
5756
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5821
5757
  freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5822
5758
  }
5823
5759
 
5824
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5760
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5825
5761
  if (ir++ < ir0) continue;
5826
5762
  if (ir > ir1) break;
5827
5763
 
5828
- if (is_neox || is_mrope) {
5829
- if (is_vision) {
5830
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5831
- const int64_t ic = i0/2;
5832
-
5833
- const float cos_theta = cache[i0 + 0];
5834
- const float sin_theta = cache[i0 + 1];
5835
-
5836
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5837
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5838
-
5839
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5840
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5841
-
5842
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5843
- dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5844
- }
5845
- } else {
5846
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5847
- const int64_t ic = i0/2;
5848
-
5849
- const float cos_theta = cache[i0 + 0];
5850
- const float sin_theta = cache[i0 + 1];
5851
-
5852
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5853
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5854
-
5855
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5856
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5857
-
5858
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5859
- dst_data[n_dims/2] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5860
- }
5861
- }
5862
- } else {
5863
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5864
- const float cos_theta = cache[i0 + 0];
5865
- const float sin_theta = cache[i0 + 1];
5866
-
5867
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5868
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5869
-
5870
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5871
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[1]);
5872
-
5873
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5874
- dst_data[1] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5875
- }
5764
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
+
5767
+ switch (mode) {
5768
+ case WSP_GGML_ROPE_TYPE_NORMAL:
5769
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
+ break;
5771
+ case WSP_GGML_ROPE_TYPE_NEOX:
5772
+ case WSP_GGML_ROPE_TYPE_MROPE:
5773
+ case WSP_GGML_ROPE_TYPE_IMROPE:
5774
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
+ break;
5776
+ case WSP_GGML_ROPE_TYPE_VISION:
5777
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
+ break;
5779
+ default:
5780
+ WSP_GGML_ABORT("rope type not supported");
5876
5781
  }
5877
5782
 
5878
- if (is_vision) {
5879
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5880
- const int64_t ic = i0/2;
5881
-
5882
- const float cos_theta = cache[i0 + 0];
5883
- const float sin_theta = cache[i0 + 1];
5884
-
5885
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5886
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5887
-
5888
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5889
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5890
-
5891
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5892
- dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5893
- }
5894
- } else {
5783
+ if (!is_vision) {
5784
+ // fill the remain channels with data from src tensor
5895
5785
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5896
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5897
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5786
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5898
5788
 
5899
5789
  dst_data[0] = src[0];
5900
5790
  dst_data[1] = src[1];
5901
5791
  }
5902
5792
  }
5903
- }
5793
+ } //attn-heads
5904
5794
  }
5905
5795
  }
5906
5796
  }
@@ -5914,11 +5804,11 @@ void wsp_ggml_compute_forward_rope(
5914
5804
  switch (src0->type) {
5915
5805
  case WSP_GGML_TYPE_F16:
5916
5806
  {
5917
- wsp_ggml_compute_forward_rope_f16(params, dst, true);
5807
+ wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, true);
5918
5808
  } break;
5919
5809
  case WSP_GGML_TYPE_F32:
5920
5810
  {
5921
- wsp_ggml_compute_forward_rope_f32(params, dst, true);
5811
+ wsp_ggml_compute_forward_rope_flt<float>(params, dst, true);
5922
5812
  } break;
5923
5813
  default:
5924
5814
  {
@@ -5938,11 +5828,11 @@ void wsp_ggml_compute_forward_rope_back(
5938
5828
  switch (src0->type) {
5939
5829
  case WSP_GGML_TYPE_F16:
5940
5830
  {
5941
- wsp_ggml_compute_forward_rope_f16(params, dst, false);
5831
+ wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, false);
5942
5832
  } break;
5943
5833
  case WSP_GGML_TYPE_F32:
5944
5834
  {
5945
- wsp_ggml_compute_forward_rope_f32(params, dst, false);
5835
+ wsp_ggml_compute_forward_rope_flt<float>(params, dst, false);
5946
5836
  } break;
5947
5837
  default:
5948
5838
  {
@@ -7074,7 +6964,11 @@ static void wsp_ggml_compute_forward_conv_2d_dw_cwhn(
7074
6964
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7075
6965
 
7076
6966
  #ifdef WSP_GGML_SIMD
7077
- const int64_t pkg_size = WSP_GGML_F32_EPR;
6967
+ #if defined(__ARM_FEATURE_SVE)
6968
+ const int64_t pkg_size = svcntw();
6969
+ #else
6970
+ const int64_t pkg_size = WSP_GGML_F32_EPR;
6971
+ #endif
7078
6972
  const int64_t pkg_count = c / pkg_size;
7079
6973
  const int64_t c_pkg_end = pkg_count * pkg_size;
7080
6974
  #else
@@ -7497,10 +7391,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
7497
7391
  float sf1 = (float)ne1/src0->ne[1];
7498
7392
  float sf2 = (float)ne2/src0->ne[2];
7499
7393
  float sf3 = (float)ne3/src0->ne[3];
7394
+ float pixel_offset = 0.5f;
7500
7395
 
7501
7396
  const int32_t mode_flags = wsp_ggml_get_op_params_i32(dst, 0);
7502
7397
  const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) (mode_flags & 0xFF);
7503
7398
 
7399
+ if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
7400
+ pixel_offset = 0.0f;
7401
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7402
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7403
+ }
7404
+
7504
7405
  if (mode == WSP_GGML_SCALE_MODE_NEAREST) {
7505
7406
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7506
7407
  const int64_t i03 = i3 / sf3;
@@ -7520,13 +7421,6 @@ static void wsp_ggml_compute_forward_upscale_f32(
7520
7421
  }
7521
7422
  }
7522
7423
  } else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
7523
- float pixel_offset = 0.5f;
7524
- if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
7525
- pixel_offset = 0.0f;
7526
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7527
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7528
- }
7529
-
7530
7424
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7531
7425
  const int64_t i03 = i3 / sf3;
7532
7426
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7561,6 +7455,51 @@ static void wsp_ggml_compute_forward_upscale_f32(
7561
7455
 
7562
7456
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7563
7457
 
7458
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7459
+ *y_dst = val;
7460
+ }
7461
+ }
7462
+ }
7463
+ }
7464
+ } else if (mode == WSP_GGML_SCALE_MODE_BICUBIC) {
7465
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7466
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7467
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7468
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7469
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7470
+ const float w0 = weight2(x + 1);
7471
+ const float w1 = weight1(x + 0);
7472
+ const float w2 = weight1(1 - x);
7473
+ const float w3 = weight2(2 - x);
7474
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7475
+ };
7476
+
7477
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7478
+ const int64_t i03 = i3 / sf3;
7479
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7480
+ const int64_t i02 = i2 / sf2;
7481
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7482
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7483
+ const int64_t y0 = (int64_t)floorf(y);
7484
+ const float dy = y - (float)y0;
7485
+
7486
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7487
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7488
+ const int64_t x0 = (int64_t)floorf(x);
7489
+ const float dx = x - (float)x0;
7490
+
7491
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7492
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7493
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7494
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7495
+ };
7496
+
7497
+ const float val = bicubic(
7498
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7499
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7500
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7501
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7502
+
7564
7503
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7565
7504
  *y_dst = val;
7566
7505
  }
@@ -7854,6 +7793,18 @@ void wsp_ggml_compute_forward_timestep_embedding(
7854
7793
 
7855
7794
  // wsp_ggml_compute_forward_argsort
7856
7795
 
7796
+ template<enum wsp_ggml_sort_order order>
7797
+ struct argsort_cmp {
7798
+ const float * data;
7799
+ bool operator()(int32_t a, int32_t b) const {
7800
+ if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
7801
+ return data[a] < data[b];
7802
+ } else {
7803
+ return data[a] > data[b];
7804
+ }
7805
+ }
7806
+ };
7807
+
7857
7808
  static void wsp_ggml_compute_forward_argsort_f32(
7858
7809
  const wsp_ggml_compute_params * params,
7859
7810
  wsp_ggml_tensor * dst) {
@@ -7872,23 +7823,25 @@ static void wsp_ggml_compute_forward_argsort_f32(
7872
7823
  wsp_ggml_sort_order order = (wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
7873
7824
 
7874
7825
  for (int64_t i = ith; i < nr; i += nth) {
7875
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7876
7826
  const float * src_data = (float *)((char *) src0->data + i*nb01);
7877
7827
 
7828
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7829
+
7878
7830
  for (int64_t j = 0; j < ne0; j++) {
7879
7831
  dst_data[j] = j;
7880
7832
  }
7881
7833
 
7882
- // C doesn't have a functional sort, so we do a bubble sort instead
7883
- for (int64_t j = 0; j < ne0; j++) {
7884
- for (int64_t k = j + 1; k < ne0; k++) {
7885
- if ((order == WSP_GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7886
- (order == WSP_GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7887
- int32_t tmp = dst_data[j];
7888
- dst_data[j] = dst_data[k];
7889
- dst_data[k] = tmp;
7890
- }
7891
- }
7834
+ switch (order) {
7835
+ case WSP_GGML_SORT_ORDER_ASC:
7836
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_ASC>{src_data});
7837
+ break;
7838
+
7839
+ case WSP_GGML_SORT_ORDER_DESC:
7840
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_DESC>{src_data});
7841
+ break;
7842
+
7843
+ default:
7844
+ WSP_GGML_ABORT("invalid sort order");
7892
7845
  }
7893
7846
  }
7894
7847
  }
@@ -7913,10 +7866,10 @@ void wsp_ggml_compute_forward_argsort(
7913
7866
 
7914
7867
  // wsp_ggml_compute_forward_flash_attn_ext
7915
7868
 
7916
- static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7869
+ static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7917
7870
  const wsp_ggml_compute_params * params,
7918
- wsp_ggml_tensor * dst) {
7919
-
7871
+ wsp_ggml_tensor * dst,
7872
+ int ir0, int ir1) {
7920
7873
  const wsp_ggml_tensor * q = dst->src[0];
7921
7874
  const wsp_ggml_tensor * k = dst->src[1];
7922
7875
  const wsp_ggml_tensor * v = dst->src[2];
@@ -7932,9 +7885,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7932
7885
  WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7933
7886
  WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7934
7887
 
7935
- const int ith = params->ith;
7936
- const int nth = params->nth;
7937
-
7938
7888
  const int64_t DK = nek0;
7939
7889
  const int64_t DV = nev0;
7940
7890
  const int64_t N = neq1;
@@ -7968,16 +7918,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7968
7918
 
7969
7919
  // parallelize by q rows using wsp_ggml_vec_dot_f32
7970
7920
 
7971
- // total rows in q
7972
- const int nr = neq1*neq2*neq3;
7973
-
7974
- // rows per thread
7975
- const int dr = (nr + nth - 1)/nth;
7976
-
7977
- // row range for this thread
7978
- const int ir0 = dr*ith;
7979
- const int ir1 = MIN(ir0 + dr, nr);
7980
-
7981
7921
  float scale = 1.0f;
7982
7922
  float max_bias = 0.0f;
7983
7923
  float logit_softcap = 0.0f;
@@ -8004,6 +7944,8 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8004
7944
  WSP_GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8005
7945
  WSP_GGML_ASSERT((v->type == WSP_GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8006
7946
 
7947
+ int ith = params->ith;
7948
+
8007
7949
  // loop over n_batch and n_head
8008
7950
  for (int ir = ir0; ir < ir1; ++ir) {
8009
7951
  // q indices
@@ -8135,7 +8077,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8135
8077
  }
8136
8078
 
8137
8079
  // V /= S
8138
- const float S_inv = 1.0f/S;
8080
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8139
8081
  wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
8140
8082
 
8141
8083
  // dst indices
@@ -8151,6 +8093,91 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8151
8093
  }
8152
8094
  }
8153
8095
 
8096
+ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8097
+ const wsp_ggml_compute_params * params,
8098
+ wsp_ggml_tensor * dst) {
8099
+
8100
+ const wsp_ggml_tensor * q = dst->src[0];
8101
+ const wsp_ggml_tensor * k = dst->src[1];
8102
+ const wsp_ggml_tensor * v = dst->src[2];
8103
+
8104
+ WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8105
+ WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8106
+ WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8107
+ WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8108
+ WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8109
+ WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8110
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8111
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8112
+
8113
+ const int64_t DK = nek0;
8114
+ const int64_t DV = nev0;
8115
+ const int64_t N = neq1;
8116
+
8117
+ WSP_GGML_ASSERT(ne0 == DV);
8118
+ WSP_GGML_ASSERT(ne2 == N);
8119
+
8120
+ // input tensor rows must be contiguous
8121
+ WSP_GGML_ASSERT(nbq0 == wsp_ggml_type_size(q->type));
8122
+ WSP_GGML_ASSERT(nbk0 == wsp_ggml_type_size(k->type));
8123
+ WSP_GGML_ASSERT(nbv0 == wsp_ggml_type_size(v->type));
8124
+
8125
+ WSP_GGML_ASSERT(neq0 == DK);
8126
+ WSP_GGML_ASSERT(nek0 == DK);
8127
+ WSP_GGML_ASSERT(nev0 == DV);
8128
+
8129
+ WSP_GGML_ASSERT(neq1 == N);
8130
+
8131
+ // dst cannot be transposed or permuted
8132
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
8133
+ WSP_GGML_ASSERT(nb0 <= nb1);
8134
+ WSP_GGML_ASSERT(nb1 <= nb2);
8135
+ WSP_GGML_ASSERT(nb2 <= nb3);
8136
+
8137
+ // parallelize by q rows using wsp_ggml_vec_dot_f32
8138
+
8139
+ // total rows in q
8140
+ const int64_t nr = neq1*neq2*neq3;
8141
+
8142
+ // rows per thread
8143
+ const int ith = params->ith;
8144
+ const int nth = params->nth;
8145
+
8146
+ // disable for NUMA
8147
+ const bool disable_chunking = wsp_ggml_is_numa();
8148
+
8149
+ // 4x chunks per thread
8150
+ int nth_scaled = nth * 4;
8151
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8152
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8153
+
8154
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8155
+ nchunk = nth;
8156
+ }
8157
+
8158
+ if (ith == 0) {
8159
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8160
+ wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
8161
+ }
8162
+
8163
+ wsp_ggml_barrier(params->threadpool);
8164
+
8165
+ // The number of elements in each chunk
8166
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8167
+
8168
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8169
+ int current_chunk = ith;
8170
+
8171
+ while (current_chunk < nchunk) {
8172
+ const int64_t ir0 = dr * current_chunk;
8173
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8174
+
8175
+ wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8176
+
8177
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
8178
+ }
8179
+ }
8180
+
8154
8181
  void wsp_ggml_compute_forward_flash_attn_ext(
8155
8182
  const wsp_ggml_compute_params * params,
8156
8183
  wsp_ggml_tensor * dst) {
@@ -8637,7 +8664,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8637
8664
  // n_head
8638
8665
  for (int h = ih0; h < ih1; ++h) {
8639
8666
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8667
+ const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
8641
8668
  const float dA = expf(dt_soft_plus * A[h]);
8642
8669
  const int g = h / (nh / ng); // repeat_interleave
8643
8670
 
@@ -8734,7 +8761,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8734
8761
  // n_head
8735
8762
  for (int h = ih0; h < ih1; ++h) {
8736
8763
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8764
+ const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
8738
8765
  const int g = h / (nh / ng); // repeat_interleave
8739
8766
 
8740
8767
  // dim
@@ -8997,6 +9024,34 @@ void wsp_ggml_compute_forward_unary(
8997
9024
  {
8998
9025
  wsp_ggml_compute_forward_exp(params, dst);
8999
9026
  } break;
9027
+ case WSP_GGML_UNARY_OP_FLOOR:
9028
+ {
9029
+ wsp_ggml_compute_forward_floor(params, dst);
9030
+ } break;
9031
+ case WSP_GGML_UNARY_OP_CEIL:
9032
+ {
9033
+ wsp_ggml_compute_forward_ceil(params, dst);
9034
+ } break;
9035
+ case WSP_GGML_UNARY_OP_ROUND:
9036
+ {
9037
+ wsp_ggml_compute_forward_round(params, dst);
9038
+ } break;
9039
+ case WSP_GGML_UNARY_OP_TRUNC:
9040
+ {
9041
+ wsp_ggml_compute_forward_trunc(params, dst);
9042
+ } break;
9043
+ case WSP_GGML_UNARY_OP_XIELU:
9044
+ {
9045
+ wsp_ggml_compute_forward_xielu(params, dst);
9046
+ } break;
9047
+ case WSP_GGML_UNARY_OP_EXPM1:
9048
+ {
9049
+ wsp_ggml_compute_forward_expm1(params, dst);
9050
+ } break;
9051
+ case WSP_GGML_UNARY_OP_SOFTPLUS:
9052
+ {
9053
+ wsp_ggml_compute_forward_softplus(params, dst);
9054
+ } break;
9000
9055
  default:
9001
9056
  {
9002
9057
  WSP_GGML_ABORT("fatal error");
@@ -9593,6 +9648,76 @@ void wsp_ggml_compute_forward_gla(
9593
9648
  }
9594
9649
  }
9595
9650
 
9651
+ static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
9652
+ const struct wsp_ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
9653
+ const struct wsp_ggml_tensor * src1 = dst->src[1]; // B (RHS)
9654
+
9655
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
9656
+
9657
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
9658
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
9659
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
9660
+
9661
+ WSP_GGML_ASSERT(ne00 == ne01); // A must be square
9662
+ WSP_GGML_ASSERT(ne0 == ne10); // solution cols == B cols
9663
+ WSP_GGML_ASSERT(ne1 == ne11); // solution rows == B rows
9664
+
9665
+ WSP_GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9666
+ WSP_GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9667
+
9668
+ const int ith = params->ith;
9669
+ const int nth = params->nth;
9670
+
9671
+ const int64_t k = ne10; // number of RHS columns
9672
+ const int64_t n = ne11; // A is n×n
9673
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9674
+
9675
+ // chunks per thread
9676
+ const int64_t dr = (nr + nth - 1)/nth;
9677
+
9678
+ // chunk range for this thread
9679
+ const int64_t ir0 = dr*ith;
9680
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9681
+
9682
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
9683
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
9684
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
9685
+
9686
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
9687
+ const int64_t i03 = ir/(ne02*k);
9688
+ const int64_t i02 = (ir - i03*ne02*k)/k;
9689
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
9690
+
9691
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9692
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9693
+
9694
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9695
+
9696
+ for (int64_t i00 = 0; i00 < n; ++i00) {
9697
+ float sum = 0.0f;
9698
+ for (int64_t t = 0; t < i00; ++t) {
9699
+ sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
9700
+ }
9701
+
9702
+ const float diag = A_batch[i00 * n + i00];
9703
+ WSP_GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
9704
+
9705
+ X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
9706
+ }
9707
+ }
9708
+ }
9709
+
9710
+ void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
9711
+ const wsp_ggml_tensor * src0 = dst->src[0];
9712
+ const wsp_ggml_tensor * src1 = dst->src[1];
9713
+
9714
+ if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32) {
9715
+ wsp_ggml_compute_forward_solve_tri_f32(params, dst);
9716
+ } else {
9717
+ WSP_GGML_ABORT("fatal error");
9718
+ }
9719
+ }
9720
+
9596
9721
  // wsp_ggml_compute_forward_rwkv_wkv7
9597
9722
 
9598
9723
  static void wsp_ggml_compute_forward_rwkv_wkv7_f32(