whisper.rn 0.5.0-rc.9 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +265 -141
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +30 -13
  6. package/cpp/ggml-backend.cpp +221 -38
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-common.h +17 -0
  9. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  10. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  12. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  13. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  14. package/cpp/ggml-cpu/arch-fallback.h +32 -2
  15. package/cpp/ggml-cpu/common.h +14 -0
  16. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  17. package/cpp/ggml-cpu/ggml-cpu.c +70 -42
  18. package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
  19. package/cpp/ggml-cpu/ops.cpp +1587 -1177
  20. package/cpp/ggml-cpu/ops.h +5 -8
  21. package/cpp/ggml-cpu/quants.c +35 -0
  22. package/cpp/ggml-cpu/quants.h +8 -0
  23. package/cpp/ggml-cpu/repack.cpp +458 -47
  24. package/cpp/ggml-cpu/repack.h +22 -0
  25. package/cpp/ggml-cpu/simd-mappings.h +89 -60
  26. package/cpp/ggml-cpu/traits.cpp +2 -2
  27. package/cpp/ggml-cpu/traits.h +1 -1
  28. package/cpp/ggml-cpu/vec.cpp +170 -26
  29. package/cpp/ggml-cpu/vec.h +506 -63
  30. package/cpp/ggml-cpu.h +1 -1
  31. package/cpp/ggml-impl.h +119 -9
  32. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  33. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  34. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  35. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  36. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  37. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  38. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  39. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  40. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  41. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  42. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  43. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  44. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  45. package/cpp/ggml-metal-impl.h +90 -51
  46. package/cpp/ggml-metal.h +1 -6
  47. package/cpp/ggml-opt.cpp +97 -41
  48. package/cpp/ggml-opt.h +25 -6
  49. package/cpp/ggml-quants.c +111 -16
  50. package/cpp/ggml-quants.h +6 -0
  51. package/cpp/ggml.c +486 -98
  52. package/cpp/ggml.h +221 -16
  53. package/cpp/gguf.cpp +8 -1
  54. package/cpp/jsi/RNWhisperJSI.cpp +25 -6
  55. package/cpp/jsi/ThreadPool.h +3 -3
  56. package/cpp/whisper.cpp +100 -76
  57. package/cpp/whisper.h +1 -0
  58. package/ios/CMakeLists.txt +6 -1
  59. package/ios/RNWhisper.mm +6 -6
  60. package/ios/RNWhisperContext.mm +2 -0
  61. package/ios/RNWhisperVadContext.mm +16 -13
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  63. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  64. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  67. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  68. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  70. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  72. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  74. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  77. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  78. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  79. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  80. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  81. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  82. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  84. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  86. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  89. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  92. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  93. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  94. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  95. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  96. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  97. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  98. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  99. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  101. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  102. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  103. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  104. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  105. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  106. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  107. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  108. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  109. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  110. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  111. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  112. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  113. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  114. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  115. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  116. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  117. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  118. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  119. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  120. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  121. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  122. package/lib/commonjs/version.json +1 -1
  123. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  124. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  125. package/lib/module/version.json +1 -1
  126. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  127. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  128. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  129. package/package.json +1 -1
  130. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  131. package/src/realtime-transcription/types.ts +6 -0
  132. package/src/version.json +1 -1
  133. package/whisper-rn.podspec +8 -9
  134. package/cpp/ggml-metal.m +0 -6284
  135. package/cpp/ggml-whisper-sim.metallib +0 -0
  136. package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml-cpu.h CHANGED
@@ -101,7 +101,6 @@ extern "C" {
101
101
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
102
102
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
103
103
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
104
- WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_nnpa (void);
105
104
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
106
105
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_llamafile (void);
107
106
 
@@ -135,6 +134,7 @@ extern "C" {
135
134
  WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_cpu_reg(void);
136
135
 
137
136
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137
+ WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
138
138
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp16(const float *, wsp_ggml_fp16_t *, int64_t);
139
139
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t *, float *, int64_t);
140
140
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_bf16(const float *, wsp_ggml_bf16_t *, int64_t);
package/cpp/ggml-impl.h CHANGED
@@ -73,6 +73,35 @@ static inline int wsp_ggml_up(int n, int m) {
73
73
  return (n + m - 1) & ~(m - 1);
74
74
  }
75
75
 
76
+ // TODO: move to ggml.h? (won't be able to inline)
77
+ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
78
+ if (a->type != b->type) {
79
+ return false;
80
+ }
81
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
82
+ if (a->ne[i] != b->ne[i]) {
83
+ return false;
84
+ }
85
+ if (a->nb[i] != b->nb[i]) {
86
+ return false;
87
+ }
88
+ }
89
+ return true;
90
+ }
91
+
92
+ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
93
+ switch (op) {
94
+ case WSP_GGML_OP_NONE:
95
+ case WSP_GGML_OP_RESHAPE:
96
+ case WSP_GGML_OP_TRANSPOSE:
97
+ case WSP_GGML_OP_VIEW:
98
+ case WSP_GGML_OP_PERMUTE:
99
+ return true;
100
+ default:
101
+ return false;
102
+ }
103
+ }
104
+
76
105
  //
77
106
  // logging
78
107
  //
@@ -313,6 +342,10 @@ struct wsp_ggml_cgraph {
313
342
  // if you need the gradients, get them from the original graph
314
343
  struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
315
344
 
345
+ // ggml-alloc.c: true if the operation can reuse memory from its sources
346
+ WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
347
+
348
+
316
349
  // Memory allocation
317
350
 
318
351
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
@@ -394,6 +427,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
394
427
  #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
395
428
  #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
396
429
 
430
+ static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
431
+ uint32_t bits; // Stores the raw bit representation of the float
432
+
433
+ // Handle special case for minimum exponent (denormalized float)
434
+ if (x == 0) {
435
+ // Bit pattern for 2^(-127):
436
+ // - Sign bit: 0 (positive)
437
+ // - Exponent: 0 (denormalized number)
438
+ // - Mantissa: 0x400000 (0.5 in fractional form)
439
+ // Value = 0.5 * 2^(-126) = 2^(-127)
440
+ bits = 0x00400000;
441
+ }
442
+ // note: disabled as we don't need to handle NaNs
443
+ //// Handle special case for NaN (all bits set)
444
+ //else if (x == 0xFF) {
445
+ // // Standard quiet NaN pattern:
446
+ // // - Sign bit: 0
447
+ // // - Exponent: all 1s (0xFF)
448
+ // // - Mantissa: 0x400000 (quiet NaN flag)
449
+ // bits = 0x7FC00000;
450
+ //}
451
+ // Normalized values (most common case)
452
+ else {
453
+ // Construct normalized float by shifting exponent into position:
454
+ // - Exponent field: 8 bits (positions 30-23)
455
+ // - Mantissa: 0 (implicit leading 1)
456
+ // Value = 2^(x - 127)
457
+ bits = (uint32_t) x << 23;
458
+ }
459
+
460
+ float result; // Final float value
461
+ // Safely reinterpret bit pattern as float without type-punning issues
462
+ memcpy(&result, &bits, sizeof(float));
463
+ return result;
464
+ }
465
+
466
+ // Equal to wsp_ggml_e8m0_to_fp32/2
467
+ // Useful with MXFP4 quantization since the E0M2 values are doubled
468
+ static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
469
+ uint32_t bits;
470
+
471
+ // For x < 2: use precomputed denormal patterns
472
+ if (x < 2) {
473
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
474
+ bits = 0x00200000 << x;
475
+ }
476
+ // For x >= 2: normalized exponent adjustment
477
+ else {
478
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
479
+ bits = (uint32_t)(x - 1) << 23;
480
+ }
481
+ // Note: NaNs are not handled here
482
+
483
+ float result;
484
+ memcpy(&result, &bits, sizeof(float));
485
+ return result;
486
+ }
487
+
488
+ #define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
489
+ #define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
490
+
397
491
  /**
398
492
  * Converts brain16 to float32.
399
493
  *
@@ -493,27 +587,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
493
587
  return true;
494
588
  }
495
589
 
496
- // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
590
+ // Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
497
591
  // and are fusable. Nodes are considered fusable according to this function if:
498
592
  // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
499
593
  // - all nodes except the last are a src of the following node.
500
594
  // - all nodes are the same shape.
501
595
  // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
502
- static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
503
- if (node_idx + num_ops > cgraph->n_nodes) {
504
- return false;
505
- }
506
-
596
+ static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
507
597
  for (int i = 0; i < num_ops; ++i) {
508
- struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
598
+ if (node_idxs[i] >= cgraph->n_nodes) {
599
+ return false;
600
+ }
601
+
602
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
509
603
  if (node->op != ops[i]) {
510
604
  return false;
511
605
  }
512
- if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
606
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
513
607
  return false;
514
608
  }
515
609
  if (i > 0) {
516
- struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
610
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
517
611
  if (node->src[0] != prev && node->src[1] != prev) {
518
612
  return false;
519
613
  }
@@ -525,6 +619,22 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
525
619
  return true;
526
620
  }
527
621
 
622
+ // same as above, for sequential indices starting at node_idx
623
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
624
+ assert(num_ops < 32);
625
+
626
+ if (node_idx + num_ops > cgraph->n_nodes) {
627
+ return false;
628
+ }
629
+
630
+ int idxs[32];
631
+ for (int i = 0; i < num_ops; ++i) {
632
+ idxs[i] = node_idx + i;
633
+ }
634
+
635
+ return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
636
+ }
637
+
528
638
  #ifdef __cplusplus
529
639
  }
530
640
  #endif
@@ -0,0 +1,446 @@
1
+ #include "ggml-metal-common.h"
2
+
3
+ #include "ggml-impl.h"
4
+ #include "ggml-backend-impl.h"
5
+
6
+ #include <vector>
7
+
8
+ // represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9
+ // the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
10
+ struct wsp_ggml_mem_range {
11
+ uint64_t pb; // buffer id
12
+
13
+ uint64_t p0; // begin
14
+ uint64_t p1; // end
15
+
16
+ wsp_ggml_mem_range_type pt;
17
+ };
18
+
19
+ struct wsp_ggml_mem_ranges {
20
+ std::vector<wsp_ggml_mem_range> ranges;
21
+
22
+ int debug = 0;
23
+ };
24
+
25
+ wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug) {
26
+ auto * res = new wsp_ggml_mem_ranges;
27
+
28
+ res->ranges.reserve(256);
29
+ res->debug = debug;
30
+
31
+ return res;
32
+ }
33
+
34
+ void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs) {
35
+ delete mrs;
36
+ }
37
+
38
+ void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs) {
39
+ mrs->ranges.clear();
40
+ }
41
+
42
+ static bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
43
+ mrs->ranges.push_back(mr);
44
+
45
+ return true;
46
+ }
47
+
48
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor(const wsp_ggml_tensor * tensor, wsp_ggml_mem_range_type pt) {
49
+ // always use the base tensor
50
+ tensor = tensor->view_src ? tensor->view_src : tensor;
51
+
52
+ WSP_GGML_ASSERT(!tensor->view_src);
53
+
54
+ wsp_ggml_mem_range mr;
55
+
56
+ if (tensor->buffer) {
57
+ // when the tensor is allocated, use the actual memory address range in the buffer
58
+ //
59
+ // take the actual allocated size with wsp_ggml_backend_buft_get_alloc_size()
60
+ // this can be larger than the tensor size if the buffer type allocates extra memory
61
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
62
+ mr = {
63
+ /*.pb =*/ (uint64_t) tensor->buffer,
64
+ /*.p0 =*/ (uint64_t) tensor->data,
65
+ /*.p1 =*/ (uint64_t) tensor->data + wsp_ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
66
+ /*.pt =*/ pt,
67
+ };
68
+ } else {
69
+ // otherwise, the pointer address is used as an unique id of the memory ranges
70
+ // that the tensor will be using when it is allocated
71
+ mr = {
72
+ /*.pb =*/ (uint64_t) tensor,
73
+ /*.p0 =*/ 0, //
74
+ /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
75
+ /*.pt =*/ pt,
76
+ };
77
+ };
78
+
79
+ return mr;
80
+ }
81
+
82
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_src(const wsp_ggml_tensor * tensor) {
83
+ return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
84
+ }
85
+
86
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_dst(const wsp_ggml_tensor * tensor) {
87
+ return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
88
+ }
89
+
90
+ static bool wsp_ggml_mem_ranges_add_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
91
+ WSP_GGML_ASSERT(tensor);
92
+
93
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
94
+
95
+ if (mrs->debug > 2) {
96
+ WSP_GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
97
+ }
98
+
99
+ return wsp_ggml_mem_ranges_add(mrs, mr);
100
+ }
101
+
102
+ static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
103
+ WSP_GGML_ASSERT(tensor);
104
+
105
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
106
+
107
+ if (mrs->debug > 2) {
108
+ WSP_GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
109
+ }
110
+
111
+ return wsp_ggml_mem_ranges_add(mrs, mr);
112
+ }
113
+
114
+ bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
115
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
116
+ if (tensor->src[i]) {
117
+ wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118
+ }
119
+ }
120
+
121
+ return wsp_ggml_mem_ranges_add_dst(mrs, tensor);
122
+ }
123
+
124
+ static bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
125
+ for (size_t i = 0; i < mrs->ranges.size(); i++) {
126
+ const auto & cmp = mrs->ranges[i];
127
+
128
+ // two memory ranges cannot intersect if they are in different buffers
129
+ if (mr.pb != cmp.pb) {
130
+ continue;
131
+ }
132
+
133
+ // intersecting source ranges are allowed
134
+ if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
135
+ continue;
136
+ }
137
+
138
+ if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
139
+ if (mrs->debug > 2) {
140
+ WSP_GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
141
+ __func__,
142
+ mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
143
+ mr.pb, mr.p0, mr.p1,
144
+ cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
145
+ cmp.pb, cmp.p0, cmp.p1);
146
+ }
147
+
148
+ return false;
149
+ }
150
+ }
151
+
152
+ return true;
153
+ }
154
+
155
+ static bool wsp_ggml_mem_ranges_check_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
156
+ WSP_GGML_ASSERT(tensor);
157
+
158
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
159
+
160
+ const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
161
+
162
+ return res;
163
+ }
164
+
165
+ static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
166
+ WSP_GGML_ASSERT(tensor);
167
+
168
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
169
+
170
+ const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
171
+
172
+ return res;
173
+ }
174
+
175
+ bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
176
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
177
+ if (tensor->src[i]) {
178
+ if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179
+ return false;
180
+ }
181
+ }
182
+ }
183
+
184
+ return wsp_ggml_mem_ranges_check_dst(mrs, tensor);
185
+ }
186
+
187
+ struct node_info {
188
+ wsp_ggml_tensor * node;
189
+
190
+ std::vector<wsp_ggml_tensor *> fused;
191
+
192
+ wsp_ggml_op op() const {
193
+ return node->op;
194
+ }
195
+
196
+ const wsp_ggml_tensor * dst() const {
197
+ return fused.empty() ? node : fused.back();
198
+ }
199
+
200
+ bool is_empty() const {
201
+ return wsp_ggml_op_is_empty(node->op);
202
+ }
203
+
204
+ void add_fused(wsp_ggml_tensor * t) {
205
+ fused.push_back(t);
206
+ }
207
+ };
208
+
209
+ static std::vector<int> wsp_ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
210
+ // helper to add node src and dst ranges
211
+ const auto & h_add = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
212
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
213
+ if (node.node->src[i]) {
214
+ if (!wsp_ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
215
+ return false;
216
+ }
217
+ }
218
+ }
219
+
220
+ // keep track of the sources of the fused nodes as well
221
+ for (const auto * fused : node.fused) {
222
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
223
+ if (fused->src[i]) {
224
+ if (!wsp_ggml_mem_ranges_add_src(mrs, fused->src[i])) {
225
+ return false;
226
+ }
227
+ }
228
+ }
229
+ }
230
+
231
+ return wsp_ggml_mem_ranges_add_dst(mrs, node.dst());
232
+ };
233
+
234
+ // helper to check if a node can run concurrently with the existing set of nodes
235
+ const auto & h_check = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
236
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
237
+ if (node.node->src[i]) {
238
+ if (!wsp_ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
239
+ return false;
240
+ }
241
+ }
242
+ }
243
+
244
+ for (const auto * fused : node.fused) {
245
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
246
+ if (fused->src[i]) {
247
+ if (!wsp_ggml_mem_ranges_check_src(mrs, fused->src[i])) {
248
+ return false;
249
+ }
250
+ }
251
+ }
252
+ }
253
+
254
+ return wsp_ggml_mem_ranges_check_dst(mrs, node.dst());
255
+ };
256
+
257
+ // perform reorders only across these types of ops
258
+ // can be expanded when needed
259
+ const auto & h_safe = [](wsp_ggml_op op) {
260
+ switch (op) {
261
+ case WSP_GGML_OP_MUL_MAT:
262
+ case WSP_GGML_OP_MUL_MAT_ID:
263
+ case WSP_GGML_OP_ROPE:
264
+ case WSP_GGML_OP_NORM:
265
+ case WSP_GGML_OP_RMS_NORM:
266
+ case WSP_GGML_OP_GROUP_NORM:
267
+ case WSP_GGML_OP_SUM_ROWS:
268
+ case WSP_GGML_OP_MUL:
269
+ case WSP_GGML_OP_ADD:
270
+ case WSP_GGML_OP_DIV:
271
+ case WSP_GGML_OP_GLU:
272
+ case WSP_GGML_OP_SCALE:
273
+ case WSP_GGML_OP_GET_ROWS:
274
+ case WSP_GGML_OP_CPY:
275
+ case WSP_GGML_OP_SET_ROWS:
276
+ return true;
277
+ default:
278
+ return wsp_ggml_op_is_empty(op);
279
+ }
280
+ };
281
+
282
+ const int n = nodes.size();
283
+
284
+ std::vector<int> res;
285
+ res.reserve(n);
286
+
287
+ std::vector<bool> used(n, false);
288
+
289
+ // the memory ranges for the set of currently concurrent nodes
290
+ wsp_ggml_mem_ranges_t mrs0 = wsp_ggml_mem_ranges_init(0);
291
+
292
+ // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
293
+ wsp_ggml_mem_ranges_t mrs1 = wsp_ggml_mem_ranges_init(0);
294
+
295
+ for (int i0 = 0; i0 < n; i0++) {
296
+ if (used[i0]) {
297
+ continue;
298
+ }
299
+
300
+ const auto & node0 = nodes[i0];
301
+
302
+ // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
303
+ // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
304
+ //
305
+ // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
306
+ if (!node0.is_empty() && !h_check(mrs0, node0)) {
307
+ // this will hold the set of memory ranges from the nodes that haven't been processed yet
308
+ // if a node is not concurrent with this set, we cannot reorder it
309
+ wsp_ggml_mem_ranges_reset(mrs1);
310
+
311
+ // initialize it with the current node
312
+ h_add(mrs1, node0);
313
+
314
+ // that many nodes forward to search for a concurrent node
315
+ constexpr int N_FORWARD = 8;
316
+
317
+ for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
318
+ if (used[i1]) {
319
+ continue;
320
+ }
321
+
322
+ const auto & node1 = nodes[i1];
323
+
324
+ // disallow reordering of certain ops
325
+ if (!h_safe(node1.op())) {
326
+ break;
327
+ }
328
+
329
+ const bool is_empty = node1.is_empty();
330
+
331
+ // to reorder a node and add it to the concurrent set, it has to be:
332
+ // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
333
+ // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
334
+ if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
335
+ // add the node to the existing concurrent set (i.e. reorder it for early execution)
336
+ h_add(mrs0, node1);
337
+ res.push_back(i1);
338
+
339
+ // mark as used, so we skip re-processing it later
340
+ used[i1] = true;
341
+ } else {
342
+ // expand the set of nodes that haven't been processed yet
343
+ h_add(mrs1, node1);
344
+ }
345
+ }
346
+
347
+ // finalize the concurrent set and begin a new one
348
+ wsp_ggml_mem_ranges_reset(mrs0);
349
+ }
350
+
351
+ // expand the concurrent set with the current node
352
+ {
353
+ h_add(mrs0, node0);
354
+ res.push_back(i0);
355
+ }
356
+ }
357
+
358
+ wsp_ggml_mem_ranges_free(mrs0);
359
+ wsp_ggml_mem_ranges_free(mrs1);
360
+
361
+ return res;
362
+ }
363
+
364
+ void wsp_ggml_graph_optimize(wsp_ggml_cgraph * gf) {
365
+ constexpr int MAX_FUSE = 16;
366
+
367
+ const int n = gf->n_nodes;
368
+
369
+ enum wsp_ggml_op ops[MAX_FUSE];
370
+
371
+ std::vector<node_info> nodes;
372
+ nodes.reserve(gf->n_nodes);
373
+
374
+ // fuse nodes:
375
+ // we don't want to make reorders that break fusing, so we first pack all fusable tensors
376
+ // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
377
+ for (int i = 0; i < n; i++) {
378
+ node_info node = {
379
+ /*.node =*/ gf->nodes[i],
380
+ /*.fused =*/ {},
381
+ };
382
+
383
+ // fuse only ops that start with these operations
384
+ // can be expanded when needed
385
+ if (node.op() == WSP_GGML_OP_ADD ||
386
+ node.op() == WSP_GGML_OP_NORM ||
387
+ node.op() == WSP_GGML_OP_RMS_NORM) {
388
+ ops[0] = node.op();
389
+
390
+ int f = i + 1;
391
+ while (f < n && f < i + MAX_FUSE) {
392
+ // conservatively allow fusing only these ops
393
+ // can be expanded when needed
394
+ if (gf->nodes[f]->op != WSP_GGML_OP_ADD &&
395
+ gf->nodes[f]->op != WSP_GGML_OP_MUL &&
396
+ gf->nodes[f]->op != WSP_GGML_OP_NORM &&
397
+ gf->nodes[f]->op != WSP_GGML_OP_RMS_NORM) {
398
+ break;
399
+ }
400
+ ops[f - i] = gf->nodes[f]->op;
401
+ f++;
402
+ }
403
+
404
+ f -= i;
405
+ for (; f > 1; f--) {
406
+ if (wsp_ggml_can_fuse(gf, i, ops, f)) {
407
+ break;
408
+ }
409
+ }
410
+
411
+ // add the fused tensors into the node info so we can unfuse them later
412
+ for (int k = 1; k < f; k++) {
413
+ ++i;
414
+
415
+ // the .dst() becomes the last fused tensor
416
+ node.add_fused(gf->nodes[i]);
417
+ }
418
+ }
419
+
420
+ nodes.push_back(std::move(node));
421
+ }
422
+
423
+ #if 1
424
+ // reorder to improve concurrency
425
+ const auto order = wsp_ggml_metal_graph_optimize_reorder(nodes);
426
+ #else
427
+ std::vector<int> order(nodes.size());
428
+ for (size_t i = 0; i < nodes.size(); i++) {
429
+ order[i] = i;
430
+ }
431
+ #endif
432
+
433
+ // unfuse
434
+ {
435
+ int j = 0;
436
+ for (const auto i : order) {
437
+ const auto & node = nodes[i];
438
+
439
+ gf->nodes[j++] = node.node;
440
+
441
+ for (auto * fused : node.fused) {
442
+ gf->nodes[j++] = fused;
443
+ }
444
+ }
445
+ }
446
+ }
@@ -0,0 +1,52 @@
1
+ // helper functions for ggml-metal that are too difficult to implement in Objective-C
2
+
3
+ #pragma once
4
+
5
+ #include <stdbool.h>
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ struct wsp_ggml_tensor;
12
+ struct wsp_ggml_cgraph;
13
+
14
+ enum wsp_ggml_mem_range_type {
15
+ MEM_RANGE_TYPE_SRC = 0,
16
+ MEM_RANGE_TYPE_DST = 1,
17
+ };
18
+
19
+ // a helper object that can be used for reordering operations to improve concurrency
20
+ //
21
+ // the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
22
+ // don't write to a memory that is being read by another task or written to by another task in the set
23
+ //
24
+ // with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
25
+ // can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
26
+ // tasks already in the set)
27
+ //
28
+ typedef struct wsp_ggml_mem_ranges * wsp_ggml_mem_ranges_t;
29
+
30
+ wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug);
31
+ void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs);
32
+
33
+ // remove all ranges from the set
34
+ void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs);
35
+
36
+ // add src or dst ranges to track
37
+ bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
38
+
39
+ // return false if:
40
+ // - new src range overlaps with any existing dst range
41
+ // - new dst range overlaps with any existing range (src or dst)
42
+ bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
43
+
44
+ // reorder the nodes in the graph to improve concurrency, while respecting fusion
45
+ //
46
+ // note: this implementation is generic and not specific to metal
47
+ // if it proves to work well, we can start using it for other backends in the future
48
+ void wsp_ggml_graph_optimize(struct wsp_ggml_cgraph * gf);
49
+
50
+ #ifdef __cplusplus
51
+ }
52
+ #endif