whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -34,6 +34,7 @@ void wsp_ggml_compute_forward_add1(const struct wsp_ggml_compute_params * params
34
34
  void wsp_ggml_compute_forward_acc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
35
35
  void wsp_ggml_compute_forward_sum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
36
36
  void wsp_ggml_compute_forward_sum_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
37
+ void wsp_ggml_compute_forward_cumsum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
37
38
  void wsp_ggml_compute_forward_mean(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
38
39
  void wsp_ggml_compute_forward_argmax(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
39
40
  void wsp_ggml_compute_forward_count_equal(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -51,10 +52,6 @@ void wsp_ggml_compute_forward_scale(const struct wsp_ggml_compute_params * param
51
52
  void wsp_ggml_compute_forward_set(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
52
53
  void wsp_ggml_compute_forward_cpy(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
53
54
  void wsp_ggml_compute_forward_cont(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
54
- void wsp_ggml_compute_forward_reshape(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
55
- void wsp_ggml_compute_forward_view(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
56
- void wsp_ggml_compute_forward_permute(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
57
- void wsp_ggml_compute_forward_transpose(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
58
55
  void wsp_ggml_compute_forward_get_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
59
56
  void wsp_ggml_compute_forward_get_rows_back(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
60
57
  void wsp_ggml_compute_forward_set_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -85,6 +82,8 @@ void wsp_ggml_compute_forward_arange(const struct wsp_ggml_compute_params * para
85
82
  void wsp_ggml_compute_forward_timestep_embedding(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
86
83
  void wsp_ggml_compute_forward_argsort(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
87
84
  void wsp_ggml_compute_forward_leaky_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
85
+ void wsp_ggml_compute_forward_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
86
+ void wsp_ggml_compute_forward_fill(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
88
87
  void wsp_ggml_compute_forward_flash_attn_ext(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
89
88
  void wsp_ggml_compute_forward_flash_attn_back(
90
89
  const struct wsp_ggml_compute_params * params,
@@ -100,6 +99,7 @@ void wsp_ggml_compute_forward_get_rel_pos(const struct wsp_ggml_compute_params *
100
99
  void wsp_ggml_compute_forward_add_rel_pos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
101
100
  void wsp_ggml_compute_forward_rwkv_wkv6(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
102
101
  void wsp_ggml_compute_forward_rwkv_wkv7(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
102
+ void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
103
103
  void wsp_ggml_compute_forward_gla(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
104
104
  void wsp_ggml_compute_forward_map_custom1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
105
105
  void wsp_ggml_compute_forward_map_custom2(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -1600,6 +1600,55 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1600
1600
  return false;
1601
1601
  }
1602
1602
 
1603
+ void forward_mul_mat_one_chunk(wsp_ggml_compute_params * params,
1604
+ wsp_ggml_tensor * op,
1605
+ int64_t src0_start,
1606
+ int64_t src0_end,
1607
+ int64_t src1_start,
1608
+ int64_t src1_end) {
1609
+ const wsp_ggml_tensor * src0 = op->src[0];
1610
+ const wsp_ggml_tensor * src1 = op->src[1];
1611
+ wsp_ggml_tensor * dst = op;
1612
+
1613
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
1614
+
1615
+ const size_t src1_col_stride = wsp_ggml_row_size(PARAM_TYPE, ne10);
1616
+
1617
+ WSP_GGML_ASSERT(ne03 == 1 && ne13 == 1);
1618
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1619
+ const int64_t r2 = ne12 / ne02;
1620
+
1621
+ const int64_t i12 = src1_start / ne1;
1622
+ const int64_t i11 = src1_start - i12 * ne1;
1623
+
1624
+ // Determine batch index
1625
+ const int64_t i02 = i12 / r2;
1626
+
1627
+ const int64_t i1 = i11;
1628
+ const int64_t i2 = i12;
1629
+
1630
+ const char * src0_ptr = (const char *) src0->data + i02 * nb02;
1631
+ const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
1632
+ char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
1633
+
1634
+ const int64_t nrows = src1_end - src1_start;
1635
+ const int64_t ncols = src0_end - src0_start;
1636
+
1637
+ WSP_GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
1638
+
1639
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1640
+ if (nrows > 3) {
1641
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
1642
+ src0_ptr + src0_start * nb01, src1_ptr,
1643
+ nrows - (nrows % 4), ncols);
1644
+ }
1645
+ for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
1646
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
1647
+ ne01, src0_ptr + src0_start * nb01,
1648
+ src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
1649
+ }
1650
+ }
1651
+
1603
1652
  void forward_mul_mat(wsp_ggml_compute_params * params, wsp_ggml_tensor * op) {
1604
1653
  const wsp_ggml_tensor * src0 = op->src[0];
1605
1654
  const wsp_ggml_tensor * src1 = op->src[1];
@@ -1621,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1621
1670
  WSP_GGML_ASSERT(nb1 <= nb2);
1622
1671
  WSP_GGML_ASSERT(nb2 <= nb3);
1623
1672
 
1673
+ // TODO: General batched mul mat for 4D tensors
1674
+ // Currently only supports 3D tensors
1675
+ WSP_GGML_ASSERT(ne03 == 1);
1676
+ WSP_GGML_ASSERT(ne13 == 1);
1677
+ WSP_GGML_ASSERT(ne3 == 1);
1678
+
1624
1679
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1625
1680
 
1626
1681
  WSP_GGML_ASSERT(wsp_ggml_n_dims(op->src[0]) == 2);
@@ -1628,46 +1683,101 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1628
1683
 
1629
1684
  char * wdata = static_cast<char *>(params->wdata);
1630
1685
  const size_t nbw1 = wsp_ggml_row_size(PARAM_TYPE, ne10);
1686
+ const size_t nbw2 = nbw1 * ne11;
1631
1687
 
1632
- assert(params->wsize >= nbw1 * ne11);
1688
+ assert(params->wsize >= nbw2 * ne12);
1633
1689
 
1634
1690
  const wsp_ggml_from_float_t from_float = wsp_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
1635
1691
 
1636
- int64_t i11_processed = 0;
1637
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1638
- wsp_ggml_wsp_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1639
- }
1692
+ // INFO: Quantization is done in planes to avoid extra complexity in chunking.
1693
+ // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
1694
+ // the planes are broadcast.
1695
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
1696
+ char * data_ptr = (char *) src1->data + i12 * nb12;
1697
+ char * wdata_ptr = wdata + i12 * nbw2;
1698
+
1699
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1700
+ wsp_ggml_wsp_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
1701
+ (void *) (wdata_ptr + i11 * nbw1), 4, ne10);
1702
+ }
1640
1703
 
1641
- i11_processed = ne11 - ne11 % 4;
1642
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1643
- from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1704
+ const int64_t i11_processed = ne11 - ne11 % 4;
1705
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1706
+ from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
1707
+ }
1644
1708
  }
1645
1709
 
1646
- wsp_ggml_barrier(params->threadpool);
1710
+ // disable for NUMA
1711
+ const bool disable_chunking = wsp_ggml_is_numa();
1647
1712
 
1648
- const void * src1_wdata = params->wdata;
1649
- const size_t src1_col_stride = wsp_ggml_row_size(PARAM_TYPE, ne10);
1650
- int64_t src0_start = (ith * ne01) / nth;
1651
- int64_t src0_end = ((ith + 1) * ne01) / nth;
1652
- src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1653
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654
- if (src0_start >= src0_end) {
1655
- return;
1713
+ // 4x chunks per thread
1714
+ const int64_t nr0 = wsp_ggml_nrows(op->src[0]);
1715
+
1716
+ int nth_scaled = nth * 4;
1717
+ int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
1718
+ int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
1719
+
1720
+ // src1 is chunked only by full planes.
1721
+ // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
1722
+ // to route them thorugh GEMV.
1723
+ // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
1724
+ // to avoid affecting their performance
1725
+ int64_t nchunk1 = ne12;
1726
+
1727
+ // Ensure minimum chunk size to avoid alignment issues with high thread counts
1728
+ // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
1729
+ const int64_t min_chunk_size = NB_COLS;
1730
+ if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
1731
+ nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1656
1732
  }
1657
1733
 
1658
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659
- if (ne11 > 3) {
1660
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661
- (float *) ((char *) dst->data) + src0_start, ne01,
1662
- (const char *) src0->data + src0_start * nb01,
1663
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1734
+ if (nth == 1 || nchunk0 < nth || disable_chunking) {
1735
+ nchunk0 = nth;
1664
1736
  }
1665
- for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1666
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667
- (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668
- (const char *) src0->data + src0_start * nb01,
1669
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
1670
- src0_end - src0_start);
1737
+
1738
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739
+
1740
+ // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1741
+ // This prevents creating too many tiny chunks that could overlap after alignment
1742
+ const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
1743
+ nchunk0 = MIN(nchunk0, max_nchunk);
1744
+
1745
+ if (ith == 0) {
1746
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1747
+ wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
1748
+ }
1749
+
1750
+ wsp_ggml_barrier(params->threadpool);
1751
+
1752
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
1753
+ int current_chunk = ith;
1754
+
1755
+ while (current_chunk < nchunk0 * nchunk1) {
1756
+ const int64_t ith0 = current_chunk % nchunk0;
1757
+ const int64_t ith1 = current_chunk / nchunk0;
1758
+
1759
+ int64_t src0_start = dr0 * ith0;
1760
+ int64_t src0_end = MIN(src0_start + dr0, nr0);
1761
+
1762
+ // full-plane range for src1
1763
+ int64_t src1_start = ith1 * ne11;
1764
+ int64_t src1_end = (ith1 + 1) * ne11;
1765
+
1766
+ // Align boundaries to NB_COLS - round up to ensure all data is included
1767
+ // The chunk size limiting above ensures chunks are large enough to prevent overlaps
1768
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1769
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1770
+ src0_end = MIN(src0_end, ne01);
1771
+
1772
+ // Make sure current plane is the last one before exiting
1773
+ if (src0_start >= src0_end) {
1774
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
1775
+ continue;
1776
+ }
1777
+
1778
+ forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
1779
+
1780
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
1671
1781
  }
1672
1782
  }
1673
1783
 
@@ -1772,8 +1882,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1772
1882
  int64_t src0_cur_start = (ith * ne01) / nth;
1773
1883
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1774
1884
 
1885
+ // Align boundaries to NB_COLS - round up to ensure all data is included
1775
1886
  src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
1776
1887
  src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1888
+ if (src0_cur_end > ne01) {
1889
+ src0_cur_end = ne01;
1890
+ }
1777
1891
 
1778
1892
  if (src0_cur_start >= src0_cur_end) {
1779
1893
  return;
@@ -956,7 +956,7 @@ do { \
956
956
 
957
957
  #define WSP_GGML_F32Cx8 __m256
958
958
  #define WSP_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
959
- #define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
959
+ #define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
960
960
 
961
961
  static inline __m256 __lasx_f32cx8_load(const wsp_ggml_fp16_t * x) {
962
962
  __m256i a;
@@ -999,34 +999,34 @@ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
999
999
 
1000
1000
  #define WSP_GGML_F32x4 __m128
1001
1001
  #define WSP_GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
1002
- #define WSP_GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1002
+ #define WSP_GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
1003
1003
  #define WSP_GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
1004
1004
  #define WSP_GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
1005
1005
  #define WSP_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1006
1006
  #define WSP_GGML_F32x4_ADD __lsx_vfadd_s
1007
1007
  #define WSP_GGML_F32x4_MUL __lsx_vfmul_s
1008
- #define WSP_GGML_F32x4_REDUCE(res, x) \
1009
- { \
1010
- int offset = WSP_GGML_F32_ARR >> 1; \
1011
- for (int i = 0; i < offset; ++i) { \
1012
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1013
- } \
1014
- offset >>= 1; \
1015
- for (int i = 0; i < offset; ++i) { \
1016
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1017
- } \
1018
- offset >>= 1; \
1019
- for (int i = 0; i < offset; ++i) { \
1020
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1021
- } \
1022
- __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
1023
- tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
1024
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1025
- const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
1026
- tmp = __lsx_vsrli_d((__m128i) t0, 32); \
1027
- tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
1028
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1029
- res = (wsp_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1008
+
1009
+ #define WSP_GGML_F32x4_REDUCE(res, x) \
1010
+ { \
1011
+ int offset = WSP_GGML_F32_ARR >> 1; \
1012
+ for (int i = 0; i < offset; ++i) { \
1013
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1014
+ } \
1015
+ offset >>= 1; \
1016
+ for (int i = 0; i < offset; ++i) { \
1017
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1018
+ } \
1019
+ offset >>= 1; \
1020
+ for (int i = 0; i < offset; ++i) { \
1021
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1022
+ } \
1023
+ __m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
1024
+ __m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
1025
+ __m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
1026
+ __m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
1027
+ __m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
1028
+ __m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
1029
+ res = (wsp_ggml_float) ((v4f32)t5)[0]; \
1030
1030
  }
1031
1031
 
1032
1032
  #define WSP_GGML_F32_VEC WSP_GGML_F32x4
@@ -1068,7 +1068,7 @@ static inline void __lsx_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
1068
1068
 
1069
1069
  #define WSP_GGML_F32Cx4 __m128
1070
1070
  #define WSP_GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
1071
- #define WSP_GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1071
+ #define WSP_GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
1072
1072
  #define WSP_GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
1073
1073
  #define WSP_GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1074
1074
  #define WSP_GGML_F32Cx4_FMA WSP_GGML_F32x4_FMA
@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
52
52
  return sqrtf(x);
53
53
  }
54
54
 
55
+ static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
56
+ if (x > 0.0f) {
57
+ return alpha_p * x * x + beta * x;
58
+ } else {
59
+ const float min_x_eps = fminf(x, eps);
60
+ return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
61
+ }
62
+ }
63
+
55
64
  static inline float op_sin(float x) {
56
65
  return sinf(x);
57
66
  }
@@ -64,6 +73,30 @@ static inline float op_log(float x) {
64
73
  return logf(x);
65
74
  }
66
75
 
76
+ static inline float op_expm1(float x) {
77
+ return expf(x) - 1.0f;
78
+ }
79
+
80
+ static inline float op_softplus(float x) {
81
+ return (x > 20.0f) ? x : logf(1.0f + expf(x));
82
+ }
83
+
84
+ static inline float op_floor(float x) {
85
+ return floorf(x);
86
+ }
87
+
88
+ static inline float op_ceil(float x) {
89
+ return ceilf(x);
90
+ }
91
+
92
+ static inline float op_round(float x) {
93
+ return roundf(x);
94
+ }
95
+
96
+ static inline float op_trunc(float x) {
97
+ return truncf(x);
98
+ }
99
+
67
100
  template <float (*op)(float), typename src0_t, typename dst_t>
68
101
  static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
69
102
  constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -121,6 +154,86 @@ static void unary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * d
121
154
  }
122
155
  }
123
156
 
157
+ template <float (*op)(float, wsp_ggml_tensor *)>
158
+ static void unary_op_params(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
159
+ const wsp_ggml_tensor * src0 = dst->src[0];
160
+
161
+ /* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
162
+ apply_unary_op<op, float, float>(params, dst);
163
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
164
+ apply_unary_op<op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
165
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
166
+ apply_unary_op<op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
167
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
168
+ apply_unary_op<op, wsp_ggml_bf16_t, float>(params, dst);
169
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
170
+ apply_unary_op<op, wsp_ggml_fp16_t, float>(params, dst);
171
+ } else {
172
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
173
+ wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
174
+ WSP_GGML_ABORT("fatal error");
175
+ }
176
+ }
177
+
178
+ // Extend vec_unary_op to support functors
179
+ template <typename Op, typename src0_t, typename dst_t>
180
+ static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
181
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
182
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
183
+
184
+ for (int i = 0; i < n; i++) {
185
+ y[i] = f32_to_dst(op(src0_to_f32(x[i])));
186
+ }
187
+ }
188
+
189
+ // Extend apply_unary_op to support functors
190
+ template <typename Op, typename src0_t, typename dst_t>
191
+ static void apply_unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
192
+ const wsp_ggml_tensor * src0 = dst->src[0];
193
+
194
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0) && wsp_ggml_is_contiguous_1(dst) && wsp_ggml_are_same_shape(src0, dst));
195
+
196
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
197
+
198
+ WSP_GGML_ASSERT( nb0 == sizeof(dst_t));
199
+ WSP_GGML_ASSERT(nb00 == sizeof(src0_t));
200
+
201
+ const auto [ir0, ir1] = get_thread_range(params, src0);
202
+
203
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
204
+ const int64_t i03 = ir/(ne02*ne01);
205
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
206
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
207
+
208
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
209
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
210
+
211
+ vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
212
+ }
213
+ }
214
+
215
+ // Generic dispatcher for functors
216
+ template <typename Op>
217
+ static void unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
218
+ const wsp_ggml_tensor * src0 = dst->src[0];
219
+
220
+ /* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
221
+ apply_unary_op_functor<Op, float, float>(params, dst, op);
222
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
223
+ apply_unary_op_functor<Op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst, op);
224
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
225
+ apply_unary_op_functor<Op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst, op);
226
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
227
+ apply_unary_op_functor<Op, wsp_ggml_bf16_t, float>(params, dst, op);
228
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
229
+ apply_unary_op_functor<Op, wsp_ggml_fp16_t, float>(params, dst, op);
230
+ } else {
231
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
232
+ wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
233
+ WSP_GGML_ABORT("fatal error");
234
+ }
235
+ }
236
+
124
237
  void wsp_ggml_compute_forward_abs(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
125
238
  unary_op<op_abs>(params, dst);
126
239
  }
@@ -184,3 +297,41 @@ void wsp_ggml_compute_forward_cos(const wsp_ggml_compute_params * params, wsp_gg
184
297
  void wsp_ggml_compute_forward_log(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
185
298
  unary_op<op_log>(params, dst);
186
299
  }
300
+
301
+ void wsp_ggml_compute_forward_expm1(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
302
+ unary_op<op_expm1>(params, dst);
303
+ }
304
+
305
+ void wsp_ggml_compute_forward_softplus(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
306
+ unary_op<op_softplus>(params, dst);
307
+ }
308
+
309
+ void wsp_ggml_compute_forward_floor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
310
+ unary_op<op_floor>(params, dst);
311
+ }
312
+
313
+ void wsp_ggml_compute_forward_ceil(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
314
+ unary_op<op_ceil>(params, dst);
315
+ }
316
+
317
+ void wsp_ggml_compute_forward_round(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
318
+ unary_op<op_round>(params, dst);
319
+ }
320
+
321
+ void wsp_ggml_compute_forward_trunc(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
322
+ unary_op<op_trunc>(params, dst);
323
+ }
324
+
325
+ void wsp_ggml_compute_forward_xielu(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
326
+ const float alpha_n = wsp_ggml_get_op_params_f32(dst, 1);
327
+ const float alpha_p = wsp_ggml_get_op_params_f32(dst, 2);
328
+ const float beta = wsp_ggml_get_op_params_f32(dst, 3);
329
+ const float eps = wsp_ggml_get_op_params_f32(dst, 4);
330
+
331
+ const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
332
+ return op_xielu(f, alpha_n, alpha_p, beta, eps);
333
+ };
334
+
335
+ unary_op_functor(params, dst, xielu_op_params);
336
+ }
337
+
@@ -22,6 +22,13 @@ void wsp_ggml_compute_forward_sqrt(const struct wsp_ggml_compute_params * params
22
22
  void wsp_ggml_compute_forward_sin(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
23
23
  void wsp_ggml_compute_forward_cos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
24
24
  void wsp_ggml_compute_forward_log(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
+ void wsp_ggml_compute_forward_expm1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
26
+ void wsp_ggml_compute_forward_softplus(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
27
+ void wsp_ggml_compute_forward_floor(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
28
+ void wsp_ggml_compute_forward_ceil(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
29
+ void wsp_ggml_compute_forward_round(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
30
+ void wsp_ggml_compute_forward_trunc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
31
+ void wsp_ggml_compute_forward_xielu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
32
 
26
33
  #ifdef __cplusplus
27
34
  }
@@ -360,6 +360,13 @@ void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) {
360
360
  for (; i + 3 < n; i += 4) {
361
361
  vst1q_f32(y + i, wsp_ggml_v_silu(vld1q_f32(x + i)));
362
362
  }
363
+ #elif defined(__riscv_v_intrinsic)
364
+ for (int vl; i < n; i += vl) {
365
+ vl = __riscv_vsetvl_e32m2(n - i);
366
+ vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
367
+ vfloat32m2_t vy = wsp_ggml_v_silu_m2(vx, vl);
368
+ __riscv_vse32_v_f32m2(&y[i], vy, vl);
369
+ }
363
370
  #endif
364
371
  for (; i < n; ++i) {
365
372
  y[i] = wsp_ggml_silu_f32(x[i]);
@@ -404,6 +411,82 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
404
411
  }
405
412
  }
406
413
 
414
+ wsp_ggml_float wsp_ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
415
+ int i = 0;
416
+ wsp_ggml_float sum = 0;
417
+ // TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
418
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
419
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
420
+ for (; i + 15 < n; i += 16) {
421
+ __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
422
+ _mm512_set1_ps(mean));
423
+ _mm512_storeu_ps(y + i, val);
424
+ sum += (wsp_ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
425
+ }
426
+ #elif defined(__AVX2__) && defined(__FMA__)
427
+ for (; i + 7 < n; i += 8) {
428
+ __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
429
+ _mm256_set1_ps(mean));
430
+ _mm256_storeu_ps(y + i, val);
431
+ val = _mm256_mul_ps(val,val);
432
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
433
+ _mm256_castps256_ps128(val));
434
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
435
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
436
+ sum += (wsp_ggml_float)_mm_cvtss_f32(val2);
437
+ }
438
+ #elif defined(__SSE2__)
439
+ for (; i + 3 < n; i += 4) {
440
+ __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
441
+ _mm_set1_ps(mean));
442
+ _mm_storeu_ps(y + i, val);
443
+ val = _mm_mul_ps(val, val);
444
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
445
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
446
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
447
+ #else
448
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
449
+ val = _mm_add_ps(val, tmp);
450
+ tmp = _mm_movehl_ps(tmp, val);
451
+ val = _mm_add_ss(val, tmp);
452
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
453
+ sum += (wsp_ggml_float)_mm_cvtss_f32(val);
454
+ }
455
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
456
+ for (; i + 3 < n; i += 4) {
457
+ float32x4_t val = vsubq_f32(vld1q_f32(x + i),
458
+ vdupq_n_f32(mean));
459
+ vst1q_f32(y + i, val);
460
+ val = vmulq_f32(val, val);
461
+ sum += (wsp_ggml_float)vaddvq_f32(val);
462
+ }
463
+ #elif defined(__VXE__) || defined(__VXE2__)
464
+ for (; i + 3 < n; i += 4) {
465
+ float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
466
+ vec_xst(val, 0, y + i);
467
+ val = vec_mul(val, val);
468
+ sum += (wsp_ggml_float)vec_hsum_f32x4(val);
469
+ }
470
+ #elif defined(__riscv_v_intrinsic)
471
+ vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
472
+ for (int vl; i < n; i += vl) {
473
+ vl = __riscv_vsetvl_e32m2(n - i);
474
+ vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
475
+ __riscv_vse32_v_f32m2(&y[i], val, vl);
476
+ val = __riscv_vfmul_vv_f32m2(val, val, vl);
477
+ vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
478
+ }
479
+ sum = (wsp_ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
480
+ #endif
481
+ for (; i < n; ++i) {
482
+ float val = x[i] - mean;
483
+ y[i] = val;
484
+ val *= val;
485
+ sum += (wsp_ggml_float)val;
486
+ }
487
+ return sum/n;
488
+ }
489
+
407
490
  wsp_ggml_float wsp_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
408
491
  int i = 0;
409
492
  wsp_ggml_float sum = 0;