whisper.rn 0.5.2 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/cpp/ggml-alloc.c +11 -4
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
- package/cpp/ggml-cpu/ggml-cpu.c +50 -21
- package/cpp/ggml-cpu/ops.cpp +458 -349
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +16 -0
- package/cpp/ggml-cpu/unary-ops.h +2 -0
- package/cpp/ggml-cpu/vec.cpp +17 -0
- package/cpp/ggml-cpu/vec.h +10 -0
- package/cpp/ggml-impl.h +17 -1
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
- package/cpp/ggml-metal/ggml-metal-device.h +8 -1
- package/cpp/ggml-metal/ggml-metal-device.m +216 -14
- package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
- package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
- package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
- package/cpp/ggml-metal/ggml-metal.cpp +5 -0
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +154 -5
- package/cpp/ggml.h +73 -0
- package/cpp/whisper.cpp +6 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +7 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
- package/src/realtime-transcription/types.ts +9 -0
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
|
@@ -318,6 +318,44 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
|
|
|
318
318
|
return res;
|
|
319
319
|
}
|
|
320
320
|
|
|
321
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
322
|
+
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
323
|
+
|
|
324
|
+
char base[256];
|
|
325
|
+
char name[256];
|
|
326
|
+
|
|
327
|
+
snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
328
|
+
snprintf(name, 256, "%s", base);
|
|
329
|
+
|
|
330
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
331
|
+
if (res) {
|
|
332
|
+
return res;
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
336
|
+
|
|
337
|
+
return res;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
341
|
+
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
342
|
+
|
|
343
|
+
char base[256];
|
|
344
|
+
char name[256];
|
|
345
|
+
|
|
346
|
+
snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
347
|
+
snprintf(name, 256, "%s", base);
|
|
348
|
+
|
|
349
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
350
|
+
if (res) {
|
|
351
|
+
return res;
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
355
|
+
|
|
356
|
+
return res;
|
|
357
|
+
}
|
|
358
|
+
|
|
321
359
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
322
360
|
WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
323
361
|
|
|
@@ -677,7 +715,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp
|
|
|
677
715
|
char name[256];
|
|
678
716
|
|
|
679
717
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
680
|
-
snprintf(name, 256, "%
|
|
718
|
+
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
|
681
719
|
|
|
682
720
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
683
721
|
if (res) {
|
|
@@ -943,6 +981,34 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
|
|
|
943
981
|
return res;
|
|
944
982
|
}
|
|
945
983
|
|
|
984
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
985
|
+
assert(op->op == WSP_GGML_OP_ARGSORT);
|
|
986
|
+
|
|
987
|
+
char base[256];
|
|
988
|
+
char name[256];
|
|
989
|
+
|
|
990
|
+
wsp_ggml_sort_order order = (wsp_ggml_sort_order) op->op_params[0];
|
|
991
|
+
|
|
992
|
+
const char * order_str = "undefined";
|
|
993
|
+
switch (order) {
|
|
994
|
+
case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
995
|
+
case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
996
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
997
|
+
};
|
|
998
|
+
|
|
999
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
1000
|
+
snprintf(name, 256, "%s", base);
|
|
1001
|
+
|
|
1002
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1003
|
+
if (res) {
|
|
1004
|
+
return res;
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1008
|
+
|
|
1009
|
+
return res;
|
|
1010
|
+
}
|
|
1011
|
+
|
|
946
1012
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
947
1013
|
wsp_ggml_metal_library_t lib,
|
|
948
1014
|
const struct wsp_ggml_tensor * op,
|
|
@@ -1332,11 +1398,12 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
|
|
|
1332
1398
|
|
|
1333
1399
|
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
1334
1400
|
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
1401
|
+
const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE;
|
|
1335
1402
|
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
1336
1403
|
|
|
1337
1404
|
if (is_neox) {
|
|
1338
1405
|
snprintf(base, 256, "kernel_rope_neox_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1339
|
-
} else if (is_mrope && !is_vision) {
|
|
1406
|
+
} else if ((is_mrope || is_imrope) && !is_vision) {
|
|
1340
1407
|
WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1341
1408
|
snprintf(base, 256, "kernel_rope_multi_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1342
1409
|
} else if (is_vision) {
|
|
@@ -1346,14 +1413,20 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
|
|
|
1346
1413
|
snprintf(base, 256, "kernel_rope_norm_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1347
1414
|
}
|
|
1348
1415
|
|
|
1349
|
-
snprintf(name, 256, "%
|
|
1416
|
+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
|
1350
1417
|
|
|
1351
1418
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1352
1419
|
if (res) {
|
|
1353
1420
|
return res;
|
|
1354
1421
|
}
|
|
1355
1422
|
|
|
1356
|
-
|
|
1423
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1424
|
+
|
|
1425
|
+
wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
|
1426
|
+
|
|
1427
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1428
|
+
|
|
1429
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1357
1430
|
|
|
1358
1431
|
return res;
|
|
1359
1432
|
}
|
|
@@ -1431,6 +1504,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(
|
|
|
1431
1504
|
return res;
|
|
1432
1505
|
}
|
|
1433
1506
|
|
|
1507
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1508
|
+
assert(op->op == WSP_GGML_OP_CONV_2D);
|
|
1509
|
+
|
|
1510
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1511
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1512
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1513
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1514
|
+
|
|
1515
|
+
char base[256];
|
|
1516
|
+
char name[256];
|
|
1517
|
+
|
|
1518
|
+
snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1519
|
+
snprintf(name, 256, "%s", base);
|
|
1520
|
+
|
|
1521
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1522
|
+
if (res) {
|
|
1523
|
+
return res;
|
|
1524
|
+
}
|
|
1525
|
+
|
|
1526
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1527
|
+
|
|
1528
|
+
return res;
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1434
1531
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1435
1532
|
assert(op->op == WSP_GGML_OP_UPSCALE);
|
|
1436
1533
|
|
|
@@ -95,7 +95,9 @@ void wsp_ggml_metal_encoder_end_encoding(wsp_ggml_metal_encoder_t encoder);
|
|
|
95
95
|
|
|
96
96
|
typedef struct wsp_ggml_metal_library * wsp_ggml_metal_library_t;
|
|
97
97
|
|
|
98
|
-
wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev);
|
|
98
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init (wsp_ggml_metal_device_t dev);
|
|
99
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose);
|
|
100
|
+
|
|
99
101
|
void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib);
|
|
100
102
|
|
|
101
103
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline (wsp_ggml_metal_library_t lib, const char * name);
|
|
@@ -111,6 +113,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary
|
|
|
111
113
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
112
114
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
113
115
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
116
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
117
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
114
118
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
115
119
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
116
120
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -123,6 +127,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id
|
|
|
123
127
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
124
128
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
125
129
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
130
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
126
131
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin (wsp_ggml_metal_library_t lib, enum wsp_ggml_op op, int32_t n_fuse, bool row);
|
|
127
132
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
128
133
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -131,6 +136,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope
|
|
|
131
136
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
132
137
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
133
138
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
139
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
134
140
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
135
141
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
136
142
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -193,6 +199,7 @@ struct wsp_ggml_metal_device_props {
|
|
|
193
199
|
bool has_simdgroup_mm;
|
|
194
200
|
bool has_unified_memory;
|
|
195
201
|
bool has_bfloat;
|
|
202
|
+
bool has_tensor;
|
|
196
203
|
bool use_residency_sets;
|
|
197
204
|
bool use_shared_buffers;
|
|
198
205
|
|
|
@@ -21,8 +21,9 @@
|
|
|
21
21
|
#define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
|
|
22
22
|
#endif
|
|
23
23
|
|
|
24
|
-
// overload of
|
|
24
|
+
// overload of MTLGPUFamilyMetalX (not available in some environments)
|
|
25
25
|
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
|
26
|
+
static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
|
|
26
27
|
|
|
27
28
|
// virtual address for GPU memory allocations
|
|
28
29
|
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
|
@@ -180,11 +181,7 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
180
181
|
NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
|
|
181
182
|
#endif
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-whisper-sim" ofType:@"metallib"];
|
|
185
|
-
#else
|
|
186
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
|
|
187
|
-
#endif
|
|
184
|
+
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
188
185
|
if (path_lib == nil) {
|
|
189
186
|
// Try to find the resource in the directory where the current binary located.
|
|
190
187
|
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
|
@@ -265,6 +262,10 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
265
262
|
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_BF16"];
|
|
266
263
|
}
|
|
267
264
|
|
|
265
|
+
if (wsp_ggml_metal_device_get_props(dev)->has_tensor) {
|
|
266
|
+
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_TENSOR"];
|
|
267
|
+
}
|
|
268
|
+
|
|
268
269
|
#if WSP_GGML_METAL_EMBED_LIBRARY
|
|
269
270
|
[prep setObject:@"1" forKey:@"WSP_GGML_METAL_EMBED_LIBRARY"];
|
|
270
271
|
#endif
|
|
@@ -302,6 +303,72 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
|
|
|
302
303
|
return res;
|
|
303
304
|
}
|
|
304
305
|
|
|
306
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose) {
|
|
307
|
+
if (source == NULL) {
|
|
308
|
+
WSP_GGML_LOG_ERROR("%s: source is NULL\n", __func__);
|
|
309
|
+
return NULL;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
id<MTLDevice> device = wsp_ggml_metal_device_get_obj(dev);
|
|
313
|
+
id<MTLLibrary> library = nil;
|
|
314
|
+
NSError * error = nil;
|
|
315
|
+
|
|
316
|
+
const int64_t t_start = wsp_ggml_time_us();
|
|
317
|
+
|
|
318
|
+
NSString * src = [[NSString alloc] initWithBytes:source
|
|
319
|
+
length:strlen(source)
|
|
320
|
+
encoding:NSUTF8StringEncoding];
|
|
321
|
+
if (!src) {
|
|
322
|
+
WSP_GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
|
|
323
|
+
return NULL;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
@autoreleasepool {
|
|
327
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
|
328
|
+
|
|
329
|
+
MTLCompileOptions * options = [MTLCompileOptions new];
|
|
330
|
+
options.preprocessorMacros = prep;
|
|
331
|
+
|
|
332
|
+
library = [device newLibraryWithSource:src options:options error:&error];
|
|
333
|
+
if (error) {
|
|
334
|
+
if (verbose) {
|
|
335
|
+
WSP_GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
|
|
336
|
+
} else {
|
|
337
|
+
WSP_GGML_LOG_ERROR("%s: error compiling source\n", __func__);
|
|
338
|
+
}
|
|
339
|
+
library = nil;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
[options release];
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
[src release];
|
|
346
|
+
|
|
347
|
+
if (!library) {
|
|
348
|
+
if (verbose) {
|
|
349
|
+
WSP_GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
return NULL;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
if (verbose) {
|
|
356
|
+
WSP_GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (wsp_ggml_time_us() - t_start) / 1e6);
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
wsp_ggml_metal_library_t res = calloc(1, sizeof(struct wsp_ggml_metal_library));
|
|
360
|
+
if (!res) {
|
|
361
|
+
WSP_GGML_LOG_ERROR("%s: calloc failed\n", __func__);
|
|
362
|
+
return NULL;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
res->obj = library;
|
|
366
|
+
res->device = device;
|
|
367
|
+
res->pipelines = wsp_ggml_metal_pipelines_init();
|
|
368
|
+
|
|
369
|
+
return res;
|
|
370
|
+
}
|
|
371
|
+
|
|
305
372
|
void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib) {
|
|
306
373
|
if (!lib) {
|
|
307
374
|
return;
|
|
@@ -349,9 +416,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
|
|
|
349
416
|
if (!mtl_function) {
|
|
350
417
|
wsp_ggml_critical_section_end();
|
|
351
418
|
|
|
352
|
-
WSP_GGML_LOG_ERROR("%s:
|
|
419
|
+
WSP_GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
|
353
420
|
if (error) {
|
|
354
|
-
WSP_GGML_LOG_ERROR("%s:
|
|
421
|
+
WSP_GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
|
355
422
|
}
|
|
356
423
|
|
|
357
424
|
return nil;
|
|
@@ -359,13 +426,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
|
|
|
359
426
|
|
|
360
427
|
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
|
361
428
|
|
|
362
|
-
wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
|
|
363
|
-
|
|
364
429
|
[mtl_function release];
|
|
365
430
|
|
|
366
431
|
WSP_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
|
|
367
432
|
(int) res->obj.maxTotalThreadsPerThreadgroup,
|
|
368
433
|
(int) res->obj.threadExecutionWidth);
|
|
434
|
+
|
|
435
|
+
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
|
436
|
+
wsp_ggml_critical_section_end();
|
|
437
|
+
|
|
438
|
+
WSP_GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
|
439
|
+
|
|
440
|
+
return nil;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
|
|
369
444
|
}
|
|
370
445
|
|
|
371
446
|
wsp_ggml_critical_section_end();
|
|
@@ -473,6 +548,128 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
473
548
|
|
|
474
549
|
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
475
550
|
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
|
551
|
+
if (getenv("WSP_GGML_METAL_BF16_DISABLE") != NULL) {
|
|
552
|
+
dev->props.has_bfloat = false;
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
|
|
556
|
+
if (getenv("WSP_GGML_METAL_TENSOR_DISABLE") != NULL) {
|
|
557
|
+
dev->props.has_tensor = false;
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// note: disable the tensor API by default for old chips because with the current implementation it is not useful
|
|
561
|
+
// - M2 Ultra: ~5% slower
|
|
562
|
+
// - M4, M4 Max: no significant difference
|
|
563
|
+
//
|
|
564
|
+
// TODO: try to update the tensor API kernels to at least match the simdgroup performance
|
|
565
|
+
if (getenv("WSP_GGML_METAL_TENSOR_ENABLE") == NULL &&
|
|
566
|
+
![[dev->mtl_device name] containsString:@"M5"] &&
|
|
567
|
+
![[dev->mtl_device name] containsString:@"M6"] &&
|
|
568
|
+
![[dev->mtl_device name] containsString:@"A19"] &&
|
|
569
|
+
![[dev->mtl_device name] containsString:@"A20"]) {
|
|
570
|
+
WSP_GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
|
|
571
|
+
dev->props.has_tensor = false;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
// double-check that the tensor API compiles
|
|
575
|
+
if (dev->props.has_tensor) {
|
|
576
|
+
const char * src_tensor_f16 = "\n"
|
|
577
|
+
"#include <metal_stdlib> \n"
|
|
578
|
+
"#include <metal_tensor> \n"
|
|
579
|
+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
|
580
|
+
" \n"
|
|
581
|
+
"using namespace metal; \n"
|
|
582
|
+
"using namespace mpp::tensor_ops; \n"
|
|
583
|
+
" \n"
|
|
584
|
+
"kernel void dummy_kernel( \n"
|
|
585
|
+
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
|
586
|
+
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
|
587
|
+
" device float * C [[buffer(2)]], \n"
|
|
588
|
+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
|
589
|
+
"{ \n"
|
|
590
|
+
" auto tA = A.slice(0, (int)tgid.y); \n"
|
|
591
|
+
" auto tB = B.slice((int)tgid.x, 0); \n"
|
|
592
|
+
" \n"
|
|
593
|
+
" matmul2d< \n"
|
|
594
|
+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
|
595
|
+
" execution_simdgroups<4>> mm; \n"
|
|
596
|
+
" \n"
|
|
597
|
+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
|
598
|
+
" \n"
|
|
599
|
+
" auto sA = tA.slice(0, 0); \n"
|
|
600
|
+
" auto sB = tB.slice(0, 0); \n"
|
|
601
|
+
" mm.run(sB, sA, cT); \n"
|
|
602
|
+
" \n"
|
|
603
|
+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
|
604
|
+
" \n"
|
|
605
|
+
" cT.store(tC); \n"
|
|
606
|
+
"}";
|
|
607
|
+
|
|
608
|
+
WSP_GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
|
|
609
|
+
wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
|
|
610
|
+
if (lib == NULL) {
|
|
611
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
|
612
|
+
dev->props.has_tensor = false;
|
|
613
|
+
} else {
|
|
614
|
+
wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
|
615
|
+
if (!ppl) {
|
|
616
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
|
617
|
+
dev->props.has_tensor = false;
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
wsp_ggml_metal_library_free(lib);
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
// try to compile a dummy kernel to determine if the tensor API is supported for bfloat
|
|
625
|
+
if (dev->props.has_tensor && dev->props.has_bfloat) {
|
|
626
|
+
const char * src_tensor_bf16 = "\n"
|
|
627
|
+
"#include <metal_stdlib> \n"
|
|
628
|
+
"#include <metal_tensor> \n"
|
|
629
|
+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
|
630
|
+
" \n"
|
|
631
|
+
"using namespace metal; \n"
|
|
632
|
+
"using namespace mpp::tensor_ops; \n"
|
|
633
|
+
" \n"
|
|
634
|
+
"kernel void dummy_kernel( \n"
|
|
635
|
+
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
|
636
|
+
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
|
637
|
+
" device float * C [[buffer(2)]], \n"
|
|
638
|
+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
|
639
|
+
"{ \n"
|
|
640
|
+
" auto tA = A.slice(0, (int)tgid.y); \n"
|
|
641
|
+
" auto tB = B.slice((int)tgid.x, 0); \n"
|
|
642
|
+
" \n"
|
|
643
|
+
" matmul2d< \n"
|
|
644
|
+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
|
645
|
+
" execution_simdgroups<4>> mm; \n"
|
|
646
|
+
" \n"
|
|
647
|
+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
|
648
|
+
" \n"
|
|
649
|
+
" auto sA = tA.slice(0, 0); \n"
|
|
650
|
+
" auto sB = tB.slice(0, 0); \n"
|
|
651
|
+
" mm.run(sB, sA, cT); \n"
|
|
652
|
+
" \n"
|
|
653
|
+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
|
|
654
|
+
" \n"
|
|
655
|
+
" cT.store(tC); \n"
|
|
656
|
+
"}";
|
|
657
|
+
|
|
658
|
+
WSP_GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
|
|
659
|
+
wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
|
|
660
|
+
if (lib == NULL) {
|
|
661
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
|
662
|
+
dev->props.has_bfloat = false;
|
|
663
|
+
} else {
|
|
664
|
+
wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
|
665
|
+
if (!ppl) {
|
|
666
|
+
WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
|
667
|
+
dev->props.has_bfloat = false;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
wsp_ggml_metal_library_free(lib);
|
|
671
|
+
}
|
|
672
|
+
}
|
|
476
673
|
|
|
477
674
|
dev->props.use_residency_sets = true;
|
|
478
675
|
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
@@ -480,7 +677,6 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
480
677
|
#endif
|
|
481
678
|
|
|
482
679
|
dev->props.use_shared_buffers = dev->props.has_unified_memory;
|
|
483
|
-
|
|
484
680
|
if (getenv("WSP_GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
|
|
485
681
|
dev->props.use_shared_buffers = false;
|
|
486
682
|
}
|
|
@@ -533,6 +729,7 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
|
|
|
533
729
|
WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
|
|
534
730
|
WSP_GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
|
|
535
731
|
WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
|
|
732
|
+
WSP_GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
|
|
536
733
|
WSP_GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
|
|
537
734
|
WSP_GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
|
|
538
735
|
|
|
@@ -673,6 +870,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
673
870
|
case WSP_GGML_OP_SUM:
|
|
674
871
|
return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
|
|
675
872
|
case WSP_GGML_OP_SUM_ROWS:
|
|
873
|
+
case WSP_GGML_OP_CUMSUM:
|
|
676
874
|
case WSP_GGML_OP_MEAN:
|
|
677
875
|
case WSP_GGML_OP_SOFT_MAX:
|
|
678
876
|
case WSP_GGML_OP_GROUP_NORM:
|
|
@@ -688,6 +886,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
688
886
|
return true;
|
|
689
887
|
case WSP_GGML_OP_IM2COL:
|
|
690
888
|
return wsp_ggml_is_contiguous(op->src[1]) && op->src[1]->type == WSP_GGML_TYPE_F32 && (op->type == WSP_GGML_TYPE_F16 || op->type == WSP_GGML_TYPE_F32);
|
|
889
|
+
case WSP_GGML_OP_CONV_2D:
|
|
890
|
+
return wsp_ggml_is_contiguous(op->src[0]) &&
|
|
891
|
+
op->src[1]->type == WSP_GGML_TYPE_F32 &&
|
|
892
|
+
op->type == WSP_GGML_TYPE_F32 &&
|
|
893
|
+
(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
691
894
|
case WSP_GGML_OP_POOL_1D:
|
|
692
895
|
return false;
|
|
693
896
|
case WSP_GGML_OP_UPSCALE:
|
|
@@ -702,8 +905,6 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
702
905
|
case WSP_GGML_OP_LEAKY_RELU:
|
|
703
906
|
return op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
704
907
|
case WSP_GGML_OP_ARGSORT:
|
|
705
|
-
// TODO: Support arbitrary column width
|
|
706
|
-
return op->src[0]->ne[0] <= 1024;
|
|
707
908
|
case WSP_GGML_OP_ARANGE:
|
|
708
909
|
return true;
|
|
709
910
|
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
@@ -711,6 +912,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
711
912
|
if (op->src[0]->ne[0] != 32 &&
|
|
712
913
|
op->src[0]->ne[0] != 40 &&
|
|
713
914
|
op->src[0]->ne[0] != 64 &&
|
|
915
|
+
op->src[0]->ne[0] != 72 &&
|
|
714
916
|
op->src[0]->ne[0] != 80 &&
|
|
715
917
|
op->src[0]->ne[0] != 96 &&
|
|
716
918
|
op->src[0]->ne[0] != 112 &&
|
|
@@ -787,7 +989,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
787
989
|
return false;
|
|
788
990
|
}
|
|
789
991
|
case WSP_GGML_TYPE_I32:
|
|
790
|
-
return op->type == WSP_GGML_TYPE_F32;
|
|
992
|
+
return op->type == WSP_GGML_TYPE_F32 || op->type == WSP_GGML_TYPE_I32;
|
|
791
993
|
default:
|
|
792
994
|
return false;
|
|
793
995
|
};
|
|
@@ -76,6 +76,7 @@
|
|
|
76
76
|
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
|
77
77
|
#define FC_MUL_MV 600
|
|
78
78
|
#define FC_MUL_MM 700
|
|
79
|
+
#define FC_ROPE 800
|
|
79
80
|
|
|
80
81
|
// op-specific constants
|
|
81
82
|
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
|
@@ -527,6 +528,36 @@ typedef struct {
|
|
|
527
528
|
uint64_t nb2;
|
|
528
529
|
} wsp_ggml_metal_kargs_conv_transpose_2d;
|
|
529
530
|
|
|
531
|
+
typedef struct {
|
|
532
|
+
uint64_t nb00;
|
|
533
|
+
uint64_t nb01;
|
|
534
|
+
uint64_t nb02;
|
|
535
|
+
uint64_t nb03;
|
|
536
|
+
uint64_t nb10;
|
|
537
|
+
uint64_t nb11;
|
|
538
|
+
uint64_t nb12;
|
|
539
|
+
uint64_t nb13;
|
|
540
|
+
uint64_t nb0;
|
|
541
|
+
uint64_t nb1;
|
|
542
|
+
uint64_t nb2;
|
|
543
|
+
uint64_t nb3;
|
|
544
|
+
int32_t IW;
|
|
545
|
+
int32_t IH;
|
|
546
|
+
int32_t KW;
|
|
547
|
+
int32_t KH;
|
|
548
|
+
int32_t IC;
|
|
549
|
+
int32_t OC;
|
|
550
|
+
int32_t OW;
|
|
551
|
+
int32_t OH;
|
|
552
|
+
int32_t N;
|
|
553
|
+
int32_t s0;
|
|
554
|
+
int32_t s1;
|
|
555
|
+
int32_t p0;
|
|
556
|
+
int32_t p1;
|
|
557
|
+
int32_t d0;
|
|
558
|
+
int32_t d1;
|
|
559
|
+
} wsp_ggml_metal_kargs_conv_2d;
|
|
560
|
+
|
|
530
561
|
typedef struct {
|
|
531
562
|
uint64_t ofs0;
|
|
532
563
|
uint64_t ofs1;
|
|
@@ -581,6 +612,45 @@ typedef struct {
|
|
|
581
612
|
uint64_t nb3;
|
|
582
613
|
} wsp_ggml_metal_kargs_sum_rows;
|
|
583
614
|
|
|
615
|
+
typedef struct {
|
|
616
|
+
int64_t ne00;
|
|
617
|
+
int64_t ne01;
|
|
618
|
+
int64_t ne02;
|
|
619
|
+
int64_t ne03;
|
|
620
|
+
uint64_t nb00;
|
|
621
|
+
uint64_t nb01;
|
|
622
|
+
uint64_t nb02;
|
|
623
|
+
uint64_t nb03;
|
|
624
|
+
int64_t net0;
|
|
625
|
+
int64_t net1;
|
|
626
|
+
int64_t net2;
|
|
627
|
+
int64_t net3;
|
|
628
|
+
uint64_t nbt0;
|
|
629
|
+
uint64_t nbt1;
|
|
630
|
+
uint64_t nbt2;
|
|
631
|
+
uint64_t nbt3;
|
|
632
|
+
bool outb;
|
|
633
|
+
} wsp_ggml_metal_kargs_cumsum_blk;
|
|
634
|
+
|
|
635
|
+
typedef struct {
|
|
636
|
+
int64_t ne00;
|
|
637
|
+
int64_t ne01;
|
|
638
|
+
int64_t ne02;
|
|
639
|
+
int64_t ne03;
|
|
640
|
+
uint64_t nb00;
|
|
641
|
+
uint64_t nb01;
|
|
642
|
+
uint64_t nb02;
|
|
643
|
+
uint64_t nb03;
|
|
644
|
+
int64_t net0;
|
|
645
|
+
int64_t net1;
|
|
646
|
+
int64_t net2;
|
|
647
|
+
int64_t net3;
|
|
648
|
+
uint64_t nbt0;
|
|
649
|
+
uint64_t nbt1;
|
|
650
|
+
uint64_t nbt2;
|
|
651
|
+
uint64_t nbt3;
|
|
652
|
+
} wsp_ggml_metal_kargs_cumsum_add;
|
|
653
|
+
|
|
584
654
|
typedef struct {
|
|
585
655
|
int32_t ne00;
|
|
586
656
|
int32_t ne01;
|
|
@@ -762,10 +832,28 @@ typedef struct {
|
|
|
762
832
|
} wsp_ggml_metal_kargs_leaky_relu;
|
|
763
833
|
|
|
764
834
|
typedef struct {
|
|
765
|
-
int64_t
|
|
766
|
-
int64_t
|
|
835
|
+
int64_t ne00;
|
|
836
|
+
int64_t ne01;
|
|
837
|
+
int64_t ne02;
|
|
838
|
+
int64_t ne03;
|
|
839
|
+
uint64_t nb00;
|
|
840
|
+
uint64_t nb01;
|
|
841
|
+
uint64_t nb02;
|
|
842
|
+
uint64_t nb03;
|
|
767
843
|
} wsp_ggml_metal_kargs_argsort;
|
|
768
844
|
|
|
845
|
+
typedef struct {
|
|
846
|
+
int64_t ne00;
|
|
847
|
+
int64_t ne01;
|
|
848
|
+
int64_t ne02;
|
|
849
|
+
int64_t ne03;
|
|
850
|
+
uint64_t nb00;
|
|
851
|
+
uint64_t nb01;
|
|
852
|
+
uint64_t nb02;
|
|
853
|
+
uint64_t nb03;
|
|
854
|
+
int32_t len;
|
|
855
|
+
} wsp_ggml_metal_kargs_argsort_merge;
|
|
856
|
+
|
|
769
857
|
typedef struct {
|
|
770
858
|
int64_t ne0;
|
|
771
859
|
float start;
|