whisper.rn 0.4.0-rc.10 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,43 @@
1
1
  #include "ggml-backend-impl.h"
2
2
  #include "ggml-backend.h"
3
- #include "ggml-cpu.h"
4
3
  #include "ggml-impl.h"
4
+ #include <algorithm>
5
+ #include <codecvt>
5
6
  #include <cstring>
7
+ #include <filesystem>
8
+ #include <locale>
9
+ #include <memory>
10
+ #include <string>
11
+ #include <type_traits>
6
12
  #include <vector>
7
13
 
14
+ #ifdef _WIN32
15
+ # define WIN32_LEAN_AND_MEAN
16
+ # ifndef NOMINMAX
17
+ # define NOMINMAX
18
+ # endif
19
+ # include <windows.h>
20
+ #elif defined(__APPLE__)
21
+ # include <mach-o/dyld.h>
22
+ # include <dlfcn.h>
23
+ #else
24
+ # include <dlfcn.h>
25
+ # include <unistd.h>
26
+ #endif
27
+
8
28
  // Backend registry
29
+ #ifdef WSP_GGML_USE_CPU
30
+ #include "ggml-cpu.h"
31
+ #endif
9
32
 
10
33
  #ifdef WSP_GGML_USE_CUDA
11
34
  #include "ggml-cuda.h"
12
35
  #endif
13
36
 
14
37
  #ifdef WSP_GGML_USE_METAL
15
- #include <TargetConditionals.h>
16
-
17
- #if !TARGET_OS_SIMULATOR
18
38
  #include "ggml-metal.h"
19
39
  #endif
20
40
 
21
- #endif
22
-
23
41
  #ifdef WSP_GGML_USE_SYCL
24
42
  #include "ggml-sycl.h"
25
43
  #endif
@@ -28,6 +46,10 @@
28
46
  #include "ggml-vulkan.h"
29
47
  #endif
30
48
 
49
+ #ifdef WSP_GGML_USE_OPENCL
50
+ #include "ggml-opencl.h"
51
+ #endif
52
+
31
53
  #ifdef WSP_GGML_USE_BLAS
32
54
  #include "ggml-blas.h"
33
55
  #endif
@@ -36,10 +58,6 @@
36
58
  #include "ggml-rpc.h"
37
59
  #endif
38
60
 
39
- #ifdef WSP_GGML_USE_AMX
40
- # include "ggml-amx.h"
41
- #endif
42
-
43
61
  #ifdef WSP_GGML_USE_CANN
44
62
  #include "ggml-cann.h"
45
63
  #endif
@@ -48,8 +66,90 @@
48
66
  #include "ggml-kompute.h"
49
67
  #endif
50
68
 
69
+ // disable C++17 deprecation warning for std::codecvt_utf8
70
+ #if defined(__clang__)
71
+ # pragma clang diagnostic push
72
+ # pragma clang diagnostic ignored "-Wdeprecated-declarations"
73
+ #endif
74
+
75
+ static std::wstring utf8_to_utf16(const std::string & str) {
76
+ std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
77
+ return converter.from_bytes(str);
78
+ }
79
+
80
+ static std::string utf16_to_utf8(const std::wstring & str) {
81
+ std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
82
+ return converter.to_bytes(str);
83
+ }
84
+
85
+ #if defined(__clang__)
86
+ # pragma clang diagnostic pop
87
+ #endif
88
+
89
+ #ifdef _WIN32
90
+
91
+ using dl_handle = std::remove_pointer_t<HMODULE>;
92
+
93
+ struct dl_handle_deleter {
94
+ void operator()(HMODULE handle) {
95
+ FreeLibrary(handle);
96
+ }
97
+ };
98
+
99
+ static dl_handle * dl_load_library(const std::wstring & path) {
100
+ // suppress error dialogs for missing DLLs
101
+ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
102
+ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
103
+
104
+ HMODULE handle = LoadLibraryW(path.c_str());
105
+
106
+ SetErrorMode(old_mode);
107
+
108
+ return handle;
109
+ }
110
+
111
+ static void * dl_get_sym(dl_handle * handle, const char * name) {
112
+ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
113
+ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
114
+
115
+ void * p = (void *) GetProcAddress(handle, name);
116
+
117
+ SetErrorMode(old_mode);
118
+
119
+ return p;
120
+ }
121
+
122
+ #else
123
+
124
+ using dl_handle = void;
125
+
126
+ struct dl_handle_deleter {
127
+ void operator()(void * handle) {
128
+ dlclose(handle);
129
+ }
130
+ };
131
+
132
+ static void * dl_load_library(const std::wstring & path) {
133
+ dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
134
+
135
+ return handle;
136
+ }
137
+
138
+ static void * dl_get_sym(dl_handle * handle, const char * name) {
139
+ return dlsym(handle, name);
140
+ }
141
+
142
+ #endif
143
+
144
+ using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
145
+
146
+ struct wsp_ggml_backend_reg_entry {
147
+ wsp_ggml_backend_reg_t reg;
148
+ dl_handle_ptr handle;
149
+ };
150
+
51
151
  struct wsp_ggml_backend_registry {
52
- std::vector<wsp_ggml_backend_reg_t> backends;
152
+ std::vector<wsp_ggml_backend_reg_entry> backends;
53
153
  std::vector<wsp_ggml_backend_dev_t> devices;
54
154
 
55
155
  wsp_ggml_backend_registry() {
@@ -57,18 +157,17 @@ struct wsp_ggml_backend_registry {
57
157
  register_backend(wsp_ggml_backend_cuda_reg());
58
158
  #endif
59
159
  #ifdef WSP_GGML_USE_METAL
60
-
61
- #if !TARGET_OS_SIMULATOR
62
160
  register_backend(wsp_ggml_backend_metal_reg());
63
161
  #endif
64
-
65
- #endif
66
162
  #ifdef WSP_GGML_USE_SYCL
67
163
  register_backend(wsp_ggml_backend_sycl_reg());
68
164
  #endif
69
165
  #ifdef WSP_GGML_USE_VULKAN
70
166
  register_backend(wsp_ggml_backend_vk_reg());
71
167
  #endif
168
+ #ifdef WSP_GGML_USE_OPENCL
169
+ register_backend(wsp_ggml_backend_opencl_reg());
170
+ #endif
72
171
  #ifdef WSP_GGML_USE_CANN
73
172
  register_backend(wsp_ggml_backend_cann_reg());
74
173
  #endif
@@ -78,17 +177,25 @@ struct wsp_ggml_backend_registry {
78
177
  #ifdef WSP_GGML_USE_RPC
79
178
  register_backend(wsp_ggml_backend_rpc_reg());
80
179
  #endif
81
- #ifdef WSP_GGML_USE_AMX
82
- register_backend(wsp_ggml_backend_amx_reg());
83
- #endif
84
180
  #ifdef WSP_GGML_USE_KOMPUTE
85
181
  register_backend(wsp_ggml_backend_kompute_reg());
86
182
  #endif
87
-
183
+ #ifdef WSP_GGML_USE_CPU
88
184
  register_backend(wsp_ggml_backend_cpu_reg());
185
+ #endif
186
+ }
187
+
188
+ ~wsp_ggml_backend_registry() {
189
+ // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources,
190
+ // since backend threads may still be running and accessing resources from the dynamic library
191
+ for (auto & entry : backends) {
192
+ if (entry.handle) {
193
+ entry.handle.release(); // NOLINT
194
+ }
195
+ }
89
196
  }
90
197
 
91
- void register_backend(wsp_ggml_backend_reg_t reg) {
198
+ void register_backend(wsp_ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {
92
199
  if (!reg) {
93
200
  return;
94
201
  }
@@ -97,7 +204,7 @@ struct wsp_ggml_backend_registry {
97
204
  WSP_GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
98
205
  __func__, wsp_ggml_backend_reg_name(reg), wsp_ggml_backend_reg_dev_count(reg));
99
206
  #endif
100
- backends.push_back(reg);
207
+ backends.push_back({ reg, std::move(handle) });
101
208
  for (size_t i = 0; i < wsp_ggml_backend_reg_dev_count(reg); i++) {
102
209
  register_device(wsp_ggml_backend_reg_dev_get(reg, i));
103
210
  }
@@ -109,6 +216,76 @@ struct wsp_ggml_backend_registry {
109
216
  #endif
110
217
  devices.push_back(device);
111
218
  }
219
+
220
+ wsp_ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
221
+ dl_handle_ptr handle { dl_load_library(path) };
222
+ if (!handle) {
223
+ if (!silent) {
224
+ WSP_GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
225
+ }
226
+ return nullptr;
227
+ }
228
+
229
+ auto score_fn = (wsp_ggml_backend_score_t) dl_get_sym(handle.get(), "wsp_ggml_backend_score");
230
+ if (score_fn && score_fn() == 0) {
231
+ if (!silent) {
232
+ WSP_GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
233
+ }
234
+ return nullptr;
235
+ }
236
+
237
+ auto backend_init_fn = (wsp_ggml_backend_init_t) dl_get_sym(handle.get(), "wsp_ggml_backend_init");
238
+ if (!backend_init_fn) {
239
+ if (!silent) {
240
+ WSP_GGML_LOG_ERROR("%s: failed to find wsp_ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
241
+ }
242
+ return nullptr;
243
+ }
244
+
245
+ wsp_ggml_backend_reg_t reg = backend_init_fn();
246
+ if (!reg || reg->api_version != WSP_GGML_BACKEND_API_VERSION) {
247
+ if (!silent) {
248
+ if (!reg) {
249
+ WSP_GGML_LOG_ERROR("%s: failed to initialize backend from %s: wsp_ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
250
+ } else {
251
+ WSP_GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
252
+ __func__, utf16_to_utf8(path).c_str(), reg->api_version, WSP_GGML_BACKEND_API_VERSION);
253
+ }
254
+ }
255
+ return nullptr;
256
+ }
257
+
258
+ WSP_GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, wsp_ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
259
+
260
+ register_backend(reg, std::move(handle));
261
+
262
+ return reg;
263
+ }
264
+
265
+ void unload_backend(wsp_ggml_backend_reg_t reg, bool silent) {
266
+ auto it = std::find_if(backends.begin(), backends.end(),
267
+ [reg](const wsp_ggml_backend_reg_entry & entry) { return entry.reg == reg; });
268
+
269
+ if (it == backends.end()) {
270
+ if (!silent) {
271
+ WSP_GGML_LOG_ERROR("%s: backend not found\n", __func__);
272
+ }
273
+ return;
274
+ }
275
+
276
+ if (!silent) {
277
+ WSP_GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, wsp_ggml_backend_reg_name(reg));
278
+ }
279
+
280
+ // remove devices
281
+ devices.erase(
282
+ std::remove_if(devices.begin(), devices.end(),
283
+ [reg](wsp_ggml_backend_dev_t dev) { return wsp_ggml_backend_dev_backend_reg(dev) == reg; }),
284
+ devices.end());
285
+
286
+ // remove backend
287
+ backends.erase(it);
288
+ }
112
289
  };
113
290
 
114
291
  static wsp_ggml_backend_registry & get_reg() {
@@ -126,23 +303,32 @@ void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device) {
126
303
  }
127
304
 
128
305
  // Backend (reg) enumeration
306
+ static bool striequals(const char * a, const char * b) {
307
+ for (; *a && *b; a++, b++) {
308
+ if (std::tolower(*a) != std::tolower(*b)) {
309
+ return false;
310
+ }
311
+ }
312
+ return *a == *b;
313
+ }
314
+
129
315
  size_t wsp_ggml_backend_reg_count() {
130
316
  return get_reg().backends.size();
131
317
  }
132
318
 
133
319
  wsp_ggml_backend_reg_t wsp_ggml_backend_reg_get(size_t index) {
134
320
  WSP_GGML_ASSERT(index < wsp_ggml_backend_reg_count());
135
- return get_reg().backends[index];
321
+ return get_reg().backends[index].reg;
136
322
  }
137
323
 
138
324
  wsp_ggml_backend_reg_t wsp_ggml_backend_reg_by_name(const char * name) {
139
325
  for (size_t i = 0; i < wsp_ggml_backend_reg_count(); i++) {
140
326
  wsp_ggml_backend_reg_t reg = wsp_ggml_backend_reg_get(i);
141
- if (std::strcmp(wsp_ggml_backend_reg_name(reg), name) == 0) {
327
+ if (striequals(wsp_ggml_backend_reg_name(reg), name)) {
142
328
  return reg;
143
329
  }
144
330
  }
145
- return NULL;
331
+ return nullptr;
146
332
  }
147
333
 
148
334
  // Device enumeration
@@ -158,11 +344,11 @@ wsp_ggml_backend_dev_t wsp_ggml_backend_dev_get(size_t index) {
158
344
  wsp_ggml_backend_dev_t wsp_ggml_backend_dev_by_name(const char * name) {
159
345
  for (size_t i = 0; i < wsp_ggml_backend_dev_count(); i++) {
160
346
  wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_get(i);
161
- if (strcmp(wsp_ggml_backend_dev_name(dev), name) == 0) {
347
+ if (striequals(wsp_ggml_backend_dev_name(dev), name)) {
162
348
  return dev;
163
349
  }
164
350
  }
165
- return NULL;
351
+ return nullptr;
166
352
  }
167
353
 
168
354
  wsp_ggml_backend_dev_t wsp_ggml_backend_dev_by_type(enum wsp_ggml_backend_dev_type type) {
@@ -172,14 +358,14 @@ wsp_ggml_backend_dev_t wsp_ggml_backend_dev_by_type(enum wsp_ggml_backend_dev_ty
172
358
  return dev;
173
359
  }
174
360
  }
175
- return NULL;
361
+ return nullptr;
176
362
  }
177
363
 
178
364
  // Convenience functions
179
365
  wsp_ggml_backend_t wsp_ggml_backend_init_by_name(const char * name, const char * params) {
180
366
  wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_by_name(name);
181
367
  if (!dev) {
182
- return NULL;
368
+ return nullptr;
183
369
  }
184
370
  return wsp_ggml_backend_dev_init(dev, params);
185
371
  }
@@ -187,7 +373,7 @@ wsp_ggml_backend_t wsp_ggml_backend_init_by_name(const char * name, const char *
187
373
  wsp_ggml_backend_t wsp_ggml_backend_init_by_type(enum wsp_ggml_backend_dev_type type, const char * params) {
188
374
  wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_by_type(type);
189
375
  if (!dev) {
190
- return NULL;
376
+ return nullptr;
191
377
  }
192
378
  return wsp_ggml_backend_dev_init(dev, params);
193
379
  }
@@ -198,7 +384,199 @@ wsp_ggml_backend_t wsp_ggml_backend_init_best(void) {
198
384
  dev = wsp_ggml_backend_dev_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU);
199
385
  }
200
386
  if (!dev) {
201
- return NULL;
387
+ return nullptr;
388
+ }
389
+ return wsp_ggml_backend_dev_init(dev, nullptr);
390
+ }
391
+
392
+ // Dynamic loading
393
+ wsp_ggml_backend_reg_t wsp_ggml_backend_load(const char * path) {
394
+ return get_reg().load_backend(utf8_to_utf16(path), false);
395
+ }
396
+
397
+ void wsp_ggml_backend_unload(wsp_ggml_backend_reg_t reg) {
398
+ get_reg().unload_backend(reg, true);
399
+ }
400
+
401
+ static std::wstring get_executable_path() {
402
+ #if defined(__APPLE__)
403
+ // get executable path
404
+ std::vector<char> path;
405
+ uint32_t size;
406
+ while (true) {
407
+ size = path.size();
408
+ if (_NSGetExecutablePath(path.data(), &size) == 0) {
409
+ break;
410
+ }
411
+ path.resize(size);
412
+ }
413
+ std::string base_path(path.data(), size);
414
+ // remove executable name
415
+ auto last_slash = base_path.find_last_of('/');
416
+ if (last_slash != std::string::npos) {
417
+ base_path = base_path.substr(0, last_slash);
418
+ }
419
+ return utf8_to_utf16(base_path + "/");
420
+ #elif defined(__linux__) || defined(__FreeBSD__)
421
+ std::string base_path = ".";
422
+ std::vector<char> path(1024);
423
+ while (true) {
424
+ // get executable path
425
+ # if defined(__linux__)
426
+ ssize_t len = readlink("/proc/self/exe", path.data(), path.size());
427
+ # elif defined(__FreeBSD__)
428
+ ssize_t len = readlink("/proc/curproc/file", path.data(), path.size());
429
+ # endif
430
+ if (len == -1) {
431
+ break;
432
+ }
433
+ if (len < (ssize_t) path.size()) {
434
+ base_path = std::string(path.data(), len);
435
+ // remove executable name
436
+ auto last_slash = base_path.find_last_of('/');
437
+ if (last_slash != std::string::npos) {
438
+ base_path = base_path.substr(0, last_slash);
439
+ }
440
+ break;
441
+ }
442
+ path.resize(path.size() * 2);
443
+ }
444
+
445
+ return utf8_to_utf16(base_path + "/");
446
+ #elif defined(_WIN32)
447
+ std::vector<wchar_t> path(MAX_PATH);
448
+ DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
449
+ if (len == 0) {
450
+ return {};
451
+ }
452
+ std::wstring base_path(path.data(), len);
453
+ // remove executable name
454
+ auto last_slash = base_path.find_last_of('\\');
455
+ if (last_slash != std::string::npos) {
456
+ base_path = base_path.substr(0, last_slash);
457
+ }
458
+ return base_path + L"\\";
459
+ #else
460
+ return {};
461
+ #endif
462
+ }
463
+
464
+ static std::wstring backend_filename_prefix() {
465
+ #ifdef _WIN32
466
+ return L"ggml-";
467
+ #else
468
+ return L"libggml-";
469
+ #endif
470
+ }
471
+
472
+ static std::wstring backend_filename_suffix() {
473
+ #ifdef _WIN32
474
+ return L".dll";
475
+ #else
476
+ return L".so";
477
+ #endif
478
+ }
479
+
480
+ static std::wstring path_separator() {
481
+ #ifdef _WIN32
482
+ return L"\\";
483
+ #else
484
+ return L"/";
485
+ #endif
486
+ }
487
+
488
+ static wsp_ggml_backend_reg_t wsp_ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
489
+ // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
490
+ // TODO: search system paths
491
+ std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
492
+ std::vector<std::wstring> search_paths;
493
+ if (user_search_path == nullptr) {
494
+ search_paths.push_back(L"." + path_separator());
495
+ search_paths.push_back(get_executable_path());
496
+ } else {
497
+ search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
498
+ }
499
+
500
+ int best_score = 0;
501
+ std::wstring best_path;
502
+
503
+ namespace fs = std::filesystem;
504
+ for (const auto & search_path : search_paths) {
505
+ if (!fs::exists(search_path)) {
506
+ continue;
507
+ }
508
+ fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
509
+ for (const auto & entry : dir_it) {
510
+ if (entry.is_regular_file()) {
511
+ std::wstring filename = entry.path().filename().wstring();
512
+ std::wstring ext = entry.path().extension().wstring();
513
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
514
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
515
+ if (!handle && !silent) {
516
+ WSP_GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
517
+ }
518
+ if (handle) {
519
+ auto score_fn = (wsp_ggml_backend_score_t) dl_get_sym(handle.get(), "wsp_ggml_backend_score");
520
+ if (score_fn) {
521
+ int s = score_fn();
522
+ #ifndef NDEBUG
523
+ WSP_GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
524
+ #endif
525
+ if (s > best_score) {
526
+ best_score = s;
527
+ best_path = entry.path().wstring();
528
+ }
529
+ } else {
530
+ if (!silent) {
531
+ WSP_GGML_LOG_INFO("%s: failed to find wsp_ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
532
+ }
533
+ }
534
+ }
535
+ }
536
+ }
537
+ }
538
+ }
539
+
540
+ if (best_score == 0) {
541
+ // try to load the base backend
542
+ for (const auto & search_path : search_paths) {
543
+ std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
544
+ if (fs::exists(path)) {
545
+ return get_reg().load_backend(path, silent);
546
+ }
547
+ }
548
+ return nullptr;
549
+ }
550
+
551
+ return get_reg().load_backend(best_path, silent);
552
+ }
553
+
554
+ void wsp_ggml_backend_load_all() {
555
+ wsp_ggml_backend_load_all_from_path(nullptr);
556
+ }
557
+
558
+ void wsp_ggml_backend_load_all_from_path(const char * dir_path) {
559
+ #ifdef NDEBUG
560
+ bool silent = true;
561
+ #else
562
+ bool silent = false;
563
+ #endif
564
+
565
+ wsp_ggml_backend_load_best("blas", silent, dir_path);
566
+ wsp_ggml_backend_load_best("cann", silent, dir_path);
567
+ wsp_ggml_backend_load_best("cuda", silent, dir_path);
568
+ wsp_ggml_backend_load_best("hip", silent, dir_path);
569
+ wsp_ggml_backend_load_best("kompute", silent, dir_path);
570
+ wsp_ggml_backend_load_best("metal", silent, dir_path);
571
+ wsp_ggml_backend_load_best("rpc", silent, dir_path);
572
+ wsp_ggml_backend_load_best("sycl", silent, dir_path);
573
+ wsp_ggml_backend_load_best("vulkan", silent, dir_path);
574
+ wsp_ggml_backend_load_best("opencl", silent, dir_path);
575
+ wsp_ggml_backend_load_best("musa", silent, dir_path);
576
+ wsp_ggml_backend_load_best("cpu", silent, dir_path);
577
+ // check the environment variable WSP_GGML_BACKEND_PATH to load an out-of-tree backend
578
+ const char * backend_path = std::getenv("WSP_GGML_BACKEND_PATH");
579
+ if (backend_path) {
580
+ wsp_ggml_backend_load(backend_path);
202
581
  }
203
- return wsp_ggml_backend_dev_init(dev, NULL);
204
582
  }
@@ -252,6 +252,7 @@ void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct
252
252
  }
253
253
 
254
254
  void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
255
+ WSP_GGML_ASSERT(tensor);
255
256
  wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
256
257
 
257
258
  if (size == 0) {
@@ -266,6 +267,7 @@ void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * d
266
267
  }
267
268
 
268
269
  void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
270
+ WSP_GGML_ASSERT(tensor);
269
271
  wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
270
272
 
271
273
  if (size == 0) {
@@ -740,7 +742,8 @@ static int wsp_ggml_backend_sched_backend_id_from_cur(wsp_ggml_backend_sched_t s
740
742
 
741
743
  if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
742
744
  // since the tensor is pre-allocated, it cannot be moved to another backend
743
- WSP_GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name);
745
+ wsp_ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
746
+ WSP_GGML_ABORT("pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)", tensor->name, wsp_ggml_backend_buffer_name(buffer), wsp_ggml_op_name(tensor->op));
744
747
  }
745
748
 
746
749
  // graph input
@@ -761,7 +764,7 @@ static int wsp_ggml_backend_sched_backend_id_from_cur(wsp_ggml_backend_sched_t s
761
764
  if (tensor->op != WSP_GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
762
765
  int src_backend_id = wsp_ggml_backend_sched_backend_from_buffer(sched, src, tensor);
763
766
  // check if a backend with higher prio wants to offload the op
764
- if (src_backend_id == sched->n_backends - 1) {
767
+ if (src_backend_id == sched->n_backends - 1 && wsp_ggml_backend_buffer_is_host(src->buffer)) {
765
768
  for (int b = 0; b < src_backend_id; b++) {
766
769
  if (wsp_ggml_backend_supports_op(sched->backends[b], tensor) && wsp_ggml_backend_offload_op(sched->backends[b], tensor)) {
767
770
  SET_CAUSE(tensor, "1.off");
@@ -792,9 +795,12 @@ static void wsp_ggml_backend_sched_print_assignments(wsp_ggml_backend_sched_t sc
792
795
  for (int i = 0; i < graph->n_nodes; i++) {
793
796
  if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
794
797
  wsp_ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
795
- WSP_GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend),
798
+ WSP_GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, wsp_ggml_backend_name(split_backend),
796
799
  sched->splits[cur_split].n_inputs);
797
800
  for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
801
+ if (j == 0) {
802
+ WSP_GGML_LOG_DEBUG(": ");
803
+ }
798
804
  WSP_GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
799
805
  fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j])));
800
806
  }
@@ -190,11 +190,21 @@ extern "C" {
190
190
  typedef void (*wsp_ggml_backend_set_n_threads_t)(wsp_ggml_backend_t backend, int n_threads);
191
191
  // Get additional buffer types provided by the device (returns a NULL-terminated array)
192
192
  typedef wsp_ggml_backend_buffer_type_t * (*wsp_ggml_backend_dev_get_extra_bufts_t)(wsp_ggml_backend_dev_t device);
193
+ // Set the abort callback for the backend
194
+ typedef void (*wsp_ggml_backend_set_abort_callback_t)(wsp_ggml_backend_t backend, wsp_ggml_abort_callback abort_callback, void * abort_callback_data);
195
+ // Get a list of feature flags supported by the backend (returns a NULL-terminated array)
196
+ struct wsp_ggml_backend_feature {
197
+ const char * name;
198
+ const char * value;
199
+ };
200
+ typedef struct wsp_ggml_backend_feature * (*wsp_ggml_backend_get_features_t)(wsp_ggml_backend_reg_t reg);
193
201
 
194
202
  //
195
203
  // Backend registry
196
204
  //
197
205
 
206
+ WSP_GGML_API void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device);
207
+
198
208
  // Backend (reg) enumeration
199
209
  WSP_GGML_API size_t wsp_ggml_backend_reg_count(void);
200
210
  WSP_GGML_API wsp_ggml_backend_reg_t wsp_ggml_backend_reg_get(size_t index);
@@ -214,6 +224,14 @@ extern "C" {
214
224
  // = wsp_ggml_backend_dev_init(wsp_ggml_backend_dev_by_type(GPU) OR wsp_ggml_backend_dev_by_type(CPU), NULL)
215
225
  WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_init_best(void);
216
226
 
227
+ // Load a backend from a dynamic library and register it
228
+ WSP_GGML_API wsp_ggml_backend_reg_t wsp_ggml_backend_load(const char * path);
229
+ // Unload a backend if loaded dynamically and unregister it
230
+ WSP_GGML_API void wsp_ggml_backend_unload(wsp_ggml_backend_reg_t reg);
231
+ // Load all known backends from dynamic libraries
232
+ WSP_GGML_API void wsp_ggml_backend_load_all(void);
233
+ WSP_GGML_API void wsp_ggml_backend_load_all_from_path(const char * dir_path);
234
+
217
235
  //
218
236
  // Backend scheduler
219
237
  //