whisper.rn 0.4.1 → 0.4.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +24 -18
- package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -57
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
- package/cpp/ggml-backend.cpp +36 -18
- package/cpp/ggml-backend.h +1 -1
- package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
- package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/ggml-cpu/common.h +3 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/ggml-cpu/ggml-cpu.c +95 -17
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +775 -74
- package/cpp/ggml-cpu/ops.h +7 -0
- package/cpp/ggml-cpu/quants.c +25 -24
- package/cpp/ggml-cpu/repack.cpp +15 -14
- package/cpp/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/ggml-cpu/vec.cpp +26 -2
- package/cpp/ggml-cpu/vec.h +99 -45
- package/cpp/ggml-cpu.h +2 -0
- package/cpp/ggml-impl.h +125 -183
- package/cpp/ggml-metal-impl.h +27 -0
- package/cpp/ggml-metal.m +298 -41
- package/cpp/ggml-quants.c +6 -6
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +269 -40
- package/cpp/ggml.h +122 -2
- package/cpp/gguf.cpp +5 -1
- package/cpp/whisper.cpp +4 -0
- package/cpp/whisper.h +2 -0
- package/ios/RNWhisper.mm +35 -38
- package/ios/RNWhisperVadContext.h +1 -1
- package/ios/RNWhisperVadContext.mm +2 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/package.json +1 -1
package/cpp/ggml-metal.m
CHANGED
|
@@ -48,22 +48,28 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
48
48
|
int mtl_device_ref_count;
|
|
49
49
|
id<MTLLibrary> mtl_library;
|
|
50
50
|
|
|
51
|
+
NSLock * mtl_lock;
|
|
52
|
+
|
|
51
53
|
bool has_simdgroup_reduction;
|
|
52
54
|
bool has_simdgroup_mm;
|
|
53
55
|
bool has_residency_sets;
|
|
54
56
|
bool has_bfloat;
|
|
55
57
|
bool use_bfloat;
|
|
56
58
|
|
|
59
|
+
size_t max_size;
|
|
60
|
+
|
|
57
61
|
char name[128];
|
|
58
62
|
} g_wsp_ggml_ctx_dev_main = {
|
|
59
63
|
/*.mtl_device =*/ nil,
|
|
60
64
|
/*.mtl_device_ref_count =*/ 0,
|
|
61
65
|
/*.mtl_library =*/ nil,
|
|
66
|
+
/*.mtl_lock =*/ nil,
|
|
62
67
|
/*.has_simdgroup_reduction =*/ false,
|
|
63
68
|
/*.has_simdgroup_mm =*/ false,
|
|
64
69
|
/*.has_residency_sets =*/ false,
|
|
65
70
|
/*.has_bfloat =*/ false,
|
|
66
71
|
/*.use_bfloat =*/ false,
|
|
72
|
+
/*.max_size =*/ 0,
|
|
67
73
|
/*.name =*/ "",
|
|
68
74
|
};
|
|
69
75
|
|
|
@@ -71,6 +77,10 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
71
77
|
static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_metal_device_context * ctx) {
|
|
72
78
|
assert(ctx != NULL);
|
|
73
79
|
|
|
80
|
+
if (ctx->mtl_lock == nil) {
|
|
81
|
+
ctx->mtl_lock = [[NSLock alloc] init];
|
|
82
|
+
}
|
|
83
|
+
|
|
74
84
|
if (ctx->mtl_device == nil) {
|
|
75
85
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
76
86
|
}
|
|
@@ -94,6 +104,8 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
|
|
|
94
104
|
ctx->use_bfloat = false;
|
|
95
105
|
#endif
|
|
96
106
|
|
|
107
|
+
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
|
+
|
|
97
109
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
98
110
|
}
|
|
99
111
|
|
|
@@ -110,6 +122,11 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
|
|
|
110
122
|
ctx->mtl_device_ref_count--;
|
|
111
123
|
|
|
112
124
|
if (ctx->mtl_device_ref_count == 0) {
|
|
125
|
+
if (ctx->mtl_lock) {
|
|
126
|
+
[ctx->mtl_lock release];
|
|
127
|
+
ctx->mtl_lock = nil;
|
|
128
|
+
}
|
|
129
|
+
|
|
113
130
|
if (ctx->mtl_library) {
|
|
114
131
|
[ctx->mtl_library release];
|
|
115
132
|
ctx->mtl_library = nil;
|
|
@@ -185,6 +202,15 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
185
202
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
186
203
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
187
204
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
205
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
|
206
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
|
207
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
|
208
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
|
209
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
|
210
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
|
211
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
|
212
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
188
214
|
WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
189
215
|
WSP_GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
190
216
|
WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -194,11 +220,14 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
194
220
|
WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
195
221
|
WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
196
222
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
223
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
|
197
224
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
225
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
|
198
226
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
199
227
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
200
228
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
201
229
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
230
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
|
202
231
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
|
203
232
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
|
204
233
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
@@ -497,6 +526,9 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
497
526
|
WSP_GGML_METAL_KERNEL_TYPE_SIN,
|
|
498
527
|
WSP_GGML_METAL_KERNEL_TYPE_COS,
|
|
499
528
|
WSP_GGML_METAL_KERNEL_TYPE_NEG,
|
|
529
|
+
WSP_GGML_METAL_KERNEL_TYPE_REGLU,
|
|
530
|
+
WSP_GGML_METAL_KERNEL_TYPE_GEGLU,
|
|
531
|
+
WSP_GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
500
532
|
WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
501
533
|
WSP_GGML_METAL_KERNEL_TYPE_MEAN,
|
|
502
534
|
WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -981,7 +1013,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
981
1013
|
struct wsp_ggml_backend_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_backend_metal_context));
|
|
982
1014
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
983
1015
|
|
|
984
|
-
id<MTLDevice> device =
|
|
1016
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
985
1017
|
|
|
986
1018
|
WSP_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
987
1019
|
|
|
@@ -995,9 +1027,16 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
995
1027
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
996
1028
|
|
|
997
1029
|
// load library
|
|
998
|
-
|
|
999
|
-
ctx_dev->
|
|
1030
|
+
{
|
|
1031
|
+
[ctx_dev->mtl_lock lock];
|
|
1032
|
+
|
|
1033
|
+
if (ctx_dev->mtl_library == nil) {
|
|
1034
|
+
ctx_dev->mtl_library = wsp_ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
[ctx_dev->mtl_lock unlock];
|
|
1000
1038
|
}
|
|
1039
|
+
|
|
1001
1040
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
|
1002
1041
|
if (metal_library == nil) {
|
|
1003
1042
|
WSP_GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
|
@@ -1146,6 +1185,15 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1146
1185
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
1147
1186
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
1148
1187
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
1188
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
|
1189
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
|
1190
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
|
1191
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
|
1192
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
|
1193
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
|
1194
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
|
1195
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1196
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1149
1197
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1150
1198
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1151
1199
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
@@ -1155,11 +1203,14 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1155
1203
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1156
1204
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1157
1205
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
1206
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
|
1158
1207
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
1208
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
|
1159
1209
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
1160
1210
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
1161
1211
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
|
1162
1212
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
1213
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
|
1163
1214
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
1164
1215
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
1165
1216
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
@@ -1458,6 +1509,9 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1458
1509
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
1459
1510
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1460
1511
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1512
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
|
1513
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
|
1514
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
1461
1515
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1462
1516
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1463
1517
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1609,6 +1663,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1609
1663
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
|
1610
1664
|
|
|
1611
1665
|
if (!use_bfloat) {
|
|
1666
|
+
if (op->type == WSP_GGML_TYPE_BF16) {
|
|
1667
|
+
return false;
|
|
1668
|
+
}
|
|
1669
|
+
|
|
1612
1670
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
1613
1671
|
if (op->src[i] != NULL && op->src[i]->type == WSP_GGML_TYPE_BF16) {
|
|
1614
1672
|
return false;
|
|
@@ -1632,6 +1690,15 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1632
1690
|
default:
|
|
1633
1691
|
return false;
|
|
1634
1692
|
}
|
|
1693
|
+
case WSP_GGML_OP_GLU:
|
|
1694
|
+
switch (wsp_ggml_get_glu_op(op)) {
|
|
1695
|
+
case WSP_GGML_GLU_OP_REGLU:
|
|
1696
|
+
case WSP_GGML_GLU_OP_GEGLU:
|
|
1697
|
+
case WSP_GGML_GLU_OP_SWIGLU:
|
|
1698
|
+
return wsp_ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
1699
|
+
default:
|
|
1700
|
+
return false;
|
|
1701
|
+
}
|
|
1635
1702
|
case WSP_GGML_OP_NONE:
|
|
1636
1703
|
case WSP_GGML_OP_RESHAPE:
|
|
1637
1704
|
case WSP_GGML_OP_VIEW:
|
|
@@ -1778,6 +1845,27 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1778
1845
|
{
|
|
1779
1846
|
return op->ne[3] == 1;
|
|
1780
1847
|
}
|
|
1848
|
+
case WSP_GGML_OP_SET_ROWS:
|
|
1849
|
+
{
|
|
1850
|
+
if (op->src[0]->type != WSP_GGML_TYPE_F32) {
|
|
1851
|
+
return false;
|
|
1852
|
+
}
|
|
1853
|
+
|
|
1854
|
+
switch (op->type) {
|
|
1855
|
+
case WSP_GGML_TYPE_F32:
|
|
1856
|
+
case WSP_GGML_TYPE_F16:
|
|
1857
|
+
case WSP_GGML_TYPE_BF16:
|
|
1858
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1859
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1860
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1861
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1862
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1863
|
+
case WSP_GGML_TYPE_IQ4_NL:
|
|
1864
|
+
return true;
|
|
1865
|
+
default:
|
|
1866
|
+
return false;
|
|
1867
|
+
};
|
|
1868
|
+
}
|
|
1781
1869
|
default:
|
|
1782
1870
|
return false;
|
|
1783
1871
|
}
|
|
@@ -2350,6 +2438,62 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2350
2438
|
WSP_GGML_ABORT("fatal error");
|
|
2351
2439
|
}
|
|
2352
2440
|
} break;
|
|
2441
|
+
case WSP_GGML_OP_GLU:
|
|
2442
|
+
{
|
|
2443
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
2444
|
+
|
|
2445
|
+
if (src1) {
|
|
2446
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2450
|
+
|
|
2451
|
+
switch (wsp_ggml_get_glu_op(node)) {
|
|
2452
|
+
case WSP_GGML_GLU_OP_REGLU:
|
|
2453
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
|
2454
|
+
break;
|
|
2455
|
+
case WSP_GGML_GLU_OP_GEGLU:
|
|
2456
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
|
2457
|
+
break;
|
|
2458
|
+
case WSP_GGML_GLU_OP_SWIGLU:
|
|
2459
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
|
2460
|
+
break;
|
|
2461
|
+
default:
|
|
2462
|
+
WSP_GGML_ABORT("fatal error");
|
|
2463
|
+
}
|
|
2464
|
+
|
|
2465
|
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
|
2466
|
+
|
|
2467
|
+
const int32_t i00 = swp ? ne0 : 0;
|
|
2468
|
+
const int32_t i10 = swp ? 0 : ne0;
|
|
2469
|
+
|
|
2470
|
+
wsp_ggml_metal_kargs_glu args = {
|
|
2471
|
+
/*.ne00 =*/ ne00,
|
|
2472
|
+
/*.nb01 =*/ nb01,
|
|
2473
|
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
|
2474
|
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
|
2475
|
+
/*.ne0 =*/ ne0,
|
|
2476
|
+
/*.nb1 =*/ nb1,
|
|
2477
|
+
/*.i00 =*/ src1 ? 0 : i00,
|
|
2478
|
+
/*.i10 =*/ src1 ? 0 : i10,
|
|
2479
|
+
};
|
|
2480
|
+
|
|
2481
|
+
[encoder setComputePipelineState:pipeline];
|
|
2482
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2483
|
+
if (src1) {
|
|
2484
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
2485
|
+
} else {
|
|
2486
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2487
|
+
}
|
|
2488
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2489
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
2490
|
+
|
|
2491
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
2492
|
+
|
|
2493
|
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
|
2494
|
+
|
|
2495
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2496
|
+
} break;
|
|
2353
2497
|
case WSP_GGML_OP_SQR:
|
|
2354
2498
|
{
|
|
2355
2499
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
@@ -2430,6 +2574,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2430
2574
|
nth *= 2;
|
|
2431
2575
|
}
|
|
2432
2576
|
|
|
2577
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
2433
2578
|
nth = MIN(nth, ne00);
|
|
2434
2579
|
|
|
2435
2580
|
wsp_ggml_metal_kargs_sum_rows args = {
|
|
@@ -3090,14 +3235,23 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3090
3235
|
nsg = 1;
|
|
3091
3236
|
nr0 = 1;
|
|
3092
3237
|
nr1 = 4;
|
|
3093
|
-
|
|
3238
|
+
if (ne00 == 4) {
|
|
3239
|
+
nr0 = 32;
|
|
3240
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
|
3241
|
+
} else {
|
|
3242
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
3243
|
+
}
|
|
3094
3244
|
} break;
|
|
3095
3245
|
case WSP_GGML_TYPE_F16:
|
|
3096
3246
|
{
|
|
3097
3247
|
nsg = 1;
|
|
3098
3248
|
nr0 = 1;
|
|
3099
3249
|
if (src1t == WSP_GGML_TYPE_F32) {
|
|
3100
|
-
if (
|
|
3250
|
+
if (ne00 == 4) {
|
|
3251
|
+
nr0 = 32;
|
|
3252
|
+
nr1 = 4;
|
|
3253
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
|
3254
|
+
} else if (ne11 * ne12 < 4) {
|
|
3101
3255
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
3102
3256
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3103
3257
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
@@ -3116,7 +3270,11 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3116
3270
|
nsg = 1;
|
|
3117
3271
|
nr0 = 1;
|
|
3118
3272
|
if (src1t == WSP_GGML_TYPE_F32) {
|
|
3119
|
-
if (
|
|
3273
|
+
if (ne00 == 4) {
|
|
3274
|
+
nr0 = 32;
|
|
3275
|
+
nr1 = 4;
|
|
3276
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
|
3277
|
+
} else if (ne11 * ne12 < 4) {
|
|
3120
3278
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
3121
3279
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3122
3280
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
@@ -3737,13 +3895,74 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3737
3895
|
};
|
|
3738
3896
|
|
|
3739
3897
|
[encoder setComputePipelineState:pipeline];
|
|
3740
|
-
[encoder
|
|
3741
|
-
[encoder setBuffer:
|
|
3742
|
-
[encoder setBuffer:
|
|
3743
|
-
[encoder
|
|
3898
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3899
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3900
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3901
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3744
3902
|
|
|
3745
3903
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
3746
3904
|
} break;
|
|
3905
|
+
case WSP_GGML_OP_SET_ROWS:
|
|
3906
|
+
{
|
|
3907
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3908
|
+
|
|
3909
|
+
switch (dst->type) {
|
|
3910
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
|
3911
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
|
3912
|
+
case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
|
3913
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
|
3914
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
|
3915
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
|
3916
|
+
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
|
3917
|
+
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
|
3918
|
+
case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
|
3919
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3920
|
+
}
|
|
3921
|
+
|
|
3922
|
+
const int32_t nk0 = ne0/wsp_ggml_blck_size(dst->type);
|
|
3923
|
+
|
|
3924
|
+
int nth = 32; // SIMD width
|
|
3925
|
+
|
|
3926
|
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3927
|
+
nth *= 2;
|
|
3928
|
+
}
|
|
3929
|
+
|
|
3930
|
+
int nrptg = 1;
|
|
3931
|
+
if (nth > nk0) {
|
|
3932
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
3933
|
+
nth = nk0;
|
|
3934
|
+
|
|
3935
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3936
|
+
nrptg--;
|
|
3937
|
+
}
|
|
3938
|
+
}
|
|
3939
|
+
|
|
3940
|
+
nth = MIN(nth, nk0);
|
|
3941
|
+
|
|
3942
|
+
wsp_ggml_metal_kargs_set_rows args = {
|
|
3943
|
+
/*.nk0 =*/ nk0,
|
|
3944
|
+
/*.ne01 =*/ ne01,
|
|
3945
|
+
/*.nb01 =*/ nb01,
|
|
3946
|
+
/*.nb02 =*/ nb02,
|
|
3947
|
+
/*.nb03 =*/ nb03,
|
|
3948
|
+
/*.ne11 =*/ ne11,
|
|
3949
|
+
/*.ne12 =*/ ne12,
|
|
3950
|
+
/*.nb10 =*/ nb10,
|
|
3951
|
+
/*.nb11 =*/ nb11,
|
|
3952
|
+
/*.nb12 =*/ nb12,
|
|
3953
|
+
/*.nb1 =*/ nb1,
|
|
3954
|
+
/*.nb2 =*/ nb2,
|
|
3955
|
+
/*.nb3 =*/ nb3,
|
|
3956
|
+
};
|
|
3957
|
+
|
|
3958
|
+
[encoder setComputePipelineState:pipeline];
|
|
3959
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3960
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3961
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3962
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3963
|
+
|
|
3964
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
3965
|
+
} break;
|
|
3747
3966
|
case WSP_GGML_OP_RMS_NORM:
|
|
3748
3967
|
{
|
|
3749
3968
|
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -3760,6 +3979,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3760
3979
|
nth *= 2;
|
|
3761
3980
|
}
|
|
3762
3981
|
|
|
3982
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3763
3983
|
nth = MIN(nth, ne00/4);
|
|
3764
3984
|
|
|
3765
3985
|
wsp_ggml_metal_kargs_rms_norm args = {
|
|
@@ -3796,6 +4016,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3796
4016
|
nth *= 2;
|
|
3797
4017
|
}
|
|
3798
4018
|
|
|
4019
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3799
4020
|
nth = MIN(nth, ne00/4);
|
|
3800
4021
|
|
|
3801
4022
|
wsp_ggml_metal_kargs_l2_norm args = {
|
|
@@ -3868,6 +4089,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3868
4089
|
nth *= 2;
|
|
3869
4090
|
}
|
|
3870
4091
|
|
|
4092
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3871
4093
|
nth = MIN(nth, ne00/4);
|
|
3872
4094
|
|
|
3873
4095
|
wsp_ggml_metal_kargs_norm args = {
|
|
@@ -4954,8 +5176,39 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
4954
5176
|
default: WSP_GGML_ABORT("not implemented");
|
|
4955
5177
|
}
|
|
4956
5178
|
|
|
5179
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
5180
|
+
|
|
5181
|
+
// TODO: support
|
|
5182
|
+
//const int32_t nk00 = ne00/wsp_ggml_blck_size(dst->type);
|
|
5183
|
+
const int32_t nk00 = ne00;
|
|
5184
|
+
|
|
5185
|
+
int nth = 32; // SIMD width
|
|
5186
|
+
|
|
5187
|
+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5188
|
+
nth *= 2;
|
|
5189
|
+
}
|
|
5190
|
+
|
|
5191
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
5192
|
+
|
|
5193
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
5194
|
+
int nrptg = 1;
|
|
5195
|
+
|
|
5196
|
+
// TODO: relax this constraint in the future
|
|
5197
|
+
if (wsp_ggml_blck_size(src0->type) == 1 && wsp_ggml_blck_size(dst->type) == 1) {
|
|
5198
|
+
if (nth > nk00) {
|
|
5199
|
+
nrptg = (nth + nk00 - 1)/nk00;
|
|
5200
|
+
nth = nk00;
|
|
5201
|
+
|
|
5202
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5203
|
+
nrptg--;
|
|
5204
|
+
}
|
|
5205
|
+
}
|
|
5206
|
+
}
|
|
5207
|
+
|
|
5208
|
+
nth = MIN(nth, nk00);
|
|
5209
|
+
|
|
4957
5210
|
wsp_ggml_metal_kargs_cpy args = {
|
|
4958
|
-
/*.ne00 =*/
|
|
5211
|
+
/*.ne00 =*/ nk00,
|
|
4959
5212
|
/*.ne01 =*/ ne01,
|
|
4960
5213
|
/*.ne02 =*/ ne02,
|
|
4961
5214
|
/*.ne03 =*/ ne03,
|
|
@@ -4978,11 +5231,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
4978
5231
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4979
5232
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
4980
5233
|
|
|
4981
|
-
|
|
4982
|
-
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
4983
|
-
|
|
4984
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4985
|
-
|
|
5234
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
4986
5235
|
} break;
|
|
4987
5236
|
case WSP_GGML_OP_SET:
|
|
4988
5237
|
{
|
|
@@ -5288,7 +5537,6 @@ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t
|
|
|
5288
5537
|
}
|
|
5289
5538
|
|
|
5290
5539
|
wsp_ggml_backend_metal_buffer_rset_free(ctx);
|
|
5291
|
-
wsp_ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
5292
5540
|
|
|
5293
5541
|
if (ctx->owned) {
|
|
5294
5542
|
#if TARGET_OS_OSX
|
|
@@ -5397,7 +5645,10 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
5397
5645
|
}
|
|
5398
5646
|
|
|
5399
5647
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)buft->device->context;
|
|
5400
|
-
|
|
5648
|
+
|
|
5649
|
+
WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5650
|
+
|
|
5651
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5401
5652
|
|
|
5402
5653
|
ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
|
|
5403
5654
|
ctx->all_size = size_aligned;
|
|
@@ -5420,14 +5671,12 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
5420
5671
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
5421
5672
|
WSP_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
5422
5673
|
free(ctx);
|
|
5423
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5424
5674
|
return NULL;
|
|
5425
5675
|
}
|
|
5426
5676
|
|
|
5427
5677
|
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5428
5678
|
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5429
5679
|
free(ctx);
|
|
5430
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5431
5680
|
return NULL;
|
|
5432
5681
|
}
|
|
5433
5682
|
|
|
@@ -5438,17 +5687,14 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
5438
5687
|
|
|
5439
5688
|
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
5440
5689
|
return 32;
|
|
5690
|
+
|
|
5441
5691
|
WSP_GGML_UNUSED(buft);
|
|
5442
5692
|
}
|
|
5443
5693
|
|
|
5444
5694
|
static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_buffer_type_t buft) {
|
|
5445
|
-
|
|
5446
|
-
const size_t max_size = device.maxBufferLength;
|
|
5447
|
-
wsp_ggml_backend_metal_device_rel(buft->device->context);
|
|
5695
|
+
const size_t max_size = ((struct wsp_ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
|
5448
5696
|
|
|
5449
5697
|
return max_size;
|
|
5450
|
-
|
|
5451
|
-
WSP_GGML_UNUSED(buft);
|
|
5452
5698
|
}
|
|
5453
5699
|
|
|
5454
5700
|
static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
|
|
@@ -5521,7 +5767,10 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
|
|
|
5521
5767
|
}
|
|
5522
5768
|
|
|
5523
5769
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = &g_wsp_ggml_ctx_dev_main;
|
|
5524
|
-
|
|
5770
|
+
|
|
5771
|
+
WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5772
|
+
|
|
5773
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5525
5774
|
|
|
5526
5775
|
// the buffer fits into the max buffer size allowed by the device
|
|
5527
5776
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5577,7 +5826,6 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
|
|
|
5577
5826
|
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5578
5827
|
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5579
5828
|
free(ctx);
|
|
5580
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5581
5829
|
return NULL;
|
|
5582
5830
|
}
|
|
5583
5831
|
|
|
@@ -5593,10 +5841,8 @@ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
|
5593
5841
|
}
|
|
5594
5842
|
|
|
5595
5843
|
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
5596
|
-
struct wsp_ggml_backend_metal_context
|
|
5597
|
-
struct wsp_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5844
|
+
struct wsp_ggml_backend_metal_context * ctx = backend->context;
|
|
5598
5845
|
|
|
5599
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5600
5846
|
wsp_ggml_metal_free(ctx);
|
|
5601
5847
|
|
|
5602
5848
|
free(backend);
|
|
@@ -5736,6 +5982,8 @@ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int fami
|
|
|
5736
5982
|
|
|
5737
5983
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5738
5984
|
|
|
5985
|
+
WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5986
|
+
|
|
5739
5987
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
5740
5988
|
}
|
|
5741
5989
|
|
|
@@ -5755,10 +6003,7 @@ static const char * wsp_ggml_backend_metal_device_get_name(wsp_ggml_backend_dev_
|
|
|
5755
6003
|
}
|
|
5756
6004
|
|
|
5757
6005
|
static const char * wsp_ggml_backend_metal_device_get_description(wsp_ggml_backend_dev_t dev) {
|
|
5758
|
-
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
5759
6006
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
|
|
5760
|
-
wsp_ggml_backend_metal_device_acq(ctx_dev);
|
|
5761
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5762
6007
|
|
|
5763
6008
|
return ctx_dev->name;
|
|
5764
6009
|
}
|
|
@@ -5766,12 +6011,10 @@ static const char * wsp_ggml_backend_metal_device_get_description(wsp_ggml_backe
|
|
|
5766
6011
|
static void wsp_ggml_backend_metal_device_get_memory(wsp_ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
5767
6012
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
5768
6013
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
|
|
5769
|
-
id<MTLDevice> device =
|
|
6014
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5770
6015
|
|
|
5771
6016
|
*total = device.recommendedMaxWorkingSetSize;
|
|
5772
6017
|
*free = *total - device.currentAllocatedSize;
|
|
5773
|
-
|
|
5774
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5775
6018
|
} else {
|
|
5776
6019
|
*free = 1;
|
|
5777
6020
|
*total = 1;
|
|
@@ -5849,7 +6092,10 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
|
|
|
5849
6092
|
}
|
|
5850
6093
|
|
|
5851
6094
|
struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
|
|
5852
|
-
|
|
6095
|
+
|
|
6096
|
+
WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
6097
|
+
|
|
6098
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5853
6099
|
|
|
5854
6100
|
// the buffer fits into the max buffer size allowed by the device
|
|
5855
6101
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5905,7 +6151,6 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
|
|
|
5905
6151
|
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5906
6152
|
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5907
6153
|
free(ctx);
|
|
5908
|
-
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
5909
6154
|
return NULL;
|
|
5910
6155
|
}
|
|
5911
6156
|
|
|
@@ -5919,8 +6164,9 @@ static bool wsp_ggml_backend_metal_device_supports_op(wsp_ggml_backend_dev_t dev
|
|
|
5919
6164
|
}
|
|
5920
6165
|
|
|
5921
6166
|
static bool wsp_ggml_backend_metal_device_supports_buft(wsp_ggml_backend_dev_t dev, wsp_ggml_backend_buffer_type_t buft) {
|
|
5922
|
-
return
|
|
5923
|
-
|
|
6167
|
+
return
|
|
6168
|
+
buft->iface.get_name == wsp_ggml_backend_metal_buffer_type_get_name ||
|
|
6169
|
+
buft->iface.get_name == wsp_ggml_backend_metal_buffer_from_ptr_type_get_name;
|
|
5924
6170
|
|
|
5925
6171
|
WSP_GGML_UNUSED(dev);
|
|
5926
6172
|
}
|
|
@@ -6005,8 +6251,19 @@ static struct wsp_ggml_backend_reg_i wsp_ggml_backend_metal_reg_i = {
|
|
|
6005
6251
|
/* .get_proc_address = */ wsp_ggml_backend_metal_get_proc_address,
|
|
6006
6252
|
};
|
|
6007
6253
|
|
|
6254
|
+
// called upon program exit
|
|
6255
|
+
static void wsp_ggml_metal_cleanup(void) {
|
|
6256
|
+
wsp_ggml_backend_metal_device_rel(&g_wsp_ggml_ctx_dev_main);
|
|
6257
|
+
}
|
|
6258
|
+
|
|
6259
|
+
// TODO: make thread-safe
|
|
6008
6260
|
wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
|
|
6009
|
-
|
|
6261
|
+
wsp_ggml_backend_metal_device_acq(&g_wsp_ggml_ctx_dev_main);
|
|
6262
|
+
|
|
6263
|
+
// register cleanup callback
|
|
6264
|
+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
|
6265
|
+
atexit(wsp_ggml_metal_cleanup);
|
|
6266
|
+
|
|
6010
6267
|
{
|
|
6011
6268
|
g_wsp_ggml_backend_metal_reg = (struct wsp_ggml_backend_reg) {
|
|
6012
6269
|
/* .api_version = */ WSP_GGML_BACKEND_API_VERSION,
|