whisper.rn 0.4.2 → 0.5.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. package/README.md +1 -3
  2. package/android/build.gradle +70 -11
  3. package/android/src/main/CMakeLists.txt +28 -1
  4. package/android/src/main/java/com/rnwhisper/JSCallInvokerResolver.java +40 -0
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +80 -27
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +21 -9
  7. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -1
  8. package/android/src/main/jni.cpp +79 -2
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  16. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  17. package/cpp/ggml-backend.cpp +36 -18
  18. package/cpp/ggml-backend.h +1 -1
  19. package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
  20. package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
  21. package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
  22. package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
  23. package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
  24. package/cpp/ggml-cpu/common.h +3 -2
  25. package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
  26. package/cpp/ggml-cpu/ggml-cpu.c +95 -17
  27. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  28. package/cpp/ggml-cpu/ops.cpp +775 -74
  29. package/cpp/ggml-cpu/ops.h +7 -0
  30. package/cpp/ggml-cpu/quants.c +25 -24
  31. package/cpp/ggml-cpu/repack.cpp +15 -14
  32. package/cpp/ggml-cpu/simd-mappings.h +211 -33
  33. package/cpp/ggml-cpu/vec.cpp +26 -2
  34. package/cpp/ggml-cpu/vec.h +99 -45
  35. package/cpp/ggml-cpu.h +2 -0
  36. package/cpp/ggml-impl.h +125 -183
  37. package/cpp/ggml-metal-impl.h +27 -0
  38. package/cpp/ggml-metal.m +298 -41
  39. package/cpp/ggml-quants.c +6 -6
  40. package/cpp/ggml-whisper-sim.metallib +0 -0
  41. package/cpp/ggml-whisper.metallib +0 -0
  42. package/cpp/ggml.c +269 -40
  43. package/cpp/ggml.h +122 -2
  44. package/cpp/gguf.cpp +5 -1
  45. package/cpp/jsi/RNWhisperJSI.cpp +681 -0
  46. package/cpp/jsi/RNWhisperJSI.h +44 -0
  47. package/cpp/jsi/ThreadPool.h +100 -0
  48. package/cpp/whisper.cpp +4 -0
  49. package/cpp/whisper.h +2 -0
  50. package/ios/RNWhisper.h +3 -0
  51. package/ios/RNWhisper.mm +66 -31
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  79. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  84. package/jest/mock.js +1 -0
  85. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  86. package/lib/commonjs/index.js +83 -2
  87. package/lib/commonjs/index.js.map +1 -1
  88. package/lib/module/NativeRNWhisper.js.map +1 -1
  89. package/lib/module/index.js +83 -2
  90. package/lib/module/index.js.map +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +4 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/index.d.ts +18 -6
  94. package/lib/typescript/index.d.ts.map +1 -1
  95. package/package.json +2 -3
  96. package/src/NativeRNWhisper.ts +2 -0
  97. package/src/index.ts +162 -33
  98. package/whisper-rn.podspec +6 -3
@@ -301,6 +301,7 @@ struct wsp_ggml_cgraph {
301
301
  struct wsp_ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302
302
  struct wsp_ggml_tensor ** grad_accs; // accumulators for node gradients
303
303
  struct wsp_ggml_tensor ** leafs; // tensors with constant data
304
+ int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
304
305
 
305
306
  struct wsp_ggml_hash_set visited_hash_set;
306
307
 
@@ -317,203 +318,81 @@ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int
317
318
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
318
319
  WSP_GGML_API void wsp_ggml_aligned_free(void * ptr, size_t size);
319
320
 
320
- // FP16 to FP32 conversion
321
+ // FP16 <-> FP32
322
+ // ref: https://github.com/Maratyszcza/FP16
321
323
 
322
- // 16-bit float
323
- // on Arm, we use __fp16
324
- // on x86, we use uint16_t
325
- //
326
- // for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616
327
- // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
328
- //
329
- #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
330
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
331
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
332
-
333
- #define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
334
-
335
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
336
- __fp16 tmp;
337
- memcpy(&tmp, &h, sizeof(wsp_ggml_fp16_t));
338
- return (float)tmp;
339
- }
340
-
341
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
342
- wsp_ggml_fp16_t res;
343
- __fp16 tmp = f;
344
- memcpy(&res, &tmp, sizeof(wsp_ggml_fp16_t));
345
- return res;
346
- }
347
-
348
- #elif defined(__F16C__)
349
-
350
- #ifdef _MSC_VER
351
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
352
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
353
- #else
354
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
355
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
356
- #endif
357
-
358
- #elif defined(__POWER9_VECTOR__)
359
-
360
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
361
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
362
- /* the inline asm below is about 12% faster than the lookup method */
363
- #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
364
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
365
-
366
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
367
- float f;
368
- double d;
369
- __asm__(
370
- "mtfprd %0,%2\n"
371
- "xscvhpdp %0,%0\n"
372
- "frsp %1,%0\n" :
373
- /* temp */ "=d"(d),
374
- /* out */ "=f"(f):
375
- /* in */ "r"(h));
376
- return f;
377
- }
378
-
379
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
380
- double d;
381
- wsp_ggml_fp16_t r;
382
- __asm__( /* xscvdphp can work on double or single precision */
383
- "xscvdphp %0,%2\n"
384
- "mffprd %1,%0\n" :
385
- /* temp */ "=d"(d),
386
- /* out */ "=r"(r):
387
- /* in */ "f"(f));
388
- return r;
389
- }
390
-
391
- #elif defined(__riscv) && defined(__riscv_zfhmin)
392
-
393
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
394
- float f;
395
- __asm__(
396
- "fmv.h.x %[f], %[h]\n\t"
397
- "fcvt.s.h %[f], %[f]"
398
- : [f] "=&f" (f)
399
- : [h] "r" (h)
400
- );
401
- return f;
402
- }
324
+ static inline float fp32_from_bits(uint32_t w) {
325
+ union {
326
+ uint32_t as_bits;
327
+ float as_value;
328
+ } fp32;
329
+ fp32.as_bits = w;
330
+ return fp32.as_value;
331
+ }
403
332
 
404
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
405
- wsp_ggml_fp16_t res;
406
- __asm__(
407
- "fcvt.h.s %[f], %[f]\n\t"
408
- "fmv.x.h %[h], %[f]"
409
- : [h] "=&r" (res)
410
- : [f] "f" (f)
411
- );
412
- return res;
413
- }
333
+ static inline uint32_t fp32_to_bits(float f) {
334
+ union {
335
+ float as_value;
336
+ uint32_t as_bits;
337
+ } fp32;
338
+ fp32.as_value = f;
339
+ return fp32.as_bits;
340
+ }
414
341
 
415
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
416
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
417
- #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
418
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
342
+ static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
343
+ const uint32_t w = (uint32_t) h << 16;
344
+ const uint32_t sign = w & UINT32_C(0x80000000);
345
+ const uint32_t two_w = w + w;
419
346
 
347
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
348
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
349
+ const float exp_scale = 0x1.0p-112f;
420
350
  #else
351
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
352
+ #endif
353
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
421
354
 
422
- // FP16 <-> FP32
423
- // ref: https://github.com/Maratyszcza/FP16
424
-
425
- static inline float fp32_from_bits(uint32_t w) {
426
- union {
427
- uint32_t as_bits;
428
- float as_value;
429
- } fp32;
430
- fp32.as_bits = w;
431
- return fp32.as_value;
432
- }
433
-
434
- static inline uint32_t fp32_to_bits(float f) {
435
- union {
436
- float as_value;
437
- uint32_t as_bits;
438
- } fp32;
439
- fp32.as_value = f;
440
- return fp32.as_bits;
441
- }
442
-
443
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
444
- const uint32_t w = (uint32_t) h << 16;
445
- const uint32_t sign = w & UINT32_C(0x80000000);
446
- const uint32_t two_w = w + w;
447
-
448
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
449
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
450
- const float exp_scale = 0x1.0p-112f;
451
- #else
452
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
453
- #endif
454
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
455
-
456
- const uint32_t magic_mask = UINT32_C(126) << 23;
457
- const float magic_bias = 0.5f;
458
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
355
+ const uint32_t magic_mask = UINT32_C(126) << 23;
356
+ const float magic_bias = 0.5f;
357
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
459
358
 
460
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
461
- const uint32_t result = sign |
462
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
463
- return fp32_from_bits(result);
464
- }
465
-
466
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
467
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
468
- const float scale_to_inf = 0x1.0p+112f;
469
- const float scale_to_zero = 0x1.0p-110f;
470
- #else
471
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
472
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
473
- #endif
474
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
475
-
476
- const uint32_t w = fp32_to_bits(f);
477
- const uint32_t shl1_w = w + w;
478
- const uint32_t sign = w & UINT32_C(0x80000000);
479
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
480
- if (bias < UINT32_C(0x71000000)) {
481
- bias = UINT32_C(0x71000000);
482
- }
359
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
360
+ const uint32_t result = sign |
361
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
362
+ return fp32_from_bits(result);
363
+ }
483
364
 
484
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
485
- const uint32_t bits = fp32_to_bits(base);
486
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
487
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
488
- const uint32_t nonsign = exp_bits + mantissa_bits;
489
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
365
+ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
366
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
367
+ const float scale_to_inf = 0x1.0p+112f;
368
+ const float scale_to_zero = 0x1.0p-110f;
369
+ #else
370
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
371
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
372
+ #endif
373
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
374
+
375
+ const uint32_t w = fp32_to_bits(f);
376
+ const uint32_t shl1_w = w + w;
377
+ const uint32_t sign = w & UINT32_C(0x80000000);
378
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
379
+ if (bias < UINT32_C(0x71000000)) {
380
+ bias = UINT32_C(0x71000000);
490
381
  }
491
382
 
492
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
493
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
494
-
495
- #endif // defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
496
-
497
- // precomputed f32 table for f16 (256 KB)
498
- // defined in ggml.c, initialized in wsp_ggml_init()
499
- WSP_GGML_API float wsp_ggml_table_f32_f16[1 << 16];
500
-
501
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32,
502
- // so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON.
503
- // This is also true for POWER9.
504
- #if !defined(WSP_GGML_FP16_TO_FP32)
505
- inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
506
- uint16_t s;
507
- memcpy(&s, &f, sizeof(uint16_t));
508
- return wsp_ggml_table_f32_f16[s];
383
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
384
+ const uint32_t bits = fp32_to_bits(base);
385
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
386
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
387
+ const uint32_t nonsign = exp_bits + mantissa_bits;
388
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
509
389
  }
510
390
 
511
- #define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x)
512
- #endif
391
+ #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
392
+ #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
513
393
 
514
- #if !defined(WSP_GGML_FP32_TO_FP16)
394
+ #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
515
395
  #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
516
- #endif
517
396
 
518
397
  /**
519
398
  * Converts brain16 to float32.
@@ -589,13 +468,76 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
589
468
  #define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
590
469
  #define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
591
470
 
471
+ // return true if the node's results are only used by N other nodes
472
+ // and can be fused into their calculations.
473
+ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
474
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
475
+
476
+ // check the use count against how many we're replacing
477
+ size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
478
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
479
+ return false;
480
+ }
481
+
482
+ // if node is a view, some other node might be using the intermediate result
483
+ // via the view source.
484
+ if (node->view_src) {
485
+ return false;
486
+ }
487
+
488
+ // If the user requested output for the node, can't fuse
489
+ if (node->flags & WSP_GGML_TENSOR_FLAG_OUTPUT) {
490
+ return false;
491
+ }
492
+
493
+ return true;
494
+ }
495
+
496
+ // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
497
+ // and are fusable. Nodes are considered fusable according to this function if:
498
+ // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
499
+ // - all nodes except the last are a src of the following node.
500
+ // - all nodes are the same shape.
501
+ // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
502
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
503
+ if (node_idx + num_ops > cgraph->n_nodes) {
504
+ return false;
505
+ }
506
+
507
+ for (int i = 0; i < num_ops; ++i) {
508
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
509
+ if (node->op != ops[i]) {
510
+ return false;
511
+ }
512
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
513
+ return false;
514
+ }
515
+ if (i > 0) {
516
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
517
+ if (node->src[0] != prev && node->src[1] != prev) {
518
+ return false;
519
+ }
520
+ if (!wsp_ggml_are_same_shape(node, prev)) {
521
+ return false;
522
+ }
523
+ }
524
+ }
525
+ return true;
526
+ }
527
+
592
528
  #ifdef __cplusplus
593
529
  }
594
530
  #endif
595
531
 
596
532
  #ifdef __cplusplus
533
+ #include <initializer_list>
597
534
  #include <vector>
598
535
 
536
+ // nicer C++ syntax for wsp_ggml_can_fuse
537
+ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum wsp_ggml_op> ops) {
538
+ return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
539
+ }
540
+
599
541
  // expose GGUF internals for test code
600
542
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
601
543
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -422,6 +422,17 @@ typedef struct {
422
422
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
423
  } wsp_ggml_metal_kargs_im2col;
424
424
 
425
+ typedef struct{
426
+ int32_t ne00;
427
+ uint64_t nb01;
428
+ int32_t ne10;
429
+ uint64_t nb11;
430
+ int32_t ne0;
431
+ uint64_t nb1;
432
+ int32_t i00;
433
+ int32_t i10;
434
+ } wsp_ggml_metal_kargs_glu;
435
+
425
436
  typedef struct {
426
437
  int64_t ne00;
427
438
  int64_t ne01;
@@ -521,6 +532,22 @@ typedef struct {
521
532
  uint64_t nb2;
522
533
  } wsp_ggml_metal_kargs_get_rows;
523
534
 
535
+ typedef struct {
536
+ int32_t nk0;
537
+ int32_t ne01;
538
+ uint64_t nb01;
539
+ uint64_t nb02;
540
+ uint64_t nb03;
541
+ int32_t ne11;
542
+ int32_t ne12;
543
+ uint64_t nb10;
544
+ uint64_t nb11;
545
+ uint64_t nb12;
546
+ uint64_t nb1;
547
+ uint64_t nb2;
548
+ uint64_t nb3;
549
+ } wsp_ggml_metal_kargs_set_rows;
550
+
524
551
  typedef struct {
525
552
  int64_t ne00;
526
553
  int64_t ne01;
@@ -470,6 +470,7 @@ extern "C" {
470
470
  WSP_GGML_OP_TRANSPOSE,
471
471
  WSP_GGML_OP_GET_ROWS,
472
472
  WSP_GGML_OP_GET_ROWS_BACK,
473
+ WSP_GGML_OP_SET_ROWS,
473
474
  WSP_GGML_OP_DIAG,
474
475
  WSP_GGML_OP_DIAG_MASK_INF,
475
476
  WSP_GGML_OP_DIAG_MASK_ZERO,
@@ -481,6 +482,7 @@ extern "C" {
481
482
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
482
483
  WSP_GGML_OP_IM2COL,
483
484
  WSP_GGML_OP_IM2COL_BACK,
485
+ WSP_GGML_OP_CONV_2D,
484
486
  WSP_GGML_OP_CONV_2D_DW,
485
487
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
486
488
  WSP_GGML_OP_POOL_1D,
@@ -519,6 +521,8 @@ extern "C" {
519
521
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
520
522
  WSP_GGML_OP_OPT_STEP_ADAMW,
521
523
 
524
+ WSP_GGML_OP_GLU,
525
+
522
526
  WSP_GGML_OP_COUNT,
523
527
  };
524
528
 
@@ -542,6 +546,14 @@ extern "C" {
542
546
  WSP_GGML_UNARY_OP_COUNT,
543
547
  };
544
548
 
549
+ enum wsp_ggml_glu_op {
550
+ WSP_GGML_GLU_OP_REGLU,
551
+ WSP_GGML_GLU_OP_GEGLU,
552
+ WSP_GGML_GLU_OP_SWIGLU,
553
+
554
+ WSP_GGML_GLU_OP_COUNT,
555
+ };
556
+
545
557
  enum wsp_ggml_object_type {
546
558
  WSP_GGML_OBJECT_TYPE_TENSOR,
547
559
  WSP_GGML_OBJECT_TYPE_GRAPH,
@@ -657,6 +669,7 @@ extern "C" {
657
669
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
658
670
 
659
671
  WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
672
+ WSP_GGML_API const char * wsp_ggml_glu_op_name(enum wsp_ggml_glu_op op);
660
673
  WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
661
674
 
662
675
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
@@ -687,6 +700,9 @@ extern "C" {
687
700
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688
701
  WSP_GGML_API bool wsp_ggml_is_contiguous_channels(const struct wsp_ggml_tensor * tensor);
689
702
 
703
+ // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
704
+ WSP_GGML_API bool wsp_ggml_is_contiguous_rows(const struct wsp_ggml_tensor * tensor);
705
+
690
706
  WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
691
707
  WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
692
708
 
@@ -758,6 +774,7 @@ extern "C" {
758
774
  WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
759
775
 
760
776
  WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
777
+ WSP_GGML_API enum wsp_ggml_glu_op wsp_ggml_get_glu_op(const struct wsp_ggml_tensor * tensor);
761
778
 
762
779
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
763
780
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
@@ -1086,6 +1103,63 @@ extern "C" {
1086
1103
  struct wsp_ggml_context * ctx,
1087
1104
  struct wsp_ggml_tensor * a);
1088
1105
 
1106
+ // gated linear unit ops
1107
+ // A: n columns, r rows,
1108
+ // result is n / 2 columns, r rows,
1109
+ // expects gate in second half of row, unless swapped is true
1110
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_glu(
1111
+ struct wsp_ggml_context * ctx,
1112
+ struct wsp_ggml_tensor * a,
1113
+ enum wsp_ggml_glu_op op,
1114
+ bool swapped);
1115
+
1116
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reglu(
1117
+ struct wsp_ggml_context * ctx,
1118
+ struct wsp_ggml_tensor * a);
1119
+
1120
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reglu_swapped(
1121
+ struct wsp_ggml_context * ctx,
1122
+ struct wsp_ggml_tensor * a);
1123
+
1124
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu(
1125
+ struct wsp_ggml_context * ctx,
1126
+ struct wsp_ggml_tensor * a);
1127
+
1128
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_swapped(
1129
+ struct wsp_ggml_context * ctx,
1130
+ struct wsp_ggml_tensor * a);
1131
+
1132
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_swiglu(
1133
+ struct wsp_ggml_context * ctx,
1134
+ struct wsp_ggml_tensor * a);
1135
+
1136
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_swiglu_swapped(
1137
+ struct wsp_ggml_context * ctx,
1138
+ struct wsp_ggml_tensor * a);
1139
+
1140
+ // A: n columns, r rows,
1141
+ // B: n columns, r rows,
1142
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_glu_split(
1143
+ struct wsp_ggml_context * ctx,
1144
+ struct wsp_ggml_tensor * a,
1145
+ struct wsp_ggml_tensor * b,
1146
+ enum wsp_ggml_glu_op op);
1147
+
1148
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reglu_split(
1149
+ struct wsp_ggml_context * ctx,
1150
+ struct wsp_ggml_tensor * a,
1151
+ struct wsp_ggml_tensor * b);
1152
+
1153
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_split(
1154
+ struct wsp_ggml_context * ctx,
1155
+ struct wsp_ggml_tensor * a,
1156
+ struct wsp_ggml_tensor * b);
1157
+
1158
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_swiglu_split(
1159
+ struct wsp_ggml_context * ctx,
1160
+ struct wsp_ggml_tensor * a,
1161
+ struct wsp_ggml_tensor * b);
1162
+
1089
1163
  // normalize along rows
1090
1164
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
1091
1165
  struct wsp_ggml_context * ctx,
@@ -1375,6 +1449,23 @@ extern "C" {
1375
1449
  struct wsp_ggml_tensor * b, // row indices
1376
1450
  struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1377
1451
 
1452
+ // a TD [n_embd, ne1, ne2, ne3]
1453
+ // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1454
+ // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1455
+ //
1456
+ // undefined behavior if destination rows overlap
1457
+ //
1458
+ // broadcast:
1459
+ // ne2 % ne11 == 0
1460
+ // ne3 % ne12 == 0
1461
+ //
1462
+ // return view(a)
1463
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_rows(
1464
+ struct wsp_ggml_context * ctx,
1465
+ struct wsp_ggml_tensor * a, // destination
1466
+ struct wsp_ggml_tensor * b, // source
1467
+ struct wsp_ggml_tensor * c); // row indices
1468
+
1378
1469
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1379
1470
  struct wsp_ggml_context * ctx,
1380
1471
  struct wsp_ggml_tensor * a);
@@ -1723,6 +1814,17 @@ extern "C" {
1723
1814
  struct wsp_ggml_tensor * b,
1724
1815
  int stride);
1725
1816
 
1817
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
1818
+ struct wsp_ggml_context * ctx,
1819
+ struct wsp_ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
1820
+ struct wsp_ggml_tensor * b, // input data [W, H, C, N]
1821
+ int s0, // stride dimension 0
1822
+ int s1, // stride dimension 1
1823
+ int p0, // padding dimension 0
1824
+ int p1, // padding dimension 1
1825
+ int d0, // dilation dimension 0
1826
+ int d1); // dilation dimension 1
1827
+
1726
1828
  enum wsp_ggml_op_pool {
1727
1829
  WSP_GGML_OP_POOL_MAX,
1728
1830
  WSP_GGML_OP_POOL_AVG,
@@ -1765,6 +1867,12 @@ extern "C" {
1765
1867
  enum wsp_ggml_scale_mode {
1766
1868
  WSP_GGML_SCALE_MODE_NEAREST = 0,
1767
1869
  WSP_GGML_SCALE_MODE_BILINEAR = 1,
1870
+
1871
+ WSP_GGML_SCALE_MODE_COUNT
1872
+ };
1873
+
1874
+ enum wsp_ggml_scale_flag {
1875
+ WSP_GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
1768
1876
  };
1769
1877
 
1770
1878
  // interpolate
@@ -1777,14 +1885,26 @@ extern "C" {
1777
1885
 
1778
1886
  // interpolate
1779
1887
  // interpolate scale to specified dimensions
1780
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1888
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1781
1889
  struct wsp_ggml_context * ctx,
1782
1890
  struct wsp_ggml_tensor * a,
1783
1891
  int ne0,
1784
1892
  int ne1,
1785
1893
  int ne2,
1786
1894
  int ne3,
1787
- enum wsp_ggml_scale_mode mode);
1895
+ enum wsp_ggml_scale_mode mode),
1896
+ "use wsp_ggml_interpolate instead");
1897
+
1898
+ // Up- or downsamples the input to the specified size.
1899
+ // 2D scale modes (eg. bilinear) are applied to the first two dimensions.
1900
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_interpolate(
1901
+ struct wsp_ggml_context * ctx,
1902
+ struct wsp_ggml_tensor * a,
1903
+ int64_t ne0,
1904
+ int64_t ne1,
1905
+ int64_t ne2,
1906
+ int64_t ne3,
1907
+ uint32_t mode); // wsp_ggml_scale_mode [ | wsp_ggml_scale_flag...]
1788
1908
 
1789
1909
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1790
1910
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
@@ -199,6 +199,8 @@ extern "C" {
199
199
  float samples_overlap; // Overlap in seconds when copying audio samples from speech segment.
200
200
  } whisper_vad_params;
201
201
 
202
+ WHISPER_API const char * whisper_version(void);
203
+
202
204
  // Various functions for loading a ggml whisper model.
203
205
  // Allocate (almost) all memory needed for the model.
204
206
  // Return NULL on failure
@@ -339,7 +339,7 @@ extern "C" {
339
339
  typedef bool (*wsp_ggml_backend_eval_callback)(int node_index, struct wsp_ggml_tensor * t1, struct wsp_ggml_tensor * t2, void * user_data);
340
340
 
341
341
  // Compare the output of two backends
342
- WSP_GGML_API bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data);
342
+ WSP_GGML_API bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data, struct wsp_ggml_tensor * test_node);
343
343
 
344
344
  // Tensor initialization
345
345
  WSP_GGML_API enum wsp_ggml_status wsp_ggml_backend_tensor_alloc(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, void * addr);
@@ -101,6 +101,7 @@ extern "C" {
101
101
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
102
102
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
103
103
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
104
+ WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_nnpa (void);
104
105
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
105
106
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_llamafile (void);
106
107
 
@@ -133,6 +134,7 @@ extern "C" {
133
134
 
134
135
  WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_cpu_reg(void);
135
136
 
137
+ WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
136
138
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp16(const float *, wsp_ggml_fp16_t *, int64_t);
137
139
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t *, float *, int64_t);
138
140
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_bf16(const float *, wsp_ggml_bf16_t *, int64_t);