whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
package/cpp/ggml-alloc.c CHANGED
@@ -1,69 +1,21 @@
1
1
  #include "ggml-alloc.h"
2
+ #include "ggml-backend-impl.h"
2
3
  #include "ggml.h"
4
+ #include "ggml-impl.h"
3
5
  #include <assert.h>
6
+ #include <limits.h>
4
7
  #include <stdarg.h>
5
8
  #include <stdio.h>
6
9
  #include <stdlib.h>
7
10
  #include <string.h>
8
11
 
9
- #ifdef __has_include
10
- #if __has_include(<unistd.h>)
11
- #include <unistd.h>
12
- #if defined(_POSIX_MAPPED_FILES)
13
- #include <sys/types.h>
14
- #include <sys/mman.h>
15
- #endif
16
- #endif
17
- #endif
18
-
19
- #if defined(_WIN32)
20
- #define WIN32_LEAN_AND_MEAN
21
- #ifndef NOMINMAX
22
- #define NOMINMAX
23
- #endif
24
- #include <windows.h>
25
- #include <memoryapi.h>
26
- #endif
27
-
28
-
29
- #define UNUSED(x) (void)(x)
30
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
31
- #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES)
13
+ #define MAX_FREE_BLOCKS 256
32
14
 
33
15
  //#define WSP_GGML_ALLOCATOR_DEBUG
34
16
 
35
- //#define AT_PRINTF printf
36
- #define AT_PRINTF(...) ((void)0)
37
-
38
- struct hash_node {
39
- struct wsp_ggml_tensor * t;
40
- int n_children;
41
- int n_views;
42
- };
43
-
44
- static size_t hash(void * p) {
45
- return (size_t)p % WSP_GGML_GRAPH_HASHTABLE_SIZE;
46
- }
47
-
48
- static struct hash_node * hash_get(struct hash_node hash_table[], struct wsp_ggml_tensor * t) {
49
- size_t h = hash(t);
50
-
51
- // linear probing
52
- size_t i = h;
53
- while (hash_table[i].t != NULL) {
54
- if (hash_table[i].t == t) {
55
- return &hash_table[i];
56
- }
57
- i = (i + 1) % WSP_GGML_GRAPH_HASHTABLE_SIZE;
58
- if (i == h) {
59
- // hash table is full
60
- WSP_GGML_ASSERT(false);
61
- }
62
- }
63
-
64
- hash_table[i].t = t;
65
- return &hash_table[i];
66
- }
17
+ //#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
18
+ #define AT_PRINTF(...)
67
19
 
68
20
  // TODO: WSP_GGML_PAD ?
69
21
  static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
@@ -77,19 +29,18 @@ struct free_block {
77
29
  size_t size;
78
30
  };
79
31
 
80
- #define MAX_FREE_BLOCKS 128
81
-
82
- struct wsp_ggml_allocr {
83
- void * data;
84
- size_t size;
32
+ struct wsp_ggml_tallocr {
33
+ struct wsp_ggml_backend_buffer * buffer;
34
+ bool buffer_owned;
35
+ void * base;
85
36
  size_t alignment;
37
+
86
38
  int n_free_blocks;
87
39
  struct free_block free_blocks[MAX_FREE_BLOCKS];
88
- struct hash_node hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
40
+
89
41
  size_t max_size;
42
+
90
43
  bool measure;
91
- int parse_seq[WSP_GGML_MAX_CONCUR];
92
- int parse_seq_len;
93
44
 
94
45
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
95
46
  struct wsp_ggml_tensor * allocated_tensors[1024];
@@ -97,7 +48,7 @@ struct wsp_ggml_allocr {
97
48
  };
98
49
 
99
50
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
100
- static void add_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) {
51
+ static void add_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
101
52
  for (int i = 0; i < 1024; i++) {
102
53
  if (alloc->allocated_tensors[i] == NULL) {
103
54
  alloc->allocated_tensors[i] = tensor;
@@ -106,7 +57,7 @@ static void add_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml
106
57
  }
107
58
  WSP_GGML_ASSERT(!"out of allocated_tensors");
108
59
  }
109
- static void remove_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) {
60
+ static void remove_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
110
61
  for (int i = 0; i < 1024; i++) {
111
62
  if (alloc->allocated_tensors[i] == tensor ||
112
63
  (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
@@ -119,28 +70,20 @@ static void remove_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_g
119
70
  }
120
71
  #endif
121
72
 
122
- static size_t wsp_ggml_allocr_get_alloc_size(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) {
123
- return wsp_ggml_nbytes(tensor);
124
-
125
- UNUSED(alloc);
126
- }
127
-
128
73
  // check if a tensor is allocated by this buffer
129
- static bool wsp_ggml_allocr_is_own(struct wsp_ggml_allocr * alloc, const struct wsp_ggml_tensor * tensor) {
130
- void * ptr = tensor->data;
131
- return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
74
+ static bool wsp_ggml_tallocr_is_own(wsp_ggml_tallocr_t alloc, const struct wsp_ggml_tensor * tensor) {
75
+ return tensor->buffer == alloc->buffer;
132
76
  }
133
77
 
134
78
  static bool wsp_ggml_is_view(struct wsp_ggml_tensor * t) {
135
79
  return t->view_src != NULL;
136
80
  }
137
81
 
138
- void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) {
139
- #ifdef WSP_GGML_ALLOCATOR_DEBUG
82
+ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
140
83
  WSP_GGML_ASSERT(!wsp_ggml_is_view(tensor)); // views generally get data pointer from one of their sources
141
84
  WSP_GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
142
- #endif
143
- size_t size = wsp_ggml_allocr_get_alloc_size(alloc, tensor);
85
+
86
+ size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
144
87
  size = aligned_offset(NULL, size, alloc->alignment);
145
88
 
146
89
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -187,10 +130,14 @@ void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tenso
187
130
  }
188
131
 
189
132
  tensor->data = addr;
133
+ tensor->buffer = alloc->buffer;
134
+ if (!alloc->measure) {
135
+ wsp_ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
136
+ }
190
137
 
191
138
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
192
139
  add_allocated_tensor(alloc, tensor);
193
- size_t cur_max = (char*)addr - (char*)alloc->data + size;
140
+ size_t cur_max = (char*)addr - (char*)alloc->base + size;
194
141
  if (cur_max > alloc->max_size) {
195
142
  printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
196
143
  for (int i = 0; i < 1024; i++) {
@@ -202,23 +149,24 @@ void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tenso
202
149
  }
203
150
  #endif
204
151
 
205
- alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
152
+ alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
206
153
  }
207
154
 
208
155
  // this is a very naive implementation, but for our case the number of free blocks should be very small
209
- static void wsp_ggml_allocr_free_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) {
210
- void * ptr = tensor->data;
211
-
212
- if (wsp_ggml_allocr_is_own(alloc, tensor) == false) {
156
+ static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
157
+ if (wsp_ggml_tallocr_is_own(alloc, tensor) == false) {
213
158
  // the tensor was not allocated in this buffer
214
159
  // this can happen because the graph allocator will try to free weights and other tensors from different buffers
215
160
  // the easiest way to deal with this is just to ignore it
161
+ // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
216
162
  return;
217
163
  }
218
164
 
219
- size_t size = wsp_ggml_allocr_get_alloc_size(alloc, tensor);
165
+ void * ptr = tensor->data;
166
+
167
+ size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
220
168
  size = aligned_offset(NULL, size, alloc->alignment);
221
- AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
169
+ AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
222
170
 
223
171
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
224
172
  remove_allocated_tensor(alloc, tensor);
@@ -272,136 +220,180 @@ static void wsp_ggml_allocr_free_tensor(struct wsp_ggml_allocr * alloc, struct w
272
220
  alloc->n_free_blocks++;
273
221
  }
274
222
 
275
- void wsp_ggml_allocr_set_parse_seq(struct wsp_ggml_allocr * alloc, const int * list, int n) {
276
- for (int i = 0; i < n; i++) {
277
- alloc->parse_seq[i] = list[i];
223
+ void wsp_ggml_tallocr_reset(wsp_ggml_tallocr_t alloc) {
224
+ alloc->n_free_blocks = 1;
225
+ size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
226
+ alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
227
+
228
+ if (alloc->measure) {
229
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
230
+ } else {
231
+ alloc->free_blocks[0].size = wsp_ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
278
232
  }
279
- alloc->parse_seq_len = n;
280
233
  }
281
234
 
282
- void wsp_ggml_allocr_reset(struct wsp_ggml_allocr * alloc) {
283
- alloc->n_free_blocks = 1;
284
- size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
285
- alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
286
- alloc->free_blocks[0].size = alloc->size - align_offset;
287
- }
235
+ wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment) {
236
+ struct wsp_ggml_backend_buffer * buffer = wsp_ggml_backend_cpu_buffer_from_ptr(data, size);
288
237
 
289
- struct wsp_ggml_allocr * wsp_ggml_allocr_new(void * data, size_t size, size_t alignment) {
290
- struct wsp_ggml_allocr * alloc = (struct wsp_ggml_allocr *)malloc(sizeof(struct wsp_ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
238
+ wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr));
291
239
 
292
- *alloc = (struct wsp_ggml_allocr){
293
- /*.data = */ data,
294
- /*.size = */ size,
240
+ *alloc = (struct wsp_ggml_tallocr) {
241
+ /*.buffer = */ buffer,
242
+ /*.buffer_owned = */ true,
243
+ /*.base = */ wsp_ggml_backend_buffer_get_base(buffer),
295
244
  /*.alignment = */ alignment,
296
245
  /*.n_free_blocks = */ 0,
297
246
  /*.free_blocks = */ {{0}},
298
- /*.hash_table = */ {{0}},
299
247
  /*.max_size = */ 0,
300
248
  /*.measure = */ false,
301
- /*.parse_seq = */ {0},
302
- /*.parse_seq_len = */ 0,
303
249
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
304
250
  /*.allocated_tensors = */ {0},
305
251
  #endif
306
252
  };
307
253
 
308
- wsp_ggml_allocr_reset(alloc);
254
+ wsp_ggml_tallocr_reset(alloc);
309
255
 
310
256
  return alloc;
311
257
  }
312
258
 
313
- // OS specific functions to allocate and free uncommitted virtual memory
314
- static void * alloc_vmem(size_t size) {
315
- #if defined(_WIN32)
316
- return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
317
- #elif defined(_POSIX_MAPPED_FILES)
318
- void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
319
- if (ptr == MAP_FAILED) {
320
- return NULL;
321
- }
322
- return ptr;
323
- #else
324
- // use a fixed address for other platforms
325
- uintptr_t base_addr = (uintptr_t)-size - 0x100;
326
- return (void *)base_addr;
327
- #endif
328
- }
259
+ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment) {
260
+ wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
261
+ alloc->measure = true;
329
262
 
330
- static void free_vmem(void * base_addr, size_t size) {
331
- #if defined(_WIN32)
332
- VirtualFree(base_addr, 0, MEM_RELEASE);
333
- UNUSED(size);
334
- #elif defined(_POSIX_MAPPED_FILES)
335
- munmap(base_addr, size);
336
- #else
337
- // nothing to do
338
- UNUSED(base_addr);
339
- UNUSED(size);
340
- #endif
263
+ return alloc;
341
264
  }
342
265
 
343
- // allocate uncommitted virtual memory to measure the size of the graph
344
- static void alloc_measure_vmem(void ** base_addr, size_t * size) {
345
- // 128GB for 64-bit, 1GB for 32-bit
346
- *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
347
- do {
348
- *base_addr = alloc_vmem(*size);
349
- if (*base_addr != NULL) {
350
- AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
351
- return;
352
- }
353
- // try again with half the size
354
- *size /= 2;
355
- } while (*size > 0);
266
+ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend) {
267
+ // create a backend buffer to get the correct tensor allocation sizes
268
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, 1);
356
269
 
357
- WSP_GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
270
+ // TODO: move alloc initialization to a common wsp_ggml_tallocr_new_impl function
271
+ wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer);
272
+ alloc->buffer_owned = true;
273
+ alloc->measure = true;
274
+ wsp_ggml_tallocr_reset(alloc);
275
+ return alloc;
358
276
  }
359
277
 
360
- static void free_measure_vmem(void * base_addr, size_t size) {
361
- free_vmem(base_addr, size);
278
+ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) {
279
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, size);
280
+ wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer);
281
+ alloc->buffer_owned = true;
282
+ return alloc;
362
283
  }
363
284
 
364
- struct wsp_ggml_allocr * wsp_ggml_allocr_new_measure(size_t alignment) {
365
- struct wsp_ggml_allocr * alloc = (struct wsp_ggml_allocr *)malloc(sizeof(struct wsp_ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
366
-
367
- void * base_addr;
368
- size_t size;
369
-
370
- alloc_measure_vmem(&base_addr, &size);
285
+ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) {
286
+ wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr));
371
287
 
372
- *alloc = (struct wsp_ggml_allocr){
373
- /*.data = */ base_addr,
374
- /*.size = */ size,
375
- /*.alignment = */ alignment,
288
+ *alloc = (struct wsp_ggml_tallocr) {
289
+ /*.buffer = */ buffer,
290
+ /*.buffer_owned = */ false,
291
+ /*.base = */ wsp_ggml_backend_buffer_get_base(buffer),
292
+ /*.alignment = */ wsp_ggml_backend_buffer_get_alignment(buffer),
376
293
  /*.n_free_blocks = */ 0,
377
294
  /*.free_blocks = */ {{0}},
378
- /*.hash_table = */ {{0}},
379
295
  /*.max_size = */ 0,
380
- /*.measure = */ true,
381
- /*.parse_seq = */ {0},
382
- /*.parse_seq_len = */ 0,
296
+ /*.measure = */ false,
383
297
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
384
298
  /*.allocated_tensors = */ {0},
385
299
  #endif
386
300
  };
387
301
 
388
- wsp_ggml_allocr_reset(alloc);
302
+ wsp_ggml_tallocr_reset(alloc);
389
303
 
390
304
  return alloc;
391
305
  }
392
306
 
393
- void wsp_ggml_allocr_free(struct wsp_ggml_allocr * alloc) {
394
- if (alloc->measure) {
395
- free_measure_vmem(alloc->data, alloc->size);
307
+ struct wsp_ggml_backend_buffer * wsp_ggml_tallocr_get_buffer(wsp_ggml_tallocr_t alloc) {
308
+ return alloc->buffer;
309
+ }
310
+
311
+ void wsp_ggml_tallocr_free(wsp_ggml_tallocr_t alloc) {
312
+ if (alloc == NULL) {
313
+ return;
314
+ }
315
+
316
+ if (alloc->buffer_owned) {
317
+ wsp_ggml_backend_buffer_free(alloc->buffer);
396
318
  }
397
319
  free(alloc);
398
320
  }
399
321
 
400
- bool wsp_ggml_allocr_is_measure(struct wsp_ggml_allocr * alloc) {
322
+ bool wsp_ggml_tallocr_is_measure(wsp_ggml_tallocr_t alloc) {
401
323
  return alloc->measure;
402
324
  }
403
325
 
404
- //////////// compute graph allocator
326
+ size_t wsp_ggml_tallocr_max_size(wsp_ggml_tallocr_t alloc) {
327
+ return alloc->max_size;
328
+ }
329
+
330
+ // graph allocator
331
+
332
+ struct hash_node {
333
+ int n_children;
334
+ int n_views;
335
+ };
336
+
337
+ struct wsp_ggml_gallocr {
338
+ wsp_ggml_tallocr_t talloc;
339
+ struct wsp_ggml_hash_set hash_set;
340
+ struct hash_node * hash_values;
341
+ size_t hash_values_size;
342
+ wsp_ggml_tallocr_t * hash_allocs;
343
+ int * parse_seq;
344
+ int parse_seq_len;
345
+ };
346
+
347
+ wsp_ggml_gallocr_t wsp_ggml_gallocr_new(void) {
348
+ wsp_ggml_gallocr_t galloc = (wsp_ggml_gallocr_t)malloc(sizeof(struct wsp_ggml_gallocr));
349
+
350
+ *galloc = (struct wsp_ggml_gallocr) {
351
+ /*.talloc = */ NULL,
352
+ /*.hash_set = */ {0},
353
+ /*.hash_values = */ NULL,
354
+ /*.hash_values_size = */ 0,
355
+ /*.hash_allocs = */ NULL,
356
+ /*.parse_seq = */ NULL,
357
+ /*.parse_seq_len = */ 0,
358
+ };
359
+
360
+ return galloc;
361
+ }
362
+
363
+ void wsp_ggml_gallocr_free(wsp_ggml_gallocr_t galloc) {
364
+ if (galloc == NULL) {
365
+ return;
366
+ }
367
+
368
+ if (galloc->hash_set.keys != NULL) {
369
+ free(galloc->hash_set.keys);
370
+ }
371
+ if (galloc->hash_values != NULL) {
372
+ free(galloc->hash_values);
373
+ }
374
+ if (galloc->hash_allocs != NULL) {
375
+ free(galloc->hash_allocs);
376
+ }
377
+ if (galloc->parse_seq != NULL) {
378
+ free(galloc->parse_seq);
379
+ }
380
+ free(galloc);
381
+ }
382
+
383
+ void wsp_ggml_gallocr_set_parse_seq(wsp_ggml_gallocr_t galloc, const int * list, int n) {
384
+ free(galloc->parse_seq);
385
+ galloc->parse_seq = malloc(sizeof(int) * n);
386
+
387
+ for (int i = 0; i < n; i++) {
388
+ galloc->parse_seq[i] = list[i];
389
+ }
390
+ galloc->parse_seq_len = n;
391
+ }
392
+
393
+ static struct hash_node * hash_get(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
394
+ size_t i = wsp_ggml_hash_find_or_insert(galloc->hash_set, t);
395
+ return &galloc->hash_values[i];
396
+ }
405
397
 
406
398
  static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
407
399
  if (a->type != b->type) {
@@ -435,7 +427,6 @@ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
435
427
  case WSP_GGML_OP_ROPE:
436
428
  case WSP_GGML_OP_RMS_NORM:
437
429
  case WSP_GGML_OP_SOFT_MAX:
438
- case WSP_GGML_OP_CONT:
439
430
  return true;
440
431
 
441
432
  default:
@@ -443,12 +434,39 @@ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
443
434
  }
444
435
  }
445
436
 
446
- static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * node) {
447
- struct hash_node * ht = alloc->hash_table;
437
+ static wsp_ggml_tallocr_t node_tallocr(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
438
+ if (galloc->talloc != NULL) {
439
+ return galloc->talloc;
440
+ }
441
+
442
+ return galloc->hash_allocs[wsp_ggml_hash_find_or_insert(galloc->hash_set, node)];
443
+ }
444
+
445
+ static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view, bool update_backend) {
446
+ wsp_ggml_tallocr_t alloc = node_tallocr(galloc, view);
447
+
448
+ WSP_GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
449
+ if (update_backend) {
450
+ view->backend = view->view_src->backend;
451
+ }
452
+ view->buffer = view->view_src->buffer;
453
+ view->data = (char *)view->view_src->data + view->view_offs;
454
+
455
+ // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
456
+ // due to the wsp_ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
457
+ assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
458
+
459
+ if (!alloc->measure) {
460
+ wsp_ggml_backend_buffer_init_tensor(alloc->buffer, view);
461
+ }
462
+ }
463
+
464
+ static void allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
465
+ wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node);
466
+
448
467
  if (node->data == NULL) {
449
468
  if (wsp_ggml_is_view(node)) {
450
- assert(node->view_src->data != NULL);
451
- node->data = (char *)node->view_src->data + node->view_offs;
469
+ init_view(galloc, node, true);
452
470
  } else {
453
471
  // see if we can reuse a parent's buffer (inplace)
454
472
  if (wsp_ggml_op_can_inplace(node->op)) {
@@ -459,16 +477,16 @@ static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor
459
477
  }
460
478
 
461
479
  // if the node's data is external, then we cannot re-use it
462
- if (wsp_ggml_allocr_is_own(alloc, parent) == false) {
480
+ if (wsp_ggml_tallocr_is_own(alloc, parent) == false) {
463
481
  AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
464
482
  continue;
465
483
  }
466
484
 
467
- struct hash_node * p_hn = hash_get(ht, parent);
485
+ struct hash_node * p_hn = hash_get(galloc, parent);
468
486
  if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && wsp_ggml_are_same_layout(node, parent)) {
469
487
  if (wsp_ggml_is_view(parent)) {
470
488
  struct wsp_ggml_tensor * view_src = parent->view_src;
471
- struct hash_node * view_src_hn = hash_get(ht, view_src);
489
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
472
490
  if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
473
491
  // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
474
492
  // the parent's data that it will need later (same layout requirement). the problem is that then
@@ -476,158 +494,309 @@ static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor
476
494
  // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
477
495
  // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
478
496
  AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
479
- node->data = parent->data;
497
+ node->view_src = view_src;
498
+ view_src_hn->n_views += 1;
499
+ init_view(galloc, node, false);
480
500
  return;
481
501
  }
482
- }
483
- else {
502
+ } else {
484
503
  AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
485
- node->data = parent->data;
504
+ node->view_src = parent;
505
+ p_hn->n_views += 1;
506
+ init_view(galloc, node, false);
486
507
  return;
487
508
  }
488
509
  }
489
510
  }
490
511
  }
491
- wsp_ggml_allocr_alloc(alloc, node);
512
+ wsp_ggml_tallocr_alloc(alloc, node);
492
513
  }
493
514
  }
494
515
  }
495
516
 
496
- static size_t wsp_ggml_allocr_alloc_graph_tensors_n(
497
- struct wsp_ggml_allocr * alloc,
498
- struct wsp_ggml_cgraph ** graphs, int n_graphs,
499
- struct wsp_ggml_tensor *** inputs, struct wsp_ggml_tensor *** outputs) {
517
+ static void free_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
518
+ wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node);
500
519
 
501
- // reset hash table
502
- struct hash_node * ht = alloc->hash_table;
503
- memset(ht, 0, sizeof(struct hash_node) * WSP_GGML_GRAPH_HASHTABLE_SIZE);
520
+ wsp_ggml_tallocr_free_tensor(alloc, node);
521
+ }
522
+
523
+ static void wsp_ggml_tallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * gf) {
524
+ const int * parse_seq = galloc->parse_seq;
525
+ int parse_seq_len = galloc->parse_seq_len;
504
526
 
505
527
  // count number of children and views
506
- for (int g = 0; g < n_graphs; g++) {
507
- struct wsp_ggml_cgraph * gf = graphs[g];
508
- for (int i = 0; i < gf->n_nodes; i++) {
509
- struct wsp_ggml_tensor * node = gf->nodes[i];
528
+ for (int i = 0; i < gf->n_nodes; i++) {
529
+ struct wsp_ggml_tensor * node = gf->nodes[i];
510
530
 
511
- if (wsp_ggml_is_view(node)) {
512
- struct wsp_ggml_tensor * view_src = node->view_src;
513
- hash_get(ht, view_src)->n_views += 1;
531
+ if (wsp_ggml_is_view(node)) {
532
+ struct wsp_ggml_tensor * view_src = node->view_src;
533
+ hash_get(galloc, view_src)->n_views += 1;
534
+ if (node->buffer == NULL && node->data != NULL) {
535
+ // view of a pre-allocated tensor, didn't call init_view() yet
536
+ init_view(galloc, node, true);
537
+ }
538
+ }
539
+
540
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
541
+ struct wsp_ggml_tensor * parent = node->src[j];
542
+ if (parent == NULL) {
543
+ break;
544
+ }
545
+ hash_get(galloc, parent)->n_children += 1;
546
+ if (wsp_ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
547
+ init_view(galloc, parent, true);
514
548
  }
549
+ }
550
+ }
551
+
552
+ // allocate tensors
553
+ // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
554
+ int last_barrier_pos = 0;
555
+ int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
556
+
557
+ for (int ind = 0; ind < n_nodes; ind++) {
558
+ // allocate a node if there is no parse_seq or this is not a barrier
559
+ if (parse_seq_len == 0 || parse_seq[ind] != -1) {
560
+ int i = parse_seq_len ? parse_seq[ind] : ind;
561
+ struct wsp_ggml_tensor * node = gf->nodes[i];
515
562
 
563
+ // allocate parents (leafs)
516
564
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
517
565
  struct wsp_ggml_tensor * parent = node->src[j];
518
566
  if (parent == NULL) {
519
567
  break;
520
568
  }
521
- hash_get(ht, parent)->n_children += 1;
569
+ allocate_node(galloc, parent);
522
570
  }
523
- }
524
- }
525
571
 
526
- // allocate tensors
527
- for (int g = 0; g < n_graphs; g++) {
528
- struct wsp_ggml_cgraph * gf = graphs[g];
529
- AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
530
- // graph inputs are allocated first to ensure that they are not overwritten by each other
531
- if (inputs != NULL && inputs[g] != NULL) {
532
- for (int i = 0; inputs[g][i] != NULL; i++) {
533
- struct wsp_ggml_tensor * input = inputs[g][i];
534
- AT_PRINTF("input: %s\n", input->name);
535
- allocate_node(alloc, input);
572
+ // allocate node
573
+ allocate_node(galloc, node);
574
+
575
+ AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_name(node->op), node->name);
576
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
577
+ struct wsp_ggml_tensor * parent = node->src[j];
578
+ if (parent == NULL) {
579
+ break;
580
+ }
581
+ AT_PRINTF("%s", parent->name);
582
+ if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
583
+ AT_PRINTF(", ");
584
+ }
536
585
  }
586
+ AT_PRINTF("\n");
537
587
  }
538
- // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
539
- int last_barrier_pos = 0;
540
- int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
541
588
 
542
- for (int ind = 0; ind < n_nodes; ind++) {
543
- // allocate a node if there is no parse_seq or this is not a barrier
544
- if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
545
- int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
546
- struct wsp_ggml_tensor * node = gf->nodes[i];
589
+ // update parents
590
+ // update immediately if there is no parse_seq
591
+ // update only at barriers if there is parse_seq
592
+ if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
593
+ int update_start = parse_seq_len ? last_barrier_pos : ind;
594
+ int update_end = parse_seq_len ? ind : ind + 1;
595
+ for (int i = update_start; i < update_end; i++) {
596
+ int node_i = parse_seq_len ? parse_seq[i] : i;
597
+ struct wsp_ggml_tensor * node = gf->nodes[node_i];
547
598
 
548
- // allocate parents (leafs)
549
599
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
550
600
  struct wsp_ggml_tensor * parent = node->src[j];
551
601
  if (parent == NULL) {
552
602
  break;
553
603
  }
554
- allocate_node(alloc, parent);
555
- }
556
-
557
- // allocate node
558
- allocate_node(alloc, node);
604
+ struct hash_node * p_hn = hash_get(galloc, parent);
605
+ p_hn->n_children -= 1;
559
606
 
560
- AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_name(node->op), node->name);
561
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
562
- struct wsp_ggml_tensor * parent = node->src[j];
563
- if (parent == NULL) {
564
- break;
565
- }
566
- AT_PRINTF("%s", parent->name);
567
- if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
568
- AT_PRINTF(", ");
569
- }
570
- }
571
- AT_PRINTF("\n");
572
- }
607
+ //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
573
608
 
574
- // update parents
575
- // update immediately if there is no parse_seq
576
- // update only at barriers if there is parse_seq
577
- if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
578
- int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
579
- int update_end = alloc->parse_seq_len ? ind : ind + 1;
580
- for (int i = update_start; i < update_end; i++) {
581
- int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
582
- struct wsp_ggml_tensor * node = gf->nodes[node_i];
583
-
584
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
585
- struct wsp_ggml_tensor * parent = node->src[j];
586
- if (parent == NULL) {
587
- break;
588
- }
589
- struct hash_node * p_hn = hash_get(ht, parent);
590
- p_hn->n_children -= 1;
591
-
592
- //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
593
-
594
- if (p_hn->n_children == 0 && p_hn->n_views == 0) {
595
- if (wsp_ggml_is_view(parent)) {
596
- struct wsp_ggml_tensor * view_src = parent->view_src;
597
- struct hash_node * view_src_hn = hash_get(ht, view_src);
598
- view_src_hn->n_views -= 1;
599
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
600
- if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
601
- wsp_ggml_allocr_free_tensor(alloc, view_src);
602
- }
603
- }
604
- else {
605
- if (parent->data != node->data) {
606
- wsp_ggml_allocr_free_tensor(alloc, parent);
607
- }
609
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
610
+ if (wsp_ggml_is_view(parent)) {
611
+ struct wsp_ggml_tensor * view_src = parent->view_src;
612
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
613
+ view_src_hn->n_views -= 1;
614
+ AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
615
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
616
+ free_node(galloc, view_src);
608
617
  }
609
618
  }
619
+ else {
620
+ free_node(galloc, parent);
621
+ }
610
622
  }
611
623
  }
612
- AT_PRINTF("\n");
613
- if (alloc->parse_seq_len) {
614
- last_barrier_pos = ind + 1;
615
- }
616
624
  }
625
+ AT_PRINTF("\n");
626
+ if (parse_seq_len) {
627
+ last_barrier_pos = ind + 1;
628
+ }
629
+ }
630
+ }
631
+ }
632
+
633
+ size_t wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, wsp_ggml_tallocr_t talloc, struct wsp_ggml_cgraph * graph) {
634
+ size_t hash_size = graph->visited_hash_table.size;
635
+
636
+ // check if the hash table is initialized and large enough
637
+ if (galloc->hash_set.size < hash_size) {
638
+ if (galloc->hash_set.keys != NULL) {
639
+ free(galloc->hash_set.keys);
617
640
  }
618
- // free graph outputs here that wouldn't be freed otherwise because they have no children
619
- if (outputs != NULL && outputs[g] != NULL) {
620
- for (int i = 0; outputs[g][i] != NULL; i++) {
621
- struct wsp_ggml_tensor * output = outputs[g][i];
622
- AT_PRINTF("output: %s\n", output->name);
623
- wsp_ggml_allocr_free_tensor(alloc, output);
641
+ if (galloc->hash_values != NULL) {
642
+ free(galloc->hash_values);
643
+ }
644
+ galloc->hash_set.keys = malloc(sizeof(struct wsp_ggml_tensor *) * hash_size);
645
+ galloc->hash_set.size = hash_size;
646
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
647
+ }
648
+
649
+ // reset hash table
650
+ memset(galloc->hash_set.keys, 0, sizeof(struct wsp_ggml_tensor *) * hash_size);
651
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
652
+
653
+ galloc->talloc = talloc;
654
+ wsp_ggml_tallocr_alloc_graph_impl(galloc, graph);
655
+ galloc->talloc = NULL;
656
+
657
+ size_t max_size = wsp_ggml_tallocr_max_size(talloc);
658
+
659
+ return max_size;
660
+ }
661
+
662
+ void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, struct wsp_ggml_hash_set hash_set, wsp_ggml_tallocr_t * hash_node_talloc) {
663
+ const size_t hash_size = hash_set.size;
664
+
665
+ WSP_GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
666
+
667
+ galloc->talloc = NULL;
668
+
669
+ // alloc hash_values if needed
670
+ if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
671
+ free(galloc->hash_values);
672
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
673
+ galloc->hash_values_size = hash_size;
674
+ }
675
+
676
+ // free hash_set.keys if needed
677
+ if (galloc->hash_set.keys != NULL) {
678
+ free(galloc->hash_set.keys);
679
+ }
680
+ galloc->hash_set = hash_set;
681
+
682
+ // reset hash values
683
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
684
+
685
+ galloc->hash_allocs = hash_node_talloc;
686
+
687
+ wsp_ggml_tallocr_alloc_graph_impl(galloc, graph);
688
+
689
+ // remove unowned resources
690
+ galloc->hash_set.keys = NULL;
691
+ galloc->hash_allocs = NULL;
692
+ }
693
+
694
+ // legacy API wrapper
695
+
696
+ struct wsp_ggml_allocr {
697
+ wsp_ggml_tallocr_t talloc;
698
+ wsp_ggml_gallocr_t galloc;
699
+ };
700
+
701
+ static wsp_ggml_allocr_t wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_t talloc) {
702
+ wsp_ggml_allocr_t alloc = (wsp_ggml_allocr_t)malloc(sizeof(struct wsp_ggml_allocr));
703
+ *alloc = (struct wsp_ggml_allocr) {
704
+ /*.talloc = */ talloc,
705
+ /*.galloc = */ wsp_ggml_gallocr_new(),
706
+ };
707
+ return alloc;
708
+ }
709
+
710
+ wsp_ggml_allocr_t wsp_ggml_allocr_new(void * data, size_t size, size_t alignment) {
711
+ return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new(data, size, alignment));
712
+ }
713
+
714
+ wsp_ggml_allocr_t wsp_ggml_allocr_new_measure(size_t alignment) {
715
+ return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure(alignment));
716
+ }
717
+
718
+ wsp_ggml_allocr_t wsp_ggml_allocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) {
719
+ return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_buffer(buffer));
720
+ }
721
+
722
+ wsp_ggml_allocr_t wsp_ggml_allocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) {
723
+ return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_backend(backend, size));
724
+ }
725
+
726
+ wsp_ggml_allocr_t wsp_ggml_allocr_new_measure_from_backend(struct wsp_ggml_backend * backend) {
727
+ return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure_from_backend(backend));
728
+ }
729
+
730
+ struct wsp_ggml_backend_buffer * wsp_ggml_allocr_get_buffer(wsp_ggml_allocr_t alloc) {
731
+ return wsp_ggml_tallocr_get_buffer(alloc->talloc);
732
+ }
733
+
734
+ void wsp_ggml_allocr_set_parse_seq(wsp_ggml_allocr_t alloc, const int * list, int n) {
735
+ wsp_ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
736
+ }
737
+
738
+ void wsp_ggml_allocr_free(wsp_ggml_allocr_t alloc) {
739
+ wsp_ggml_gallocr_free(alloc->galloc);
740
+ wsp_ggml_tallocr_free(alloc->talloc);
741
+ free(alloc);
742
+ }
743
+
744
+ bool wsp_ggml_allocr_is_measure(wsp_ggml_allocr_t alloc) {
745
+ return wsp_ggml_tallocr_is_measure(alloc->talloc);
746
+ }
747
+
748
+ void wsp_ggml_allocr_reset(wsp_ggml_allocr_t alloc) {
749
+ wsp_ggml_tallocr_reset(alloc->talloc);
750
+ }
751
+
752
+ void wsp_ggml_allocr_alloc(wsp_ggml_allocr_t alloc, struct wsp_ggml_tensor * tensor) {
753
+ wsp_ggml_tallocr_alloc(alloc->talloc, tensor);
754
+ }
755
+
756
+ size_t wsp_ggml_allocr_max_size(wsp_ggml_allocr_t alloc) {
757
+ return wsp_ggml_tallocr_max_size(alloc->talloc);
758
+ }
759
+
760
+ size_t wsp_ggml_allocr_alloc_graph(wsp_ggml_allocr_t alloc, struct wsp_ggml_cgraph * graph) {
761
+ return wsp_ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
762
+ }
763
+
764
+ // utils
765
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct wsp_ggml_context * ctx, wsp_ggml_backend_buffer_type_t buft) {
766
+ WSP_GGML_ASSERT(wsp_ggml_get_no_alloc(ctx) == true);
767
+
768
+ size_t alignment = wsp_ggml_backend_buft_get_alignment(buft);
769
+
770
+ size_t nbytes = 0;
771
+ for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
772
+ if (t->data == NULL && t->view_src == NULL) {
773
+ nbytes += WSP_GGML_PAD(wsp_ggml_backend_buft_get_alloc_size(buft, t), alignment);
774
+ }
775
+ }
776
+
777
+ if (nbytes == 0) {
778
+ fprintf(stderr, "%s: no tensors to allocate\n", __func__);
779
+ return NULL;
780
+ }
781
+
782
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, nbytes);
783
+ wsp_ggml_tallocr_t tallocr = wsp_ggml_tallocr_new_from_buffer(buffer);
784
+
785
+ for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
786
+ if (t->data == NULL) {
787
+ if (t->view_src == NULL) {
788
+ wsp_ggml_tallocr_alloc(tallocr, t);
789
+ } else {
790
+ wsp_ggml_backend_view_init(buffer, t);
624
791
  }
625
792
  }
626
793
  }
627
794
 
628
- return alloc->max_size;
795
+ wsp_ggml_tallocr_free(tallocr);
796
+
797
+ return buffer;
629
798
  }
630
799
 
631
- size_t wsp_ggml_allocr_alloc_graph(struct wsp_ggml_allocr * alloc, struct wsp_ggml_cgraph * graph) {
632
- return wsp_ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
800
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors(struct wsp_ggml_context * ctx, wsp_ggml_backend_t backend) {
801
+ return wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, wsp_ggml_backend_get_default_buffer_type(backend));
633
802
  }