whisper.rn 0.5.0-rc.9 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +265 -141
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +30 -13
  6. package/cpp/ggml-backend.cpp +221 -38
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-common.h +17 -0
  9. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  10. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  12. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  13. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  14. package/cpp/ggml-cpu/arch-fallback.h +32 -2
  15. package/cpp/ggml-cpu/common.h +14 -0
  16. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  17. package/cpp/ggml-cpu/ggml-cpu.c +70 -42
  18. package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
  19. package/cpp/ggml-cpu/ops.cpp +1587 -1177
  20. package/cpp/ggml-cpu/ops.h +5 -8
  21. package/cpp/ggml-cpu/quants.c +35 -0
  22. package/cpp/ggml-cpu/quants.h +8 -0
  23. package/cpp/ggml-cpu/repack.cpp +458 -47
  24. package/cpp/ggml-cpu/repack.h +22 -0
  25. package/cpp/ggml-cpu/simd-mappings.h +89 -60
  26. package/cpp/ggml-cpu/traits.cpp +2 -2
  27. package/cpp/ggml-cpu/traits.h +1 -1
  28. package/cpp/ggml-cpu/vec.cpp +170 -26
  29. package/cpp/ggml-cpu/vec.h +506 -63
  30. package/cpp/ggml-cpu.h +1 -1
  31. package/cpp/ggml-impl.h +119 -9
  32. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  33. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  34. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  35. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  36. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  37. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  38. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  39. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  40. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  41. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  42. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  43. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  44. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  45. package/cpp/ggml-metal-impl.h +90 -51
  46. package/cpp/ggml-metal.h +1 -6
  47. package/cpp/ggml-opt.cpp +97 -41
  48. package/cpp/ggml-opt.h +25 -6
  49. package/cpp/ggml-quants.c +111 -16
  50. package/cpp/ggml-quants.h +6 -0
  51. package/cpp/ggml.c +486 -98
  52. package/cpp/ggml.h +221 -16
  53. package/cpp/gguf.cpp +8 -1
  54. package/cpp/jsi/RNWhisperJSI.cpp +25 -6
  55. package/cpp/jsi/ThreadPool.h +3 -3
  56. package/cpp/whisper.cpp +100 -76
  57. package/cpp/whisper.h +1 -0
  58. package/ios/CMakeLists.txt +6 -1
  59. package/ios/RNWhisper.mm +6 -6
  60. package/ios/RNWhisperContext.mm +2 -0
  61. package/ios/RNWhisperVadContext.mm +16 -13
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  63. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  64. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  67. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  68. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  70. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  72. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  74. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  77. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  78. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  79. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  80. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  81. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  82. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  84. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  86. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  89. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  92. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  93. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  94. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  95. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  96. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  97. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  98. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  99. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  101. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  102. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  103. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  104. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  105. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  106. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  107. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  108. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  109. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  110. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  111. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  112. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  113. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  114. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  115. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  116. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  117. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  118. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  119. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  120. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  121. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  122. package/lib/commonjs/version.json +1 -1
  123. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  124. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  125. package/lib/module/version.json +1 -1
  126. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  127. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  128. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  129. package/package.json +1 -1
  130. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  131. package/src/realtime-transcription/types.ts +6 -0
  132. package/src/version.json +1 -1
  133. package/whisper-rn.podspec +8 -9
  134. package/cpp/ggml-metal.m +0 -6284
  135. package/cpp/ggml-whisper-sim.metallib +0 -0
  136. package/cpp/ggml-whisper.metallib +0 -0
@@ -1,5 +1,5 @@
1
- #ifndef WSP_GGML_METAL_IMPL
2
- #define WSP_GGML_METAL_IMPL
1
+ #ifndef WSP_WSP_WSP_GGML_METAL_IMPL
2
+ #define WSP_WSP_WSP_GGML_METAL_IMPL
3
3
 
4
4
  // kernel parameters for mat-vec threadgroups
5
5
  //
@@ -23,6 +23,9 @@
23
23
  #define N_R0_Q8_0 4
24
24
  #define N_SG_Q8_0 2
25
25
 
26
+ #define N_R0_MXFP4 2
27
+ #define N_SG_MXFP4 2
28
+
26
29
  #define N_R0_Q2_K 4
27
30
  #define N_SG_Q2_K 2
28
31
 
@@ -98,7 +101,7 @@ typedef struct {
98
101
  uint64_t nb2;
99
102
  uint64_t nb3;
100
103
  int32_t dim;
101
- } wsp_ggml_metal_kargs_concat;
104
+ } wsp_wsp_wsp_ggml_metal_kargs_concat;
102
105
 
103
106
  typedef struct {
104
107
  int32_t ne00;
@@ -126,7 +129,17 @@ typedef struct {
126
129
  uint64_t nb2;
127
130
  uint64_t nb3;
128
131
  uint64_t offs;
129
- } wsp_ggml_metal_kargs_bin;
132
+ uint64_t o1[8];
133
+ } wsp_wsp_wsp_ggml_metal_kargs_bin;
134
+
135
+ typedef struct {
136
+ int64_t ne0;
137
+ int64_t ne1;
138
+ size_t nb01;
139
+ size_t nb02;
140
+ size_t nb11;
141
+ size_t nb21;
142
+ } wsp_wsp_wsp_ggml_metal_kargs_add_id;
130
143
 
131
144
  typedef struct {
132
145
  int32_t ne00;
@@ -145,7 +158,7 @@ typedef struct {
145
158
  uint64_t nb1;
146
159
  uint64_t nb2;
147
160
  uint64_t nb3;
148
- } wsp_ggml_metal_kargs_repeat;
161
+ } wsp_wsp_wsp_ggml_metal_kargs_repeat;
149
162
 
150
163
  typedef struct {
151
164
  int64_t ne00;
@@ -164,7 +177,7 @@ typedef struct {
164
177
  uint64_t nb1;
165
178
  uint64_t nb2;
166
179
  uint64_t nb3;
167
- } wsp_ggml_metal_kargs_cpy;
180
+ } wsp_wsp_wsp_ggml_metal_kargs_cpy;
168
181
 
169
182
  typedef struct {
170
183
  int64_t ne10;
@@ -179,7 +192,7 @@ typedef struct {
179
192
  uint64_t nb3;
180
193
  uint64_t offs;
181
194
  bool inplace;
182
- } wsp_ggml_metal_kargs_set;
195
+ } wsp_wsp_wsp_ggml_metal_kargs_set;
183
196
 
184
197
  typedef struct {
185
198
  int32_t ne00;
@@ -211,7 +224,7 @@ typedef struct {
211
224
  int32_t sect_1;
212
225
  int32_t sect_2;
213
226
  int32_t sect_3;
214
- } wsp_ggml_metal_kargs_rope;
227
+ } wsp_wsp_wsp_ggml_metal_kargs_rope;
215
228
 
216
229
  typedef struct {
217
230
  int32_t ne01;
@@ -229,16 +242,20 @@ typedef struct {
229
242
  uint64_t nb21;
230
243
  uint64_t nb22;
231
244
  uint64_t nb23;
245
+ int32_t ne32;
246
+ int32_t ne33;
232
247
  uint64_t nb31;
248
+ uint64_t nb32;
249
+ uint64_t nb33;
233
250
  int32_t ne1;
234
251
  int32_t ne2;
235
252
  float scale;
236
253
  float max_bias;
237
254
  float m0;
238
255
  float m1;
239
- uint16_t n_head_log2;
256
+ int32_t n_head_log2;
240
257
  float logit_softcap;
241
- } wsp_ggml_metal_kargs_flash_attn_ext;
258
+ } wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
242
259
 
243
260
  typedef struct {
244
261
  int32_t ne00;
@@ -255,7 +272,7 @@ typedef struct {
255
272
  int32_t ne1;
256
273
  int16_t r2;
257
274
  int16_t r3;
258
- } wsp_ggml_metal_kargs_mul_mm;
275
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
259
276
 
260
277
  typedef struct {
261
278
  int32_t ne00;
@@ -276,7 +293,7 @@ typedef struct {
276
293
  int32_t ne1;
277
294
  int16_t r2;
278
295
  int16_t r3;
279
- } wsp_ggml_metal_kargs_mul_mv;
296
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
280
297
 
281
298
  typedef struct {
282
299
  int32_t ne00;
@@ -300,7 +317,7 @@ typedef struct {
300
317
  int16_t nsg;
301
318
  int16_t nxpsg;
302
319
  int16_t r1ptg;
303
- } wsp_ggml_metal_kargs_mul_mv_ext;
320
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
304
321
 
305
322
  typedef struct {
306
323
  int32_t ne10;
@@ -311,7 +328,7 @@ typedef struct {
311
328
  uint64_t nbh11;
312
329
  int32_t ne20; // n_expert_used
313
330
  uint64_t nb21;
314
- } wsp_ggml_metal_kargs_mul_mm_id_map0;
331
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
315
332
 
316
333
  typedef struct {
317
334
  int32_t ne20; // n_expert_used
@@ -322,7 +339,7 @@ typedef struct {
322
339
  int32_t ne0;
323
340
  uint64_t nb1;
324
341
  uint64_t nb2;
325
- } wsp_ggml_metal_kargs_mul_mm_id_map1;
342
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
326
343
 
327
344
  typedef struct {
328
345
  int32_t ne00;
@@ -339,7 +356,7 @@ typedef struct {
339
356
  int32_t neh1;
340
357
  int16_t r2;
341
358
  int16_t r3;
342
- } wsp_ggml_metal_kargs_mul_mm_id;
359
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
343
360
 
344
361
  typedef struct {
345
362
  int32_t nei0;
@@ -361,28 +378,36 @@ typedef struct {
361
378
  int32_t ne0;
362
379
  int32_t ne1;
363
380
  uint64_t nb1;
364
- } wsp_ggml_metal_kargs_mul_mv_id;
381
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
365
382
 
366
383
  typedef struct {
367
384
  int32_t ne00;
368
385
  int32_t ne00_4;
369
386
  uint64_t nb01;
370
387
  float eps;
371
- } wsp_ggml_metal_kargs_norm;
388
+ } wsp_wsp_wsp_ggml_metal_kargs_norm;
372
389
 
373
390
  typedef struct {
374
391
  int32_t ne00;
375
392
  int32_t ne00_4;
376
- uint64_t nb01;
393
+ uint64_t nb1;
394
+ uint64_t nb2;
395
+ uint64_t nb3;
377
396
  float eps;
378
- } wsp_ggml_metal_kargs_rms_norm;
397
+ int32_t nef1[3];
398
+ int32_t nef2[3];
399
+ int32_t nef3[3];
400
+ uint64_t nbf1[3];
401
+ uint64_t nbf2[3];
402
+ uint64_t nbf3[3];
403
+ } wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
379
404
 
380
405
  typedef struct {
381
406
  int32_t ne00;
382
407
  int32_t ne00_4;
383
408
  uint64_t nb01;
384
409
  float eps;
385
- } wsp_ggml_metal_kargs_l2_norm;
410
+ } wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
386
411
 
387
412
  typedef struct {
388
413
  int64_t ne00;
@@ -393,7 +418,7 @@ typedef struct {
393
418
  uint64_t nb02;
394
419
  int32_t n_groups;
395
420
  float eps;
396
- } wsp_ggml_metal_kargs_group_norm;
421
+ } wsp_wsp_wsp_ggml_metal_kargs_group_norm;
397
422
 
398
423
  typedef struct {
399
424
  int32_t IC;
@@ -402,7 +427,7 @@ typedef struct {
402
427
  int32_t s0;
403
428
  uint64_t nb0;
404
429
  uint64_t nb1;
405
- } wsp_ggml_metal_kargs_conv_transpose_1d;
430
+ } wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
406
431
 
407
432
  typedef struct {
408
433
  uint64_t ofs0;
@@ -420,7 +445,7 @@ typedef struct {
420
445
  int32_t KH;
421
446
  int32_t KW;
422
447
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
- } wsp_ggml_metal_kargs_im2col;
448
+ } wsp_wsp_wsp_ggml_metal_kargs_im2col;
424
449
 
425
450
  typedef struct{
426
451
  int32_t ne00;
@@ -431,7 +456,9 @@ typedef struct{
431
456
  uint64_t nb1;
432
457
  int32_t i00;
433
458
  int32_t i10;
434
- } wsp_ggml_metal_kargs_glu;
459
+ float alpha;
460
+ float limit;
461
+ } wsp_wsp_wsp_ggml_metal_kargs_glu;
435
462
 
436
463
  typedef struct {
437
464
  int64_t ne00;
@@ -458,24 +485,36 @@ typedef struct {
458
485
  uint64_t nb1;
459
486
  uint64_t nb2;
460
487
  uint64_t nb3;
461
- } wsp_ggml_metal_kargs_sum_rows;
488
+ } wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
462
489
 
463
490
  typedef struct {
464
- int64_t ne00;
465
- int64_t ne01;
466
- int64_t ne02;
491
+ int32_t ne00;
492
+ int32_t ne01;
493
+ int32_t ne02;
494
+ uint64_t nb01;
495
+ uint64_t nb02;
496
+ uint64_t nb03;
497
+ int32_t ne11;
498
+ int32_t ne12;
499
+ int32_t ne13;
500
+ uint64_t nb11;
501
+ uint64_t nb12;
502
+ uint64_t nb13;
503
+ uint64_t nb1;
504
+ uint64_t nb2;
505
+ uint64_t nb3;
467
506
  float scale;
468
507
  float max_bias;
469
508
  float m0;
470
509
  float m1;
471
- uint32_t n_head_log2;
472
- } wsp_ggml_metal_kargs_soft_max;
510
+ int32_t n_head_log2;
511
+ } wsp_wsp_wsp_ggml_metal_kargs_soft_max;
473
512
 
474
513
  typedef struct {
475
514
  int64_t ne00;
476
515
  int64_t ne01;
477
516
  int n_past;
478
- } wsp_ggml_metal_kargs_diag_mask_inf;
517
+ } wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
479
518
 
480
519
  typedef struct {
481
520
  int64_t ne00;
@@ -494,32 +533,32 @@ typedef struct {
494
533
  uint64_t nb0;
495
534
  uint64_t nb1;
496
535
  uint64_t nb2;
497
- } wsp_ggml_metal_kargs_ssm_conv;
536
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
498
537
 
499
538
  typedef struct {
500
539
  int64_t d_state;
501
540
  int64_t d_inner;
541
+ int64_t n_head;
542
+ int64_t n_group;
502
543
  int64_t n_seq_tokens;
503
544
  int64_t n_seqs;
504
- uint64_t nb00;
545
+ int64_t s_off;
505
546
  uint64_t nb01;
506
547
  uint64_t nb02;
507
- uint64_t nb10;
548
+ uint64_t nb03;
508
549
  uint64_t nb11;
509
550
  uint64_t nb12;
510
551
  uint64_t nb13;
511
- uint64_t nb20;
512
552
  uint64_t nb21;
513
553
  uint64_t nb22;
514
- uint64_t nb30;
515
554
  uint64_t nb31;
516
- uint64_t nb40;
517
555
  uint64_t nb41;
518
556
  uint64_t nb42;
519
- uint64_t nb50;
557
+ uint64_t nb43;
520
558
  uint64_t nb51;
521
559
  uint64_t nb52;
522
- } wsp_ggml_metal_kargs_ssm_scan;
560
+ uint64_t nb53;
561
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
523
562
 
524
563
  typedef struct {
525
564
  int64_t ne00;
@@ -530,7 +569,7 @@ typedef struct {
530
569
  uint64_t nb11;
531
570
  uint64_t nb1;
532
571
  uint64_t nb2;
533
- } wsp_ggml_metal_kargs_get_rows;
572
+ } wsp_wsp_wsp_ggml_metal_kargs_get_rows;
534
573
 
535
574
  typedef struct {
536
575
  int32_t nk0;
@@ -546,7 +585,7 @@ typedef struct {
546
585
  uint64_t nb1;
547
586
  uint64_t nb2;
548
587
  uint64_t nb3;
549
- } wsp_ggml_metal_kargs_set_rows;
588
+ } wsp_wsp_wsp_ggml_metal_kargs_set_rows;
550
589
 
551
590
  typedef struct {
552
591
  int64_t ne00;
@@ -569,7 +608,7 @@ typedef struct {
569
608
  float sf1;
570
609
  float sf2;
571
610
  float sf3;
572
- } wsp_ggml_metal_kargs_upscale;
611
+ } wsp_wsp_wsp_ggml_metal_kargs_upscale;
573
612
 
574
613
  typedef struct {
575
614
  int64_t ne00;
@@ -588,7 +627,7 @@ typedef struct {
588
627
  uint64_t nb1;
589
628
  uint64_t nb2;
590
629
  uint64_t nb3;
591
- } wsp_ggml_metal_kargs_pad;
630
+ } wsp_wsp_wsp_ggml_metal_kargs_pad;
592
631
 
593
632
  typedef struct {
594
633
  int64_t ne00;
@@ -609,28 +648,28 @@ typedef struct {
609
648
  uint64_t nb3;
610
649
  int32_t p0;
611
650
  int32_t p1;
612
- } wsp_ggml_metal_kargs_pad_reflect_1d;
651
+ } wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
613
652
 
614
653
  typedef struct {
615
654
  uint64_t nb1;
616
655
  int dim;
617
656
  int max_period;
618
- } wsp_ggml_metal_kargs_timestep_embedding;
657
+ } wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
619
658
 
620
659
  typedef struct {
621
660
  float slope;
622
- } wsp_ggml_metal_kargs_leaky_relu;
661
+ } wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
623
662
 
624
663
  typedef struct {
625
664
  int64_t ncols;
626
665
  int64_t ncols_pad;
627
- } wsp_ggml_metal_kargs_argsort;
666
+ } wsp_wsp_wsp_ggml_metal_kargs_argsort;
628
667
 
629
668
  typedef struct {
630
669
  int64_t ne0;
631
670
  float start;
632
671
  float step;
633
- } wsp_ggml_metal_kargs_arange;
672
+ } wsp_wsp_wsp_ggml_metal_kargs_arange;
634
673
 
635
674
  typedef struct {
636
675
  int32_t k0;
@@ -644,6 +683,6 @@ typedef struct {
644
683
  int64_t OH;
645
684
  int64_t OW;
646
685
  int64_t parallel_elements;
647
- } wsp_ggml_metal_kargs_pool_2d;
686
+ } wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
648
687
 
649
- #endif // WSP_GGML_METAL_IMPL
688
+ #endif // WSP_WSP_WSP_GGML_METAL_IMPL
package/cpp/ggml-metal.h CHANGED
@@ -39,18 +39,13 @@ extern "C" {
39
39
  // user-code should use only these functions
40
40
  //
41
41
 
42
+ // TODO: remove in the future
42
43
  WSP_GGML_BACKEND_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
43
44
 
44
45
  WSP_GGML_BACKEND_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
45
46
 
46
- WSP_GGML_DEPRECATED(
47
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
- "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
-
50
47
  WSP_GGML_BACKEND_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
51
48
 
52
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
53
-
54
49
  // helper to check if the device supports a specific family
55
50
  // ideally, the user code should be doing these checks
56
51
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
package/cpp/ggml-opt.cpp CHANGED
@@ -64,9 +64,11 @@ struct wsp_ggml_opt_context {
64
64
  int32_t opt_i = 0;
65
65
  bool loss_per_datapoint = false;
66
66
 
67
- wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
- void * get_opt_pars_ud = nullptr;
69
- struct wsp_ggml_tensor * adamw_params = nullptr;
67
+ wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
+ void * get_opt_pars_ud = nullptr;
69
+ struct wsp_ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
70
+
71
+ enum wsp_ggml_opt_optimizer_type optimizer = WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
70
72
  };
71
73
 
72
74
  struct wsp_ggml_opt_result {
@@ -229,9 +231,13 @@ struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(v
229
231
  result.adamw.eps = 1e-8f;
230
232
  result.adamw.wd = 0.0f;
231
233
 
234
+ result.sgd.alpha = 1e-3f;
235
+ result.sgd.wd = 0.0f;
236
+
232
237
  return result;
233
238
  }
234
239
 
240
+
235
241
  struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_constant_optimizer_params(void * userdata) {
236
242
  return *((struct wsp_ggml_opt_optimizer_params *) userdata);
237
243
  }
@@ -249,6 +255,7 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
249
255
  /*opt_period =*/ 1,
250
256
  /*get_opt_pars =*/ wsp_ggml_opt_get_default_optimizer_params,
251
257
  /*get_opt_pars_ud =*/ nullptr,
258
+ /*optimizer =*/ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
252
259
  };
253
260
  }
254
261
 
@@ -316,9 +323,14 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
316
323
  WSP_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with wsp_ggml_opt_prepare_alloc");
317
324
  WSP_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
318
325
 
326
+ const enum wsp_ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
327
+
319
328
  const bool accumulate = opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_GRAD &&
320
329
  !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
321
330
 
331
+ const bool need_momenta = opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT &&
332
+ opt_ctx->optimizer == WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
333
+
322
334
  wsp_ggml_set_input(opt_ctx->inputs);
323
335
  wsp_ggml_set_output(opt_ctx->outputs);
324
336
 
@@ -340,8 +352,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
340
352
  // - pred (if using static graphs)
341
353
  // - ncorrect (if using static graphs, 2 tensors).
342
354
  constexpr size_t n_loss = 1;
343
- const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
- (opt_ctx->build_type_alloc == WSP_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
355
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
345
356
  const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
357
  const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * wsp_ggml_tensor_overhead();
347
358
  struct wsp_ggml_init_params params = {
@@ -458,7 +469,7 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
458
469
  }
459
470
  }
460
471
 
461
- if (opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
472
+ if (need_momenta && opt_ctx->build_type_alloc >= WSP_GGML_OPT_BUILD_TYPE_OPT) {
462
473
  opt_ctx->grad_m.resize(n_nodes);
463
474
  opt_ctx->grad_v.resize(n_nodes);
464
475
  for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void wsp_ggml_opt_build(wsp_ggml_opt_context_t opt_ctx) {
492
503
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
493
504
  opt_ctx->gb_opt = wsp_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
494
505
 
495
- opt_ctx->adamw_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, 7);
496
- wsp_ggml_set_input(opt_ctx->adamw_params);
497
- wsp_ggml_set_name(opt_ctx->adamw_params, "adamw_params");
498
-
506
+ opt_ctx->opt_step_params = wsp_ggml_new_tensor_1d(opt_ctx->ctx_cpu, WSP_GGML_TYPE_F32, need_momenta ? 7 : 2);
507
+ wsp_ggml_tensor * adamw_params = opt_ctx->opt_step_params;
508
+ wsp_ggml_set_input(adamw_params);
509
+ const char * optimizer_name = wsp_ggml_opt_optimizer_name(opt_ctx->optimizer);
510
+ wsp_ggml_format_name(adamw_params, "%s_params", optimizer_name);
499
511
  for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
512
  struct wsp_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
513
  struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(opt_ctx->gb_opt, node);
502
514
 
503
515
  if (grad && (node->flags & WSP_GGML_TENSOR_FLAG_PARAM)) {
504
- struct wsp_ggml_tensor * m = opt_ctx->grad_m[i];
505
- struct wsp_ggml_tensor * v = opt_ctx->grad_v[i];
506
- struct wsp_ggml_tensor * opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
-
508
- wsp_ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
- wsp_ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
- wsp_ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
-
516
+ struct wsp_ggml_tensor * m = nullptr;
517
+ struct wsp_ggml_tensor * v = nullptr;
518
+ if (need_momenta) {
519
+ m = opt_ctx->grad_m[i];
520
+ v = opt_ctx->grad_v[i];
521
+ wsp_ggml_format_name(m, "AdamW m for %s", node->name);
522
+ wsp_ggml_format_name(v, "AdamW v for %s", node->name);
523
+ }
524
+ struct wsp_ggml_tensor * opt_step;
525
+ switch (optimizer) {
526
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
527
+ opt_step = wsp_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
528
+ break;
529
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
530
+ opt_step = wsp_ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
531
+ break;
532
+ default:
533
+ WSP_GGML_ABORT("fatal error");
534
+ }
535
+ wsp_ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
512
536
  wsp_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
513
537
  }
514
538
  }
@@ -534,6 +558,7 @@ wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params) {
534
558
  result->opt_period = params.opt_period;
535
559
  result->get_opt_pars = params.get_opt_pars;
536
560
  result->get_opt_pars_ud = params.get_opt_pars_ud;
561
+ result->optimizer = params.optimizer;
537
562
 
538
563
  WSP_GGML_ASSERT(result->opt_period >= 1);
539
564
 
@@ -756,29 +781,43 @@ void wsp_ggml_opt_alloc(wsp_ggml_opt_context_t opt_ctx, bool backward) {
756
781
  void wsp_ggml_opt_eval(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result_t result) {
757
782
  WSP_GGML_ASSERT(opt_ctx->eval_ready);
758
783
  if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
759
- struct wsp_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
760
-
761
- WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
762
- WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
763
- WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
764
- WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
765
- WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
766
- WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
767
- WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
768
- WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
769
-
770
- // beta1, beta2 after applying warmup
771
- const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
772
- const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
773
-
774
- float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->adamw_params);
775
- adamw_par_data[0] = opt_pars.adamw.alpha;
776
- adamw_par_data[1] = opt_pars.adamw.beta1;
777
- adamw_par_data[2] = opt_pars.adamw.beta2;
778
- adamw_par_data[3] = opt_pars.adamw.eps;
779
- adamw_par_data[4] = opt_pars.adamw.wd;
780
- adamw_par_data[5] = beta1h;
781
- adamw_par_data[6] = beta2h;
784
+ const wsp_ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
785
+
786
+ switch (opt_ctx->optimizer) {
787
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
788
+ WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
789
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
790
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
791
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
792
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
793
+ WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
794
+ WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
795
+ WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
796
+
797
+ // beta1, beta2 after applying warmup
798
+ const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
799
+ const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
800
+
801
+ float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
802
+ adamw_par_data[0] = opt_pars.adamw.alpha;
803
+ adamw_par_data[1] = opt_pars.adamw.beta1;
804
+ adamw_par_data[2] = opt_pars.adamw.beta2;
805
+ adamw_par_data[3] = opt_pars.adamw.eps;
806
+ adamw_par_data[4] = opt_pars.adamw.wd;
807
+ adamw_par_data[5] = beta1h;
808
+ adamw_par_data[6] = beta2h;
809
+ } break;
810
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD: {
811
+ WSP_GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
812
+ WSP_GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
813
+ WSP_GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
814
+ float * sgd = wsp_ggml_get_data_f32(opt_ctx->opt_step_params);
815
+ sgd[0] = opt_pars.sgd.alpha;
816
+ sgd[1] = opt_pars.sgd.wd;
817
+ } break;
818
+ default:
819
+ WSP_GGML_ABORT("fatal error");
820
+ }
782
821
  }
783
822
 
784
823
  wsp_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void wsp_ggml_opt_fit(
963
1002
  wsp_ggml_tensor * outputs,
964
1003
  wsp_ggml_opt_dataset_t dataset,
965
1004
  enum wsp_ggml_opt_loss_type loss_type,
1005
+ enum wsp_ggml_opt_optimizer_type optimizer,
966
1006
  wsp_ggml_opt_get_optimizer_params get_opt_pars,
967
1007
  int64_t nepoch,
968
1008
  int64_t nbatch_logical,
@@ -993,6 +1033,7 @@ void wsp_ggml_opt_fit(
993
1033
  params.opt_period = opt_period;
994
1034
  params.get_opt_pars = get_opt_pars;
995
1035
  params.get_opt_pars_ud = &epoch;
1036
+ params.optimizer = optimizer;
996
1037
  wsp_ggml_opt_context_t opt_ctx = wsp_ggml_opt_init(params);
997
1038
 
998
1039
  // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void wsp_ggml_opt_fit(
1035
1076
  wsp_ggml_opt_result_free(result_train);
1036
1077
  wsp_ggml_opt_result_free(result_val);
1037
1078
  }
1079
+
1080
+ enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t c) {
1081
+ return c->optimizer;
1082
+ }
1083
+
1084
+ WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type o) {
1085
+ switch (o) {
1086
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087
+ return "adamw";
1088
+ case WSP_GGML_OPT_OPTIMIZER_TYPE_SGD:
1089
+ return "sgd";
1090
+ default:
1091
+ return "undefined";
1092
+ };
1093
+ }