whisper.rn 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/android/src/main/jni.cpp +12 -3
  4. package/cpp/ggml-alloc.c +292 -130
  5. package/cpp/ggml-backend-impl.h +4 -4
  6. package/cpp/ggml-backend-reg.cpp +13 -5
  7. package/cpp/ggml-backend.cpp +207 -17
  8. package/cpp/ggml-backend.h +19 -1
  9. package/cpp/ggml-cpu/amx/amx.cpp +5 -2
  10. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  11. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  12. package/cpp/ggml-cpu/common.h +14 -0
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
  14. package/cpp/ggml-cpu/ggml-cpu.c +65 -44
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  16. package/cpp/ggml-cpu/ops.cpp +542 -775
  17. package/cpp/ggml-cpu/ops.h +2 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  19. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  20. package/cpp/ggml-cpu/unary-ops.h +5 -0
  21. package/cpp/ggml-cpu/vec.cpp +227 -20
  22. package/cpp/ggml-cpu/vec.h +407 -56
  23. package/cpp/ggml-cpu.h +1 -1
  24. package/cpp/ggml-impl.h +94 -12
  25. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  26. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  27. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  28. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  29. package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
  30. package/cpp/ggml-metal/ggml-metal-device.h +244 -0
  31. package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
  32. package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
  33. package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
  34. package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
  35. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  36. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  37. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +40 -40
  39. package/cpp/ggml-metal.h +1 -6
  40. package/cpp/ggml-quants.c +1 -0
  41. package/cpp/ggml.c +341 -15
  42. package/cpp/ggml.h +150 -5
  43. package/cpp/jsi/RNWhisperJSI.cpp +9 -2
  44. package/cpp/jsi/ThreadPool.h +3 -3
  45. package/cpp/rn-whisper.h +1 -0
  46. package/cpp/whisper.cpp +89 -72
  47. package/cpp/whisper.h +1 -0
  48. package/ios/CMakeLists.txt +6 -1
  49. package/ios/RNWhisperContext.mm +3 -1
  50. package/ios/RNWhisperVadContext.mm +14 -13
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  61. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  70. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  72. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  74. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  80. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  81. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  82. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  83. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  86. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  87. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  92. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  93. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  94. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  95. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  96. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  97. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  98. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  99. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  101. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  102. package/lib/commonjs/version.json +1 -1
  103. package/lib/module/NativeRNWhisper.js.map +1 -1
  104. package/lib/module/version.json +1 -1
  105. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  106. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  107. package/package.json +1 -1
  108. package/src/NativeRNWhisper.ts +2 -0
  109. package/src/version.json +1 -1
  110. package/whisper-rn.podspec +8 -9
  111. package/cpp/ggml-metal.m +0 -6779
  112. package/cpp/ggml-whisper-sim.metallib +0 -0
  113. package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml-impl.h CHANGED
@@ -73,7 +73,7 @@ static inline int wsp_ggml_up(int n, int m) {
73
73
  return (n + m - 1) & ~(m - 1);
74
74
  }
75
75
 
76
- // TODO: move to ggml.h?
76
+ // TODO: move to ggml.h? (won't be able to inline)
77
77
  static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
78
78
  if (a->type != b->type) {
79
79
  return false;
@@ -89,6 +89,22 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str
89
89
  return true;
90
90
  }
91
91
 
92
+ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
93
+ switch (op) {
94
+ case WSP_GGML_OP_NONE:
95
+ case WSP_GGML_OP_RESHAPE:
96
+ case WSP_GGML_OP_TRANSPOSE:
97
+ case WSP_GGML_OP_VIEW:
98
+ case WSP_GGML_OP_PERMUTE:
99
+ return true;
100
+ default:
101
+ return false;
102
+ }
103
+ }
104
+
105
+ static inline float wsp_ggml_softplus(float input) {
106
+ return (input > 20.0f) ? input : logf(1 + expf(input));
107
+ }
92
108
  //
93
109
  // logging
94
110
  //
@@ -329,6 +345,10 @@ struct wsp_ggml_cgraph {
329
345
  // if you need the gradients, get them from the original graph
330
346
  struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
331
347
 
348
+ // ggml-alloc.c: true if the operation can reuse memory from its sources
349
+ WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
350
+
351
+
332
352
  // Memory allocation
333
353
 
334
354
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
@@ -545,14 +565,23 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
545
565
  #define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
546
566
  #define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
547
567
 
568
+ static inline int32_t wsp_ggml_node_get_use_count(const struct wsp_ggml_cgraph * cgraph, int node_idx) {
569
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
570
+
571
+ size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
572
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
573
+ return 0;
574
+ }
575
+ return cgraph->use_counts[hash_pos];
576
+ }
577
+
548
578
  // return true if the node's results are only used by N other nodes
549
579
  // and can be fused into their calculations.
550
580
  static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
551
581
  const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
552
582
 
553
583
  // check the use count against how many we're replacing
554
- size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
555
- if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
584
+ if (wsp_ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
556
585
  return false;
557
586
  }
558
587
 
@@ -570,27 +599,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
570
599
  return true;
571
600
  }
572
601
 
573
- // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
602
+ // Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
574
603
  // and are fusable. Nodes are considered fusable according to this function if:
575
604
  // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
576
605
  // - all nodes except the last are a src of the following node.
577
606
  // - all nodes are the same shape.
578
607
  // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
579
- static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
580
- if (node_idx + num_ops > cgraph->n_nodes) {
581
- return false;
582
- }
583
-
608
+ static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
584
609
  for (int i = 0; i < num_ops; ++i) {
585
- struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
610
+ if (node_idxs[i] >= cgraph->n_nodes) {
611
+ return false;
612
+ }
613
+
614
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
586
615
  if (node->op != ops[i]) {
587
616
  return false;
588
617
  }
589
- if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
618
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
590
619
  return false;
591
620
  }
592
621
  if (i > 0) {
593
- struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
622
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
594
623
  if (node->src[0] != prev && node->src[1] != prev) {
595
624
  return false;
596
625
  }
@@ -602,6 +631,52 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
602
631
  return true;
603
632
  }
604
633
 
634
+ // same as above, for sequential indices starting at node_idx
635
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
636
+ assert(num_ops < 32);
637
+
638
+ if (node_idx + num_ops > cgraph->n_nodes) {
639
+ return false;
640
+ }
641
+
642
+ int idxs[32];
643
+ for (int i = 0; i < num_ops; ++i) {
644
+ idxs[i] = node_idx + i;
645
+ }
646
+
647
+ return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
648
+ }
649
+
650
+ WSP_GGML_API bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
651
+ const int * node_idxs,
652
+ int count,
653
+ const enum wsp_ggml_op * ops,
654
+ const int * outputs,
655
+ int num_outputs);
656
+
657
+ // Returns true if the subgraph formed by {node_idxs} can be fused
658
+ // checks whethers all nodes which are not part of outputs can be elided
659
+ // by checking if their num_uses are confined to the subgraph
660
+ static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
661
+ int node_idx,
662
+ int count,
663
+ const enum wsp_ggml_op * ops,
664
+ const int * outputs,
665
+ int num_outputs) {
666
+ WSP_GGML_ASSERT(count < 32);
667
+ if (node_idx + count > cgraph->n_nodes) {
668
+ return false;
669
+ }
670
+
671
+ int idxs[32];
672
+
673
+ for (int i = 0; i < count; ++i) {
674
+ idxs[i] = node_idx + i;
675
+ }
676
+
677
+ return wsp_ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
678
+ }
679
+
605
680
  #ifdef __cplusplus
606
681
  }
607
682
  #endif
@@ -615,6 +690,13 @@ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_id
615
690
  return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
616
691
  }
617
692
 
693
+ inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
694
+ int start_idx,
695
+ std::initializer_list<enum wsp_ggml_op> ops,
696
+ std::initializer_list<int> outputs = {}) {
697
+ return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698
+ }
699
+
618
700
  // expose GGUF internals for test code
619
701
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
620
702
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -0,0 +1,446 @@
1
+ #include "ggml-metal-common.h"
2
+
3
+ #include "ggml-impl.h"
4
+ #include "ggml-backend-impl.h"
5
+
6
+ #include <vector>
7
+
8
+ // represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9
+ // the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
10
+ struct wsp_ggml_mem_range {
11
+ uint64_t pb; // buffer id
12
+
13
+ uint64_t p0; // begin
14
+ uint64_t p1; // end
15
+
16
+ wsp_ggml_mem_range_type pt;
17
+ };
18
+
19
+ struct wsp_ggml_mem_ranges {
20
+ std::vector<wsp_ggml_mem_range> ranges;
21
+
22
+ int debug = 0;
23
+ };
24
+
25
+ wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug) {
26
+ auto * res = new wsp_ggml_mem_ranges;
27
+
28
+ res->ranges.reserve(256);
29
+ res->debug = debug;
30
+
31
+ return res;
32
+ }
33
+
34
+ void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs) {
35
+ delete mrs;
36
+ }
37
+
38
+ void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs) {
39
+ mrs->ranges.clear();
40
+ }
41
+
42
+ static bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
43
+ mrs->ranges.push_back(mr);
44
+
45
+ return true;
46
+ }
47
+
48
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor(const wsp_ggml_tensor * tensor, wsp_ggml_mem_range_type pt) {
49
+ // always use the base tensor
50
+ tensor = tensor->view_src ? tensor->view_src : tensor;
51
+
52
+ WSP_GGML_ASSERT(!tensor->view_src);
53
+
54
+ wsp_ggml_mem_range mr;
55
+
56
+ if (tensor->buffer) {
57
+ // when the tensor is allocated, use the actual memory address range in the buffer
58
+ //
59
+ // take the actual allocated size with wsp_ggml_backend_buft_get_alloc_size()
60
+ // this can be larger than the tensor size if the buffer type allocates extra memory
61
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
62
+ mr = {
63
+ /*.pb =*/ (uint64_t) tensor->buffer,
64
+ /*.p0 =*/ (uint64_t) tensor->data,
65
+ /*.p1 =*/ (uint64_t) tensor->data + wsp_ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
66
+ /*.pt =*/ pt,
67
+ };
68
+ } else {
69
+ // otherwise, the pointer address is used as an unique id of the memory ranges
70
+ // that the tensor will be using when it is allocated
71
+ mr = {
72
+ /*.pb =*/ (uint64_t) tensor,
73
+ /*.p0 =*/ 0, //
74
+ /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
75
+ /*.pt =*/ pt,
76
+ };
77
+ };
78
+
79
+ return mr;
80
+ }
81
+
82
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_src(const wsp_ggml_tensor * tensor) {
83
+ return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
84
+ }
85
+
86
+ static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_dst(const wsp_ggml_tensor * tensor) {
87
+ return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
88
+ }
89
+
90
+ static bool wsp_ggml_mem_ranges_add_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
91
+ WSP_GGML_ASSERT(tensor);
92
+
93
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
94
+
95
+ if (mrs->debug > 2) {
96
+ WSP_GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
97
+ }
98
+
99
+ return wsp_ggml_mem_ranges_add(mrs, mr);
100
+ }
101
+
102
+ static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
103
+ WSP_GGML_ASSERT(tensor);
104
+
105
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
106
+
107
+ if (mrs->debug > 2) {
108
+ WSP_GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
109
+ }
110
+
111
+ return wsp_ggml_mem_ranges_add(mrs, mr);
112
+ }
113
+
114
+ bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
115
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
116
+ if (tensor->src[i]) {
117
+ wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118
+ }
119
+ }
120
+
121
+ return wsp_ggml_mem_ranges_add_dst(mrs, tensor);
122
+ }
123
+
124
+ static bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
125
+ for (size_t i = 0; i < mrs->ranges.size(); i++) {
126
+ const auto & cmp = mrs->ranges[i];
127
+
128
+ // two memory ranges cannot intersect if they are in different buffers
129
+ if (mr.pb != cmp.pb) {
130
+ continue;
131
+ }
132
+
133
+ // intersecting source ranges are allowed
134
+ if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
135
+ continue;
136
+ }
137
+
138
+ if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
139
+ if (mrs->debug > 2) {
140
+ WSP_GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
141
+ __func__,
142
+ mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
143
+ mr.pb, mr.p0, mr.p1,
144
+ cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
145
+ cmp.pb, cmp.p0, cmp.p1);
146
+ }
147
+
148
+ return false;
149
+ }
150
+ }
151
+
152
+ return true;
153
+ }
154
+
155
+ static bool wsp_ggml_mem_ranges_check_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
156
+ WSP_GGML_ASSERT(tensor);
157
+
158
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
159
+
160
+ const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
161
+
162
+ return res;
163
+ }
164
+
165
+ static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
166
+ WSP_GGML_ASSERT(tensor);
167
+
168
+ wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
169
+
170
+ const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
171
+
172
+ return res;
173
+ }
174
+
175
+ bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
176
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
177
+ if (tensor->src[i]) {
178
+ if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179
+ return false;
180
+ }
181
+ }
182
+ }
183
+
184
+ return wsp_ggml_mem_ranges_check_dst(mrs, tensor);
185
+ }
186
+
187
+ struct node_info {
188
+ wsp_ggml_tensor * node;
189
+
190
+ std::vector<wsp_ggml_tensor *> fused;
191
+
192
+ wsp_ggml_op op() const {
193
+ return node->op;
194
+ }
195
+
196
+ const wsp_ggml_tensor * dst() const {
197
+ return fused.empty() ? node : fused.back();
198
+ }
199
+
200
+ bool is_empty() const {
201
+ return wsp_ggml_op_is_empty(node->op);
202
+ }
203
+
204
+ void add_fused(wsp_ggml_tensor * t) {
205
+ fused.push_back(t);
206
+ }
207
+ };
208
+
209
+ static std::vector<int> wsp_ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
210
+ // helper to add node src and dst ranges
211
+ const auto & h_add = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
212
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
213
+ if (node.node->src[i]) {
214
+ if (!wsp_ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
215
+ return false;
216
+ }
217
+ }
218
+ }
219
+
220
+ // keep track of the sources of the fused nodes as well
221
+ for (const auto * fused : node.fused) {
222
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
223
+ if (fused->src[i]) {
224
+ if (!wsp_ggml_mem_ranges_add_src(mrs, fused->src[i])) {
225
+ return false;
226
+ }
227
+ }
228
+ }
229
+ }
230
+
231
+ return wsp_ggml_mem_ranges_add_dst(mrs, node.dst());
232
+ };
233
+
234
+ // helper to check if a node can run concurrently with the existing set of nodes
235
+ const auto & h_check = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
236
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
237
+ if (node.node->src[i]) {
238
+ if (!wsp_ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
239
+ return false;
240
+ }
241
+ }
242
+ }
243
+
244
+ for (const auto * fused : node.fused) {
245
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
246
+ if (fused->src[i]) {
247
+ if (!wsp_ggml_mem_ranges_check_src(mrs, fused->src[i])) {
248
+ return false;
249
+ }
250
+ }
251
+ }
252
+ }
253
+
254
+ return wsp_ggml_mem_ranges_check_dst(mrs, node.dst());
255
+ };
256
+
257
+ // perform reorders only across these types of ops
258
+ // can be expanded when needed
259
+ const auto & h_safe = [](wsp_ggml_op op) {
260
+ switch (op) {
261
+ case WSP_GGML_OP_MUL_MAT:
262
+ case WSP_GGML_OP_MUL_MAT_ID:
263
+ case WSP_GGML_OP_ROPE:
264
+ case WSP_GGML_OP_NORM:
265
+ case WSP_GGML_OP_RMS_NORM:
266
+ case WSP_GGML_OP_GROUP_NORM:
267
+ case WSP_GGML_OP_SUM_ROWS:
268
+ case WSP_GGML_OP_MUL:
269
+ case WSP_GGML_OP_ADD:
270
+ case WSP_GGML_OP_DIV:
271
+ case WSP_GGML_OP_GLU:
272
+ case WSP_GGML_OP_SCALE:
273
+ case WSP_GGML_OP_GET_ROWS:
274
+ case WSP_GGML_OP_CPY:
275
+ case WSP_GGML_OP_SET_ROWS:
276
+ return true;
277
+ default:
278
+ return wsp_ggml_op_is_empty(op);
279
+ }
280
+ };
281
+
282
+ const int n = nodes.size();
283
+
284
+ std::vector<int> res;
285
+ res.reserve(n);
286
+
287
+ std::vector<bool> used(n, false);
288
+
289
+ // the memory ranges for the set of currently concurrent nodes
290
+ wsp_ggml_mem_ranges_t mrs0 = wsp_ggml_mem_ranges_init(0);
291
+
292
+ // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
293
+ wsp_ggml_mem_ranges_t mrs1 = wsp_ggml_mem_ranges_init(0);
294
+
295
+ for (int i0 = 0; i0 < n; i0++) {
296
+ if (used[i0]) {
297
+ continue;
298
+ }
299
+
300
+ const auto & node0 = nodes[i0];
301
+
302
+ // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
303
+ // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
304
+ //
305
+ // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
306
+ if (!node0.is_empty() && !h_check(mrs0, node0)) {
307
+ // this will hold the set of memory ranges from the nodes that haven't been processed yet
308
+ // if a node is not concurrent with this set, we cannot reorder it
309
+ wsp_ggml_mem_ranges_reset(mrs1);
310
+
311
+ // initialize it with the current node
312
+ h_add(mrs1, node0);
313
+
314
+ // that many nodes forward to search for a concurrent node
315
+ constexpr int N_FORWARD = 8;
316
+
317
+ for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
318
+ if (used[i1]) {
319
+ continue;
320
+ }
321
+
322
+ const auto & node1 = nodes[i1];
323
+
324
+ // disallow reordering of certain ops
325
+ if (!h_safe(node1.op())) {
326
+ break;
327
+ }
328
+
329
+ const bool is_empty = node1.is_empty();
330
+
331
+ // to reorder a node and add it to the concurrent set, it has to be:
332
+ // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
333
+ // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
334
+ if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
335
+ // add the node to the existing concurrent set (i.e. reorder it for early execution)
336
+ h_add(mrs0, node1);
337
+ res.push_back(i1);
338
+
339
+ // mark as used, so we skip re-processing it later
340
+ used[i1] = true;
341
+ } else {
342
+ // expand the set of nodes that haven't been processed yet
343
+ h_add(mrs1, node1);
344
+ }
345
+ }
346
+
347
+ // finalize the concurrent set and begin a new one
348
+ wsp_ggml_mem_ranges_reset(mrs0);
349
+ }
350
+
351
+ // expand the concurrent set with the current node
352
+ {
353
+ h_add(mrs0, node0);
354
+ res.push_back(i0);
355
+ }
356
+ }
357
+
358
+ wsp_ggml_mem_ranges_free(mrs0);
359
+ wsp_ggml_mem_ranges_free(mrs1);
360
+
361
+ return res;
362
+ }
363
+
364
+ void wsp_ggml_graph_optimize(wsp_ggml_cgraph * gf) {
365
+ constexpr int MAX_FUSE = 16;
366
+
367
+ const int n = gf->n_nodes;
368
+
369
+ enum wsp_ggml_op ops[MAX_FUSE];
370
+
371
+ std::vector<node_info> nodes;
372
+ nodes.reserve(gf->n_nodes);
373
+
374
+ // fuse nodes:
375
+ // we don't want to make reorders that break fusing, so we first pack all fusable tensors
376
+ // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
377
+ for (int i = 0; i < n; i++) {
378
+ node_info node = {
379
+ /*.node =*/ gf->nodes[i],
380
+ /*.fused =*/ {},
381
+ };
382
+
383
+ // fuse only ops that start with these operations
384
+ // can be expanded when needed
385
+ if (node.op() == WSP_GGML_OP_ADD ||
386
+ node.op() == WSP_GGML_OP_NORM ||
387
+ node.op() == WSP_GGML_OP_RMS_NORM) {
388
+ ops[0] = node.op();
389
+
390
+ int f = i + 1;
391
+ while (f < n && f < i + MAX_FUSE) {
392
+ // conservatively allow fusing only these ops
393
+ // can be expanded when needed
394
+ if (gf->nodes[f]->op != WSP_GGML_OP_ADD &&
395
+ gf->nodes[f]->op != WSP_GGML_OP_MUL &&
396
+ gf->nodes[f]->op != WSP_GGML_OP_NORM &&
397
+ gf->nodes[f]->op != WSP_GGML_OP_RMS_NORM) {
398
+ break;
399
+ }
400
+ ops[f - i] = gf->nodes[f]->op;
401
+ f++;
402
+ }
403
+
404
+ f -= i;
405
+ for (; f > 1; f--) {
406
+ if (wsp_ggml_can_fuse(gf, i, ops, f)) {
407
+ break;
408
+ }
409
+ }
410
+
411
+ // add the fused tensors into the node info so we can unfuse them later
412
+ for (int k = 1; k < f; k++) {
413
+ ++i;
414
+
415
+ // the .dst() becomes the last fused tensor
416
+ node.add_fused(gf->nodes[i]);
417
+ }
418
+ }
419
+
420
+ nodes.push_back(std::move(node));
421
+ }
422
+
423
+ #if 1
424
+ // reorder to improve concurrency
425
+ const auto order = wsp_ggml_metal_graph_optimize_reorder(nodes);
426
+ #else
427
+ std::vector<int> order(nodes.size());
428
+ for (size_t i = 0; i < nodes.size(); i++) {
429
+ order[i] = i;
430
+ }
431
+ #endif
432
+
433
+ // unfuse
434
+ {
435
+ int j = 0;
436
+ for (const auto i : order) {
437
+ const auto & node = nodes[i];
438
+
439
+ gf->nodes[j++] = node.node;
440
+
441
+ for (auto * fused : node.fused) {
442
+ gf->nodes[j++] = fused;
443
+ }
444
+ }
445
+ }
446
+ }
@@ -0,0 +1,52 @@
1
+ // helper functions for ggml-metal that are too difficult to implement in Objective-C
2
+
3
+ #pragma once
4
+
5
+ #include <stdbool.h>
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ struct wsp_ggml_tensor;
12
+ struct wsp_ggml_cgraph;
13
+
14
+ enum wsp_ggml_mem_range_type {
15
+ MEM_RANGE_TYPE_SRC = 0,
16
+ MEM_RANGE_TYPE_DST = 1,
17
+ };
18
+
19
+ // a helper object that can be used for reordering operations to improve concurrency
20
+ //
21
+ // the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
22
+ // don't write to a memory that is being read by another task or written to by another task in the set
23
+ //
24
+ // with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
25
+ // can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
26
+ // tasks already in the set)
27
+ //
28
+ typedef struct wsp_ggml_mem_ranges * wsp_ggml_mem_ranges_t;
29
+
30
+ wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug);
31
+ void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs);
32
+
33
+ // remove all ranges from the set
34
+ void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs);
35
+
36
+ // add src or dst ranges to track
37
+ bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
38
+
39
+ // return false if:
40
+ // - new src range overlaps with any existing dst range
41
+ // - new dst range overlaps with any existing range (src or dst)
42
+ bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
43
+
44
+ // reorder the nodes in the graph to improve concurrency, while respecting fusion
45
+ //
46
+ // note: this implementation is generic and not specific to metal
47
+ // if it proves to work well, we can start using it for other backends in the future
48
+ void wsp_ggml_graph_optimize(struct wsp_ggml_cgraph * gf);
49
+
50
+ #ifdef __cplusplus
51
+ }
52
+ #endif
@@ -0,0 +1,33 @@
1
+ #pragma once
2
+
3
+ #include "ggml-metal-device.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ //
10
+ // backend context
11
+ //
12
+
13
+ typedef struct wsp_ggml_metal * wsp_ggml_metal_t;
14
+
15
+ wsp_ggml_metal_t wsp_ggml_metal_init(wsp_ggml_metal_device_t dev);
16
+ void wsp_ggml_metal_free(wsp_ggml_metal_t ctx);
17
+
18
+ void wsp_ggml_metal_synchronize(wsp_ggml_metal_t ctx);
19
+
20
+ void wsp_ggml_metal_set_tensor_async(wsp_ggml_metal_t ctx, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size);
21
+ void wsp_ggml_metal_get_tensor_async(wsp_ggml_metal_t ctx, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size);
22
+
23
+ enum wsp_ggml_status wsp_ggml_metal_graph_compute (wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf);
24
+ void wsp_ggml_metal_graph_optimize(wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf);
25
+
26
+ void wsp_ggml_metal_set_n_cb (wsp_ggml_metal_t ctx, int n_cb);
27
+ void wsp_ggml_metal_set_abort_callback (wsp_ggml_metal_t ctx, wsp_ggml_abort_callback abort_callback, void * user_data);
28
+ bool wsp_ggml_metal_supports_family (wsp_ggml_metal_t ctx, int family);
29
+ void wsp_ggml_metal_capture_next_compute(wsp_ggml_metal_t ctx);
30
+
31
+ #ifdef __cplusplus
32
+ }
33
+ #endif