whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
@@ -10,6 +10,7 @@
10
10
  #include "whisper.h"
11
11
  #include "rn-whisper.h"
12
12
  #include "ggml.h"
13
+ #include "jni-utils.h"
13
14
 
14
15
  #define UNUSED(x) (void)(x)
15
16
  #define TAG "JNI"
@@ -96,7 +97,8 @@ static void input_stream_close(void *ctx) {
96
97
 
97
98
  static struct whisper_context *whisper_init_from_input_stream(
98
99
  JNIEnv *env,
99
- jobject input_stream // PushbackInputStream
100
+ jobject input_stream, // PushbackInputStream
101
+ struct whisper_context_params cparams
100
102
  ) {
101
103
  input_stream_context *context = new input_stream_context;
102
104
  context->env = env;
@@ -108,7 +110,7 @@ static struct whisper_context *whisper_init_from_input_stream(
108
110
  .eof = &input_stream_is_eof,
109
111
  .close = &input_stream_close
110
112
  };
111
- return whisper_init(&loader);
113
+ return whisper_init_with_params(&loader, cparams);
112
114
  }
113
115
 
114
116
  // Load model from asset
@@ -127,7 +129,8 @@ static void asset_close(void *ctx) {
127
129
  static struct whisper_context *whisper_init_from_asset(
128
130
  JNIEnv *env,
129
131
  jobject assetManager,
130
- const char *asset_path
132
+ const char *asset_path,
133
+ struct whisper_context_params cparams
131
134
  ) {
132
135
  LOGI("Loading model from asset '%s'\n", asset_path);
133
136
  AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
@@ -142,7 +145,7 @@ static struct whisper_context *whisper_init_from_asset(
142
145
  .eof = &asset_is_eof,
143
146
  .close = &asset_close
144
147
  };
145
- return whisper_init(&loader);
148
+ return whisper_init_with_params(&loader, cparams);
146
149
  }
147
150
 
148
151
  extern "C" {
@@ -151,9 +154,10 @@ JNIEXPORT jlong JNICALL
151
154
  Java_com_rnwhisper_WhisperContext_initContext(
152
155
  JNIEnv *env, jobject thiz, jstring model_path_str) {
153
156
  UNUSED(thiz);
157
+ struct whisper_context_params cparams;
154
158
  struct whisper_context *context = nullptr;
155
159
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
- context = whisper_init_from_file(model_path_chars);
160
+ context = whisper_init_from_file_with_params(model_path_chars, cparams);
157
161
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
158
162
  return reinterpret_cast<jlong>(context);
159
163
  }
@@ -166,9 +170,10 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
166
170
  jstring model_path_str
167
171
  ) {
168
172
  UNUSED(thiz);
173
+ struct whisper_context_params cparams;
169
174
  struct whisper_context *context = nullptr;
170
175
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
171
- context = whisper_init_from_asset(env, asset_manager, model_path_chars);
176
+ context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
172
177
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
173
178
  return reinterpret_cast<jlong>(context);
174
179
  }
@@ -180,30 +185,65 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
180
185
  jobject input_stream
181
186
  ) {
182
187
  UNUSED(thiz);
188
+ struct whisper_context_params cparams;
183
189
  struct whisper_context *context = nullptr;
184
- context = whisper_init_from_input_stream(env, input_stream);
190
+ context = whisper_init_from_input_stream(env, input_stream, cparams);
185
191
  return reinterpret_cast<jlong>(context);
186
192
  }
187
193
 
188
- JNIEXPORT jboolean JNICALL
189
- Java_com_rnwhisper_WhisperContext_vadSimple(
190
- JNIEnv *env,
191
- jobject thiz,
192
- jfloatArray audio_data,
193
- jint audio_data_len,
194
- jfloat vad_thold,
195
- jfloat vad_freq_thold
196
- ) {
197
- UNUSED(thiz);
198
194
 
199
- std::vector<float> samples(audio_data_len);
200
- jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
201
- for (int i = 0; i < audio_data_len; i++) {
202
- samples[i] = audio_data_arr[i];
195
+ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
196
+ struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
197
+
198
+ params.print_realtime = false;
199
+ params.print_progress = false;
200
+ params.print_timestamps = false;
201
+ params.print_special = false;
202
+
203
+ int max_threads = std::thread::hardware_concurrency();
204
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
205
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
206
+ int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads);
207
+ params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
208
+ params.translate = readablemap::getBool(env, options, "translate", false);
209
+ params.speed_up = readablemap::getBool(env, options, "speedUp", false);
210
+ params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
211
+ params.offset_ms = 0;
212
+ params.no_context = true;
213
+ params.single_segment = false;
214
+
215
+ int beam_size = readablemap::getInt(env, options, "beamSize", -1);
216
+ if (beam_size > -1) {
217
+ params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
218
+ params.beam_search.beam_size = beam_size;
203
219
  }
204
- bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
205
- env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
206
- return is_speech;
220
+ int best_of = readablemap::getInt(env, options, "bestOf", -1);
221
+ if (best_of > -1) params.greedy.best_of = best_of;
222
+ int max_len = readablemap::getInt(env, options, "maxLen", -1);
223
+ if (max_len > -1) params.max_len = max_len;
224
+ int max_context = readablemap::getInt(env, options, "maxContext", -1);
225
+ if (max_context > -1) params.n_max_text_ctx = max_context;
226
+ int offset = readablemap::getInt(env, options, "offset", -1);
227
+ if (offset > -1) params.offset_ms = offset;
228
+ int duration = readablemap::getInt(env, options, "duration", -1);
229
+ if (duration > -1) params.duration_ms = duration;
230
+ int word_thold = readablemap::getInt(env, options, "wordThold", -1);
231
+ if (word_thold > -1) params.thold_pt = word_thold;
232
+ float temperature = readablemap::getFloat(env, options, "temperature", -1);
233
+ if (temperature > -1) params.temperature = temperature;
234
+ float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1);
235
+ if (temperature_inc > -1) params.temperature_inc = temperature_inc;
236
+ jstring prompt = readablemap::getString(env, options, "prompt", nullptr);
237
+ if (prompt != nullptr) {
238
+ params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
239
+ env->DeleteLocalRef(prompt);
240
+ }
241
+ jstring language = readablemap::getString(env, options, "language", nullptr);
242
+ if (language != nullptr) {
243
+ params.language = env->GetStringUTFChars(language, nullptr);
244
+ env->DeleteLocalRef(language);
245
+ }
246
+ return params;
207
247
  }
208
248
 
209
249
  struct callback_context {
@@ -212,102 +252,23 @@ struct callback_context {
212
252
  };
213
253
 
214
254
  JNIEXPORT jint JNICALL
215
- Java_com_rnwhisper_WhisperContext_fullTranscribe(
255
+ Java_com_rnwhisper_WhisperContext_fullWithNewJob(
216
256
  JNIEnv *env,
217
257
  jobject thiz,
218
258
  jint job_id,
219
259
  jlong context_ptr,
220
260
  jfloatArray audio_data,
221
261
  jint audio_data_len,
222
- jint n_threads,
223
- jint max_context,
224
- int word_thold,
225
- int max_len,
226
- jboolean token_timestamps,
227
- jint offset,
228
- jint duration,
229
- jfloat temperature,
230
- jfloat temperature_inc,
231
- jint beam_size,
232
- jint best_of,
233
- jboolean speed_up,
234
- jboolean translate,
235
- jstring language,
236
- jstring prompt,
262
+ jobject options,
237
263
  jobject callback_instance
238
264
  ) {
239
265
  UNUSED(thiz);
240
266
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
241
267
  jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
242
268
 
243
- int max_threads = std::thread::hardware_concurrency();
244
- // Use 2 threads by default on 4-core devices, 4 threads on more cores
245
- int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
246
-
247
269
  LOGI("About to create params");
248
270
 
249
- struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
250
-
251
- if (beam_size > -1) {
252
- params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
253
- params.beam_search.beam_size = beam_size;
254
- }
255
-
256
- params.print_realtime = false;
257
- params.print_progress = false;
258
- params.print_timestamps = false;
259
- params.print_special = false;
260
- params.translate = translate;
261
- const char *language_chars = env->GetStringUTFChars(language, nullptr);
262
- params.language = language_chars;
263
- params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
264
- params.speed_up = speed_up;
265
- params.offset_ms = 0;
266
- params.no_context = true;
267
- params.single_segment = false;
268
-
269
- if (max_len > -1) {
270
- params.max_len = max_len;
271
- }
272
- params.token_timestamps = token_timestamps;
273
-
274
- if (best_of > -1) {
275
- params.greedy.best_of = best_of;
276
- }
277
- if (max_context > -1) {
278
- params.n_max_text_ctx = max_context;
279
- }
280
- if (offset > -1) {
281
- params.offset_ms = offset;
282
- }
283
- if (duration > -1) {
284
- params.duration_ms = duration;
285
- }
286
- if (word_thold > -1) {
287
- params.thold_pt = word_thold;
288
- }
289
- if (temperature > -1) {
290
- params.temperature = temperature;
291
- }
292
- if (temperature_inc > -1) {
293
- params.temperature_inc = temperature_inc;
294
- }
295
- if (prompt != nullptr) {
296
- params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
297
- }
298
-
299
- // abort handlers
300
- bool* abort_ptr = rn_whisper_assign_abort_map(job_id);
301
- params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
302
- bool is_aborted = *(bool*)user_data;
303
- return !is_aborted;
304
- };
305
- params.encoder_begin_callback_user_data = abort_ptr;
306
- params.abort_callback = [](void * user_data) {
307
- bool is_aborted = *(bool*)user_data;
308
- return is_aborted;
309
- };
310
- params.abort_callback_user_data = abort_ptr;
271
+ whisper_full_params params = createFullParams(env, options);
311
272
 
312
273
  if (callback_instance != nullptr) {
313
274
  callback_context *cb_ctx = new callback_context;
@@ -335,6 +296,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
335
296
  params.new_segment_callback_user_data = cb_ctx;
336
297
  }
337
298
 
299
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
300
+
338
301
  LOGI("About to reset timings");
339
302
  whisper_reset_timings(context);
340
303
 
@@ -344,11 +307,122 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
344
307
  // whisper_print_timings(context);
345
308
  }
346
309
  env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
347
- env->ReleaseStringUTFChars(language, language_chars);
348
- if (rn_whisper_transcribe_is_aborted(job_id)) {
349
- code = -999;
310
+
311
+ if (job->is_aborted()) code = -999;
312
+ rnwhisper::job_remove(job_id);
313
+ return code;
314
+ }
315
+
316
+ JNIEXPORT void JNICALL
317
+ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
318
+ JNIEnv *env,
319
+ jobject thiz,
320
+ jint job_id,
321
+ jlong context_ptr,
322
+ jobject options
323
+ ) {
324
+ whisper_full_params params = createFullParams(env, options);
325
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
326
+ rnwhisper::vad_params vad;
327
+ vad.use_vad = readablemap::getBool(env, options, "useVad", false);
328
+ vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
329
+ vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f);
330
+ vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f);
331
+
332
+ jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr);
333
+ const char* audio_output_path_str = nullptr;
334
+ if (audio_output_path != nullptr) {
335
+ audio_output_path_str = env->GetStringUTFChars(audio_output_path, nullptr);
336
+ env->DeleteLocalRef(audio_output_path);
337
+ }
338
+ job->set_realtime_params(
339
+ vad,
340
+ readablemap::getInt(env, options, "realtimeAudioSec", 0),
341
+ readablemap::getInt(env, options, "realtimeAudioSliceSec", 0),
342
+ audio_output_path_str
343
+ );
344
+ }
345
+
346
+ JNIEXPORT void JNICALL
347
+ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob(
348
+ JNIEnv *env,
349
+ jobject thiz,
350
+ jint job_id,
351
+ jlong context_ptr,
352
+ jintArray slice_n_samples
353
+ ) {
354
+ UNUSED(env);
355
+ UNUSED(thiz);
356
+ UNUSED(context_ptr);
357
+
358
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
359
+ if (job->audio_output_path != nullptr) {
360
+ RNWHISPER_LOG_INFO("job->params.language: %s\n", job->params.language);
361
+ std::vector<int> slice_n_samples_vec;
362
+ jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr);
363
+ slice_n_samples_vec = std::vector<int>(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples));
364
+ env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT);
365
+
366
+ // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
367
+ rnaudioutils::save_wav_file(
368
+ rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec),
369
+ job->audio_output_path
370
+ );
371
+ }
372
+ rnwhisper::job_remove(job_id);
373
+ }
374
+
375
+ JNIEXPORT jboolean JNICALL
376
+ Java_com_rnwhisper_WhisperContext_vadSimple(
377
+ JNIEnv *env,
378
+ jobject thiz,
379
+ jint job_id,
380
+ jint slice_index,
381
+ jint n_samples,
382
+ jint n
383
+ ) {
384
+ UNUSED(thiz);
385
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
386
+ return job->vad_simple(slice_index, n_samples, n);
387
+ }
388
+
389
+ JNIEXPORT void JNICALL
390
+ Java_com_rnwhisper_WhisperContext_putPcmData(
391
+ JNIEnv *env,
392
+ jobject thiz,
393
+ jint job_id,
394
+ jshortArray pcm,
395
+ jint slice_index,
396
+ jint n_samples,
397
+ jint n
398
+ ) {
399
+ UNUSED(thiz);
400
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
401
+ jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr);
402
+ job->put_pcm_data(pcm_arr, slice_index, n_samples, n);
403
+ env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT);
404
+ }
405
+
406
+ JNIEXPORT jint JNICALL
407
+ Java_com_rnwhisper_WhisperContext_fullWithJob(
408
+ JNIEnv *env,
409
+ jobject thiz,
410
+ jint job_id,
411
+ jlong context_ptr,
412
+ jint slice_index,
413
+ jint n_samples
414
+ ) {
415
+ UNUSED(thiz);
416
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
417
+
418
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
419
+ float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples);
420
+ int code = whisper_full(context, job->params, pcmf32, n_samples);
421
+ free(pcmf32);
422
+ if (code == 0) {
423
+ // whisper_print_timings(context);
350
424
  }
351
- rn_whisper_remove_abort_map(job_id);
425
+ if (job->is_aborted()) code = -999;
352
426
  return code;
353
427
  }
354
428
 
@@ -359,7 +433,8 @@ Java_com_rnwhisper_WhisperContext_abortTranscribe(
359
433
  jint job_id
360
434
  ) {
361
435
  UNUSED(thiz);
362
- rn_whisper_abort_transcribe(job_id);
436
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
437
+ if (job) job->abort();
363
438
  }
364
439
 
365
440
  JNIEXPORT void JNICALL
@@ -368,7 +443,7 @@ Java_com_rnwhisper_WhisperContext_abortAllTranscribe(
368
443
  jobject thiz
369
444
  ) {
370
445
  UNUSED(thiz);
371
- rn_whisper_abort_all_transcribe();
446
+ rnwhisper::job_abort_all();
372
447
  }
373
448
 
374
449
  JNIEXPORT jint JNICALL
package/cpp/README.md CHANGED
@@ -1,4 +1,4 @@
1
1
  # Note
2
2
 
3
- - Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp).
3
+ - Only `rn-*` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp).
4
4
  - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script.
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
123
123
 
124
124
  /**
125
125
  Make a prediction using the convenience interface
126
- @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
126
+ @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
127
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
128
  @return the prediction as whisper_encoder_implOutput
129
129
  */
@@ -3,6 +3,8 @@
3
3
  // Code is derived from the work of Github user @wangchou
4
4
  // ref: https://github.com/wangchou/callCoreMLFromCpp
5
5
 
6
+ #include <stdint.h>
7
+
6
8
  #if __cplusplus
7
9
  extern "C" {
8
10
  #endif
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
14
16
 
15
17
  void whisper_coreml_encode(
16
18
  const whisper_coreml_context * ctx,
19
+ int64_t n_ctx,
20
+ int64_t n_mel,
17
21
  float * mel,
18
22
  float * out);
19
23
 
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
48
48
 
49
49
  void whisper_coreml_encode(
50
50
  const whisper_coreml_context * ctx,
51
+ int64_t n_ctx,
52
+ int64_t n_mel,
51
53
  float * mel,
52
54
  float * out) {
53
55
  MLMultiArray * inMultiArray = [
54
56
  [MLMultiArray alloc] initWithDataPointer: mel
55
- shape: @[@1, @80, @3000]
57
+ shape: @[@1, @(n_mel), @(n_ctx)]
56
58
  dataType: MLMultiArrayDataTypeFloat32
57
- strides: @[@(240000), @(3000), @1]
59
+ strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
58
60
  deallocator: nil
59
61
  error: nil
60
62
  ];
package/cpp/ggml-alloc.c CHANGED
@@ -137,7 +137,7 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t
137
137
 
138
138
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
139
139
  add_allocated_tensor(alloc, tensor);
140
- size_t cur_max = (char*)addr - (char*)alloc->data + size;
140
+ size_t cur_max = (char*)addr - (char*)alloc->base + size;
141
141
  if (cur_max > alloc->max_size) {
142
142
  printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143
143
  for (int i = 0; i < 1024; i++) {
@@ -168,10 +168,6 @@ static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_gg
168
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
169
  AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
170
 
171
- if (!alloc->measure) {
172
- wsp_ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
173
- }
174
-
175
171
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
176
172
  remove_allocated_tensor(alloc, tensor);
177
173
  #endif
@@ -237,7 +233,7 @@ void wsp_ggml_tallocr_reset(wsp_ggml_tallocr_t alloc) {
237
233
  }
238
234
 
239
235
  wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment) {
240
- struct wsp_ggml_backend_buffer * buffer = wsp_ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
236
+ struct wsp_ggml_backend_buffer * buffer = wsp_ggml_backend_cpu_buffer_from_ptr(data, size);
241
237
 
242
238
  wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr));
243
239
 
@@ -446,18 +442,19 @@ static wsp_ggml_tallocr_t node_tallocr(wsp_ggml_gallocr_t galloc, struct wsp_ggm
446
442
  return galloc->hash_allocs[wsp_ggml_hash_find_or_insert(galloc->hash_set, node)];
447
443
  }
448
444
 
449
- static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view) {
445
+ static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view, bool update_backend) {
450
446
  wsp_ggml_tallocr_t alloc = node_tallocr(galloc, view);
451
447
 
452
- //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
453
448
  WSP_GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
454
- view->backend = view->view_src->backend;
449
+ if (update_backend) {
450
+ view->backend = view->view_src->backend;
451
+ }
455
452
  view->buffer = view->view_src->buffer;
456
453
  view->data = (char *)view->view_src->data + view->view_offs;
457
454
 
458
455
  // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
459
456
  // due to the wsp_ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
460
- assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
457
+ assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
461
458
 
462
459
  if (!alloc->measure) {
463
460
  wsp_ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -469,7 +466,7 @@ static void allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * no
469
466
 
470
467
  if (node->data == NULL) {
471
468
  if (wsp_ggml_is_view(node)) {
472
- init_view(galloc, node);
469
+ init_view(galloc, node, true);
473
470
  } else {
474
471
  // see if we can reuse a parent's buffer (inplace)
475
472
  if (wsp_ggml_op_can_inplace(node->op)) {
@@ -499,15 +496,14 @@ static void allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * no
499
496
  AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
500
497
  node->view_src = view_src;
501
498
  view_src_hn->n_views += 1;
502
- init_view(galloc, node);
499
+ init_view(galloc, node, false);
503
500
  return;
504
501
  }
505
- }
506
- else {
502
+ } else {
507
503
  AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
508
504
  node->view_src = parent;
509
505
  p_hn->n_views += 1;
510
- init_view(galloc, node);
506
+ init_view(galloc, node, false);
511
507
  return;
512
508
  }
513
509
  }
@@ -537,7 +533,7 @@ static void wsp_ggml_tallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct
537
533
  hash_get(galloc, view_src)->n_views += 1;
538
534
  if (node->buffer == NULL && node->data != NULL) {
539
535
  // view of a pre-allocated tensor, didn't call init_view() yet
540
- init_view(galloc, node);
536
+ init_view(galloc, node, true);
541
537
  }
542
538
  }
543
539
 
@@ -548,7 +544,7 @@ static void wsp_ggml_tallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct
548
544
  }
549
545
  hash_get(galloc, parent)->n_children += 1;
550
546
  if (wsp_ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
551
- init_view(galloc, parent);
547
+ init_view(galloc, parent, true);
552
548
  }
553
549
  }
554
550
  }
@@ -663,7 +659,7 @@ size_t wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, wsp_ggml_tallocr_
663
659
  return max_size;
664
660
  }
665
661
 
666
- void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, struct wsp_ggml_hash_set hash_set, wsp_ggml_tallocr_t * hash_node_alloct) {
662
+ void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, struct wsp_ggml_hash_set hash_set, wsp_ggml_tallocr_t * hash_node_talloc) {
667
663
  const size_t hash_size = hash_set.size;
668
664
 
669
665
  WSP_GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
@@ -686,7 +682,7 @@ void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_c
686
682
  // reset hash values
687
683
  memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
688
684
 
689
- galloc->hash_allocs = hash_node_alloct;
685
+ galloc->hash_allocs = hash_node_talloc;
690
686
 
691
687
  wsp_ggml_tallocr_alloc_graph_impl(galloc, graph);
692
688
 
@@ -764,3 +760,43 @@ size_t wsp_ggml_allocr_max_size(wsp_ggml_allocr_t alloc) {
764
760
  size_t wsp_ggml_allocr_alloc_graph(wsp_ggml_allocr_t alloc, struct wsp_ggml_cgraph * graph) {
765
761
  return wsp_ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
766
762
  }
763
+
764
+ // utils
765
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct wsp_ggml_context * ctx, wsp_ggml_backend_buffer_type_t buft) {
766
+ WSP_GGML_ASSERT(wsp_ggml_get_no_alloc(ctx) == true);
767
+
768
+ size_t alignment = wsp_ggml_backend_buft_get_alignment(buft);
769
+
770
+ size_t nbytes = 0;
771
+ for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
772
+ if (t->data == NULL && t->view_src == NULL) {
773
+ nbytes += WSP_GGML_PAD(wsp_ggml_backend_buft_get_alloc_size(buft, t), alignment);
774
+ }
775
+ }
776
+
777
+ if (nbytes == 0) {
778
+ fprintf(stderr, "%s: no tensors to allocate\n", __func__);
779
+ return NULL;
780
+ }
781
+
782
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, nbytes);
783
+ wsp_ggml_tallocr_t tallocr = wsp_ggml_tallocr_new_from_buffer(buffer);
784
+
785
+ for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
786
+ if (t->data == NULL) {
787
+ if (t->view_src == NULL) {
788
+ wsp_ggml_tallocr_alloc(tallocr, t);
789
+ } else {
790
+ wsp_ggml_backend_view_init(buffer, t);
791
+ }
792
+ }
793
+ }
794
+
795
+ wsp_ggml_tallocr_free(tallocr);
796
+
797
+ return buffer;
798
+ }
799
+
800
+ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors(struct wsp_ggml_context * ctx, wsp_ggml_backend_t backend) {
801
+ return wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, wsp_ggml_backend_get_default_buffer_type(backend));
802
+ }
package/cpp/ggml-alloc.h CHANGED
@@ -8,6 +8,7 @@ extern "C" {
8
8
 
9
9
  struct wsp_ggml_backend;
10
10
  struct wsp_ggml_backend_buffer;
11
+ struct wsp_ggml_backend_buffer_type;
11
12
 
12
13
  //
13
14
  // Legacy API
@@ -80,6 +81,12 @@ WSP_GGML_API void wsp_ggml_gallocr_alloc_graph_n(
80
81
  struct wsp_ggml_hash_set hash_set,
81
82
  wsp_ggml_tallocr_t * hash_node_talloc);
82
83
 
84
+
85
+ // Utils
86
+ // Create a buffer and allocate all the tensors in a wsp_ggml_context
87
+ WSP_GGML_API struct wsp_ggml_backend_buffer * wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct wsp_ggml_context * ctx, struct wsp_ggml_backend_buffer_type * buft);
88
+ WSP_GGML_API struct wsp_ggml_backend_buffer * wsp_ggml_backend_alloc_ctx_tensors(struct wsp_ggml_context * ctx, struct wsp_ggml_backend * backend);
89
+
83
90
  #ifdef __cplusplus
84
91
  }
85
92
  #endif