whisper.rn 0.5.2 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/README.md +1 -1
  2. package/cpp/ggml-alloc.c +11 -4
  3. package/cpp/ggml-backend-reg.cpp +8 -0
  4. package/cpp/ggml-backend.cpp +0 -2
  5. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  6. package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
  7. package/cpp/ggml-cpu/ggml-cpu.c +50 -21
  8. package/cpp/ggml-cpu/ops.cpp +458 -349
  9. package/cpp/ggml-cpu/ops.h +4 -4
  10. package/cpp/ggml-cpu/repack.cpp +143 -29
  11. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  12. package/cpp/ggml-cpu/unary-ops.cpp +16 -0
  13. package/cpp/ggml-cpu/unary-ops.h +2 -0
  14. package/cpp/ggml-cpu/vec.cpp +17 -0
  15. package/cpp/ggml-cpu/vec.h +10 -0
  16. package/cpp/ggml-impl.h +17 -1
  17. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  18. package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
  19. package/cpp/ggml-metal/ggml-metal-device.h +8 -1
  20. package/cpp/ggml-metal/ggml-metal-device.m +216 -14
  21. package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
  22. package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
  23. package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
  24. package/cpp/ggml-metal/ggml-metal.cpp +5 -0
  25. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  26. package/cpp/ggml.c +154 -5
  27. package/cpp/ggml.h +73 -0
  28. package/cpp/whisper.cpp +6 -2
  29. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  33. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  34. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  35. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  39. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  40. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  41. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  42. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  43. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  45. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  46. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  47. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  48. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  49. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  50. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  51. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
  52. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  53. package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
  54. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  55. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
  56. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  57. package/lib/typescript/realtime-transcription/types.d.ts +7 -0
  58. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
  61. package/src/realtime-transcription/types.ts +9 -0
  62. package/whisper-rn.podspec +1 -1
  63. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  64. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -7,8 +7,10 @@
7
7
  #include "unary-ops.h"
8
8
  #include "vec.h"
9
9
 
10
- #include <float.h>
10
+ #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cmath>
13
+ #include <functional>
12
14
 
13
15
  // wsp_ggml_compute_forward_dup
14
16
 
@@ -1394,6 +1396,56 @@ void wsp_ggml_compute_forward_sum(
1394
1396
  }
1395
1397
  }
1396
1398
 
1399
+ // wsp_ggml_compute_forward_cumsum
1400
+
1401
+ static void wsp_ggml_compute_forward_cumsum_f32(
1402
+ const wsp_ggml_compute_params * params,
1403
+ wsp_ggml_tensor * dst) {
1404
+
1405
+ const wsp_ggml_tensor * src0 = dst->src[0];
1406
+
1407
+ WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
+ WSP_GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
+
1410
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
1411
+
1412
+ WSP_GGML_ASSERT(ne0 == ne00);
1413
+ WSP_GGML_ASSERT(ne1 == ne01);
1414
+ WSP_GGML_ASSERT(ne2 == ne02);
1415
+ WSP_GGML_ASSERT(ne3 == ne03);
1416
+
1417
+ const auto [ir0, ir1] = get_thread_range(params, src0);
1418
+
1419
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
+ const int64_t i03 = ir/(ne02*ne01);
1421
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
+
1424
+ float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
+ float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1426
+
1427
+ wsp_ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
+ }
1429
+ }
1430
+
1431
+ void wsp_ggml_compute_forward_cumsum(
1432
+ const wsp_ggml_compute_params * params,
1433
+ wsp_ggml_tensor * dst) {
1434
+
1435
+ const wsp_ggml_tensor * src0 = dst->src[0];
1436
+
1437
+ switch (src0->type) {
1438
+ case WSP_GGML_TYPE_F32:
1439
+ {
1440
+ wsp_ggml_compute_forward_cumsum_f32(params, dst);
1441
+ } break;
1442
+ default:
1443
+ {
1444
+ WSP_GGML_ABORT("fatal error");
1445
+ }
1446
+ }
1447
+ }
1448
+
1397
1449
  // wsp_ggml_compute_forward_sum_rows
1398
1450
 
1399
1451
  static void wsp_ggml_compute_forward_sum_rows_f32(
@@ -2140,6 +2192,83 @@ static void wsp_ggml_compute_forward_gelu(
2140
2192
  }
2141
2193
  }
2142
2194
 
2195
+ // wsp_ggml_compute_fill
2196
+
2197
+ static void wsp_ggml_compute_forward_fill_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2198
+ const float c = wsp_ggml_get_op_params_f32(dst, 0);
2199
+
2200
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2202
+
2203
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2204
+
2205
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
+ const int64_t i03 = ir/(ne2*ne1);
2207
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
+
2210
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
+
2212
+ wsp_ggml_vec_set_f32(ne0, dst_ptr, c);
2213
+ }
2214
+ }
2215
+
2216
+ void wsp_ggml_compute_forward_fill(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2217
+ wsp_ggml_compute_forward_fill_f32(params, dst);
2218
+ }
2219
+
2220
+ // wsp_ggml_compute_tri
2221
+
2222
+ static void wsp_ggml_compute_forward_tri_f32(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2223
+ const wsp_ggml_tensor * src0 = dst->src[0];
2224
+
2225
+ const wsp_ggml_tri_type ttype = (wsp_ggml_tri_type) wsp_ggml_get_op_params_i32(dst, 0);
2226
+
2227
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2228
+
2229
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
2230
+
2231
+ const auto [ir0, ir1] = get_thread_range(params, src0);
2232
+
2233
+ bool (*bipred)(int, int);
2234
+
2235
+ switch (ttype) {
2236
+ case WSP_GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2237
+ case WSP_GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
+ case WSP_GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2239
+ case WSP_GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
+ default: WSP_GGML_ABORT("invalid tri type");
2241
+ }
2242
+
2243
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
+ const int64_t i03 = ir/(ne02*ne01);
2245
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
+
2248
+ const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
+ float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2250
+
2251
+ for (int i0 = 0; i0 < ne0; ++i0) {
2252
+ dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
+ }
2254
+ }
2255
+ }
2256
+
2257
+ void wsp_ggml_compute_forward_tri(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
2258
+ const wsp_ggml_tensor * src0 = dst->src[0];
2259
+
2260
+ switch (src0->type) {
2261
+ case WSP_GGML_TYPE_F32:
2262
+ {
2263
+ wsp_ggml_compute_forward_tri_f32(params, dst);
2264
+ } break;
2265
+ default:
2266
+ {
2267
+ WSP_GGML_ABORT("fatal error");
2268
+ }
2269
+ }
2270
+ }
2271
+
2143
2272
  // wsp_ggml_compute_forward_gelu_erf
2144
2273
 
2145
2274
  static void wsp_ggml_compute_forward_gelu_erf_f32(
@@ -4455,46 +4584,6 @@ void wsp_ggml_compute_forward_cont(
4455
4584
  wsp_ggml_compute_forward_dup(params, dst);
4456
4585
  }
4457
4586
 
4458
- // wsp_ggml_compute_forward_reshape
4459
-
4460
- void wsp_ggml_compute_forward_reshape(
4461
- const wsp_ggml_compute_params * params,
4462
- wsp_ggml_tensor * dst) {
4463
- // NOP
4464
- WSP_GGML_UNUSED(params);
4465
- WSP_GGML_UNUSED(dst);
4466
- }
4467
-
4468
- // wsp_ggml_compute_forward_view
4469
-
4470
- void wsp_ggml_compute_forward_view(
4471
- const wsp_ggml_compute_params * params,
4472
- wsp_ggml_tensor * dst) {
4473
- // NOP
4474
- WSP_GGML_UNUSED(params);
4475
- WSP_GGML_UNUSED(dst);
4476
- }
4477
-
4478
- // wsp_ggml_compute_forward_permute
4479
-
4480
- void wsp_ggml_compute_forward_permute(
4481
- const wsp_ggml_compute_params * params,
4482
- wsp_ggml_tensor * dst) {
4483
- // NOP
4484
- WSP_GGML_UNUSED(params);
4485
- WSP_GGML_UNUSED(dst);
4486
- }
4487
-
4488
- // wsp_ggml_compute_forward_transpose
4489
-
4490
- void wsp_ggml_compute_forward_transpose(
4491
- const wsp_ggml_compute_params * params,
4492
- wsp_ggml_tensor * dst) {
4493
- // NOP
4494
- WSP_GGML_UNUSED(params);
4495
- WSP_GGML_UNUSED(dst);
4496
- }
4497
-
4498
4587
  // wsp_ggml_compute_forward_get_rows
4499
4588
 
4500
4589
  static void wsp_ggml_compute_forward_get_rows_q(
@@ -5474,7 +5563,7 @@ static void wsp_ggml_rope_cache_init(
5474
5563
  }
5475
5564
 
5476
5565
  static void wsp_ggml_mrope_cache_init(
5477
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
5566
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5478
5567
  float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5479
5568
  float * cache, float sin_sign, float theta_scale) {
5480
5569
  // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5509,14 +5598,26 @@ static void wsp_ggml_mrope_cache_init(
5509
5598
  }
5510
5599
 
5511
5600
  float theta = theta_t;
5512
- if (sector >= sections[0] && sector < sec_w) {
5513
- theta = theta_h;
5514
- }
5515
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
5516
- theta = theta_w;
5517
- }
5518
- else if (sector >= sec_w + sections[2]) {
5519
- theta = theta_e;
5601
+ if (is_imrope) { // qwen3vl apply interleaved mrope
5602
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
+ theta = theta_h;
5604
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
+ theta = theta_w;
5606
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
+ theta = theta_t;
5608
+ } else {
5609
+ theta = theta_e;
5610
+ }
5611
+ } else {
5612
+ if (sector >= sections[0] && sector < sec_w) {
5613
+ theta = theta_h;
5614
+ }
5615
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
+ theta = theta_w;
5617
+ }
5618
+ else if (sector >= sec_w + sections[2]) {
5619
+ theta = theta_e;
5620
+ }
5520
5621
  }
5521
5622
 
5522
5623
  rope_yarn(
@@ -5531,193 +5632,28 @@ static void wsp_ggml_mrope_cache_init(
5531
5632
  }
5532
5633
  }
5533
5634
 
5534
- static void wsp_ggml_compute_forward_rope_f32(
5535
- const wsp_ggml_compute_params * params,
5536
- wsp_ggml_tensor * dst,
5537
- const bool forward) {
5538
-
5539
- const wsp_ggml_tensor * src0 = dst->src[0];
5540
- const wsp_ggml_tensor * src1 = dst->src[1];
5541
- const wsp_ggml_tensor * src2 = dst->src[2];
5542
-
5543
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5544
- int sections[4];
5545
-
5546
- //const int n_past = ((int32_t *) dst->op_params)[0];
5547
- const int n_dims = ((int32_t *) dst->op_params)[1];
5548
- const int mode = ((int32_t *) dst->op_params)[2];
5549
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
5550
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5551
-
5552
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5553
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5554
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5555
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5556
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5557
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5558
- memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5559
-
5560
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
5561
-
5562
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5563
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5564
-
5565
- WSP_GGML_ASSERT(nb00 == sizeof(float));
5566
-
5567
- const int ith = params->ith;
5568
- const int nth = params->nth;
5569
-
5570
- const int nr = wsp_ggml_nrows(dst);
5571
-
5572
- WSP_GGML_ASSERT(n_dims <= ne0);
5573
- WSP_GGML_ASSERT(n_dims % 2 == 0);
5574
-
5575
- // rows per thread
5576
- const int dr = (nr + nth - 1)/nth;
5577
-
5578
- // row range for this thread
5579
- const int ir0 = dr*ith;
5580
- const int ir1 = MIN(ir0 + dr, nr);
5581
-
5582
- // row index used to determine which thread to use
5583
- int ir = 0;
5584
-
5585
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
5586
-
5587
- float corr_dims[2];
5588
- wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5589
-
5590
- const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
5591
- const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, multimodal rotary position embedding
5592
- const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
5593
-
5594
- if (is_mrope) {
5595
- WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5596
- }
5597
-
5598
- if (is_vision) {
5599
- WSP_GGML_ASSERT(n_dims == ne0/2);
5600
- }
5601
-
5602
- const float * freq_factors = NULL;
5603
- if (src2 != NULL) {
5604
- WSP_GGML_ASSERT(src2->type == WSP_GGML_TYPE_F32);
5605
- WSP_GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5606
- freq_factors = (const float *) src2->data;
5607
- }
5608
-
5609
- // backward process uses inverse rotation by cos and sin.
5610
- // cos and sin build a rotation matrix, where the inverse is the transpose.
5611
- // this essentially just switches the sign of sin.
5612
- const float sin_sign = forward ? 1.0f : -1.0f;
5613
5635
 
5614
- const int32_t * pos = (const int32_t *) src1->data;
5615
-
5616
- for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5617
- for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5618
-
5619
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5620
- if (!is_mrope) {
5621
- const int64_t p = pos[i2];
5622
- wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5623
- }
5624
- else {
5625
- const int64_t p_t = pos[i2];
5626
- const int64_t p_h = pos[i2 + ne2];
5627
- const int64_t p_w = pos[i2 + ne2 * 2];
5628
- const int64_t p_e = pos[i2 + ne2 * 3];
5629
- wsp_ggml_mrope_cache_init(
5630
- p_t, p_h, p_w, p_e, sections, is_vision,
5631
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5632
- }
5633
-
5634
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5635
- if (ir++ < ir0) continue;
5636
- if (ir > ir1) break;
5637
-
5638
- if (is_neox || is_mrope) {
5639
- if (is_vision){
5640
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5641
- const int64_t ic = i0/2;
5642
-
5643
- const float cos_theta = cache[i0 + 0];
5644
- const float sin_theta = cache[i0 + 1];
5645
-
5646
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5647
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5648
-
5649
- const float x0 = src[0];
5650
- const float x1 = src[n_dims];
5651
-
5652
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5653
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5654
- }
5655
- } else {
5656
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5657
- const int64_t ic = i0/2;
5636
+ template<typename T>
5637
+ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
+ const int64_t ic = i0/scale; // hack for WSP_GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5658
5640
 
5659
- const float cos_theta = cache[i0 + 0];
5660
- const float sin_theta = cache[i0 + 1];
5641
+ const float cos_theta = cache[i0 + 0];
5642
+ const float sin_theta = cache[i0 + 1];
5661
5643
 
5662
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5663
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5644
+ const T * const src = src_data + ic;
5645
+ T * dst = dst_data + ic;
5664
5646
 
5665
- const float x0 = src[0];
5666
- const float x1 = src[n_dims/2];
5647
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5667
5649
 
5668
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5669
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5670
- }
5671
- }
5672
- } else {
5673
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5674
- const float cos_theta = cache[i0 + 0];
5675
- const float sin_theta = cache[i0 + 1];
5676
-
5677
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5678
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5679
-
5680
- const float x0 = src[0];
5681
- const float x1 = src[1];
5682
-
5683
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5684
- dst_data[1] = x0*sin_theta + x1*cos_theta;
5685
- }
5686
- }
5687
-
5688
- if (is_vision) {
5689
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5690
- const int64_t ic = i0/2;
5691
-
5692
- const float cos_theta = cache[i0 + 0];
5693
- const float sin_theta = cache[i0 + 1];
5694
-
5695
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5696
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5697
-
5698
- const float x0 = src[0];
5699
- const float x1 = src[n_dims];
5700
-
5701
- dst_data[0] = x0*cos_theta - x1*sin_theta;
5702
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5703
- }
5704
- } else {
5705
- // fill the remain channels with data from src tensor
5706
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5707
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5708
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5709
-
5710
- dst_data[0] = src[0];
5711
- dst_data[1] = src[1];
5712
- }
5713
- }
5714
- }
5715
- }
5716
- }
5650
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
+ }
5717
5653
  }
5718
5654
 
5719
- // TODO: deduplicate f16/f32 code
5720
- static void wsp_ggml_compute_forward_rope_f16(
5655
+ template<typename T> //float or wsp_ggml_fp16_t
5656
+ static void wsp_ggml_compute_forward_rope_flt(
5721
5657
  const wsp_ggml_compute_params * params,
5722
5658
  wsp_ggml_tensor * dst,
5723
5659
  const bool forward) {
@@ -5726,6 +5662,9 @@ static void wsp_ggml_compute_forward_rope_f16(
5726
5662
  const wsp_ggml_tensor * src1 = dst->src[1];
5727
5663
  const wsp_ggml_tensor * src2 = dst->src[2];
5728
5664
 
5665
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32 || src0->type == WSP_GGML_TYPE_F16);
5666
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_I32);
5667
+
5729
5668
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5730
5669
  int sections[4];
5731
5670
 
@@ -5734,6 +5673,7 @@ static void wsp_ggml_compute_forward_rope_f16(
5734
5673
  const int mode = ((int32_t *) dst->op_params)[2];
5735
5674
  //const int n_ctx = ((int32_t *) dst->op_params)[3];
5736
5675
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5676
+
5737
5677
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5738
5678
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5739
5679
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -5742,13 +5682,13 @@ static void wsp_ggml_compute_forward_rope_f16(
5742
5682
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5743
5683
  memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5744
5684
 
5745
-
5746
5685
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
5747
5686
 
5748
5687
  //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5749
5688
  //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5750
5689
 
5751
- WSP_GGML_ASSERT(nb0 == sizeof(wsp_ggml_fp16_t));
5690
+ WSP_GGML_ASSERT(nb0 == nb00);
5691
+ WSP_GGML_ASSERT(nb0 == sizeof(T));
5752
5692
 
5753
5693
  const int ith = params->ith;
5754
5694
  const int nth = params->nth;
@@ -5773,11 +5713,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5773
5713
  float corr_dims[2];
5774
5714
  wsp_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5775
5715
 
5776
- const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
5777
- const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
5716
+ const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
+ const bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE; // wsp_ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5778
5718
  const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
5779
5719
 
5780
- if (is_mrope) {
5720
+ if (mrope_used) {
5781
5721
  WSP_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5782
5722
  }
5783
5723
 
@@ -5799,11 +5739,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5799
5739
 
5800
5740
  const int32_t * pos = (const int32_t *) src1->data;
5801
5741
 
5802
- for (int64_t i3 = 0; i3 < ne3; i3++) {
5803
- for (int64_t i2 = 0; i2 < ne2; i2++) {
5742
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5804
5744
 
5805
5745
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5806
- if (!is_mrope) {
5746
+ if (!mrope_used) {
5807
5747
  const int64_t p = pos[i2];
5808
5748
  wsp_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5809
5749
  }
@@ -5813,90 +5753,44 @@ static void wsp_ggml_compute_forward_rope_f16(
5813
5753
  const int64_t p_w = pos[i2 + ne2 * 2];
5814
5754
  const int64_t p_e = pos[i2 + ne2 * 3];
5815
5755
  wsp_ggml_mrope_cache_init(
5816
- p_t, p_h, p_w, p_e, sections, is_vision,
5756
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5817
5757
  freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5818
5758
  }
5819
5759
 
5820
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5760
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5821
5761
  if (ir++ < ir0) continue;
5822
5762
  if (ir > ir1) break;
5823
5763
 
5824
- if (is_neox || is_mrope) {
5825
- if (is_vision) {
5826
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5827
- const int64_t ic = i0/2;
5828
-
5829
- const float cos_theta = cache[i0 + 0];
5830
- const float sin_theta = cache[i0 + 1];
5831
-
5832
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5833
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5834
-
5835
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5836
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5837
-
5838
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5839
- dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5840
- }
5841
- } else {
5842
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5843
- const int64_t ic = i0/2;
5844
-
5845
- const float cos_theta = cache[i0 + 0];
5846
- const float sin_theta = cache[i0 + 1];
5847
-
5848
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5849
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5850
-
5851
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5852
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5853
-
5854
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5855
- dst_data[n_dims/2] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5856
- }
5857
- }
5858
- } else {
5859
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5860
- const float cos_theta = cache[i0 + 0];
5861
- const float sin_theta = cache[i0 + 1];
5862
-
5863
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5864
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5865
-
5866
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5867
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[1]);
5868
-
5869
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5870
- dst_data[1] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5871
- }
5764
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
+
5767
+ switch (mode) {
5768
+ case WSP_GGML_ROPE_TYPE_NORMAL:
5769
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
+ break;
5771
+ case WSP_GGML_ROPE_TYPE_NEOX:
5772
+ case WSP_GGML_ROPE_TYPE_MROPE:
5773
+ case WSP_GGML_ROPE_TYPE_IMROPE:
5774
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
+ break;
5776
+ case WSP_GGML_ROPE_TYPE_VISION:
5777
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
+ break;
5779
+ default:
5780
+ WSP_GGML_ABORT("rope type not supported");
5872
5781
  }
5873
5782
 
5874
- if (is_vision) {
5875
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5876
- const int64_t ic = i0/2;
5877
-
5878
- const float cos_theta = cache[i0 + 0];
5879
- const float sin_theta = cache[i0 + 1];
5880
-
5881
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5882
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5883
-
5884
- const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5885
- const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5886
-
5887
- dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5888
- dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5889
- }
5890
- } else {
5783
+ if (!is_vision) {
5784
+ // fill the remain channels with data from src tensor
5891
5785
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5892
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5893
- wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5786
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5894
5788
 
5895
5789
  dst_data[0] = src[0];
5896
5790
  dst_data[1] = src[1];
5897
5791
  }
5898
5792
  }
5899
- }
5793
+ } //attn-heads
5900
5794
  }
5901
5795
  }
5902
5796
  }
@@ -5910,11 +5804,11 @@ void wsp_ggml_compute_forward_rope(
5910
5804
  switch (src0->type) {
5911
5805
  case WSP_GGML_TYPE_F16:
5912
5806
  {
5913
- wsp_ggml_compute_forward_rope_f16(params, dst, true);
5807
+ wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, true);
5914
5808
  } break;
5915
5809
  case WSP_GGML_TYPE_F32:
5916
5810
  {
5917
- wsp_ggml_compute_forward_rope_f32(params, dst, true);
5811
+ wsp_ggml_compute_forward_rope_flt<float>(params, dst, true);
5918
5812
  } break;
5919
5813
  default:
5920
5814
  {
@@ -5934,11 +5828,11 @@ void wsp_ggml_compute_forward_rope_back(
5934
5828
  switch (src0->type) {
5935
5829
  case WSP_GGML_TYPE_F16:
5936
5830
  {
5937
- wsp_ggml_compute_forward_rope_f16(params, dst, false);
5831
+ wsp_ggml_compute_forward_rope_flt<wsp_ggml_fp16_t>(params, dst, false);
5938
5832
  } break;
5939
5833
  case WSP_GGML_TYPE_F32:
5940
5834
  {
5941
- wsp_ggml_compute_forward_rope_f32(params, dst, false);
5835
+ wsp_ggml_compute_forward_rope_flt<float>(params, dst, false);
5942
5836
  } break;
5943
5837
  default:
5944
5838
  {
@@ -7070,7 +6964,11 @@ static void wsp_ggml_compute_forward_conv_2d_dw_cwhn(
7070
6964
  const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7071
6965
 
7072
6966
  #ifdef WSP_GGML_SIMD
7073
- const int64_t pkg_size = WSP_GGML_F32_EPR;
6967
+ #if defined(__ARM_FEATURE_SVE)
6968
+ const int64_t pkg_size = svcntw();
6969
+ #else
6970
+ const int64_t pkg_size = WSP_GGML_F32_EPR;
6971
+ #endif
7074
6972
  const int64_t pkg_count = c / pkg_size;
7075
6973
  const int64_t c_pkg_end = pkg_count * pkg_size;
7076
6974
  #else
@@ -7493,10 +7391,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
7493
7391
  float sf1 = (float)ne1/src0->ne[1];
7494
7392
  float sf2 = (float)ne2/src0->ne[2];
7495
7393
  float sf3 = (float)ne3/src0->ne[3];
7394
+ float pixel_offset = 0.5f;
7496
7395
 
7497
7396
  const int32_t mode_flags = wsp_ggml_get_op_params_i32(dst, 0);
7498
7397
  const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) (mode_flags & 0xFF);
7499
7398
 
7399
+ if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
7400
+ pixel_offset = 0.0f;
7401
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7402
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7403
+ }
7404
+
7500
7405
  if (mode == WSP_GGML_SCALE_MODE_NEAREST) {
7501
7406
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7502
7407
  const int64_t i03 = i3 / sf3;
@@ -7516,13 +7421,6 @@ static void wsp_ggml_compute_forward_upscale_f32(
7516
7421
  }
7517
7422
  }
7518
7423
  } else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
7519
- float pixel_offset = 0.5f;
7520
- if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
7521
- pixel_offset = 0.0f;
7522
- sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7523
- sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7524
- }
7525
-
7526
7424
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7527
7425
  const int64_t i03 = i3 / sf3;
7528
7426
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7557,6 +7455,51 @@ static void wsp_ggml_compute_forward_upscale_f32(
7557
7455
 
7558
7456
  const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7559
7457
 
7458
+ float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7459
+ *y_dst = val;
7460
+ }
7461
+ }
7462
+ }
7463
+ }
7464
+ } else if (mode == WSP_GGML_SCALE_MODE_BICUBIC) {
7465
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7466
+ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7467
+ auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7468
+ auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7469
+ auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7470
+ const float w0 = weight2(x + 1);
7471
+ const float w1 = weight1(x + 0);
7472
+ const float w2 = weight1(1 - x);
7473
+ const float w3 = weight2(2 - x);
7474
+ return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7475
+ };
7476
+
7477
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7478
+ const int64_t i03 = i3 / sf3;
7479
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7480
+ const int64_t i02 = i2 / sf2;
7481
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7482
+ const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7483
+ const int64_t y0 = (int64_t)floorf(y);
7484
+ const float dy = y - (float)y0;
7485
+
7486
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7487
+ const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7488
+ const int64_t x0 = (int64_t)floorf(x);
7489
+ const float dx = x - (float)x0;
7490
+
7491
+ auto p = [=](int64_t x_off, int64_t y_off) -> float {
7492
+ int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7493
+ int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7494
+ return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7495
+ };
7496
+
7497
+ const float val = bicubic(
7498
+ bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7499
+ bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7500
+ bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7501
+ bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7502
+
7560
7503
  float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7561
7504
  *y_dst = val;
7562
7505
  }
@@ -7850,6 +7793,18 @@ void wsp_ggml_compute_forward_timestep_embedding(
7850
7793
 
7851
7794
  // wsp_ggml_compute_forward_argsort
7852
7795
 
7796
+ template<enum wsp_ggml_sort_order order>
7797
+ struct argsort_cmp {
7798
+ const float * data;
7799
+ bool operator()(int32_t a, int32_t b) const {
7800
+ if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
7801
+ return data[a] < data[b];
7802
+ } else {
7803
+ return data[a] > data[b];
7804
+ }
7805
+ }
7806
+ };
7807
+
7853
7808
  static void wsp_ggml_compute_forward_argsort_f32(
7854
7809
  const wsp_ggml_compute_params * params,
7855
7810
  wsp_ggml_tensor * dst) {
@@ -7868,23 +7823,25 @@ static void wsp_ggml_compute_forward_argsort_f32(
7868
7823
  wsp_ggml_sort_order order = (wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
7869
7824
 
7870
7825
  for (int64_t i = ith; i < nr; i += nth) {
7871
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7872
7826
  const float * src_data = (float *)((char *) src0->data + i*nb01);
7873
7827
 
7828
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7829
+
7874
7830
  for (int64_t j = 0; j < ne0; j++) {
7875
7831
  dst_data[j] = j;
7876
7832
  }
7877
7833
 
7878
- // C doesn't have a functional sort, so we do a bubble sort instead
7879
- for (int64_t j = 0; j < ne0; j++) {
7880
- for (int64_t k = j + 1; k < ne0; k++) {
7881
- if ((order == WSP_GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7882
- (order == WSP_GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7883
- int32_t tmp = dst_data[j];
7884
- dst_data[j] = dst_data[k];
7885
- dst_data[k] = tmp;
7886
- }
7887
- }
7834
+ switch (order) {
7835
+ case WSP_GGML_SORT_ORDER_ASC:
7836
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_ASC>{src_data});
7837
+ break;
7838
+
7839
+ case WSP_GGML_SORT_ORDER_DESC:
7840
+ std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_DESC>{src_data});
7841
+ break;
7842
+
7843
+ default:
7844
+ WSP_GGML_ABORT("invalid sort order");
7888
7845
  }
7889
7846
  }
7890
7847
  }
@@ -7909,10 +7866,10 @@ void wsp_ggml_compute_forward_argsort(
7909
7866
 
7910
7867
  // wsp_ggml_compute_forward_flash_attn_ext
7911
7868
 
7912
- static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7869
+ static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7913
7870
  const wsp_ggml_compute_params * params,
7914
- wsp_ggml_tensor * dst) {
7915
-
7871
+ wsp_ggml_tensor * dst,
7872
+ int ir0, int ir1) {
7916
7873
  const wsp_ggml_tensor * q = dst->src[0];
7917
7874
  const wsp_ggml_tensor * k = dst->src[1];
7918
7875
  const wsp_ggml_tensor * v = dst->src[2];
@@ -7928,9 +7885,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7928
7885
  WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7929
7886
  WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7930
7887
 
7931
- const int ith = params->ith;
7932
- const int nth = params->nth;
7933
-
7934
7888
  const int64_t DK = nek0;
7935
7889
  const int64_t DV = nev0;
7936
7890
  const int64_t N = neq1;
@@ -7964,16 +7918,6 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7964
7918
 
7965
7919
  // parallelize by q rows using wsp_ggml_vec_dot_f32
7966
7920
 
7967
- // total rows in q
7968
- const int nr = neq1*neq2*neq3;
7969
-
7970
- // rows per thread
7971
- const int dr = (nr + nth - 1)/nth;
7972
-
7973
- // row range for this thread
7974
- const int ir0 = dr*ith;
7975
- const int ir1 = MIN(ir0 + dr, nr);
7976
-
7977
7921
  float scale = 1.0f;
7978
7922
  float max_bias = 0.0f;
7979
7923
  float logit_softcap = 0.0f;
@@ -8000,6 +7944,8 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8000
7944
  WSP_GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8001
7945
  WSP_GGML_ASSERT((v->type == WSP_GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8002
7946
 
7947
+ int ith = params->ith;
7948
+
8003
7949
  // loop over n_batch and n_head
8004
7950
  for (int ir = ir0; ir < ir1; ++ir) {
8005
7951
  // q indices
@@ -8147,6 +8093,91 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8147
8093
  }
8148
8094
  }
8149
8095
 
8096
+ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8097
+ const wsp_ggml_compute_params * params,
8098
+ wsp_ggml_tensor * dst) {
8099
+
8100
+ const wsp_ggml_tensor * q = dst->src[0];
8101
+ const wsp_ggml_tensor * k = dst->src[1];
8102
+ const wsp_ggml_tensor * v = dst->src[2];
8103
+
8104
+ WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8105
+ WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8106
+ WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8107
+ WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8108
+ WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8109
+ WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8110
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8111
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8112
+
8113
+ const int64_t DK = nek0;
8114
+ const int64_t DV = nev0;
8115
+ const int64_t N = neq1;
8116
+
8117
+ WSP_GGML_ASSERT(ne0 == DV);
8118
+ WSP_GGML_ASSERT(ne2 == N);
8119
+
8120
+ // input tensor rows must be contiguous
8121
+ WSP_GGML_ASSERT(nbq0 == wsp_ggml_type_size(q->type));
8122
+ WSP_GGML_ASSERT(nbk0 == wsp_ggml_type_size(k->type));
8123
+ WSP_GGML_ASSERT(nbv0 == wsp_ggml_type_size(v->type));
8124
+
8125
+ WSP_GGML_ASSERT(neq0 == DK);
8126
+ WSP_GGML_ASSERT(nek0 == DK);
8127
+ WSP_GGML_ASSERT(nev0 == DV);
8128
+
8129
+ WSP_GGML_ASSERT(neq1 == N);
8130
+
8131
+ // dst cannot be transposed or permuted
8132
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
8133
+ WSP_GGML_ASSERT(nb0 <= nb1);
8134
+ WSP_GGML_ASSERT(nb1 <= nb2);
8135
+ WSP_GGML_ASSERT(nb2 <= nb3);
8136
+
8137
+ // parallelize by q rows using wsp_ggml_vec_dot_f32
8138
+
8139
+ // total rows in q
8140
+ const int64_t nr = neq1*neq2*neq3;
8141
+
8142
+ // rows per thread
8143
+ const int ith = params->ith;
8144
+ const int nth = params->nth;
8145
+
8146
+ // disable for NUMA
8147
+ const bool disable_chunking = wsp_ggml_is_numa();
8148
+
8149
+ // 4x chunks per thread
8150
+ int nth_scaled = nth * 4;
8151
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8152
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8153
+
8154
+ if (nth == 1 || nchunk < nth || disable_chunking) {
8155
+ nchunk = nth;
8156
+ }
8157
+
8158
+ if (ith == 0) {
8159
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8160
+ wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
8161
+ }
8162
+
8163
+ wsp_ggml_barrier(params->threadpool);
8164
+
8165
+ // The number of elements in each chunk
8166
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
8167
+
8168
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8169
+ int current_chunk = ith;
8170
+
8171
+ while (current_chunk < nchunk) {
8172
+ const int64_t ir0 = dr * current_chunk;
8173
+ const int64_t ir1 = MIN(ir0 + dr, nr);
8174
+
8175
+ wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8176
+
8177
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
8178
+ }
8179
+ }
8180
+
8150
8181
  void wsp_ggml_compute_forward_flash_attn_ext(
8151
8182
  const wsp_ggml_compute_params * params,
8152
8183
  wsp_ggml_tensor * dst) {
@@ -8633,7 +8664,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8633
8664
  // n_head
8634
8665
  for (int h = ih0; h < ih1; ++h) {
8635
8666
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8636
- const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8667
+ const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
8637
8668
  const float dA = expf(dt_soft_plus * A[h]);
8638
8669
  const int g = h / (nh / ng); // repeat_interleave
8639
8670
 
@@ -8730,7 +8761,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8730
8761
  // n_head
8731
8762
  for (int h = ih0; h < ih1; ++h) {
8732
8763
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8733
- const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8764
+ const float dt_soft_plus = wsp_ggml_compute_softplus_f32(dt[h]);
8734
8765
  const int g = h / (nh / ng); // repeat_interleave
8735
8766
 
8736
8767
  // dim
@@ -9013,6 +9044,14 @@ void wsp_ggml_compute_forward_unary(
9013
9044
  {
9014
9045
  wsp_ggml_compute_forward_xielu(params, dst);
9015
9046
  } break;
9047
+ case WSP_GGML_UNARY_OP_EXPM1:
9048
+ {
9049
+ wsp_ggml_compute_forward_expm1(params, dst);
9050
+ } break;
9051
+ case WSP_GGML_UNARY_OP_SOFTPLUS:
9052
+ {
9053
+ wsp_ggml_compute_forward_softplus(params, dst);
9054
+ } break;
9016
9055
  default:
9017
9056
  {
9018
9057
  WSP_GGML_ABORT("fatal error");
@@ -9609,6 +9648,76 @@ void wsp_ggml_compute_forward_gla(
9609
9648
  }
9610
9649
  }
9611
9650
 
9651
+ static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
9652
+ const struct wsp_ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
9653
+ const struct wsp_ggml_tensor * src1 = dst->src[1]; // B (RHS)
9654
+
9655
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
9656
+
9657
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
9658
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
9659
+ WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F32);
9660
+
9661
+ WSP_GGML_ASSERT(ne00 == ne01); // A must be square
9662
+ WSP_GGML_ASSERT(ne0 == ne10); // solution cols == B cols
9663
+ WSP_GGML_ASSERT(ne1 == ne11); // solution rows == B rows
9664
+
9665
+ WSP_GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9666
+ WSP_GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9667
+
9668
+ const int ith = params->ith;
9669
+ const int nth = params->nth;
9670
+
9671
+ const int64_t k = ne10; // number of RHS columns
9672
+ const int64_t n = ne11; // A is n×n
9673
+ const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9674
+
9675
+ // chunks per thread
9676
+ const int64_t dr = (nr + nth - 1)/nth;
9677
+
9678
+ // chunk range for this thread
9679
+ const int64_t ir0 = dr*ith;
9680
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9681
+
9682
+ const float * A = (const float *) src0->data; // [n, n, B1, B2]
9683
+ const float * B = (const float *) src1->data; // [n, k, B1, B2]
9684
+ float * X = ( float *) dst->data; // [n, k, B1, B2]
9685
+
9686
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
9687
+ const int64_t i03 = ir/(ne02*k);
9688
+ const int64_t i02 = (ir - i03*ne02*k)/k;
9689
+ const int64_t i01 = (ir - i03*ne02*k - i02*k);
9690
+
9691
+ const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9692
+ const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9693
+
9694
+ float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9695
+
9696
+ for (int64_t i00 = 0; i00 < n; ++i00) {
9697
+ float sum = 0.0f;
9698
+ for (int64_t t = 0; t < i00; ++t) {
9699
+ sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
9700
+ }
9701
+
9702
+ const float diag = A_batch[i00 * n + i00];
9703
+ WSP_GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
9704
+
9705
+ X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
9706
+ }
9707
+ }
9708
+ }
9709
+
9710
+ void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst) {
9711
+ const wsp_ggml_tensor * src0 = dst->src[0];
9712
+ const wsp_ggml_tensor * src1 = dst->src[1];
9713
+
9714
+ if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32) {
9715
+ wsp_ggml_compute_forward_solve_tri_f32(params, dst);
9716
+ } else {
9717
+ WSP_GGML_ABORT("fatal error");
9718
+ }
9719
+ }
9720
+
9612
9721
  // wsp_ggml_compute_forward_rwkv_wkv7
9613
9722
 
9614
9723
  static void wsp_ggml_compute_forward_rwkv_wkv7_f32(