whisper.rn 0.4.0-rc.6 → 0.4.0-rc.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,11 @@
15
15
 
16
16
  // backend buffer type
17
17
 
18
- wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
18
+ const char * wsp_ggml_backend_buft_name(wsp_ggml_backend_buffer_type_t buft) {
19
+ return buft->iface.get_name(buft);
20
+ }
21
+
22
+ WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
19
23
  return buft->iface.alloc_buffer(buft, size);
20
24
  }
21
25
 
@@ -23,7 +27,7 @@ size_t wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type_t buft)
23
27
  return buft->iface.get_alignment(buft);
24
28
  }
25
29
 
26
- size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) {
30
+ WSP_GGML_CALL size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) {
27
31
  // get_alloc_size is optional, defaults to wsp_ggml_nbytes
28
32
  if (buft->iface.get_alloc_size) {
29
33
  return buft->iface.get_alloc_size(buft, tensor);
@@ -35,9 +39,16 @@ bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft,
35
39
  return buft->iface.supports_backend(buft, backend);
36
40
  }
37
41
 
42
+ bool wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_type_t buft) {
43
+ if (buft->iface.is_host) {
44
+ return buft->iface.is_host(buft);
45
+ }
46
+ return false;
47
+ }
48
+
38
49
  // backend buffer
39
50
 
40
- wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
51
+ WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
41
52
  wsp_ggml_backend_buffer_type_t buft,
42
53
  struct wsp_ggml_backend_buffer_i iface,
43
54
  wsp_ggml_backend_buffer_context_t context,
@@ -51,11 +62,16 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
51
62
  /* .buft = */ buft,
52
63
  /* .context = */ context,
53
64
  /* .size = */ size,
65
+ /* .usage = */ WSP_GGML_BACKEND_BUFFER_USAGE_ANY
54
66
  };
55
67
 
56
68
  return buffer;
57
69
  }
58
70
 
71
+ const char * wsp_ggml_backend_buffer_name(wsp_ggml_backend_buffer_t buffer) {
72
+ return buffer->iface.get_name(buffer);
73
+ }
74
+
59
75
  void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) {
60
76
  if (buffer == NULL) {
61
77
  return;
@@ -79,7 +95,7 @@ void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
79
95
  return base;
80
96
  }
81
97
 
82
- void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
98
+ WSP_GGML_CALL void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
83
99
  // init_tensor is optional
84
100
  if (buffer->iface.init_tensor) {
85
101
  buffer->iface.init_tensor(buffer, tensor);
@@ -87,17 +103,43 @@ void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struc
87
103
  }
88
104
 
89
105
  size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer) {
90
- return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type(buffer));
106
+ return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_get_type(buffer));
91
107
  }
92
108
 
93
109
  size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
94
- return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type(buffer), tensor);
110
+ return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_get_type(buffer), tensor);
111
+ }
112
+
113
+ void wsp_ggml_backend_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
114
+ buffer->iface.clear(buffer, value);
95
115
  }
96
116
 
97
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_type(wsp_ggml_backend_buffer_t buffer) {
117
+ bool wsp_ggml_backend_buffer_is_host(wsp_ggml_backend_buffer_t buffer) {
118
+ return wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_get_type(buffer));
119
+ }
120
+
121
+ void wsp_ggml_backend_buffer_set_usage(wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage) {
122
+ buffer->usage = usage;
123
+ }
124
+
125
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_get_type(wsp_ggml_backend_buffer_t buffer) {
98
126
  return buffer->buft;
99
127
  }
100
128
 
129
+ void wsp_ggml_backend_buffer_reset(wsp_ggml_backend_buffer_t buffer) {
130
+ if (buffer->iface.reset) {
131
+ buffer->iface.reset(buffer);
132
+ }
133
+ }
134
+
135
+ bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
136
+ wsp_ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
137
+ if (dst_buf->iface.cpy_tensor) {
138
+ return src->buffer->iface.cpy_tensor(dst_buf, src, dst);
139
+ }
140
+ return false;
141
+ }
142
+
101
143
  // backend
102
144
 
103
145
  const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) {
@@ -131,30 +173,42 @@ void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_gg
131
173
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
132
174
  WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
133
175
 
134
- backend->iface.set_tensor_async(backend, tensor, data, offset, size);
176
+ if (backend->iface.set_tensor_async == NULL) {
177
+ wsp_ggml_backend_tensor_set(tensor, data, offset, size);
178
+ } else {
179
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
180
+ }
135
181
  }
136
182
 
137
183
  void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
138
184
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
139
185
  WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
140
186
 
141
- backend->iface.get_tensor_async(backend, tensor, data, offset, size);
187
+ if (backend->iface.get_tensor_async == NULL) {
188
+ wsp_ggml_backend_tensor_get(tensor, data, offset, size);
189
+ } else {
190
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
191
+ }
142
192
  }
143
193
 
144
- void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
194
+ WSP_GGML_CALL void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
195
+ wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
196
+
145
197
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
146
- WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
198
+ WSP_GGML_ASSERT(buf != NULL && "tensor buffer not set");
147
199
  WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
148
200
 
149
- tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
201
+ tensor->buffer->iface.set_tensor(buf, tensor, data, offset, size);
150
202
  }
151
203
 
152
- void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
204
+ WSP_GGML_CALL void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
205
+ wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
206
+
153
207
  WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
154
208
  WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
155
209
  WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
156
210
 
157
- tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
211
+ tensor->buffer->iface.get_tensor(buf, tensor, data, offset, size);
158
212
  }
159
213
 
160
214
  void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) {
@@ -175,16 +229,10 @@ void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backe
175
229
 
176
230
  void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
177
231
  backend->iface.graph_plan_compute(backend, plan);
178
-
179
- // TODO: optional sync
180
- wsp_ggml_backend_synchronize(backend);
181
232
  }
182
233
 
183
- void wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
184
- backend->iface.graph_compute(backend, cgraph);
185
-
186
- // TODO: optional sync
187
- wsp_ggml_backend_synchronize(backend);
234
+ bool wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
235
+ return backend->iface.graph_compute(backend, cgraph);
188
236
  }
189
237
 
190
238
  bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
@@ -209,28 +257,20 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str
209
257
  }
210
258
 
211
259
  void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
212
- //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
213
- //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
214
260
  WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
215
261
 
216
- // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, wsp_ggml_backend_name(src->backend), wsp_ggml_backend_name(dst->backend), wsp_ggml_nbytes(src));
217
-
218
262
  if (src == dst) {
219
263
  return;
220
264
  }
221
265
 
222
- // TODO: allow backends to support copy to/from same backend
223
-
224
- if (dst->buffer->iface.cpy_tensor_from != NULL) {
225
- dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
226
- } else if (src->buffer->iface.cpy_tensor_to != NULL) {
227
- src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
228
- } else {
229
- // shouldn't be hit when copying from/to CPU
230
- #ifndef NDEBUG
231
- fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
232
- "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
233
- #endif
266
+ if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
267
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
268
+ } else if (wsp_ggml_backend_buffer_is_host(dst->buffer)) {
269
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
270
+ } else if (!wsp_ggml_backend_buffer_copy_tensor(src, dst)) {
271
+ #ifndef NDEBUG
272
+ fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, wsp_ggml_backend_buffer_name(src->buffer), wsp_ggml_backend_buffer_name(dst->buffer));
273
+ #endif
234
274
  size_t nbytes = wsp_ggml_nbytes(src);
235
275
  void * data = malloc(nbytes);
236
276
  wsp_ggml_backend_tensor_get(src, data, 0, nbytes);
@@ -239,6 +279,31 @@ void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_
239
279
  }
240
280
  }
241
281
 
282
+ void wsp_ggml_backend_tensor_copy_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
283
+ WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
284
+
285
+ if (src == dst) {
286
+ return;
287
+ }
288
+
289
+ if (wsp_ggml_backend_buft_supports_backend(src->buffer->buft, backend) && wsp_ggml_backend_buft_supports_backend(dst->buffer->buft, backend)) {
290
+ if (backend->iface.cpy_tensor_async != NULL) {
291
+ if (backend->iface.cpy_tensor_async(backend, src, dst)) {
292
+ return;
293
+ }
294
+ }
295
+ }
296
+
297
+ size_t nbytes = wsp_ggml_nbytes(src);
298
+ if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
299
+ wsp_ggml_backend_tensor_set_async(backend, dst, src->data, 0, nbytes);
300
+ }
301
+ else {
302
+ wsp_ggml_backend_tensor_copy(src, dst);
303
+ }
304
+ }
305
+
306
+
242
307
  // backend registry
243
308
 
244
309
  #define WSP_GGML_MAX_BACKENDS_REG 16
@@ -253,9 +318,9 @@ struct wsp_ggml_backend_reg {
253
318
  static struct wsp_ggml_backend_reg wsp_ggml_backend_registry[WSP_GGML_MAX_BACKENDS_REG];
254
319
  static size_t wsp_ggml_backend_registry_count = 0;
255
320
 
256
- static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data);
321
+ WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data);
257
322
 
258
- static void wsp_ggml_backend_registry_init(void) {
323
+ WSP_GGML_CALL static void wsp_ggml_backend_registry_init(void) {
259
324
  static bool initialized = false;
260
325
 
261
326
  if (initialized) {
@@ -268,21 +333,21 @@ static void wsp_ggml_backend_registry_init(void) {
268
333
 
269
334
  // add forward decls here to avoid including the backend headers
270
335
  #ifdef WSP_GGML_USE_CUBLAS
271
- extern void wsp_ggml_backend_cuda_reg_devices(void);
336
+ extern WSP_GGML_CALL void wsp_ggml_backend_cuda_reg_devices(void);
272
337
  wsp_ggml_backend_cuda_reg_devices();
273
338
  #endif
274
339
 
275
340
  #ifdef WSP_GGML_USE_METAL
276
- extern wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data);
277
- extern wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
341
+ extern WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data);
342
+ extern WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
278
343
  wsp_ggml_backend_register("Metal", wsp_ggml_backend_reg_metal_init, wsp_ggml_backend_metal_buffer_type(), NULL);
279
344
  #endif
280
345
  }
281
346
 
282
- void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
347
+ WSP_GGML_CALL void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
283
348
  WSP_GGML_ASSERT(wsp_ggml_backend_registry_count < WSP_GGML_MAX_BACKENDS_REG);
284
349
 
285
- int id = wsp_ggml_backend_registry_count;
350
+ size_t id = wsp_ggml_backend_registry_count;
286
351
 
287
352
  wsp_ggml_backend_registry[id] = (struct wsp_ggml_backend_reg) {
288
353
  /* .name = */ {0},
@@ -315,6 +380,8 @@ size_t wsp_ggml_backend_reg_find_by_name(const char * name) {
315
380
  return i;
316
381
  }
317
382
  }
383
+
384
+ // not found
318
385
  return SIZE_MAX;
319
386
  }
320
387
 
@@ -325,15 +392,15 @@ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend_from_str(const char * backe
325
392
  const char * params = strchr(backend_str, ':');
326
393
  char backend_name[128];
327
394
  if (params == NULL) {
328
- strcpy(backend_name, backend_str);
395
+ snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
329
396
  params = "";
330
397
  } else {
331
- strncpy(backend_name, backend_str, params - backend_str);
332
- backend_name[params - backend_str] = '\0';
398
+ snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
333
399
  params++;
334
400
  }
335
401
 
336
402
  size_t backend_i = wsp_ggml_backend_reg_find_by_name(backend_name);
403
+
337
404
  if (backend_i == SIZE_MAX) {
338
405
  fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
339
406
  return NULL;
@@ -372,69 +439,80 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_reg_alloc_buffer(size_t i, size_t siz
372
439
 
373
440
  // backend CPU
374
441
 
375
- static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
442
+ WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_name(wsp_ggml_backend_buffer_t buffer) {
443
+ return "CPU";
444
+
445
+ WSP_GGML_UNUSED(buffer);
446
+ }
447
+
448
+ WSP_GGML_CALL static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
376
449
  return (void *)buffer->context;
377
450
  }
378
451
 
379
- static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
452
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
380
453
  free(buffer->context);
381
- WSP_GGML_UNUSED(buffer);
382
454
  }
383
455
 
384
- static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
385
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
386
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
387
-
456
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
388
457
  memcpy((char *)tensor->data + offset, data, size);
389
458
 
390
459
  WSP_GGML_UNUSED(buffer);
391
460
  }
392
461
 
393
- static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
394
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
395
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
396
-
462
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
397
463
  memcpy(data, (const char *)tensor->data + offset, size);
398
464
 
399
465
  WSP_GGML_UNUSED(buffer);
400
466
  }
401
467
 
402
- static void wsp_ggml_backend_cpu_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
403
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
468
+ WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
469
+ if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
470
+ memcpy(dst->data, src->data, wsp_ggml_nbytes(src));
471
+ return true;
472
+ }
473
+ return false;
404
474
 
405
475
  WSP_GGML_UNUSED(buffer);
406
476
  }
407
477
 
408
- static void wsp_ggml_backend_cpu_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
409
- wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
410
-
411
- WSP_GGML_UNUSED(buffer);
478
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
479
+ memset(buffer->context, value, buffer->size);
412
480
  }
413
481
 
414
482
  static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = {
483
+ /* .get_name = */ wsp_ggml_backend_cpu_buffer_name,
415
484
  /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer,
416
485
  /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
417
486
  /* .init_tensor = */ NULL, // no initialization required
418
487
  /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
419
488
  /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
420
- /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
421
- /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
489
+ /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor,
490
+ /* .clear = */ wsp_ggml_backend_cpu_buffer_clear,
491
+ /* .reset = */ NULL,
422
492
  };
423
493
 
424
494
  // for buffers from ptr, free is not called
425
495
  static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
496
+ /* .get_name = */ wsp_ggml_backend_cpu_buffer_name,
426
497
  /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
427
498
  /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
428
499
  /* .init_tensor = */ NULL, // no initialization required
429
500
  /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
430
501
  /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
431
- /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from,
432
- /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to,
502
+ /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor,
503
+ /* .clear = */ wsp_ggml_backend_cpu_buffer_clear,
504
+ /* .reset = */ NULL,
433
505
  };
434
506
 
435
507
  static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
436
508
 
437
- static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
509
+ WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
510
+ return "CPU";
511
+
512
+ WSP_GGML_UNUSED(buft);
513
+ }
514
+
515
+ WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
438
516
  size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
439
517
  void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC?
440
518
 
@@ -443,31 +521,95 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(w
443
521
  return wsp_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
444
522
  }
445
523
 
446
- static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
524
+ WSP_GGML_CALL static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
447
525
  return TENSOR_ALIGNMENT;
448
526
 
449
527
  WSP_GGML_UNUSED(buft);
450
528
  }
451
529
 
452
- static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
530
+ WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
453
531
  return wsp_ggml_backend_is_cpu(backend);
454
532
 
455
533
  WSP_GGML_UNUSED(buft);
456
534
  }
457
535
 
458
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) {
459
- static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_cpu = {
536
+ WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
537
+ return true;
538
+
539
+ WSP_GGML_UNUSED(buft);
540
+ }
541
+
542
+ WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) {
543
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type = {
460
544
  /* .iface = */ {
545
+ /* .get_name = */ wsp_ggml_backend_cpu_buffer_type_get_name,
461
546
  /* .alloc_buffer = */ wsp_ggml_backend_cpu_buffer_type_alloc_buffer,
462
547
  /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
463
548
  /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
464
549
  /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
550
+ /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host,
465
551
  },
466
552
  /* .context = */ NULL,
467
553
  };
468
554
 
469
- return &wsp_ggml_backend_buffer_type_cpu;
555
+ return &wsp_ggml_backend_cpu_buffer_type;
556
+ }
557
+
558
+ #ifdef WSP_GGML_USE_CPU_HBM
559
+
560
+ // buffer type HBM
561
+
562
+ #include <hbwmalloc.h>
563
+
564
+ WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
565
+ return "CPU_HBM";
566
+
567
+ WSP_GGML_UNUSED(buft);
568
+ }
569
+
570
+ WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_get_name(wsp_ggml_backend_buffer_t buf) {
571
+ return "CPU_HBM";
572
+
573
+ WSP_GGML_UNUSED(buf);
574
+ }
575
+
576
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_hbm_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
577
+ hbw_free(buffer->context);
578
+ }
579
+
580
+ WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
581
+ //void * ptr = hbw_malloc(size);
582
+ void * ptr;
583
+ int result = hbw_posix_memalign(&ptr, wsp_ggml_backend_cpu_buffer_type_get_alignment(buft), size);
584
+ if (result != 0) {
585
+ fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
586
+ return NULL;
587
+ }
588
+
589
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_cpu_buffer_from_ptr(ptr, size);
590
+ buffer->buft = buft;
591
+ buffer->iface.get_name = wsp_ggml_backend_cpu_hbm_buffer_get_name;
592
+ buffer->iface.free_buffer = wsp_ggml_backend_cpu_hbm_buffer_free_buffer;
593
+
594
+ return buffer;
595
+ }
596
+
597
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_hbm_buffer_type(void) {
598
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type_hbm = {
599
+ /* .iface = */ {
600
+ /* .get_name = */ wsp_ggml_backend_cpu_hbm_buffer_type_get_name,
601
+ /* .alloc_buffer = */ wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
602
+ /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
603
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
604
+ /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
605
+ /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host,
606
+ },
607
+ /* .context = */ NULL,
608
+ };
609
+
610
+ return &wsp_ggml_backend_cpu_buffer_type_hbm;
470
611
  }
612
+ #endif
471
613
 
472
614
  struct wsp_ggml_backend_cpu_context {
473
615
  int n_threads;
@@ -475,20 +617,20 @@ struct wsp_ggml_backend_cpu_context {
475
617
  size_t work_size;
476
618
  };
477
619
 
478
- static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
620
+ WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
479
621
  return "CPU";
480
622
 
481
623
  WSP_GGML_UNUSED(backend);
482
624
  }
483
625
 
484
- static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
626
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
485
627
  struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
486
628
  free(cpu_ctx->work_data);
487
629
  free(cpu_ctx);
488
630
  free(backend);
489
631
  }
490
632
 
491
- static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) {
633
+ WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) {
492
634
  return wsp_ggml_backend_cpu_buffer_type();
493
635
 
494
636
  WSP_GGML_UNUSED(backend);
@@ -499,13 +641,13 @@ struct wsp_ggml_backend_plan_cpu {
499
641
  struct wsp_ggml_cgraph cgraph;
500
642
  };
501
643
 
502
- static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
644
+ WSP_GGML_CALL static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, const struct wsp_ggml_cgraph * cgraph) {
503
645
  struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
504
646
 
505
647
  struct wsp_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct wsp_ggml_backend_plan_cpu));
506
648
 
507
649
  cpu_plan->cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
508
- cpu_plan->cgraph = *cgraph;
650
+ cpu_plan->cgraph = *cgraph; // FIXME: deep copy
509
651
 
510
652
  if (cpu_plan->cplan.work_size > 0) {
511
653
  cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
@@ -514,7 +656,7 @@ static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_
514
656
  return cpu_plan;
515
657
  }
516
658
 
517
- static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
659
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
518
660
  struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
519
661
 
520
662
  free(cpu_plan->cplan.work_data);
@@ -523,7 +665,7 @@ static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp
523
665
  WSP_GGML_UNUSED(backend);
524
666
  }
525
667
 
526
- static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
668
+ WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
527
669
  struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
528
670
 
529
671
  wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
@@ -531,7 +673,7 @@ static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend,
531
673
  WSP_GGML_UNUSED(backend);
532
674
  }
533
675
 
534
- static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
676
+ WSP_GGML_CALL static bool wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
535
677
  struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
536
678
 
537
679
  struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
@@ -545,13 +687,20 @@ static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struc
545
687
  cplan.work_data = cpu_ctx->work_data;
546
688
 
547
689
  wsp_ggml_graph_compute(cgraph, &cplan);
690
+ return true;
548
691
  }
549
692
 
550
- static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
551
- return true;
693
+ WSP_GGML_CALL static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
694
+ switch (op->op) {
695
+ case WSP_GGML_OP_CPY:
696
+ return op->type != WSP_GGML_TYPE_IQ2_XXS && op->type != WSP_GGML_TYPE_IQ2_XS; // missing type_traits.from_float
697
+ case WSP_GGML_OP_MUL_MAT:
698
+ return op->src[1]->type == WSP_GGML_TYPE_F32 || op->src[1]->type == wsp_ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
699
+ default:
700
+ return true;
701
+ }
552
702
 
553
703
  WSP_GGML_UNUSED(backend);
554
- WSP_GGML_UNUSED(op);
555
704
  }
556
705
 
557
706
  static struct wsp_ggml_backend_i cpu_backend_i = {
@@ -560,8 +709,7 @@ static struct wsp_ggml_backend_i cpu_backend_i = {
560
709
  /* .get_default_buffer_type = */ wsp_ggml_backend_cpu_get_default_buffer_type,
561
710
  /* .set_tensor_async = */ NULL,
562
711
  /* .get_tensor_async = */ NULL,
563
- /* .cpy_tensor_from_async = */ NULL,
564
- /* .cpy_tensor_to_async = */ NULL,
712
+ /* .cpy_tensor_async = */ NULL,
565
713
  /* .synchronize = */ NULL,
566
714
  /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create,
567
715
  /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free,
@@ -586,8 +734,8 @@ wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) {
586
734
  return cpu_backend;
587
735
  }
588
736
 
589
- bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) {
590
- return backend->iface.get_name == wsp_ggml_backend_cpu_name;
737
+ WSP_GGML_CALL bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) {
738
+ return backend && backend->iface.get_name == wsp_ggml_backend_cpu_name;
591
739
  }
592
740
 
593
741
  void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads) {
@@ -597,11 +745,11 @@ void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_th
597
745
  ctx->n_threads = n_threads;
598
746
  }
599
747
 
600
- wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
748
+ WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
601
749
  return wsp_ggml_backend_buffer_init(wsp_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
602
750
  }
603
751
 
604
- static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) {
752
+ WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) {
605
753
  return wsp_ggml_backend_cpu_init();
606
754
 
607
755
  WSP_GGML_UNUSED(params);
@@ -611,7 +759,7 @@ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, voi
611
759
 
612
760
  // scheduler
613
761
 
614
- #define WSP_GGML_MAX_BACKENDS 4
762
+ #define WSP_GGML_MAX_BACKENDS 16
615
763
  #define WSP_GGML_MAX_SPLITS 256
616
764
  #define WSP_GGML_MAX_SPLIT_INPUTS 16
617
765
 
@@ -621,21 +769,29 @@ struct wsp_ggml_backend_sched_split {
621
769
  int i_end;
622
770
  struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS];
623
771
  int n_inputs;
772
+ // graph view of this split
624
773
  struct wsp_ggml_cgraph graph;
625
774
  };
626
775
 
627
776
  struct wsp_ggml_backend_sched {
777
+ bool is_reset; // true if the scheduler has been reset since the last graph split
778
+
628
779
  int n_backends;
629
780
  wsp_ggml_backend_t backends[WSP_GGML_MAX_BACKENDS];
781
+ wsp_ggml_backend_buffer_type_t bufts[WSP_GGML_MAX_BACKENDS];
630
782
  wsp_ggml_tallocr_t tallocs[WSP_GGML_MAX_BACKENDS];
631
783
 
632
784
  wsp_ggml_gallocr_t galloc;
633
785
 
786
+ // hash keys of the nodes in the graph
634
787
  struct wsp_ggml_hash_set hash_set;
635
- wsp_ggml_tallocr_t * node_talloc; // [hash_set.size]
636
- struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // [hash_set.size][WSP_GGML_MAX_BACKENDS]
788
+ // hash values (arrays of [hash_set.size])
789
+ wsp_ggml_tallocr_t * node_talloc; // tallocr assigned to each node (indirectly this is the backend)
790
+ struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // copies of each node for each destination backend
637
791
 
792
+ // copy of the graph with modified inputs
638
793
  struct wsp_ggml_cgraph * graph;
794
+
639
795
  struct wsp_ggml_backend_sched_split splits[WSP_GGML_MAX_SPLITS];
640
796
  int n_splits;
641
797
 
@@ -648,6 +804,9 @@ struct wsp_ggml_backend_sched {
648
804
  __attribute__((aligned(WSP_GGML_MEM_ALIGN)))
649
805
  #endif
650
806
  char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + sizeof(struct wsp_ggml_cgraph)];
807
+
808
+ wsp_ggml_backend_sched_eval_callback callback_eval;
809
+ void * callback_eval_user_data;
651
810
  };
652
811
 
653
812
  #define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node)
@@ -676,14 +835,22 @@ static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t
676
835
  return INT_MAX;
677
836
  }
678
837
 
679
- static wsp_ggml_backend_t get_buffer_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) {
838
+ static wsp_ggml_tallocr_t sched_allocr_from_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) {
680
839
  if (buffer == NULL) {
681
840
  return NULL;
682
841
  }
842
+
843
+ // check if this is already allocate in a allocr buffer (from user manual allocations)
844
+ for (int i = 0; i < sched->n_backends; i++) {
845
+ if (wsp_ggml_tallocr_get_buffer(sched->tallocs[i]) == buffer) {
846
+ return sched->tallocs[i];
847
+ }
848
+ }
849
+
683
850
  // find highest prio backend that supports the buffer type
684
851
  for (int i = 0; i < sched->n_backends; i++) {
685
852
  if (wsp_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
686
- return sched->backends[i];
853
+ return sched->tallocs[i];
687
854
  }
688
855
  }
689
856
  WSP_GGML_ASSERT(false && "tensor buffer type not supported by any backend");
@@ -693,7 +860,6 @@ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp
693
860
  if (allocr == NULL) {
694
861
  return NULL;
695
862
  }
696
- // find highest prio backend that supports the buffer type
697
863
  for (int i = 0; i < sched->n_backends; i++) {
698
864
  if (sched->tallocs[i] == allocr) {
699
865
  return sched->backends[i];
@@ -703,7 +869,7 @@ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp
703
869
  }
704
870
 
705
871
  #if 0
706
- static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
872
+ static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*16 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug only
707
873
  #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
708
874
  #define GET_CAUSE(node) causes[hash_id(node)]
709
875
  #else
@@ -712,45 +878,37 @@ static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_
712
878
  #endif
713
879
 
714
880
  // returns the backend that should be used for the node based on the current locations
715
- static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
716
- // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
717
- // ie. kv cache updates
718
- // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
881
+ static wsp_ggml_tallocr_t sched_allocr_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
882
+ // assign pre-allocated nodes to their backend
719
883
  // dst
720
- wsp_ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
721
- if (cur_backend != NULL) {
884
+ wsp_ggml_tallocr_t cur_allocr = sched_allocr_from_buffer(sched, node->buffer);
885
+ if (cur_allocr != NULL) {
722
886
  SET_CAUSE(node, "1.dst");
723
- return cur_backend;
887
+ return cur_allocr;
724
888
  }
725
-
726
889
  // view_src
727
- if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
728
- SET_CAUSE(node, "1.vsrc");
729
- return get_buffer_backend(sched, node->view_src->buffer);
890
+ if (node->view_src != NULL) {
891
+ cur_allocr = sched_allocr_from_buffer(sched, node->view_src->buffer);
892
+ if (cur_allocr != NULL) {
893
+ SET_CAUSE(node, "1.vsrc");
894
+ return cur_allocr;
895
+ }
730
896
  }
731
-
732
- // src
733
- int cur_prio = INT_MAX;
734
- size_t cur_size = 0;
735
-
897
+ // assign nodes that use weights to the backend of the weights
736
898
  for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
737
899
  const struct wsp_ggml_tensor * src = node->src[i];
738
900
  if (src == NULL) {
739
901
  break;
740
902
  }
741
- wsp_ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
742
- if (src_backend != NULL) {
743
- int src_prio = sched_backend_prio(sched, src_backend);
744
- size_t src_size = wsp_ggml_nbytes(src);
745
- if (src_prio < cur_prio && src_size >= cur_size) {
746
- cur_prio = src_prio;
747
- cur_size = src_size;
748
- cur_backend = src_backend;
749
- SET_CAUSE(node, "1.src%d", i);
750
- }
903
+ if (src->buffer != NULL && src->buffer->usage == WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
904
+ wsp_ggml_tallocr_t src_allocr = sched_allocr_from_buffer(sched, src->buffer);
905
+ // operations with weights are always run on the same backend as the weights
906
+ SET_CAUSE(node, "1.wgt%d", i);
907
+ return src_allocr;
751
908
  }
752
909
  }
753
- return cur_backend;
910
+
911
+ return NULL;
754
912
  }
755
913
 
756
914
  static char * fmt_size(size_t size) {
@@ -783,7 +941,7 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g
783
941
  }
784
942
  wsp_ggml_tallocr_t node_allocr = node_allocr(node);
785
943
  wsp_ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
786
- fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name,
944
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name,
787
945
  fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
788
946
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
789
947
  struct wsp_ggml_tensor * src = node->src[j];
@@ -792,7 +950,7 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g
792
950
  }
793
951
  wsp_ggml_tallocr_t src_allocr = node_allocr(src);
794
952
  wsp_ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
795
- fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
953
+ fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
796
954
  fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
797
955
  }
798
956
  fprintf(stderr, "\n");
@@ -808,15 +966,17 @@ static struct wsp_ggml_tensor * wsp_ggml_dup_tensor_layout(struct wsp_ggml_conte
808
966
  return dup;
809
967
  }
810
968
 
969
+
970
+ //#define DEBUG_PASS1
971
+ //#define DEBUG_PASS2
972
+ //#define DEBUG_PASS3
973
+ //#define DEBUG_PASS4
974
+
811
975
  // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
812
- // TODO: merge passes
813
976
  static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
814
- // reset state
815
- size_t hash_size = sched->hash_set.size;
816
- memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
817
- memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
818
- memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
977
+ // reset splits
819
978
  sched->n_splits = 0;
979
+ sched->is_reset = false;
820
980
 
821
981
  struct wsp_ggml_init_params params = {
822
982
  /* .mem_size = */ sizeof(sched->context_buffer),
@@ -824,26 +984,22 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
824
984
  /* .no_alloc = */ true
825
985
  };
826
986
 
827
- if (sched->ctx != NULL) {
828
- wsp_ggml_free(sched->ctx);
829
- }
987
+ wsp_ggml_free(sched->ctx);
830
988
 
831
989
  sched->ctx = wsp_ggml_init(params);
990
+ if (sched->ctx == NULL) {
991
+ fprintf(stderr, "%s: failed to initialize context\n", __func__);
992
+ WSP_GGML_ASSERT(false);
993
+ }
832
994
 
833
- // pass 1: assign backends to ops with allocated inputs
995
+ // pass 1: assign backends to ops with pre-allocated inputs
834
996
  for (int i = 0; i < graph->n_leafs; i++) {
835
997
  struct wsp_ggml_tensor * leaf = graph->leafs[i];
836
998
  if (node_allocr(leaf) != NULL) {
837
999
  // do not overwrite user assignments
838
1000
  continue;
839
1001
  }
840
- wsp_ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
841
- if (leaf_backend == NULL && leaf->view_src != NULL) {
842
- leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
843
- }
844
- if (leaf_backend != NULL) {
845
- node_allocr(leaf) = wsp_ggml_backend_sched_get_tallocr(sched, leaf_backend);
846
- }
1002
+ node_allocr(leaf) = sched_allocr_from_cur(sched, leaf);
847
1003
  }
848
1004
 
849
1005
  for (int i = 0; i < graph->n_nodes; i++) {
@@ -852,50 +1008,120 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
852
1008
  // do not overwrite user assignments
853
1009
  continue;
854
1010
  }
855
- wsp_ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
856
- if (node_backend != NULL) {
857
- node_allocr(node) = wsp_ggml_backend_sched_get_tallocr(sched, node_backend);
1011
+ node_allocr(node) = sched_allocr_from_cur(sched, node);
1012
+ // src
1013
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1014
+ struct wsp_ggml_tensor * src = node->src[j];
1015
+ if (src == NULL) {
1016
+ break;
1017
+ }
1018
+ if (node_allocr(src) == NULL) {
1019
+ node_allocr(src) = sched_allocr_from_cur(sched, src);
1020
+ }
858
1021
  }
859
1022
  }
860
- //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1023
+ #ifdef DEBUG_PASS1
1024
+ fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1025
+ #endif
861
1026
 
862
- // pass 2: assign backends to ops from current assignments
863
- // TODO:
864
- // - reuse sched_backend_from_cur
865
- for (int i = 0; i < graph->n_nodes; i++) {
866
- struct wsp_ggml_tensor * node = graph->nodes[i];
867
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
868
- if (node_allocr == NULL) {
869
- int cur_prio = INT_MAX;
870
- size_t cur_size = 0;
871
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
872
- struct wsp_ggml_tensor * src = node->src[j];
873
- if (src == NULL) {
874
- break;
1027
+ // pass 2: expand current backend assignments
1028
+ // assign the same backend to adjacent nodes
1029
+ // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
1030
+ // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
1031
+
1032
+ // pass 2.1 expand gpu up
1033
+ {
1034
+ wsp_ggml_tallocr_t cur_allocr = NULL;
1035
+ for (int i = graph->n_nodes - 1; i >= 0; i--) {
1036
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1037
+ if (wsp_ggml_is_view_op(node->op)) {
1038
+ continue;
1039
+ }
1040
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1041
+ if (node_allocr != NULL) {
1042
+ if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
1043
+ // skip cpu (lowest prio backend)
1044
+ cur_allocr = NULL;
1045
+ } else {
1046
+ cur_allocr = node_allocr;
875
1047
  }
876
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
877
- if (src_allocr != NULL) {
878
- int src_prio = sched_allocr_prio(sched, src_allocr);
879
- size_t src_size = wsp_ggml_nbytes(src);
880
- if (src_prio < cur_prio && src_size >= cur_size) {
881
- cur_prio = src_prio;
882
- cur_size = src_size;
883
- node_allocr = src_allocr;
884
- SET_CAUSE(node, "2.src%d", j);
885
- }
1048
+ } else {
1049
+ node_allocr(node) = cur_allocr;
1050
+ SET_CAUSE(node, "2.1");
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+ // pass 2.2 expand gpu down
1056
+ {
1057
+ wsp_ggml_tallocr_t cur_allocr = NULL;
1058
+ for (int i = 0; i < graph->n_nodes; i++) {
1059
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1060
+ if (wsp_ggml_is_view_op(node->op)) {
1061
+ continue;
1062
+ }
1063
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1064
+ if (node_allocr != NULL) {
1065
+ if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
1066
+ // skip cpu (lowest prio backend)
1067
+ cur_allocr = NULL;
1068
+ } else {
1069
+ cur_allocr = node_allocr;
886
1070
  }
1071
+ } else {
1072
+ node_allocr(node) = cur_allocr;
1073
+ SET_CAUSE(node, "2.2");
1074
+ }
1075
+ }
1076
+ }
1077
+
1078
+ // pass 2.3 expand rest up
1079
+ {
1080
+ wsp_ggml_tallocr_t cur_allocr = NULL;
1081
+ for (int i = graph->n_nodes - 1; i >= 0; i--) {
1082
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1083
+ if (wsp_ggml_is_view_op(node->op)) {
1084
+ continue;
1085
+ }
1086
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1087
+ if (node_allocr != NULL) {
1088
+ cur_allocr = node_allocr;
1089
+ } else {
1090
+ node_allocr(node) = cur_allocr;
1091
+ SET_CAUSE(node, "2.3");
887
1092
  }
1093
+ }
1094
+ }
1095
+
1096
+ // pass 2.4 expand rest down
1097
+ {
1098
+ wsp_ggml_tallocr_t cur_allocr = NULL;
1099
+ for (int i = 0; i < graph->n_nodes; i++) {
1100
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1101
+ if (wsp_ggml_is_view_op(node->op)) {
1102
+ continue;
1103
+ }
1104
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
888
1105
  if (node_allocr != NULL) {
889
- node_allocr(node) = node_allocr;
1106
+ cur_allocr = node_allocr;
1107
+ } else {
1108
+ node_allocr(node) = cur_allocr;
1109
+ SET_CAUSE(node, "2.4");
890
1110
  }
891
1111
  }
892
1112
  }
893
- //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1113
+ #ifdef DEBUG_PASS2
1114
+ fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1115
+ #endif
894
1116
 
895
- // pass 3: assign backends to remaining src from dst (should only be leafs)
1117
+ // pass 3: assign backends to remaining src from dst and view_src
896
1118
  for (int i = 0; i < graph->n_nodes; i++) {
897
1119
  struct wsp_ggml_tensor * node = graph->nodes[i];
898
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1120
+ wsp_ggml_tallocr_t cur_allocr = node_allocr(node);
1121
+ if (node->view_src != NULL && cur_allocr == NULL) {
1122
+ cur_allocr = node_allocr(node) = node_allocr(node->view_src);
1123
+ SET_CAUSE(node, "3.vsrc");
1124
+ }
899
1125
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
900
1126
  struct wsp_ggml_tensor * src = node->src[j];
901
1127
  if (src == NULL) {
@@ -903,81 +1129,107 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
903
1129
  }
904
1130
  wsp_ggml_tallocr_t src_allocr = node_allocr(src);
905
1131
  if (src_allocr == NULL) {
906
- node_allocr(src) = node_allocr;
1132
+ if (src->view_src != NULL) {
1133
+ // views are always on the same backend as the source
1134
+ node_allocr(src) = node_allocr(src->view_src);
1135
+ SET_CAUSE(src, "3.vsrc");
1136
+ } else {
1137
+ node_allocr(src) = cur_allocr;
1138
+ SET_CAUSE(src, "3.cur");
1139
+ }
907
1140
  }
908
1141
  }
909
1142
  }
910
- //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1143
+ #ifdef DEBUG_PASS3
1144
+ fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1145
+ #endif
911
1146
 
912
1147
  // pass 4: split graph, find tensors that need to be copied
913
- // TODO:
914
- // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
915
- // find first backend
916
- int cur_split = 0;
917
- for (int i = 0; i < graph->n_nodes; i++) {
918
- struct wsp_ggml_tensor * node = graph->nodes[i];
919
- if (node->view_src == NULL) {
920
- sched->splits[0].tallocr = node_allocr(node);
921
- break;
922
- }
923
- }
924
- sched->splits[0].i_start = 0;
925
- sched->splits[0].n_inputs = 0;
926
- memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
927
- wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
928
- size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
929
- for (int i = 0; i < graph->n_nodes; i++) {
930
- struct wsp_ggml_tensor * node = graph->nodes[i];
931
-
932
- if (wsp_ggml_is_view_op(node->op)) {
933
- continue;
1148
+ {
1149
+ int cur_split = 0;
1150
+ // find the backend of the first split, skipping view ops
1151
+ for (int i = 0; i < graph->n_nodes; i++) {
1152
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1153
+ if (!wsp_ggml_is_view_op(node->op)) {
1154
+ sched->splits[0].tallocr = node_allocr(node);
1155
+ break;
1156
+ }
934
1157
  }
1158
+ sched->splits[0].i_start = 0;
1159
+ sched->splits[0].n_inputs = 0;
1160
+ memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
1161
+ wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
1162
+ size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
1163
+ for (int i = 0; i < graph->n_nodes; i++) {
1164
+ struct wsp_ggml_tensor * node = graph->nodes[i];
1165
+
1166
+ if (wsp_ggml_is_view_op(node->op)) {
1167
+ continue;
1168
+ }
935
1169
 
936
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1170
+ wsp_ggml_tallocr_t node_allocr = node_allocr(node);
937
1171
 
938
- if (node_allocr != cur_allocr) {
939
- sched->splits[cur_split].i_end = i;
940
- cur_split++;
941
- WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS);
942
- sched->splits[cur_split].tallocr = node_allocr;
943
- sched->splits[cur_split].i_start = i;
944
- sched->splits[cur_split].n_inputs = 0;
945
- memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
946
- cur_allocr = node_allocr;
947
- cur_backend_id = sched_allocr_prio(sched, cur_allocr);
948
- }
1172
+ WSP_GGML_ASSERT(node_allocr != NULL); // all nodes should be assigned by now
949
1173
 
950
- // find inputs that are not on the same backend
951
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
952
- struct wsp_ggml_tensor * src = node->src[j];
953
- if (src == NULL) {
954
- break;
1174
+ if (node_allocr != cur_allocr) {
1175
+ sched->splits[cur_split].i_end = i;
1176
+ cur_split++;
1177
+ WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS);
1178
+ sched->splits[cur_split].tallocr = node_allocr;
1179
+ sched->splits[cur_split].i_start = i;
1180
+ sched->splits[cur_split].n_inputs = 0;
1181
+ cur_allocr = node_allocr;
1182
+ cur_backend_id = sched_allocr_prio(sched, cur_allocr);
955
1183
  }
956
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
957
- if (src_allocr != node_allocr) {
958
- int n_inputs = sched->splits[cur_split].n_inputs++;
959
- WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS);
960
- sched->splits[cur_split].inputs[n_inputs] = (struct wsp_ggml_tensor *)src;
961
-
962
- // create copies
963
- size_t id = hash_id(src);
964
- if (sched->node_copies[id][cur_backend_id] == NULL) {
965
- struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src);
966
- sched->node_copies[id][cur_backend_id] = tensor_copy;
967
- node_allocr(tensor_copy) = cur_allocr;
968
- wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
969
- wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name);
1184
+
1185
+ // find inputs that are not on the same backend
1186
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1187
+ struct wsp_ggml_tensor * src = node->src[j];
1188
+ if (src == NULL) {
1189
+ break;
1190
+ }
1191
+ wsp_ggml_tallocr_t src_allocr = node_allocr(src);
1192
+ WSP_GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
1193
+ if (src_allocr != node_allocr) {
1194
+ // check if the input is already in the split
1195
+ bool found = false;
1196
+ for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
1197
+ if (sched->splits[cur_split].inputs[k] == src) {
1198
+ found = true;
1199
+ break;
1200
+ }
1201
+ }
1202
+
1203
+ if (!found) {
1204
+ int n_inputs = sched->splits[cur_split].n_inputs++;
1205
+ //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)));
1206
+ WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS);
1207
+ sched->splits[cur_split].inputs[n_inputs] = src;
1208
+ }
1209
+
1210
+ // create a copy of the input in the split's backend
1211
+ size_t id = hash_id(src);
1212
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
1213
+ wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
1214
+ struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src);
1215
+ wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name);
1216
+
1217
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
1218
+ node_allocr(tensor_copy) = cur_allocr;
1219
+ SET_CAUSE(tensor_copy, "4.cpy");
1220
+ }
1221
+ node->src[j] = sched->node_copies[id][cur_backend_id];
970
1222
  }
971
- node->src[j] = sched->node_copies[id][cur_backend_id];
972
1223
  }
973
1224
  }
1225
+ sched->splits[cur_split].i_end = graph->n_nodes;
1226
+ sched->n_splits = cur_split + 1;
974
1227
  }
975
- sched->splits[cur_split].i_end = graph->n_nodes;
976
- sched->n_splits = cur_split + 1;
977
-
978
- //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
1228
+ #ifdef DEBUG_PASS4
1229
+ fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1230
+ #endif
979
1231
 
980
- #if 1
1232
+ #ifndef NDEBUG
981
1233
  // sanity check: all sources should have the same backend as the node
982
1234
  for (int i = 0; i < graph->n_nodes; i++) {
983
1235
  struct wsp_ggml_tensor * node = graph->nodes[i];
@@ -985,6 +1237,11 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
985
1237
  if (node_allocr == NULL) {
986
1238
  fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
987
1239
  }
1240
+ if (node->view_src != NULL && node_allocr != node_allocr(node->view_src)) {
1241
+ fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
1242
+ node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
1243
+ node->view_src->name, node_allocr(node->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(node->view_src))) : "NULL");
1244
+ }
988
1245
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
989
1246
  struct wsp_ggml_tensor * src = node->src[j];
990
1247
  if (src == NULL) {
@@ -996,8 +1253,14 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
996
1253
  node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
997
1254
  j, src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
998
1255
  }
1256
+ if (src->view_src != NULL && src_allocr != node_allocr(src->view_src)) {
1257
+ fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
1258
+ src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL",
1259
+ src->view_src->name, node_allocr(src->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(src->view_src))) : "NULL");
1260
+ }
999
1261
  }
1000
1262
  }
1263
+ fflush(stderr);
1001
1264
  #endif
1002
1265
 
1003
1266
  // create copies of the graph for each split
@@ -1011,6 +1274,8 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg
1011
1274
  for (int j = 0; j < split->n_inputs; j++) {
1012
1275
  struct wsp_ggml_tensor * input = split->inputs[j];
1013
1276
  struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
1277
+ // add a dependency to the input source so that it is not freed before the copy is done
1278
+ WSP_GGML_ASSERT(input_cpy->src[0] == NULL || input_cpy->src[0] == input);
1014
1279
  input_cpy->src[0] = input;
1015
1280
  graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
1016
1281
  }
@@ -1045,24 +1310,16 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
1045
1310
  uint64_t copy_start_us = wsp_ggml_time_us();
1046
1311
  for (int j = 0; j < split->n_inputs; j++) {
1047
1312
  struct wsp_ggml_tensor * input = split->inputs[j];
1048
- struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
1049
- if (input->buffer == NULL) {
1050
- if (input->view_src == NULL) {
1051
- fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
1052
- exit(1);
1053
- }
1054
- // FIXME: may need to use the sched buffer instead
1055
- wsp_ggml_backend_view_init(input->view_src->buffer, input);
1056
- }
1057
- if (input_cpy->buffer == NULL) {
1058
- fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
1059
- exit(1);
1060
- }
1061
- //WSP_GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
1062
- //WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend);
1063
- wsp_ggml_backend_tensor_copy(input, input_cpy);
1313
+ struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][split_backend_id];
1314
+
1315
+ WSP_GGML_ASSERT(input->buffer != NULL);
1316
+ WSP_GGML_ASSERT(input_cpy->buffer != NULL);
1317
+
1318
+ // TODO: avoid this copy if it was already copied in a previous split, and the input didn't change
1319
+ // this is important to avoid copying constants such as KQ_mask and inp_pos multiple times
1320
+ wsp_ggml_backend_tensor_copy_async(split_backend, input, input_cpy);
1064
1321
  }
1065
- // wsp_ggml_backend_synchronize(split_backend);
1322
+ //wsp_ggml_backend_synchronize(split_backend); // necessary to measure copy time
1066
1323
  int64_t copy_end_us = wsp_ggml_time_us();
1067
1324
  copy_us[split_backend_id] += copy_end_us - copy_start_us;
1068
1325
 
@@ -1072,9 +1329,38 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
1072
1329
  wsp_ggml_graph_dump_dot(split->graph, NULL, split_filename);
1073
1330
  #endif
1074
1331
 
1332
+
1075
1333
  uint64_t compute_start_us = wsp_ggml_time_us();
1076
- wsp_ggml_backend_graph_compute(split_backend, &split->graph);
1077
- // wsp_ggml_backend_synchronize(split_backend);
1334
+ if (!sched->callback_eval) {
1335
+ wsp_ggml_backend_graph_compute(split_backend, &split->graph);
1336
+ //wsp_ggml_backend_synchronize(split_backend); // necessary to measure compute time
1337
+ } else {
1338
+ // similar to wsp_ggml_backend_compare_graph_backend
1339
+ for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
1340
+ struct wsp_ggml_tensor * t = split->graph.nodes[j0];
1341
+
1342
+ // check if the user needs data from this node
1343
+ bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1344
+
1345
+ int j1 = j0;
1346
+
1347
+ // determine the range [j0, j1] of nodes that can be computed together
1348
+ while (!need && j1 < split->graph.n_nodes - 1) {
1349
+ t = split->graph.nodes[++j1];
1350
+ need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1351
+ }
1352
+
1353
+ struct wsp_ggml_cgraph gv = wsp_ggml_graph_view(&split->graph, j0, j1 + 1);
1354
+
1355
+ wsp_ggml_backend_graph_compute(split_backend, &gv);
1356
+
1357
+ if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
1358
+ break;
1359
+ }
1360
+
1361
+ j0 = j1;
1362
+ }
1363
+ }
1078
1364
  uint64_t compute_end_us = wsp_ggml_time_us();
1079
1365
  compute_us[split_backend_id] += compute_end_us - compute_start_us;
1080
1366
  }
@@ -1094,26 +1380,41 @@ static void sched_reset(wsp_ggml_backend_sched_t sched) {
1094
1380
  for (int i = 0; i < sched->n_backends; i++) {
1095
1381
  wsp_ggml_tallocr_reset(sched->tallocs[i]);
1096
1382
  }
1383
+ // reset state for the next run
1384
+ size_t hash_size = sched->hash_set.size;
1385
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
1386
+ memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
1387
+ memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
1388
+
1389
+ sched->is_reset = true;
1097
1390
  }
1098
1391
 
1099
- wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends) {
1392
+ wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, wsp_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
1393
+ WSP_GGML_ASSERT(n_backends > 0);
1100
1394
  WSP_GGML_ASSERT(n_backends <= WSP_GGML_MAX_BACKENDS);
1101
1395
 
1102
- struct wsp_ggml_backend_sched * sched = malloc(sizeof(struct wsp_ggml_backend_sched));
1103
- memset(sched, 0, sizeof(struct wsp_ggml_backend_sched));
1396
+ struct wsp_ggml_backend_sched * sched = calloc(sizeof(struct wsp_ggml_backend_sched), 1);
1397
+
1398
+ // initialize hash table
1399
+ sched->hash_set = wsp_ggml_hash_set_new(graph_size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1400
+ sched->node_talloc = calloc(sizeof(sched->node_talloc[0]) * sched->hash_set.size, 1);
1401
+ sched->node_copies = calloc(sizeof(sched->node_copies[0]) * sched->hash_set.size, 1);
1104
1402
 
1105
1403
  sched->n_backends = n_backends;
1106
1404
  for (int i = 0; i < n_backends; i++) {
1107
1405
  sched->backends[i] = backends[i];
1406
+ sched->bufts[i] = bufts ? bufts[i] : wsp_ggml_backend_get_default_buffer_type(backends[i]);
1108
1407
  }
1109
1408
 
1110
1409
  sched->galloc = wsp_ggml_gallocr_new();
1111
1410
 
1112
1411
  // init measure allocs for each backend
1113
1412
  for (int i = 0; i < n_backends; i++) {
1114
- sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_backend(backends[i]);
1413
+ sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_buft(sched->bufts[i]);
1115
1414
  }
1116
1415
 
1416
+ sched_reset(sched);
1417
+
1117
1418
  return sched;
1118
1419
  }
1119
1420
 
@@ -1125,6 +1426,7 @@ void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) {
1125
1426
  wsp_ggml_tallocr_free(sched->tallocs[i]);
1126
1427
  }
1127
1428
  wsp_ggml_gallocr_free(sched->galloc);
1429
+ wsp_ggml_free(sched->ctx);
1128
1430
  free(sched->hash_set.keys);
1129
1431
  free(sched->node_talloc);
1130
1432
  free(sched->node_copies);
@@ -1132,12 +1434,7 @@ void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) {
1132
1434
  }
1133
1435
 
1134
1436
  void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph) {
1135
- // initialize hash tables
1136
- size_t hash_size = measure_graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS;
1137
- sched->hash_set.size = hash_size;
1138
- sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
1139
- sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
1140
- sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
1437
+ WSP_GGML_ASSERT(wsp_ggml_tallocr_is_measure(sched->tallocs[0])); // can only be initialized once
1141
1438
 
1142
1439
  sched_split_graph(sched, measure_graph);
1143
1440
  sched_alloc_splits(sched);
@@ -1146,28 +1443,47 @@ void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct
1146
1443
  for (int i = 0; i < sched->n_backends; i++) {
1147
1444
  size_t size = wsp_ggml_tallocr_max_size(sched->tallocs[i]);
1148
1445
  wsp_ggml_tallocr_free(sched->tallocs[i]);
1149
- sched->tallocs[i] = wsp_ggml_tallocr_new_from_backend(sched->backends[i], size);
1446
+ sched->tallocs[i] = wsp_ggml_tallocr_new_from_buft(sched->bufts[i], size);
1150
1447
  }
1151
1448
 
1152
1449
  sched_reset(sched);
1153
1450
  }
1154
1451
 
1155
1452
  void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
1156
- WSP_GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1453
+ WSP_GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1454
+
1455
+ if (!sched->is_reset) {
1456
+ sched_reset(sched);
1457
+ }
1157
1458
 
1158
1459
  sched_split_graph(sched, graph);
1159
1460
  sched_alloc_splits(sched);
1160
1461
  sched_compute_splits(sched);
1462
+ }
1463
+
1464
+ void wsp_ggml_backend_sched_reset(wsp_ggml_backend_sched_t sched) {
1161
1465
  sched_reset(sched);
1162
1466
  }
1163
1467
 
1468
+
1469
+ void wsp_ggml_backend_sched_set_eval_callback(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_sched_eval_callback callback, void * user_data) {
1470
+ sched->callback_eval = callback;
1471
+ sched->callback_eval_user_data = user_data;
1472
+ }
1473
+
1474
+ int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched) {
1475
+ return sched->n_splits;
1476
+ }
1477
+
1164
1478
  wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1165
1479
  int backend_index = sched_backend_prio(sched, backend);
1480
+ WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1166
1481
  return sched->tallocs[backend_index];
1167
1482
  }
1168
1483
 
1169
1484
  wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1170
1485
  int backend_index = sched_backend_prio(sched, backend);
1486
+ WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1171
1487
  return wsp_ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
1172
1488
  }
1173
1489
 
@@ -1177,10 +1493,19 @@ void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, str
1177
1493
  node_allocr(node) = sched->tallocs[backend_index];
1178
1494
  }
1179
1495
 
1496
+ wsp_ggml_backend_t wsp_ggml_backend_sched_get_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
1497
+ wsp_ggml_tallocr_t allocr = node_allocr(node);
1498
+ if (allocr == NULL) {
1499
+ return NULL;
1500
+ }
1501
+ return get_allocr_backend(sched, allocr);
1502
+ }
1503
+
1180
1504
  // utils
1505
+
1181
1506
  void wsp_ggml_backend_view_init(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
1182
1507
  WSP_GGML_ASSERT(tensor->buffer == NULL);
1183
- WSP_GGML_ASSERT(tensor->data == NULL);
1508
+ //WSP_GGML_ASSERT(tensor->data == NULL); // views of pre-allocated tensors may have the data set in wsp_ggml_new_tensor, but still need to be initialized by the backend
1184
1509
  WSP_GGML_ASSERT(tensor->view_src != NULL);
1185
1510
  WSP_GGML_ASSERT(tensor->view_src->buffer != NULL);
1186
1511
  WSP_GGML_ASSERT(tensor->view_src->data != NULL);
@@ -1246,6 +1571,7 @@ static void graph_init_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml
1246
1571
 
1247
1572
  struct wsp_ggml_tensor * dst = node_copies[id];
1248
1573
  if (dst->view_src != NULL) {
1574
+ graph_init_tensor(hash_set, node_copies, node_init, src->view_src);
1249
1575
  wsp_ggml_backend_view_init(dst->view_src->buffer, dst);
1250
1576
  }
1251
1577
  else {
@@ -1279,6 +1605,21 @@ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_
1279
1605
  struct wsp_ggml_context * ctx_allocated = wsp_ggml_init(params);
1280
1606
  struct wsp_ggml_context * ctx_unallocated = wsp_ggml_init(params);
1281
1607
 
1608
+ if (ctx_allocated == NULL || ctx_unallocated == NULL) {
1609
+ fprintf(stderr, "failed to allocate context for graph copy\n");
1610
+ free(hash_set.keys);
1611
+ free(node_copies);
1612
+ free(node_init);
1613
+ wsp_ggml_free(ctx_allocated);
1614
+ wsp_ggml_free(ctx_unallocated);
1615
+ return (struct wsp_ggml_backend_graph_copy) {
1616
+ /* .buffer = */ NULL,
1617
+ /* .ctx_allocated = */ NULL,
1618
+ /* .ctx_unallocated = */ NULL,
1619
+ /* .graph = */ NULL,
1620
+ };
1621
+ }
1622
+
1282
1623
  // dup nodes
1283
1624
  for (int i = 0; i < graph->n_nodes; i++) {
1284
1625
  struct wsp_ggml_tensor * node = graph->nodes[i];
@@ -1287,6 +1628,20 @@ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_
1287
1628
 
1288
1629
  // allocate nodes
1289
1630
  wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
1631
+ if (buffer == NULL) {
1632
+ fprintf(stderr, "failed to allocate buffer for graph copy\n");
1633
+ free(hash_set.keys);
1634
+ free(node_copies);
1635
+ free(node_init);
1636
+ wsp_ggml_free(ctx_allocated);
1637
+ wsp_ggml_free(ctx_unallocated);
1638
+ return (struct wsp_ggml_backend_graph_copy) {
1639
+ /* .buffer = */ NULL,
1640
+ /* .ctx_allocated = */ NULL,
1641
+ /* .ctx_unallocated = */ NULL,
1642
+ /* .graph = */ NULL,
1643
+ };
1644
+ }
1290
1645
 
1291
1646
  //printf("copy buffer size: %zu MB\n", wsp_ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
1292
1647
 
@@ -1323,8 +1678,12 @@ void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy) {
1323
1678
  wsp_ggml_free(copy.ctx_unallocated);
1324
1679
  }
1325
1680
 
1326
- void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) {
1681
+ bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) {
1327
1682
  struct wsp_ggml_backend_graph_copy copy = wsp_ggml_backend_graph_copy(backend2, graph);
1683
+ if (copy.buffer == NULL) {
1684
+ return false;
1685
+ }
1686
+
1328
1687
  struct wsp_ggml_cgraph * g1 = graph;
1329
1688
  struct wsp_ggml_cgraph * g2 = copy.graph;
1330
1689
 
@@ -1354,4 +1713,6 @@ void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggm
1354
1713
  }
1355
1714
 
1356
1715
  wsp_ggml_backend_graph_copy_free(copy);
1716
+
1717
+ return true;
1357
1718
  }