whisper.rn 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/android/src/main/jni.cpp +12 -3
  4. package/cpp/ggml-alloc.c +292 -130
  5. package/cpp/ggml-backend-impl.h +4 -4
  6. package/cpp/ggml-backend-reg.cpp +13 -5
  7. package/cpp/ggml-backend.cpp +207 -17
  8. package/cpp/ggml-backend.h +19 -1
  9. package/cpp/ggml-cpu/amx/amx.cpp +5 -2
  10. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  11. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  12. package/cpp/ggml-cpu/common.h +14 -0
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
  14. package/cpp/ggml-cpu/ggml-cpu.c +65 -44
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  16. package/cpp/ggml-cpu/ops.cpp +542 -775
  17. package/cpp/ggml-cpu/ops.h +2 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  19. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  20. package/cpp/ggml-cpu/unary-ops.h +5 -0
  21. package/cpp/ggml-cpu/vec.cpp +227 -20
  22. package/cpp/ggml-cpu/vec.h +407 -56
  23. package/cpp/ggml-cpu.h +1 -1
  24. package/cpp/ggml-impl.h +94 -12
  25. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  26. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  27. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  28. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  29. package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
  30. package/cpp/ggml-metal/ggml-metal-device.h +244 -0
  31. package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
  32. package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
  33. package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
  34. package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
  35. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  36. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  37. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +40 -40
  39. package/cpp/ggml-metal.h +1 -6
  40. package/cpp/ggml-quants.c +1 -0
  41. package/cpp/ggml.c +341 -15
  42. package/cpp/ggml.h +150 -5
  43. package/cpp/jsi/RNWhisperJSI.cpp +9 -2
  44. package/cpp/jsi/ThreadPool.h +3 -3
  45. package/cpp/rn-whisper.h +1 -0
  46. package/cpp/whisper.cpp +89 -72
  47. package/cpp/whisper.h +1 -0
  48. package/ios/CMakeLists.txt +6 -1
  49. package/ios/RNWhisperContext.mm +3 -1
  50. package/ios/RNWhisperVadContext.mm +14 -13
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  61. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  70. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  72. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  74. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  80. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  81. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  82. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  83. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  86. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  87. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  92. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  93. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  94. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  95. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  96. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  97. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  98. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  99. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  101. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  102. package/lib/commonjs/version.json +1 -1
  103. package/lib/module/NativeRNWhisper.js.map +1 -1
  104. package/lib/module/version.json +1 -1
  105. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  106. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  107. package/package.json +1 -1
  108. package/src/NativeRNWhisper.ts +2 -0
  109. package/src/version.json +1 -1
  110. package/whisper-rn.podspec +8 -9
  111. package/cpp/ggml-metal.m +0 -6779
  112. package/cpp/ggml-whisper-sim.metallib +0 -0
  113. package/cpp/ggml-whisper.metallib +0 -0
@@ -55,7 +55,12 @@ add_library(rnwhisper SHARED
55
55
  ${SOURCE_DIR}/ggml-cpu/binary-ops.cpp
56
56
  ${SOURCE_DIR}/ggml-cpu/vec.cpp
57
57
  ${SOURCE_DIR}/ggml-cpu/ops.cpp
58
- ${SOURCE_DIR}/ggml-metal.m
58
+ ${SOURCE_DIR}/ggml-metal/ggml-metal.cpp
59
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-common.cpp
60
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-device.cpp
61
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-context.m
62
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-device.m
63
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-ops.cpp
59
64
  ${SOURCE_DIR}/ggml-opt.cpp
60
65
  ${SOURCE_DIR}/ggml-threading.cpp
61
66
  ${SOURCE_DIR}/ggml-quants.c
@@ -168,6 +168,7 @@ static void* retained_log_block = nullptr;
168
168
  self->recordState.sliceNSamples.push_back(0);
169
169
 
170
170
  self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]);
171
+ self->recordState.job->n_processors = options[@"nProcessors"] != nil ? [options[@"nProcessors"] intValue] : 1;
171
172
  self->recordState.job->set_realtime_params(
172
173
  {
173
174
  .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false,
@@ -471,6 +472,7 @@ struct rnwhisper_segments_callback_data {
471
472
  }
472
473
 
473
474
  rnwhisper::job* job = rnwhisper::job_new(jobId, params);
475
+ job->n_processors = options[@"nProcessors"] != nil ? [options[@"nProcessors"] intValue] : 1;
474
476
  self->recordState.job = job;
475
477
  int code = [self fullTranscribe:job audioData:audioData audioDataCount:audioDataCount];
476
478
  rnwhisper::job_remove(jobId);
@@ -572,7 +574,7 @@ struct rnwhisper_segments_callback_data {
572
574
  audioDataCount:(int)audioDataCount
573
575
  {
574
576
  whisper_reset_timings(self->ctx);
575
- int code = whisper_full(self->ctx, job->params, audioData, audioDataCount);
577
+ int code = whisper_full_parallel(self->ctx, job->params, audioData, audioDataCount, job->n_processors);
576
578
  if (job && job->is_aborted()) code = -999;
577
579
  // if (code == 0) {
578
580
  // whisper_print_timings(self->ctx);
@@ -20,27 +20,28 @@
20
20
 
21
21
  #ifdef WSP_GGML_USE_METAL
22
22
  if (ctx_params.use_gpu) {
23
- ctx_params.gpu_device = 0;
23
+ // TODO: GPU VAD is forced disabled until the performance is improved (ref: whisper.cpp/whisper_vad_init_context)
24
+ ctx_params.use_gpu = false;
25
+ // ctx_params.gpu_device = 0;
24
26
 
25
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
27
+ // id<MTLDevice> device = MTLCreateSystemDefaultDevice();
26
28
 
27
- // Check ggml-metal availability
28
- BOOL supportsGgmlMetal = [device supportsFamily:MTLGPUFamilyApple7];
29
- if (@available(iOS 16.0, tvOS 16.0, *)) {
30
- supportsGgmlMetal = supportsGgmlMetal && [device supportsFamily:MTLGPUFamilyMetal3];
31
- }
32
- if (!supportsGgmlMetal) {
33
- ctx_params.use_gpu = false;
34
- reasonNoMetal = @"Metal is not supported in this device";
35
- }
29
+ // // Check ggml-metal availability
30
+ // BOOL supportsGgmlMetal = [device supportsFamily:MTLGPUFamilyApple7];
31
+ // if (@available(iOS 16.0, tvOS 16.0, *)) {
32
+ // supportsGgmlMetal = supportsGgmlMetal && [device supportsFamily:MTLGPUFamilyMetal3];
33
+ // }
34
+ // if (!supportsGgmlMetal) {
35
+ // ctx_params.use_gpu = false;
36
+ // reasonNoMetal = @"Metal is not supported in this device";
37
+ // }
38
+ // device = nil;
36
39
 
37
40
  #if TARGET_OS_SIMULATOR
38
41
  // Use the backend, but no layers because not supported fully on simulator
39
42
  ctx_params.use_gpu = false;
40
43
  reasonNoMetal = @"Metal is not supported in simulator";
41
44
  #endif
42
-
43
- device = nil;
44
45
  }
45
46
  #endif // WSP_GGML_USE_METAL
46
47
 
@@ -8,7 +8,7 @@
8
8
  extern "C" {
9
9
  #endif
10
10
 
11
- #define WSP_GGML_BACKEND_API_VERSION 1
11
+ #define WSP_GGML_BACKEND_API_VERSION 2
12
12
 
13
13
  //
14
14
  // Backend buffer type
@@ -114,6 +114,9 @@ extern "C" {
114
114
  void (*event_record)(wsp_ggml_backend_t backend, wsp_ggml_backend_event_t event);
115
115
  // wait for an event on on a different stream
116
116
  void (*event_wait) (wsp_ggml_backend_t backend, wsp_ggml_backend_event_t event);
117
+
118
+ // (optional) sort/optimize the nodes in the graph
119
+ void (*graph_optimize) (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph);
117
120
  };
118
121
 
119
122
  struct wsp_ggml_backend {
@@ -206,9 +209,6 @@ extern "C" {
206
209
  void * context;
207
210
  };
208
211
 
209
- // Internal backend registry API
210
- WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
211
-
212
212
  // Add backend dynamic loading support to the backend
213
213
 
214
214
  // Initialize the backend
@@ -132,6 +132,8 @@ extern "C" {
132
132
  WSP_GGML_BACKEND_DEVICE_TYPE_CPU,
133
133
  // GPU device using dedicated memory
134
134
  WSP_GGML_BACKEND_DEVICE_TYPE_GPU,
135
+ // integrated GPU device using host memory
136
+ WSP_GGML_BACKEND_DEVICE_TYPE_IGPU,
135
137
  // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
136
138
  WSP_GGML_BACKEND_DEVICE_TYPE_ACCEL
137
139
  };
@@ -150,11 +152,21 @@ extern "C" {
150
152
 
151
153
  // all the device properties
152
154
  struct wsp_ggml_backend_dev_props {
155
+ // device name
153
156
  const char * name;
157
+ // device description
154
158
  const char * description;
159
+ // device free memory in bytes
155
160
  size_t memory_free;
161
+ // device total memory in bytes
156
162
  size_t memory_total;
163
+ // device type
157
164
  enum wsp_ggml_backend_dev_type type;
165
+ // device id
166
+ // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
167
+ // if the id is unknown, this should be NULL
168
+ const char * device_id;
169
+ // device capabilities
158
170
  struct wsp_ggml_backend_dev_caps caps;
159
171
  };
160
172
 
@@ -203,6 +215,8 @@ extern "C" {
203
215
  // Backend registry
204
216
  //
205
217
 
218
+ WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
219
+
206
220
  WSP_GGML_API void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device);
207
221
 
208
222
  // Backend (reg) enumeration
@@ -302,11 +316,15 @@ extern "C" {
302
316
  WSP_GGML_API int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched);
303
317
  WSP_GGML_API int wsp_ggml_backend_sched_get_n_copies(wsp_ggml_backend_sched_t sched);
304
318
 
305
- WSP_GGML_API size_t wsp_ggml_backend_sched_get_buffer_size(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
319
+ WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_sched_get_buffer_type(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
320
+ WSP_GGML_API size_t wsp_ggml_backend_sched_get_buffer_size(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
306
321
 
307
322
  WSP_GGML_API void wsp_ggml_backend_sched_set_tensor_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend);
308
323
  WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_sched_get_tensor_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node);
309
324
 
325
+ // Split graph without allocating it
326
+ WSP_GGML_API void wsp_ggml_backend_sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph);
327
+
310
328
  // Allocate and compute graph on the backend scheduler
311
329
  WSP_GGML_API bool wsp_ggml_backend_sched_alloc_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph); // returns success
312
330
  WSP_GGML_API enum wsp_ggml_status wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph);
@@ -101,7 +101,6 @@ extern "C" {
101
101
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
102
102
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
103
103
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
104
- WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_nnpa (void);
105
104
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
106
105
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_llamafile (void);
107
106
 
@@ -135,6 +134,7 @@ extern "C" {
135
134
  WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_cpu_reg(void);
136
135
 
137
136
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137
+ WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
138
138
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp16(const float *, wsp_ggml_fp16_t *, int64_t);
139
139
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t *, float *, int64_t);
140
140
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_bf16(const float *, wsp_ggml_bf16_t *, int64_t);
@@ -73,7 +73,7 @@ static inline int wsp_ggml_up(int n, int m) {
73
73
  return (n + m - 1) & ~(m - 1);
74
74
  }
75
75
 
76
- // TODO: move to ggml.h?
76
+ // TODO: move to ggml.h? (won't be able to inline)
77
77
  static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
78
78
  if (a->type != b->type) {
79
79
  return false;
@@ -89,6 +89,22 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str
89
89
  return true;
90
90
  }
91
91
 
92
+ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
93
+ switch (op) {
94
+ case WSP_GGML_OP_NONE:
95
+ case WSP_GGML_OP_RESHAPE:
96
+ case WSP_GGML_OP_TRANSPOSE:
97
+ case WSP_GGML_OP_VIEW:
98
+ case WSP_GGML_OP_PERMUTE:
99
+ return true;
100
+ default:
101
+ return false;
102
+ }
103
+ }
104
+
105
+ static inline float wsp_ggml_softplus(float input) {
106
+ return (input > 20.0f) ? input : logf(1 + expf(input));
107
+ }
92
108
  //
93
109
  // logging
94
110
  //
@@ -329,6 +345,10 @@ struct wsp_ggml_cgraph {
329
345
  // if you need the gradients, get them from the original graph
330
346
  struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int i0, int i1);
331
347
 
348
+ // ggml-alloc.c: true if the operation can reuse memory from its sources
349
+ WSP_GGML_API bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op);
350
+
351
+
332
352
  // Memory allocation
333
353
 
334
354
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
@@ -545,14 +565,23 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
545
565
  #define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
546
566
  #define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
547
567
 
568
+ static inline int32_t wsp_ggml_node_get_use_count(const struct wsp_ggml_cgraph * cgraph, int node_idx) {
569
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
570
+
571
+ size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
572
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
573
+ return 0;
574
+ }
575
+ return cgraph->use_counts[hash_pos];
576
+ }
577
+
548
578
  // return true if the node's results are only used by N other nodes
549
579
  // and can be fused into their calculations.
550
580
  static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
551
581
  const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
552
582
 
553
583
  // check the use count against how many we're replacing
554
- size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
555
- if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
584
+ if (wsp_ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
556
585
  return false;
557
586
  }
558
587
 
@@ -570,27 +599,27 @@ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgrap
570
599
  return true;
571
600
  }
572
601
 
573
- // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
602
+ // Returns true if nodes with indices { node_idxs } are the sequence of wsp_ggml_ops in ops[]
574
603
  // and are fusable. Nodes are considered fusable according to this function if:
575
604
  // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
576
605
  // - all nodes except the last are a src of the following node.
577
606
  // - all nodes are the same shape.
578
607
  // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
579
- static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
580
- if (node_idx + num_ops > cgraph->n_nodes) {
581
- return false;
582
- }
583
-
608
+ static inline bool wsp_ggml_can_fuse_ext(const struct wsp_ggml_cgraph * cgraph, const int * node_idxs, const enum wsp_ggml_op * ops, int num_ops) {
584
609
  for (int i = 0; i < num_ops; ++i) {
585
- struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
610
+ if (node_idxs[i] >= cgraph->n_nodes) {
611
+ return false;
612
+ }
613
+
614
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idxs[i]];
586
615
  if (node->op != ops[i]) {
587
616
  return false;
588
617
  }
589
- if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
618
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
590
619
  return false;
591
620
  }
592
621
  if (i > 0) {
593
- struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
622
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
594
623
  if (node->src[0] != prev && node->src[1] != prev) {
595
624
  return false;
596
625
  }
@@ -602,6 +631,52 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
602
631
  return true;
603
632
  }
604
633
 
634
+ // same as above, for sequential indices starting at node_idx
635
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
636
+ assert(num_ops < 32);
637
+
638
+ if (node_idx + num_ops > cgraph->n_nodes) {
639
+ return false;
640
+ }
641
+
642
+ int idxs[32];
643
+ for (int i = 0; i < num_ops; ++i) {
644
+ idxs[i] = node_idx + i;
645
+ }
646
+
647
+ return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
648
+ }
649
+
650
+ WSP_GGML_API bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
651
+ const int * node_idxs,
652
+ int count,
653
+ const enum wsp_ggml_op * ops,
654
+ const int * outputs,
655
+ int num_outputs);
656
+
657
+ // Returns true if the subgraph formed by {node_idxs} can be fused
658
+ // checks whethers all nodes which are not part of outputs can be elided
659
+ // by checking if their num_uses are confined to the subgraph
660
+ static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
661
+ int node_idx,
662
+ int count,
663
+ const enum wsp_ggml_op * ops,
664
+ const int * outputs,
665
+ int num_outputs) {
666
+ WSP_GGML_ASSERT(count < 32);
667
+ if (node_idx + count > cgraph->n_nodes) {
668
+ return false;
669
+ }
670
+
671
+ int idxs[32];
672
+
673
+ for (int i = 0; i < count; ++i) {
674
+ idxs[i] = node_idx + i;
675
+ }
676
+
677
+ return wsp_ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
678
+ }
679
+
605
680
  #ifdef __cplusplus
606
681
  }
607
682
  #endif
@@ -615,6 +690,13 @@ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_id
615
690
  return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
616
691
  }
617
692
 
693
+ inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
694
+ int start_idx,
695
+ std::initializer_list<enum wsp_ggml_op> ops,
696
+ std::initializer_list<int> outputs = {}) {
697
+ return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698
+ }
699
+
618
700
  // expose GGUF internals for test code
619
701
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
620
702
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -1,5 +1,5 @@
1
- #ifndef WSP_GGML_METAL_IMPL
2
- #define WSP_GGML_METAL_IMPL
1
+ #ifndef WSP_WSP_WSP_GGML_METAL_IMPL
2
+ #define WSP_WSP_WSP_GGML_METAL_IMPL
3
3
 
4
4
  // kernel parameters for mat-vec threadgroups
5
5
  //
@@ -101,7 +101,7 @@ typedef struct {
101
101
  uint64_t nb2;
102
102
  uint64_t nb3;
103
103
  int32_t dim;
104
- } wsp_ggml_metal_kargs_concat;
104
+ } wsp_wsp_wsp_ggml_metal_kargs_concat;
105
105
 
106
106
  typedef struct {
107
107
  int32_t ne00;
@@ -130,7 +130,7 @@ typedef struct {
130
130
  uint64_t nb3;
131
131
  uint64_t offs;
132
132
  uint64_t o1[8];
133
- } wsp_ggml_metal_kargs_bin;
133
+ } wsp_wsp_wsp_ggml_metal_kargs_bin;
134
134
 
135
135
  typedef struct {
136
136
  int64_t ne0;
@@ -139,7 +139,7 @@ typedef struct {
139
139
  size_t nb02;
140
140
  size_t nb11;
141
141
  size_t nb21;
142
- } wsp_ggml_metal_kargs_add_id;
142
+ } wsp_wsp_wsp_ggml_metal_kargs_add_id;
143
143
 
144
144
  typedef struct {
145
145
  int32_t ne00;
@@ -158,7 +158,7 @@ typedef struct {
158
158
  uint64_t nb1;
159
159
  uint64_t nb2;
160
160
  uint64_t nb3;
161
- } wsp_ggml_metal_kargs_repeat;
161
+ } wsp_wsp_wsp_ggml_metal_kargs_repeat;
162
162
 
163
163
  typedef struct {
164
164
  int64_t ne00;
@@ -177,7 +177,7 @@ typedef struct {
177
177
  uint64_t nb1;
178
178
  uint64_t nb2;
179
179
  uint64_t nb3;
180
- } wsp_ggml_metal_kargs_cpy;
180
+ } wsp_wsp_wsp_ggml_metal_kargs_cpy;
181
181
 
182
182
  typedef struct {
183
183
  int64_t ne10;
@@ -192,7 +192,7 @@ typedef struct {
192
192
  uint64_t nb3;
193
193
  uint64_t offs;
194
194
  bool inplace;
195
- } wsp_ggml_metal_kargs_set;
195
+ } wsp_wsp_wsp_ggml_metal_kargs_set;
196
196
 
197
197
  typedef struct {
198
198
  int32_t ne00;
@@ -224,7 +224,7 @@ typedef struct {
224
224
  int32_t sect_1;
225
225
  int32_t sect_2;
226
226
  int32_t sect_3;
227
- } wsp_ggml_metal_kargs_rope;
227
+ } wsp_wsp_wsp_ggml_metal_kargs_rope;
228
228
 
229
229
  typedef struct {
230
230
  int32_t ne01;
@@ -255,7 +255,7 @@ typedef struct {
255
255
  float m1;
256
256
  int32_t n_head_log2;
257
257
  float logit_softcap;
258
- } wsp_ggml_metal_kargs_flash_attn_ext;
258
+ } wsp_wsp_wsp_ggml_metal_kargs_flash_attn_ext;
259
259
 
260
260
  typedef struct {
261
261
  int32_t ne00;
@@ -272,7 +272,7 @@ typedef struct {
272
272
  int32_t ne1;
273
273
  int16_t r2;
274
274
  int16_t r3;
275
- } wsp_ggml_metal_kargs_mul_mm;
275
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm;
276
276
 
277
277
  typedef struct {
278
278
  int32_t ne00;
@@ -293,7 +293,7 @@ typedef struct {
293
293
  int32_t ne1;
294
294
  int16_t r2;
295
295
  int16_t r3;
296
- } wsp_ggml_metal_kargs_mul_mv;
296
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv;
297
297
 
298
298
  typedef struct {
299
299
  int32_t ne00;
@@ -317,7 +317,7 @@ typedef struct {
317
317
  int16_t nsg;
318
318
  int16_t nxpsg;
319
319
  int16_t r1ptg;
320
- } wsp_ggml_metal_kargs_mul_mv_ext;
320
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_ext;
321
321
 
322
322
  typedef struct {
323
323
  int32_t ne10;
@@ -328,7 +328,7 @@ typedef struct {
328
328
  uint64_t nbh11;
329
329
  int32_t ne20; // n_expert_used
330
330
  uint64_t nb21;
331
- } wsp_ggml_metal_kargs_mul_mm_id_map0;
331
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map0;
332
332
 
333
333
  typedef struct {
334
334
  int32_t ne20; // n_expert_used
@@ -339,7 +339,7 @@ typedef struct {
339
339
  int32_t ne0;
340
340
  uint64_t nb1;
341
341
  uint64_t nb2;
342
- } wsp_ggml_metal_kargs_mul_mm_id_map1;
342
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id_map1;
343
343
 
344
344
  typedef struct {
345
345
  int32_t ne00;
@@ -356,7 +356,7 @@ typedef struct {
356
356
  int32_t neh1;
357
357
  int16_t r2;
358
358
  int16_t r3;
359
- } wsp_ggml_metal_kargs_mul_mm_id;
359
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mm_id;
360
360
 
361
361
  typedef struct {
362
362
  int32_t nei0;
@@ -378,14 +378,14 @@ typedef struct {
378
378
  int32_t ne0;
379
379
  int32_t ne1;
380
380
  uint64_t nb1;
381
- } wsp_ggml_metal_kargs_mul_mv_id;
381
+ } wsp_wsp_wsp_ggml_metal_kargs_mul_mv_id;
382
382
 
383
383
  typedef struct {
384
384
  int32_t ne00;
385
385
  int32_t ne00_4;
386
386
  uint64_t nb01;
387
387
  float eps;
388
- } wsp_ggml_metal_kargs_norm;
388
+ } wsp_wsp_wsp_ggml_metal_kargs_norm;
389
389
 
390
390
  typedef struct {
391
391
  int32_t ne00;
@@ -400,14 +400,14 @@ typedef struct {
400
400
  uint64_t nbf1[3];
401
401
  uint64_t nbf2[3];
402
402
  uint64_t nbf3[3];
403
- } wsp_ggml_metal_kargs_rms_norm;
403
+ } wsp_wsp_wsp_ggml_metal_kargs_rms_norm;
404
404
 
405
405
  typedef struct {
406
406
  int32_t ne00;
407
407
  int32_t ne00_4;
408
408
  uint64_t nb01;
409
409
  float eps;
410
- } wsp_ggml_metal_kargs_l2_norm;
410
+ } wsp_wsp_wsp_ggml_metal_kargs_l2_norm;
411
411
 
412
412
  typedef struct {
413
413
  int64_t ne00;
@@ -418,7 +418,7 @@ typedef struct {
418
418
  uint64_t nb02;
419
419
  int32_t n_groups;
420
420
  float eps;
421
- } wsp_ggml_metal_kargs_group_norm;
421
+ } wsp_wsp_wsp_ggml_metal_kargs_group_norm;
422
422
 
423
423
  typedef struct {
424
424
  int32_t IC;
@@ -427,7 +427,7 @@ typedef struct {
427
427
  int32_t s0;
428
428
  uint64_t nb0;
429
429
  uint64_t nb1;
430
- } wsp_ggml_metal_kargs_conv_transpose_1d;
430
+ } wsp_wsp_wsp_ggml_metal_kargs_conv_transpose_1d;
431
431
 
432
432
  typedef struct {
433
433
  uint64_t ofs0;
@@ -445,7 +445,7 @@ typedef struct {
445
445
  int32_t KH;
446
446
  int32_t KW;
447
447
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
448
- } wsp_ggml_metal_kargs_im2col;
448
+ } wsp_wsp_wsp_ggml_metal_kargs_im2col;
449
449
 
450
450
  typedef struct{
451
451
  int32_t ne00;
@@ -458,7 +458,7 @@ typedef struct{
458
458
  int32_t i10;
459
459
  float alpha;
460
460
  float limit;
461
- } wsp_ggml_metal_kargs_glu;
461
+ } wsp_wsp_wsp_ggml_metal_kargs_glu;
462
462
 
463
463
  typedef struct {
464
464
  int64_t ne00;
@@ -485,7 +485,7 @@ typedef struct {
485
485
  uint64_t nb1;
486
486
  uint64_t nb2;
487
487
  uint64_t nb3;
488
- } wsp_ggml_metal_kargs_sum_rows;
488
+ } wsp_wsp_wsp_ggml_metal_kargs_sum_rows;
489
489
 
490
490
  typedef struct {
491
491
  int32_t ne00;
@@ -508,13 +508,13 @@ typedef struct {
508
508
  float m0;
509
509
  float m1;
510
510
  int32_t n_head_log2;
511
- } wsp_ggml_metal_kargs_soft_max;
511
+ } wsp_wsp_wsp_ggml_metal_kargs_soft_max;
512
512
 
513
513
  typedef struct {
514
514
  int64_t ne00;
515
515
  int64_t ne01;
516
516
  int n_past;
517
- } wsp_ggml_metal_kargs_diag_mask_inf;
517
+ } wsp_wsp_wsp_ggml_metal_kargs_diag_mask_inf;
518
518
 
519
519
  typedef struct {
520
520
  int64_t ne00;
@@ -533,7 +533,7 @@ typedef struct {
533
533
  uint64_t nb0;
534
534
  uint64_t nb1;
535
535
  uint64_t nb2;
536
- } wsp_ggml_metal_kargs_ssm_conv;
536
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_conv;
537
537
 
538
538
  typedef struct {
539
539
  int64_t d_state;
@@ -558,7 +558,7 @@ typedef struct {
558
558
  uint64_t nb51;
559
559
  uint64_t nb52;
560
560
  uint64_t nb53;
561
- } wsp_ggml_metal_kargs_ssm_scan;
561
+ } wsp_wsp_wsp_ggml_metal_kargs_ssm_scan;
562
562
 
563
563
  typedef struct {
564
564
  int64_t ne00;
@@ -569,7 +569,7 @@ typedef struct {
569
569
  uint64_t nb11;
570
570
  uint64_t nb1;
571
571
  uint64_t nb2;
572
- } wsp_ggml_metal_kargs_get_rows;
572
+ } wsp_wsp_wsp_ggml_metal_kargs_get_rows;
573
573
 
574
574
  typedef struct {
575
575
  int32_t nk0;
@@ -585,7 +585,7 @@ typedef struct {
585
585
  uint64_t nb1;
586
586
  uint64_t nb2;
587
587
  uint64_t nb3;
588
- } wsp_ggml_metal_kargs_set_rows;
588
+ } wsp_wsp_wsp_ggml_metal_kargs_set_rows;
589
589
 
590
590
  typedef struct {
591
591
  int64_t ne00;
@@ -608,7 +608,7 @@ typedef struct {
608
608
  float sf1;
609
609
  float sf2;
610
610
  float sf3;
611
- } wsp_ggml_metal_kargs_upscale;
611
+ } wsp_wsp_wsp_ggml_metal_kargs_upscale;
612
612
 
613
613
  typedef struct {
614
614
  int64_t ne00;
@@ -627,7 +627,7 @@ typedef struct {
627
627
  uint64_t nb1;
628
628
  uint64_t nb2;
629
629
  uint64_t nb3;
630
- } wsp_ggml_metal_kargs_pad;
630
+ } wsp_wsp_wsp_ggml_metal_kargs_pad;
631
631
 
632
632
  typedef struct {
633
633
  int64_t ne00;
@@ -648,28 +648,28 @@ typedef struct {
648
648
  uint64_t nb3;
649
649
  int32_t p0;
650
650
  int32_t p1;
651
- } wsp_ggml_metal_kargs_pad_reflect_1d;
651
+ } wsp_wsp_wsp_ggml_metal_kargs_pad_reflect_1d;
652
652
 
653
653
  typedef struct {
654
654
  uint64_t nb1;
655
655
  int dim;
656
656
  int max_period;
657
- } wsp_ggml_metal_kargs_timestep_embedding;
657
+ } wsp_wsp_wsp_ggml_metal_kargs_timestep_embedding;
658
658
 
659
659
  typedef struct {
660
660
  float slope;
661
- } wsp_ggml_metal_kargs_leaky_relu;
661
+ } wsp_wsp_wsp_ggml_metal_kargs_leaky_relu;
662
662
 
663
663
  typedef struct {
664
664
  int64_t ncols;
665
665
  int64_t ncols_pad;
666
- } wsp_ggml_metal_kargs_argsort;
666
+ } wsp_wsp_wsp_ggml_metal_kargs_argsort;
667
667
 
668
668
  typedef struct {
669
669
  int64_t ne0;
670
670
  float start;
671
671
  float step;
672
- } wsp_ggml_metal_kargs_arange;
672
+ } wsp_wsp_wsp_ggml_metal_kargs_arange;
673
673
 
674
674
  typedef struct {
675
675
  int32_t k0;
@@ -683,6 +683,6 @@ typedef struct {
683
683
  int64_t OH;
684
684
  int64_t OW;
685
685
  int64_t parallel_elements;
686
- } wsp_ggml_metal_kargs_pool_2d;
686
+ } wsp_wsp_wsp_ggml_metal_kargs_pool_2d;
687
687
 
688
- #endif // WSP_GGML_METAL_IMPL
688
+ #endif // WSP_WSP_WSP_GGML_METAL_IMPL