whisper.rn 0.4.1 → 0.4.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. package/android/src/main/java/com/rnwhisper/RNWhisper.java +24 -18
  2. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -57
  3. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  9. package/cpp/ggml-backend.cpp +36 -18
  10. package/cpp/ggml-backend.h +1 -1
  11. package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
  12. package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
  13. package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
  14. package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
  15. package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
  16. package/cpp/ggml-cpu/common.h +3 -2
  17. package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
  18. package/cpp/ggml-cpu/ggml-cpu.c +95 -17
  19. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  20. package/cpp/ggml-cpu/ops.cpp +775 -74
  21. package/cpp/ggml-cpu/ops.h +7 -0
  22. package/cpp/ggml-cpu/quants.c +25 -24
  23. package/cpp/ggml-cpu/repack.cpp +15 -14
  24. package/cpp/ggml-cpu/simd-mappings.h +211 -33
  25. package/cpp/ggml-cpu/vec.cpp +26 -2
  26. package/cpp/ggml-cpu/vec.h +99 -45
  27. package/cpp/ggml-cpu.h +2 -0
  28. package/cpp/ggml-impl.h +125 -183
  29. package/cpp/ggml-metal-impl.h +27 -0
  30. package/cpp/ggml-metal.m +298 -41
  31. package/cpp/ggml-quants.c +6 -6
  32. package/cpp/ggml-whisper-sim.metallib +0 -0
  33. package/cpp/ggml-whisper.metallib +0 -0
  34. package/cpp/ggml.c +269 -40
  35. package/cpp/ggml.h +122 -2
  36. package/cpp/gguf.cpp +5 -1
  37. package/cpp/whisper.cpp +4 -0
  38. package/cpp/whisper.h +2 -0
  39. package/ios/RNWhisper.mm +35 -38
  40. package/ios/RNWhisperVadContext.h +1 -1
  41. package/ios/RNWhisperVadContext.mm +2 -6
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  74. package/package.json +1 -1
package/cpp/ggml-impl.h CHANGED
@@ -301,6 +301,7 @@ struct wsp_ggml_cgraph {
301
301
  struct wsp_ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302
302
  struct wsp_ggml_tensor ** grad_accs; // accumulators for node gradients
303
303
  struct wsp_ggml_tensor ** leafs; // tensors with constant data
304
+ int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
304
305
 
305
306
  struct wsp_ggml_hash_set visited_hash_set;
306
307
 
@@ -317,203 +318,81 @@ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph, int
317
318
  WSP_GGML_API void * wsp_ggml_aligned_malloc(size_t size);
318
319
  WSP_GGML_API void wsp_ggml_aligned_free(void * ptr, size_t size);
319
320
 
320
- // FP16 to FP32 conversion
321
+ // FP16 <-> FP32
322
+ // ref: https://github.com/Maratyszcza/FP16
321
323
 
322
- // 16-bit float
323
- // on Arm, we use __fp16
324
- // on x86, we use uint16_t
325
- //
326
- // for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616
327
- // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
328
- //
329
- #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
330
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
331
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
332
-
333
- #define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
334
-
335
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
336
- __fp16 tmp;
337
- memcpy(&tmp, &h, sizeof(wsp_ggml_fp16_t));
338
- return (float)tmp;
339
- }
340
-
341
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
342
- wsp_ggml_fp16_t res;
343
- __fp16 tmp = f;
344
- memcpy(&res, &tmp, sizeof(wsp_ggml_fp16_t));
345
- return res;
346
- }
347
-
348
- #elif defined(__F16C__)
349
-
350
- #ifdef _MSC_VER
351
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
352
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
353
- #else
354
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
355
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
356
- #endif
357
-
358
- #elif defined(__POWER9_VECTOR__)
359
-
360
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
361
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
362
- /* the inline asm below is about 12% faster than the lookup method */
363
- #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
364
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
365
-
366
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
367
- float f;
368
- double d;
369
- __asm__(
370
- "mtfprd %0,%2\n"
371
- "xscvhpdp %0,%0\n"
372
- "frsp %1,%0\n" :
373
- /* temp */ "=d"(d),
374
- /* out */ "=f"(f):
375
- /* in */ "r"(h));
376
- return f;
377
- }
378
-
379
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
380
- double d;
381
- wsp_ggml_fp16_t r;
382
- __asm__( /* xscvdphp can work on double or single precision */
383
- "xscvdphp %0,%2\n"
384
- "mffprd %1,%0\n" :
385
- /* temp */ "=d"(d),
386
- /* out */ "=r"(r):
387
- /* in */ "f"(f));
388
- return r;
389
- }
390
-
391
- #elif defined(__riscv) && defined(__riscv_zfhmin)
392
-
393
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
394
- float f;
395
- __asm__(
396
- "fmv.h.x %[f], %[h]\n\t"
397
- "fcvt.s.h %[f], %[f]"
398
- : [f] "=&f" (f)
399
- : [h] "r" (h)
400
- );
401
- return f;
402
- }
324
+ static inline float fp32_from_bits(uint32_t w) {
325
+ union {
326
+ uint32_t as_bits;
327
+ float as_value;
328
+ } fp32;
329
+ fp32.as_bits = w;
330
+ return fp32.as_value;
331
+ }
403
332
 
404
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
405
- wsp_ggml_fp16_t res;
406
- __asm__(
407
- "fcvt.h.s %[f], %[f]\n\t"
408
- "fmv.x.h %[h], %[f]"
409
- : [h] "=&r" (res)
410
- : [f] "f" (f)
411
- );
412
- return res;
413
- }
333
+ static inline uint32_t fp32_to_bits(float f) {
334
+ union {
335
+ float as_value;
336
+ uint32_t as_bits;
337
+ } fp32;
338
+ fp32.as_value = f;
339
+ return fp32.as_bits;
340
+ }
414
341
 
415
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
416
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
417
- #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
418
- #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
342
+ static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
343
+ const uint32_t w = (uint32_t) h << 16;
344
+ const uint32_t sign = w & UINT32_C(0x80000000);
345
+ const uint32_t two_w = w + w;
419
346
 
347
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
348
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
349
+ const float exp_scale = 0x1.0p-112f;
420
350
  #else
351
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
352
+ #endif
353
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
421
354
 
422
- // FP16 <-> FP32
423
- // ref: https://github.com/Maratyszcza/FP16
424
-
425
- static inline float fp32_from_bits(uint32_t w) {
426
- union {
427
- uint32_t as_bits;
428
- float as_value;
429
- } fp32;
430
- fp32.as_bits = w;
431
- return fp32.as_value;
432
- }
433
-
434
- static inline uint32_t fp32_to_bits(float f) {
435
- union {
436
- float as_value;
437
- uint32_t as_bits;
438
- } fp32;
439
- fp32.as_value = f;
440
- return fp32.as_bits;
441
- }
442
-
443
- static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) {
444
- const uint32_t w = (uint32_t) h << 16;
445
- const uint32_t sign = w & UINT32_C(0x80000000);
446
- const uint32_t two_w = w + w;
447
-
448
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
449
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
450
- const float exp_scale = 0x1.0p-112f;
451
- #else
452
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
453
- #endif
454
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
455
-
456
- const uint32_t magic_mask = UINT32_C(126) << 23;
457
- const float magic_bias = 0.5f;
458
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
355
+ const uint32_t magic_mask = UINT32_C(126) << 23;
356
+ const float magic_bias = 0.5f;
357
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
459
358
 
460
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
461
- const uint32_t result = sign |
462
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
463
- return fp32_from_bits(result);
464
- }
465
-
466
- static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
467
- #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
468
- const float scale_to_inf = 0x1.0p+112f;
469
- const float scale_to_zero = 0x1.0p-110f;
470
- #else
471
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
472
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
473
- #endif
474
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
475
-
476
- const uint32_t w = fp32_to_bits(f);
477
- const uint32_t shl1_w = w + w;
478
- const uint32_t sign = w & UINT32_C(0x80000000);
479
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
480
- if (bias < UINT32_C(0x71000000)) {
481
- bias = UINT32_C(0x71000000);
482
- }
359
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
360
+ const uint32_t result = sign |
361
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
362
+ return fp32_from_bits(result);
363
+ }
483
364
 
484
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
485
- const uint32_t bits = fp32_to_bits(base);
486
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
487
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
488
- const uint32_t nonsign = exp_bits + mantissa_bits;
489
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
365
+ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
366
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
367
+ const float scale_to_inf = 0x1.0p+112f;
368
+ const float scale_to_zero = 0x1.0p-110f;
369
+ #else
370
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
371
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
372
+ #endif
373
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
374
+
375
+ const uint32_t w = fp32_to_bits(f);
376
+ const uint32_t shl1_w = w + w;
377
+ const uint32_t sign = w & UINT32_C(0x80000000);
378
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
379
+ if (bias < UINT32_C(0x71000000)) {
380
+ bias = UINT32_C(0x71000000);
490
381
  }
491
382
 
492
- #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
493
- #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
494
-
495
- #endif // defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
496
-
497
- // precomputed f32 table for f16 (256 KB)
498
- // defined in ggml.c, initialized in wsp_ggml_init()
499
- WSP_GGML_API float wsp_ggml_table_f32_f16[1 << 16];
500
-
501
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32,
502
- // so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON.
503
- // This is also true for POWER9.
504
- #if !defined(WSP_GGML_FP16_TO_FP32)
505
- inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
506
- uint16_t s;
507
- memcpy(&s, &f, sizeof(uint16_t));
508
- return wsp_ggml_table_f32_f16[s];
383
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
384
+ const uint32_t bits = fp32_to_bits(base);
385
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
386
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
387
+ const uint32_t nonsign = exp_bits + mantissa_bits;
388
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
509
389
  }
510
390
 
511
- #define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x)
512
- #endif
391
+ #define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x)
392
+ #define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x)
513
393
 
514
- #if !defined(WSP_GGML_FP32_TO_FP16)
394
+ #define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
515
395
  #define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
516
- #endif
517
396
 
518
397
  /**
519
398
  * Converts brain16 to float32.
@@ -589,13 +468,76 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
589
468
  #define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
590
469
  #define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
591
470
 
471
+ // return true if the node's results are only used by N other nodes
472
+ // and can be fused into their calculations.
473
+ static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
474
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
475
+
476
+ // check the use count against how many we're replacing
477
+ size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
478
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
479
+ return false;
480
+ }
481
+
482
+ // if node is a view, some other node might be using the intermediate result
483
+ // via the view source.
484
+ if (node->view_src) {
485
+ return false;
486
+ }
487
+
488
+ // If the user requested output for the node, can't fuse
489
+ if (node->flags & WSP_GGML_TENSOR_FLAG_OUTPUT) {
490
+ return false;
491
+ }
492
+
493
+ return true;
494
+ }
495
+
496
+ // Returns true if nodes [i, i+ops.size()) are the sequence of wsp_ggml_ops in ops[]
497
+ // and are fusable. Nodes are considered fusable according to this function if:
498
+ // - all nodes except the last have only one use and are not views/outputs (see wsp_ggml_node_has_N_uses).
499
+ // - all nodes except the last are a src of the following node.
500
+ // - all nodes are the same shape.
501
+ // TODO: Consider allowing WSP_GGML_OP_NONE nodes in between
502
+ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, const enum wsp_ggml_op * ops, int num_ops) {
503
+ if (node_idx + num_ops > cgraph->n_nodes) {
504
+ return false;
505
+ }
506
+
507
+ for (int i = 0; i < num_ops; ++i) {
508
+ struct wsp_ggml_tensor * node = cgraph->nodes[node_idx + i];
509
+ if (node->op != ops[i]) {
510
+ return false;
511
+ }
512
+ if (i < num_ops - 1 && !wsp_ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
513
+ return false;
514
+ }
515
+ if (i > 0) {
516
+ struct wsp_ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
517
+ if (node->src[0] != prev && node->src[1] != prev) {
518
+ return false;
519
+ }
520
+ if (!wsp_ggml_are_same_shape(node, prev)) {
521
+ return false;
522
+ }
523
+ }
524
+ }
525
+ return true;
526
+ }
527
+
592
528
  #ifdef __cplusplus
593
529
  }
594
530
  #endif
595
531
 
596
532
  #ifdef __cplusplus
533
+ #include <initializer_list>
597
534
  #include <vector>
598
535
 
536
+ // nicer C++ syntax for wsp_ggml_can_fuse
537
+ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum wsp_ggml_op> ops) {
538
+ return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
539
+ }
540
+
599
541
  // expose GGUF internals for test code
600
542
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
601
543
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -422,6 +422,17 @@ typedef struct {
422
422
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
423
  } wsp_ggml_metal_kargs_im2col;
424
424
 
425
+ typedef struct{
426
+ int32_t ne00;
427
+ uint64_t nb01;
428
+ int32_t ne10;
429
+ uint64_t nb11;
430
+ int32_t ne0;
431
+ uint64_t nb1;
432
+ int32_t i00;
433
+ int32_t i10;
434
+ } wsp_ggml_metal_kargs_glu;
435
+
425
436
  typedef struct {
426
437
  int64_t ne00;
427
438
  int64_t ne01;
@@ -521,6 +532,22 @@ typedef struct {
521
532
  uint64_t nb2;
522
533
  } wsp_ggml_metal_kargs_get_rows;
523
534
 
535
+ typedef struct {
536
+ int32_t nk0;
537
+ int32_t ne01;
538
+ uint64_t nb01;
539
+ uint64_t nb02;
540
+ uint64_t nb03;
541
+ int32_t ne11;
542
+ int32_t ne12;
543
+ uint64_t nb10;
544
+ uint64_t nb11;
545
+ uint64_t nb12;
546
+ uint64_t nb1;
547
+ uint64_t nb2;
548
+ uint64_t nb3;
549
+ } wsp_ggml_metal_kargs_set_rows;
550
+
524
551
  typedef struct {
525
552
  int64_t ne00;
526
553
  int64_t ne01;