whisper.rn 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +38 -14
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend.h +2 -0
  5. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  6. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  7. package/cpp/ggml-cpu/ggml-cpu.c +17 -3
  8. package/cpp/ggml-cpu/ops.cpp +33 -17
  9. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  10. package/cpp/ggml-cpu/unary-ops.h +5 -0
  11. package/cpp/ggml-cpu/vec.cpp +66 -0
  12. package/cpp/ggml-cpu/vec.h +10 -8
  13. package/cpp/ggml-impl.h +51 -2
  14. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  15. package/cpp/ggml-metal/ggml-metal-device.cpp +199 -10
  16. package/cpp/ggml-metal/ggml-metal-device.h +18 -0
  17. package/cpp/ggml-metal/ggml-metal-device.m +27 -14
  18. package/cpp/ggml-metal/ggml-metal-impl.h +87 -7
  19. package/cpp/ggml-metal/ggml-metal-ops.cpp +513 -88
  20. package/cpp/ggml-metal/ggml-metal-ops.h +6 -0
  21. package/cpp/ggml-metal/ggml-metal.cpp +3 -3
  22. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  23. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +166 -2
  25. package/cpp/ggml.h +66 -0
  26. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  27. package/cpp/rn-whisper.h +1 -0
  28. package/cpp/whisper.cpp +4 -2
  29. package/ios/RNWhisperContext.mm +3 -1
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  33. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  34. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  39. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  40. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  47. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  48. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  50. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  51. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  52. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  54. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  58. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  59. package/lib/commonjs/version.json +1 -1
  60. package/lib/module/NativeRNWhisper.js.map +1 -1
  61. package/lib/module/version.json +1 -1
  62. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  63. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  64. package/package.json +1 -1
  65. package/src/NativeRNWhisper.ts +2 -0
  66. package/src/version.json +1 -1
@@ -276,6 +276,7 @@ JNIEXPORT jlong JNICALL
276
276
  Java_com_rnwhisper_WhisperContext_initContextWithAsset(
277
277
  JNIEnv *env,
278
278
  jobject thiz,
279
+ jint context_id,
279
280
  jobject asset_manager,
280
281
  jstring model_path_str
281
282
  ) {
@@ -290,6 +291,7 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
290
291
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
291
292
  context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
292
293
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
294
+ rnwhisper_jsi::addContext(context_id, reinterpret_cast<jlong>(context));
293
295
  return reinterpret_cast<jlong>(context);
294
296
  }
295
297
 
@@ -297,6 +299,7 @@ JNIEXPORT jlong JNICALL
297
299
  Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
298
300
  JNIEnv *env,
299
301
  jobject thiz,
302
+ jint context_id,
300
303
  jobject input_stream
301
304
  ) {
302
305
  UNUSED(thiz);
@@ -308,6 +311,7 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
308
311
 
309
312
  struct whisper_context *context = nullptr;
310
313
  context = whisper_init_from_input_stream(env, input_stream, cparams);
314
+ rnwhisper_jsi::addContext(context_id, reinterpret_cast<jlong>(context));
311
315
  return reinterpret_cast<jlong>(context);
312
316
  }
313
317
 
@@ -421,8 +425,9 @@ Java_com_rnwhisper_WhisperContext_fullWithNewJob(
421
425
  LOGI("About to reset timings");
422
426
  whisper_reset_timings(context);
423
427
 
424
- LOGI("About to run whisper_full");
425
- int code = whisper_full(context, params, audio_data_arr, audio_data_len);
428
+ int n_processors = readablemap::getInt(env, options, "nProcessors", 1);
429
+ LOGI("About to run whisper_full_parallel with n_processors=%d", n_processors);
430
+ int code = whisper_full_parallel(context, params, audio_data_arr, audio_data_len, n_processors);
426
431
  if (code == 0) {
427
432
  // whisper_print_timings(context);
428
433
  }
@@ -441,8 +446,11 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
441
446
  jlong context_ptr,
442
447
  jobject options
443
448
  ) {
449
+ UNUSED(thiz);
450
+ UNUSED(context_ptr);
444
451
  whisper_full_params params = createFullParams(env, options);
445
452
  rnwhisper::job* job = rnwhisper::job_new(job_id, params);
453
+ job->n_processors = readablemap::getInt(env, options, "nProcessors", 1);
446
454
  rnwhisper::vad_params vad;
447
455
  vad.use_vad = readablemap::getBool(env, options, "useVad", false);
448
456
  vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
@@ -534,11 +542,12 @@ Java_com_rnwhisper_WhisperContext_fullWithJob(
534
542
  jint n_samples
535
543
  ) {
536
544
  UNUSED(thiz);
545
+ UNUSED(env);
537
546
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
538
547
 
539
548
  rnwhisper::job* job = rnwhisper::job_get(job_id);
540
549
  float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples);
541
- int code = whisper_full(context, job->params, pcmf32, n_samples);
550
+ int code = whisper_full_parallel(context, job->params, pcmf32, n_samples, job->n_processors);
542
551
  free(pcmf32);
543
552
  if (code == 0) {
544
553
  // whisper_print_timings(context);
package/cpp/ggml-alloc.c CHANGED
@@ -392,12 +392,8 @@ static void wsp_ggml_dyn_tallocr_free(struct wsp_ggml_dyn_tallocr * alloc) {
392
392
  free(alloc);
393
393
  }
394
394
 
395
- static size_t wsp_ggml_dyn_tallocr_max_size(struct wsp_ggml_dyn_tallocr * alloc) {
396
- size_t max_size = 0;
397
- for (int i = 0; i < alloc->n_chunks; i++) {
398
- max_size += alloc->chunks[i]->max_size;
399
- }
400
- return max_size;
395
+ static size_t wsp_ggml_dyn_tallocr_max_size(struct wsp_ggml_dyn_tallocr * alloc, int chunk) {
396
+ return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
401
397
  }
402
398
 
403
399
 
@@ -417,10 +413,8 @@ static void wsp_ggml_vbuffer_free(struct vbuffer * buf) {
417
413
  free(buf);
418
414
  }
419
415
 
420
- static int wsp_ggml_vbuffer_n_chunks(struct vbuffer * buf) {
421
- int n = 0;
422
- while (n < WSP_GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
423
- return n;
416
+ static size_t wsp_ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
417
+ return buf->chunks[chunk] ? wsp_ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
424
418
  }
425
419
 
426
420
  static size_t wsp_ggml_vbuffer_size(struct vbuffer * buf) {
@@ -604,6 +598,26 @@ static bool wsp_ggml_gallocr_is_allocated(wsp_ggml_gallocr_t galloc, struct wsp_
604
598
  return t->data != NULL || wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
605
599
  }
606
600
 
601
+ // free the extra space at the end if the new tensor is smaller
602
+ static void wsp_ggml_gallocr_free_extra_space(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent) {
603
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
604
+ struct hash_node * p_hn = wsp_ggml_gallocr_hash_get(galloc, parent);
605
+
606
+ size_t parent_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
607
+ size_t node_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
608
+
609
+ WSP_GGML_ASSERT(parent_size >= node_size);
610
+
611
+ if (parent_size > node_size) {
612
+ struct wsp_ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
613
+ struct buffer_address p_addr = p_hn->addr;
614
+ p_addr.offset += node_size;
615
+ size_t extra_size = parent_size - node_size;
616
+ AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
617
+ wsp_ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
618
+ }
619
+ }
620
+
607
621
  static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, int buffer_id) {
608
622
  WSP_GGML_ASSERT(buffer_id >= 0);
609
623
  struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
@@ -649,6 +663,7 @@ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp
649
663
  hn->addr = p_hn->addr;
650
664
  p_hn->allocated = false; // avoid freeing the parent
651
665
  view_src_hn->allocated = false;
666
+ wsp_ggml_gallocr_free_extra_space(galloc, node, view_src);
652
667
  return;
653
668
  }
654
669
  } else {
@@ -656,6 +671,7 @@ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp
656
671
  hn->buffer_id = p_hn->buffer_id;
657
672
  hn->addr = p_hn->addr;
658
673
  p_hn->allocated = false; // avoid freeing the parent
674
+ wsp_ggml_gallocr_free_extra_space(galloc, node, parent);
659
675
  return;
660
676
  }
661
677
  }
@@ -885,12 +901,20 @@ bool wsp_ggml_gallocr_reserve_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgrap
885
901
  }
886
902
  }
887
903
 
888
- size_t cur_size = galloc->buffers[i] ? wsp_ggml_vbuffer_size(galloc->buffers[i]) : 0;
889
- size_t new_size = wsp_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
890
-
891
904
  // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
892
- if (new_size > cur_size || galloc->buffers[i] == NULL) {
905
+ bool realloc = galloc->buffers[i] == NULL;
906
+ size_t new_size = 0;
907
+ for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
908
+ size_t cur_chunk_size = galloc->buffers[i] ? wsp_ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
909
+ size_t new_chunk_size = wsp_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
910
+ new_size += new_chunk_size;
911
+ if (new_chunk_size > cur_chunk_size) {
912
+ realloc = true;
913
+ }
914
+ }
915
+ if (realloc) {
893
916
  #ifndef NDEBUG
917
+ size_t cur_size = galloc->buffers[i] ? wsp_ggml_vbuffer_size(galloc->buffers[i]) : 0;
894
918
  WSP_GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, wsp_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
895
919
  #endif
896
920
 
@@ -209,9 +209,6 @@ extern "C" {
209
209
  void * context;
210
210
  };
211
211
 
212
- // Internal backend registry API
213
- WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
214
-
215
212
  // Add backend dynamic loading support to the backend
216
213
 
217
214
  // Initialize the backend
@@ -215,6 +215,8 @@ extern "C" {
215
215
  // Backend registry
216
216
  //
217
217
 
218
+ WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
219
+
218
220
  WSP_GGML_API void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device);
219
221
 
220
222
  // Backend (reg) enumeration
@@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
149
149
  if (op->op == WSP_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150
150
  is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151
151
  op->src[0]->buffer && op->src[0]->buffer->buft == wsp_ggml_backend_amx_buffer_type() &&
152
+ op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
152
153
  op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
153
154
  (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == WSP_GGML_TYPE_F16))) {
154
155
  // src1 must be host buffer
@@ -68,7 +68,7 @@ struct wsp_ggml_compute_params {
68
68
  #endif // __VXE2__
69
69
  #endif // __s390x__ && __VEC__
70
70
 
71
- #if defined(__ARM_FEATURE_SVE)
71
+ #if defined(__ARM_FEATURE_SVE) && defined(__linux__)
72
72
  #include <sys/prctl.h>
73
73
  #endif
74
74
 
@@ -689,8 +689,13 @@ bool wsp_ggml_is_numa(void) {
689
689
  #endif
690
690
 
691
691
  static void wsp_ggml_init_arm_arch_features(void) {
692
- #if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
692
+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
693
+ #if defined(__linux__)
693
694
  wsp_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
695
+ #else
696
+ // TODO: add support of SVE for non-linux systems
697
+ #error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
698
+ #endif
694
699
  #endif
695
700
  }
696
701
 
@@ -2179,6 +2184,10 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
2179
2184
  case WSP_GGML_UNARY_OP_HARDSWISH:
2180
2185
  case WSP_GGML_UNARY_OP_HARDSIGMOID:
2181
2186
  case WSP_GGML_UNARY_OP_EXP:
2187
+ case WSP_GGML_UNARY_OP_FLOOR:
2188
+ case WSP_GGML_UNARY_OP_CEIL:
2189
+ case WSP_GGML_UNARY_OP_ROUND:
2190
+ case WSP_GGML_UNARY_OP_TRUNC:
2182
2191
  {
2183
2192
  n_tasks = 1;
2184
2193
  } break;
@@ -2187,6 +2196,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
2187
2196
  case WSP_GGML_UNARY_OP_GELU_ERF:
2188
2197
  case WSP_GGML_UNARY_OP_GELU_QUICK:
2189
2198
  case WSP_GGML_UNARY_OP_SILU:
2199
+ case WSP_GGML_UNARY_OP_XIELU:
2190
2200
  {
2191
2201
  n_tasks = n_threads;
2192
2202
  } break;
@@ -3557,13 +3567,17 @@ void wsp_ggml_cpu_init(void) {
3557
3567
  #ifdef WSP_GGML_USE_OPENMP
3558
3568
  //if (!getenv("OMP_WAIT_POLICY")) {
3559
3569
  // // set the wait policy to active, so that OpenMP threads don't sleep
3560
- // putenv("OMP_WAIT_POLICY=active");
3570
+ // setenv("OMP_WAIT_POLICY", "active", 0)
3561
3571
  //}
3562
3572
 
3563
3573
  if (!getenv("KMP_BLOCKTIME")) {
3564
3574
  // set the time to wait before sleeping a thread
3565
3575
  // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
3566
- putenv("KMP_BLOCKTIME=200"); // 200ms
3576
+ #ifdef _WIN32
3577
+ _putenv_s("KMP_BLOCKTIME", "200"); // 200ms
3578
+ #else
3579
+ setenv("KMP_BLOCKTIME", "200", 0); // 200ms
3580
+ #endif
3567
3581
  }
3568
3582
  #endif
3569
3583
  }
@@ -3467,31 +3467,27 @@ static void wsp_ggml_compute_forward_norm_f32(
3467
3467
 
3468
3468
  WSP_GGML_ASSERT(eps >= 0.0f);
3469
3469
 
3470
- // TODO: optimize
3471
3470
  for (int64_t i03 = 0; i03 < ne03; i03++) {
3472
3471
  for (int64_t i02 = 0; i02 < ne02; i02++) {
3473
3472
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3474
3473
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3475
3474
 
3476
- wsp_ggml_float sum = 0.0;
3477
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3478
- sum += (wsp_ggml_float)x[i00];
3479
- }
3480
-
3475
+ float sum = 0.0;
3476
+ wsp_ggml_vec_sum_f32(ne00, &sum, x);
3481
3477
  float mean = sum/ne00;
3482
3478
 
3483
3479
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3480
+ float variance = 0;
3484
3481
 
3485
- wsp_ggml_float sum2 = 0.0;
3486
- for (int64_t i00 = 0; i00 < ne00; i00++) {
3487
- float v = x[i00] - mean;
3488
- y[i00] = v;
3489
- sum2 += (wsp_ggml_float)(v*v);
3490
- }
3482
+ #ifdef WSP_GGML_USE_ACCELERATE
3483
+ mean = -mean;
3484
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3485
+ vDSP_measqv(y, 1, &variance, ne00);
3486
+ #else
3487
+ variance = wsp_ggml_vec_cvar_f32(ne00, y, x, mean);
3488
+ #endif //WSP_GGML_USE_ACCELERATE
3491
3489
 
3492
- float variance = sum2/ne00;
3493
3490
  const float scale = 1.0f/sqrtf(variance + eps);
3494
-
3495
3491
  wsp_ggml_vec_scale_f32(ne00, y, scale);
3496
3492
  }
3497
3493
  }
@@ -8135,7 +8131,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8135
8131
  }
8136
8132
 
8137
8133
  // V /= S
8138
- const float S_inv = 1.0f/S;
8134
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8139
8135
  wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
8140
8136
 
8141
8137
  // dst indices
@@ -8637,7 +8633,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8637
8633
  // n_head
8638
8634
  for (int h = ih0; h < ih1; ++h) {
8639
8635
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8636
+ const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8641
8637
  const float dA = expf(dt_soft_plus * A[h]);
8642
8638
  const int g = h / (nh / ng); // repeat_interleave
8643
8639
 
@@ -8734,7 +8730,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8734
8730
  // n_head
8735
8731
  for (int h = ih0; h < ih1; ++h) {
8736
8732
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8733
+ const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8738
8734
  const int g = h / (nh / ng); // repeat_interleave
8739
8735
 
8740
8736
  // dim
@@ -8997,6 +8993,26 @@ void wsp_ggml_compute_forward_unary(
8997
8993
  {
8998
8994
  wsp_ggml_compute_forward_exp(params, dst);
8999
8995
  } break;
8996
+ case WSP_GGML_UNARY_OP_FLOOR:
8997
+ {
8998
+ wsp_ggml_compute_forward_floor(params, dst);
8999
+ } break;
9000
+ case WSP_GGML_UNARY_OP_CEIL:
9001
+ {
9002
+ wsp_ggml_compute_forward_ceil(params, dst);
9003
+ } break;
9004
+ case WSP_GGML_UNARY_OP_ROUND:
9005
+ {
9006
+ wsp_ggml_compute_forward_round(params, dst);
9007
+ } break;
9008
+ case WSP_GGML_UNARY_OP_TRUNC:
9009
+ {
9010
+ wsp_ggml_compute_forward_trunc(params, dst);
9011
+ } break;
9012
+ case WSP_GGML_UNARY_OP_XIELU:
9013
+ {
9014
+ wsp_ggml_compute_forward_xielu(params, dst);
9015
+ } break;
9000
9016
  default:
9001
9017
  {
9002
9018
  WSP_GGML_ABORT("fatal error");
@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
52
52
  return sqrtf(x);
53
53
  }
54
54
 
55
+ static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
56
+ if (x > 0.0f) {
57
+ return alpha_p * x * x + beta * x;
58
+ } else {
59
+ const float min_x_eps = fminf(x, eps);
60
+ return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
61
+ }
62
+ }
63
+
55
64
  static inline float op_sin(float x) {
56
65
  return sinf(x);
57
66
  }
@@ -64,6 +73,22 @@ static inline float op_log(float x) {
64
73
  return logf(x);
65
74
  }
66
75
 
76
+ static inline float op_floor(float x) {
77
+ return floorf(x);
78
+ }
79
+
80
+ static inline float op_ceil(float x) {
81
+ return ceilf(x);
82
+ }
83
+
84
+ static inline float op_round(float x) {
85
+ return roundf(x);
86
+ }
87
+
88
+ static inline float op_trunc(float x) {
89
+ return truncf(x);
90
+ }
91
+
67
92
  template <float (*op)(float), typename src0_t, typename dst_t>
68
93
  static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
69
94
  constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -121,6 +146,86 @@ static void unary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * d
121
146
  }
122
147
  }
123
148
 
149
+ template <float (*op)(float, wsp_ggml_tensor *)>
150
+ static void unary_op_params(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
151
+ const wsp_ggml_tensor * src0 = dst->src[0];
152
+
153
+ /* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
154
+ apply_unary_op<op, float, float>(params, dst);
155
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
156
+ apply_unary_op<op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
157
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
158
+ apply_unary_op<op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
159
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
160
+ apply_unary_op<op, wsp_ggml_bf16_t, float>(params, dst);
161
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
162
+ apply_unary_op<op, wsp_ggml_fp16_t, float>(params, dst);
163
+ } else {
164
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
165
+ wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
166
+ WSP_GGML_ABORT("fatal error");
167
+ }
168
+ }
169
+
170
+ // Extend vec_unary_op to support functors
171
+ template <typename Op, typename src0_t, typename dst_t>
172
+ static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
173
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
174
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
175
+
176
+ for (int i = 0; i < n; i++) {
177
+ y[i] = f32_to_dst(op(src0_to_f32(x[i])));
178
+ }
179
+ }
180
+
181
+ // Extend apply_unary_op to support functors
182
+ template <typename Op, typename src0_t, typename dst_t>
183
+ static void apply_unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
184
+ const wsp_ggml_tensor * src0 = dst->src[0];
185
+
186
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0) && wsp_ggml_is_contiguous_1(dst) && wsp_ggml_are_same_shape(src0, dst));
187
+
188
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
189
+
190
+ WSP_GGML_ASSERT( nb0 == sizeof(dst_t));
191
+ WSP_GGML_ASSERT(nb00 == sizeof(src0_t));
192
+
193
+ const auto [ir0, ir1] = get_thread_range(params, src0);
194
+
195
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
196
+ const int64_t i03 = ir/(ne02*ne01);
197
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
198
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
199
+
200
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
201
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
202
+
203
+ vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
204
+ }
205
+ }
206
+
207
+ // Generic dispatcher for functors
208
+ template <typename Op>
209
+ static void unary_op_functor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst, Op op) {
210
+ const wsp_ggml_tensor * src0 = dst->src[0];
211
+
212
+ /* */ if (src0->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
213
+ apply_unary_op_functor<Op, float, float>(params, dst, op);
214
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
215
+ apply_unary_op_functor<Op, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst, op);
216
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
217
+ apply_unary_op_functor<Op, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst, op);
218
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_F32) {
219
+ apply_unary_op_functor<Op, wsp_ggml_bf16_t, float>(params, dst, op);
220
+ } else if (src0->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F32) {
221
+ apply_unary_op_functor<Op, wsp_ggml_fp16_t, float>(params, dst, op);
222
+ } else {
223
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
224
+ wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type));
225
+ WSP_GGML_ABORT("fatal error");
226
+ }
227
+ }
228
+
124
229
  void wsp_ggml_compute_forward_abs(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
125
230
  unary_op<op_abs>(params, dst);
126
231
  }
@@ -184,3 +289,33 @@ void wsp_ggml_compute_forward_cos(const wsp_ggml_compute_params * params, wsp_gg
184
289
  void wsp_ggml_compute_forward_log(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
185
290
  unary_op<op_log>(params, dst);
186
291
  }
292
+
293
+ void wsp_ggml_compute_forward_floor(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
294
+ unary_op<op_floor>(params, dst);
295
+ }
296
+
297
+ void wsp_ggml_compute_forward_ceil(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
298
+ unary_op<op_ceil>(params, dst);
299
+ }
300
+
301
+ void wsp_ggml_compute_forward_round(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
302
+ unary_op<op_round>(params, dst);
303
+ }
304
+
305
+ void wsp_ggml_compute_forward_trunc(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
306
+ unary_op<op_trunc>(params, dst);
307
+ }
308
+
309
+ void wsp_ggml_compute_forward_xielu(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
310
+ const float alpha_n = wsp_ggml_get_op_params_f32(dst, 1);
311
+ const float alpha_p = wsp_ggml_get_op_params_f32(dst, 2);
312
+ const float beta = wsp_ggml_get_op_params_f32(dst, 3);
313
+ const float eps = wsp_ggml_get_op_params_f32(dst, 4);
314
+
315
+ const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
316
+ return op_xielu(f, alpha_n, alpha_p, beta, eps);
317
+ };
318
+
319
+ unary_op_functor(params, dst, xielu_op_params);
320
+ }
321
+
@@ -22,6 +22,11 @@ void wsp_ggml_compute_forward_sqrt(const struct wsp_ggml_compute_params * params
22
22
  void wsp_ggml_compute_forward_sin(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
23
23
  void wsp_ggml_compute_forward_cos(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
24
24
  void wsp_ggml_compute_forward_log(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
+ void wsp_ggml_compute_forward_floor(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
26
+ void wsp_ggml_compute_forward_ceil(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
27
+ void wsp_ggml_compute_forward_round(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
28
+ void wsp_ggml_compute_forward_trunc(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
29
+ void wsp_ggml_compute_forward_xielu(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
25
30
 
26
31
  #ifdef __cplusplus
27
32
  }
@@ -404,6 +404,72 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
404
404
  }
405
405
  }
406
406
 
407
+ wsp_ggml_float wsp_ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
408
+ int i = 0;
409
+ wsp_ggml_float sum = 0;
410
+ // TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
411
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
412
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
413
+ for (; i + 15 < n; i += 16) {
414
+ __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
415
+ _mm512_set1_ps(mean));
416
+ _mm512_storeu_ps(y + i, val);
417
+ sum += (wsp_ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
418
+ }
419
+ #elif defined(__AVX2__) && defined(__FMA__)
420
+ for (; i + 7 < n; i += 8) {
421
+ __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
422
+ _mm256_set1_ps(mean));
423
+ _mm256_storeu_ps(y + i, val);
424
+ val = _mm256_mul_ps(val,val);
425
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
426
+ _mm256_castps256_ps128(val));
427
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
428
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
429
+ sum += (wsp_ggml_float)_mm_cvtss_f32(val2);
430
+ }
431
+ #elif defined(__SSE2__)
432
+ for (; i + 3 < n; i += 4) {
433
+ __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
434
+ _mm_set1_ps(mean));
435
+ _mm_storeu_ps(y + i, val);
436
+ val = _mm_mul_ps(val, val);
437
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
438
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
439
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
440
+ #else
441
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
442
+ val = _mm_add_ps(val, tmp);
443
+ tmp = _mm_movehl_ps(tmp, val);
444
+ val = _mm_add_ss(val, tmp);
445
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
446
+ sum += (wsp_ggml_float)_mm_cvtss_f32(val);
447
+ }
448
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
449
+ for (; i + 3 < n; i += 4) {
450
+ float32x4_t val = vsubq_f32(vld1q_f32(x + i),
451
+ vdupq_n_f32(mean));
452
+ vst1q_f32(y + i, val);
453
+ val = vmulq_f32(val, val);
454
+ sum += (wsp_ggml_float)vaddvq_f32(val);
455
+ }
456
+ #elif defined(__VXE__) || defined(__VXE2__)
457
+ for (; i + 3 < n; i += 4) {
458
+ float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
459
+ vec_xst(val, 0, y + i);
460
+ val = vec_mul(val, val);
461
+ sum += (wsp_ggml_float)vec_hsum_f32x4(val);
462
+ }
463
+ #endif
464
+ for (; i < n; ++i) {
465
+ float val = x[i] - mean;
466
+ y[i] = val;
467
+ val *= val;
468
+ sum += (wsp_ggml_float)val;
469
+ }
470
+ return sum/n;
471
+ }
472
+
407
473
  wsp_ggml_float wsp_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
408
474
  int i = 0;
409
475
  wsp_ggml_float sum = 0;
@@ -44,6 +44,7 @@ void wsp_ggml_vec_dot_bf16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_gg
44
44
  void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, size_t by, int nrc);
45
45
 
46
46
  void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x);
47
+ wsp_ggml_float wsp_ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
47
48
  wsp_ggml_float wsp_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
48
49
  wsp_ggml_float wsp_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
49
50
 
@@ -143,14 +144,14 @@ inline static void wsp_ggml_vec_dot_f16_unroll(const int n, const int xs, float
143
144
  for (int i = 0; i < np; i += wsp_ggml_f16_step) {
144
145
  ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0 * wsp_ggml_f16_epr, 0); // 8 elements
145
146
 
146
- ax1 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 0*wsp_ggml_f16_epr, 0); // 8 elemnst
147
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 0*wsp_ggml_f16_epr, 0); // 8 elements
147
148
  sum_00 = WSP_GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
148
149
  ax1 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 0*wsp_ggml_f16_epr, 0); // 8 elements
149
150
  sum_10 = WSP_GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
150
151
 
151
152
  ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1 * wsp_ggml_f16_epr, 1); // next 8 elements
152
153
 
153
- ax2 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 1*wsp_ggml_f16_epr, 1); // next 8 ekements
154
+ ax2 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 1*wsp_ggml_f16_epr, 1); // next 8 elements
154
155
  sum_01 = WSP_GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
155
156
  ax2 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 1*wsp_ggml_f16_epr, 1);
156
157
  sum_11 = WSP_GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
@@ -159,7 +160,7 @@ inline static void wsp_ggml_vec_dot_f16_unroll(const int n, const int xs, float
159
160
 
160
161
  ax3 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 2*wsp_ggml_f16_epr, 2);
161
162
  sum_02 = WSP_GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
162
- ax1 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 2*wsp_ggml_f16_epr, 2);
163
+ ax3 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 2*wsp_ggml_f16_epr, 2);
163
164
  sum_12 = WSP_GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
164
165
 
165
166
  ay4 = WSP_GGML_F16x_VEC_LOAD(y + i + 3 * wsp_ggml_f16_epr, 3);
@@ -654,11 +655,11 @@ inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float
654
655
  }
655
656
  // leftovers
656
657
  // maximum number of leftover elements will be less that wsp_ggml_f32_epr. Apply predicated svmad on available elements only
657
- if (np < n) {
658
- svbool_t pg = svwhilelt_b32(np, n);
659
- ay1 = svld1_f32(pg, y + np);
658
+ for (int i = np; i < n; i += wsp_ggml_f32_epr) {
659
+ svbool_t pg = svwhilelt_b32(i, n);
660
+ ay1 = svld1_f32(pg, y + i);
660
661
  ay1 = svmul_f32_m(pg, ay1, vx);
661
- svst1_f32(pg, y + np, ay1);
662
+ svst1_f32(pg, y + i, ay1);
662
663
  }
663
664
  #elif defined(__riscv_v_intrinsic)
664
665
  for (int i = 0, avl; i < n; i += avl) {
@@ -819,7 +820,8 @@ inline static void wsp_ggml_vec_tanh_f16 (const int n, wsp_ggml_fp16_t * y, cons
819
820
  inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
820
821
  inline static void wsp_ggml_vec_elu_f16 (const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) {
821
822
  for (int i = 0; i < n; ++i) {
822
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(expm1f(WSP_GGML_CPU_FP16_TO_FP32(x[i])));
823
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
824
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
823
825
  }
824
826
  }
825
827
  inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }