whisper.rn 0.5.0-rc.9 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +265 -141
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +30 -13
  6. package/cpp/ggml-backend.cpp +221 -38
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-common.h +17 -0
  9. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  10. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  12. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  13. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  14. package/cpp/ggml-cpu/arch-fallback.h +32 -2
  15. package/cpp/ggml-cpu/common.h +14 -0
  16. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  17. package/cpp/ggml-cpu/ggml-cpu.c +70 -42
  18. package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
  19. package/cpp/ggml-cpu/ops.cpp +1587 -1177
  20. package/cpp/ggml-cpu/ops.h +5 -8
  21. package/cpp/ggml-cpu/quants.c +35 -0
  22. package/cpp/ggml-cpu/quants.h +8 -0
  23. package/cpp/ggml-cpu/repack.cpp +458 -47
  24. package/cpp/ggml-cpu/repack.h +22 -0
  25. package/cpp/ggml-cpu/simd-mappings.h +89 -60
  26. package/cpp/ggml-cpu/traits.cpp +2 -2
  27. package/cpp/ggml-cpu/traits.h +1 -1
  28. package/cpp/ggml-cpu/vec.cpp +170 -26
  29. package/cpp/ggml-cpu/vec.h +506 -63
  30. package/cpp/ggml-cpu.h +1 -1
  31. package/cpp/ggml-impl.h +119 -9
  32. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  33. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  34. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  35. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  36. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  37. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  38. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  39. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  40. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  41. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  42. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  43. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  44. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  45. package/cpp/ggml-metal-impl.h +90 -51
  46. package/cpp/ggml-metal.h +1 -6
  47. package/cpp/ggml-opt.cpp +97 -41
  48. package/cpp/ggml-opt.h +25 -6
  49. package/cpp/ggml-quants.c +111 -16
  50. package/cpp/ggml-quants.h +6 -0
  51. package/cpp/ggml.c +486 -98
  52. package/cpp/ggml.h +221 -16
  53. package/cpp/gguf.cpp +8 -1
  54. package/cpp/jsi/RNWhisperJSI.cpp +25 -6
  55. package/cpp/jsi/ThreadPool.h +3 -3
  56. package/cpp/whisper.cpp +100 -76
  57. package/cpp/whisper.h +1 -0
  58. package/ios/CMakeLists.txt +6 -1
  59. package/ios/RNWhisper.mm +6 -6
  60. package/ios/RNWhisperContext.mm +2 -0
  61. package/ios/RNWhisperVadContext.mm +16 -13
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  63. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  64. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  67. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  68. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  70. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  72. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  74. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  77. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  78. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  79. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  80. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  81. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  82. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  84. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  86. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  89. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  92. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  93. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  94. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  95. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  96. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  97. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  98. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  99. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  101. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  102. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  103. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  104. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  105. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  106. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  107. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  108. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  109. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  110. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  111. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  112. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  113. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  114. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  115. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  116. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  117. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  118. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  119. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  120. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  121. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  122. package/lib/commonjs/version.json +1 -1
  123. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  124. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  125. package/lib/module/version.json +1 -1
  126. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  127. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  128. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  129. package/package.json +1 -1
  130. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  131. package/src/realtime-transcription/types.ts +6 -0
  132. package/src/version.json +1 -1
  133. package/whisper-rn.podspec +8 -9
  134. package/cpp/ggml-metal.m +0 -6284
  135. package/cpp/ggml-whisper-sim.metallib +0 -0
  136. package/cpp/ggml-whisper.metallib +0 -0
@@ -73,6 +73,35 @@ static inline int wsp_ggml_up(int n, int m) {
73
73
  return (n + m - 1) & ~(m - 1);
74
74
  }
75
75
 
76
+ // TODO: move to ggml.h? (won't be able to inline)
77
+ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
78
+ if (a->type != b->type) {
79
+ return false;
80
+ }
81
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
82
+ if (a->ne[i] != b->ne[i]) {
83
+ return false;
84
+ }
85
+ if (a->nb[i] != b->nb[i]) {
86
+ return false;
87
+ }
88
+ }
89
+ return true;
90
+ }
91
+
92
+ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
93
+ switch (op) {
94
+ case WSP_GGML_OP_NONE:
95
+ case WSP_GGML_OP_RESHAPE:
96
+ case WSP_GGML_OP_TRANSPOSE:
97
+ case WSP_GGML_OP_VIEW:
98
+ case WSP_GGML_OP_PERMUTE:
99
+ return true;
100
+ default:
101
+ return false;
102
+ }
103
+ }
104
+
76
105
  //
77
106
  // logging
78
107
  //
@@ -313,6 +342,10 @@ struct wsp_ggml_cgraph {
313
342
  // if you need the gradients, get them from the original graph
314
343
  struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
315
344
 
345
+ // ggml-alloc.c: true if the operation can reuse memory from its sources
346
+ WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
347
+
348
+
316
349
  // Memory allocation
317
350
 
318
351
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
@@ -394,6 +427,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
394
427
  #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
395
428
  #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
396
429
 
430
+ static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
431
+ uint32_t bits; // Stores the raw bit representation of the float
432
+
433
+ // Handle special case for minimum exponent (denormalized float)
434
+ if (x == 0) {
435
+ // Bit pattern for 2^(-127):
436
+ // - Sign bit: 0 (positive)
437
+ // - Exponent: 0 (denormalized number)
438
+ // - Mantissa: 0x400000 (0.5 in fractional form)
439
+ // Value = 0.5 * 2^(-126) = 2^(-127)
440
+ bits = 0x00400000;
441
+ }
442
+ // note: disabled as we don't need to handle NaNs
443
+ //// Handle special case for NaN (all bits set)
444
+ //else if (x == 0xFF) {
445
+ // // Standard quiet NaN pattern:
446
+ // // - Sign bit: 0
447
+ // // - Exponent: all 1s (0xFF)
448
+ // // - Mantissa: 0x400000 (quiet NaN flag)
449
+ // bits = 0x7FC00000;
450
+ //}
451
+ // Normalized values (most common case)
452
+ else {
453
+ // Construct normalized float by shifting exponent into position:
454
+ // - Exponent field: 8 bits (positions 30-23)
455
+ // - Mantissa: 0 (implicit leading 1)
456
+ // Value = 2^(x - 127)
457
+ bits = (uint32_t) x << 23;
458
+ }
459
+
460
+ float result; // Final float value
461
+ // Safely reinterpret bit pattern as float without type-punning issues
462
+ memcpy(&result, &bits, sizeof(float));
463
+ return result;
464
+ }
465
+
466
+ // Equal to wsp_ggml_e8m0_to_fp32/2
467
+ // Useful with MXFP4 quantization since the E0M2 values are doubled
468
+ static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
469
+ uint32_t bits;
470
+
471
+ // For x < 2: use precomputed denormal patterns
472
+ if (x < 2) {
473
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
474
+ bits = 0x00200000 << x;
475
+ }
476
+ // For x >= 2: normalized exponent adjustment
477
+ else {
478
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
479
+ bits = (uint32_t)(x - 1) << 23;
480
+ }
481
+ // Note: NaNs are not handled here
482
+
483
+ float result;
484
+ memcpy(&result, &bits, sizeof(float));
485
+ return result;
486
+ }
487
+
488
+ #define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
489
+ #define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
490
+
397
491
  /**
398
492
  * Converts brain16 to float32.
399
493
  *
@@ -493,27 +587,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
493
587
  return true;
494
588
  }
495
589
 
496
- // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
590
+ // Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
497
591
  // and are fusable. Nodes are considered fusable according to this function if:
498
592
  // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
499
593
  // - all nodes except the last are a src of the following node.
500
594
  // - all nodes are the same shape.
501
595
  // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
502
- static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
503
- if (node_idx + num_ops > cgraph->n_nodes) {
504
- return false;
505
- }
506
-
596
+ static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
507
597
  for (int i = 0; i < num_ops; ++i) {
508
- struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
598
+ if (node_idxs[i] >= cgraph->n_nodes) {
599
+ return false;
600
+ }
601
+
602
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
509
603
  if (node->op != ops[i]) {
510
604
  return false;
511
605
  }
512
- if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
606
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
513
607
  return false;
514
608
  }
515
609
  if (i > 0) {
516
- struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
610
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
517
611
  if (node->src[0] != prev && node->src[1] != prev) {
518
612
  return false;
519
613
  }
@@ -525,6 +619,22 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
525
619
  return true;
526
620
  }
527
621
 
622
+ // same as above, for sequential indices starting at node_idx
623
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
624
+ assert(num_ops < 32);
625
+
626
+ if (node_idx + num_ops > cgraph->n_nodes) {
627
+ return false;
628
+ }
629
+
630
+ int idxs[32];
631
+ for (int i = 0; i < num_ops; ++i) {
632
+ idxs[i] = node_idx + i;
633
+ }
634
+
635
+ return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
636
+ }
637
+
528
638
  #ifdef __cplusplus
529
639
  }
530
640
  #endif
@@ -1,5 +1,5 @@
1
- #ifndef WSP_GGML_METAL_IMPL
2
- #define WSP_GGML_METAL_IMPL
1
+ #ifndef WSP_WSP_WSP_GGML_METAL_IMPL
2
+ #define WSP_WSP_WSP_GGML_METAL_IMPL
3
3
 
4
4
  // kernel parameters for mat-vec threadgroups
5
5
  //
@@ -23,6 +23,9 @@
23
23
  #define N_R0_Q8_0 4
24
24
  #define N_SG_Q8_0 2
25
25
 
26
+ #define N_R0_MXFP4 2
27
+ #define N_SG_MXFP4 2
28
+
26
29
  #define N_R0_Q2_K 4
27
30
  #define N_SG_Q2_K 2
28
31
 
@@ -98,7 +101,7 @@ typedef struct {
98
101
  uint64_t nb2;
99
102
  uint64_t nb3;
100
103
  int32_t dim;
101
- } wsp_ggml_metal_kargs_concat;
104
+ } wsp_wsp_wsp_ggml_metal_kargs_concat;
102
105
 
103
106
  typedef struct {
104
107
  int32_t ne00;
@@ -126,7 +129,17 @@ typedef struct {
126
129
  uint64_t nb2;
127
130
  uint64_t nb3;
128
131
  uint64_t offs;
129
- } wsp_ggml_metal_kargs_bin;
132
+ uint64_t o1[8];
133
+ } wsp_wsp_wsp_ggml_metal_kargs_bin;
134
+
135
+ typedef struct {
136
+ int64_t ne0;
137
+ int64_t ne1;
138
+ size_t nb01;
139
+ size_t nb02;
140
+ size_t nb11;
141
+ size_t nb21;
142
+ } wsp_wsp_wsp_ggml_metal_kargs_add_id;
130
143
 
131
144
  typedef struct {
132
145
  int32_t ne00;
@@ -145,7 +158,7 @@ typedef struct {
145
158
  uint64_t nb1;
146
159
  uint64_t nb2;
147
160
  uint64_t nb3;
148
- } wsp_ggml_metal_kargs_repeat;
161
+ } wsp_wsp_wsp_ggml_metal_kargs_repeat;
149
162
 
150
163
  typedef struct {
151
164
  int64_t ne00;
@@ -164,7 +177,7 @@ typedef struct {
164
177
  uint64_t nb1;
165
178
  uint64_t nb2;
166
179
  uint64_t nb3;
167
- } wsp_ggml_metal_kargs_cpy;
180
+ } wsp_wsp_wsp_ggml_metal_kargs_cpy;
168
181
 
169
182
  typedef struct {
170
183
  int64_t ne10;
@@ -179,7 +192,7 @@ typedef struct {
179
192
  uint64_t nb3;
180
193
  uint64_t offs;
181
194
  bool inplace;
182
- } wsp_ggml_metal_kargs_set;
195
+ } wsp_wsp_wsp_ggml_metal_kargs_set;
183
196
 
184
197
  typedef struct {
185
198
  int32_t ne00;
@@ -211,7 +224,7 @@ typedef struct {
211
224
  int32_t sect_1;
212
225
  int32_t sect_2;
213
226
  int32_t sect_3;
214
- } wsp_ggml_metal_kargs_rope;
227
+ } wsp_wsp_wsp_ggml_metal_kargs_rope;
215
228
 
216
229
  typedef struct {
217
230
  int32_t ne01;
@@ -229,16 +242,20 @@ typedef struct {
229
242
  uint64_t nb21;
230
243
  uint64_t nb22;
231
244
  uint64_t nb23;
245
+ int32_t ne32;
246
+ int32_t ne33;
232
247
  uint64_t nb31;
248
+ uint64_t nb32;
249
+ uint64_t nb33;
233
250
  int32_t ne1;
234
251
  int32_t ne2;
235
252
  float scale;
236
253
  float max_bias;
237
254
  float m0;
238
255
  float m1;
239
- uint16_t n_head_log2;
256
+ int32_t n_head_log2;
240
257
  float logit_softcap;
241
- } wsp_ggml_metal_kargs_flash_attn_ext;
258
+ } wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
242
259
 
243
260
  typedef struct {
244
261
  int32_t ne00;
@@ -255,7 +272,7 @@ typedef struct {
255
272
  int32_t ne1;
256
273
  int16_t r2;
257
274
  int16_t r3;
258
- } wsp_ggml_metal_kargs_mul_mm;
275
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
259
276
 
260
277
  typedef struct {
261
278
  int32_t ne00;
@@ -276,7 +293,7 @@ typedef struct {
276
293
  int32_t ne1;
277
294
  int16_t r2;
278
295
  int16_t r3;
279
- } wsp_ggml_metal_kargs_mul_mv;
296
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
280
297
 
281
298
  typedef struct {
282
299
  int32_t ne00;
@@ -300,7 +317,7 @@ typedef struct {
300
317
  int16_t nsg;
301
318
  int16_t nxpsg;
302
319
  int16_t r1ptg;
303
- } wsp_ggml_metal_kargs_mul_mv_ext;
320
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
304
321
 
305
322
  typedef struct {
306
323
  int32_t ne10;
@@ -311,7 +328,7 @@ typedef struct {
311
328
  uint64_t nbh11;
312
329
  int32_t ne20; // n_expert_used
313
330
  uint64_t nb21;
314
- } wsp_ggml_metal_kargs_mul_mm_id_map0;
331
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
315
332
 
316
333
  typedef struct {
317
334
  int32_t ne20; // n_expert_used
@@ -322,7 +339,7 @@ typedef struct {
322
339
  int32_t ne0;
323
340
  uint64_t nb1;
324
341
  uint64_t nb2;
325
- } wsp_ggml_metal_kargs_mul_mm_id_map1;
342
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
326
343
 
327
344
  typedef struct {
328
345
  int32_t ne00;
@@ -339,7 +356,7 @@ typedef struct {
339
356
  int32_t neh1;
340
357
  int16_t r2;
341
358
  int16_t r3;
342
- } wsp_ggml_metal_kargs_mul_mm_id;
359
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
343
360
 
344
361
  typedef struct {
345
362
  int32_t nei0;
@@ -361,28 +378,36 @@ typedef struct {
361
378
  int32_t ne0;
362
379
  int32_t ne1;
363
380
  uint64_t nb1;
364
- } wsp_ggml_metal_kargs_mul_mv_id;
381
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
365
382
 
366
383
  typedef struct {
367
384
  int32_t ne00;
368
385
  int32_t ne00_4;
369
386
  uint64_t nb01;
370
387
  float eps;
371
- } wsp_ggml_metal_kargs_norm;
388
+ } wsp_wsp_wsp_ggml_metal_kargs_norm;
372
389
 
373
390
  typedef struct {
374
391
  int32_t ne00;
375
392
  int32_t ne00_4;
376
- uint64_t nb01;
393
+ uint64_t nb1;
394
+ uint64_t nb2;
395
+ uint64_t nb3;
377
396
  float eps;
378
- } wsp_ggml_metal_kargs_rms_norm;
397
+ int32_t nef1[3];
398
+ int32_t nef2[3];
399
+ int32_t nef3[3];
400
+ uint64_t nbf1[3];
401
+ uint64_t nbf2[3];
402
+ uint64_t nbf3[3];
403
+ } wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
379
404
 
380
405
  typedef struct {
381
406
  int32_t ne00;
382
407
  int32_t ne00_4;
383
408
  uint64_t nb01;
384
409
  float eps;
385
- } wsp_ggml_metal_kargs_l2_norm;
410
+ } wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
386
411
 
387
412
  typedef struct {
388
413
  int64_t ne00;
@@ -393,7 +418,7 @@ typedef struct {
393
418
  uint64_t nb02;
394
419
  int32_t n_groups;
395
420
  float eps;
396
- } wsp_ggml_metal_kargs_group_norm;
421
+ } wsp_wsp_wsp_ggml_metal_kargs_group_norm;
397
422
 
398
423
  typedef struct {
399
424
  int32_t IC;
@@ -402,7 +427,7 @@ typedef struct {
402
427
  int32_t s0;
403
428
  uint64_t nb0;
404
429
  uint64_t nb1;
405
- } wsp_ggml_metal_kargs_conv_transpose_1d;
430
+ } wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
406
431
 
407
432
  typedef struct {
408
433
  uint64_t ofs0;
@@ -420,7 +445,7 @@ typedef struct {
420
445
  int32_t KH;
421
446
  int32_t KW;
422
447
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
- } wsp_ggml_metal_kargs_im2col;
448
+ } wsp_wsp_wsp_ggml_metal_kargs_im2col;
424
449
 
425
450
  typedef struct{
426
451
  int32_t ne00;
@@ -431,7 +456,9 @@ typedef struct{
431
456
  uint64_t nb1;
432
457
  int32_t i00;
433
458
  int32_t i10;
434
- } wsp_ggml_metal_kargs_glu;
459
+ float alpha;
460
+ float limit;
461
+ } wsp_wsp_wsp_ggml_metal_kargs_glu;
435
462
 
436
463
  typedef struct {
437
464
  int64_t ne00;
@@ -458,24 +485,36 @@ typedef struct {
458
485
  uint64_t nb1;
459
486
  uint64_t nb2;
460
487
  uint64_t nb3;
461
- } wsp_ggml_metal_kargs_sum_rows;
488
+ } wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
462
489
 
463
490
  typedef struct {
464
- int64_t ne00;
465
- int64_t ne01;
466
- int64_t ne02;
491
+ int32_t ne00;
492
+ int32_t ne01;
493
+ int32_t ne02;
494
+ uint64_t nb01;
495
+ uint64_t nb02;
496
+ uint64_t nb03;
497
+ int32_t ne11;
498
+ int32_t ne12;
499
+ int32_t ne13;
500
+ uint64_t nb11;
501
+ uint64_t nb12;
502
+ uint64_t nb13;
503
+ uint64_t nb1;
504
+ uint64_t nb2;
505
+ uint64_t nb3;
467
506
  float scale;
468
507
  float max_bias;
469
508
  float m0;
470
509
  float m1;
471
- uint32_t n_head_log2;
472
- } wsp_ggml_metal_kargs_soft_max;
510
+ int32_t n_head_log2;
511
+ } wsp_wsp_wsp_ggml_metal_kargs_soft_max;
473
512
 
474
513
  typedef struct {
475
514
  int64_t ne00;
476
515
  int64_t ne01;
477
516
  int n_past;
478
- } wsp_ggml_metal_kargs_diag_mask_inf;
517
+ } wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
479
518
 
480
519
  typedef struct {
481
520
  int64_t ne00;
@@ -494,32 +533,32 @@ typedef struct {
494
533
  uint64_t nb0;
495
534
  uint64_t nb1;
496
535
  uint64_t nb2;
497
- } wsp_ggml_metal_kargs_ssm_conv;
536
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
498
537
 
499
538
  typedef struct {
500
539
  int64_t d_state;
501
540
  int64_t d_inner;
541
+ int64_t n_head;
542
+ int64_t n_group;
502
543
  int64_t n_seq_tokens;
503
544
  int64_t n_seqs;
504
- uint64_t nb00;
545
+ int64_t s_off;
505
546
  uint64_t nb01;
506
547
  uint64_t nb02;
507
- uint64_t nb10;
548
+ uint64_t nb03;
508
549
  uint64_t nb11;
509
550
  uint64_t nb12;
510
551
  uint64_t nb13;
511
- uint64_t nb20;
512
552
  uint64_t nb21;
513
553
  uint64_t nb22;
514
- uint64_t nb30;
515
554
  uint64_t nb31;
516
- uint64_t nb40;
517
555
  uint64_t nb41;
518
556
  uint64_t nb42;
519
- uint64_t nb50;
557
+ uint64_t nb43;
520
558
  uint64_t nb51;
521
559
  uint64_t nb52;
522
- } wsp_ggml_metal_kargs_ssm_scan;
560
+ uint64_t nb53;
561
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
523
562
 
524
563
  typedef struct {
525
564
  int64_t ne00;
@@ -530,7 +569,7 @@ typedef struct {
530
569
  uint64_t nb11;
531
570
  uint64_t nb1;
532
571
  uint64_t nb2;
533
- } wsp_ggml_metal_kargs_get_rows;
572
+ } wsp_wsp_wsp_ggml_metal_kargs_get_rows;
534
573
 
535
574
  typedef struct {
536
575
  int32_t nk0;
@@ -546,7 +585,7 @@ typedef struct {
546
585
  uint64_t nb1;
547
586
  uint64_t nb2;
548
587
  uint64_t nb3;
549
- } wsp_ggml_metal_kargs_set_rows;
588
+ } wsp_wsp_wsp_ggml_metal_kargs_set_rows;
550
589
 
551
590
  typedef struct {
552
591
  int64_t ne00;
@@ -569,7 +608,7 @@ typedef struct {
569
608
  float sf1;
570
609
  float sf2;
571
610
  float sf3;
572
- } wsp_ggml_metal_kargs_upscale;
611
+ } wsp_wsp_wsp_ggml_metal_kargs_upscale;
573
612
 
574
613
  typedef struct {
575
614
  int64_t ne00;
@@ -588,7 +627,7 @@ typedef struct {
588
627
  uint64_t nb1;
589
628
  uint64_t nb2;
590
629
  uint64_t nb3;
591
- } wsp_ggml_metal_kargs_pad;
630
+ } wsp_wsp_wsp_ggml_metal_kargs_pad;
592
631
 
593
632
  typedef struct {
594
633
  int64_t ne00;
@@ -609,28 +648,28 @@ typedef struct {
609
648
  uint64_t nb3;
610
649
  int32_t p0;
611
650
  int32_t p1;
612
- } wsp_ggml_metal_kargs_pad_reflect_1d;
651
+ } wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
613
652
 
614
653
  typedef struct {
615
654
  uint64_t nb1;
616
655
  int dim;
617
656
  int max_period;
618
- } wsp_ggml_metal_kargs_timestep_embedding;
657
+ } wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
619
658
 
620
659
  typedef struct {
621
660
  float slope;
622
- } wsp_ggml_metal_kargs_leaky_relu;
661
+ } wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
623
662
 
624
663
  typedef struct {
625
664
  int64_t ncols;
626
665
  int64_t ncols_pad;
627
- } wsp_ggml_metal_kargs_argsort;
666
+ } wsp_wsp_wsp_ggml_metal_kargs_argsort;
628
667
 
629
668
  typedef struct {
630
669
  int64_t ne0;
631
670
  float start;
632
671
  float step;
633
- } wsp_ggml_metal_kargs_arange;
672
+ } wsp_wsp_wsp_ggml_metal_kargs_arange;
634
673
 
635
674
  typedef struct {
636
675
  int32_t k0;
@@ -644,6 +683,6 @@ typedef struct {
644
683
  int64_t OH;
645
684
  int64_t OW;
646
685
  int64_t parallel_elements;
647
- } wsp_ggml_metal_kargs_pool_2d;
686
+ } wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
648
687
 
649
- #endif // WSP_GGML_METAL_IMPL
688
+ #endif // WSP_WSP_WSP_GGML_METAL_IMPL
@@ -39,18 +39,13 @@ extern "C" {
39
39
  // user-code should use only these functions
40
40
  //
41
41
 
42
+ // TODO: remove in the future
42
43
  WSP_GGML_BACKEND_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
43
44
 
44
45
  WSP_GGML_BACKEND_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
45
46
 
46
- WSP_GGML_DEPRECATED(
47
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
- "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
-
50
47
  WSP_GGML_BACKEND_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
51
48
 
52
- WSP_GGML_BACKEND_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
53
-
54
49
  // helper to check if the device supports a specific family
55
50
  // ideally, the user code should be doing these checks
56
51
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
@@ -74,16 +74,26 @@ extern "C" {
74
74
  WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
75
75
  };
76
76
 
77
+ enum wsp_ggml_opt_optimizer_type {
78
+ WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79
+ WSP_GGML_OPT_OPTIMIZER_TYPE_SGD,
80
+
81
+ WSP_GGML_OPT_OPTIMIZER_TYPE_COUNT
82
+ };
83
+
77
84
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
78
85
  struct wsp_ggml_opt_optimizer_params {
79
- // AdamW optimizer parameters
80
86
  struct {
81
87
  float alpha; // learning rate
82
- float beta1;
83
- float beta2;
88
+ float beta1; // first AdamW momentum
89
+ float beta2; // second AdamW momentum
84
90
  float eps; // epsilon for numerical stability
85
- float wd; // weight decay for AdamW, use 0.0f to disable
91
+ float wd; // weight decay - 0.0f to disable
86
92
  } adamw;
93
+ struct {
94
+ float alpha; // learning rate
95
+ float wd; // weight decay
96
+ } sgd;
87
97
  };
88
98
 
89
99
  // callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
112
122
 
113
123
  int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114
124
 
115
- wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116
- void * get_opt_pars_ud; // userdata for calculating optimizer parameters
125
+ wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
126
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127
+
128
+ // only WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
129
+ enum wsp_ggml_opt_optimizer_type optimizer;
117
130
  };
118
131
 
119
132
  // get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142
155
  // get the gradient accumulator for a node from the forward graph
143
156
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
144
157
 
158
+ WSP_GGML_API enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t); //TODO consistent naming scheme
159
+
160
+ WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type);
161
+
145
162
  // ====== Optimization Result ======
146
163
 
147
164
  WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226
243
  struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227
244
  wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228
245
  enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
246
+ enum wsp_ggml_opt_optimizer_type optimizer, // sgd or adamw
229
247
  wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230
248
  int64_t nepoch, // how many times the dataset should be iterated over
231
249
  int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232
250
  float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233
251
  bool silent); // whether or not info prints to stderr should be suppressed
234
252
 
253
+
235
254
  #ifdef __cplusplus
236
255
  }
237
256
  #endif
@@ -21,6 +21,8 @@ WSP_GGML_API void wsp_quantize_row_q5_1_ref(const float * WSP_GGML_RESTRICT x, b
21
21
  WSP_GGML_API void wsp_quantize_row_q8_0_ref(const float * WSP_GGML_RESTRICT x, block_q8_0 * WSP_GGML_RESTRICT y, int64_t k);
22
22
  WSP_GGML_API void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * WSP_GGML_RESTRICT y, int64_t k);
23
23
 
24
+ WSP_GGML_API void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k);
25
+
24
26
  WSP_GGML_API void wsp_quantize_row_q2_K_ref(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int64_t k);
25
27
  WSP_GGML_API void wsp_quantize_row_q3_K_ref(const float * WSP_GGML_RESTRICT x, block_q3_K * WSP_GGML_RESTRICT y, int64_t k);
26
28
  WSP_GGML_API void wsp_quantize_row_q4_K_ref(const float * WSP_GGML_RESTRICT x, block_q4_K * WSP_GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ WSP_GGML_API void wsp_dewsp_quantize_row_q5_1(const block_q5_1 * WSP_GGML_RESTRI
45
47
  WSP_GGML_API void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
46
48
  //WSP_GGML_API void wsp_dewsp_quantize_row_q8_1(const block_q8_1 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
47
49
 
50
+ WSP_GGML_API void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
51
+
48
52
  WSP_GGML_API void wsp_dewsp_quantize_row_q2_K(const block_q2_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
49
53
  WSP_GGML_API void wsp_dewsp_quantize_row_q3_K(const block_q3_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
50
54
  WSP_GGML_API void wsp_dewsp_quantize_row_q4_K(const block_q4_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ WSP_GGML_API size_t wsp_quantize_q5_0(const float * WSP_GGML_RESTRICT src, void
90
94
  WSP_GGML_API size_t wsp_quantize_q5_1(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
91
95
  WSP_GGML_API size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
92
96
 
97
+ WSP_GGML_API size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
98
+
93
99
  WSP_GGML_API void wsp_iq2xs_init_impl(enum wsp_ggml_type type);
94
100
  WSP_GGML_API void wsp_iq2xs_free_impl(enum wsp_ggml_type type);
95
101
  WSP_GGML_API void wsp_iq3xs_init_impl(int grid_size);