whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -7,6 +7,8 @@
7
7
 
8
8
  #include <Metal/Metal.h>
9
9
 
10
+ #include <stdatomic.h>
11
+
10
12
  #ifndef TARGET_OS_VISION
11
13
  #define TARGET_OS_VISION 0
12
14
  #endif
@@ -19,8 +21,12 @@
19
21
  #define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
20
22
  #endif
21
23
 
22
- // overload of MTLGPUFamilyMetal3 (not available in some environments)
24
+ // overload of MTLGPUFamilyMetalX (not available in some environments)
23
25
  static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
26
+ static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
27
+
28
+ // virtual address for GPU memory allocations
29
+ static atomic_uintptr_t g_addr_device = 0x000000400ULL;
24
30
 
25
31
  #if !WSP_GGML_METAL_EMBED_LIBRARY
26
32
  // Here to assist with NSBundle Path Hack
@@ -175,11 +181,7 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
175
181
  NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
176
182
  #endif
177
183
 
178
- #if TARGET_OS_SIMULATOR
179
- NSString * path_lib = [bundle pathForResource:@"ggml-whisper-sim" ofType:@"metallib"];
180
- #else
181
- NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
182
- #endif
184
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
183
185
  if (path_lib == nil) {
184
186
  // Try to find the resource in the directory where the current binary located.
185
187
  NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
@@ -260,6 +262,10 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
260
262
  [prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_BF16"];
261
263
  }
262
264
 
265
+ if (wsp_ggml_metal_device_get_props(dev)->has_tensor) {
266
+ [prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_TENSOR"];
267
+ }
268
+
263
269
  #if WSP_GGML_METAL_EMBED_LIBRARY
264
270
  [prep setObject:@"1" forKey:@"WSP_GGML_METAL_EMBED_LIBRARY"];
265
271
  #endif
@@ -297,6 +303,72 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
297
303
  return res;
298
304
  }
299
305
 
306
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose) {
307
+ if (source == NULL) {
308
+ WSP_GGML_LOG_ERROR("%s: source is NULL\n", __func__);
309
+ return NULL;
310
+ }
311
+
312
+ id<MTLDevice> device = wsp_ggml_metal_device_get_obj(dev);
313
+ id<MTLLibrary> library = nil;
314
+ NSError * error = nil;
315
+
316
+ const int64_t t_start = wsp_ggml_time_us();
317
+
318
+ NSString * src = [[NSString alloc] initWithBytes:source
319
+ length:strlen(source)
320
+ encoding:NSUTF8StringEncoding];
321
+ if (!src) {
322
+ WSP_GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
323
+ return NULL;
324
+ }
325
+
326
+ @autoreleasepool {
327
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
328
+
329
+ MTLCompileOptions * options = [MTLCompileOptions new];
330
+ options.preprocessorMacros = prep;
331
+
332
+ library = [device newLibraryWithSource:src options:options error:&error];
333
+ if (error) {
334
+ if (verbose) {
335
+ WSP_GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
336
+ } else {
337
+ WSP_GGML_LOG_ERROR("%s: error compiling source\n", __func__);
338
+ }
339
+ library = nil;
340
+ }
341
+
342
+ [options release];
343
+ }
344
+
345
+ [src release];
346
+
347
+ if (!library) {
348
+ if (verbose) {
349
+ WSP_GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
350
+ }
351
+
352
+ return NULL;
353
+ }
354
+
355
+ if (verbose) {
356
+ WSP_GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (wsp_ggml_time_us() - t_start) / 1e6);
357
+ }
358
+
359
+ wsp_ggml_metal_library_t res = calloc(1, sizeof(struct wsp_ggml_metal_library));
360
+ if (!res) {
361
+ WSP_GGML_LOG_ERROR("%s: calloc failed\n", __func__);
362
+ return NULL;
363
+ }
364
+
365
+ res->obj = library;
366
+ res->device = device;
367
+ res->pipelines = wsp_ggml_metal_pipelines_init();
368
+
369
+ return res;
370
+ }
371
+
300
372
  void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib) {
301
373
  if (!lib) {
302
374
  return;
@@ -344,9 +416,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
344
416
  if (!mtl_function) {
345
417
  wsp_ggml_critical_section_end();
346
418
 
347
- WSP_GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
419
+ WSP_GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
348
420
  if (error) {
349
- WSP_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
421
+ WSP_GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
350
422
  }
351
423
 
352
424
  return nil;
@@ -354,13 +426,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
354
426
 
355
427
  res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
356
428
 
357
- wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
358
-
359
429
  [mtl_function release];
360
430
 
361
431
  WSP_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
362
432
  (int) res->obj.maxTotalThreadsPerThreadgroup,
363
433
  (int) res->obj.threadExecutionWidth);
434
+
435
+ if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
436
+ wsp_ggml_critical_section_end();
437
+
438
+ WSP_GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
439
+
440
+ return nil;
441
+ }
442
+
443
+ wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
364
444
  }
365
445
 
366
446
  wsp_ggml_critical_section_end();
@@ -468,6 +548,128 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
468
548
 
469
549
  dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
470
550
  dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
551
+ if (getenv("WSP_GGML_METAL_BF16_DISABLE") != NULL) {
552
+ dev->props.has_bfloat = false;
553
+ }
554
+
555
+ dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
556
+ if (getenv("WSP_GGML_METAL_TENSOR_DISABLE") != NULL) {
557
+ dev->props.has_tensor = false;
558
+ }
559
+
560
+ // note: disable the tensor API by default for old chips because with the current implementation it is not useful
561
+ // - M2 Ultra: ~5% slower
562
+ // - M4, M4 Max: no significant difference
563
+ //
564
+ // TODO: try to update the tensor API kernels to at least match the simdgroup performance
565
+ if (getenv("WSP_GGML_METAL_TENSOR_ENABLE") == NULL &&
566
+ ![[dev->mtl_device name] containsString:@"M5"] &&
567
+ ![[dev->mtl_device name] containsString:@"M6"] &&
568
+ ![[dev->mtl_device name] containsString:@"A19"] &&
569
+ ![[dev->mtl_device name] containsString:@"A20"]) {
570
+ WSP_GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
571
+ dev->props.has_tensor = false;
572
+ }
573
+
574
+ // double-check that the tensor API compiles
575
+ if (dev->props.has_tensor) {
576
+ const char * src_tensor_f16 = "\n"
577
+ "#include <metal_stdlib> \n"
578
+ "#include <metal_tensor> \n"
579
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
580
+ " \n"
581
+ "using namespace metal; \n"
582
+ "using namespace mpp::tensor_ops; \n"
583
+ " \n"
584
+ "kernel void dummy_kernel( \n"
585
+ " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
586
+ " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
587
+ " device float * C [[buffer(2)]], \n"
588
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
589
+ "{ \n"
590
+ " auto tA = A.slice(0, (int)tgid.y); \n"
591
+ " auto tB = B.slice((int)tgid.x, 0); \n"
592
+ " \n"
593
+ " matmul2d< \n"
594
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
595
+ " execution_simdgroups<4>> mm; \n"
596
+ " \n"
597
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
598
+ " \n"
599
+ " auto sA = tA.slice(0, 0); \n"
600
+ " auto sB = tB.slice(0, 0); \n"
601
+ " mm.run(sB, sA, cT); \n"
602
+ " \n"
603
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
604
+ " \n"
605
+ " cT.store(tC); \n"
606
+ "}";
607
+
608
+ WSP_GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
609
+ wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
610
+ if (lib == NULL) {
611
+ WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
612
+ dev->props.has_tensor = false;
613
+ } else {
614
+ wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
615
+ if (!ppl) {
616
+ WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
617
+ dev->props.has_tensor = false;
618
+ }
619
+
620
+ wsp_ggml_metal_library_free(lib);
621
+ }
622
+ }
623
+
624
+ // try to compile a dummy kernel to determine if the tensor API is supported for bfloat
625
+ if (dev->props.has_tensor && dev->props.has_bfloat) {
626
+ const char * src_tensor_bf16 = "\n"
627
+ "#include <metal_stdlib> \n"
628
+ "#include <metal_tensor> \n"
629
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
630
+ " \n"
631
+ "using namespace metal; \n"
632
+ "using namespace mpp::tensor_ops; \n"
633
+ " \n"
634
+ "kernel void dummy_kernel( \n"
635
+ " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
636
+ " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
637
+ " device float * C [[buffer(2)]], \n"
638
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
639
+ "{ \n"
640
+ " auto tA = A.slice(0, (int)tgid.y); \n"
641
+ " auto tB = B.slice((int)tgid.x, 0); \n"
642
+ " \n"
643
+ " matmul2d< \n"
644
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
645
+ " execution_simdgroups<4>> mm; \n"
646
+ " \n"
647
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
648
+ " \n"
649
+ " auto sA = tA.slice(0, 0); \n"
650
+ " auto sB = tB.slice(0, 0); \n"
651
+ " mm.run(sB, sA, cT); \n"
652
+ " \n"
653
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
654
+ " \n"
655
+ " cT.store(tC); \n"
656
+ "}";
657
+
658
+ WSP_GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
659
+ wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
660
+ if (lib == NULL) {
661
+ WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
662
+ dev->props.has_bfloat = false;
663
+ } else {
664
+ wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
665
+ if (!ppl) {
666
+ WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
667
+ dev->props.has_bfloat = false;
668
+ }
669
+
670
+ wsp_ggml_metal_library_free(lib);
671
+ }
672
+ }
471
673
 
472
674
  dev->props.use_residency_sets = true;
473
675
  #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
@@ -475,7 +677,6 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
475
677
  #endif
476
678
 
477
679
  dev->props.use_shared_buffers = dev->props.has_unified_memory;
478
-
479
680
  if (getenv("WSP_GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
480
681
  dev->props.use_shared_buffers = false;
481
682
  }
@@ -528,6 +729,7 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
528
729
  WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
529
730
  WSP_GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
530
731
  WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
732
+ WSP_GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
531
733
  WSP_GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
532
734
  WSP_GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
533
735
 
@@ -652,6 +854,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
652
854
  case WSP_GGML_OP_SCALE:
653
855
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
654
856
  return true;
857
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
858
+ return wsp_ggml_is_contiguous(op->src[0]) && wsp_ggml_is_contiguous(op->src[1]) &&
859
+ (op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32) &&
860
+ op->src[1]->type == WSP_GGML_TYPE_F32 &&
861
+ op->type == WSP_GGML_TYPE_F32;
655
862
  case WSP_GGML_OP_CLAMP:
656
863
  return op->src[0]->type == WSP_GGML_TYPE_F32;
657
864
  case WSP_GGML_OP_SQR:
@@ -660,7 +867,10 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
660
867
  case WSP_GGML_OP_COS:
661
868
  case WSP_GGML_OP_LOG:
662
869
  return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
870
+ case WSP_GGML_OP_SUM:
871
+ return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
663
872
  case WSP_GGML_OP_SUM_ROWS:
873
+ case WSP_GGML_OP_CUMSUM:
664
874
  case WSP_GGML_OP_MEAN:
665
875
  case WSP_GGML_OP_SOFT_MAX:
666
876
  case WSP_GGML_OP_GROUP_NORM:
@@ -676,6 +886,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
676
886
  return true;
677
887
  case WSP_GGML_OP_IM2COL:
678
888
  return wsp_ggml_is_contiguous(op->src[1]) && op->src[1]->type == WSP_GGML_TYPE_F32 && (op->type == WSP_GGML_TYPE_F16 || op->type == WSP_GGML_TYPE_F32);
889
+ case WSP_GGML_OP_CONV_2D:
890
+ return wsp_ggml_is_contiguous(op->src[0]) &&
891
+ op->src[1]->type == WSP_GGML_TYPE_F32 &&
892
+ op->type == WSP_GGML_TYPE_F32 &&
893
+ (op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
679
894
  case WSP_GGML_OP_POOL_1D:
680
895
  return false;
681
896
  case WSP_GGML_OP_UPSCALE:
@@ -690,14 +905,14 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
690
905
  case WSP_GGML_OP_LEAKY_RELU:
691
906
  return op->src[0]->type == WSP_GGML_TYPE_F32;
692
907
  case WSP_GGML_OP_ARGSORT:
693
- // TODO: Support arbitrary column width
694
- return op->src[0]->ne[0] <= 1024;
695
908
  case WSP_GGML_OP_ARANGE:
696
909
  return true;
697
910
  case WSP_GGML_OP_FLASH_ATTN_EXT:
698
911
  // for new head sizes, add checks here
699
- if (op->src[0]->ne[0] != 40 &&
912
+ if (op->src[0]->ne[0] != 32 &&
913
+ op->src[0]->ne[0] != 40 &&
700
914
  op->src[0]->ne[0] != 64 &&
915
+ op->src[0]->ne[0] != 72 &&
701
916
  op->src[0]->ne[0] != 80 &&
702
917
  op->src[0]->ne[0] != 96 &&
703
918
  op->src[0]->ne[0] != 112 &&
@@ -774,15 +989,13 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
774
989
  return false;
775
990
  }
776
991
  case WSP_GGML_TYPE_I32:
777
- return op->type == WSP_GGML_TYPE_F32;
992
+ return op->type == WSP_GGML_TYPE_F32 || op->type == WSP_GGML_TYPE_I32;
778
993
  default:
779
994
  return false;
780
995
  };
781
996
  }
782
997
  case WSP_GGML_OP_GET_ROWS:
783
- {
784
- return op->ne[3] == 1;
785
- }
998
+ return true;
786
999
  case WSP_GGML_OP_SET_ROWS:
787
1000
  {
788
1001
  if (op->src[0]->type != WSP_GGML_TYPE_F32) {
@@ -804,6 +1017,9 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
804
1017
  return false;
805
1018
  };
806
1019
  }
1020
+ case WSP_GGML_OP_OPT_STEP_ADAMW:
1021
+ case WSP_GGML_OP_OPT_STEP_SGD:
1022
+ return has_simdgroup_reduction;
807
1023
  default:
808
1024
  return false;
809
1025
  }
@@ -828,7 +1044,7 @@ struct wsp_ggml_metal_buffer_wrapper {
828
1044
  };
829
1045
 
830
1046
  struct wsp_ggml_metal_buffer {
831
- void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
1047
+ void * all_data;
832
1048
  size_t all_size;
833
1049
 
834
1050
  // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
@@ -966,14 +1182,15 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
966
1182
  if (shared) {
967
1183
  res->all_data = wsp_ggml_metal_host_malloc(size_aligned);
968
1184
  res->is_shared = true;
969
- res->owned = true;
970
1185
  } else {
971
- // dummy, non-NULL value - we'll populate this after creating the Metal buffer below
972
- res->all_data = (void *) 0x000000400ULL;
1186
+ // use virtual address from g_addr_device counter
1187
+ res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
973
1188
  res->is_shared = false;
974
1189
  }
975
1190
  res->all_size = size_aligned;
976
1191
 
1192
+ res->owned = true;
1193
+
977
1194
  res->device = wsp_ggml_metal_device_get_obj(dev);
978
1195
  res->queue = wsp_ggml_metal_device_get_queue(dev);
979
1196
 
@@ -984,15 +1201,13 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
984
1201
  res->buffers[0].metal = nil;
985
1202
 
986
1203
  if (size_aligned > 0) {
987
- if (props_dev->use_shared_buffers &&shared) {
1204
+ if (props_dev->use_shared_buffers && shared) {
988
1205
  res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
989
1206
  length:size_aligned
990
1207
  options:MTLResourceStorageModeShared
991
1208
  deallocator:nil];
992
1209
  } else {
993
1210
  res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
994
-
995
- res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
996
1211
  }
997
1212
  }
998
1213
 
@@ -1140,7 +1355,7 @@ bool wsp_ggml_metal_buffer_is_shared(wsp_ggml_metal_buffer_t buf) {
1140
1355
 
1141
1356
  void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
1142
1357
  if (buf->is_shared) {
1143
- memset((char *)tensor->data + offset, value, size);
1358
+ memset((char *) tensor->data + offset, value, size);
1144
1359
  return;
1145
1360
  }
1146
1361
 
@@ -1169,7 +1384,7 @@ void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp
1169
1384
 
1170
1385
  void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1171
1386
  if (buf->is_shared) {
1172
- memcpy((char *)tensor->data + offset, data, size);
1387
+ memcpy((char *) tensor->data + offset, data, size);
1173
1388
  return;
1174
1389
  }
1175
1390
 
@@ -1224,7 +1439,7 @@ void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_gg
1224
1439
 
1225
1440
  void wsp_ggml_metal_buffer_get_tensor(wsp_ggml_metal_buffer_t buf, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1226
1441
  if (buf->is_shared) {
1227
- memcpy(data, (const char *)tensor->data + offset, size);
1442
+ memcpy(data, (const char *) tensor->data + offset, size);
1228
1443
  return;
1229
1444
  }
1230
1445