whisper.rn 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +292 -130
- package/cpp/ggml-backend-impl.h +4 -4
- package/cpp/ggml-backend-reg.cpp +13 -5
- package/cpp/ggml-backend.cpp +207 -17
- package/cpp/ggml-backend.h +19 -1
- package/cpp/ggml-cpu/amx/amx.cpp +5 -2
- package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
- package/cpp/ggml-cpu/arch-fallback.h +0 -4
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
- package/cpp/ggml-cpu/ggml-cpu.c +65 -44
- package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
- package/cpp/ggml-cpu/ops.cpp +542 -775
- package/cpp/ggml-cpu/ops.h +2 -0
- package/cpp/ggml-cpu/simd-mappings.h +88 -59
- package/cpp/ggml-cpu/unary-ops.cpp +135 -0
- package/cpp/ggml-cpu/unary-ops.h +5 -0
- package/cpp/ggml-cpu/vec.cpp +227 -20
- package/cpp/ggml-cpu/vec.h +407 -56
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +94 -12
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
- package/cpp/ggml-metal/ggml-metal-device.h +244 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +40 -40
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-quants.c +1 -0
- package/cpp/ggml.c +341 -15
- package/cpp/ggml.h +150 -5
- package/cpp/jsi/RNWhisperJSI.cpp +9 -2
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +89 -72
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/RNWhisperVadContext.mm +14 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6779
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml-impl.h
CHANGED
|
@@ -73,7 +73,7 @@ static inline int wsp_ggml_up(int n, int m) {
|
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
-
// TODO: move to ggml.h?
|
|
76
|
+
// TODO: move to ggml.h? (won't be able to inline)
|
|
77
77
|
static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
|
|
78
78
|
if (a->type != b->type) {
|
|
79
79
|
return false;
|
|
@@ -89,6 +89,22 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str
|
|
|
89
89
|
return true;
|
|
90
90
|
}
|
|
91
91
|
|
|
92
|
+
static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
|
|
93
|
+
switch (op) {
|
|
94
|
+
case WSP_GGML_OP_NONE:
|
|
95
|
+
case WSP_GGML_OP_RESHAPE:
|
|
96
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
97
|
+
case WSP_GGML_OP_VIEW:
|
|
98
|
+
case WSP_GGML_OP_PERMUTE:
|
|
99
|
+
return true;
|
|
100
|
+
default:
|
|
101
|
+
return false;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static inline float wsp_ggml_softplus(float input) {
|
|
106
|
+
return (input > 20.0f) ? input : logf(1 + expf(input));
|
|
107
|
+
}
|
|
92
108
|
//
|
|
93
109
|
// logging
|
|
94
110
|
//
|
|
@@ -329,6 +345,10 @@ struct wsp_ggml_cgraph {
|
|
|
329
345
|
// if you need the gradients, get them from the original graph
|
|
330
346
|
struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
|
|
331
347
|
|
|
348
|
+
// ggml-alloc.c: true if the operation can reuse memory from its sources
|
|
349
|
+
WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
|
|
350
|
+
|
|
351
|
+
|
|
332
352
|
// Memory allocation
|
|
333
353
|
|
|
334
354
|
WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
|
|
@@ -545,14 +565,23 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
|
|
|
545
565
|
#define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
|
|
546
566
|
#define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
|
|
547
567
|
|
|
568
|
+
static inline int32_t wsp_ggml_node_get_use_count(const struct wsp_ggml_cgraph * cgraph, int node_idx) {
|
|
569
|
+
const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
|
|
570
|
+
|
|
571
|
+
size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
|
|
572
|
+
if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
|
|
573
|
+
return 0;
|
|
574
|
+
}
|
|
575
|
+
return cgraph->use_counts[hash_pos];
|
|
576
|
+
}
|
|
577
|
+
|
|
548
578
|
// return true if the node's results are only used by N other nodes
|
|
549
579
|
// and can be fused into their calculations.
|
|
550
580
|
static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
|
551
581
|
const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
|
|
552
582
|
|
|
553
583
|
// check the use count against how many we're replacing
|
|
554
|
-
|
|
555
|
-
if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
|
584
|
+
if (wsp_ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
|
|
556
585
|
return false;
|
|
557
586
|
}
|
|
558
587
|
|
|
@@ -570,27 +599,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
|
|
|
570
599
|
return true;
|
|
571
600
|
}
|
|
572
601
|
|
|
573
|
-
// Returns true if nodes
|
|
602
|
+
// Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
|
|
574
603
|
// and are fusable. Nodes are considered fusable according to this function if:
|
|
575
604
|
// - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
|
|
576
605
|
// - all nodes except the last are a src of the following node.
|
|
577
606
|
// - all nodes are the same shape.
|
|
578
607
|
// TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
|
|
579
|
-
static inline bool
|
|
580
|
-
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
581
|
-
return false;
|
|
582
|
-
}
|
|
583
|
-
|
|
608
|
+
static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
|
|
584
609
|
for (int i = 0; i < num_ops; ++i) {
|
|
585
|
-
|
|
610
|
+
if (node_idxs[i] >= cgraph->n_nodes) {
|
|
611
|
+
return false;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
|
586
615
|
if (node->op != ops[i]) {
|
|
587
616
|
return false;
|
|
588
617
|
}
|
|
589
|
-
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph,
|
|
618
|
+
if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
|
|
590
619
|
return false;
|
|
591
620
|
}
|
|
592
621
|
if (i > 0) {
|
|
593
|
-
struct wsp_ggml_tensor * prev = cgraph->nodes[
|
|
622
|
+
struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
|
|
594
623
|
if (node->src[0] != prev && node->src[1] != prev) {
|
|
595
624
|
return false;
|
|
596
625
|
}
|
|
@@ -602,6 +631,52 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
|
|
|
602
631
|
return true;
|
|
603
632
|
}
|
|
604
633
|
|
|
634
|
+
// same as above, for sequential indices starting at node_idx
|
|
635
|
+
static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
|
|
636
|
+
assert(num_ops < 32);
|
|
637
|
+
|
|
638
|
+
if (node_idx + num_ops > cgraph->n_nodes) {
|
|
639
|
+
return false;
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
int idxs[32];
|
|
643
|
+
for (int i = 0; i < num_ops; ++i) {
|
|
644
|
+
idxs[i] = node_idx + i;
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
WSP_GGML_API bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
|
|
651
|
+
const int * node_idxs,
|
|
652
|
+
int count,
|
|
653
|
+
const enum wsp_ggml_op * ops,
|
|
654
|
+
const int * outputs,
|
|
655
|
+
int num_outputs);
|
|
656
|
+
|
|
657
|
+
// Returns true if the subgraph formed by {node_idxs} can be fused
|
|
658
|
+
// checks whethers all nodes which are not part of outputs can be elided
|
|
659
|
+
// by checking if their num_uses are confined to the subgraph
|
|
660
|
+
static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
|
|
661
|
+
int node_idx,
|
|
662
|
+
int count,
|
|
663
|
+
const enum wsp_ggml_op * ops,
|
|
664
|
+
const int * outputs,
|
|
665
|
+
int num_outputs) {
|
|
666
|
+
WSP_GGML_ASSERT(count < 32);
|
|
667
|
+
if (node_idx + count > cgraph->n_nodes) {
|
|
668
|
+
return false;
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
int idxs[32];
|
|
672
|
+
|
|
673
|
+
for (int i = 0; i < count; ++i) {
|
|
674
|
+
idxs[i] = node_idx + i;
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return wsp_ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
|
678
|
+
}
|
|
679
|
+
|
|
605
680
|
#ifdef __cplusplus
|
|
606
681
|
}
|
|
607
682
|
#endif
|
|
@@ -615,6 +690,13 @@ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_id
|
|
|
615
690
|
return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
|
616
691
|
}
|
|
617
692
|
|
|
693
|
+
inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
|
|
694
|
+
int start_idx,
|
|
695
|
+
std::initializer_list<enum wsp_ggml_op> ops,
|
|
696
|
+
std::initializer_list<int> outputs = {}) {
|
|
697
|
+
return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
|
698
|
+
}
|
|
699
|
+
|
|
618
700
|
// expose GGUF internals for test code
|
|
619
701
|
WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
|
|
620
702
|
WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
#include "ggml-metal-common.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml-impl.h"
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
5
|
+
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
|
|
9
|
+
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
|
|
10
|
+
struct wsp_ggml_mem_range {
|
|
11
|
+
uint64_t pb; // buffer id
|
|
12
|
+
|
|
13
|
+
uint64_t p0; // begin
|
|
14
|
+
uint64_t p1; // end
|
|
15
|
+
|
|
16
|
+
wsp_ggml_mem_range_type pt;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
struct wsp_ggml_mem_ranges {
|
|
20
|
+
std::vector<wsp_ggml_mem_range> ranges;
|
|
21
|
+
|
|
22
|
+
int debug = 0;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug) {
|
|
26
|
+
auto * res = new wsp_ggml_mem_ranges;
|
|
27
|
+
|
|
28
|
+
res->ranges.reserve(256);
|
|
29
|
+
res->debug = debug;
|
|
30
|
+
|
|
31
|
+
return res;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs) {
|
|
35
|
+
delete mrs;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs) {
|
|
39
|
+
mrs->ranges.clear();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
static bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
|
|
43
|
+
mrs->ranges.push_back(mr);
|
|
44
|
+
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor(const wsp_ggml_tensor * tensor, wsp_ggml_mem_range_type pt) {
|
|
49
|
+
// always use the base tensor
|
|
50
|
+
tensor = tensor->view_src ? tensor->view_src : tensor;
|
|
51
|
+
|
|
52
|
+
WSP_GGML_ASSERT(!tensor->view_src);
|
|
53
|
+
|
|
54
|
+
wsp_ggml_mem_range mr;
|
|
55
|
+
|
|
56
|
+
if (tensor->buffer) {
|
|
57
|
+
// when the tensor is allocated, use the actual memory address range in the buffer
|
|
58
|
+
//
|
|
59
|
+
// take the actual allocated size with wsp_ggml_backend_buft_get_alloc_size()
|
|
60
|
+
// this can be larger than the tensor size if the buffer type allocates extra memory
|
|
61
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
|
62
|
+
mr = {
|
|
63
|
+
/*.pb =*/ (uint64_t) tensor->buffer,
|
|
64
|
+
/*.p0 =*/ (uint64_t) tensor->data,
|
|
65
|
+
/*.p1 =*/ (uint64_t) tensor->data + wsp_ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
|
|
66
|
+
/*.pt =*/ pt,
|
|
67
|
+
};
|
|
68
|
+
} else {
|
|
69
|
+
// otherwise, the pointer address is used as an unique id of the memory ranges
|
|
70
|
+
// that the tensor will be using when it is allocated
|
|
71
|
+
mr = {
|
|
72
|
+
/*.pb =*/ (uint64_t) tensor,
|
|
73
|
+
/*.p0 =*/ 0, //
|
|
74
|
+
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
|
|
75
|
+
/*.pt =*/ pt,
|
|
76
|
+
};
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
return mr;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_src(const wsp_ggml_tensor * tensor) {
|
|
83
|
+
return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
static wsp_ggml_mem_range wsp_ggml_mem_range_from_tensor_dst(const wsp_ggml_tensor * tensor) {
|
|
87
|
+
return wsp_ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
static bool wsp_ggml_mem_ranges_add_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
91
|
+
WSP_GGML_ASSERT(tensor);
|
|
92
|
+
|
|
93
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
|
|
94
|
+
|
|
95
|
+
if (mrs->debug > 2) {
|
|
96
|
+
WSP_GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return wsp_ggml_mem_ranges_add(mrs, mr);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
103
|
+
WSP_GGML_ASSERT(tensor);
|
|
104
|
+
|
|
105
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
|
|
106
|
+
|
|
107
|
+
if (mrs->debug > 2) {
|
|
108
|
+
WSP_GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
return wsp_ggml_mem_ranges_add(mrs, mr);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
115
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
116
|
+
if (tensor->src[i]) {
|
|
117
|
+
wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return wsp_ggml_mem_ranges_add_dst(mrs, tensor);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
static bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, wsp_ggml_mem_range mr) {
|
|
125
|
+
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
|
126
|
+
const auto & cmp = mrs->ranges[i];
|
|
127
|
+
|
|
128
|
+
// two memory ranges cannot intersect if they are in different buffers
|
|
129
|
+
if (mr.pb != cmp.pb) {
|
|
130
|
+
continue;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// intersecting source ranges are allowed
|
|
134
|
+
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
|
|
135
|
+
continue;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
|
|
139
|
+
if (mrs->debug > 2) {
|
|
140
|
+
WSP_GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
|
|
141
|
+
__func__,
|
|
142
|
+
mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
|
143
|
+
mr.pb, mr.p0, mr.p1,
|
|
144
|
+
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
|
145
|
+
cmp.pb, cmp.p0, cmp.p1);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
return false;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
return true;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
static bool wsp_ggml_mem_ranges_check_src(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
156
|
+
WSP_GGML_ASSERT(tensor);
|
|
157
|
+
|
|
158
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_src(tensor);
|
|
159
|
+
|
|
160
|
+
const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
|
|
161
|
+
|
|
162
|
+
return res;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
166
|
+
WSP_GGML_ASSERT(tensor);
|
|
167
|
+
|
|
168
|
+
wsp_ggml_mem_range mr = wsp_ggml_mem_range_from_tensor_dst(tensor);
|
|
169
|
+
|
|
170
|
+
const bool res = wsp_ggml_mem_ranges_check(mrs, mr);
|
|
171
|
+
|
|
172
|
+
return res;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
176
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
177
|
+
if (tensor->src[i]) {
|
|
178
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
|
179
|
+
return false;
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
return wsp_ggml_mem_ranges_check_dst(mrs, tensor);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
struct node_info {
|
|
188
|
+
wsp_ggml_tensor * node;
|
|
189
|
+
|
|
190
|
+
std::vector<wsp_ggml_tensor *> fused;
|
|
191
|
+
|
|
192
|
+
wsp_ggml_op op() const {
|
|
193
|
+
return node->op;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
const wsp_ggml_tensor * dst() const {
|
|
197
|
+
return fused.empty() ? node : fused.back();
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
bool is_empty() const {
|
|
201
|
+
return wsp_ggml_op_is_empty(node->op);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
void add_fused(wsp_ggml_tensor * t) {
|
|
205
|
+
fused.push_back(t);
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
static std::vector<int> wsp_ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
|
210
|
+
// helper to add node src and dst ranges
|
|
211
|
+
const auto & h_add = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
|
|
212
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
213
|
+
if (node.node->src[i]) {
|
|
214
|
+
if (!wsp_ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
|
215
|
+
return false;
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
// keep track of the sources of the fused nodes as well
|
|
221
|
+
for (const auto * fused : node.fused) {
|
|
222
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
223
|
+
if (fused->src[i]) {
|
|
224
|
+
if (!wsp_ggml_mem_ranges_add_src(mrs, fused->src[i])) {
|
|
225
|
+
return false;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
return wsp_ggml_mem_ranges_add_dst(mrs, node.dst());
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
// helper to check if a node can run concurrently with the existing set of nodes
|
|
235
|
+
const auto & h_check = [](wsp_ggml_mem_ranges_t mrs, const node_info & node) {
|
|
236
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
237
|
+
if (node.node->src[i]) {
|
|
238
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
|
239
|
+
return false;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
for (const auto * fused : node.fused) {
|
|
245
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
246
|
+
if (fused->src[i]) {
|
|
247
|
+
if (!wsp_ggml_mem_ranges_check_src(mrs, fused->src[i])) {
|
|
248
|
+
return false;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
return wsp_ggml_mem_ranges_check_dst(mrs, node.dst());
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
// perform reorders only across these types of ops
|
|
258
|
+
// can be expanded when needed
|
|
259
|
+
const auto & h_safe = [](wsp_ggml_op op) {
|
|
260
|
+
switch (op) {
|
|
261
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
262
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
263
|
+
case WSP_GGML_OP_ROPE:
|
|
264
|
+
case WSP_GGML_OP_NORM:
|
|
265
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
266
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
267
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
268
|
+
case WSP_GGML_OP_MUL:
|
|
269
|
+
case WSP_GGML_OP_ADD:
|
|
270
|
+
case WSP_GGML_OP_DIV:
|
|
271
|
+
case WSP_GGML_OP_GLU:
|
|
272
|
+
case WSP_GGML_OP_SCALE:
|
|
273
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
274
|
+
case WSP_GGML_OP_CPY:
|
|
275
|
+
case WSP_GGML_OP_SET_ROWS:
|
|
276
|
+
return true;
|
|
277
|
+
default:
|
|
278
|
+
return wsp_ggml_op_is_empty(op);
|
|
279
|
+
}
|
|
280
|
+
};
|
|
281
|
+
|
|
282
|
+
const int n = nodes.size();
|
|
283
|
+
|
|
284
|
+
std::vector<int> res;
|
|
285
|
+
res.reserve(n);
|
|
286
|
+
|
|
287
|
+
std::vector<bool> used(n, false);
|
|
288
|
+
|
|
289
|
+
// the memory ranges for the set of currently concurrent nodes
|
|
290
|
+
wsp_ggml_mem_ranges_t mrs0 = wsp_ggml_mem_ranges_init(0);
|
|
291
|
+
|
|
292
|
+
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
|
|
293
|
+
wsp_ggml_mem_ranges_t mrs1 = wsp_ggml_mem_ranges_init(0);
|
|
294
|
+
|
|
295
|
+
for (int i0 = 0; i0 < n; i0++) {
|
|
296
|
+
if (used[i0]) {
|
|
297
|
+
continue;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
const auto & node0 = nodes[i0];
|
|
301
|
+
|
|
302
|
+
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
|
|
303
|
+
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
|
|
304
|
+
//
|
|
305
|
+
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
|
|
306
|
+
if (!node0.is_empty() && !h_check(mrs0, node0)) {
|
|
307
|
+
// this will hold the set of memory ranges from the nodes that haven't been processed yet
|
|
308
|
+
// if a node is not concurrent with this set, we cannot reorder it
|
|
309
|
+
wsp_ggml_mem_ranges_reset(mrs1);
|
|
310
|
+
|
|
311
|
+
// initialize it with the current node
|
|
312
|
+
h_add(mrs1, node0);
|
|
313
|
+
|
|
314
|
+
// that many nodes forward to search for a concurrent node
|
|
315
|
+
constexpr int N_FORWARD = 8;
|
|
316
|
+
|
|
317
|
+
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
|
318
|
+
if (used[i1]) {
|
|
319
|
+
continue;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
const auto & node1 = nodes[i1];
|
|
323
|
+
|
|
324
|
+
// disallow reordering of certain ops
|
|
325
|
+
if (!h_safe(node1.op())) {
|
|
326
|
+
break;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
const bool is_empty = node1.is_empty();
|
|
330
|
+
|
|
331
|
+
// to reorder a node and add it to the concurrent set, it has to be:
|
|
332
|
+
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
|
|
333
|
+
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
|
|
334
|
+
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
|
|
335
|
+
// add the node to the existing concurrent set (i.e. reorder it for early execution)
|
|
336
|
+
h_add(mrs0, node1);
|
|
337
|
+
res.push_back(i1);
|
|
338
|
+
|
|
339
|
+
// mark as used, so we skip re-processing it later
|
|
340
|
+
used[i1] = true;
|
|
341
|
+
} else {
|
|
342
|
+
// expand the set of nodes that haven't been processed yet
|
|
343
|
+
h_add(mrs1, node1);
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
// finalize the concurrent set and begin a new one
|
|
348
|
+
wsp_ggml_mem_ranges_reset(mrs0);
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
// expand the concurrent set with the current node
|
|
352
|
+
{
|
|
353
|
+
h_add(mrs0, node0);
|
|
354
|
+
res.push_back(i0);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
wsp_ggml_mem_ranges_free(mrs0);
|
|
359
|
+
wsp_ggml_mem_ranges_free(mrs1);
|
|
360
|
+
|
|
361
|
+
return res;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
void wsp_ggml_graph_optimize(wsp_ggml_cgraph * gf) {
|
|
365
|
+
constexpr int MAX_FUSE = 16;
|
|
366
|
+
|
|
367
|
+
const int n = gf->n_nodes;
|
|
368
|
+
|
|
369
|
+
enum wsp_ggml_op ops[MAX_FUSE];
|
|
370
|
+
|
|
371
|
+
std::vector<node_info> nodes;
|
|
372
|
+
nodes.reserve(gf->n_nodes);
|
|
373
|
+
|
|
374
|
+
// fuse nodes:
|
|
375
|
+
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
|
|
376
|
+
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
|
|
377
|
+
for (int i = 0; i < n; i++) {
|
|
378
|
+
node_info node = {
|
|
379
|
+
/*.node =*/ gf->nodes[i],
|
|
380
|
+
/*.fused =*/ {},
|
|
381
|
+
};
|
|
382
|
+
|
|
383
|
+
// fuse only ops that start with these operations
|
|
384
|
+
// can be expanded when needed
|
|
385
|
+
if (node.op() == WSP_GGML_OP_ADD ||
|
|
386
|
+
node.op() == WSP_GGML_OP_NORM ||
|
|
387
|
+
node.op() == WSP_GGML_OP_RMS_NORM) {
|
|
388
|
+
ops[0] = node.op();
|
|
389
|
+
|
|
390
|
+
int f = i + 1;
|
|
391
|
+
while (f < n && f < i + MAX_FUSE) {
|
|
392
|
+
// conservatively allow fusing only these ops
|
|
393
|
+
// can be expanded when needed
|
|
394
|
+
if (gf->nodes[f]->op != WSP_GGML_OP_ADD &&
|
|
395
|
+
gf->nodes[f]->op != WSP_GGML_OP_MUL &&
|
|
396
|
+
gf->nodes[f]->op != WSP_GGML_OP_NORM &&
|
|
397
|
+
gf->nodes[f]->op != WSP_GGML_OP_RMS_NORM) {
|
|
398
|
+
break;
|
|
399
|
+
}
|
|
400
|
+
ops[f - i] = gf->nodes[f]->op;
|
|
401
|
+
f++;
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
f -= i;
|
|
405
|
+
for (; f > 1; f--) {
|
|
406
|
+
if (wsp_ggml_can_fuse(gf, i, ops, f)) {
|
|
407
|
+
break;
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// add the fused tensors into the node info so we can unfuse them later
|
|
412
|
+
for (int k = 1; k < f; k++) {
|
|
413
|
+
++i;
|
|
414
|
+
|
|
415
|
+
// the .dst() becomes the last fused tensor
|
|
416
|
+
node.add_fused(gf->nodes[i]);
|
|
417
|
+
}
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
nodes.push_back(std::move(node));
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
#if 1
|
|
424
|
+
// reorder to improve concurrency
|
|
425
|
+
const auto order = wsp_ggml_metal_graph_optimize_reorder(nodes);
|
|
426
|
+
#else
|
|
427
|
+
std::vector<int> order(nodes.size());
|
|
428
|
+
for (size_t i = 0; i < nodes.size(); i++) {
|
|
429
|
+
order[i] = i;
|
|
430
|
+
}
|
|
431
|
+
#endif
|
|
432
|
+
|
|
433
|
+
// unfuse
|
|
434
|
+
{
|
|
435
|
+
int j = 0;
|
|
436
|
+
for (const auto i : order) {
|
|
437
|
+
const auto & node = nodes[i];
|
|
438
|
+
|
|
439
|
+
gf->nodes[j++] = node.node;
|
|
440
|
+
|
|
441
|
+
for (auto * fused : node.fused) {
|
|
442
|
+
gf->nodes[j++] = fused;
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
// helper functions for ggml-metal that are too difficult to implement in Objective-C
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <stdbool.h>
|
|
6
|
+
|
|
7
|
+
#ifdef __cplusplus
|
|
8
|
+
extern "C" {
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
struct wsp_ggml_tensor;
|
|
12
|
+
struct wsp_ggml_cgraph;
|
|
13
|
+
|
|
14
|
+
enum wsp_ggml_mem_range_type {
|
|
15
|
+
MEM_RANGE_TYPE_SRC = 0,
|
|
16
|
+
MEM_RANGE_TYPE_DST = 1,
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
// a helper object that can be used for reordering operations to improve concurrency
|
|
20
|
+
//
|
|
21
|
+
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
|
|
22
|
+
// don't write to a memory that is being read by another task or written to by another task in the set
|
|
23
|
+
//
|
|
24
|
+
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
|
|
25
|
+
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
|
26
|
+
// tasks already in the set)
|
|
27
|
+
//
|
|
28
|
+
typedef struct wsp_ggml_mem_ranges * wsp_ggml_mem_ranges_t;
|
|
29
|
+
|
|
30
|
+
wsp_ggml_mem_ranges_t wsp_ggml_mem_ranges_init(int debug);
|
|
31
|
+
void wsp_ggml_mem_ranges_free(wsp_ggml_mem_ranges_t mrs);
|
|
32
|
+
|
|
33
|
+
// remove all ranges from the set
|
|
34
|
+
void wsp_ggml_mem_ranges_reset(wsp_ggml_mem_ranges_t mrs);
|
|
35
|
+
|
|
36
|
+
// add src or dst ranges to track
|
|
37
|
+
bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
|
|
38
|
+
|
|
39
|
+
// return false if:
|
|
40
|
+
// - new src range overlaps with any existing dst range
|
|
41
|
+
// - new dst range overlaps with any existing range (src or dst)
|
|
42
|
+
bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const struct wsp_ggml_tensor * tensor);
|
|
43
|
+
|
|
44
|
+
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
|
45
|
+
//
|
|
46
|
+
// note: this implementation is generic and not specific to metal
|
|
47
|
+
// if it proves to work well, we can start using it for other backends in the future
|
|
48
|
+
void wsp_ggml_graph_optimize(struct wsp_ggml_cgraph * gf);
|
|
49
|
+
|
|
50
|
+
#ifdef __cplusplus
|
|
51
|
+
}
|
|
52
|
+
#endif
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "ggml-metal-device.h"
|
|
4
|
+
|
|
5
|
+
#ifdef __cplusplus
|
|
6
|
+
extern "C" {
|
|
7
|
+
#endif
|
|
8
|
+
|
|
9
|
+
//
|
|
10
|
+
// backend context
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
typedef struct wsp_ggml_metal * wsp_ggml_metal_t;
|
|
14
|
+
|
|
15
|
+
wsp_ggml_metal_t wsp_ggml_metal_init(wsp_ggml_metal_device_t dev);
|
|
16
|
+
void wsp_ggml_metal_free(wsp_ggml_metal_t ctx);
|
|
17
|
+
|
|
18
|
+
void wsp_ggml_metal_synchronize(wsp_ggml_metal_t ctx);
|
|
19
|
+
|
|
20
|
+
void wsp_ggml_metal_set_tensor_async(wsp_ggml_metal_t ctx, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
|
21
|
+
void wsp_ggml_metal_get_tensor_async(wsp_ggml_metal_t ctx, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
|
22
|
+
|
|
23
|
+
enum wsp_ggml_status wsp_ggml_metal_graph_compute (wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf);
|
|
24
|
+
void wsp_ggml_metal_graph_optimize(wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf);
|
|
25
|
+
|
|
26
|
+
void wsp_ggml_metal_set_n_cb (wsp_ggml_metal_t ctx, int n_cb);
|
|
27
|
+
void wsp_ggml_metal_set_abort_callback (wsp_ggml_metal_t ctx, wsp_ggml_abort_callback abort_callback, void * user_data);
|
|
28
|
+
bool wsp_ggml_metal_supports_family (wsp_ggml_metal_t ctx, int family);
|
|
29
|
+
void wsp_ggml_metal_capture_next_compute(wsp_ggml_metal_t ctx);
|
|
30
|
+
|
|
31
|
+
#ifdef __cplusplus
|
|
32
|
+
}
|
|
33
|
+
#endif
|