whisper.rn 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. package/cpp/ggml-alloc.c +11 -4
  2. package/cpp/ggml-backend-reg.cpp +8 -0
  3. package/cpp/ggml-backend.cpp +0 -2
  4. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  5. package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
  6. package/cpp/ggml-cpu/ggml-cpu.c +50 -21
  7. package/cpp/ggml-cpu/ops.cpp +458 -349
  8. package/cpp/ggml-cpu/ops.h +4 -4
  9. package/cpp/ggml-cpu/repack.cpp +143 -29
  10. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  11. package/cpp/ggml-cpu/unary-ops.cpp +16 -0
  12. package/cpp/ggml-cpu/unary-ops.h +2 -0
  13. package/cpp/ggml-cpu/vec.cpp +17 -0
  14. package/cpp/ggml-cpu/vec.h +10 -0
  15. package/cpp/ggml-impl.h +17 -1
  16. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  17. package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
  18. package/cpp/ggml-metal/ggml-metal-device.h +8 -1
  19. package/cpp/ggml-metal/ggml-metal-device.m +216 -14
  20. package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
  21. package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
  22. package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
  23. package/cpp/ggml-metal/ggml-metal.cpp +5 -0
  24. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  25. package/cpp/ggml.c +154 -5
  26. package/cpp/ggml.h +73 -0
  27. package/cpp/whisper.cpp +5 -1
  28. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  29. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  33. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  34. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  39. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  40. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  41. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  45. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  48. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  50. package/package.json +1 -1
  51. package/whisper-rn.podspec +1 -1
  52. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  53. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -34,6 +34,7 @@ void wsp_ggml_compute_forward_add1(const struct wsp_ggml_compute_params * params
34
34
  void wsp_ggml_compute_forward_acc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
35
35
  void wsp_ggml_compute_forward_sum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
36
36
  void wsp_ggml_compute_forward_sum_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
37
+ void wsp_ggml_compute_forward_cumsum(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
37
38
  void wsp_ggml_compute_forward_mean(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
38
39
  void wsp_ggml_compute_forward_argmax(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
39
40
  void wsp_ggml_compute_forward_count_equal(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -51,10 +52,6 @@ void wsp_ggml_compute_forward_scale(const struct wsp_ggml_compute_params * param
51
52
  void wsp_ggml_compute_forward_set(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
52
53
  void wsp_ggml_compute_forward_cpy(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
53
54
  void wsp_ggml_compute_forward_cont(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
54
- void wsp_ggml_compute_forward_reshape(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
55
- void wsp_ggml_compute_forward_view(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
56
- void wsp_ggml_compute_forward_permute(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
57
- void wsp_ggml_compute_forward_transpose(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
58
55
  void wsp_ggml_compute_forward_get_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
59
56
  void wsp_ggml_compute_forward_get_rows_back(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
60
57
  void wsp_ggml_compute_forward_set_rows(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -85,6 +82,8 @@ void wsp_ggml_compute_forward_arange(const struct wsp_ggml_compute_params * para
85
82
  void wsp_ggml_compute_forward_timestep_embedding(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
86
83
  void wsp_ggml_compute_forward_argsort(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
87
84
  void wsp_ggml_compute_forward_leaky_relu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
85
+ void wsp_ggml_compute_forward_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
86
+ void wsp_ggml_compute_forward_fill(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
88
87
  void wsp_ggml_compute_forward_flash_attn_ext(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
89
88
  void wsp_ggml_compute_forward_flash_attn_back(
90
89
  const struct wsp_ggml_compute_params * params,
@@ -100,6 +99,7 @@ void wsp_ggml_compute_forward_get_rel_pos(const struct wsp_ggml_compute_params *
100
99
  void wsp_ggml_compute_forward_add_rel_pos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
101
100
  void wsp_ggml_compute_forward_rwkv_wkv6(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
102
101
  void wsp_ggml_compute_forward_rwkv_wkv7(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
102
+ void wsp_ggml_compute_forward_solve_tri(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
103
103
  void wsp_ggml_compute_forward_gla(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
104
104
  void wsp_ggml_compute_forward_map_custom1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
105
105
  void wsp_ggml_compute_forward_map_custom2(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -1600,6 +1600,55 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1600
1600
  return false;
1601
1601
  }
1602
1602
 
1603
+ void forward_mul_mat_one_chunk(wsp_ggml_compute_params * params,
1604
+ wsp_ggml_tensor * op,
1605
+ int64_t src0_start,
1606
+ int64_t src0_end,
1607
+ int64_t src1_start,
1608
+ int64_t src1_end) {
1609
+ const wsp_ggml_tensor * src0 = op->src[0];
1610
+ const wsp_ggml_tensor * src1 = op->src[1];
1611
+ wsp_ggml_tensor * dst = op;
1612
+
1613
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
1614
+
1615
+ const size_t src1_col_stride = wsp_ggml_row_size(PARAM_TYPE, ne10);
1616
+
1617
+ WSP_GGML_ASSERT(ne03 == 1 && ne13 == 1);
1618
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1619
+ const int64_t r2 = ne12 / ne02;
1620
+
1621
+ const int64_t i12 = src1_start / ne1;
1622
+ const int64_t i11 = src1_start - i12 * ne1;
1623
+
1624
+ // Determine batch index
1625
+ const int64_t i02 = i12 / r2;
1626
+
1627
+ const int64_t i1 = i11;
1628
+ const int64_t i2 = i12;
1629
+
1630
+ const char * src0_ptr = (const char *) src0->data + i02 * nb02;
1631
+ const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
1632
+ char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
1633
+
1634
+ const int64_t nrows = src1_end - src1_start;
1635
+ const int64_t ncols = src0_end - src0_start;
1636
+
1637
+ WSP_GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
1638
+
1639
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1640
+ if (nrows > 3) {
1641
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
1642
+ src0_ptr + src0_start * nb01, src1_ptr,
1643
+ nrows - (nrows % 4), ncols);
1644
+ }
1645
+ for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
1646
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
1647
+ ne01, src0_ptr + src0_start * nb01,
1648
+ src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
1649
+ }
1650
+ }
1651
+
1603
1652
  void forward_mul_mat(wsp_ggml_compute_params * params, wsp_ggml_tensor * op) {
1604
1653
  const wsp_ggml_tensor * src0 = op->src[0];
1605
1654
  const wsp_ggml_tensor * src1 = op->src[1];
@@ -1621,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1621
1670
  WSP_GGML_ASSERT(nb1 <= nb2);
1622
1671
  WSP_GGML_ASSERT(nb2 <= nb3);
1623
1672
 
1673
+ // TODO: General batched mul mat for 4D tensors
1674
+ // Currently only supports 3D tensors
1675
+ WSP_GGML_ASSERT(ne03 == 1);
1676
+ WSP_GGML_ASSERT(ne13 == 1);
1677
+ WSP_GGML_ASSERT(ne3 == 1);
1678
+
1624
1679
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1625
1680
 
1626
1681
  WSP_GGML_ASSERT(wsp_ggml_n_dims(op->src[0]) == 2);
@@ -1628,46 +1683,101 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1628
1683
 
1629
1684
  char * wdata = static_cast<char *>(params->wdata);
1630
1685
  const size_t nbw1 = wsp_ggml_row_size(PARAM_TYPE, ne10);
1686
+ const size_t nbw2 = nbw1 * ne11;
1631
1687
 
1632
- assert(params->wsize >= nbw1 * ne11);
1688
+ assert(params->wsize >= nbw2 * ne12);
1633
1689
 
1634
1690
  const wsp_ggml_from_float_t from_float = wsp_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
1635
1691
 
1636
- int64_t i11_processed = 0;
1637
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1638
- wsp_ggml_wsp_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1639
- }
1692
+ // INFO: Quantization is done in planes to avoid extra complexity in chunking.
1693
+ // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
1694
+ // the planes are broadcast.
1695
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
1696
+ char * data_ptr = (char *) src1->data + i12 * nb12;
1697
+ char * wdata_ptr = wdata + i12 * nbw2;
1698
+
1699
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1700
+ wsp_ggml_wsp_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
1701
+ (void *) (wdata_ptr + i11 * nbw1), 4, ne10);
1702
+ }
1640
1703
 
1641
- i11_processed = ne11 - ne11 % 4;
1642
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1643
- from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1704
+ const int64_t i11_processed = ne11 - ne11 % 4;
1705
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1706
+ from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
1707
+ }
1644
1708
  }
1645
1709
 
1646
- wsp_ggml_barrier(params->threadpool);
1710
+ // disable for NUMA
1711
+ const bool disable_chunking = wsp_ggml_is_numa();
1647
1712
 
1648
- const void * src1_wdata = params->wdata;
1649
- const size_t src1_col_stride = wsp_ggml_row_size(PARAM_TYPE, ne10);
1650
- int64_t src0_start = (ith * ne01) / nth;
1651
- int64_t src0_end = ((ith + 1) * ne01) / nth;
1652
- src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1653
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654
- if (src0_start >= src0_end) {
1655
- return;
1713
+ // 4x chunks per thread
1714
+ const int64_t nr0 = wsp_ggml_nrows(op->src[0]);
1715
+
1716
+ int nth_scaled = nth * 4;
1717
+ int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
1718
+ int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
1719
+
1720
+ // src1 is chunked only by full planes.
1721
+ // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
1722
+ // to route them thorugh GEMV.
1723
+ // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
1724
+ // to avoid affecting their performance
1725
+ int64_t nchunk1 = ne12;
1726
+
1727
+ // Ensure minimum chunk size to avoid alignment issues with high thread counts
1728
+ // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
1729
+ const int64_t min_chunk_size = NB_COLS;
1730
+ if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
1731
+ nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1656
1732
  }
1657
1733
 
1658
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659
- if (ne11 > 3) {
1660
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661
- (float *) ((char *) dst->data) + src0_start, ne01,
1662
- (const char *) src0->data + src0_start * nb01,
1663
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1734
+ if (nth == 1 || nchunk0 < nth || disable_chunking) {
1735
+ nchunk0 = nth;
1664
1736
  }
1665
- for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1666
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667
- (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668
- (const char *) src0->data + src0_start * nb01,
1669
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
1670
- src0_end - src0_start);
1737
+
1738
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739
+
1740
+ // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1741
+ // This prevents creating too many tiny chunks that could overlap after alignment
1742
+ const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
1743
+ nchunk0 = MIN(nchunk0, max_nchunk);
1744
+
1745
+ if (ith == 0) {
1746
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1747
+ wsp_ggml_threadpool_chunk_set(params->threadpool, nth);
1748
+ }
1749
+
1750
+ wsp_ggml_barrier(params->threadpool);
1751
+
1752
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
1753
+ int current_chunk = ith;
1754
+
1755
+ while (current_chunk < nchunk0 * nchunk1) {
1756
+ const int64_t ith0 = current_chunk % nchunk0;
1757
+ const int64_t ith1 = current_chunk / nchunk0;
1758
+
1759
+ int64_t src0_start = dr0 * ith0;
1760
+ int64_t src0_end = MIN(src0_start + dr0, nr0);
1761
+
1762
+ // full-plane range for src1
1763
+ int64_t src1_start = ith1 * ne11;
1764
+ int64_t src1_end = (ith1 + 1) * ne11;
1765
+
1766
+ // Align boundaries to NB_COLS - round up to ensure all data is included
1767
+ // The chunk size limiting above ensures chunks are large enough to prevent overlaps
1768
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1769
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1770
+ src0_end = MIN(src0_end, ne01);
1771
+
1772
+ // Make sure current plane is the last one before exiting
1773
+ if (src0_start >= src0_end) {
1774
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
1775
+ continue;
1776
+ }
1777
+
1778
+ forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
1779
+
1780
+ current_chunk = wsp_ggml_threadpool_chunk_add(params->threadpool, 1);
1671
1781
  }
1672
1782
  }
1673
1783
 
@@ -1772,8 +1882,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1772
1882
  int64_t src0_cur_start = (ith * ne01) / nth;
1773
1883
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1774
1884
 
1885
+ // Align boundaries to NB_COLS - round up to ensure all data is included
1775
1886
  src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
1776
1887
  src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1888
+ if (src0_cur_end > ne01) {
1889
+ src0_cur_end = ne01;
1890
+ }
1777
1891
 
1778
1892
  if (src0_cur_start >= src0_cur_end) {
1779
1893
  return;
@@ -956,7 +956,7 @@ do { \
956
956
 
957
957
  #define WSP_GGML_F32Cx8 __m256
958
958
  #define WSP_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
959
- #define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
959
+ #define WSP_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
960
960
 
961
961
  static inline __m256 __lasx_f32cx8_load(const wsp_ggml_fp16_t * x) {
962
962
  __m256i a;
@@ -999,34 +999,34 @@ static inline void __lasx_f32cx8_store(wsp_ggml_fp16_t * x, __m256 y) {
999
999
 
1000
1000
  #define WSP_GGML_F32x4 __m128
1001
1001
  #define WSP_GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
1002
- #define WSP_GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1002
+ #define WSP_GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
1003
1003
  #define WSP_GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
1004
1004
  #define WSP_GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
1005
1005
  #define WSP_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1006
1006
  #define WSP_GGML_F32x4_ADD __lsx_vfadd_s
1007
1007
  #define WSP_GGML_F32x4_MUL __lsx_vfmul_s
1008
- #define WSP_GGML_F32x4_REDUCE(res, x) \
1009
- { \
1010
- int offset = WSP_GGML_F32_ARR >> 1; \
1011
- for (int i = 0; i < offset; ++i) { \
1012
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1013
- } \
1014
- offset >>= 1; \
1015
- for (int i = 0; i < offset; ++i) { \
1016
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1017
- } \
1018
- offset >>= 1; \
1019
- for (int i = 0; i < offset; ++i) { \
1020
- x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1021
- } \
1022
- __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
1023
- tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
1024
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1025
- const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
1026
- tmp = __lsx_vsrli_d((__m128i) t0, 32); \
1027
- tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
1028
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1029
- res = (wsp_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1008
+
1009
+ #define WSP_GGML_F32x4_REDUCE(res, x) \
1010
+ { \
1011
+ int offset = WSP_GGML_F32_ARR >> 1; \
1012
+ for (int i = 0; i < offset; ++i) { \
1013
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1014
+ } \
1015
+ offset >>= 1; \
1016
+ for (int i = 0; i < offset; ++i) { \
1017
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1018
+ } \
1019
+ offset >>= 1; \
1020
+ for (int i = 0; i < offset; ++i) { \
1021
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1022
+ } \
1023
+ __m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
1024
+ __m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
1025
+ __m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
1026
+ __m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
1027
+ __m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
1028
+ __m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
1029
+ res = (wsp_ggml_float) ((v4f32)t5)[0]; \
1030
1030
  }
1031
1031
 
1032
1032
  #define WSP_GGML_F32_VEC WSP_GGML_F32x4
@@ -1068,7 +1068,7 @@ static inline void __lsx_f16x4_store(wsp_ggml_fp16_t * x, __m128 y) {
1068
1068
 
1069
1069
  #define WSP_GGML_F32Cx4 __m128
1070
1070
  #define WSP_GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
1071
- #define WSP_GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1071
+ #define WSP_GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
1072
1072
  #define WSP_GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
1073
1073
  #define WSP_GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1074
1074
  #define WSP_GGML_F32Cx4_FMA WSP_GGML_F32x4_FMA
@@ -73,6 +73,14 @@ static inline float op_log(float x) {
73
73
  return logf(x);
74
74
  }
75
75
 
76
+ static inline float op_expm1(float x) {
77
+ return expf(x) - 1.0f;
78
+ }
79
+
80
+ static inline float op_softplus(float x) {
81
+ return (x > 20.0f) ? x : logf(1.0f + expf(x));
82
+ }
83
+
76
84
  static inline float op_floor(float x) {
77
85
  return floorf(x);
78
86
  }
@@ -290,6 +298,14 @@ void wsp_ggml_compute_forward_log(const wsp_ggml_compute_params * params, wsp_gg
290
298
  unary_op<op_log>(params, dst);
291
299
  }
292
300
 
301
+ void wsp_ggml_compute_forward_expm1(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
302
+ unary_op<op_expm1>(params, dst);
303
+ }
304
+
305
+ void wsp_ggml_compute_forward_softplus(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
306
+ unary_op<op_softplus>(params, dst);
307
+ }
308
+
293
309
  void wsp_ggml_compute_forward_floor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
294
310
  unary_op<op_floor>(params, dst);
295
311
  }
@@ -22,6 +22,8 @@ void wsp_ggml_compute_forward_sqrt(const struct wsp_ggml_compute_params * params
22
22
  void wsp_ggml_compute_forward_sin(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
23
23
  void wsp_ggml_compute_forward_cos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
24
24
  void wsp_ggml_compute_forward_log(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
+ void wsp_ggml_compute_forward_expm1(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
26
+ void wsp_ggml_compute_forward_softplus(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
27
  void wsp_ggml_compute_forward_floor(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
26
28
  void wsp_ggml_compute_forward_ceil(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
27
29
  void wsp_ggml_compute_forward_round(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
@@ -360,6 +360,13 @@ void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x) {
360
360
  for (; i + 3 < n; i += 4) {
361
361
  vst1q_f32(y + i, wsp_ggml_v_silu(vld1q_f32(x + i)));
362
362
  }
363
+ #elif defined(__riscv_v_intrinsic)
364
+ for (int vl; i < n; i += vl) {
365
+ vl = __riscv_vsetvl_e32m2(n - i);
366
+ vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
367
+ vfloat32m2_t vy = wsp_ggml_v_silu_m2(vx, vl);
368
+ __riscv_vse32_v_f32m2(&y[i], vy, vl);
369
+ }
363
370
  #endif
364
371
  for (; i < n; ++i) {
365
372
  y[i] = wsp_ggml_silu_f32(x[i]);
@@ -460,6 +467,16 @@ wsp_ggml_float wsp_ggml_vec_cvar_f32(const int n, float * y, const float * x, co
460
467
  val = vec_mul(val, val);
461
468
  sum += (wsp_ggml_float)vec_hsum_f32x4(val);
462
469
  }
470
+ #elif defined(__riscv_v_intrinsic)
471
+ vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
472
+ for (int vl; i < n; i += vl) {
473
+ vl = __riscv_vsetvl_e32m2(n - i);
474
+ vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
475
+ __riscv_vse32_v_f32m2(&y[i], val, vl);
476
+ val = __riscv_vfmul_vv_f32m2(val, val, vl);
477
+ vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
478
+ }
479
+ sum = (wsp_ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
463
480
  #endif
464
481
  for (; i < n; ++i) {
465
482
  float val = x[i] - mean;
@@ -1416,6 +1416,16 @@ inline static void wsp_ggml_vec_sum_f32(const int n, float * s, const float * x)
1416
1416
  #endif
1417
1417
  }
1418
1418
 
1419
+ inline static void wsp_ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
1420
+ for (int i = 0; i < n; ++i) {
1421
+ if (i == 0) {
1422
+ y[i] = x[i];
1423
+ } else {
1424
+ y[i] = y[i - 1] + x[i];
1425
+ }
1426
+ }
1427
+ }
1428
+
1419
1429
  inline static void wsp_ggml_vec_sum_f32_ggf(const int n, wsp_ggml_float * s, const float * x) {
1420
1430
  wsp_ggml_float sum = 0.0;
1421
1431
  for (int i = 0; i < n; ++i) {
package/cpp/ggml-impl.h CHANGED
@@ -102,7 +102,7 @@ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
102
102
  }
103
103
  }
104
104
 
105
- static inline float wsp_ggml_softplus(float input) {
105
+ static inline float wsp_ggml_compute_softplus_f32(float input) {
106
106
  return (input > 20.0f) ? input : logf(1 + expf(input));
107
107
  }
108
108
  //
@@ -682,6 +682,7 @@ static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgr
682
682
  #endif
683
683
 
684
684
  #ifdef __cplusplus
685
+ #include <array>
685
686
  #include <initializer_list>
686
687
  #include <vector>
687
688
 
@@ -697,6 +698,21 @@ inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * c
697
698
  return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698
699
  }
699
700
 
701
+ // Return true if the edges in the graph match expectations.
702
+ inline bool wsp_ggml_check_edges(const struct wsp_ggml_cgraph * cgraph,
703
+ int start_idx,
704
+ std::initializer_list<std::array<int, 3>> edges) {
705
+ for (const auto & edge : edges) {
706
+ int dst_node = edge[0];
707
+ int src_idx = edge[1];
708
+ int src_node = edge[2];
709
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
710
+ return false;
711
+ }
712
+ }
713
+ return true;
714
+ }
715
+
700
716
  // expose GGUF internals for test code
701
717
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
702
718
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -35,7 +35,6 @@ struct wsp_ggml_metal {
35
35
  // additional, inference-time compiled pipelines
36
36
  wsp_ggml_metal_pipelines_t pipelines_ext;
37
37
 
38
- bool use_bfloat;
39
38
  bool use_fusion;
40
39
  bool use_concurrency;
41
40
  bool use_graph_optimize;
@@ -121,11 +120,10 @@ wsp_ggml_metal_t wsp_ggml_metal_init(wsp_ggml_metal_device_t dev) {
121
120
  }
122
121
  }
123
122
 
124
- const struct wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(dev);
123
+ //const struct wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(dev);
125
124
 
126
125
  res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
127
126
 
128
- res->use_bfloat = props_dev->has_bfloat;
129
127
  res->use_fusion = getenv("WSP_GGML_METAL_FUSION_DISABLE") == nil;
130
128
  res->use_concurrency = getenv("WSP_GGML_METAL_CONCURRENCY_DISABLE") == nil;
131
129
 
@@ -147,7 +145,6 @@ wsp_ggml_metal_t wsp_ggml_metal_init(wsp_ggml_metal_device_t dev) {
147
145
 
148
146
  memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
149
147
 
150
- WSP_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
151
148
  WSP_GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
152
149
  WSP_GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
153
150
  WSP_GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
@@ -292,7 +289,7 @@ void wsp_ggml_metal_set_tensor_async(wsp_ggml_metal_t ctx, struct wsp_ggml_tenso
292
289
 
293
290
  // queue the copy operation into the queue of the Metal context
294
291
  // this will be queued at the end, after any currently ongoing GPU operations
295
- id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
292
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
296
293
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
297
294
 
298
295
  [encoder copyFromBuffer:buf_src
@@ -303,6 +300,7 @@ void wsp_ggml_metal_set_tensor_async(wsp_ggml_metal_t ctx, struct wsp_ggml_tenso
303
300
 
304
301
  [encoder endEncoding];
305
302
  [cmd_buf commit];
303
+ [buf_src release];
306
304
 
307
305
  // do not wait here for completion
308
306
  //[cmd_buf waitUntilCompleted];
@@ -333,7 +331,7 @@ void wsp_ggml_metal_get_tensor_async(wsp_ggml_metal_t ctx, const struct wsp_ggml
333
331
 
334
332
  // queue the copy operation into the queue of the Metal context
335
333
  // this will be queued at the end, after any currently ongoing GPU operations
336
- id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
334
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
337
335
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
338
336
 
339
337
  [encoder copyFromBuffer:bid_src.metal
@@ -344,6 +342,7 @@ void wsp_ggml_metal_get_tensor_async(wsp_ggml_metal_t ctx, const struct wsp_ggml
344
342
 
345
343
  [encoder endEncoding];
346
344
  [cmd_buf commit];
345
+ [buf_dst release];
347
346
 
348
347
  // do not wait here for completion
349
348
  //[cmd_buf waitUntilCompleted];