whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
@@ -9,14 +9,36 @@
9
9
  #include <stdlib.h>
10
10
  #include <string.h>
11
11
 
12
- #define UNUSED WSP_GGML_UNUSED
13
12
 
14
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
15
14
 
15
+
16
+ // backend buffer type
17
+
18
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
19
+ return buft->iface.alloc_buffer(buft, size);
20
+ }
21
+
22
+ size_t wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
23
+ return buft->iface.get_alignment(buft);
24
+ }
25
+
26
+ size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) {
27
+ // get_alloc_size is optional, defaults to wsp_ggml_nbytes
28
+ if (buft->iface.get_alloc_size) {
29
+ return buft->iface.get_alloc_size(buft, tensor);
30
+ }
31
+ return wsp_ggml_nbytes(tensor);
32
+ }
33
+
34
+ bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
35
+ return buft->iface.supports_backend(buft, backend);
36
+ }
37
+
16
38
  // backend buffer
17
39
 
18
40
  wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
19
- struct wsp_ggml_backend * backend,
41
+ wsp_ggml_backend_buffer_type_t buft,
20
42
  struct wsp_ggml_backend_buffer_i iface,
21
43
  wsp_ggml_backend_buffer_context_t context,
22
44
  size_t size) {
@@ -26,7 +48,7 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
26
48
 
27
49
  (*buffer) = (struct wsp_ggml_backend_buffer) {
28
50
  /* .interface = */ iface,
29
- /* .backend = */ backend,
51
+ /* .buft = */ buft,
30
52
  /* .context = */ context,
31
53
  /* .size = */ size,
32
54
  };
@@ -45,10 +67,6 @@ void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) {
45
67
  free(buffer);
46
68
  }
47
69
 
48
- size_t wsp_ggml_backend_buffer_get_alignment(wsp_ggml_backend_buffer_t buffer) {
49
- return wsp_ggml_backend_get_alignment(buffer->backend);
50
- }
51
-
52
70
  size_t wsp_ggml_backend_buffer_get_size(wsp_ggml_backend_buffer_t buffer) {
53
71
  return buffer->size;
54
72
  }
@@ -61,14 +79,6 @@ void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
61
79
  return base;
62
80
  }
63
81
 
64
- size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
65
- // get_alloc_size is optional, defaults to wsp_ggml_nbytes
66
- if (buffer->iface.get_alloc_size) {
67
- return buffer->iface.get_alloc_size(buffer, tensor);
68
- }
69
- return wsp_ggml_nbytes(tensor);
70
- }
71
-
72
82
  void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
73
83
  // init_tensor is optional
74
84
  if (buffer->iface.init_tensor) {
@@ -76,19 +86,20 @@ void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struc
76
86
  }
77
87
  }
78
88
 
79
- void wsp_ggml_backend_buffer_free_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
80
- // free_tensor is optional
81
- if (buffer->iface.free_tensor) {
82
- buffer->iface.free_tensor(buffer, tensor);
83
- }
89
+ size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer) {
90
+ return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type(buffer));
84
91
  }
85
92
 
86
- // backend
93
+ size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
94
+ return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type(buffer), tensor);
95
+ }
87
96
 
88
- wsp_ggml_backend_t wsp_ggml_get_backend(const struct wsp_ggml_tensor * tensor) {
89
- return tensor->buffer ? tensor->buffer->backend : NULL;
97
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_type(wsp_ggml_backend_buffer_t buffer) {
98
+ return buffer->buft;
90
99
  }
91
100
 
101
+ // backend
102
+
92
103
  const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) {
93
104
  if (backend == NULL) {
94
105
  return "NULL";
@@ -104,43 +115,53 @@ void wsp_ggml_backend_free(wsp_ggml_backend_t backend) {
104
115
  backend->iface.free(backend);
105
116
  }
106
117
 
118
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_get_default_buffer_type(wsp_ggml_backend_t backend) {
119
+ return backend->iface.get_default_buffer_type(backend);
120
+ }
121
+
107
122
  wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
108
- return backend->iface.alloc_buffer(backend, size);
123
+ return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_get_default_buffer_type(backend), size);
109
124
  }
110
125
 
111
126
  size_t wsp_ggml_backend_get_alignment(wsp_ggml_backend_t backend) {
112
- return backend->iface.get_alignment(backend);
127
+ return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_get_default_buffer_type(backend));
113
128
  }
114
129
 
115
- void wsp_ggml_backend_tensor_set_async(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
116
- wsp_ggml_get_backend(tensor)->iface.set_tensor_async(wsp_ggml_get_backend(tensor), tensor, data, offset, size);
130
+ void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
131
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
132
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
133
+
134
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
117
135
  }
118
136
 
119
- void wsp_ggml_backend_tensor_get_async(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
120
- wsp_ggml_get_backend(tensor)->iface.get_tensor_async(wsp_ggml_get_backend(tensor), tensor, data, offset, size);
137
+ void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
138
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
139
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
140
+
141
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
121
142
  }
122
143
 
123
144
  void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
124
- wsp_ggml_backend_t backend = wsp_ggml_get_backend(tensor);
125
-
126
145
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
127
- WSP_GGML_ASSERT(backend != NULL && "tensor backend not set");
146
+ WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
147
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
128
148
 
129
- backend->iface.set_tensor_async(backend, tensor, data, offset, size);
130
- backend->iface.synchronize(backend);
149
+ tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
131
150
  }
132
151
 
133
152
  void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
134
- wsp_ggml_backend_t backend = wsp_ggml_get_backend(tensor);
135
-
136
153
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
137
- WSP_GGML_ASSERT(backend != NULL && "tensor backend not set");
154
+ WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
155
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
138
156
 
139
- backend->iface.get_tensor_async(backend, tensor, data, offset, size);
140
- backend->iface.synchronize(backend);
157
+ tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
141
158
  }
142
159
 
143
160
  void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) {
161
+ if (backend->iface.synchronize == NULL) {
162
+ return;
163
+ }
164
+
144
165
  backend->iface.synchronize(backend);
145
166
  }
146
167
 
@@ -154,10 +175,16 @@ void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backe
154
175
 
155
176
  void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
156
177
  backend->iface.graph_plan_compute(backend, plan);
178
+
179
+ // TODO: optional sync
180
+ wsp_ggml_backend_synchronize(backend);
157
181
  }
158
182
 
159
183
  void wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
160
184
  backend->iface.graph_compute(backend, cgraph);
185
+
186
+ // TODO: optional sync
187
+ wsp_ggml_backend_synchronize(backend);
161
188
  }
162
189
 
163
190
  bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
@@ -194,14 +221,15 @@ void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_
194
221
 
195
222
  // TODO: allow backends to support copy to/from same backend
196
223
 
197
- if (wsp_ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
198
- wsp_ggml_get_backend(dst)->iface.cpy_tensor_from(wsp_ggml_get_backend(dst)->context, src, dst);
199
- } else if (wsp_ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
200
- wsp_ggml_get_backend(src)->iface.cpy_tensor_to(wsp_ggml_get_backend(src)->context, src, dst);
224
+ if (dst->buffer->iface.cpy_tensor_from != NULL) {
225
+ dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
226
+ } else if (src->buffer->iface.cpy_tensor_to != NULL) {
227
+ src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
201
228
  } else {
202
229
  // shouldn't be hit when copying from/to CPU
203
230
  #ifndef NDEBUG
204
- fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", wsp_ggml_backend_name(src->buffer->backend), wsp_ggml_backend_name(dst->buffer->backend));
231
+ fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
232
+ "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
205
233
  #endif
206
234
  size_t nbytes = wsp_ggml_nbytes(src);
207
235
  void * data = malloc(nbytes);
@@ -211,101 +239,259 @@ void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_
211
239
  }
212
240
  }
213
241
 
214
- // backend CPU
242
+ // backend registry
215
243
 
216
- struct wsp_ggml_backend_cpu_context {
217
- int n_threads;
218
- void * work_data;
219
- size_t work_size;
244
+ #define WSP_GGML_MAX_BACKENDS_REG 16
245
+
246
+ struct wsp_ggml_backend_reg {
247
+ char name[128];
248
+ wsp_ggml_backend_init_fn init_fn;
249
+ wsp_ggml_backend_buffer_type_t default_buffer_type;
250
+ void * user_data;
220
251
  };
221
252
 
222
- static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
223
- return "CPU";
253
+ static struct wsp_ggml_backend_reg wsp_ggml_backend_registry[WSP_GGML_MAX_BACKENDS_REG];
254
+ static size_t wsp_ggml_backend_registry_count = 0;
255
+
256
+ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data);
257
+
258
+ static void wsp_ggml_backend_registry_init(void) {
259
+ static bool initialized = false;
260
+
261
+ if (initialized) {
262
+ return;
263
+ }
264
+
265
+ initialized = true;
224
266
 
225
- UNUSED(backend);
267
+ wsp_ggml_backend_register("CPU", wsp_ggml_backend_reg_cpu_init, wsp_ggml_backend_cpu_buffer_type(), NULL);
268
+
269
+ // add forward decls here to avoid including the backend headers
270
+ #ifdef WSP_GGML_USE_CUBLAS
271
+ extern void wsp_ggml_backend_cuda_reg_devices(void);
272
+ wsp_ggml_backend_cuda_reg_devices();
273
+ #endif
274
+
275
+ #ifdef WSP_GGML_USE_METAL
276
+ extern wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data);
277
+ extern wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
278
+ wsp_ggml_backend_register("Metal", wsp_ggml_backend_reg_metal_init, wsp_ggml_backend_metal_buffer_type(), NULL);
279
+ #endif
226
280
  }
227
281
 
228
- static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
229
- struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
230
- free(cpu_ctx->work_data);
231
- free(cpu_ctx);
232
- free(backend);
282
+ void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
283
+ WSP_GGML_ASSERT(wsp_ggml_backend_registry_count < WSP_GGML_MAX_BACKENDS_REG);
284
+
285
+ int id = wsp_ggml_backend_registry_count;
286
+
287
+ wsp_ggml_backend_registry[id] = (struct wsp_ggml_backend_reg) {
288
+ /* .name = */ {0},
289
+ /* .fn = */ init_fn,
290
+ /* .default_buffer_type = */ default_buffer_type,
291
+ /* .user_data = */ user_data,
292
+ };
293
+
294
+ snprintf(wsp_ggml_backend_registry[id].name, sizeof(wsp_ggml_backend_registry[id].name), "%s", name);
295
+
296
+ #ifndef NDEBUG
297
+ fprintf(stderr, "%s: registered backend %s\n", __func__, name);
298
+ #endif
299
+
300
+ wsp_ggml_backend_registry_count++;
301
+ }
302
+
303
+ size_t wsp_ggml_backend_reg_get_count(void) {
304
+ wsp_ggml_backend_registry_init();
305
+
306
+ return wsp_ggml_backend_registry_count;
307
+ }
308
+
309
+ size_t wsp_ggml_backend_reg_find_by_name(const char * name) {
310
+ wsp_ggml_backend_registry_init();
311
+
312
+ for (size_t i = 0; i < wsp_ggml_backend_registry_count; i++) {
313
+ // TODO: case insensitive in a portable way
314
+ if (strcmp(wsp_ggml_backend_registry[i].name, name) == 0) {
315
+ return i;
316
+ }
317
+ }
318
+ return SIZE_MAX;
319
+ }
320
+
321
+ // init from backend:params string
322
+ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend_from_str(const char * backend_str) {
323
+ wsp_ggml_backend_registry_init();
324
+
325
+ const char * params = strchr(backend_str, ':');
326
+ char backend_name[128];
327
+ if (params == NULL) {
328
+ strcpy(backend_name, backend_str);
329
+ params = "";
330
+ } else {
331
+ strncpy(backend_name, backend_str, params - backend_str);
332
+ backend_name[params - backend_str] = '\0';
333
+ params++;
334
+ }
335
+
336
+ size_t backend_i = wsp_ggml_backend_reg_find_by_name(backend_name);
337
+ if (backend_i == SIZE_MAX) {
338
+ fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
339
+ return NULL;
340
+ }
341
+
342
+ return wsp_ggml_backend_reg_init_backend(backend_i, params);
343
+ }
344
+
345
+ const char * wsp_ggml_backend_reg_get_name(size_t i) {
346
+ wsp_ggml_backend_registry_init();
347
+
348
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
349
+ return wsp_ggml_backend_registry[i].name;
350
+ }
351
+
352
+ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend(size_t i, const char * params) {
353
+ wsp_ggml_backend_registry_init();
354
+
355
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
356
+ return wsp_ggml_backend_registry[i].init_fn(params, wsp_ggml_backend_registry[i].user_data);
357
+ }
358
+
359
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_reg_get_default_buffer_type(size_t i) {
360
+ wsp_ggml_backend_registry_init();
361
+
362
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
363
+ return wsp_ggml_backend_registry[i].default_buffer_type;
364
+ }
365
+
366
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
367
+ wsp_ggml_backend_registry_init();
368
+
369
+ WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
370
+ return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_registry[i].default_buffer_type, size);
233
371
  }
234
372
 
373
+ // backend CPU
374
+
235
375
  static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
236
376
  return (void *)buffer->context;
237
377
  }
238
378
 
239
379
  static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
240
380
  free(buffer->context);
241
- UNUSED(buffer);
381
+ WSP_GGML_UNUSED(buffer);
382
+ }
383
+
384
+ static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
385
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
386
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
387
+
388
+ memcpy((char *)tensor->data + offset, data, size);
389
+
390
+ WSP_GGML_UNUSED(buffer);
391
+ }
392
+
393
+ static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
394
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
395
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
396
+
397
+ memcpy(data, (const char *)tensor->data + offset, size);
398
+
399
+ WSP_GGML_UNUSED(buffer);
400
+ }
401
+
402
+ static void wsp_ggml_backend_cpu_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
403
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
404
+
405
+ WSP_GGML_UNUSED(buffer);
406
+ }
407
+
408
+ static void wsp_ggml_backend_cpu_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
409
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
410
+
411
+ WSP_GGML_UNUSED(buffer);
242
412
  }
243
413
 
244
414
  static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = {
245
- /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer,
246
- /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
247
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
248
- /* .init_tensor = */ NULL, // no initialization required
249
- /* .free_tensor = */ NULL, // no cleanup required
415
+ /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer,
416
+ /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
417
+ /* .init_tensor = */ NULL, // no initialization required
418
+ /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
419
+ /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
420
+ /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
421
+ /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
250
422
  };
251
423
 
252
424
  // for buffers from ptr, free is not called
253
425
  static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
254
- /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
255
- /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
256
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
257
- /* .init_tensor = */ NULL,
258
- /* .free_tensor = */ NULL,
426
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
427
+ /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
428
+ /* .init_tensor = */ NULL, // no initialization required
429
+ /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
430
+ /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
431
+ /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
432
+ /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
259
433
  };
260
434
 
261
435
  static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
262
436
 
263
- static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
437
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
264
438
  size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
265
439
  void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC?
266
440
 
267
441
  WSP_GGML_ASSERT(data != NULL && "failed to allocate buffer");
268
442
 
269
- return wsp_ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
443
+ return wsp_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
270
444
  }
271
445
 
272
- static size_t wsp_ggml_backend_cpu_get_alignment(wsp_ggml_backend_t backend) {
446
+ static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
273
447
  return TENSOR_ALIGNMENT;
274
- UNUSED(backend);
275
- }
276
448
 
277
- static void wsp_ggml_backend_cpu_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
278
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
279
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
449
+ WSP_GGML_UNUSED(buft);
450
+ }
280
451
 
281
- memcpy((char *)tensor->data + offset, data, size);
452
+ static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
453
+ return wsp_ggml_backend_is_cpu(backend);
282
454
 
283
- UNUSED(backend);
455
+ WSP_GGML_UNUSED(buft);
284
456
  }
285
457
 
286
- static void wsp_ggml_backend_cpu_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
287
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
288
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
289
-
290
- memcpy(data, (const char *)tensor->data + offset, size);
458
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) {
459
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_cpu = {
460
+ /* .iface = */ {
461
+ /* .alloc_buffer = */ wsp_ggml_backend_cpu_buffer_type_alloc_buffer,
462
+ /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
463
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
464
+ /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
465
+ },
466
+ /* .context = */ NULL,
467
+ };
291
468
 
292
- UNUSED(backend);
469
+ return &wsp_ggml_backend_buffer_type_cpu;
293
470
  }
294
471
 
295
- static void wsp_ggml_backend_cpu_synchronize(wsp_ggml_backend_t backend) {
296
- UNUSED(backend);
297
- }
472
+ struct wsp_ggml_backend_cpu_context {
473
+ int n_threads;
474
+ void * work_data;
475
+ size_t work_size;
476
+ };
298
477
 
299
- static void wsp_ggml_backend_cpu_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
300
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
478
+ static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
479
+ return "CPU";
301
480
 
302
- UNUSED(backend);
481
+ WSP_GGML_UNUSED(backend);
303
482
  }
304
483
 
305
- static void wsp_ggml_backend_cpu_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
306
- wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
484
+ static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
485
+ struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
486
+ free(cpu_ctx->work_data);
487
+ free(cpu_ctx);
488
+ free(backend);
489
+ }
490
+
491
+ static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) {
492
+ return wsp_ggml_backend_cpu_buffer_type();
307
493
 
308
- UNUSED(backend);
494
+ WSP_GGML_UNUSED(backend);
309
495
  }
310
496
 
311
497
  struct wsp_ggml_backend_plan_cpu {
@@ -334,7 +520,7 @@ static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp
334
520
  free(cpu_plan->cplan.work_data);
335
521
  free(cpu_plan);
336
522
 
337
- UNUSED(backend);
523
+ WSP_GGML_UNUSED(backend);
338
524
  }
339
525
 
340
526
  static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
@@ -342,7 +528,7 @@ static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend,
342
528
 
343
529
  wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
344
530
 
345
- UNUSED(backend);
531
+ WSP_GGML_UNUSED(backend);
346
532
  }
347
533
 
348
534
  static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
@@ -363,25 +549,25 @@ static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struc
363
549
 
364
550
  static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
365
551
  return true;
366
- UNUSED(backend);
367
- UNUSED(op);
552
+
553
+ WSP_GGML_UNUSED(backend);
554
+ WSP_GGML_UNUSED(op);
368
555
  }
369
556
 
370
557
  static struct wsp_ggml_backend_i cpu_backend_i = {
371
- /* .get_name = */ wsp_ggml_backend_cpu_name,
372
- /* .free = */ wsp_ggml_backend_cpu_free,
373
- /* .alloc_buffer = */ wsp_ggml_backend_cpu_alloc_buffer,
374
- /* .get_alignment = */ wsp_ggml_backend_cpu_get_alignment,
375
- /* .set_tensor_async = */ wsp_ggml_backend_cpu_set_tensor_async,
376
- /* .get_tensor_async = */ wsp_ggml_backend_cpu_get_tensor_async,
377
- /* .synchronize = */ wsp_ggml_backend_cpu_synchronize,
378
- /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_cpy_tensor_from,
379
- /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_cpy_tensor_to,
380
- /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create,
381
- /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free,
382
- /* .graph_plan_compute = */ wsp_ggml_backend_cpu_graph_plan_compute,
383
- /* .graph_compute = */ wsp_ggml_backend_cpu_graph_compute,
384
- /* .supports_op = */ wsp_ggml_backend_cpu_supports_op,
558
+ /* .get_name = */ wsp_ggml_backend_cpu_name,
559
+ /* .free = */ wsp_ggml_backend_cpu_free,
560
+ /* .get_default_buffer_type = */ wsp_ggml_backend_cpu_get_default_buffer_type,
561
+ /* .set_tensor_async = */ NULL,
562
+ /* .get_tensor_async = */ NULL,
563
+ /* .cpy_tensor_from_async = */ NULL,
564
+ /* .cpy_tensor_to_async = */ NULL,
565
+ /* .synchronize = */ NULL,
566
+ /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create,
567
+ /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free,
568
+ /* .graph_plan_compute = */ wsp_ggml_backend_cpu_graph_plan_compute,
569
+ /* .graph_compute = */ wsp_ggml_backend_cpu_graph_compute,
570
+ /* .supports_op = */ wsp_ggml_backend_cpu_supports_op,
385
571
  };
386
572
 
387
573
  wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) {
@@ -411,10 +597,18 @@ void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_th
411
597
  ctx->n_threads = n_threads;
412
598
  }
413
599
 
414
- wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(wsp_ggml_backend_t backend_cpu, void * ptr, size_t size) {
415
- return wsp_ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
600
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
601
+ return wsp_ggml_backend_buffer_init(wsp_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
602
+ }
603
+
604
+ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) {
605
+ return wsp_ggml_backend_cpu_init();
606
+
607
+ WSP_GGML_UNUSED(params);
608
+ WSP_GGML_UNUSED(user_data);
416
609
  }
417
610
 
611
+
418
612
  // scheduler
419
613
 
420
614
  #define WSP_GGML_MAX_BACKENDS 4
@@ -427,7 +621,7 @@ struct wsp_ggml_backend_sched_split {
427
621
  int i_end;
428
622
  struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS];
429
623
  int n_inputs;
430
- struct wsp_ggml_cgraph * graph;
624
+ struct wsp_ggml_cgraph graph;
431
625
  };
432
626
 
433
627
  struct wsp_ggml_backend_sched {
@@ -453,7 +647,7 @@ struct wsp_ggml_backend_sched {
453
647
  #else
454
648
  __attribute__((aligned(WSP_GGML_MEM_ALIGN)))
455
649
  #endif
456
- char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + WSP_GGML_MAX_SPLITS*sizeof(struct wsp_ggml_cgraph)];
650
+ char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + sizeof(struct wsp_ggml_cgraph)];
457
651
  };
458
652
 
459
653
  #define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node)
@@ -482,23 +676,57 @@ static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t
482
676
  return INT_MAX;
483
677
  }
484
678
 
679
+ static wsp_ggml_backend_t get_buffer_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) {
680
+ if (buffer == NULL) {
681
+ return NULL;
682
+ }
683
+ // find highest prio backend that supports the buffer type
684
+ for (int i = 0; i < sched->n_backends; i++) {
685
+ if (wsp_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
686
+ return sched->backends[i];
687
+ }
688
+ }
689
+ WSP_GGML_ASSERT(false && "tensor buffer type not supported by any backend");
690
+ }
691
+
692
+ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) {
693
+ if (allocr == NULL) {
694
+ return NULL;
695
+ }
696
+ // find highest prio backend that supports the buffer type
697
+ for (int i = 0; i < sched->n_backends; i++) {
698
+ if (sched->tallocs[i] == allocr) {
699
+ return sched->backends[i];
700
+ }
701
+ }
702
+ WSP_GGML_UNREACHABLE();
703
+ }
704
+
705
+ #if 0
706
+ static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
707
+ #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
708
+ #define GET_CAUSE(node) causes[hash_id(node)]
709
+ #else
710
+ #define SET_CAUSE(node, ...)
711
+ #define GET_CAUSE(node) ""
712
+ #endif
713
+
485
714
  // returns the backend that should be used for the node based on the current locations
486
- char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*4 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
487
715
  static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
488
716
  // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
489
717
  // ie. kv cache updates
490
718
  // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
491
719
  // dst
492
- wsp_ggml_backend_t cur_backend = wsp_ggml_get_backend(node);
720
+ wsp_ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
493
721
  if (cur_backend != NULL) {
494
- sprintf(causes[hash_id(node)], "1.dst");
722
+ SET_CAUSE(node, "1.dst");
495
723
  return cur_backend;
496
724
  }
497
725
 
498
726
  // view_src
499
- if (node->view_src != NULL && wsp_ggml_get_backend(node->view_src) != NULL) {
500
- sprintf(causes[hash_id(node)], "1.vsrc");
501
- return wsp_ggml_get_backend(node->view_src);
727
+ if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
728
+ SET_CAUSE(node, "1.vsrc");
729
+ return get_buffer_backend(sched, node->view_src->buffer);
502
730
  }
503
731
 
504
732
  // src
@@ -510,7 +738,7 @@ static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched,
510
738
  if (src == NULL) {
511
739
  break;
512
740
  }
513
- wsp_ggml_backend_t src_backend = wsp_ggml_get_backend(src);
741
+ wsp_ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
514
742
  if (src_backend != NULL) {
515
743
  int src_prio = sched_backend_prio(sched, src_backend);
516
744
  size_t src_size = wsp_ggml_nbytes(src);
@@ -518,7 +746,7 @@ static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched,
518
746
  cur_prio = src_prio;
519
747
  cur_size = src_size;
520
748
  cur_backend = src_backend;
521
- sprintf(causes[hash_id(node)], "1.src%d", i);
749
+ SET_CAUSE(node, "1.src%d", i);
522
750
  }
523
751
  }
524
752
  }
@@ -539,10 +767,12 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g
539
767
  int cur_split = 0;
540
768
  for (int i = 0; i < graph->n_nodes; i++) {
541
769
  if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
542
- wsp_ggml_backend_t split_backend = wsp_ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
543
- fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
770
+ wsp_ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
771
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend),
772
+ sched->splits[cur_split].n_inputs);
544
773
  for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
545
- fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j])));
774
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
775
+ fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j])));
546
776
  }
547
777
  fprintf(stderr, "\n");
548
778
  cur_split++;
@@ -552,16 +782,18 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g
552
782
  continue;
553
783
  }
554
784
  wsp_ggml_tallocr_t node_allocr = node_allocr(node);
555
- wsp_ggml_backend_t node_backend = node_allocr ? wsp_ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
556
- fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name, fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
785
+ wsp_ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
786
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name,
787
+ fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
557
788
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
558
789
  struct wsp_ggml_tensor * src = node->src[j];
559
790
  if (src == NULL) {
560
791
  break;
561
792
  }
562
793
  wsp_ggml_tallocr_t src_allocr = node_allocr(src);
563
- wsp_ggml_backend_t src_backend = src_allocr ? wsp_ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
564
- fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
794
+ wsp_ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
795
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
796
+ fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
565
797
  }
566
798
  fprintf(stderr, "\n");
567
799
  }
@@ -587,9 +819,9 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
587
819
  sched->n_splits = 0;
588
820
 
589
821
  struct wsp_ggml_init_params params = {
590
- /*.mem_size = */ sizeof(sched->context_buffer),
591
- /*.mem_buffer = */ sched->context_buffer,
592
- /*.no_alloc = */ true
822
+ /* .mem_size = */ sizeof(sched->context_buffer),
823
+ /* .mem_buffer = */ sched->context_buffer,
824
+ /* .no_alloc = */ true
593
825
  };
594
826
 
595
827
  if (sched->ctx != NULL) {
@@ -605,9 +837,9 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
605
837
  // do not overwrite user assignments
606
838
  continue;
607
839
  }
608
- wsp_ggml_backend_t leaf_backend = wsp_ggml_get_backend(leaf);
840
+ wsp_ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
609
841
  if (leaf_backend == NULL && leaf->view_src != NULL) {
610
- leaf_backend = wsp_ggml_get_backend(leaf->view_src);
842
+ leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
611
843
  }
612
844
  if (leaf_backend != NULL) {
613
845
  node_allocr(leaf) = wsp_ggml_backend_sched_get_tallocr(sched, leaf_backend);
@@ -649,7 +881,7 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
649
881
  cur_prio = src_prio;
650
882
  cur_size = src_size;
651
883
  node_allocr = src_allocr;
652
- sprintf(causes[hash_id(node)], "2.src%d", j);
884
+ SET_CAUSE(node, "2.src%d", j);
653
885
  }
654
886
  }
655
887
  }
@@ -733,7 +965,7 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
733
965
  struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src);
734
966
  sched->node_copies[id][cur_backend_id] = tensor_copy;
735
967
  node_allocr(tensor_copy) = cur_allocr;
736
- wsp_ggml_backend_t backend = wsp_ggml_tallocr_get_buffer(cur_allocr)->backend;
968
+ wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
737
969
  wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name);
738
970
  }
739
971
  node->src[j] = sched->node_copies[id][cur_backend_id];
@@ -761,8 +993,8 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
761
993
  wsp_ggml_tallocr_t src_allocr = node_allocr(src);
762
994
  if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
763
995
  fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
764
- node->name, node_allocr ? wsp_ggml_backend_name(wsp_ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
765
- j, src->name, src_allocr ? wsp_ggml_backend_name(wsp_ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
996
+ node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
997
+ j, src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
766
998
  }
767
999
  }
768
1000
  }
@@ -773,7 +1005,7 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
773
1005
  struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*WSP_GGML_MAX_SPLIT_INPUTS, false);
774
1006
  for (int i = 0; i < sched->n_splits; i++) {
775
1007
  struct wsp_ggml_backend_sched_split * split = &sched->splits[i];
776
- split->graph = wsp_ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
1008
+ split->graph = wsp_ggml_graph_view(graph, split->i_start, split->i_end);
777
1009
 
778
1010
  // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
779
1011
  for (int j = 0; j < split->n_inputs; j++) {
@@ -806,31 +1038,29 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
806
1038
 
807
1039
  for (int i = 0; i < sched->n_splits; i++) {
808
1040
  struct wsp_ggml_backend_sched_split * split = &splits[i];
809
- wsp_ggml_backend_t split_backend = wsp_ggml_tallocr_get_buffer(split->tallocr)->backend;
1041
+ wsp_ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
810
1042
  int split_backend_id = sched_backend_prio(sched, split_backend);
811
1043
 
812
1044
  // copy the input tensors to the split backend
813
1045
  uint64_t copy_start_us = wsp_ggml_time_us();
814
1046
  for (int j = 0; j < split->n_inputs; j++) {
815
- struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
816
- if (split->inputs[j]->buffer == NULL) {
817
- if (split->inputs[j]->view_src == NULL) {
818
- fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
1047
+ struct wsp_ggml_tensor * input = split->inputs[j];
1048
+ struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
1049
+ if (input->buffer == NULL) {
1050
+ if (input->view_src == NULL) {
1051
+ fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
819
1052
  exit(1);
820
1053
  }
821
- struct wsp_ggml_tensor * view = split->inputs[j];
822
- view->backend = view->view_src->backend;
823
- view->buffer = view->view_src->buffer;
824
- view->data = (char *)view->view_src->data + view->view_offs;
825
- wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
1054
+ // FIXME: may need to use the sched buffer instead
1055
+ wsp_ggml_backend_view_init(input->view_src->buffer, input);
826
1056
  }
827
1057
  if (input_cpy->buffer == NULL) {
828
1058
  fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
829
1059
  exit(1);
830
1060
  }
831
- WSP_GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
832
- WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend);
833
- wsp_ggml_backend_tensor_copy(split->inputs[j], input_cpy);
1061
+ //WSP_GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
1062
+ //WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend);
1063
+ wsp_ggml_backend_tensor_copy(input, input_cpy);
834
1064
  }
835
1065
  // wsp_ggml_backend_synchronize(split_backend);
836
1066
  int64_t copy_end_us = wsp_ggml_time_us();
@@ -843,7 +1073,7 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
843
1073
  #endif
844
1074
 
845
1075
  uint64_t compute_start_us = wsp_ggml_time_us();
846
- wsp_ggml_backend_graph_compute(split_backend, split->graph);
1076
+ wsp_ggml_backend_graph_compute(split_backend, &split->graph);
847
1077
  // wsp_ggml_backend_synchronize(split_backend);
848
1078
  uint64_t compute_end_us = wsp_ggml_time_us();
849
1079
  compute_us[split_backend_id] += compute_end_us - compute_start_us;
@@ -872,8 +1102,6 @@ wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backend
872
1102
  struct wsp_ggml_backend_sched * sched = malloc(sizeof(struct wsp_ggml_backend_sched));
873
1103
  memset(sched, 0, sizeof(struct wsp_ggml_backend_sched));
874
1104
 
875
- fprintf(stderr, "wsp_ggml_backend_sched size: %lu KB\n", sizeof(struct wsp_ggml_backend_sched)/1024);
876
-
877
1105
  sched->n_backends = n_backends;
878
1106
  for (int i = 0; i < n_backends; i++) {
879
1107
  sched->backends[i] = backends[i];
@@ -948,3 +1176,182 @@ void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, str
948
1176
  WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
949
1177
  node_allocr(node) = sched->tallocs[backend_index];
950
1178
  }
1179
+
1180
+ // utils
1181
+ void wsp_ggml_backend_view_init(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
1182
+ WSP_GGML_ASSERT(tensor->buffer == NULL);
1183
+ WSP_GGML_ASSERT(tensor->data == NULL);
1184
+ WSP_GGML_ASSERT(tensor->view_src != NULL);
1185
+ WSP_GGML_ASSERT(tensor->view_src->buffer != NULL);
1186
+ WSP_GGML_ASSERT(tensor->view_src->data != NULL);
1187
+
1188
+ tensor->buffer = buffer;
1189
+ tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1190
+ tensor->backend = tensor->view_src->backend;
1191
+ wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1192
+ }
1193
+
1194
+ void wsp_ggml_backend_tensor_alloc(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, void * addr) {
1195
+ WSP_GGML_ASSERT(tensor->buffer == NULL);
1196
+ WSP_GGML_ASSERT(tensor->data == NULL);
1197
+ WSP_GGML_ASSERT(tensor->view_src == NULL);
1198
+ WSP_GGML_ASSERT(addr >= wsp_ggml_backend_buffer_get_base(buffer));
1199
+ WSP_GGML_ASSERT((char *)addr + wsp_ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
1200
+ (char *)wsp_ggml_backend_buffer_get_base(buffer) + wsp_ggml_backend_buffer_get_size(buffer));
1201
+
1202
+ tensor->buffer = buffer;
1203
+ tensor->data = addr;
1204
+ wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1205
+ }
1206
+
1207
+ static struct wsp_ggml_tensor * graph_dup_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies,
1208
+ struct wsp_ggml_context * ctx_allocated, struct wsp_ggml_context * ctx_unallocated, struct wsp_ggml_tensor * src) {
1209
+
1210
+ WSP_GGML_ASSERT(src != NULL);
1211
+ WSP_GGML_ASSERT(src->data && "graph must be allocated");
1212
+
1213
+ size_t id = wsp_ggml_hash_insert(hash_set, src);
1214
+ if (id == WSP_GGML_HASHTABLE_ALREADY_EXISTS) {
1215
+ return node_copies[wsp_ggml_hash_find(hash_set, src)];
1216
+ }
1217
+
1218
+ struct wsp_ggml_tensor * dst = wsp_ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
1219
+ if (src->view_src != NULL) {
1220
+ dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
1221
+ dst->view_offs = src->view_offs;
1222
+ }
1223
+ dst->op = src->op;
1224
+ memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
1225
+ wsp_ggml_set_name(dst, src->name);
1226
+
1227
+ // copy src
1228
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1229
+ struct wsp_ggml_tensor * s = src->src[i];
1230
+ if (s == NULL) {
1231
+ break;
1232
+ }
1233
+ dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
1234
+ }
1235
+
1236
+ node_copies[id] = dst;
1237
+ return dst;
1238
+ }
1239
+
1240
+ static void graph_init_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies, bool * node_init, struct wsp_ggml_tensor * src) {
1241
+ size_t id = wsp_ggml_hash_find(hash_set, src);
1242
+ if (node_init[id]) {
1243
+ return;
1244
+ }
1245
+ node_init[id] = true;
1246
+
1247
+ struct wsp_ggml_tensor * dst = node_copies[id];
1248
+ if (dst->view_src != NULL) {
1249
+ wsp_ggml_backend_view_init(dst->view_src->buffer, dst);
1250
+ }
1251
+ else {
1252
+ wsp_ggml_backend_tensor_copy(src, dst);
1253
+ }
1254
+
1255
+ // init src
1256
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1257
+ struct wsp_ggml_tensor * s = src->src[i];
1258
+ if (s == NULL) {
1259
+ break;
1260
+ }
1261
+ graph_init_tensor(hash_set, node_copies, node_init, s);
1262
+ }
1263
+ }
1264
+
1265
+ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * graph) {
1266
+ struct wsp_ggml_hash_set hash_set = {
1267
+ /* .size = */ graph->visited_hash_table.size,
1268
+ /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
1269
+ };
1270
+ struct wsp_ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
1271
+ bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
1272
+
1273
+ struct wsp_ggml_init_params params = {
1274
+ /* .mem_size = */ wsp_ggml_tensor_overhead()*hash_set.size + wsp_ggml_graph_overhead_custom(graph->size, false),
1275
+ /* .mem_buffer = */ NULL,
1276
+ /* .no_alloc = */ true
1277
+ };
1278
+
1279
+ struct wsp_ggml_context * ctx_allocated = wsp_ggml_init(params);
1280
+ struct wsp_ggml_context * ctx_unallocated = wsp_ggml_init(params);
1281
+
1282
+ // dup nodes
1283
+ for (int i = 0; i < graph->n_nodes; i++) {
1284
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1285
+ graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
1286
+ }
1287
+
1288
+ // allocate nodes
1289
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
1290
+
1291
+ //printf("copy buffer size: %zu MB\n", wsp_ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
1292
+
1293
+ // copy data and init views
1294
+ for (int i = 0; i < graph->n_nodes; i++) {
1295
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1296
+ graph_init_tensor(hash_set, node_copies, node_init, node);
1297
+ }
1298
+
1299
+ // build graph copy
1300
+ struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(ctx_allocated, graph->size, false);
1301
+ for (int i = 0; i < graph->n_nodes; i++) {
1302
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1303
+ struct wsp_ggml_tensor * node_copy = node_copies[wsp_ggml_hash_find(hash_set, node)];
1304
+ graph_copy->nodes[i] = node_copy;
1305
+ }
1306
+ graph_copy->n_nodes = graph->n_nodes;
1307
+
1308
+ free(hash_set.keys);
1309
+ free(node_copies);
1310
+ free(node_init);
1311
+
1312
+ return (struct wsp_ggml_backend_graph_copy) {
1313
+ /* .buffer = */ buffer,
1314
+ /* .ctx_allocated = */ ctx_allocated,
1315
+ /* .ctx_unallocated = */ ctx_unallocated,
1316
+ /* .graph = */ graph_copy,
1317
+ };
1318
+ }
1319
+
1320
+ void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy) {
1321
+ wsp_ggml_backend_buffer_free(copy.buffer);
1322
+ wsp_ggml_free(copy.ctx_allocated);
1323
+ wsp_ggml_free(copy.ctx_unallocated);
1324
+ }
1325
+
1326
+ void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) {
1327
+ struct wsp_ggml_backend_graph_copy copy = wsp_ggml_backend_graph_copy(backend2, graph);
1328
+ struct wsp_ggml_cgraph * g1 = graph;
1329
+ struct wsp_ggml_cgraph * g2 = copy.graph;
1330
+
1331
+ assert(g1->n_nodes == g2->n_nodes);
1332
+
1333
+ for (int i = 0; i < g1->n_nodes; i++) {
1334
+ //printf("eval %d/%d\n", i, g1->n_nodes);
1335
+ struct wsp_ggml_tensor * t1 = g1->nodes[i];
1336
+ struct wsp_ggml_tensor * t2 = g2->nodes[i];
1337
+
1338
+ assert(t1->op == t2->op && wsp_ggml_are_same_layout(t1, t2));
1339
+
1340
+ struct wsp_ggml_cgraph g1v = wsp_ggml_graph_view(g1, i, i + 1);
1341
+ struct wsp_ggml_cgraph g2v = wsp_ggml_graph_view(g2, i, i + 1);
1342
+
1343
+ wsp_ggml_backend_graph_compute(backend1, &g1v);
1344
+ wsp_ggml_backend_graph_compute(backend2, &g2v);
1345
+
1346
+ if (wsp_ggml_is_view_op(t1->op)) {
1347
+ continue;
1348
+ }
1349
+
1350
+ // compare results, calculate rms etc
1351
+ if (!callback(i, t1, t2, user_data)) {
1352
+ break;
1353
+ }
1354
+ }
1355
+
1356
+ wsp_ggml_backend_graph_copy_free(copy);
1357
+ }