whisper.rn 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +264 -126
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +13 -5
  6. package/cpp/ggml-backend.cpp +207 -17
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  9. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  10. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  11. package/cpp/ggml-cpu/common.h +14 -0
  12. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  13. package/cpp/ggml-cpu/ggml-cpu.c +48 -41
  14. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  15. package/cpp/ggml-cpu/ops.cpp +518 -767
  16. package/cpp/ggml-cpu/ops.h +2 -0
  17. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  18. package/cpp/ggml-cpu/vec.cpp +161 -20
  19. package/cpp/ggml-cpu/vec.h +400 -51
  20. package/cpp/ggml-cpu.h +1 -1
  21. package/cpp/ggml-impl.h +43 -10
  22. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  23. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  24. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  25. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  27. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  28. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  29. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  31. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  33. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  34. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  35. package/cpp/ggml-metal-impl.h +40 -40
  36. package/cpp/ggml-metal.h +1 -6
  37. package/cpp/ggml-quants.c +1 -0
  38. package/cpp/ggml.c +175 -13
  39. package/cpp/ggml.h +84 -5
  40. package/cpp/jsi/RNWhisperJSI.cpp +2 -0
  41. package/cpp/jsi/ThreadPool.h +3 -3
  42. package/cpp/whisper.cpp +85 -70
  43. package/cpp/whisper.h +1 -0
  44. package/ios/CMakeLists.txt +6 -1
  45. package/ios/RNWhisperVadContext.mm +14 -13
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  50. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  84. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  86. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  87. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  92. package/lib/commonjs/version.json +1 -1
  93. package/lib/module/version.json +1 -1
  94. package/package.json +1 -1
  95. package/src/version.json +1 -1
  96. package/whisper-rn.podspec +8 -9
  97. package/cpp/ggml-metal.m +0 -6779
  98. package/cpp/ggml-whisper-sim.metallib +0 -0
  99. package/cpp/ggml-whisper.metallib +0 -0
@@ -1,5 +1,5 @@
1
- #ifndef WSP_GGML_METAL_IMPL
2
- #define WSP_GGML_METAL_IMPL
1
+ #ifndef WSP_WSP_WSP_GGML_METAL_IMPL
2
+ #define WSP_WSP_WSP_GGML_METAL_IMPL
3
3
 
4
4
  // kernel parameters for mat-vec threadgroups
5
5
  //
@@ -101,7 +101,7 @@ typedef struct {
101
101
  uint64_t nb2;
102
102
  uint64_t nb3;
103
103
  int32_t dim;
104
- } wsp_ggml_metal_kargs_concat;
104
+ } wsp_wsp_wsp_ggml_metal_kargs_concat;
105
105
 
106
106
  typedef struct {
107
107
  int32_t ne00;
@@ -130,7 +130,7 @@ typedef struct {
130
130
  uint64_t nb3;
131
131
  uint64_t offs;
132
132
  uint64_t o1[8];
133
- } wsp_ggml_metal_kargs_bin;
133
+ } wsp_wsp_wsp_ggml_metal_kargs_bin;
134
134
 
135
135
  typedef struct {
136
136
  int64_t ne0;
@@ -139,7 +139,7 @@ typedef struct {
139
139
  size_t nb02;
140
140
  size_t nb11;
141
141
  size_t nb21;
142
- } wsp_ggml_metal_kargs_add_id;
142
+ } wsp_wsp_wsp_ggml_metal_kargs_add_id;
143
143
 
144
144
  typedef struct {
145
145
  int32_t ne00;
@@ -158,7 +158,7 @@ typedef struct {
158
158
  uint64_t nb1;
159
159
  uint64_t nb2;
160
160
  uint64_t nb3;
161
- } wsp_ggml_metal_kargs_repeat;
161
+ } wsp_wsp_wsp_ggml_metal_kargs_repeat;
162
162
 
163
163
  typedef struct {
164
164
  int64_t ne00;
@@ -177,7 +177,7 @@ typedef struct {
177
177
  uint64_t nb1;
178
178
  uint64_t nb2;
179
179
  uint64_t nb3;
180
- } wsp_ggml_metal_kargs_cpy;
180
+ } wsp_wsp_wsp_ggml_metal_kargs_cpy;
181
181
 
182
182
  typedef struct {
183
183
  int64_t ne10;
@@ -192,7 +192,7 @@ typedef struct {
192
192
  uint64_t nb3;
193
193
  uint64_t offs;
194
194
  bool inplace;
195
- } wsp_ggml_metal_kargs_set;
195
+ } wsp_wsp_wsp_ggml_metal_kargs_set;
196
196
 
197
197
  typedef struct {
198
198
  int32_t ne00;
@@ -224,7 +224,7 @@ typedef struct {
224
224
  int32_t sect_1;
225
225
  int32_t sect_2;
226
226
  int32_t sect_3;
227
- } wsp_ggml_metal_kargs_rope;
227
+ } wsp_wsp_wsp_ggml_metal_kargs_rope;
228
228
 
229
229
  typedef struct {
230
230
  int32_t ne01;
@@ -255,7 +255,7 @@ typedef struct {
255
255
  float m1;
256
256
  int32_t n_head_log2;
257
257
  float logit_softcap;
258
- } wsp_ggml_metal_kargs_flash_attn_ext;
258
+ } wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
259
259
 
260
260
  typedef struct {
261
261
  int32_t ne00;
@@ -272,7 +272,7 @@ typedef struct {
272
272
  int32_t ne1;
273
273
  int16_t r2;
274
274
  int16_t r3;
275
- } wsp_ggml_metal_kargs_mul_mm;
275
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
276
276
 
277
277
  typedef struct {
278
278
  int32_t ne00;
@@ -293,7 +293,7 @@ typedef struct {
293
293
  int32_t ne1;
294
294
  int16_t r2;
295
295
  int16_t r3;
296
- } wsp_ggml_metal_kargs_mul_mv;
296
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
297
297
 
298
298
  typedef struct {
299
299
  int32_t ne00;
@@ -317,7 +317,7 @@ typedef struct {
317
317
  int16_t nsg;
318
318
  int16_t nxpsg;
319
319
  int16_t r1ptg;
320
- } wsp_ggml_metal_kargs_mul_mv_ext;
320
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
321
321
 
322
322
  typedef struct {
323
323
  int32_t ne10;
@@ -328,7 +328,7 @@ typedef struct {
328
328
  uint64_t nbh11;
329
329
  int32_t ne20; // n_expert_used
330
330
  uint64_t nb21;
331
- } wsp_ggml_metal_kargs_mul_mm_id_map0;
331
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
332
332
 
333
333
  typedef struct {
334
334
  int32_t ne20; // n_expert_used
@@ -339,7 +339,7 @@ typedef struct {
339
339
  int32_t ne0;
340
340
  uint64_t nb1;
341
341
  uint64_t nb2;
342
- } wsp_ggml_metal_kargs_mul_mm_id_map1;
342
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
343
343
 
344
344
  typedef struct {
345
345
  int32_t ne00;
@@ -356,7 +356,7 @@ typedef struct {
356
356
  int32_t neh1;
357
357
  int16_t r2;
358
358
  int16_t r3;
359
- } wsp_ggml_metal_kargs_mul_mm_id;
359
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
360
360
 
361
361
  typedef struct {
362
362
  int32_t nei0;
@@ -378,14 +378,14 @@ typedef struct {
378
378
  int32_t ne0;
379
379
  int32_t ne1;
380
380
  uint64_t nb1;
381
- } wsp_ggml_metal_kargs_mul_mv_id;
381
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
382
382
 
383
383
  typedef struct {
384
384
  int32_t ne00;
385
385
  int32_t ne00_4;
386
386
  uint64_t nb01;
387
387
  float eps;
388
- } wsp_ggml_metal_kargs_norm;
388
+ } wsp_wsp_wsp_ggml_metal_kargs_norm;
389
389
 
390
390
  typedef struct {
391
391
  int32_t ne00;
@@ -400,14 +400,14 @@ typedef struct {
400
400
  uint64_t nbf1[3];
401
401
  uint64_t nbf2[3];
402
402
  uint64_t nbf3[3];
403
- } wsp_ggml_metal_kargs_rms_norm;
403
+ } wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
404
404
 
405
405
  typedef struct {
406
406
  int32_t ne00;
407
407
  int32_t ne00_4;
408
408
  uint64_t nb01;
409
409
  float eps;
410
- } wsp_ggml_metal_kargs_l2_norm;
410
+ } wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
411
411
 
412
412
  typedef struct {
413
413
  int64_t ne00;
@@ -418,7 +418,7 @@ typedef struct {
418
418
  uint64_t nb02;
419
419
  int32_t n_groups;
420
420
  float eps;
421
- } wsp_ggml_metal_kargs_group_norm;
421
+ } wsp_wsp_wsp_ggml_metal_kargs_group_norm;
422
422
 
423
423
  typedef struct {
424
424
  int32_t IC;
@@ -427,7 +427,7 @@ typedef struct {
427
427
  int32_t s0;
428
428
  uint64_t nb0;
429
429
  uint64_t nb1;
430
- } wsp_ggml_metal_kargs_conv_transpose_1d;
430
+ } wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
431
431
 
432
432
  typedef struct {
433
433
  uint64_t ofs0;
@@ -445,7 +445,7 @@ typedef struct {
445
445
  int32_t KH;
446
446
  int32_t KW;
447
447
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
448
- } wsp_ggml_metal_kargs_im2col;
448
+ } wsp_wsp_wsp_ggml_metal_kargs_im2col;
449
449
 
450
450
  typedef struct{
451
451
  int32_t ne00;
@@ -458,7 +458,7 @@ typedef struct{
458
458
  int32_t i10;
459
459
  float alpha;
460
460
  float limit;
461
- } wsp_ggml_metal_kargs_glu;
461
+ } wsp_wsp_wsp_ggml_metal_kargs_glu;
462
462
 
463
463
  typedef struct {
464
464
  int64_t ne00;
@@ -485,7 +485,7 @@ typedef struct {
485
485
  uint64_t nb1;
486
486
  uint64_t nb2;
487
487
  uint64_t nb3;
488
- } wsp_ggml_metal_kargs_sum_rows;
488
+ } wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
489
489
 
490
490
  typedef struct {
491
491
  int32_t ne00;
@@ -508,13 +508,13 @@ typedef struct {
508
508
  float m0;
509
509
  float m1;
510
510
  int32_t n_head_log2;
511
- } wsp_ggml_metal_kargs_soft_max;
511
+ } wsp_wsp_wsp_ggml_metal_kargs_soft_max;
512
512
 
513
513
  typedef struct {
514
514
  int64_t ne00;
515
515
  int64_t ne01;
516
516
  int n_past;
517
- } wsp_ggml_metal_kargs_diag_mask_inf;
517
+ } wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
518
518
 
519
519
  typedef struct {
520
520
  int64_t ne00;
@@ -533,7 +533,7 @@ typedef struct {
533
533
  uint64_t nb0;
534
534
  uint64_t nb1;
535
535
  uint64_t nb2;
536
- } wsp_ggml_metal_kargs_ssm_conv;
536
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
537
537
 
538
538
  typedef struct {
539
539
  int64_t d_state;
@@ -558,7 +558,7 @@ typedef struct {
558
558
  uint64_t nb51;
559
559
  uint64_t nb52;
560
560
  uint64_t nb53;
561
- } wsp_ggml_metal_kargs_ssm_scan;
561
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
562
562
 
563
563
  typedef struct {
564
564
  int64_t ne00;
@@ -569,7 +569,7 @@ typedef struct {
569
569
  uint64_t nb11;
570
570
  uint64_t nb1;
571
571
  uint64_t nb2;
572
- } wsp_ggml_metal_kargs_get_rows;
572
+ } wsp_wsp_wsp_ggml_metal_kargs_get_rows;
573
573
 
574
574
  typedef struct {
575
575
  int32_t nk0;
@@ -585,7 +585,7 @@ typedef struct {
585
585
  uint64_t nb1;
586
586
  uint64_t nb2;
587
587
  uint64_t nb3;
588
- } wsp_ggml_metal_kargs_set_rows;
588
+ } wsp_wsp_wsp_ggml_metal_kargs_set_rows;
589
589
 
590
590
  typedef struct {
591
591
  int64_t ne00;
@@ -608,7 +608,7 @@ typedef struct {
608
608
  float sf1;
609
609
  float sf2;
610
610
  float sf3;
611
- } wsp_ggml_metal_kargs_upscale;
611
+ } wsp_wsp_wsp_ggml_metal_kargs_upscale;
612
612
 
613
613
  typedef struct {
614
614
  int64_t ne00;
@@ -627,7 +627,7 @@ typedef struct {
627
627
  uint64_t nb1;
628
628
  uint64_t nb2;
629
629
  uint64_t nb3;
630
- } wsp_ggml_metal_kargs_pad;
630
+ } wsp_wsp_wsp_ggml_metal_kargs_pad;
631
631
 
632
632
  typedef struct {
633
633
  int64_t ne00;
@@ -648,28 +648,28 @@ typedef struct {
648
648
  uint64_t nb3;
649
649
  int32_t p0;
650
650
  int32_t p1;
651
- } wsp_ggml_metal_kargs_pad_reflect_1d;
651
+ } wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
652
652
 
653
653
  typedef struct {
654
654
  uint64_t nb1;
655
655
  int dim;
656
656
  int max_period;
657
- } wsp_ggml_metal_kargs_timestep_embedding;
657
+ } wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
658
658
 
659
659
  typedef struct {
660
660
  float slope;
661
- } wsp_ggml_metal_kargs_leaky_relu;
661
+ } wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
662
662
 
663
663
  typedef struct {
664
664
  int64_t ncols;
665
665
  int64_t ncols_pad;
666
- } wsp_ggml_metal_kargs_argsort;
666
+ } wsp_wsp_wsp_ggml_metal_kargs_argsort;
667
667
 
668
668
  typedef struct {
669
669
  int64_t ne0;
670
670
  float start;
671
671
  float step;
672
- } wsp_ggml_metal_kargs_arange;
672
+ } wsp_wsp_wsp_ggml_metal_kargs_arange;
673
673
 
674
674
  typedef struct {
675
675
  int32_t k0;
@@ -683,6 +683,6 @@ typedef struct {
683
683
  int64_t OH;
684
684
  int64_t OW;
685
685
  int64_t parallel_elements;
686
- } wsp_ggml_metal_kargs_pool_2d;
686
+ } wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
687
687
 
688
- #endif // WSP_GGML_METAL_IMPL
688
+ #endif // WSP_WSP_WSP_GGML_METAL_IMPL
package/cpp/ggml-metal.h CHANGED
@@ -39,18 +39,13 @@ extern "C" {
39
39
  // user-code should use only these functions
40
40
  //
41
41
 
42
+ // TODO: remove in the future
42
43
  WSP_GGML_BACKEND_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
43
44
 
44
45
  WSP_GGML_BACKEND_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
45
46
 
46
- WSP_GGML_DEPRECATED(
47
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
- "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
-
50
47
  WSP_GGML_BACKEND_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
51
48
 
52
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
53
-
54
49
  // helper to check if the device supports a specific family
55
50
  // ideally, the user code should be doing these checks
56
51
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
package/cpp/ggml-quants.c CHANGED
@@ -3721,6 +3721,7 @@ static void wsp_quantize_row_iq3_xxs_impl(int grid_size, const float * WSP_GGML_
3721
3721
  }
3722
3722
  float best = 0;
3723
3723
  float scale = max/(2*kMaxQ-1);
3724
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = true;
3724
3725
  for (int is = -15; is <= 15; ++is) {
3725
3726
  float id = (2*kMaxQ-1+is*0.2f)/max;
3726
3727
  float this_scale = 1/id;
package/cpp/ggml.c CHANGED
@@ -982,7 +982,9 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
982
982
  "CONV_TRANSPOSE_1D",
983
983
  "IM2COL",
984
984
  "IM2COL_BACK",
985
+ "IM2COL_3D",
985
986
  "CONV_2D",
987
+ "CONV_3D",
986
988
  "CONV_2D_DW",
987
989
  "CONV_TRANSPOSE_2D",
988
990
  "POOL_1D",
@@ -1025,7 +1027,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1025
1027
  "GLU",
1026
1028
  };
1027
1029
 
1028
- static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
1030
+ static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1029
1031
 
1030
1032
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1031
1033
  "none",
@@ -1084,7 +1086,9 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1084
1086
  "conv_transpose_1d(x)",
1085
1087
  "im2col(x)",
1086
1088
  "im2col_back(x)",
1089
+ "im2col_3d(x)",
1087
1090
  "conv_2d(x)",
1091
+ "conv_3d(x)",
1088
1092
  "conv_2d_dw(x)",
1089
1093
  "conv_transpose_2d(x)",
1090
1094
  "pool_1d(x)",
@@ -1127,7 +1131,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1127
1131
  "glu(x)",
1128
1132
  };
1129
1133
 
1130
- static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
1134
+ static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
1131
1135
 
1132
1136
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1133
1137
 
@@ -3627,6 +3631,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
3627
3631
  struct wsp_ggml_tensor * a,
3628
3632
  struct wsp_ggml_tensor * b) {
3629
3633
  WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
3634
+ WSP_GGML_ASSERT(a->ne[3] == b->ne[2]);
3630
3635
  WSP_GGML_ASSERT(b->ne[3] == 1);
3631
3636
  WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
3632
3637
 
@@ -3680,7 +3685,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
3680
3685
  WSP_GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3681
3686
  WSP_GGML_ASSERT(c->ne[3] == 1);
3682
3687
  WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
3683
- WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64);
3688
+ WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64 || c->type == WSP_GGML_TYPE_I32);
3684
3689
 
3685
3690
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(a));
3686
3691
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(b));
@@ -3690,6 +3695,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
3690
3695
  result->op = WSP_GGML_OP_SET_ROWS;
3691
3696
  result->src[0] = b;
3692
3697
  result->src[1] = c;
3698
+ result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
3693
3699
 
3694
3700
  return result;
3695
3701
  }
@@ -3930,7 +3936,7 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
3930
3936
  memcpy(params + 8, &attn_factor, sizeof(float));
3931
3937
  memcpy(params + 9, &beta_fast, sizeof(float));
3932
3938
  memcpy(params + 10, &beta_slow, sizeof(float));
3933
- if (mrope_used) {
3939
+ if (mrope_used && sections) {
3934
3940
  memcpy(params + 11, sections, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
3935
3941
  } else {
3936
3942
  memset(params + 11, 0, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
@@ -4367,6 +4373,91 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d(
4367
4373
  return result;
4368
4374
  }
4369
4375
 
4376
+ // a: [OC*IC, KD, KH, KW]
4377
+ // b: [N*IC, ID, IH, IW]
4378
+ // result: [N*OD, OH, OW, IC * KD * KH * KW]
4379
+ struct wsp_ggml_tensor * wsp_ggml_im2col_3d(
4380
+ struct wsp_ggml_context * ctx,
4381
+ struct wsp_ggml_tensor * a,
4382
+ struct wsp_ggml_tensor * b,
4383
+ int64_t IC,
4384
+ int s0, // stride width
4385
+ int s1, // stride height
4386
+ int s2, // stride depth
4387
+ int p0, // padding width
4388
+ int p1, // padding height
4389
+ int p2, // padding depth
4390
+ int d0, // dilation width
4391
+ int d1, // dilation height
4392
+ int d2, // dilation depth
4393
+ enum wsp_ggml_type dst_type) {
4394
+ const int64_t N = b->ne[3] / IC;
4395
+ const int64_t ID = b->ne[2];
4396
+ const int64_t IH = b->ne[1];
4397
+ const int64_t IW = b->ne[0];
4398
+
4399
+ const int64_t OC = a->ne[3] / IC;
4400
+ UNUSED(OC);
4401
+ const int64_t KD = a->ne[2];
4402
+ const int64_t KH = a->ne[1];
4403
+ const int64_t KW = a->ne[0];
4404
+ const int64_t OD = wsp_ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
4405
+ const int64_t OH = wsp_ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
4406
+ const int64_t OW = wsp_ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
4407
+
4408
+ WSP_GGML_ASSERT((OD > 0) && "b too small compared to a");
4409
+ WSP_GGML_ASSERT((OH > 0) && "b too small compared to a");
4410
+ WSP_GGML_ASSERT((OW > 0) && "b too small compared to a");
4411
+
4412
+
4413
+ const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
4414
+
4415
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, dst_type, 4, ne);
4416
+ int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
4417
+ wsp_ggml_set_op_params(result, params, sizeof(params));
4418
+
4419
+ result->op = WSP_GGML_OP_IM2COL_3D;
4420
+ result->src[0] = a;
4421
+ result->src[1] = b;
4422
+
4423
+ return result;
4424
+ }
4425
+
4426
+ // a: [OC*IC, KD, KH, KW]
4427
+ // b: [N*IC, ID, IH, IW]
4428
+ // result: [N*OC, OD, OH, OW]
4429
+ struct wsp_ggml_tensor * wsp_ggml_conv_3d(
4430
+ struct wsp_ggml_context * ctx,
4431
+ struct wsp_ggml_tensor * a,
4432
+ struct wsp_ggml_tensor * b,
4433
+ int64_t IC,
4434
+ int s0, // stride width
4435
+ int s1, // stride height
4436
+ int s2, // stride depth
4437
+ int p0, // padding width
4438
+ int p1, // padding height
4439
+ int p2, // padding depth
4440
+ int d0, // dilation width
4441
+ int d1, // dilation height
4442
+ int d2 // dilation depth
4443
+ ) {
4444
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
4445
+
4446
+ int64_t OC = a->ne[3] / IC;
4447
+ int64_t N = b->ne[3] / IC;
4448
+ struct wsp_ggml_tensor * result =
4449
+ wsp_ggml_mul_mat(ctx,
4450
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
4451
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
4452
+
4453
+ int64_t OD = im2col->ne[3] / N;
4454
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
4455
+ result = wsp_ggml_cont(ctx, wsp_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
4456
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
4457
+
4458
+ return result;
4459
+ }
4460
+
4370
4461
  // wsp_ggml_conv_2d_sk_p0
4371
4462
 
4372
4463
  struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
@@ -4488,6 +4579,56 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
4488
4579
  return result;
4489
4580
  }
4490
4581
 
4582
+ // wsp_ggml_conv_3d_direct
4583
+
4584
+ struct wsp_ggml_tensor * wsp_ggml_conv_3d_direct(
4585
+ struct wsp_ggml_context * ctx,
4586
+ struct wsp_ggml_tensor * a,
4587
+ struct wsp_ggml_tensor * b,
4588
+ int s0,
4589
+ int s1,
4590
+ int s2,
4591
+ int p0,
4592
+ int p1,
4593
+ int p2,
4594
+ int d0,
4595
+ int d1,
4596
+ int d2,
4597
+ int c,
4598
+ int n,
4599
+ int oc) {
4600
+
4601
+ WSP_GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
4602
+ WSP_GGML_ASSERT(b->ne[3] == (int64_t) c * n);
4603
+
4604
+ int64_t ne[4];
4605
+ ne[0] = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4606
+ ne[1] = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4607
+ ne[2] = wsp_ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
4608
+ ne[3] = (int64_t) oc * n;
4609
+
4610
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
4611
+
4612
+ wsp_ggml_set_op_params_i32(result, 0, s0);
4613
+ wsp_ggml_set_op_params_i32(result, 1, s1);
4614
+ wsp_ggml_set_op_params_i32(result, 2, s2);
4615
+ wsp_ggml_set_op_params_i32(result, 3, p0);
4616
+ wsp_ggml_set_op_params_i32(result, 4, p1);
4617
+ wsp_ggml_set_op_params_i32(result, 5, p2);
4618
+ wsp_ggml_set_op_params_i32(result, 6, d0);
4619
+ wsp_ggml_set_op_params_i32(result, 7, d1);
4620
+ wsp_ggml_set_op_params_i32(result, 8, d2);
4621
+ wsp_ggml_set_op_params_i32(result, 9, c);
4622
+ wsp_ggml_set_op_params_i32(result, 10, n);
4623
+ wsp_ggml_set_op_params_i32(result, 11, oc);
4624
+
4625
+ result->op = WSP_GGML_OP_CONV_3D;
4626
+ result->src[0] = a;
4627
+ result->src[1] = b;
4628
+
4629
+ return result;
4630
+ }
4631
+
4491
4632
  // wsp_ggml_conv_transpose_2d_p0
4492
4633
 
4493
4634
  static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4666,11 +4807,36 @@ struct wsp_ggml_tensor * wsp_ggml_pad(
4666
4807
  int p1,
4667
4808
  int p2,
4668
4809
  int p3) {
4810
+ return wsp_ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
4811
+ }
4812
+
4813
+ struct wsp_ggml_tensor * wsp_ggml_pad_ext(
4814
+ struct wsp_ggml_context * ctx,
4815
+ struct wsp_ggml_tensor * a,
4816
+ int lp0,
4817
+ int rp0,
4818
+ int lp1,
4819
+ int rp1,
4820
+ int lp2,
4821
+ int rp2,
4822
+ int lp3,
4823
+ int rp3
4824
+ ) {
4669
4825
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
4670
- a->ne[0] + p0,
4671
- a->ne[1] + p1,
4672
- a->ne[2] + p2,
4673
- a->ne[3] + p3);
4826
+ a->ne[0] + lp0 + rp0,
4827
+ a->ne[1] + lp1 + rp1,
4828
+ a->ne[2] + lp2 + rp2,
4829
+ a->ne[3] + lp3 + rp3);
4830
+
4831
+ wsp_ggml_set_op_params_i32(result, 0, lp0);
4832
+ wsp_ggml_set_op_params_i32(result, 1, rp0);
4833
+ wsp_ggml_set_op_params_i32(result, 2, lp1);
4834
+ wsp_ggml_set_op_params_i32(result, 3, rp1);
4835
+ wsp_ggml_set_op_params_i32(result, 4, lp2);
4836
+ wsp_ggml_set_op_params_i32(result, 5, rp2);
4837
+ wsp_ggml_set_op_params_i32(result, 6, lp3);
4838
+ wsp_ggml_set_op_params_i32(result, 7, rp3);
4839
+
4674
4840
 
4675
4841
  result->op = WSP_GGML_OP_PAD;
4676
4842
  result->src[0] = a;
@@ -4766,12 +4932,8 @@ struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
4766
4932
  struct wsp_ggml_tensor * timesteps,
4767
4933
  int dim,
4768
4934
  int max_period) {
4769
- int actual_dim = dim;
4770
- if (dim % 2 != 0) {
4771
- actual_dim = dim + 1;
4772
- }
4773
4935
 
4774
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
4936
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, dim, timesteps->ne[0]);
4775
4937
 
4776
4938
  wsp_ggml_set_op_params_i32(result, 0, dim);
4777
4939
  wsp_ggml_set_op_params_i32(result, 1, max_period);