whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
package/cpp/ggml-impl.h CHANGED
@@ -1,23 +1,32 @@
1
1
  #pragma once
2
2
 
3
- #include "ggml.h"
4
-
5
3
  // GGML internal header
6
4
 
5
+ #include "ggml.h"
6
+
7
7
  #include <assert.h>
8
- #include <stddef.h>
8
+ #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
9
9
  #include <stdbool.h>
10
- #include <string.h> // memcpy
11
- #include <math.h> // fabsf
10
+ #include <stdint.h>
12
11
 
13
12
  #ifdef __cplusplus
14
13
  extern "C" {
15
14
  #endif
16
15
 
16
+ #undef MIN
17
+ #undef MAX
18
+
19
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
20
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
21
+
22
+ // required for mmap as gguf only guarantees 32-byte alignment
23
+ #define TENSOR_ALIGNMENT 32
24
+
17
25
  // static_assert should be a #define, but if it's not,
18
26
  // fall back to the _Static_assert C11 keyword.
19
27
  // if C99 - static_assert is noop
20
28
  // ref: https://stackoverflow.com/a/53923785/4039976
29
+ #ifndef __cplusplus
21
30
  #ifndef static_assert
22
31
  #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
23
32
  #define static_assert(cond, msg) _Static_assert(cond, msg)
@@ -25,218 +34,175 @@ extern "C" {
25
34
  #define static_assert(cond, msg) struct global_scope_noop_trick
26
35
  #endif
27
36
  #endif
28
-
29
- // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
30
- #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
31
- #ifndef __FMA__
32
- #define __FMA__
33
- #endif
34
- #ifndef __F16C__
35
- #define __F16C__
36
- #endif
37
- #ifndef __SSE3__
38
- #define __SSE3__
39
- #endif
40
37
  #endif
41
38
 
42
- // 16-bit float
43
- // on Arm, we use __fp16
44
- // on x86, we use uint16_t
45
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
46
-
47
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
48
39
  //
49
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
40
+ // logging
50
41
  //
51
- #include <arm_neon.h>
52
42
 
53
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
54
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) (x)
43
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
44
+ void wsp_ggml_log_internal (enum wsp_ggml_log_level level, const char * format, ...);
45
+ void wsp_ggml_log_callback_default(enum wsp_ggml_log_level level, const char * text, void * user_data);
55
46
 
56
- #define WSP_GGML_FP16_TO_FP32(x) ((float) (x))
57
- #define WSP_GGML_FP32_TO_FP16(x) (x)
47
+ #define WSP_GGML_LOG(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_NONE , __VA_ARGS__)
48
+ #define WSP_GGML_LOG_INFO(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__)
49
+ #define WSP_GGML_LOG_WARN(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__)
50
+ #define WSP_GGML_LOG_ERROR(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
51
+ #define WSP_GGML_LOG_DEBUG(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
52
+ #define WSP_GGML_LOG_CONT(...) wsp_ggml_log_internal(WSP_GGML_LOG_LEVEL_CONT , __VA_ARGS__)
58
53
 
59
- #else
54
+ // bitset
60
55
 
61
- #ifdef __wasm_simd128__
62
- #include <wasm_simd128.h>
63
- #else
64
- #ifdef __POWER9_VECTOR__
65
- #include <altivec.h>
66
- #undef bool
67
- #define bool _Bool
68
- #else
69
- #if defined(_MSC_VER) || defined(__MINGW32__)
70
- #include <intrin.h>
71
- #else
72
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
73
- #if !defined(__riscv)
74
- #include <immintrin.h>
75
- #endif
76
- #endif
77
- #endif
78
- #endif
79
- #endif
56
+ typedef uint32_t wsp_ggml_bitset_t;
80
57
 
81
- #ifdef __riscv_v_intrinsic
82
- #include <riscv_vector.h>
83
- #endif
58
+ static_assert(sizeof(wsp_ggml_bitset_t) == 4, "bitset_t constants must be updated");
59
+ #define BITSET_SHR 5 // log2(sizeof(wsp_ggml_bitset_t)*8)
60
+ #define BITSET_MASK (sizeof(wsp_ggml_bitset_t)*8 - 1)
84
61
 
85
- #ifdef __F16C__
62
+ static size_t wsp_ggml_bitset_size(size_t n) {
63
+ return (n + BITSET_MASK) >> BITSET_SHR;
64
+ }
86
65
 
87
- #ifdef _MSC_VER
88
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
89
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
90
- #else
91
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
92
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
93
- #endif
66
+ static inline bool wsp_ggml_bitset_get(const wsp_ggml_bitset_t * bitset, size_t i) {
67
+ return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK)));
68
+ }
94
69
 
95
- #elif defined(__POWER9_VECTOR__)
96
-
97
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
98
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
99
- /* the inline asm below is about 12% faster than the lookup method */
100
- #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
101
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
102
-
103
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
104
- register float f;
105
- register double d;
106
- __asm__(
107
- "mtfprd %0,%2\n"
108
- "xscvhpdp %0,%0\n"
109
- "frsp %1,%0\n" :
110
- /* temp */ "=d"(d),
111
- /* out */ "=f"(f):
112
- /* in */ "r"(h));
113
- return f;
70
+ static inline void wsp_ggml_bitset_set(wsp_ggml_bitset_t * bitset, size_t i) {
71
+ bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK));
114
72
  }
115
73
 
116
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
117
- register double d;
118
- register wsp_ggml_fp16_t r;
119
- __asm__( /* xscvdphp can work on double or single precision */
120
- "xscvdphp %0,%2\n"
121
- "mffprd %1,%0\n" :
122
- /* temp */ "=d"(d),
123
- /* out */ "=r"(r):
124
- /* in */ "f"(f));
125
- return r;
74
+ static inline void wsp_ggml_bitset_clear(wsp_ggml_bitset_t * bitset, size_t i) {
75
+ bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK));
126
76
  }
127
77
 
128
- #else
78
+ // hash set
129
79
 
130
- // FP16 <-> FP32
131
- // ref: https://github.com/Maratyszcza/FP16
80
+ #define WSP_GGML_HASHSET_FULL ((size_t)-1)
81
+ #define WSP_GGML_HASHSET_ALREADY_EXISTS ((size_t)-2)
132
82
 
133
- static inline float fp32_from_bits(uint32_t w) {
134
- union {
135
- uint32_t as_bits;
136
- float as_value;
137
- } fp32;
138
- fp32.as_bits = w;
139
- return fp32.as_value;
140
- }
83
+ struct wsp_ggml_hash_set {
84
+ size_t size;
85
+ wsp_ggml_bitset_t * used; // whether or not the keys are in use i.e. set
86
+ struct wsp_ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if wsp_ggml_bitset_get(used, i)
87
+ };
141
88
 
142
- static inline uint32_t fp32_to_bits(float f) {
143
- union {
144
- float as_value;
145
- uint32_t as_bits;
146
- } fp32;
147
- fp32.as_value = f;
148
- return fp32.as_bits;
149
- }
89
+ struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size);
90
+ void wsp_ggml_hash_set_free(struct wsp_ggml_hash_set * hash_set);
150
91
 
151
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
152
- const uint32_t w = (uint32_t) h << 16;
153
- const uint32_t sign = w & UINT32_C(0x80000000);
154
- const uint32_t two_w = w + w;
92
+ // returns the minimum size for a hash set that can hold min_sz elements
93
+ size_t wsp_ggml_hash_size(size_t min_sz);
155
94
 
156
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
157
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
158
- const float exp_scale = 0x1.0p-112f;
159
- #else
160
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
161
- #endif
162
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
95
+ // remove all elements from the hash set
96
+ void wsp_ggml_hash_set_reset(struct wsp_ggml_hash_set * hash_set);
163
97
 
164
- const uint32_t magic_mask = UINT32_C(126) << 23;
165
- const float magic_bias = 0.5f;
166
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
98
+ // returns true if key is in the hash set
99
+ static bool wsp_ggml_hash_contains(const struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key);
167
100
 
168
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
169
- const uint32_t result = sign |
170
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
171
- return fp32_from_bits(result);
172
- }
101
+ // returns WSP_GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
102
+ static size_t wsp_ggml_hash_find(const struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key);
173
103
 
174
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
175
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
176
- const float scale_to_inf = 0x1.0p+112f;
177
- const float scale_to_zero = 0x1.0p-110f;
178
- #else
179
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
180
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
181
- #endif
182
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
183
-
184
- const uint32_t w = fp32_to_bits(f);
185
- const uint32_t shl1_w = w + w;
186
- const uint32_t sign = w & UINT32_C(0x80000000);
187
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
188
- if (bias < UINT32_C(0x71000000)) {
189
- bias = UINT32_C(0x71000000);
190
- }
104
+ // returns WSP_GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
105
+ static size_t wsp_ggml_hash_insert(struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key);
106
+
107
+ // return index, asserts if table is full
108
+ static size_t wsp_ggml_hash_find_or_insert(struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key);
191
109
 
192
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
193
- const uint32_t bits = fp32_to_bits(base);
194
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
195
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
196
- const uint32_t nonsign = exp_bits + mantissa_bits;
197
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
110
+ // hash function for wsp_ggml_tensor
111
+ static inline size_t wsp_ggml_hash(const struct wsp_ggml_tensor * p) {
112
+ // the last 4 bits are always zero due to alignment
113
+ return (size_t)(uintptr_t)p >> 4;
198
114
  }
199
115
 
200
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
201
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
116
+ static size_t wsp_ggml_hash_find(const struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key) {
117
+ size_t h = wsp_ggml_hash(key) % hash_set->size;
118
+
119
+ // linear probing
120
+ size_t i = h;
121
+ while (wsp_ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) {
122
+ i = (i + 1) % hash_set->size;
123
+ if (i == h) {
124
+ // visited all hash table entries -> not found
125
+ return WSP_GGML_HASHSET_FULL;
126
+ }
127
+ }
128
+ return i;
129
+ }
202
130
 
203
- #endif // __F16C__
131
+ static bool wsp_ggml_hash_contains(const struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key) {
132
+ size_t i = wsp_ggml_hash_find(hash_set, key);
133
+ return i != WSP_GGML_HASHSET_FULL && wsp_ggml_bitset_get(hash_set->used, i);
134
+ }
204
135
 
205
- #endif // __ARM_NEON
136
+ static size_t wsp_ggml_hash_insert(struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key) {
137
+ size_t h = wsp_ggml_hash(key) % hash_set->size;
138
+
139
+ // linear probing
140
+ size_t i = h;
141
+ do {
142
+ if (!wsp_ggml_bitset_get(hash_set->used, i)) {
143
+ wsp_ggml_bitset_set(hash_set->used, i);
144
+ hash_set->keys[i] = key;
145
+ return i;
146
+ }
147
+ if (hash_set->keys[i] == key) {
148
+ return WSP_GGML_HASHSET_ALREADY_EXISTS;
149
+ }
150
+ i = (i + 1) % hash_set->size;
151
+ } while (i != h);
152
+
153
+ // visited all hash table entries -> not found
154
+ WSP_GGML_ABORT("fatal error");
155
+ }
206
156
 
207
- // precomputed f32 table for f16 (256 KB)
208
- // defined in ggml.c, initialized in wsp_ggml_init()
209
- extern float wsp_ggml_table_f32_f16[1 << 16];
157
+ static size_t wsp_ggml_hash_find_or_insert(struct wsp_ggml_hash_set * hash_set, struct wsp_ggml_tensor * key) {
158
+ size_t h = wsp_ggml_hash(key) % hash_set->size;
159
+
160
+ // linear probing
161
+ size_t i = h;
162
+ do {
163
+ if (!wsp_ggml_bitset_get(hash_set->used, i)) {
164
+ wsp_ggml_bitset_set(hash_set->used, i);
165
+ hash_set->keys[i] = key;
166
+ return i;
167
+ }
168
+ if (hash_set->keys[i] == key) {
169
+ return i;
170
+ }
171
+ i = (i + 1) % hash_set->size;
172
+ } while (i != h);
173
+
174
+ // visited all hash table entries -> not found
175
+ WSP_GGML_ABORT("fatal error");
176
+ }
210
177
 
211
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32,
212
- // so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON.
213
- // This is also true for POWER9.
214
- #if !defined(WSP_GGML_FP16_TO_FP32) || !defined(WSP_GGML_FP32_TO_FP16)
178
+ // computation graph
215
179
 
216
- inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
217
- uint16_t s;
218
- memcpy(&s, &f, sizeof(uint16_t));
219
- return wsp_ggml_table_f32_f16[s];
220
- }
180
+ enum wsp_ggml_cgraph_eval_order {
181
+ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
182
+ WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
183
+ WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
184
+ };
221
185
 
222
- #define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x)
223
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
186
+ struct wsp_ggml_cgraph {
187
+ int size;
188
+ int n_nodes;
189
+ int n_leafs;
224
190
 
225
- #endif
191
+ struct wsp_ggml_tensor ** nodes;
192
+ struct wsp_ggml_tensor ** grads;
193
+ struct wsp_ggml_tensor ** leafs;
226
194
 
227
- #define WSP_GGML_HASHTABLE_FULL ((size_t)-1)
228
- #define WSP_GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
195
+ struct wsp_ggml_hash_set visited_hash_set;
229
196
 
230
- bool wsp_ggml_hash_contains (const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key);
197
+ enum wsp_ggml_cgraph_eval_order order;
198
+ };
231
199
 
232
- // returns WSP_GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
233
- size_t wsp_ggml_hash_find (const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key);
200
+ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
234
201
 
235
- // returns WSP_GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
236
- size_t wsp_ggml_hash_insert ( struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key);
202
+ // Memory allocation
237
203
 
238
- // return index, asserts if table is full
239
- size_t wsp_ggml_hash_find_or_insert( struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key);
204
+ void * wsp_ggml_aligned_malloc(size_t size);
205
+ void wsp_ggml_aligned_free(void * ptr, size_t size);
240
206
 
241
207
  #ifdef __cplusplus
242
208
  }
package/cpp/ggml-metal.h CHANGED
@@ -1,7 +1,9 @@
1
+ // Note: this description is outdated
2
+ //
1
3
  // An interface allowing to compute wsp_ggml_cgraph with Metal
2
4
  //
3
5
  // This is a fully functional interface that extends ggml with GPU support for Apple devices.
4
- // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
6
+ // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
5
7
  //
6
8
  // How it works?
7
9
  //
@@ -25,10 +27,6 @@
25
27
  #include <stddef.h>
26
28
  #include <stdbool.h>
27
29
 
28
- // max memory buffers that can be mapped to the device
29
- #define WSP_GGML_METAL_MAX_BUFFERS 64
30
- #define WSP_GGML_METAL_MAX_COMMAND_BUFFERS 32
31
-
32
30
  struct wsp_ggml_tensor;
33
31
  struct wsp_ggml_cgraph;
34
32
 
@@ -36,59 +34,6 @@ struct wsp_ggml_cgraph;
36
34
  extern "C" {
37
35
  #endif
38
36
 
39
- //
40
- // internal API
41
- // temporary exposed to user-code
42
- //
43
-
44
- struct wsp_ggml_metal_context;
45
-
46
- void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data);
47
-
48
- // number of command buffers to use
49
- struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb);
50
- void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx);
51
-
52
- void * wsp_ggml_metal_host_malloc(size_t n);
53
- void wsp_ggml_metal_host_free (void * data);
54
-
55
- // set the number of command buffers to use
56
- void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb);
57
-
58
- // creates a mapping between a host memory buffer and a device memory buffer
59
- // - make sure to map all buffers used in the graph before calling wsp_ggml_metal_graph_compute
60
- // - the mapping is used during computation to determine the arguments of the compute kernels
61
- // - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
62
- // - max_size specifies the maximum size of a tensor and is used to create shared views such
63
- // that it is guaranteed that the tensor will fit in at least one of the views
64
- //
65
- bool wsp_ggml_metal_add_buffer(
66
- struct wsp_ggml_metal_context * ctx,
67
- const char * name,
68
- void * data,
69
- size_t size,
70
- size_t max_size);
71
-
72
- // set data from host memory into the device
73
- void wsp_ggml_metal_set_tensor(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t);
74
-
75
- // get data from the device into host memory
76
- void wsp_ggml_metal_get_tensor(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t);
77
-
78
- // try to find operations that can be run concurrently in the graph
79
- // you should run it again if the topology of your graph changes
80
- void wsp_ggml_metal_graph_find_concurrency(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf, bool check_mem);
81
-
82
- // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
83
- int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx);
84
-
85
- // output the concur_list for wsp_ggml_alloc
86
- int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx);
87
-
88
- // same as wsp_ggml_graph_compute but uses Metal
89
- // creates gf->n_threads command buffers in parallel
90
- void wsp_ggml_metal_graph_compute(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf);
91
-
92
37
  //
93
38
  // backend API
94
39
  // user-code should use only these functions
@@ -98,7 +43,12 @@ WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void);
98
43
 
99
44
  WSP_GGML_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend);
100
45
 
101
- WSP_GGML_API void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb);
46
+ WSP_GGML_DEPRECATED(
47
+ WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
+ "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
49
+
50
+ WSP_GGML_API void wsp_ggml_backend_metal_set_abort_callback(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * user_data);
51
+
102
52
  WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
103
53
 
104
54
  // helper to check if the device supports a specific family
@@ -106,7 +56,11 @@ WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(v
106
56
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
107
57
  WSP_GGML_API bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family);
108
58
 
59
+ // capture all command buffers committed the next time `wsp_ggml_backend_graph_compute` is called
60
+ WSP_GGML_API void wsp_ggml_backend_metal_capture_next_compute(wsp_ggml_backend_t backend);
61
+
62
+ WSP_GGML_API wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void);
63
+
109
64
  #ifdef __cplusplus
110
65
  }
111
66
  #endif
112
-