whisper.rn 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +49 -18
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
- package/cpp/ggml-cpu/ggml-cpu.c +67 -24
- package/cpp/ggml-cpu/ops.cpp +489 -364
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/ggml-cpu/unary-ops.h +7 -0
- package/cpp/ggml-cpu/vec.cpp +83 -0
- package/cpp/ggml-cpu/vec.h +20 -8
- package/cpp/ggml-impl.h +67 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
- package/cpp/ggml-metal/ggml-metal-device.h +26 -1
- package/cpp/ggml-metal/ggml-metal-device.m +243 -28
- package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
- package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
- package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
- package/cpp/ggml-metal/ggml-metal.cpp +8 -3
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +317 -4
- package/cpp/ggml.h +139 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +8 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
|
@@ -7,6 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#include <Metal/Metal.h>
|
|
9
9
|
|
|
10
|
+
#include <stdatomic.h>
|
|
11
|
+
|
|
10
12
|
#ifndef TARGET_OS_VISION
|
|
11
13
|
#define TARGET_OS_VISION 0
|
|
12
14
|
#endif
|
|
@@ -19,8 +21,12 @@
|
|
|
19
21
|
#define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
|
|
20
22
|
#endif
|
|
21
23
|
|
|
22
|
-
// overload of
|
|
24
|
+
// overload of MTLGPUFamilyMetalX (not available in some environments)
|
|
23
25
|
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
|
26
|
+
static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
|
|
27
|
+
|
|
28
|
+
// virtual address for GPU memory allocations
|
|
29
|
+
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
|
24
30
|
|
|
25
31
|
#if !WSP_GGML_METAL_EMBED_LIBRARY
|
|
26
32
|
// Here to assist with NSBundle Path Hack
|
|
@@ -175,11 +181,7 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
175
181
|
NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
|
|
176
182
|
#endif
|
|
177
183
|
|
|
178
|
-
|
|
179
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-whisper-sim" ofType:@"metallib"];
|
|
180
|
-
#else
|
|
181
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
|
|
182
|
-
#endif
|
|
184
|
+
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
183
185
|
if (path_lib == nil) {
|
|
184
186
|
// Try to find the resource in the directory where the current binary located.
|
|
185
187
|
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
|
@@ -260,6 +262,10 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
260
262
|
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_BF16"];
|
|
261
263
|
}
|
|
262
264
|
|
|
265
|
+
if (wsp_ggml_metal_device_get_props(dev)->has_tensor) {
|
|
266
|
+
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_TENSOR"];
|
|
267
|
+
}
|
|
268
|
+
|
|
263
269
|
#if WSP_GGML_METAL_EMBED_LIBRARY
|
|
264
270
|
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_EMBED_LIBRARY"];
|
|
265
271
|
#endif
|
|
@@ -297,6 +303,72 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
297
303
|
return res;
|
|
298
304
|
}
|
|
299
305
|
|
|
306
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose) {
|
|
307
|
+
if (source == NULL) {
|
|
308
|
+
WSP_GGML_LOG_ERROR("%s: source is NULL\n", __func__);
|
|
309
|
+
return NULL;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
id<MTLDevice> device = wsp_ggml_metal_device_get_obj(dev);
|
|
313
|
+
id<MTLLibrary> library = nil;
|
|
314
|
+
NSError * error = nil;
|
|
315
|
+
|
|
316
|
+
const int64_t t_start = wsp_ggml_time_us();
|
|
317
|
+
|
|
318
|
+
NSString * src = [[NSString alloc] initWithBytes:source
|
|
319
|
+
length:strlen(source)
|
|
320
|
+
encoding:NSUTF8StringEncoding];
|
|
321
|
+
if (!src) {
|
|
322
|
+
WSP_GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
|
|
323
|
+
return NULL;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
@autoreleasepool {
|
|
327
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
|
328
|
+
|
|
329
|
+
MTLCompileOptions * options = [MTLCompileOptions new];
|
|
330
|
+
options.preprocessorMacros = prep;
|
|
331
|
+
|
|
332
|
+
library = [device newLibraryWithSource:src options:options error:&error];
|
|
333
|
+
if (error) {
|
|
334
|
+
if (verbose) {
|
|
335
|
+
WSP_GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
|
|
336
|
+
} else {
|
|
337
|
+
WSP_GGML_LOG_ERROR("%s: error compiling source\n", __func__);
|
|
338
|
+
}
|
|
339
|
+
library = nil;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
[options release];
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
[src release];
|
|
346
|
+
|
|
347
|
+
if (!library) {
|
|
348
|
+
if (verbose) {
|
|
349
|
+
WSP_GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
return NULL;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
if (verbose) {
|
|
356
|
+
WSP_GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (wsp_ggml_time_us() - t_start) / 1e6);
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
wsp_ggml_metal_library_t res = calloc(1, sizeof(struct wsp_ggml_metal_library));
|
|
360
|
+
if (!res) {
|
|
361
|
+
WSP_GGML_LOG_ERROR("%s: calloc failed\n", __func__);
|
|
362
|
+
return NULL;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
res->obj = library;
|
|
366
|
+
res->device = device;
|
|
367
|
+
res->pipelines = wsp_ggml_metal_pipelines_init();
|
|
368
|
+
|
|
369
|
+
return res;
|
|
370
|
+
}
|
|
371
|
+
|
|
300
372
|
void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib) {
|
|
301
373
|
if (!lib) {
|
|
302
374
|
return;
|
|
@@ -344,9 +416,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
|
|
|
344
416
|
if (!mtl_function) {
|
|
345
417
|
wsp_ggml_critical_section_end();
|
|
346
418
|
|
|
347
|
-
WSP_GGML_LOG_ERROR("%s:
|
|
419
|
+
WSP_GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
|
348
420
|
if (error) {
|
|
349
|
-
WSP_GGML_LOG_ERROR("%s:
|
|
421
|
+
WSP_GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
|
350
422
|
}
|
|
351
423
|
|
|
352
424
|
return nil;
|
|
@@ -354,13 +426,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
|
|
|
354
426
|
|
|
355
427
|
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
|
356
428
|
|
|
357
|
-
wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
|
|
358
|
-
|
|
359
429
|
[mtl_function release];
|
|
360
430
|
|
|
361
431
|
WSP_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
|
|
362
432
|
(int) res->obj.maxTotalThreadsPerThreadgroup,
|
|
363
433
|
(int) res->obj.threadExecutionWidth);
|
|
434
|
+
|
|
435
|
+
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
|
436
|
+
wsp_ggml_critical_section_end();
|
|
437
|
+
|
|
438
|
+
WSP_GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
|
439
|
+
|
|
440
|
+
return nil;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
|
|
364
444
|
}
|
|
365
445
|
|
|
366
446
|
wsp_ggml_critical_section_end();
|
|
@@ -468,6 +548,128 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
468
548
|
|
|
469
549
|
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
470
550
|
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
|
551
|
+
if (getenv("WSP_GGML_METAL_BF16_DISABLE") != NULL) {
|
|
552
|
+
dev->props.has_bfloat = false;
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
|
|
556
|
+
if (getenv("WSP_GGML_METAL_TENSOR_DISABLE") != NULL) {
|
|
557
|
+
dev->props.has_tensor = false;
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// note: disable the tensor API by default for old chips because with the current implementation it is not useful
|
|
561
|
+
// - M2 Ultra: ~5% slower
|
|
562
|
+
// - M4, M4 Max: no significant difference
|
|
563
|
+
//
|
|
564
|
+
// TODO: try to update the tensor API kernels to at least match the simdgroup performance
|
|
565
|
+
if (getenv("WSP_GGML_METAL_TENSOR_ENABLE") == NULL &&
|
|
566
|
+
![[dev->mtl_device name] containsString:@"M5"] &&
|
|
567
|
+
![[dev->mtl_device name] containsString:@"M6"] &&
|
|
568
|
+
![[dev->mtl_device name] containsString:@"A19"] &&
|
|
569
|
+
![[dev->mtl_device name] containsString:@"A20"]) {
|
|
570
|
+
WSP_GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
|
|
571
|
+
dev->props.has_tensor = false;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
// double-check that the tensor API compiles
|
|
575
|
+
if (dev->props.has_tensor) {
|
|
576
|
+
const char * src_tensor_f16 = "\n"
|
|
577
|
+
"#include <metal_stdlib> \n"
|
|
578
|
+
"#include <metal_tensor> \n"
|
|
579
|
+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
|
580
|
+
" \n"
|
|
581
|
+
"using namespace metal; \n"
|
|
582
|
+
"using namespace mpp::tensor_ops; \n"
|
|
583
|
+
" \n"
|
|
584
|
+
"kernel void dummy_kernel( \n"
|
|
585
|
+
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
|
586
|
+
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
|
587
|
+
" device float * C [[buffer(2)]], \n"
|
|
588
|
+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
|
589
|
+
"{ \n"
|
|
590
|
+
" auto tA = A.slice(0, (int)tgid.y); \n"
|
|
591
|
+
" auto tB = B.slice((int)tgid.x, 0); \n"
|
|
592
|
+
" \n"
|
|
593
|
+
" matmul2d< \n"
|
|
594
|
+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
|
595
|
+
" execution_simdgroups<4>> mm; \n"
|
|
596
|
+
" \n"
|
|
597
|
+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
|
598
|
+
" \n"
|
|
599
|
+
" auto sA = tA.slice(0, 0); \n"
|
|
600
|
+
" auto sB = tB.slice(0, 0); \n"
|
|
601
|
+
" mm.run(sB, sA, cT); \n"
|
|
602
|
+
" \n"
|
|
603
|
+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
|
604
|
+
" \n"
|
|
605
|
+
" cT.store(tC); \n"
|
|
606
|
+
"}";
|
|
607
|
+
|
|
608
|
+
WSP_GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
|
|
609
|
+
wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
|
|
610
|
+
if (lib == NULL) {
|
|
611
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
|
612
|
+
dev->props.has_tensor = false;
|
|
613
|
+
} else {
|
|
614
|
+
wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
|
615
|
+
if (!ppl) {
|
|
616
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
|
617
|
+
dev->props.has_tensor = false;
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
wsp_ggml_metal_library_free(lib);
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
// try to compile a dummy kernel to determine if the tensor API is supported for bfloat
|
|
625
|
+
if (dev->props.has_tensor && dev->props.has_bfloat) {
|
|
626
|
+
const char * src_tensor_bf16 = "\n"
|
|
627
|
+
"#include <metal_stdlib> \n"
|
|
628
|
+
"#include <metal_tensor> \n"
|
|
629
|
+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
|
630
|
+
" \n"
|
|
631
|
+
"using namespace metal; \n"
|
|
632
|
+
"using namespace mpp::tensor_ops; \n"
|
|
633
|
+
" \n"
|
|
634
|
+
"kernel void dummy_kernel( \n"
|
|
635
|
+
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
|
636
|
+
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
|
637
|
+
" device float * C [[buffer(2)]], \n"
|
|
638
|
+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
|
639
|
+
"{ \n"
|
|
640
|
+
" auto tA = A.slice(0, (int)tgid.y); \n"
|
|
641
|
+
" auto tB = B.slice((int)tgid.x, 0); \n"
|
|
642
|
+
" \n"
|
|
643
|
+
" matmul2d< \n"
|
|
644
|
+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
|
645
|
+
" execution_simdgroups<4>> mm; \n"
|
|
646
|
+
" \n"
|
|
647
|
+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
|
648
|
+
" \n"
|
|
649
|
+
" auto sA = tA.slice(0, 0); \n"
|
|
650
|
+
" auto sB = tB.slice(0, 0); \n"
|
|
651
|
+
" mm.run(sB, sA, cT); \n"
|
|
652
|
+
" \n"
|
|
653
|
+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
|
654
|
+
" \n"
|
|
655
|
+
" cT.store(tC); \n"
|
|
656
|
+
"}";
|
|
657
|
+
|
|
658
|
+
WSP_GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
|
|
659
|
+
wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
|
|
660
|
+
if (lib == NULL) {
|
|
661
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
|
662
|
+
dev->props.has_bfloat = false;
|
|
663
|
+
} else {
|
|
664
|
+
wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
|
665
|
+
if (!ppl) {
|
|
666
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
|
667
|
+
dev->props.has_bfloat = false;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
wsp_ggml_metal_library_free(lib);
|
|
671
|
+
}
|
|
672
|
+
}
|
|
471
673
|
|
|
472
674
|
dev->props.use_residency_sets = true;
|
|
473
675
|
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
@@ -475,7 +677,6 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
475
677
|
#endif
|
|
476
678
|
|
|
477
679
|
dev->props.use_shared_buffers = dev->props.has_unified_memory;
|
|
478
|
-
|
|
479
680
|
if (getenv("WSP_GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
|
|
480
681
|
dev->props.use_shared_buffers = false;
|
|
481
682
|
}
|
|
@@ -528,6 +729,7 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
528
729
|
WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
|
|
529
730
|
WSP_GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
|
|
530
731
|
WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
|
|
732
|
+
WSP_GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
|
|
531
733
|
WSP_GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
|
|
532
734
|
WSP_GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
|
|
533
735
|
|
|
@@ -652,6 +854,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
652
854
|
case WSP_GGML_OP_SCALE:
|
|
653
855
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
654
856
|
return true;
|
|
857
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
858
|
+
return wsp_ggml_is_contiguous(op->src[0]) && wsp_ggml_is_contiguous(op->src[1]) &&
|
|
859
|
+
(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32) &&
|
|
860
|
+
op->src[1]->type == WSP_GGML_TYPE_F32 &&
|
|
861
|
+
op->type == WSP_GGML_TYPE_F32;
|
|
655
862
|
case WSP_GGML_OP_CLAMP:
|
|
656
863
|
return op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
657
864
|
case WSP_GGML_OP_SQR:
|
|
@@ -660,7 +867,10 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
660
867
|
case WSP_GGML_OP_COS:
|
|
661
868
|
case WSP_GGML_OP_LOG:
|
|
662
869
|
return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
870
|
+
case WSP_GGML_OP_SUM:
|
|
871
|
+
return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
|
|
663
872
|
case WSP_GGML_OP_SUM_ROWS:
|
|
873
|
+
case WSP_GGML_OP_CUMSUM:
|
|
664
874
|
case WSP_GGML_OP_MEAN:
|
|
665
875
|
case WSP_GGML_OP_SOFT_MAX:
|
|
666
876
|
case WSP_GGML_OP_GROUP_NORM:
|
|
@@ -676,6 +886,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
676
886
|
return true;
|
|
677
887
|
case WSP_GGML_OP_IM2COL:
|
|
678
888
|
return wsp_ggml_is_contiguous(op->src[1]) && op->src[1]->type == WSP_GGML_TYPE_F32 && (op->type == WSP_GGML_TYPE_F16 || op->type == WSP_GGML_TYPE_F32);
|
|
889
|
+
case WSP_GGML_OP_CONV_2D:
|
|
890
|
+
return wsp_ggml_is_contiguous(op->src[0]) &&
|
|
891
|
+
op->src[1]->type == WSP_GGML_TYPE_F32 &&
|
|
892
|
+
op->type == WSP_GGML_TYPE_F32 &&
|
|
893
|
+
(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
679
894
|
case WSP_GGML_OP_POOL_1D:
|
|
680
895
|
return false;
|
|
681
896
|
case WSP_GGML_OP_UPSCALE:
|
|
@@ -690,14 +905,14 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
690
905
|
case WSP_GGML_OP_LEAKY_RELU:
|
|
691
906
|
return op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
692
907
|
case WSP_GGML_OP_ARGSORT:
|
|
693
|
-
// TODO: Support arbitrary column width
|
|
694
|
-
return op->src[0]->ne[0] <= 1024;
|
|
695
908
|
case WSP_GGML_OP_ARANGE:
|
|
696
909
|
return true;
|
|
697
910
|
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
698
911
|
// for new head sizes, add checks here
|
|
699
|
-
if (op->src[0]->ne[0] !=
|
|
912
|
+
if (op->src[0]->ne[0] != 32 &&
|
|
913
|
+
op->src[0]->ne[0] != 40 &&
|
|
700
914
|
op->src[0]->ne[0] != 64 &&
|
|
915
|
+
op->src[0]->ne[0] != 72 &&
|
|
701
916
|
op->src[0]->ne[0] != 80 &&
|
|
702
917
|
op->src[0]->ne[0] != 96 &&
|
|
703
918
|
op->src[0]->ne[0] != 112 &&
|
|
@@ -774,15 +989,13 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
774
989
|
return false;
|
|
775
990
|
}
|
|
776
991
|
case WSP_GGML_TYPE_I32:
|
|
777
|
-
return op->type == WSP_GGML_TYPE_F32;
|
|
992
|
+
return op->type == WSP_GGML_TYPE_F32 || op->type == WSP_GGML_TYPE_I32;
|
|
778
993
|
default:
|
|
779
994
|
return false;
|
|
780
995
|
};
|
|
781
996
|
}
|
|
782
997
|
case WSP_GGML_OP_GET_ROWS:
|
|
783
|
-
|
|
784
|
-
return op->ne[3] == 1;
|
|
785
|
-
}
|
|
998
|
+
return true;
|
|
786
999
|
case WSP_GGML_OP_SET_ROWS:
|
|
787
1000
|
{
|
|
788
1001
|
if (op->src[0]->type != WSP_GGML_TYPE_F32) {
|
|
@@ -804,6 +1017,9 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
804
1017
|
return false;
|
|
805
1018
|
};
|
|
806
1019
|
}
|
|
1020
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
1021
|
+
case WSP_GGML_OP_OPT_STEP_SGD:
|
|
1022
|
+
return has_simdgroup_reduction;
|
|
807
1023
|
default:
|
|
808
1024
|
return false;
|
|
809
1025
|
}
|
|
@@ -828,7 +1044,7 @@ struct wsp_ggml_metal_buffer_wrapper {
|
|
|
828
1044
|
};
|
|
829
1045
|
|
|
830
1046
|
struct wsp_ggml_metal_buffer {
|
|
831
|
-
void * all_data;
|
|
1047
|
+
void * all_data;
|
|
832
1048
|
size_t all_size;
|
|
833
1049
|
|
|
834
1050
|
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
|
@@ -966,14 +1182,15 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
|
|
|
966
1182
|
if (shared) {
|
|
967
1183
|
res->all_data = wsp_ggml_metal_host_malloc(size_aligned);
|
|
968
1184
|
res->is_shared = true;
|
|
969
|
-
res->owned = true;
|
|
970
1185
|
} else {
|
|
971
|
-
//
|
|
972
|
-
res->all_data = (void *)
|
|
1186
|
+
// use virtual address from g_addr_device counter
|
|
1187
|
+
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
|
973
1188
|
res->is_shared = false;
|
|
974
1189
|
}
|
|
975
1190
|
res->all_size = size_aligned;
|
|
976
1191
|
|
|
1192
|
+
res->owned = true;
|
|
1193
|
+
|
|
977
1194
|
res->device = wsp_ggml_metal_device_get_obj(dev);
|
|
978
1195
|
res->queue = wsp_ggml_metal_device_get_queue(dev);
|
|
979
1196
|
|
|
@@ -984,15 +1201,13 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
|
|
|
984
1201
|
res->buffers[0].metal = nil;
|
|
985
1202
|
|
|
986
1203
|
if (size_aligned > 0) {
|
|
987
|
-
if (props_dev->use_shared_buffers &&shared) {
|
|
1204
|
+
if (props_dev->use_shared_buffers && shared) {
|
|
988
1205
|
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
|
989
1206
|
length:size_aligned
|
|
990
1207
|
options:MTLResourceStorageModeShared
|
|
991
1208
|
deallocator:nil];
|
|
992
1209
|
} else {
|
|
993
1210
|
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
|
994
|
-
|
|
995
|
-
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
|
996
1211
|
}
|
|
997
1212
|
}
|
|
998
1213
|
|
|
@@ -1140,7 +1355,7 @@ bool wsp_ggml_metal_buffer_is_shared(wsp_ggml_metal_buffer_t buf) {
|
|
|
1140
1355
|
|
|
1141
1356
|
void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
|
1142
1357
|
if (buf->is_shared) {
|
|
1143
|
-
memset((char *)tensor->data + offset, value, size);
|
|
1358
|
+
memset((char *) tensor->data + offset, value, size);
|
|
1144
1359
|
return;
|
|
1145
1360
|
}
|
|
1146
1361
|
|
|
@@ -1169,7 +1384,7 @@ void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp
|
|
|
1169
1384
|
|
|
1170
1385
|
void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1171
1386
|
if (buf->is_shared) {
|
|
1172
|
-
memcpy((char *)tensor->data + offset, data, size);
|
|
1387
|
+
memcpy((char *) tensor->data + offset, data, size);
|
|
1173
1388
|
return;
|
|
1174
1389
|
}
|
|
1175
1390
|
|
|
@@ -1224,7 +1439,7 @@ void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_gg
|
|
|
1224
1439
|
|
|
1225
1440
|
void wsp_ggml_metal_buffer_get_tensor(wsp_ggml_metal_buffer_t buf, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1226
1441
|
if (buf->is_shared) {
|
|
1227
|
-
memcpy(data, (const char *)tensor->data + offset, size);
|
|
1442
|
+
memcpy(data, (const char *) tensor->data + offset, size);
|
|
1228
1443
|
return;
|
|
1229
1444
|
}
|
|
1230
1445
|
|