whisper.rn 0.5.0-rc.9 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. package/cpp/ggml-alloc.c +1 -15
  2. package/cpp/ggml-backend-reg.cpp +17 -8
  3. package/cpp/ggml-backend.cpp +15 -22
  4. package/cpp/ggml-common.h +17 -0
  5. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  6. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  7. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  8. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  9. package/cpp/ggml-cpu/arch-fallback.h +34 -0
  10. package/cpp/ggml-cpu/ggml-cpu.c +22 -1
  11. package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
  12. package/cpp/ggml-cpu/ops.cpp +870 -211
  13. package/cpp/ggml-cpu/ops.h +3 -8
  14. package/cpp/ggml-cpu/quants.c +35 -0
  15. package/cpp/ggml-cpu/quants.h +8 -0
  16. package/cpp/ggml-cpu/repack.cpp +458 -47
  17. package/cpp/ggml-cpu/repack.h +22 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +1 -1
  19. package/cpp/ggml-cpu/traits.cpp +2 -2
  20. package/cpp/ggml-cpu/traits.h +1 -1
  21. package/cpp/ggml-cpu/vec.cpp +12 -9
  22. package/cpp/ggml-cpu/vec.h +107 -13
  23. package/cpp/ggml-impl.h +77 -0
  24. package/cpp/ggml-metal-impl.h +51 -12
  25. package/cpp/ggml-metal.m +610 -115
  26. package/cpp/ggml-opt.cpp +97 -41
  27. package/cpp/ggml-opt.h +25 -6
  28. package/cpp/ggml-quants.c +110 -16
  29. package/cpp/ggml-quants.h +6 -0
  30. package/cpp/ggml-whisper-sim.metallib +0 -0
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +314 -88
  33. package/cpp/ggml.h +137 -11
  34. package/cpp/gguf.cpp +8 -1
  35. package/cpp/jsi/RNWhisperJSI.cpp +23 -6
  36. package/cpp/whisper.cpp +15 -6
  37. package/ios/RNWhisper.mm +6 -6
  38. package/ios/RNWhisperContext.mm +2 -0
  39. package/ios/RNWhisperVadContext.mm +2 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  72. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  73. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  74. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  75. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  76. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  77. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  78. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  79. package/package.json +1 -1
  80. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  81. package/src/realtime-transcription/types.ts +6 -0
@@ -44,7 +44,14 @@ struct block_q4_Kx8 {
44
44
  };
45
45
 
46
46
  static_assert(sizeof(block_q4_Kx8) == sizeof(wsp_ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
47
+ struct block_q2_Kx8 {
48
+ wsp_ggml_half d[8]; // super-block scale for quantized scales
49
+ wsp_ggml_half dmin[8]; // super-block scale for quantized mins
50
+ uint8_t scales[128]; // scales and mins, quantized with 4 bits
51
+ uint8_t qs[512]; // 2--bit quants
52
+ };
47
53
 
54
+ static_assert(sizeof(block_q2_Kx8) == sizeof(wsp_ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
48
55
  struct block_q8_Kx4 {
49
56
  float d[4]; // delta
50
57
  int8_t qs[QK_K * 4]; // quants
@@ -60,6 +67,13 @@ struct block_iq4_nlx4 {
60
67
 
61
68
  static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(wsp_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
62
69
 
70
+ struct block_iq4_nlx8 {
71
+ wsp_ggml_half d[8]; // deltas for 8 iq4_nl blocks
72
+ uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks
73
+ };
74
+
75
+ static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(wsp_ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
76
+
63
77
  #if defined(__cplusplus)
64
78
  extern "C" {
65
79
  #endif
@@ -71,12 +85,16 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs,
71
85
  void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
72
86
  void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
73
87
  void wsp_ggml_gemv_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
88
+ void wsp_ggml_gemv_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
74
89
  void wsp_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
90
+ void wsp_ggml_gemv_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
75
91
  void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
76
92
  void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
77
93
  void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
78
94
  void wsp_ggml_gemm_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
95
+ void wsp_ggml_gemm_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
79
96
  void wsp_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
97
+ void wsp_ggml_gemm_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
80
98
 
81
99
  // Native implementations
82
100
  void wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
@@ -86,12 +104,16 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, siz
86
104
  void wsp_ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
87
105
  void wsp_ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
88
106
  void wsp_ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
107
+ void wsp_ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
89
108
  void wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
109
+ void wsp_ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
90
110
  void wsp_ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
91
111
  void wsp_ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
92
112
  void wsp_ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
93
113
  void wsp_ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
114
+ void wsp_ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
94
115
  void wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
116
+ void wsp_ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
95
117
 
96
118
  #if defined(__cplusplus)
97
119
  } // extern "C"
@@ -189,7 +189,7 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
189
189
  #define WSP_GGML_F32xt_LOAD(...) WSP_GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190
190
  #define WSP_GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191
191
  #define WSP_GGML_F32xt_STORE(...) WSP_GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192
- #define WSP_GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
192
+ #define WSP_GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
193
193
  #define WSP_GGML_F32xt_FMA(...) WSP_GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194
194
  #define WSP_GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195
195
  #define WSP_GGML_F32xt_ADD(...) WSP_GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
10
10
  } // namespace ggml::cpu
11
11
 
12
12
  bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) {
13
- for (auto extra : wsp_ggml_backend_cpu_get_extra_buffers_type()) {
13
+ for (auto extra : wsp_ggml_backend_cpu_get_extra_buffer_types()) {
14
14
  if (extra && extra->context) {
15
15
  auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
16
16
  auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -23,7 +23,7 @@ bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params,
23
23
  }
24
24
 
25
25
  bool wsp_ggml_cpu_extra_work_size(int n_threads, const struct wsp_ggml_tensor * op, size_t * size) {
26
- for (auto extra : wsp_ggml_backend_cpu_get_extra_buffers_type()) {
26
+ for (auto extra : wsp_ggml_backend_cpu_get_extra_buffer_types()) {
27
27
  if (extra && extra->context) {
28
28
  auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
29
29
  auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -33,6 +33,6 @@ class extra_buffer_type {
33
33
  } // namespace ggml::cpu
34
34
 
35
35
  // implemented in ggml-cpu.cpp.
36
- std::vector<wsp_ggml_backend_buffer_type_t> & wsp_ggml_backend_cpu_get_extra_buffers_type();
36
+ std::vector<wsp_ggml_backend_buffer_type_t> & wsp_ggml_backend_cpu_get_extra_buffer_types();
37
37
 
38
38
  #endif
@@ -37,35 +37,35 @@ void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const f
37
37
  for (int i = 0; i < np; i += wsp_ggml_f32_step) {
38
38
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
39
39
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
40
- sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1, sum1);
40
+ sum1 = WSP_GGML_F32_VEC_FMA(sum1, ax1, ay1);
41
41
 
42
42
  ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
43
43
  ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
44
- sum2 = WSP_GGML_F32_VEC_FMA(ax2, ay2, sum2);
44
+ sum2 = WSP_GGML_F32_VEC_FMA(sum2, ax2, ay2);
45
45
 
46
46
  ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
47
47
  ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
48
- sum3 = WSP_GGML_F32_VEC_FMA(ax3, ay3, sum3);
48
+ sum3 = WSP_GGML_F32_VEC_FMA(sum3, ax3, ay3);
49
49
 
50
50
  ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
51
51
  ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
52
- sum4 = WSP_GGML_F32_VEC_FMA(ax4, ay4, sum4);
52
+ sum4 = WSP_GGML_F32_VEC_FMA(sum4, ax4, ay4);
53
53
 
54
54
  ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
55
55
  ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
56
- sum5 = WSP_GGML_F32_VEC_FMA(ax5, ay5, sum5);
56
+ sum5 = WSP_GGML_F32_VEC_FMA(sum5, ax5, ay5);
57
57
 
58
58
  ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
59
59
  ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
60
- sum6 = WSP_GGML_F32_VEC_FMA(ax6, ay6, sum6);
60
+ sum6 = WSP_GGML_F32_VEC_FMA(sum6, ax6, ay6);
61
61
 
62
62
  ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
63
63
  ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
64
- sum7 = WSP_GGML_F32_VEC_FMA(ax7, ay7, sum7);
64
+ sum7 = WSP_GGML_F32_VEC_FMA(sum7, ax7, ay7);
65
65
 
66
66
  ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
67
67
  ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
68
- sum8 = WSP_GGML_F32_VEC_FMA(ax8, ay8, sum8);
68
+ sum8 = WSP_GGML_F32_VEC_FMA(sum8, ax8, ay8);
69
69
  }
70
70
  // leftovers
71
71
  // Since 8 unrolls are done in above loop, leftovers lie in range [0, wsp_ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const f
73
73
  for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
74
74
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
75
75
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
76
- sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1, sum1);
76
+ sum1 = WSP_GGML_F32_VEC_FMA(sum1, ax1, ay1);
77
77
  }
78
78
  // maximum number of leftover elements will be less that wsp_ggml_f32_epr. Apply predicated svmad on available elements only
79
79
  if (np2 < n) {
@@ -221,6 +221,9 @@ void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggm
221
221
  for (int i = np; i < n; ++i) {
222
222
  sumf += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
223
223
  }
224
+
225
+ // if you hit this, you are likely running outside the FP range
226
+ assert(!isnan(sumf) && !isinf(sumf));
224
227
  #else
225
228
  for (int i = 0; i < n; ++i) {
226
229
  sumf += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
@@ -55,7 +55,22 @@ inline static void wsp_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t
55
55
 
56
56
  inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
57
57
  inline static void wsp_ggml_vec_set_bf16(const int n, wsp_ggml_bf16_t * x, const wsp_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
58
- inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
58
+
59
+ inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
60
+ int i = 0;
61
+ #if defined(__AVX2__)
62
+ for (; i + 7 < n; i += 8) {
63
+ __m256 vx = _mm256_loadu_ps(x + i);
64
+ __m256 vy = _mm256_loadu_ps(y + i);
65
+ __m256 vz = _mm256_add_ps(vx, vy);
66
+ _mm256_storeu_ps(z + i, vz);
67
+ }
68
+ #endif
69
+ for (; i < n; ++i) {
70
+ z[i] = x[i] + y[i];
71
+ }
72
+ }
73
+
59
74
  inline static void wsp_ggml_vec_add_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
60
75
  for (int i = 0; i < n; ++i) {
61
76
  z[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(x[i]) + WSP_GGML_CPU_FP16_TO_FP32(y[i]));
@@ -163,49 +178,49 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
163
178
 
164
179
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
165
180
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
166
- ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx, ay1);
181
+ ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
167
182
 
168
183
  WSP_GGML_F32_VEC_STORE(y + i, ay1);
169
184
 
170
185
  ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
171
186
  ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
172
- ay2 = WSP_GGML_F32_VEC_FMA(ax2, vx, ay2);
187
+ ay2 = WSP_GGML_F32_VEC_FMA(ay2, ax2, vx);
173
188
 
174
189
  WSP_GGML_F32_VEC_STORE(y + i + 1*wsp_ggml_f32_epr, ay2);
175
190
 
176
191
  ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
177
192
  ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
178
- ay3 = WSP_GGML_F32_VEC_FMA(ax3, vx, ay3);
193
+ ay3 = WSP_GGML_F32_VEC_FMA(ay3, ax3, vx);
179
194
 
180
195
  WSP_GGML_F32_VEC_STORE(y + i + 2*wsp_ggml_f32_epr, ay3);
181
196
 
182
197
  ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
183
198
  ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
184
- ay4 = WSP_GGML_F32_VEC_FMA(ax4, vx, ay4);
199
+ ay4 = WSP_GGML_F32_VEC_FMA(ay4, ax4, vx);
185
200
 
186
201
  WSP_GGML_F32_VEC_STORE(y + i + 3*wsp_ggml_f32_epr, ay4);
187
202
 
188
203
  ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
189
204
  ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
190
- ay5 = WSP_GGML_F32_VEC_FMA(ax5, vx, ay5);
205
+ ay5 = WSP_GGML_F32_VEC_FMA(ay5, ax5, vx);
191
206
 
192
207
  WSP_GGML_F32_VEC_STORE(y + i + 4*wsp_ggml_f32_epr, ay5);
193
208
 
194
209
  ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
195
210
  ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
196
- ay6 = WSP_GGML_F32_VEC_FMA(ax6, vx, ay6);
211
+ ay6 = WSP_GGML_F32_VEC_FMA(ay6, ax6, vx);
197
212
 
198
213
  WSP_GGML_F32_VEC_STORE(y + i + 5*wsp_ggml_f32_epr, ay6);
199
214
 
200
215
  ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
201
216
  ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
202
- ay7 = WSP_GGML_F32_VEC_FMA(ax7, vx, ay7);
217
+ ay7 = WSP_GGML_F32_VEC_FMA(ay7, ax7, vx);
203
218
 
204
219
  WSP_GGML_F32_VEC_STORE(y + i + 6*wsp_ggml_f32_epr, ay7);
205
220
 
206
221
  ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
207
222
  ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
208
- ay8 = WSP_GGML_F32_VEC_FMA(ax8, vx, ay8);
223
+ ay8 = WSP_GGML_F32_VEC_FMA(ay8, ax8, vx);
209
224
 
210
225
  WSP_GGML_F32_VEC_STORE(y + i + 7*wsp_ggml_f32_epr, ay8);
211
226
  }
@@ -215,7 +230,7 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
215
230
  for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
216
231
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
217
232
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
218
- ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx, ay1);
233
+ ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
219
234
 
220
235
  WSP_GGML_F32_VEC_STORE(y + i, ay1);
221
236
  }
@@ -351,6 +366,45 @@ inline static void wsp_ggml_vec_mad_f32_unroll(const int n, const int xs, const
351
366
  #endif
352
367
  }
353
368
 
369
+ inline static void wsp_ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
370
+ #if defined(WSP_GGML_USE_ACCELERATE)
371
+ vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
372
+ #elif defined(WSP_GGML_SIMD)
373
+ #if defined(__ARM_FEATURE_SVE)
374
+ // scalar ; TODO: Write SVE code
375
+ for (int i = 0; i < n; ++i) {
376
+ y[i] = x[i]*s + b;
377
+ }
378
+ #else
379
+ const int np = (n & ~(WSP_GGML_F32_STEP - 1));
380
+
381
+ WSP_GGML_F32_VEC vs = WSP_GGML_F32_VEC_SET1(s);
382
+ WSP_GGML_F32_VEC vb = WSP_GGML_F32_VEC_SET1(b);
383
+
384
+ WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
385
+
386
+ for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
387
+ for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
388
+ ay[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR);
389
+ ay[j] = WSP_GGML_F32_VEC_FMA(ay[j], vs, vb);
390
+
391
+ WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]);
392
+ }
393
+ }
394
+
395
+ // leftovers
396
+ for (int i = np; i < n; ++i) {
397
+ y[i] = x[i]*s + b;
398
+ }
399
+ #endif
400
+ #else
401
+ // scalar
402
+ for (int i = 0; i < n; ++i) {
403
+ y[i] = x[i]*s + b;
404
+ }
405
+ #endif
406
+ }
407
+
354
408
  //inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
355
409
  inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) {
356
410
  #if defined(WSP_GGML_USE_ACCELERATE)
@@ -953,9 +1007,49 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
953
1007
 
954
1008
  inline static void wsp_ggml_vec_swiglu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
955
1009
  for (int i = 0; i < n; ++i) {
956
- float v = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
957
- float w = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
958
- y[i] = WSP_GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
1010
+ float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
1011
+ float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1012
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
1013
+ }
1014
+ }
1015
+
1016
+ inline static void wsp_ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
1017
+ for (int i = 0; i < n; ++i) {
1018
+ float xi = x[i];
1019
+ y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
1020
+ }
1021
+ }
1022
+
1023
+ inline static void wsp_ggml_vec_geglu_erf_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
1024
+ for (int i = 0; i < n; ++i) {
1025
+ float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
1026
+ float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1027
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
1028
+ }
1029
+ }
1030
+
1031
+ #ifdef WSP_GGML_GELU_QUICK_FP16
1032
+ inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1033
+ uint16_t t;
1034
+ for (int i = 0; i < n; ++i) {
1035
+ wsp_ggml_fp16_t fp16 = WSP_GGML_CPU_FP32_TO_FP16(x[i]);
1036
+ memcpy(&t, &fp16, sizeof(uint16_t));
1037
+ y[i] = WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[t]) * g[i];
1038
+ }
1039
+ }
1040
+ #else
1041
+ inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1042
+ for (int i = 0; i < n; ++i) {
1043
+ y[i] = wsp_ggml_gelu_quick_f32(x[i]) * g[i];
1044
+ }
1045
+ }
1046
+ #endif
1047
+
1048
+ inline static void wsp_ggml_vec_geglu_quick_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
1049
+ const uint16_t * i16 = (const uint16_t *) x;
1050
+ for (int i = 0; i < n; ++i) {
1051
+ float v = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1052
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[i16[i]]) * v);
959
1053
  }
960
1054
  }
961
1055
 
package/cpp/ggml-impl.h CHANGED
@@ -73,6 +73,22 @@ static inline int wsp_ggml_up(int n, int m) {
73
73
  return (n + m - 1) & ~(m - 1);
74
74
  }
75
75
 
76
+ // TODO: move to ggml.h?
77
+ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
78
+ if (a->type != b->type) {
79
+ return false;
80
+ }
81
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
82
+ if (a->ne[i] != b->ne[i]) {
83
+ return false;
84
+ }
85
+ if (a->nb[i] != b->nb[i]) {
86
+ return false;
87
+ }
88
+ }
89
+ return true;
90
+ }
91
+
76
92
  //
77
93
  // logging
78
94
  //
@@ -394,6 +410,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
394
410
  #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
395
411
  #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
396
412
 
413
+ static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
414
+ uint32_t bits; // Stores the raw bit representation of the float
415
+
416
+ // Handle special case for minimum exponent (denormalized float)
417
+ if (x == 0) {
418
+ // Bit pattern for 2^(-127):
419
+ // - Sign bit: 0 (positive)
420
+ // - Exponent: 0 (denormalized number)
421
+ // - Mantissa: 0x400000 (0.5 in fractional form)
422
+ // Value = 0.5 * 2^(-126) = 2^(-127)
423
+ bits = 0x00400000;
424
+ }
425
+ // note: disabled as we don't need to handle NaNs
426
+ //// Handle special case for NaN (all bits set)
427
+ //else if (x == 0xFF) {
428
+ // // Standard quiet NaN pattern:
429
+ // // - Sign bit: 0
430
+ // // - Exponent: all 1s (0xFF)
431
+ // // - Mantissa: 0x400000 (quiet NaN flag)
432
+ // bits = 0x7FC00000;
433
+ //}
434
+ // Normalized values (most common case)
435
+ else {
436
+ // Construct normalized float by shifting exponent into position:
437
+ // - Exponent field: 8 bits (positions 30-23)
438
+ // - Mantissa: 0 (implicit leading 1)
439
+ // Value = 2^(x - 127)
440
+ bits = (uint32_t) x << 23;
441
+ }
442
+
443
+ float result; // Final float value
444
+ // Safely reinterpret bit pattern as float without type-punning issues
445
+ memcpy(&result, &bits, sizeof(float));
446
+ return result;
447
+ }
448
+
449
+ // Equal to wsp_ggml_e8m0_to_fp32/2
450
+ // Useful with MXFP4 quantization since the E0M2 values are doubled
451
+ static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
452
+ uint32_t bits;
453
+
454
+ // For x < 2: use precomputed denormal patterns
455
+ if (x < 2) {
456
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
457
+ bits = 0x00200000 << x;
458
+ }
459
+ // For x >= 2: normalized exponent adjustment
460
+ else {
461
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
462
+ bits = (uint32_t)(x - 1) << 23;
463
+ }
464
+ // Note: NaNs are not handled here
465
+
466
+ float result;
467
+ memcpy(&result, &bits, sizeof(float));
468
+ return result;
469
+ }
470
+
471
+ #define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
472
+ #define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
473
+
397
474
  /**
398
475
  * Converts brain16 to float32.
399
476
  *
@@ -23,6 +23,9 @@
23
23
  #define N_R0_Q8_0 4
24
24
  #define N_SG_Q8_0 2
25
25
 
26
+ #define N_R0_MXFP4 2
27
+ #define N_SG_MXFP4 2
28
+
26
29
  #define N_R0_Q2_K 4
27
30
  #define N_SG_Q2_K 2
28
31
 
@@ -126,8 +129,18 @@ typedef struct {
126
129
  uint64_t nb2;
127
130
  uint64_t nb3;
128
131
  uint64_t offs;
132
+ uint64_t o1[8];
129
133
  } wsp_ggml_metal_kargs_bin;
130
134
 
135
+ typedef struct {
136
+ int64_t ne0;
137
+ int64_t ne1;
138
+ size_t nb01;
139
+ size_t nb02;
140
+ size_t nb11;
141
+ size_t nb21;
142
+ } wsp_ggml_metal_kargs_add_id;
143
+
131
144
  typedef struct {
132
145
  int32_t ne00;
133
146
  int32_t ne01;
@@ -229,14 +242,18 @@ typedef struct {
229
242
  uint64_t nb21;
230
243
  uint64_t nb22;
231
244
  uint64_t nb23;
245
+ int32_t ne32;
246
+ int32_t ne33;
232
247
  uint64_t nb31;
248
+ uint64_t nb32;
249
+ uint64_t nb33;
233
250
  int32_t ne1;
234
251
  int32_t ne2;
235
252
  float scale;
236
253
  float max_bias;
237
254
  float m0;
238
255
  float m1;
239
- uint16_t n_head_log2;
256
+ int32_t n_head_log2;
240
257
  float logit_softcap;
241
258
  } wsp_ggml_metal_kargs_flash_attn_ext;
242
259
 
@@ -373,8 +390,16 @@ typedef struct {
373
390
  typedef struct {
374
391
  int32_t ne00;
375
392
  int32_t ne00_4;
376
- uint64_t nb01;
393
+ uint64_t nb1;
394
+ uint64_t nb2;
395
+ uint64_t nb3;
377
396
  float eps;
397
+ int32_t nef1[3];
398
+ int32_t nef2[3];
399
+ int32_t nef3[3];
400
+ uint64_t nbf1[3];
401
+ uint64_t nbf2[3];
402
+ uint64_t nbf3[3];
378
403
  } wsp_ggml_metal_kargs_rms_norm;
379
404
 
380
405
  typedef struct {
@@ -431,6 +456,8 @@ typedef struct{
431
456
  uint64_t nb1;
432
457
  int32_t i00;
433
458
  int32_t i10;
459
+ float alpha;
460
+ float limit;
434
461
  } wsp_ggml_metal_kargs_glu;
435
462
 
436
463
  typedef struct {
@@ -461,14 +488,26 @@ typedef struct {
461
488
  } wsp_ggml_metal_kargs_sum_rows;
462
489
 
463
490
  typedef struct {
464
- int64_t ne00;
465
- int64_t ne01;
466
- int64_t ne02;
491
+ int32_t ne00;
492
+ int32_t ne01;
493
+ int32_t ne02;
494
+ uint64_t nb01;
495
+ uint64_t nb02;
496
+ uint64_t nb03;
497
+ int32_t ne11;
498
+ int32_t ne12;
499
+ int32_t ne13;
500
+ uint64_t nb11;
501
+ uint64_t nb12;
502
+ uint64_t nb13;
503
+ uint64_t nb1;
504
+ uint64_t nb2;
505
+ uint64_t nb3;
467
506
  float scale;
468
507
  float max_bias;
469
508
  float m0;
470
509
  float m1;
471
- uint32_t n_head_log2;
510
+ int32_t n_head_log2;
472
511
  } wsp_ggml_metal_kargs_soft_max;
473
512
 
474
513
  typedef struct {
@@ -499,26 +538,26 @@ typedef struct {
499
538
  typedef struct {
500
539
  int64_t d_state;
501
540
  int64_t d_inner;
541
+ int64_t n_head;
542
+ int64_t n_group;
502
543
  int64_t n_seq_tokens;
503
544
  int64_t n_seqs;
504
- uint64_t nb00;
545
+ int64_t s_off;
505
546
  uint64_t nb01;
506
547
  uint64_t nb02;
507
- uint64_t nb10;
548
+ uint64_t nb03;
508
549
  uint64_t nb11;
509
550
  uint64_t nb12;
510
551
  uint64_t nb13;
511
- uint64_t nb20;
512
552
  uint64_t nb21;
513
553
  uint64_t nb22;
514
- uint64_t nb30;
515
554
  uint64_t nb31;
516
- uint64_t nb40;
517
555
  uint64_t nb41;
518
556
  uint64_t nb42;
519
- uint64_t nb50;
557
+ uint64_t nb43;
520
558
  uint64_t nb51;
521
559
  uint64_t nb52;
560
+ uint64_t nb53;
522
561
  } wsp_ggml_metal_kargs_ssm_scan;
523
562
 
524
563
  typedef struct {