whisper.rn 0.5.4 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  2. package/android/src/main/jni.cpp +13 -0
  3. package/cpp/ggml-alloc.c +78 -26
  4. package/cpp/ggml-alloc.h +9 -0
  5. package/cpp/ggml-backend-impl.h +1 -1
  6. package/cpp/ggml-backend-reg.cpp +19 -3
  7. package/cpp/ggml-backend.cpp +72 -20
  8. package/cpp/ggml-backend.h +2 -1
  9. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  10. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  11. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  12. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  14. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  16. package/cpp/ggml-cpu/ops.cpp +170 -18
  17. package/cpp/ggml-cpu/ops.h +1 -0
  18. package/cpp/ggml-cpu/repack.cpp +531 -5
  19. package/cpp/ggml-cpu/repack.h +14 -0
  20. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  21. package/cpp/ggml-cpu/vec.cpp +41 -1
  22. package/cpp/ggml-cpu/vec.h +241 -138
  23. package/cpp/ggml-cpu.h +1 -0
  24. package/cpp/ggml-impl.h +0 -4
  25. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  27. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  28. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  29. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  31. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  33. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  34. package/cpp/ggml.c +110 -31
  35. package/cpp/ggml.h +51 -12
  36. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  37. package/cpp/whisper.cpp +16 -3
  38. package/ios/CMakeLists.txt +21 -1
  39. package/ios/RNWhisperContext.mm +5 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  78. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  79. package/lib/commonjs/jest-mock.js +2 -0
  80. package/lib/commonjs/jest-mock.js.map +1 -1
  81. package/lib/commonjs/version.json +1 -1
  82. package/lib/module/NativeRNWhisper.js.map +1 -1
  83. package/lib/module/jest-mock.js +2 -0
  84. package/lib/module/jest-mock.js.map +1 -1
  85. package/lib/module/version.json +1 -1
  86. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  87. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  88. package/package.json +1 -1
  89. package/src/NativeRNWhisper.ts +1 -0
  90. package/src/jest-mock.ts +2 -0
  91. package/src/version.json +1 -1
@@ -583,6 +583,10 @@ static wsp_ggml_backend_feature * wsp_ggml_backend_cpu_get_features(wsp_ggml_bac
583
583
  if (wsp_ggml_cpu_has_riscv_v()) {
584
584
  features.push_back({ "RISCV_V", "1" });
585
585
  }
586
+ if (wsp_ggml_cpu_get_rvv_vlen() > 0) {
587
+ static std::string rvv_vlen = std::to_string(wsp_ggml_cpu_get_rvv_vlen());
588
+ features.push_back({ "RVV_VLEN", rvv_vlen.c_str() });
589
+ }
586
590
  if (wsp_ggml_cpu_has_vsx()) {
587
591
  features.push_back({ "VSX", "1" });
588
592
  }
@@ -6383,7 +6383,7 @@ static void wsp_ggml_compute_forward_im2col_3d_f16(
6383
6383
  const int64_t iih = ioh*s1 + ikh*d1 - p1;
6384
6384
  const int64_t iid = iod*s2 + ikd*d2 - p2;
6385
6385
 
6386
- if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6386
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6387
6387
  dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6388
6388
  } else {
6389
6389
  const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
@@ -6554,8 +6554,13 @@ static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_par
6554
6554
  wsp_ggml_compute_forward_mul_mat(params, &dst);
6555
6555
  }
6556
6556
 
6557
+ static inline int64_t wsp_ggml_wrap_around(int64_t coord, int64_t size) {
6558
+ return (coord + size) % size; // adding size avoids negative number weirdness
6559
+ }
6560
+
6557
6561
  // wsp_ggml_compute_forward_conv_2d
6558
6562
 
6563
+
6559
6564
  static void wsp_ggml_compute_forward_conv_2d_impl(const wsp_ggml_compute_params * params,
6560
6565
  const wsp_ggml_tensor * kernel, // [KW, KH, IC, OC]
6561
6566
  const wsp_ggml_tensor * src, // [W, H, C, N]
@@ -7420,6 +7425,65 @@ static void wsp_ggml_compute_forward_upscale_f32(
7420
7425
  }
7421
7426
  }
7422
7427
  }
7428
+ } else if (mode == WSP_GGML_SCALE_MODE_BILINEAR && (mode_flags & WSP_GGML_SCALE_FLAG_ANTIALIAS)) {
7429
+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7430
+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7431
+ auto triangle_filter = [](float x) -> float {
7432
+ return std::max(1.0f - fabsf(x), 0.0f);
7433
+ };
7434
+
7435
+ // support and invscale, minimum 1 pixel for bilinear
7436
+ const float support1 = std::max(1.0f, 1.0f / sf1);
7437
+ const float invscale1 = 1.0f / support1;
7438
+ const float support0 = std::max(1.0f, 1.0f / sf0);
7439
+ const float invscale0 = 1.0f / support0;
7440
+
7441
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
7442
+ const int64_t i03 = i3 / sf3;
7443
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7444
+ const int64_t i02 = i2 / sf2;
7445
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
7446
+ const float y = ((float) i1 + pixel_offset) / sf1;
7447
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
7448
+ const float x = ((float) i0 + pixel_offset) / sf0;
7449
+
7450
+ // the range of source pixels that contribute
7451
+ const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7452
+ const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7453
+ const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7454
+ const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7455
+
7456
+ // bilinear filter with antialiasing
7457
+ float val = 0.0f;
7458
+ float total_weight = 0.0f;
7459
+
7460
+ for (int64_t sy = y_min; sy < y_max; sy++) {
7461
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7462
+
7463
+ for (int64_t sx = x_min; sx < x_max; sx++) {
7464
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7465
+ const float weight = weight_x * weight_y;
7466
+
7467
+ if (weight <= 0.0f) {
7468
+ continue;
7469
+ }
7470
+
7471
+ const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7472
+ val += pixel * weight;
7473
+ total_weight += weight;
7474
+ }
7475
+ }
7476
+
7477
+ if (total_weight > 0.0f) {
7478
+ val /= total_weight;
7479
+ }
7480
+
7481
+ float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7482
+ *dst_ptr = val;
7483
+ }
7484
+ }
7485
+ }
7486
+ }
7423
7487
  } else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
7424
7488
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7425
7489
  const int64_t i03 = i3 / sf3;
@@ -7532,6 +7596,7 @@ void wsp_ggml_compute_forward_upscale(
7532
7596
 
7533
7597
  // wsp_ggml_compute_forward_pad
7534
7598
 
7599
+ template<bool circular_t>
7535
7600
  static void wsp_ggml_compute_forward_pad_f32(
7536
7601
  const wsp_ggml_compute_params * params,
7537
7602
  wsp_ggml_tensor * dst) {
@@ -7556,23 +7621,40 @@ static void wsp_ggml_compute_forward_pad_f32(
7556
7621
  const int32_t lp3 = wsp_ggml_get_op_params_i32(dst, 6);
7557
7622
  const int32_t rp3 = wsp_ggml_get_op_params_i32(dst, 7);
7558
7623
 
7559
-
7560
7624
  // TODO: optimize
7561
7625
 
7562
7626
  for (int64_t i2 = 0; i2 < ne2; ++i2) {
7563
7627
  for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7564
7628
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7565
7629
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7566
- const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7567
- if ((i0 >= lp0 && i0 < ne0 - rp0) \
7568
- && (i1 >= lp1 && i1 < ne1 - rp1) \
7569
- && (i2 >= lp2 && i2 < ne2 - rp2) \
7570
- && (i3 >= lp3 && i3 < ne3 - rp3)) {
7571
- const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7630
+ // circular means wrap around on a torus, so x and y loop around
7631
+ if constexpr (circular_t) {
7632
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7633
+ const int64_t src_i0 = wsp_ggml_wrap_around(i0 - lp0, ne00);
7634
+ const int64_t src_i1 = wsp_ggml_wrap_around(i1 - lp1, ne01);
7635
+ const int64_t src_i2 = wsp_ggml_wrap_around(i2 - lp2, ne02);
7636
+ const int64_t src_i3 = wsp_ggml_wrap_around(i3 - lp3, ne03);
7637
+
7638
+ const int64_t src_idx =
7639
+ src_i3*nb03 +
7640
+ src_i2*nb02 +
7641
+ src_i1*nb01 +
7642
+ src_i0*nb00;
7643
+
7572
7644
  const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7573
7645
  dst_ptr[dst_idx] = *src_ptr;
7574
7646
  } else {
7575
- dst_ptr[dst_idx] = 0;
7647
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7648
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7649
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7650
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7651
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7652
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7653
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7654
+ dst_ptr[dst_idx] = *src_ptr;
7655
+ } else {
7656
+ dst_ptr[dst_idx] = 0;
7657
+ }
7576
7658
  }
7577
7659
  }
7578
7660
  }
@@ -7580,16 +7662,20 @@ static void wsp_ggml_compute_forward_pad_f32(
7580
7662
  }
7581
7663
  }
7582
7664
 
7665
+
7583
7666
  void wsp_ggml_compute_forward_pad(
7584
7667
  const wsp_ggml_compute_params * params,
7585
7668
  wsp_ggml_tensor * dst) {
7586
-
7587
7669
  const wsp_ggml_tensor * src0 = dst->src[0];
7588
-
7670
+ const bool circular = (bool) wsp_ggml_get_op_params_i32(dst, 8);
7589
7671
  switch (src0->type) {
7590
7672
  case WSP_GGML_TYPE_F32:
7591
7673
  {
7592
- wsp_ggml_compute_forward_pad_f32(params, dst);
7674
+ if (circular) {
7675
+ wsp_ggml_compute_forward_pad_f32<true>(params, dst);
7676
+ } else {
7677
+ wsp_ggml_compute_forward_pad_f32<false>(params, dst);
7678
+ }
7593
7679
  } break;
7594
7680
  default:
7595
7681
  {
@@ -7794,7 +7880,7 @@ void wsp_ggml_compute_forward_timestep_embedding(
7794
7880
  // wsp_ggml_compute_forward_argsort
7795
7881
 
7796
7882
  template<enum wsp_ggml_sort_order order>
7797
- struct argsort_cmp {
7883
+ struct cmp_argsort {
7798
7884
  const float * data;
7799
7885
  bool operator()(int32_t a, int32_t b) const {
7800
7886
  if constexpr (order == WSP_GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7919,11 @@ static void wsp_ggml_compute_forward_argsort_f32(
7833
7919
 
7834
7920
  switch (order) {
7835
7921
  case WSP_GGML_SORT_ORDER_ASC:
7836
- std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_ASC>{src_data});
7922
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<WSP_GGML_SORT_ORDER_ASC>{src_data});
7837
7923
  break;
7838
7924
 
7839
7925
  case WSP_GGML_SORT_ORDER_DESC:
7840
- std::sort(dst_data, dst_data + ne0, argsort_cmp<WSP_GGML_SORT_ORDER_DESC>{src_data});
7926
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<WSP_GGML_SORT_ORDER_DESC>{src_data});
7841
7927
  break;
7842
7928
 
7843
7929
  default:
@@ -7864,6 +7950,72 @@ void wsp_ggml_compute_forward_argsort(
7864
7950
  }
7865
7951
  }
7866
7952
 
7953
+ // wsp_ggml_compute_forward_top_k
7954
+
7955
+ struct cmp_top_k {
7956
+ const float * data;
7957
+ bool operator()(int32_t a, int32_t b) const {
7958
+ return data[a] > data[b];
7959
+ }
7960
+ };
7961
+
7962
+ static void wsp_ggml_compute_forward_top_k_f32(
7963
+ const wsp_ggml_compute_params * params,
7964
+ wsp_ggml_tensor * dst) {
7965
+
7966
+ const wsp_ggml_tensor * src0 = dst->src[0];
7967
+
7968
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
7969
+
7970
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
7971
+
7972
+ const int ith = params->ith;
7973
+ const int nth = params->nth;
7974
+
7975
+ const int64_t nr = wsp_ggml_nrows(src0);
7976
+
7977
+ const int top_k = ne0;
7978
+
7979
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7980
+
7981
+ for (int64_t i = ith; i < nr; i += nth) {
7982
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7983
+
7984
+ for (int64_t j = 0; j < ne00; j++) {
7985
+ tmp[j] = j;
7986
+ }
7987
+
7988
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7989
+
7990
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7991
+
7992
+ std::copy(tmp, tmp + top_k, dst_data);
7993
+
7994
+ // emphasize that the order is not important
7995
+ if (top_k > 1) {
7996
+ std::swap(dst_data[0], dst_data[1]);
7997
+ }
7998
+ }
7999
+ }
8000
+
8001
+ void wsp_ggml_compute_forward_top_k(
8002
+ const wsp_ggml_compute_params * params,
8003
+ wsp_ggml_tensor * dst) {
8004
+
8005
+ const wsp_ggml_tensor * src0 = dst->src[0];
8006
+
8007
+ switch (src0->type) {
8008
+ case WSP_GGML_TYPE_F32:
8009
+ {
8010
+ wsp_ggml_compute_forward_top_k_f32(params, dst);
8011
+ } break;
8012
+ default:
8013
+ {
8014
+ WSP_GGML_ABORT("fatal error");
8015
+ }
8016
+ }
8017
+ }
8018
+
7867
8019
  // wsp_ggml_compute_forward_flash_attn_ext
7868
8020
 
7869
8021
  static void wsp_ggml_compute_forward_flash_attn_ext_f16_one_chunk(
@@ -9696,13 +9848,13 @@ static void wsp_ggml_compute_forward_solve_tri_f32(const struct wsp_ggml_compute
9696
9848
  for (int64_t i00 = 0; i00 < n; ++i00) {
9697
9849
  float sum = 0.0f;
9698
9850
  for (int64_t t = 0; t < i00; ++t) {
9699
- sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
9851
+ sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9700
9852
  }
9701
9853
 
9702
9854
  const float diag = A_batch[i00 * n + i00];
9703
- WSP_GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
9855
+ assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9704
9856
 
9705
- X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
9857
+ X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9706
9858
  }
9707
9859
  }
9708
9860
  }
@@ -81,6 +81,7 @@ void wsp_ggml_compute_forward_roll(const struct wsp_ggml_compute_params * params
81
81
  void wsp_ggml_compute_forward_arange(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
82
82
  void wsp_ggml_compute_forward_timestep_embedding(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
83
83
  void wsp_ggml_compute_forward_argsort(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
84
+ void wsp_ggml_compute_forward_top_k(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
84
85
  void wsp_ggml_compute_forward_leaky_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
85
86
  void wsp_ggml_compute_forward_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
86
87
  void wsp_ggml_compute_forward_fill(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);