whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
@@ -0,0 +1,1357 @@
1
+ #include "ggml-backend-impl.h"
2
+ #include "ggml-alloc.h"
3
+ #include "ggml-impl.h"
4
+
5
+ #include <assert.h>
6
+ #include <limits.h>
7
+ #include <stdarg.h>
8
+ #include <stdio.h>
9
+ #include <stdlib.h>
10
+ #include <string.h>
11
+
12
+
13
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
+
15
+
16
+ // backend buffer type
17
+
18
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
19
+ return buft->iface.alloc_buffer(buft, size);
20
+ }
21
+
22
+ size_t wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
23
+ return buft->iface.get_alignment(buft);
24
+ }
25
+
26
+ size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) {
27
+ // get_alloc_size is optional, defaults to wsp_ggml_nbytes
28
+ if (buft->iface.get_alloc_size) {
29
+ return buft->iface.get_alloc_size(buft, tensor);
30
+ }
31
+ return wsp_ggml_nbytes(tensor);
32
+ }
33
+
34
+ bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
35
+ return buft->iface.supports_backend(buft, backend);
36
+ }
37
+
38
+ // backend buffer
39
+
40
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
41
+ wsp_ggml_backend_buffer_type_t buft,
42
+ struct wsp_ggml_backend_buffer_i iface,
43
+ wsp_ggml_backend_buffer_context_t context,
44
+ size_t size) {
45
+ wsp_ggml_backend_buffer_t buffer = malloc(sizeof(struct wsp_ggml_backend_buffer));
46
+
47
+ WSP_GGML_ASSERT(iface.get_base != NULL);
48
+
49
+ (*buffer) = (struct wsp_ggml_backend_buffer) {
50
+ /* .interface = */ iface,
51
+ /* .buft = */ buft,
52
+ /* .context = */ context,
53
+ /* .size = */ size,
54
+ };
55
+
56
+ return buffer;
57
+ }
58
+
59
+ void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) {
60
+ if (buffer == NULL) {
61
+ return;
62
+ }
63
+
64
+ if (buffer->iface.free_buffer != NULL) {
65
+ buffer->iface.free_buffer(buffer);
66
+ }
67
+ free(buffer);
68
+ }
69
+
70
+ size_t wsp_ggml_backend_buffer_get_size(wsp_ggml_backend_buffer_t buffer) {
71
+ return buffer->size;
72
+ }
73
+
74
+ void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
75
+ void * base = buffer->iface.get_base(buffer);
76
+
77
+ WSP_GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
78
+
79
+ return base;
80
+ }
81
+
82
+ void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
83
+ // init_tensor is optional
84
+ if (buffer->iface.init_tensor) {
85
+ buffer->iface.init_tensor(buffer, tensor);
86
+ }
87
+ }
88
+
89
+ size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer) {
90
+ return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type(buffer));
91
+ }
92
+
93
+ size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
94
+ return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type(buffer), tensor);
95
+ }
96
+
97
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_type(wsp_ggml_backend_buffer_t buffer) {
98
+ return buffer->buft;
99
+ }
100
+
101
+ // backend
102
+
103
+ const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) {
104
+ if (backend == NULL) {
105
+ return "NULL";
106
+ }
107
+ return backend->iface.get_name(backend);
108
+ }
109
+
110
+ void wsp_ggml_backend_free(wsp_ggml_backend_t backend) {
111
+ if (backend == NULL) {
112
+ return;
113
+ }
114
+
115
+ backend->iface.free(backend);
116
+ }
117
+
118
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_get_default_buffer_type(wsp_ggml_backend_t backend) {
119
+ return backend->iface.get_default_buffer_type(backend);
120
+ }
121
+
122
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
123
+ return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_get_default_buffer_type(backend), size);
124
+ }
125
+
126
+ size_t wsp_ggml_backend_get_alignment(wsp_ggml_backend_t backend) {
127
+ return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_get_default_buffer_type(backend));
128
+ }
129
+
130
+ void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
131
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
132
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
133
+
134
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
135
+ }
136
+
137
+ void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
138
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
139
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
140
+
141
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
142
+ }
143
+
144
+ void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
145
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
146
+ WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
147
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
148
+
149
+ tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
150
+ }
151
+
152
+ void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
153
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
154
+ WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
155
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
156
+
157
+ tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
158
+ }
159
+
160
+ void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) {
161
+ if (backend->iface.synchronize == NULL) {
162
+ return;
163
+ }
164
+
165
+ backend->iface.synchronize(backend);
166
+ }
167
+
168
+ wsp_ggml_backend_graph_plan_t wsp_ggml_backend_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
169
+ return backend->iface.graph_plan_create(backend, cgraph);
170
+ }
171
+
172
+ void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
173
+ backend->iface.graph_plan_free(backend, plan);
174
+ }
175
+
176
+ void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
177
+ backend->iface.graph_plan_compute(backend, plan);
178
+
179
+ // TODO: optional sync
180
+ wsp_ggml_backend_synchronize(backend);
181
+ }
182
+
183
+ void wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
184
+ backend->iface.graph_compute(backend, cgraph);
185
+
186
+ // TODO: optional sync
187
+ wsp_ggml_backend_synchronize(backend);
188
+ }
189
+
190
+ bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
191
+ return backend->iface.supports_op(backend, op);
192
+ }
193
+
194
+ // backend copy
195
+
196
+ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
197
+ if (a->type != b->type) {
198
+ return false;
199
+ }
200
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
201
+ if (a->ne[i] != b->ne[i]) {
202
+ return false;
203
+ }
204
+ if (a->nb[i] != b->nb[i]) {
205
+ return false;
206
+ }
207
+ }
208
+ return true;
209
+ }
210
+
211
+ void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
212
+ //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
213
+ //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
214
+ WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
215
+
216
+ // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, wsp_ggml_backend_name(src->backend), wsp_ggml_backend_name(dst->backend), wsp_ggml_nbytes(src));
217
+
218
+ if (src == dst) {
219
+ return;
220
+ }
221
+
222
+ // TODO: allow backends to support copy to/from same backend
223
+
224
+ if (dst->buffer->iface.cpy_tensor_from != NULL) {
225
+ dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
226
+ } else if (src->buffer->iface.cpy_tensor_to != NULL) {
227
+ src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
228
+ } else {
229
+ // shouldn't be hit when copying from/to CPU
230
+ #ifndef NDEBUG
231
+ fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
232
+ "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
233
+ #endif
234
+ size_t nbytes = wsp_ggml_nbytes(src);
235
+ void * data = malloc(nbytes);
236
+ wsp_ggml_backend_tensor_get(src, data, 0, nbytes);
237
+ wsp_ggml_backend_tensor_set(dst, data, 0, nbytes);
238
+ free(data);
239
+ }
240
+ }
241
+
242
+ // backend registry
243
+
244
+ #define WSP_GGML_MAX_BACKENDS_REG 16
245
+
246
+ struct wsp_ggml_backend_reg {
247
+ char name[128];
248
+ wsp_ggml_backend_init_fn init_fn;
249
+ wsp_ggml_backend_buffer_type_t default_buffer_type;
250
+ void * user_data;
251
+ };
252
+
253
+ static struct wsp_ggml_backend_reg wsp_ggml_backend_registry[WSP_GGML_MAX_BACKENDS_REG];
254
+ static size_t wsp_ggml_backend_registry_count = 0;
255
+
256
+ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data);
257
+
258
+ static void wsp_ggml_backend_registry_init(void) {
259
+ static bool initialized = false;
260
+
261
+ if (initialized) {
262
+ return;
263
+ }
264
+
265
+ initialized = true;
266
+
267
+ wsp_ggml_backend_register("CPU", wsp_ggml_backend_reg_cpu_init, wsp_ggml_backend_cpu_buffer_type(), NULL);
268
+
269
+ // add forward decls here to avoid including the backend headers
270
+ #ifdef WSP_GGML_USE_CUBLAS
271
+ extern void wsp_ggml_backend_cuda_reg_devices(void);
272
+ wsp_ggml_backend_cuda_reg_devices();
273
+ #endif
274
+
275
+ #ifdef WSP_GGML_USE_METAL
276
+ extern wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data);
277
+ extern wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
278
+ wsp_ggml_backend_register("Metal", wsp_ggml_backend_reg_metal_init, wsp_ggml_backend_metal_buffer_type(), NULL);
279
+ #endif
280
+ }
281
+
282
+ void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
283
+ WSP_GGML_ASSERT(wsp_ggml_backend_registry_count < WSP_GGML_MAX_BACKENDS_REG);
284
+
285
+ int id = wsp_ggml_backend_registry_count;
286
+
287
+ wsp_ggml_backend_registry[id] = (struct wsp_ggml_backend_reg) {
288
+ /* .name = */ {0},
289
+ /* .fn = */ init_fn,
290
+ /* .default_buffer_type = */ default_buffer_type,
291
+ /* .user_data = */ user_data,
292
+ };
293
+
294
+ snprintf(wsp_ggml_backend_registry[id].name, sizeof(wsp_ggml_backend_registry[id].name), "%s", name);
295
+
296
+ #ifndef NDEBUG
297
+ fprintf(stderr, "%s: registered backend %s\n", __func__, name);
298
+ #endif
299
+
300
+ wsp_ggml_backend_registry_count++;
301
+ }
302
+
303
+ size_t wsp_ggml_backend_reg_get_count(void) {
304
+ wsp_ggml_backend_registry_init();
305
+
306
+ return wsp_ggml_backend_registry_count;
307
+ }
308
+
309
+ size_t wsp_ggml_backend_reg_find_by_name(const char * name) {
310
+ wsp_ggml_backend_registry_init();
311
+
312
+ for (size_t i = 0; i < wsp_ggml_backend_registry_count; i++) {
313
+ // TODO: case insensitive in a portable way
314
+ if (strcmp(wsp_ggml_backend_registry[i].name, name) == 0) {
315
+ return i;
316
+ }
317
+ }
318
+ return SIZE_MAX;
319
+ }
320
+
321
+ // init from backend:params string
322
+ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend_from_str(const char * backend_str) {
323
+ wsp_ggml_backend_registry_init();
324
+
325
+ const char * params = strchr(backend_str, ':');
326
+ char backend_name[128];
327
+ if (params == NULL) {
328
+ strcpy(backend_name, backend_str);
329
+ params = "";
330
+ } else {
331
+ strncpy(backend_name, backend_str, params - backend_str);
332
+ backend_name[params - backend_str] = '\0';
333
+ params++;
334
+ }
335
+
336
+ size_t backend_i = wsp_ggml_backend_reg_find_by_name(backend_name);
337
+ if (backend_i == SIZE_MAX) {
338
+ fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
339
+ return NULL;
340
+ }
341
+
342
+ return wsp_ggml_backend_reg_init_backend(backend_i, params);
343
+ }
344
+
345
+ const char * wsp_ggml_backend_reg_get_name(size_t i) {
346
+ wsp_ggml_backend_registry_init();
347
+
348
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
349
+ return wsp_ggml_backend_registry[i].name;
350
+ }
351
+
352
+ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend(size_t i, const char * params) {
353
+ wsp_ggml_backend_registry_init();
354
+
355
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
356
+ return wsp_ggml_backend_registry[i].init_fn(params, wsp_ggml_backend_registry[i].user_data);
357
+ }
358
+
359
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_reg_get_default_buffer_type(size_t i) {
360
+ wsp_ggml_backend_registry_init();
361
+
362
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
363
+ return wsp_ggml_backend_registry[i].default_buffer_type;
364
+ }
365
+
366
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
367
+ wsp_ggml_backend_registry_init();
368
+
369
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
370
+ return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_registry[i].default_buffer_type, size);
371
+ }
372
+
373
+ // backend CPU
374
+
375
+ static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
376
+ return (void *)buffer->context;
377
+ }
378
+
379
+ static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
380
+ free(buffer->context);
381
+ WSP_GGML_UNUSED(buffer);
382
+ }
383
+
384
+ static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
385
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
386
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
387
+
388
+ memcpy((char *)tensor->data + offset, data, size);
389
+
390
+ WSP_GGML_UNUSED(buffer);
391
+ }
392
+
393
+ static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
394
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
395
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
396
+
397
+ memcpy(data, (const char *)tensor->data + offset, size);
398
+
399
+ WSP_GGML_UNUSED(buffer);
400
+ }
401
+
402
+ static void wsp_ggml_backend_cpu_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
403
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
404
+
405
+ WSP_GGML_UNUSED(buffer);
406
+ }
407
+
408
+ static void wsp_ggml_backend_cpu_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
409
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
410
+
411
+ WSP_GGML_UNUSED(buffer);
412
+ }
413
+
414
+ static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = {
415
+ /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer,
416
+ /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
417
+ /* .init_tensor = */ NULL, // no initialization required
418
+ /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
419
+ /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
420
+ /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
421
+ /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
422
+ };
423
+
424
+ // for buffers from ptr, free is not called
425
+ static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
426
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
427
+ /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
428
+ /* .init_tensor = */ NULL, // no initialization required
429
+ /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
430
+ /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
431
+ /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
432
+ /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
433
+ };
434
+
435
+ static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
436
+
437
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
438
+ size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
439
+ void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC?
440
+
441
+ WSP_GGML_ASSERT(data != NULL && "failed to allocate buffer");
442
+
443
+ return wsp_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
444
+ }
445
+
446
+ static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
447
+ return TENSOR_ALIGNMENT;
448
+
449
+ WSP_GGML_UNUSED(buft);
450
+ }
451
+
452
+ static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
453
+ return wsp_ggml_backend_is_cpu(backend);
454
+
455
+ WSP_GGML_UNUSED(buft);
456
+ }
457
+
458
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) {
459
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_cpu = {
460
+ /* .iface = */ {
461
+ /* .alloc_buffer = */ wsp_ggml_backend_cpu_buffer_type_alloc_buffer,
462
+ /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
463
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
464
+ /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
465
+ },
466
+ /* .context = */ NULL,
467
+ };
468
+
469
+ return &wsp_ggml_backend_buffer_type_cpu;
470
+ }
471
+
472
+ struct wsp_ggml_backend_cpu_context {
473
+ int n_threads;
474
+ void * work_data;
475
+ size_t work_size;
476
+ };
477
+
478
+ static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
479
+ return "CPU";
480
+
481
+ WSP_GGML_UNUSED(backend);
482
+ }
483
+
484
+ static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
485
+ struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
486
+ free(cpu_ctx->work_data);
487
+ free(cpu_ctx);
488
+ free(backend);
489
+ }
490
+
491
+ static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) {
492
+ return wsp_ggml_backend_cpu_buffer_type();
493
+
494
+ WSP_GGML_UNUSED(backend);
495
+ }
496
+
497
+ struct wsp_ggml_backend_plan_cpu {
498
+ struct wsp_ggml_cplan cplan;
499
+ struct wsp_ggml_cgraph cgraph;
500
+ };
501
+
502
+ static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
503
+ struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
504
+
505
+ struct wsp_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct wsp_ggml_backend_plan_cpu));
506
+
507
+ cpu_plan->cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
508
+ cpu_plan->cgraph = *cgraph;
509
+
510
+ if (cpu_plan->cplan.work_size > 0) {
511
+ cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
512
+ }
513
+
514
+ return cpu_plan;
515
+ }
516
+
517
+ static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
518
+ struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
519
+
520
+ free(cpu_plan->cplan.work_data);
521
+ free(cpu_plan);
522
+
523
+ WSP_GGML_UNUSED(backend);
524
+ }
525
+
526
+ static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
527
+ struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
528
+
529
+ wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
530
+
531
+ WSP_GGML_UNUSED(backend);
532
+ }
533
+
534
+ static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
535
+ struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
536
+
537
+ struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
538
+
539
+ if (cpu_ctx->work_size < cplan.work_size) {
540
+ // TODO: may be faster to free and use malloc to avoid the copy
541
+ cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
542
+ cpu_ctx->work_size = cplan.work_size;
543
+ }
544
+
545
+ cplan.work_data = cpu_ctx->work_data;
546
+
547
+ wsp_ggml_graph_compute(cgraph, &cplan);
548
+ }
549
+
550
+ static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
551
+ return true;
552
+
553
+ WSP_GGML_UNUSED(backend);
554
+ WSP_GGML_UNUSED(op);
555
+ }
556
+
557
+ static struct wsp_ggml_backend_i cpu_backend_i = {
558
+ /* .get_name = */ wsp_ggml_backend_cpu_name,
559
+ /* .free = */ wsp_ggml_backend_cpu_free,
560
+ /* .get_default_buffer_type = */ wsp_ggml_backend_cpu_get_default_buffer_type,
561
+ /* .set_tensor_async = */ NULL,
562
+ /* .get_tensor_async = */ NULL,
563
+ /* .cpy_tensor_from_async = */ NULL,
564
+ /* .cpy_tensor_to_async = */ NULL,
565
+ /* .synchronize = */ NULL,
566
+ /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create,
567
+ /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free,
568
+ /* .graph_plan_compute = */ wsp_ggml_backend_cpu_graph_plan_compute,
569
+ /* .graph_compute = */ wsp_ggml_backend_cpu_graph_compute,
570
+ /* .supports_op = */ wsp_ggml_backend_cpu_supports_op,
571
+ };
572
+
573
+ wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) {
574
+ struct wsp_ggml_backend_cpu_context * ctx = malloc(sizeof(struct wsp_ggml_backend_cpu_context));
575
+
576
+ ctx->n_threads = WSP_GGML_DEFAULT_N_THREADS;
577
+ ctx->work_data = NULL;
578
+ ctx->work_size = 0;
579
+
580
+ wsp_ggml_backend_t cpu_backend = malloc(sizeof(struct wsp_ggml_backend));
581
+
582
+ *cpu_backend = (struct wsp_ggml_backend) {
583
+ /* .interface = */ cpu_backend_i,
584
+ /* .context = */ ctx
585
+ };
586
+ return cpu_backend;
587
+ }
588
+
589
+ bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) {
590
+ return backend->iface.get_name == wsp_ggml_backend_cpu_name;
591
+ }
592
+
593
+ void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads) {
594
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_cpu(backend_cpu));
595
+
596
+ struct wsp_ggml_backend_cpu_context * ctx = (struct wsp_ggml_backend_cpu_context *)backend_cpu->context;
597
+ ctx->n_threads = n_threads;
598
+ }
599
+
600
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
601
+ return wsp_ggml_backend_buffer_init(wsp_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
602
+ }
603
+
604
+ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) {
605
+ return wsp_ggml_backend_cpu_init();
606
+
607
+ WSP_GGML_UNUSED(params);
608
+ WSP_GGML_UNUSED(user_data);
609
+ }
610
+
611
+
612
+ // scheduler
613
+
614
+ #define WSP_GGML_MAX_BACKENDS 4
615
+ #define WSP_GGML_MAX_SPLITS 256
616
+ #define WSP_GGML_MAX_SPLIT_INPUTS 16
617
+
618
+ struct wsp_ggml_backend_sched_split {
619
+ wsp_ggml_tallocr_t tallocr;
620
+ int i_start;
621
+ int i_end;
622
+ struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS];
623
+ int n_inputs;
624
+ struct wsp_ggml_cgraph graph;
625
+ };
626
+
627
+ struct wsp_ggml_backend_sched {
628
+ int n_backends;
629
+ wsp_ggml_backend_t backends[WSP_GGML_MAX_BACKENDS];
630
+ wsp_ggml_tallocr_t tallocs[WSP_GGML_MAX_BACKENDS];
631
+
632
+ wsp_ggml_gallocr_t galloc;
633
+
634
+ struct wsp_ggml_hash_set hash_set;
635
+ wsp_ggml_tallocr_t * node_talloc; // [hash_set.size]
636
+ struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // [hash_set.size][WSP_GGML_MAX_BACKENDS]
637
+
638
+ struct wsp_ggml_cgraph * graph;
639
+ struct wsp_ggml_backend_sched_split splits[WSP_GGML_MAX_SPLITS];
640
+ int n_splits;
641
+
642
+ struct wsp_ggml_context * ctx;
643
+
644
+ // align context_buffer to WSP_GGML_MEM_ALIGN
645
+ #ifdef _MSC_VER
646
+ __declspec(align(WSP_GGML_MEM_ALIGN))
647
+ #else
648
+ __attribute__((aligned(WSP_GGML_MEM_ALIGN)))
649
+ #endif
650
+ char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + sizeof(struct wsp_ggml_cgraph)];
651
+ };
652
+
653
+ #define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node)
654
+ #define node_allocr(node) sched->node_talloc[hash_id(node)]
655
+
656
+ static bool wsp_ggml_is_view_op(enum wsp_ggml_op op) {
657
+ return op == WSP_GGML_OP_VIEW || op == WSP_GGML_OP_RESHAPE || op == WSP_GGML_OP_PERMUTE || op == WSP_GGML_OP_TRANSPOSE;
658
+ }
659
+
660
+ // returns the priority of the backend, lower is better
661
+ static int sched_backend_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
662
+ for (int i = 0; i < sched->n_backends; i++) {
663
+ if (sched->backends[i] == backend) {
664
+ return i;
665
+ }
666
+ }
667
+ return INT_MAX;
668
+ }
669
+
670
+ static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) {
671
+ for (int i = 0; i < sched->n_backends; i++) {
672
+ if (sched->tallocs[i] == allocr) {
673
+ return i;
674
+ }
675
+ }
676
+ return INT_MAX;
677
+ }
678
+
679
+ static wsp_ggml_backend_t get_buffer_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) {
680
+ if (buffer == NULL) {
681
+ return NULL;
682
+ }
683
+ // find highest prio backend that supports the buffer type
684
+ for (int i = 0; i < sched->n_backends; i++) {
685
+ if (wsp_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
686
+ return sched->backends[i];
687
+ }
688
+ }
689
+ WSP_GGML_ASSERT(false && "tensor buffer type not supported by any backend");
690
+ }
691
+
692
+ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) {
693
+ if (allocr == NULL) {
694
+ return NULL;
695
+ }
696
+ // find highest prio backend that supports the buffer type
697
+ for (int i = 0; i < sched->n_backends; i++) {
698
+ if (sched->tallocs[i] == allocr) {
699
+ return sched->backends[i];
700
+ }
701
+ }
702
+ WSP_GGML_UNREACHABLE();
703
+ }
704
+
705
+ #if 0
706
+ static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
707
+ #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
708
+ #define GET_CAUSE(node) causes[hash_id(node)]
709
+ #else
710
+ #define SET_CAUSE(node, ...)
711
+ #define GET_CAUSE(node) ""
712
+ #endif
713
+
714
+ // returns the backend that should be used for the node based on the current locations
715
+ static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
716
+ // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
717
+ // ie. kv cache updates
718
+ // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
719
+ // dst
720
+ wsp_ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
721
+ if (cur_backend != NULL) {
722
+ SET_CAUSE(node, "1.dst");
723
+ return cur_backend;
724
+ }
725
+
726
+ // view_src
727
+ if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
728
+ SET_CAUSE(node, "1.vsrc");
729
+ return get_buffer_backend(sched, node->view_src->buffer);
730
+ }
731
+
732
+ // src
733
+ int cur_prio = INT_MAX;
734
+ size_t cur_size = 0;
735
+
736
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
737
+ const struct wsp_ggml_tensor * src = node->src[i];
738
+ if (src == NULL) {
739
+ break;
740
+ }
741
+ wsp_ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
742
+ if (src_backend != NULL) {
743
+ int src_prio = sched_backend_prio(sched, src_backend);
744
+ size_t src_size = wsp_ggml_nbytes(src);
745
+ if (src_prio < cur_prio && src_size >= cur_size) {
746
+ cur_prio = src_prio;
747
+ cur_size = src_size;
748
+ cur_backend = src_backend;
749
+ SET_CAUSE(node, "1.src%d", i);
750
+ }
751
+ }
752
+ }
753
+ return cur_backend;
754
+ }
755
+
756
+ static char * fmt_size(size_t size) {
757
+ static char buffer[128];
758
+ if (size >= 1024*1024) {
759
+ sprintf(buffer, "%zuM", size/1024/1024);
760
+ } else {
761
+ sprintf(buffer, "%zuK", size/1024);
762
+ }
763
+ return buffer;
764
+ }
765
+
766
+ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
767
+ int cur_split = 0;
768
+ for (int i = 0; i < graph->n_nodes; i++) {
769
+ if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
770
+ wsp_ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
771
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend),
772
+ sched->splits[cur_split].n_inputs);
773
+ for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
774
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
775
+ fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j])));
776
+ }
777
+ fprintf(stderr, "\n");
778
+ cur_split++;
779
+ }
780
+ struct wsp_ggml_tensor * node = graph->nodes[i];
781
+ if (wsp_ggml_is_view_op(node->op)) {
782
+ continue;
783
+ }
784
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
785
+ wsp_ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
786
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name,
787
+ fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
788
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
789
+ struct wsp_ggml_tensor * src = node->src[j];
790
+ if (src == NULL) {
791
+ break;
792
+ }
793
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
794
+ wsp_ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
795
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
796
+ fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
797
+ }
798
+ fprintf(stderr, "\n");
799
+ }
800
+ }
801
+
802
+ // creates a copy of the tensor with the same memory layout
803
+ static struct wsp_ggml_tensor * wsp_ggml_dup_tensor_layout(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * tensor) {
804
+ struct wsp_ggml_tensor * dup = wsp_ggml_dup_tensor(ctx, tensor);
805
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
806
+ dup->nb[i] = tensor->nb[i];
807
+ }
808
+ return dup;
809
+ }
810
+
811
+ // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
812
+ // TODO: merge passes
813
+ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
814
+ // reset state
815
+ size_t hash_size = sched->hash_set.size;
816
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
817
+ memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
818
+ memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
819
+ sched->n_splits = 0;
820
+
821
+ struct wsp_ggml_init_params params = {
822
+ /* .mem_size = */ sizeof(sched->context_buffer),
823
+ /* .mem_buffer = */ sched->context_buffer,
824
+ /* .no_alloc = */ true
825
+ };
826
+
827
+ if (sched->ctx != NULL) {
828
+ wsp_ggml_free(sched->ctx);
829
+ }
830
+
831
+ sched->ctx = wsp_ggml_init(params);
832
+
833
+ // pass 1: assign backends to ops with allocated inputs
834
+ for (int i = 0; i < graph->n_leafs; i++) {
835
+ struct wsp_ggml_tensor * leaf = graph->leafs[i];
836
+ if (node_allocr(leaf) != NULL) {
837
+ // do not overwrite user assignments
838
+ continue;
839
+ }
840
+ wsp_ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
841
+ if (leaf_backend == NULL && leaf->view_src != NULL) {
842
+ leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
843
+ }
844
+ if (leaf_backend != NULL) {
845
+ node_allocr(leaf) = wsp_ggml_backend_sched_get_tallocr(sched, leaf_backend);
846
+ }
847
+ }
848
+
849
+ for (int i = 0; i < graph->n_nodes; i++) {
850
+ struct wsp_ggml_tensor * node = graph->nodes[i];
851
+ if (node_allocr(node) != NULL) {
852
+ // do not overwrite user assignments
853
+ continue;
854
+ }
855
+ wsp_ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
856
+ if (node_backend != NULL) {
857
+ node_allocr(node) = wsp_ggml_backend_sched_get_tallocr(sched, node_backend);
858
+ }
859
+ }
860
+ //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
861
+
862
+ // pass 2: assign backends to ops from current assignments
863
+ // TODO:
864
+ // - reuse sched_backend_from_cur
865
+ for (int i = 0; i < graph->n_nodes; i++) {
866
+ struct wsp_ggml_tensor * node = graph->nodes[i];
867
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
868
+ if (node_allocr == NULL) {
869
+ int cur_prio = INT_MAX;
870
+ size_t cur_size = 0;
871
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
872
+ struct wsp_ggml_tensor * src = node->src[j];
873
+ if (src == NULL) {
874
+ break;
875
+ }
876
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
877
+ if (src_allocr != NULL) {
878
+ int src_prio = sched_allocr_prio(sched, src_allocr);
879
+ size_t src_size = wsp_ggml_nbytes(src);
880
+ if (src_prio < cur_prio && src_size >= cur_size) {
881
+ cur_prio = src_prio;
882
+ cur_size = src_size;
883
+ node_allocr = src_allocr;
884
+ SET_CAUSE(node, "2.src%d", j);
885
+ }
886
+ }
887
+ }
888
+ if (node_allocr != NULL) {
889
+ node_allocr(node) = node_allocr;
890
+ }
891
+ }
892
+ }
893
+ //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
894
+
895
+ // pass 3: assign backends to remaining src from dst (should only be leafs)
896
+ for (int i = 0; i < graph->n_nodes; i++) {
897
+ struct wsp_ggml_tensor * node = graph->nodes[i];
898
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
899
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
900
+ struct wsp_ggml_tensor * src = node->src[j];
901
+ if (src == NULL) {
902
+ break;
903
+ }
904
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
905
+ if (src_allocr == NULL) {
906
+ node_allocr(src) = node_allocr;
907
+ }
908
+ }
909
+ }
910
+ //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
911
+
912
+ // pass 4: split graph, find tensors that need to be copied
913
+ // TODO:
914
+ // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
915
+ // find first backend
916
+ int cur_split = 0;
917
+ for (int i = 0; i < graph->n_nodes; i++) {
918
+ struct wsp_ggml_tensor * node = graph->nodes[i];
919
+ if (node->view_src == NULL) {
920
+ sched->splits[0].tallocr = node_allocr(node);
921
+ break;
922
+ }
923
+ }
924
+ sched->splits[0].i_start = 0;
925
+ sched->splits[0].n_inputs = 0;
926
+ memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
927
+ wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
928
+ size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
929
+ for (int i = 0; i < graph->n_nodes; i++) {
930
+ struct wsp_ggml_tensor * node = graph->nodes[i];
931
+
932
+ if (wsp_ggml_is_view_op(node->op)) {
933
+ continue;
934
+ }
935
+
936
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
937
+
938
+ if (node_allocr != cur_allocr) {
939
+ sched->splits[cur_split].i_end = i;
940
+ cur_split++;
941
+ WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS);
942
+ sched->splits[cur_split].tallocr = node_allocr;
943
+ sched->splits[cur_split].i_start = i;
944
+ sched->splits[cur_split].n_inputs = 0;
945
+ memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
946
+ cur_allocr = node_allocr;
947
+ cur_backend_id = sched_allocr_prio(sched, cur_allocr);
948
+ }
949
+
950
+ // find inputs that are not on the same backend
951
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
952
+ struct wsp_ggml_tensor * src = node->src[j];
953
+ if (src == NULL) {
954
+ break;
955
+ }
956
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
957
+ if (src_allocr != node_allocr) {
958
+ int n_inputs = sched->splits[cur_split].n_inputs++;
959
+ WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS);
960
+ sched->splits[cur_split].inputs[n_inputs] = (struct wsp_ggml_tensor *)src;
961
+
962
+ // create copies
963
+ size_t id = hash_id(src);
964
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
965
+ struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src);
966
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
967
+ node_allocr(tensor_copy) = cur_allocr;
968
+ wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
969
+ wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name);
970
+ }
971
+ node->src[j] = sched->node_copies[id][cur_backend_id];
972
+ }
973
+ }
974
+ }
975
+ sched->splits[cur_split].i_end = graph->n_nodes;
976
+ sched->n_splits = cur_split + 1;
977
+
978
+ //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
979
+
980
+ #if 1
981
+ // sanity check: all sources should have the same backend as the node
982
+ for (int i = 0; i < graph->n_nodes; i++) {
983
+ struct wsp_ggml_tensor * node = graph->nodes[i];
984
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
985
+ if (node_allocr == NULL) {
986
+ fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
987
+ }
988
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
989
+ struct wsp_ggml_tensor * src = node->src[j];
990
+ if (src == NULL) {
991
+ break;
992
+ }
993
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
994
+ if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
995
+ fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
996
+ node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
997
+ j, src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
998
+ }
999
+ }
1000
+ }
1001
+ #endif
1002
+
1003
+ // create copies of the graph for each split
1004
+ // FIXME: avoid this copy, pass split inputs to wsp_ggml_gallocr_alloc_graph_n in some other way
1005
+ struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*WSP_GGML_MAX_SPLIT_INPUTS, false);
1006
+ for (int i = 0; i < sched->n_splits; i++) {
1007
+ struct wsp_ggml_backend_sched_split * split = &sched->splits[i];
1008
+ split->graph = wsp_ggml_graph_view(graph, split->i_start, split->i_end);
1009
+
1010
+ // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
1011
+ for (int j = 0; j < split->n_inputs; j++) {
1012
+ struct wsp_ggml_tensor * input = split->inputs[j];
1013
+ struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
1014
+ input_cpy->src[0] = input;
1015
+ graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
1016
+ }
1017
+
1018
+ for (int j = split->i_start; j < split->i_end; j++) {
1019
+ graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
1020
+ }
1021
+ }
1022
+ sched->graph = graph_copy;
1023
+ }
1024
+
1025
+ static void sched_alloc_splits(wsp_ggml_backend_sched_t sched) {
1026
+ wsp_ggml_gallocr_alloc_graph_n(
1027
+ sched->galloc,
1028
+ sched->graph,
1029
+ sched->hash_set,
1030
+ sched->node_talloc);
1031
+ }
1032
+
1033
+ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
1034
+ uint64_t copy_us[WSP_GGML_MAX_BACKENDS] = {0};
1035
+ uint64_t compute_us[WSP_GGML_MAX_BACKENDS] = {0};
1036
+
1037
+ struct wsp_ggml_backend_sched_split * splits = sched->splits;
1038
+
1039
+ for (int i = 0; i < sched->n_splits; i++) {
1040
+ struct wsp_ggml_backend_sched_split * split = &splits[i];
1041
+ wsp_ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
1042
+ int split_backend_id = sched_backend_prio(sched, split_backend);
1043
+
1044
+ // copy the input tensors to the split backend
1045
+ uint64_t copy_start_us = wsp_ggml_time_us();
1046
+ for (int j = 0; j < split->n_inputs; j++) {
1047
+ struct wsp_ggml_tensor * input = split->inputs[j];
1048
+ struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
1049
+ if (input->buffer == NULL) {
1050
+ if (input->view_src == NULL) {
1051
+ fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
1052
+ exit(1);
1053
+ }
1054
+ // FIXME: may need to use the sched buffer instead
1055
+ wsp_ggml_backend_view_init(input->view_src->buffer, input);
1056
+ }
1057
+ if (input_cpy->buffer == NULL) {
1058
+ fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
1059
+ exit(1);
1060
+ }
1061
+ //WSP_GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
1062
+ //WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend);
1063
+ wsp_ggml_backend_tensor_copy(input, input_cpy);
1064
+ }
1065
+ // wsp_ggml_backend_synchronize(split_backend);
1066
+ int64_t copy_end_us = wsp_ggml_time_us();
1067
+ copy_us[split_backend_id] += copy_end_us - copy_start_us;
1068
+
1069
+ #if 0
1070
+ char split_filename[WSP_GGML_MAX_NAME];
1071
+ snprintf(split_filename, WSP_GGML_MAX_NAME, "split_%i_%s.dot", i, wsp_ggml_backend_name(split_backend));
1072
+ wsp_ggml_graph_dump_dot(split->graph, NULL, split_filename);
1073
+ #endif
1074
+
1075
+ uint64_t compute_start_us = wsp_ggml_time_us();
1076
+ wsp_ggml_backend_graph_compute(split_backend, &split->graph);
1077
+ // wsp_ggml_backend_synchronize(split_backend);
1078
+ uint64_t compute_end_us = wsp_ggml_time_us();
1079
+ compute_us[split_backend_id] += compute_end_us - compute_start_us;
1080
+ }
1081
+
1082
+ #if 0
1083
+ // per-backend timings
1084
+ fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
1085
+ for (int i = 0; i < sched->n_backends; i++) {
1086
+ if (copy_us[i] > 0 || compute_us[i] > 0) {
1087
+ fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", wsp_ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
1088
+ }
1089
+ }
1090
+ #endif
1091
+ }
1092
+
1093
+ static void sched_reset(wsp_ggml_backend_sched_t sched) {
1094
+ for (int i = 0; i < sched->n_backends; i++) {
1095
+ wsp_ggml_tallocr_reset(sched->tallocs[i]);
1096
+ }
1097
+ }
1098
+
1099
+ wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends) {
1100
+ WSP_GGML_ASSERT(n_backends <= WSP_GGML_MAX_BACKENDS);
1101
+
1102
+ struct wsp_ggml_backend_sched * sched = malloc(sizeof(struct wsp_ggml_backend_sched));
1103
+ memset(sched, 0, sizeof(struct wsp_ggml_backend_sched));
1104
+
1105
+ sched->n_backends = n_backends;
1106
+ for (int i = 0; i < n_backends; i++) {
1107
+ sched->backends[i] = backends[i];
1108
+ }
1109
+
1110
+ sched->galloc = wsp_ggml_gallocr_new();
1111
+
1112
+ // init measure allocs for each backend
1113
+ for (int i = 0; i < n_backends; i++) {
1114
+ sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_backend(backends[i]);
1115
+ }
1116
+
1117
+ return sched;
1118
+ }
1119
+
1120
+ void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) {
1121
+ if (sched == NULL) {
1122
+ return;
1123
+ }
1124
+ for (int i = 0; i < sched->n_backends; i++) {
1125
+ wsp_ggml_tallocr_free(sched->tallocs[i]);
1126
+ }
1127
+ wsp_ggml_gallocr_free(sched->galloc);
1128
+ free(sched->hash_set.keys);
1129
+ free(sched->node_talloc);
1130
+ free(sched->node_copies);
1131
+ free(sched);
1132
+ }
1133
+
1134
+ void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph) {
1135
+ // initialize hash tables
1136
+ size_t hash_size = measure_graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS;
1137
+ sched->hash_set.size = hash_size;
1138
+ sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
1139
+ sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
1140
+ sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
1141
+
1142
+ sched_split_graph(sched, measure_graph);
1143
+ sched_alloc_splits(sched);
1144
+
1145
+ // allocate buffers and reset allocators
1146
+ for (int i = 0; i < sched->n_backends; i++) {
1147
+ size_t size = wsp_ggml_tallocr_max_size(sched->tallocs[i]);
1148
+ wsp_ggml_tallocr_free(sched->tallocs[i]);
1149
+ sched->tallocs[i] = wsp_ggml_tallocr_new_from_backend(sched->backends[i], size);
1150
+ }
1151
+
1152
+ sched_reset(sched);
1153
+ }
1154
+
1155
+ void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
1156
+ WSP_GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1157
+
1158
+ sched_split_graph(sched, graph);
1159
+ sched_alloc_splits(sched);
1160
+ sched_compute_splits(sched);
1161
+ sched_reset(sched);
1162
+ }
1163
+
1164
+ wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1165
+ int backend_index = sched_backend_prio(sched, backend);
1166
+ return sched->tallocs[backend_index];
1167
+ }
1168
+
1169
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1170
+ int backend_index = sched_backend_prio(sched, backend);
1171
+ return wsp_ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
1172
+ }
1173
+
1174
+ void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend) {
1175
+ int backend_index = sched_backend_prio(sched, backend);
1176
+ WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1177
+ node_allocr(node) = sched->tallocs[backend_index];
1178
+ }
1179
+
1180
+ // utils
1181
+ void wsp_ggml_backend_view_init(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
1182
+ WSP_GGML_ASSERT(tensor->buffer == NULL);
1183
+ WSP_GGML_ASSERT(tensor->data == NULL);
1184
+ WSP_GGML_ASSERT(tensor->view_src != NULL);
1185
+ WSP_GGML_ASSERT(tensor->view_src->buffer != NULL);
1186
+ WSP_GGML_ASSERT(tensor->view_src->data != NULL);
1187
+
1188
+ tensor->buffer = buffer;
1189
+ tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1190
+ tensor->backend = tensor->view_src->backend;
1191
+ wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1192
+ }
1193
+
1194
+ void wsp_ggml_backend_tensor_alloc(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, void * addr) {
1195
+ WSP_GGML_ASSERT(tensor->buffer == NULL);
1196
+ WSP_GGML_ASSERT(tensor->data == NULL);
1197
+ WSP_GGML_ASSERT(tensor->view_src == NULL);
1198
+ WSP_GGML_ASSERT(addr >= wsp_ggml_backend_buffer_get_base(buffer));
1199
+ WSP_GGML_ASSERT((char *)addr + wsp_ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
1200
+ (char *)wsp_ggml_backend_buffer_get_base(buffer) + wsp_ggml_backend_buffer_get_size(buffer));
1201
+
1202
+ tensor->buffer = buffer;
1203
+ tensor->data = addr;
1204
+ wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1205
+ }
1206
+
1207
+ static struct wsp_ggml_tensor * graph_dup_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies,
1208
+ struct wsp_ggml_context * ctx_allocated, struct wsp_ggml_context * ctx_unallocated, struct wsp_ggml_tensor * src) {
1209
+
1210
+ WSP_GGML_ASSERT(src != NULL);
1211
+ WSP_GGML_ASSERT(src->data && "graph must be allocated");
1212
+
1213
+ size_t id = wsp_ggml_hash_insert(hash_set, src);
1214
+ if (id == WSP_GGML_HASHTABLE_ALREADY_EXISTS) {
1215
+ return node_copies[wsp_ggml_hash_find(hash_set, src)];
1216
+ }
1217
+
1218
+ struct wsp_ggml_tensor * dst = wsp_ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
1219
+ if (src->view_src != NULL) {
1220
+ dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
1221
+ dst->view_offs = src->view_offs;
1222
+ }
1223
+ dst->op = src->op;
1224
+ memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
1225
+ wsp_ggml_set_name(dst, src->name);
1226
+
1227
+ // copy src
1228
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1229
+ struct wsp_ggml_tensor * s = src->src[i];
1230
+ if (s == NULL) {
1231
+ break;
1232
+ }
1233
+ dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
1234
+ }
1235
+
1236
+ node_copies[id] = dst;
1237
+ return dst;
1238
+ }
1239
+
1240
+ static void graph_init_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies, bool * node_init, struct wsp_ggml_tensor * src) {
1241
+ size_t id = wsp_ggml_hash_find(hash_set, src);
1242
+ if (node_init[id]) {
1243
+ return;
1244
+ }
1245
+ node_init[id] = true;
1246
+
1247
+ struct wsp_ggml_tensor * dst = node_copies[id];
1248
+ if (dst->view_src != NULL) {
1249
+ wsp_ggml_backend_view_init(dst->view_src->buffer, dst);
1250
+ }
1251
+ else {
1252
+ wsp_ggml_backend_tensor_copy(src, dst);
1253
+ }
1254
+
1255
+ // init src
1256
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1257
+ struct wsp_ggml_tensor * s = src->src[i];
1258
+ if (s == NULL) {
1259
+ break;
1260
+ }
1261
+ graph_init_tensor(hash_set, node_copies, node_init, s);
1262
+ }
1263
+ }
1264
+
1265
+ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * graph) {
1266
+ struct wsp_ggml_hash_set hash_set = {
1267
+ /* .size = */ graph->visited_hash_table.size,
1268
+ /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
1269
+ };
1270
+ struct wsp_ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
1271
+ bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
1272
+
1273
+ struct wsp_ggml_init_params params = {
1274
+ /* .mem_size = */ wsp_ggml_tensor_overhead()*hash_set.size + wsp_ggml_graph_overhead_custom(graph->size, false),
1275
+ /* .mem_buffer = */ NULL,
1276
+ /* .no_alloc = */ true
1277
+ };
1278
+
1279
+ struct wsp_ggml_context * ctx_allocated = wsp_ggml_init(params);
1280
+ struct wsp_ggml_context * ctx_unallocated = wsp_ggml_init(params);
1281
+
1282
+ // dup nodes
1283
+ for (int i = 0; i < graph->n_nodes; i++) {
1284
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1285
+ graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
1286
+ }
1287
+
1288
+ // allocate nodes
1289
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
1290
+
1291
+ //printf("copy buffer size: %zu MB\n", wsp_ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
1292
+
1293
+ // copy data and init views
1294
+ for (int i = 0; i < graph->n_nodes; i++) {
1295
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1296
+ graph_init_tensor(hash_set, node_copies, node_init, node);
1297
+ }
1298
+
1299
+ // build graph copy
1300
+ struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(ctx_allocated, graph->size, false);
1301
+ for (int i = 0; i < graph->n_nodes; i++) {
1302
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1303
+ struct wsp_ggml_tensor * node_copy = node_copies[wsp_ggml_hash_find(hash_set, node)];
1304
+ graph_copy->nodes[i] = node_copy;
1305
+ }
1306
+ graph_copy->n_nodes = graph->n_nodes;
1307
+
1308
+ free(hash_set.keys);
1309
+ free(node_copies);
1310
+ free(node_init);
1311
+
1312
+ return (struct wsp_ggml_backend_graph_copy) {
1313
+ /* .buffer = */ buffer,
1314
+ /* .ctx_allocated = */ ctx_allocated,
1315
+ /* .ctx_unallocated = */ ctx_unallocated,
1316
+ /* .graph = */ graph_copy,
1317
+ };
1318
+ }
1319
+
1320
+ void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy) {
1321
+ wsp_ggml_backend_buffer_free(copy.buffer);
1322
+ wsp_ggml_free(copy.ctx_allocated);
1323
+ wsp_ggml_free(copy.ctx_unallocated);
1324
+ }
1325
+
1326
+ void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) {
1327
+ struct wsp_ggml_backend_graph_copy copy = wsp_ggml_backend_graph_copy(backend2, graph);
1328
+ struct wsp_ggml_cgraph * g1 = graph;
1329
+ struct wsp_ggml_cgraph * g2 = copy.graph;
1330
+
1331
+ assert(g1->n_nodes == g2->n_nodes);
1332
+
1333
+ for (int i = 0; i < g1->n_nodes; i++) {
1334
+ //printf("eval %d/%d\n", i, g1->n_nodes);
1335
+ struct wsp_ggml_tensor * t1 = g1->nodes[i];
1336
+ struct wsp_ggml_tensor * t2 = g2->nodes[i];
1337
+
1338
+ assert(t1->op == t2->op && wsp_ggml_are_same_layout(t1, t2));
1339
+
1340
+ struct wsp_ggml_cgraph g1v = wsp_ggml_graph_view(g1, i, i + 1);
1341
+ struct wsp_ggml_cgraph g2v = wsp_ggml_graph_view(g2, i, i + 1);
1342
+
1343
+ wsp_ggml_backend_graph_compute(backend1, &g1v);
1344
+ wsp_ggml_backend_graph_compute(backend2, &g2v);
1345
+
1346
+ if (wsp_ggml_is_view_op(t1->op)) {
1347
+ continue;
1348
+ }
1349
+
1350
+ // compare results, calculate rms etc
1351
+ if (!callback(i, t1, t2, user_data)) {
1352
+ break;
1353
+ }
1354
+ }
1355
+
1356
+ wsp_ggml_backend_graph_copy_free(copy);
1357
+ }