whisper.rn 0.5.3 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
- package/android/src/main/jni.cpp +13 -0
- package/cpp/ggml-alloc.c +78 -26
- package/cpp/ggml-alloc.h +9 -0
- package/cpp/ggml-backend-impl.h +1 -1
- package/cpp/ggml-backend-reg.cpp +19 -3
- package/cpp/ggml-backend.cpp +72 -20
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
- package/cpp/ggml-cpu/arch-fallback.h +50 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +139 -58
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +170 -18
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/repack.cpp +531 -5
- package/cpp/ggml-cpu/repack.h +14 -0
- package/cpp/ggml-cpu/simd-mappings.h +16 -18
- package/cpp/ggml-cpu/vec.cpp +41 -1
- package/cpp/ggml-cpu/vec.h +241 -138
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +0 -4
- package/cpp/ggml-metal/ggml-metal-context.m +26 -16
- package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
- package/cpp/ggml-metal/ggml-metal-device.h +87 -65
- package/cpp/ggml-metal/ggml-metal-device.m +263 -104
- package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
- package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
- package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
- package/cpp/ggml-metal/ggml-metal.cpp +6 -5
- package/cpp/ggml-metal/ggml-metal.metal +404 -34
- package/cpp/ggml.c +110 -31
- package/cpp/ggml.h +51 -12
- package/cpp/jsi/RNWhisperJSI.cpp +1 -0
- package/cpp/whisper.cpp +17 -4
- package/ios/CMakeLists.txt +21 -1
- package/ios/RNWhisperContext.mm +5 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/jest-mock.js +2 -0
- package/lib/commonjs/jest-mock.js.map +1 -1
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/jest-mock.js +2 -0
- package/lib/module/jest-mock.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +7 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/jest-mock.ts +2 -0
- package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
- package/src/realtime-transcription/types.ts +9 -0
- package/src/version.json +1 -1
|
@@ -50,14 +50,14 @@ void wsp_ggml_metal_pipelines_add(wsp_ggml_metal_pipelines_t ppls, const char *
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_pipelines_get(wsp_ggml_metal_pipelines_t ppls, const char * name) {
|
|
53
|
-
if
|
|
53
|
+
if (ppls->data.find(name) == ppls->data.end()) {
|
|
54
54
|
return nullptr;
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
return ppls->data[name];
|
|
58
58
|
}
|
|
59
59
|
|
|
60
|
-
|
|
60
|
+
struct wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_metal_library_t lib, wsp_ggml_op op) {
|
|
61
61
|
char base[256];
|
|
62
62
|
char name[256];
|
|
63
63
|
|
|
@@ -71,34 +71,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_meta
|
|
|
71
71
|
snprintf(base, 256, "kernel_%s", op_str);
|
|
72
72
|
snprintf(name, 256, "%s", base);
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
if (res) {
|
|
76
|
-
|
|
74
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
75
|
+
if (!res.pipeline) {
|
|
76
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
80
|
-
|
|
81
79
|
return res;
|
|
82
80
|
}
|
|
83
81
|
|
|
84
|
-
|
|
82
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cpy(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc, wsp_ggml_type tdst) {
|
|
85
83
|
char base[256];
|
|
86
84
|
char name[256];
|
|
87
85
|
|
|
88
86
|
snprintf(base, 256, "kernel_cpy_%s_%s", wsp_ggml_type_name(tsrc), wsp_ggml_type_name(tdst));
|
|
89
87
|
snprintf(name, 256, "%s", base);
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
if (res) {
|
|
93
|
-
|
|
89
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
90
|
+
if (!res.pipeline) {
|
|
91
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
94
92
|
}
|
|
95
93
|
|
|
96
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
97
|
-
|
|
98
94
|
return res;
|
|
99
95
|
}
|
|
100
96
|
|
|
101
|
-
|
|
97
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, wsp_ggml_op_pool op_pool) {
|
|
102
98
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
103
99
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32 && op->src[0]->type == op->type);
|
|
104
100
|
|
|
@@ -115,68 +111,60 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_m
|
|
|
115
111
|
snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, wsp_ggml_type_name(op->src[0]->type));
|
|
116
112
|
snprintf(name, 256, "%s", base);
|
|
117
113
|
|
|
118
|
-
|
|
119
|
-
if (res) {
|
|
120
|
-
|
|
114
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
115
|
+
if (!res.pipeline) {
|
|
116
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
121
117
|
}
|
|
122
118
|
|
|
123
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
124
|
-
|
|
125
119
|
return res;
|
|
126
120
|
}
|
|
127
121
|
|
|
128
|
-
|
|
122
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_get_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
|
|
129
123
|
char base[256];
|
|
130
124
|
char name[256];
|
|
131
125
|
|
|
132
126
|
snprintf(base, 256, "kernel_get_rows_%s", wsp_ggml_type_name(tsrc));
|
|
133
127
|
snprintf(name, 256, "%s", base);
|
|
134
128
|
|
|
135
|
-
|
|
136
|
-
if (res) {
|
|
137
|
-
|
|
129
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
130
|
+
if (!res.pipeline) {
|
|
131
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
138
132
|
}
|
|
139
133
|
|
|
140
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
141
|
-
|
|
142
134
|
return res;
|
|
143
135
|
}
|
|
144
136
|
|
|
145
|
-
|
|
137
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_set_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tidx, wsp_ggml_type tdst) {
|
|
146
138
|
char base[256];
|
|
147
139
|
char name[256];
|
|
148
140
|
|
|
149
141
|
snprintf(base, 256, "kernel_set_rows_%s_%s", wsp_ggml_type_name(tdst), wsp_ggml_type_name(tidx));
|
|
150
142
|
snprintf(name, 256, "%s", base);
|
|
151
143
|
|
|
152
|
-
|
|
153
|
-
if (res) {
|
|
154
|
-
|
|
144
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
145
|
+
if (!res.pipeline) {
|
|
146
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
155
147
|
}
|
|
156
148
|
|
|
157
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
158
|
-
|
|
159
149
|
return res;
|
|
160
150
|
}
|
|
161
151
|
|
|
162
|
-
|
|
152
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_repeat(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
|
|
163
153
|
char base[256];
|
|
164
154
|
char name[256];
|
|
165
155
|
|
|
166
156
|
snprintf(base, 256, "kernel_repeat_%s", wsp_ggml_type_name(tsrc));
|
|
167
157
|
snprintf(name, 256, "%s", base);
|
|
168
158
|
|
|
169
|
-
|
|
170
|
-
if (res) {
|
|
171
|
-
|
|
159
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
160
|
+
if (!res.pipeline) {
|
|
161
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
172
162
|
}
|
|
173
163
|
|
|
174
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
175
|
-
|
|
176
164
|
return res;
|
|
177
165
|
}
|
|
178
166
|
|
|
179
|
-
|
|
167
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
180
168
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
181
169
|
|
|
182
170
|
char base[256];
|
|
@@ -187,6 +175,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
|
|
|
187
175
|
const char * op_str = "undefined";
|
|
188
176
|
switch (op->op) {
|
|
189
177
|
case WSP_GGML_OP_SCALE: op_str = "scale"; break;
|
|
178
|
+
case WSP_GGML_OP_FILL: op_str = "fill"; break;
|
|
190
179
|
case WSP_GGML_OP_CLAMP: op_str = "clamp"; break;
|
|
191
180
|
case WSP_GGML_OP_SQR: op_str = "sqr"; break;
|
|
192
181
|
case WSP_GGML_OP_SQRT: op_str = "sqrt"; break;
|
|
@@ -211,6 +200,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
|
|
|
211
200
|
case WSP_GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
|
212
201
|
case WSP_GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
|
213
202
|
case WSP_GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
|
203
|
+
case WSP_GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
|
|
204
|
+
case WSP_GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
|
|
214
205
|
default: WSP_GGML_ABORT("fatal error");
|
|
215
206
|
} break;
|
|
216
207
|
default: WSP_GGML_ABORT("fatal error");
|
|
@@ -224,17 +215,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_met
|
|
|
224
215
|
snprintf(base, 256, "kernel_%s_%s%s", op_str, wsp_ggml_type_name(op->src[0]->type), suffix);
|
|
225
216
|
snprintf(name, 256, "%s", base);
|
|
226
217
|
|
|
227
|
-
|
|
228
|
-
if (res) {
|
|
229
|
-
|
|
218
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
219
|
+
if (!res.pipeline) {
|
|
220
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
230
221
|
}
|
|
231
222
|
|
|
232
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
233
|
-
|
|
234
223
|
return res;
|
|
235
224
|
}
|
|
236
225
|
|
|
237
|
-
|
|
226
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
238
227
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
|
|
239
228
|
|
|
240
229
|
char base[256];
|
|
@@ -258,17 +247,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
|
|
|
258
247
|
snprintf(base, 256, "kernel_%s_%s", op_str, wsp_ggml_type_name(op->src[0]->type));
|
|
259
248
|
snprintf(name, 256, "%s", base);
|
|
260
249
|
|
|
261
|
-
|
|
262
|
-
if (res) {
|
|
263
|
-
|
|
250
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
251
|
+
if (!res.pipeline) {
|
|
252
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
264
253
|
}
|
|
265
254
|
|
|
266
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
267
|
-
|
|
268
255
|
return res;
|
|
269
256
|
}
|
|
270
257
|
|
|
271
|
-
|
|
258
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
259
|
assert(op->op == WSP_GGML_OP_SUM);
|
|
273
260
|
|
|
274
261
|
char base[256];
|
|
@@ -277,17 +264,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal
|
|
|
277
264
|
snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
278
265
|
snprintf(name, 256, "%s", base);
|
|
279
266
|
|
|
280
|
-
|
|
281
|
-
if (res) {
|
|
282
|
-
|
|
267
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
268
|
+
if (!res.pipeline) {
|
|
269
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
283
270
|
}
|
|
284
271
|
|
|
285
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
286
|
-
|
|
287
272
|
return res;
|
|
288
273
|
}
|
|
289
274
|
|
|
290
|
-
|
|
275
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
291
276
|
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
292
277
|
|
|
293
278
|
char base[256];
|
|
@@ -306,19 +291,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
|
|
|
306
291
|
|
|
307
292
|
snprintf(name, 256, "%s", base);
|
|
308
293
|
|
|
309
|
-
|
|
310
|
-
if (res) {
|
|
311
|
-
|
|
294
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
295
|
+
if (!res.pipeline) {
|
|
296
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
312
297
|
}
|
|
313
298
|
|
|
314
|
-
res =
|
|
315
|
-
|
|
316
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
299
|
+
res.smem = 32*sizeof(float);
|
|
317
300
|
|
|
318
301
|
return res;
|
|
319
302
|
}
|
|
320
303
|
|
|
321
|
-
|
|
304
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
322
305
|
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
323
306
|
|
|
324
307
|
char base[256];
|
|
@@ -327,17 +310,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggm
|
|
|
327
310
|
snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
328
311
|
snprintf(name, 256, "%s", base);
|
|
329
312
|
|
|
330
|
-
|
|
331
|
-
if (res) {
|
|
332
|
-
|
|
313
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
314
|
+
if (!res.pipeline) {
|
|
315
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
333
316
|
}
|
|
334
317
|
|
|
335
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
336
|
-
|
|
337
318
|
return res;
|
|
338
319
|
}
|
|
339
320
|
|
|
340
|
-
|
|
321
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
341
322
|
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
342
323
|
|
|
343
324
|
char base[256];
|
|
@@ -346,17 +327,37 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggm
|
|
|
346
327
|
snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
347
328
|
snprintf(name, 256, "%s", base);
|
|
348
329
|
|
|
349
|
-
|
|
350
|
-
if (res) {
|
|
351
|
-
|
|
330
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
331
|
+
if (!res.pipeline) {
|
|
332
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
352
333
|
}
|
|
353
334
|
|
|
354
|
-
res
|
|
335
|
+
return res;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_tri(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
339
|
+
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_TRI);
|
|
340
|
+
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
341
|
+
|
|
342
|
+
char base[256];
|
|
343
|
+
char name[256];
|
|
344
|
+
|
|
345
|
+
const char * op_str = "tri";
|
|
346
|
+
const int ttype = op->op_params[0];
|
|
347
|
+
|
|
348
|
+
snprintf(base, 256, "kernel_%s_%s_%d", op_str, wsp_ggml_type_name(op->src[0]->type), ttype);
|
|
349
|
+
|
|
350
|
+
snprintf(name, 256, "%s", base);
|
|
351
|
+
|
|
352
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
353
|
+
if (!res.pipeline) {
|
|
354
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
355
|
+
}
|
|
355
356
|
|
|
356
357
|
return res;
|
|
357
358
|
}
|
|
358
359
|
|
|
359
|
-
|
|
360
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
360
361
|
WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
361
362
|
|
|
362
363
|
char base[256];
|
|
@@ -373,19 +374,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_
|
|
|
373
374
|
snprintf(base, 256, "kernel_soft_max_%s%s", wsp_ggml_type_name(tsrc1), suffix);
|
|
374
375
|
snprintf(name, 256, "%s", base);
|
|
375
376
|
|
|
376
|
-
|
|
377
|
-
if (res) {
|
|
378
|
-
|
|
377
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
378
|
+
if (!res.pipeline) {
|
|
379
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
379
380
|
}
|
|
380
381
|
|
|
381
|
-
res =
|
|
382
|
-
|
|
383
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
382
|
+
res.smem = 32*sizeof(float);
|
|
384
383
|
|
|
385
384
|
return res;
|
|
386
385
|
}
|
|
387
386
|
|
|
388
|
-
|
|
387
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
389
388
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
390
389
|
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
391
390
|
|
|
@@ -404,17 +403,47 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
|
|
|
404
403
|
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
|
|
405
404
|
snprintf(name, 256, "%s", base);
|
|
406
405
|
|
|
407
|
-
|
|
408
|
-
if (res) {
|
|
409
|
-
|
|
406
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
407
|
+
if (!res.pipeline) {
|
|
408
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
410
409
|
}
|
|
411
410
|
|
|
412
|
-
res
|
|
411
|
+
return res;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_conv_batched(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int ssm_conv_bs) {
|
|
415
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
416
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
417
|
+
|
|
418
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
419
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
420
|
+
|
|
421
|
+
char base[256];
|
|
422
|
+
char name[256];
|
|
423
|
+
|
|
424
|
+
const char * suffix = "";
|
|
425
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
426
|
+
suffix = "_4";
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
|
|
430
|
+
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
|
|
431
|
+
|
|
432
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
433
|
+
if (!res.pipeline) {
|
|
434
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
435
|
+
|
|
436
|
+
wsp_ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
|
|
437
|
+
|
|
438
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
439
|
+
|
|
440
|
+
wsp_ggml_metal_cv_free(cv);
|
|
441
|
+
}
|
|
413
442
|
|
|
414
443
|
return res;
|
|
415
444
|
}
|
|
416
445
|
|
|
417
|
-
|
|
446
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
418
447
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
419
448
|
|
|
420
449
|
char base[256];
|
|
@@ -425,19 +454,22 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
|
|
|
425
454
|
snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
426
455
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
427
456
|
|
|
428
|
-
|
|
429
|
-
if (res) {
|
|
430
|
-
|
|
457
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
458
|
+
if (!res.pipeline) {
|
|
459
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
431
460
|
}
|
|
432
461
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
462
|
+
// Shared memory layout:
|
|
463
|
+
// - sgptg * NW floats for partial sums (nsg * 32)
|
|
464
|
+
// - sgptg floats for shared_x_dt (nsg)
|
|
465
|
+
// - sgptg floats for shared_dA (nsg)
|
|
466
|
+
// Total: nsg * (32 + 2) floats
|
|
467
|
+
res.smem = (32 + 2)*sizeof(float)*nsg;
|
|
436
468
|
|
|
437
469
|
return res;
|
|
438
470
|
}
|
|
439
471
|
|
|
440
|
-
|
|
472
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
441
473
|
char base[256];
|
|
442
474
|
char name[256];
|
|
443
475
|
|
|
@@ -467,41 +499,37 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_meta
|
|
|
467
499
|
|
|
468
500
|
snprintf(name, 256, "%s", base);
|
|
469
501
|
|
|
470
|
-
|
|
471
|
-
if (res) {
|
|
472
|
-
|
|
502
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
503
|
+
if (!res.pipeline) {
|
|
504
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
473
505
|
}
|
|
474
506
|
|
|
475
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
476
|
-
|
|
477
507
|
return res;
|
|
478
508
|
}
|
|
479
509
|
|
|
480
|
-
|
|
510
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv_ext(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc0, wsp_ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
|
481
511
|
char base[256];
|
|
482
512
|
char name[256];
|
|
483
513
|
|
|
484
514
|
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), r1ptg);
|
|
485
515
|
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
|
486
516
|
|
|
487
|
-
|
|
488
|
-
if (res) {
|
|
489
|
-
|
|
490
|
-
}
|
|
491
|
-
|
|
492
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
517
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
518
|
+
if (!res.pipeline) {
|
|
519
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
493
520
|
|
|
494
|
-
|
|
495
|
-
|
|
521
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
522
|
+
wsp_ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
|
496
523
|
|
|
497
|
-
|
|
524
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
498
525
|
|
|
499
|
-
|
|
526
|
+
wsp_ggml_metal_cv_free(cv);
|
|
527
|
+
}
|
|
500
528
|
|
|
501
529
|
return res;
|
|
502
530
|
}
|
|
503
531
|
|
|
504
|
-
|
|
532
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
505
533
|
char base[256];
|
|
506
534
|
char name[256];
|
|
507
535
|
|
|
@@ -514,27 +542,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_me
|
|
|
514
542
|
snprintf(base, 256, "kernel_mul_mm_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
|
|
515
543
|
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
|
516
544
|
|
|
517
|
-
|
|
518
|
-
if (res) {
|
|
519
|
-
|
|
520
|
-
}
|
|
521
|
-
|
|
522
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
545
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
546
|
+
if (!res.pipeline) {
|
|
547
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
523
548
|
|
|
524
|
-
|
|
525
|
-
|
|
549
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
550
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
|
526
551
|
|
|
527
|
-
|
|
552
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
528
553
|
|
|
529
|
-
|
|
554
|
+
wsp_ggml_metal_cv_free(cv);
|
|
555
|
+
}
|
|
530
556
|
|
|
531
557
|
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
|
532
|
-
|
|
558
|
+
res.smem = bc_out ? 8192 : 4096 + 2048;
|
|
533
559
|
|
|
534
560
|
return res;
|
|
535
561
|
}
|
|
536
562
|
|
|
537
|
-
|
|
563
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
538
564
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
539
565
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
540
566
|
|
|
@@ -689,49 +715,43 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_me
|
|
|
689
715
|
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
|
|
690
716
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
691
717
|
|
|
692
|
-
|
|
693
|
-
if (res) {
|
|
694
|
-
|
|
695
|
-
}
|
|
696
|
-
|
|
697
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
718
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
719
|
+
if (!res.pipeline) {
|
|
720
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
698
721
|
|
|
699
|
-
|
|
722
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
700
723
|
|
|
701
|
-
|
|
724
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
702
725
|
|
|
703
|
-
|
|
726
|
+
wsp_ggml_metal_cv_free(cv);
|
|
727
|
+
}
|
|
704
728
|
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
729
|
+
res.nr0 = nr0;
|
|
730
|
+
res.nr1 = nr1;
|
|
731
|
+
res.nsg = nsg;
|
|
732
|
+
res.smem = smem;
|
|
709
733
|
|
|
710
734
|
return res;
|
|
711
735
|
}
|
|
712
736
|
|
|
713
|
-
|
|
737
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp_ggml_metal_library_t lib, int ne02, int ne20) {
|
|
714
738
|
char base[256];
|
|
715
739
|
char name[256];
|
|
716
740
|
|
|
717
741
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
718
742
|
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
|
719
743
|
|
|
720
|
-
|
|
721
|
-
if (res) {
|
|
722
|
-
|
|
744
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
745
|
+
if (!res.pipeline) {
|
|
746
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
723
747
|
}
|
|
724
748
|
|
|
725
|
-
res =
|
|
726
|
-
|
|
727
|
-
const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
728
|
-
|
|
729
|
-
wsp_ggml_metal_pipeline_set_smem(res, smem);
|
|
749
|
+
res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
730
750
|
|
|
731
751
|
return res;
|
|
732
752
|
}
|
|
733
753
|
|
|
734
|
-
|
|
754
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
735
755
|
char base[256];
|
|
736
756
|
char name[256];
|
|
737
757
|
|
|
@@ -743,25 +763,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml
|
|
|
743
763
|
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
|
|
744
764
|
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
|
745
765
|
|
|
746
|
-
|
|
747
|
-
if (res) {
|
|
748
|
-
|
|
749
|
-
}
|
|
766
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
767
|
+
if (!res.pipeline) {
|
|
768
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
750
769
|
|
|
751
|
-
|
|
770
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
752
771
|
|
|
753
|
-
|
|
772
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
754
773
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
wsp_ggml_metal_cv_free(cv);
|
|
774
|
+
wsp_ggml_metal_cv_free(cv);
|
|
775
|
+
}
|
|
758
776
|
|
|
759
|
-
|
|
777
|
+
res.smem = 8192;
|
|
760
778
|
|
|
761
779
|
return res;
|
|
762
780
|
}
|
|
763
781
|
|
|
764
|
-
|
|
782
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
765
783
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
766
784
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
767
785
|
|
|
@@ -909,28 +927,26 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml
|
|
|
909
927
|
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
|
|
910
928
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
911
929
|
|
|
912
|
-
|
|
913
|
-
if (res) {
|
|
914
|
-
|
|
915
|
-
}
|
|
930
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
931
|
+
if (!res.pipeline) {
|
|
932
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
916
933
|
|
|
917
|
-
|
|
934
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
918
935
|
|
|
919
|
-
|
|
936
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
920
937
|
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
wsp_ggml_metal_cv_free(cv);
|
|
938
|
+
wsp_ggml_metal_cv_free(cv);
|
|
939
|
+
}
|
|
924
940
|
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
941
|
+
res.nr0 = nr0;
|
|
942
|
+
res.nr1 = nr1;
|
|
943
|
+
res.nsg = nsg;
|
|
944
|
+
res.smem = smem;
|
|
929
945
|
|
|
930
946
|
return res;
|
|
931
947
|
}
|
|
932
948
|
|
|
933
|
-
|
|
949
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
934
950
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
935
951
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
|
|
936
952
|
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
@@ -941,19 +957,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_me
|
|
|
941
957
|
snprintf(base, 256, "kernel_argmax_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
942
958
|
snprintf(name, 256, "%s", base);
|
|
943
959
|
|
|
944
|
-
|
|
945
|
-
if (res) {
|
|
946
|
-
|
|
960
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
961
|
+
if (!res.pipeline) {
|
|
962
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
947
963
|
}
|
|
948
964
|
|
|
949
|
-
res =
|
|
950
|
-
|
|
951
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
|
|
965
|
+
res.smem = 32*(sizeof(float) + sizeof(int32_t));
|
|
952
966
|
|
|
953
967
|
return res;
|
|
954
968
|
}
|
|
955
969
|
|
|
956
|
-
|
|
970
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
957
971
|
assert(op->op == WSP_GGML_OP_ARGSORT);
|
|
958
972
|
|
|
959
973
|
char base[256];
|
|
@@ -971,17 +985,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
|
|
|
971
985
|
snprintf(base, 256, "kernel_argsort_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
972
986
|
snprintf(name, 256, "%s", base);
|
|
973
987
|
|
|
974
|
-
|
|
975
|
-
if (res) {
|
|
976
|
-
|
|
988
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
989
|
+
if (!res.pipeline) {
|
|
990
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
977
991
|
}
|
|
978
992
|
|
|
979
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
980
|
-
|
|
981
993
|
return res;
|
|
982
994
|
}
|
|
983
995
|
|
|
984
|
-
|
|
996
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
985
997
|
assert(op->op == WSP_GGML_OP_ARGSORT);
|
|
986
998
|
|
|
987
999
|
char base[256];
|
|
@@ -999,17 +1011,69 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_
|
|
|
999
1011
|
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
1000
1012
|
snprintf(name, 256, "%s", base);
|
|
1001
1013
|
|
|
1002
|
-
|
|
1003
|
-
if (res) {
|
|
1004
|
-
|
|
1014
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1015
|
+
if (!res.pipeline) {
|
|
1016
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1005
1017
|
}
|
|
1006
1018
|
|
|
1007
|
-
res
|
|
1019
|
+
return res;
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
// note: reuse the argsort kernel for top_k
|
|
1023
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_top_k(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1024
|
+
assert(op->op == WSP_GGML_OP_TOP_K);
|
|
1025
|
+
|
|
1026
|
+
char base[256];
|
|
1027
|
+
char name[256];
|
|
1028
|
+
|
|
1029
|
+
// note: the top_k kernel is always descending order
|
|
1030
|
+
wsp_ggml_sort_order order = WSP_GGML_SORT_ORDER_DESC;
|
|
1031
|
+
|
|
1032
|
+
const char * order_str = "undefined";
|
|
1033
|
+
switch (order) {
|
|
1034
|
+
case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1035
|
+
case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1036
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1037
|
+
};
|
|
1038
|
+
|
|
1039
|
+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
1040
|
+
snprintf(name, 256, "%s", base);
|
|
1041
|
+
|
|
1042
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1043
|
+
if (!res.pipeline) {
|
|
1044
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1045
|
+
}
|
|
1008
1046
|
|
|
1009
1047
|
return res;
|
|
1010
1048
|
}
|
|
1011
1049
|
|
|
1012
|
-
|
|
1050
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_top_k_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1051
|
+
assert(op->op == WSP_GGML_OP_TOP_K);
|
|
1052
|
+
|
|
1053
|
+
char base[256];
|
|
1054
|
+
char name[256];
|
|
1055
|
+
|
|
1056
|
+
wsp_ggml_sort_order order = WSP_GGML_SORT_ORDER_DESC;
|
|
1057
|
+
|
|
1058
|
+
const char * order_str = "undefined";
|
|
1059
|
+
switch (order) {
|
|
1060
|
+
case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1061
|
+
case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1062
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1063
|
+
};
|
|
1064
|
+
|
|
1065
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
1066
|
+
snprintf(name, 256, "%s", base);
|
|
1067
|
+
|
|
1068
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1069
|
+
if (!res.pipeline) {
|
|
1070
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1071
|
+
}
|
|
1072
|
+
|
|
1073
|
+
return res;
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
1013
1077
|
wsp_ggml_metal_library_t lib,
|
|
1014
1078
|
const struct wsp_ggml_tensor * op,
|
|
1015
1079
|
bool has_mask,
|
|
@@ -1028,33 +1092,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad
|
|
|
1028
1092
|
has_mask,
|
|
1029
1093
|
ncpsg);
|
|
1030
1094
|
|
|
1031
|
-
|
|
1032
|
-
if (res) {
|
|
1033
|
-
|
|
1034
|
-
}
|
|
1035
|
-
|
|
1036
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1095
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1096
|
+
if (!res.pipeline) {
|
|
1097
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1037
1098
|
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1099
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
1100
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
1101
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
1102
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
1042
1103
|
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1104
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
1105
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
1106
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
1107
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
1108
|
+
//wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
1109
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
1049
1110
|
|
|
1050
|
-
|
|
1111
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1051
1112
|
|
|
1052
|
-
|
|
1113
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1114
|
+
}
|
|
1053
1115
|
|
|
1054
1116
|
return res;
|
|
1055
1117
|
}
|
|
1056
1118
|
|
|
1057
|
-
|
|
1119
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
1058
1120
|
wsp_ggml_metal_library_t lib,
|
|
1059
1121
|
const struct wsp_ggml_tensor * op,
|
|
1060
1122
|
int32_t nqptg,
|
|
@@ -1073,33 +1135,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk
|
|
|
1073
1135
|
nqptg,
|
|
1074
1136
|
ncpsg);
|
|
1075
1137
|
|
|
1076
|
-
|
|
1077
|
-
if (res) {
|
|
1078
|
-
|
|
1079
|
-
}
|
|
1080
|
-
|
|
1081
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1138
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1139
|
+
if (!res.pipeline) {
|
|
1140
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1082
1141
|
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1142
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1143
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1144
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1145
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1087
1146
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1147
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1148
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1149
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1150
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1151
|
+
wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1152
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1094
1153
|
|
|
1095
|
-
|
|
1154
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1096
1155
|
|
|
1097
|
-
|
|
1156
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1157
|
+
}
|
|
1098
1158
|
|
|
1099
1159
|
return res;
|
|
1100
1160
|
}
|
|
1101
1161
|
|
|
1102
|
-
|
|
1162
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
1103
1163
|
wsp_ggml_metal_library_t lib,
|
|
1104
1164
|
const wsp_ggml_tensor * op,
|
|
1105
1165
|
bool has_mask,
|
|
@@ -1140,33 +1200,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
1140
1200
|
ns20,
|
|
1141
1201
|
nsg);
|
|
1142
1202
|
|
|
1143
|
-
|
|
1144
|
-
if (res) {
|
|
1145
|
-
|
|
1146
|
-
}
|
|
1147
|
-
|
|
1148
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1203
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1204
|
+
if (!res.pipeline) {
|
|
1205
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1149
1206
|
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1207
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
|
|
1208
|
+
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
1209
|
+
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
1210
|
+
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1211
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
1155
1212
|
|
|
1156
|
-
|
|
1213
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
1157
1214
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1215
|
+
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
1216
|
+
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
1217
|
+
wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
|
|
1161
1218
|
|
|
1162
|
-
|
|
1219
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1163
1220
|
|
|
1164
|
-
|
|
1221
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1222
|
+
}
|
|
1165
1223
|
|
|
1166
1224
|
return res;
|
|
1167
1225
|
}
|
|
1168
1226
|
|
|
1169
|
-
|
|
1227
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
1170
1228
|
wsp_ggml_metal_library_t lib,
|
|
1171
1229
|
const wsp_ggml_tensor * op,
|
|
1172
1230
|
bool has_mask,
|
|
@@ -1204,32 +1262,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1204
1262
|
ns20,
|
|
1205
1263
|
nsg, nwg);
|
|
1206
1264
|
|
|
1207
|
-
|
|
1208
|
-
if (res) {
|
|
1209
|
-
|
|
1210
|
-
}
|
|
1211
|
-
|
|
1212
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1265
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1266
|
+
if (!res.pipeline) {
|
|
1267
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1213
1268
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1269
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
|
|
1270
|
+
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1271
|
+
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1272
|
+
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1273
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1219
1274
|
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1275
|
+
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1276
|
+
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
1277
|
+
wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
|
|
1278
|
+
wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
|
|
1224
1279
|
|
|
1225
|
-
|
|
1280
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1226
1281
|
|
|
1227
|
-
|
|
1282
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1283
|
+
}
|
|
1228
1284
|
|
|
1229
1285
|
return res;
|
|
1230
1286
|
}
|
|
1231
1287
|
|
|
1232
|
-
|
|
1288
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
1233
1289
|
wsp_ggml_metal_library_t lib,
|
|
1234
1290
|
const wsp_ggml_tensor * op,
|
|
1235
1291
|
int32_t dv,
|
|
@@ -1242,26 +1298,24 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1242
1298
|
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
|
1243
1299
|
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
|
1244
1300
|
|
|
1245
|
-
|
|
1246
|
-
if (res) {
|
|
1247
|
-
|
|
1248
|
-
}
|
|
1249
|
-
|
|
1250
|
-
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1301
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1302
|
+
if (!res.pipeline) {
|
|
1303
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1251
1304
|
|
|
1252
|
-
|
|
1253
|
-
|
|
1305
|
+
wsp_ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
|
|
1306
|
+
wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
|
|
1254
1307
|
|
|
1255
|
-
|
|
1308
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1256
1309
|
|
|
1257
|
-
|
|
1310
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1311
|
+
}
|
|
1258
1312
|
|
|
1259
1313
|
return res;
|
|
1260
1314
|
|
|
1261
1315
|
WSP_GGML_UNUSED(op);
|
|
1262
1316
|
}
|
|
1263
1317
|
|
|
1264
|
-
|
|
1318
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_bin(
|
|
1265
1319
|
wsp_ggml_metal_library_t lib,
|
|
1266
1320
|
wsp_ggml_op op,
|
|
1267
1321
|
int32_t n_fuse,
|
|
@@ -1286,17 +1340,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin(
|
|
|
1286
1340
|
|
|
1287
1341
|
snprintf(name, 256, "%s", base);
|
|
1288
1342
|
|
|
1289
|
-
|
|
1290
|
-
if (res) {
|
|
1291
|
-
|
|
1343
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1344
|
+
if (!res.pipeline) {
|
|
1345
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1292
1346
|
}
|
|
1293
1347
|
|
|
1294
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1295
|
-
|
|
1296
1348
|
return res;
|
|
1297
1349
|
}
|
|
1298
1350
|
|
|
1299
|
-
|
|
1351
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1300
1352
|
assert(op->op == WSP_GGML_OP_L2_NORM);
|
|
1301
1353
|
|
|
1302
1354
|
WSP_GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
|
@@ -1308,19 +1360,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_m
|
|
|
1308
1360
|
snprintf(base, 256, "kernel_l2_norm_f32");
|
|
1309
1361
|
snprintf(name, 256, "%s", base);
|
|
1310
1362
|
|
|
1311
|
-
|
|
1312
|
-
if (res) {
|
|
1313
|
-
|
|
1363
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1364
|
+
if (!res.pipeline) {
|
|
1365
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1314
1366
|
}
|
|
1315
1367
|
|
|
1316
|
-
res =
|
|
1317
|
-
|
|
1318
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1368
|
+
res.smem = 32*sizeof(float);
|
|
1319
1369
|
|
|
1320
1370
|
return res;
|
|
1321
1371
|
}
|
|
1322
1372
|
|
|
1323
|
-
|
|
1373
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1324
1374
|
assert(op->op == WSP_GGML_OP_GROUP_NORM);
|
|
1325
1375
|
|
|
1326
1376
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
@@ -1331,19 +1381,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggm
|
|
|
1331
1381
|
snprintf(base, 256, "kernel_group_norm_f32");
|
|
1332
1382
|
snprintf(name, 256, "%s", base);
|
|
1333
1383
|
|
|
1334
|
-
|
|
1335
|
-
if (res) {
|
|
1336
|
-
|
|
1384
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1385
|
+
if (!res.pipeline) {
|
|
1386
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1337
1387
|
}
|
|
1338
1388
|
|
|
1339
|
-
res =
|
|
1340
|
-
|
|
1341
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1389
|
+
res.smem = 32*sizeof(float);
|
|
1342
1390
|
|
|
1343
1391
|
return res;
|
|
1344
1392
|
}
|
|
1345
1393
|
|
|
1346
|
-
|
|
1394
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int n_fuse) {
|
|
1347
1395
|
assert(op->op == WSP_GGML_OP_NORM || op->op == WSP_GGML_OP_RMS_NORM);
|
|
1348
1396
|
|
|
1349
1397
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
@@ -1376,19 +1424,17 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_meta
|
|
|
1376
1424
|
|
|
1377
1425
|
snprintf(name, 256, "%s", base);
|
|
1378
1426
|
|
|
1379
|
-
|
|
1380
|
-
if (res) {
|
|
1381
|
-
|
|
1427
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1428
|
+
if (!res.pipeline) {
|
|
1429
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1382
1430
|
}
|
|
1383
1431
|
|
|
1384
|
-
res =
|
|
1385
|
-
|
|
1386
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1432
|
+
res.smem = 32*sizeof(float);
|
|
1387
1433
|
|
|
1388
1434
|
return res;
|
|
1389
1435
|
}
|
|
1390
1436
|
|
|
1391
|
-
|
|
1437
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1392
1438
|
assert(op->op == WSP_GGML_OP_ROPE);
|
|
1393
1439
|
|
|
1394
1440
|
char base[256];
|
|
@@ -1415,23 +1461,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
|
|
|
1415
1461
|
|
|
1416
1462
|
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
|
1417
1463
|
|
|
1418
|
-
|
|
1419
|
-
if (res) {
|
|
1420
|
-
|
|
1421
|
-
}
|
|
1464
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1465
|
+
if (!res.pipeline) {
|
|
1466
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1422
1467
|
|
|
1423
|
-
|
|
1468
|
+
wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
|
1424
1469
|
|
|
1425
|
-
|
|
1470
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1426
1471
|
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
wsp_ggml_metal_cv_free(cv);
|
|
1472
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1473
|
+
}
|
|
1430
1474
|
|
|
1431
1475
|
return res;
|
|
1432
1476
|
}
|
|
1433
1477
|
|
|
1434
|
-
|
|
1478
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1435
1479
|
assert(op->op == WSP_GGML_OP_IM2COL);
|
|
1436
1480
|
|
|
1437
1481
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
@@ -1444,17 +1488,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_me
|
|
|
1444
1488
|
snprintf(base, 256, "kernel_im2col_%s", wsp_ggml_type_name(op->type));
|
|
1445
1489
|
snprintf(name, 256, "%s", base);
|
|
1446
1490
|
|
|
1447
|
-
|
|
1448
|
-
if (res) {
|
|
1449
|
-
|
|
1491
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1492
|
+
if (!res.pipeline) {
|
|
1493
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1450
1494
|
}
|
|
1451
1495
|
|
|
1452
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1453
|
-
|
|
1454
1496
|
return res;
|
|
1455
1497
|
}
|
|
1456
1498
|
|
|
1457
|
-
|
|
1499
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1458
1500
|
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_1D);
|
|
1459
1501
|
|
|
1460
1502
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
@@ -1469,17 +1511,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
|
|
|
1469
1511
|
snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1470
1512
|
snprintf(name, 256, "%s", base);
|
|
1471
1513
|
|
|
1472
|
-
|
|
1473
|
-
if (res) {
|
|
1474
|
-
|
|
1514
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1515
|
+
if (!res.pipeline) {
|
|
1516
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1475
1517
|
}
|
|
1476
1518
|
|
|
1477
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1478
|
-
|
|
1479
1519
|
return res;
|
|
1480
1520
|
}
|
|
1481
1521
|
|
|
1482
|
-
|
|
1522
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1483
1523
|
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
|
|
1484
1524
|
|
|
1485
1525
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
@@ -1494,17 +1534,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(
|
|
|
1494
1534
|
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1495
1535
|
snprintf(name, 256, "%s", base);
|
|
1496
1536
|
|
|
1497
|
-
|
|
1498
|
-
if (res) {
|
|
1499
|
-
|
|
1537
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1538
|
+
if (!res.pipeline) {
|
|
1539
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1500
1540
|
}
|
|
1501
1541
|
|
|
1502
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1503
|
-
|
|
1504
1542
|
return res;
|
|
1505
1543
|
}
|
|
1506
1544
|
|
|
1507
|
-
|
|
1545
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1508
1546
|
assert(op->op == WSP_GGML_OP_CONV_2D);
|
|
1509
1547
|
|
|
1510
1548
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
@@ -1518,17 +1556,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_m
|
|
|
1518
1556
|
snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1519
1557
|
snprintf(name, 256, "%s", base);
|
|
1520
1558
|
|
|
1521
|
-
|
|
1522
|
-
if (res) {
|
|
1523
|
-
|
|
1559
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1560
|
+
if (!res.pipeline) {
|
|
1561
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1524
1562
|
}
|
|
1525
1563
|
|
|
1526
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1527
|
-
|
|
1528
1564
|
return res;
|
|
1529
1565
|
}
|
|
1530
1566
|
|
|
1531
|
-
|
|
1567
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1532
1568
|
assert(op->op == WSP_GGML_OP_UPSCALE);
|
|
1533
1569
|
|
|
1534
1570
|
char base[256];
|
|
@@ -1537,17 +1573,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_m
|
|
|
1537
1573
|
snprintf(base, 256, "kernel_upscale_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1538
1574
|
snprintf(name, 256, "%s", base);
|
|
1539
1575
|
|
|
1540
|
-
|
|
1541
|
-
if (res) {
|
|
1542
|
-
|
|
1576
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1577
|
+
if (!res.pipeline) {
|
|
1578
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1543
1579
|
}
|
|
1544
1580
|
|
|
1545
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1546
|
-
|
|
1547
1581
|
return res;
|
|
1548
1582
|
}
|
|
1549
1583
|
|
|
1550
|
-
|
|
1584
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1551
1585
|
assert(op->op == WSP_GGML_OP_PAD);
|
|
1552
1586
|
|
|
1553
1587
|
char base[256];
|
|
@@ -1556,8 +1590,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal
|
|
|
1556
1590
|
snprintf(base, 256, "kernel_pad_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1557
1591
|
snprintf(name, 256, "%s", base);
|
|
1558
1592
|
|
|
1559
|
-
|
|
1560
|
-
if (res) {
|
|
1593
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1594
|
+
if (res.pipeline) {
|
|
1561
1595
|
return res;
|
|
1562
1596
|
}
|
|
1563
1597
|
|
|
@@ -1566,7 +1600,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal
|
|
|
1566
1600
|
return res;
|
|
1567
1601
|
}
|
|
1568
1602
|
|
|
1569
|
-
|
|
1603
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1570
1604
|
assert(op->op == WSP_GGML_OP_PAD_REFLECT_1D);
|
|
1571
1605
|
|
|
1572
1606
|
char base[256];
|
|
@@ -1575,17 +1609,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp
|
|
|
1575
1609
|
snprintf(base, 256, "kernel_pad_reflect_1d_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1576
1610
|
snprintf(name, 256, "%s", base);
|
|
1577
1611
|
|
|
1578
|
-
|
|
1579
|
-
if (res) {
|
|
1580
|
-
|
|
1612
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1613
|
+
if (!res.pipeline) {
|
|
1614
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1581
1615
|
}
|
|
1582
1616
|
|
|
1583
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1584
|
-
|
|
1585
1617
|
return res;
|
|
1586
1618
|
}
|
|
1587
1619
|
|
|
1588
|
-
|
|
1620
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1589
1621
|
assert(op->op == WSP_GGML_OP_ARANGE);
|
|
1590
1622
|
|
|
1591
1623
|
char base[256];
|
|
@@ -1594,17 +1626,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_me
|
|
|
1594
1626
|
snprintf(base, 256, "kernel_arange_%s", wsp_ggml_type_name(op->type));
|
|
1595
1627
|
snprintf(name, 256, "%s", base);
|
|
1596
1628
|
|
|
1597
|
-
|
|
1598
|
-
if (res) {
|
|
1599
|
-
|
|
1629
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1630
|
+
if (!res.pipeline) {
|
|
1631
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1600
1632
|
}
|
|
1601
1633
|
|
|
1602
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1603
|
-
|
|
1604
1634
|
return res;
|
|
1605
1635
|
}
|
|
1606
1636
|
|
|
1607
|
-
|
|
1637
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1608
1638
|
assert(op->op == WSP_GGML_OP_TIMESTEP_EMBEDDING);
|
|
1609
1639
|
|
|
1610
1640
|
char base[256];
|
|
@@ -1613,17 +1643,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
|
|
|
1613
1643
|
snprintf(base, 256, "kernel_timestep_embedding_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1614
1644
|
snprintf(name, 256, "%s", base);
|
|
1615
1645
|
|
|
1616
|
-
|
|
1617
|
-
if (res) {
|
|
1618
|
-
|
|
1646
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1647
|
+
if (!res.pipeline) {
|
|
1648
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1619
1649
|
}
|
|
1620
1650
|
|
|
1621
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1622
|
-
|
|
1623
1651
|
return res;
|
|
1624
1652
|
}
|
|
1625
1653
|
|
|
1626
|
-
|
|
1654
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1627
1655
|
assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
|
|
1628
1656
|
|
|
1629
1657
|
char base[256];
|
|
@@ -1632,17 +1660,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp
|
|
|
1632
1660
|
snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1633
1661
|
snprintf(name, 256, "%s", base);
|
|
1634
1662
|
|
|
1635
|
-
|
|
1636
|
-
if (res) {
|
|
1637
|
-
|
|
1663
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1664
|
+
if (!res.pipeline) {
|
|
1665
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1638
1666
|
}
|
|
1639
1667
|
|
|
1640
|
-
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1641
|
-
|
|
1642
1668
|
return res;
|
|
1643
1669
|
}
|
|
1644
1670
|
|
|
1645
|
-
|
|
1671
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1646
1672
|
assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
|
|
1647
1673
|
|
|
1648
1674
|
char base[256];
|
|
@@ -1651,12 +1677,67 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_g
|
|
|
1651
1677
|
snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1652
1678
|
snprintf(name, 256, "%s", base);
|
|
1653
1679
|
|
|
1654
|
-
|
|
1655
|
-
if (res) {
|
|
1656
|
-
|
|
1680
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1681
|
+
if (!res.pipeline) {
|
|
1682
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1657
1683
|
}
|
|
1658
1684
|
|
|
1659
|
-
res
|
|
1685
|
+
return res;
|
|
1686
|
+
}
|
|
1687
|
+
|
|
1688
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_memset(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1689
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_I64);
|
|
1690
|
+
|
|
1691
|
+
char base[256];
|
|
1692
|
+
char name[256];
|
|
1693
|
+
|
|
1694
|
+
snprintf(base, 256, "kernel_memset_%s", wsp_ggml_type_name(op->type));
|
|
1695
|
+
snprintf(name, 256, "%s", base);
|
|
1696
|
+
|
|
1697
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1698
|
+
if (!res.pipeline) {
|
|
1699
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1700
|
+
}
|
|
1701
|
+
|
|
1702
|
+
return res;
|
|
1703
|
+
}
|
|
1704
|
+
|
|
1705
|
+
wsp_ggml_metal_pipeline_with_params wsp_ggml_metal_library_get_pipeline_count_equal(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1706
|
+
assert(op->op == WSP_GGML_OP_COUNT_EQUAL);
|
|
1707
|
+
|
|
1708
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
|
|
1709
|
+
|
|
1710
|
+
WSP_GGML_ASSERT(op->src[0]->type == op->src[1]->type);
|
|
1711
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_I32);
|
|
1712
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_I64);
|
|
1713
|
+
|
|
1714
|
+
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
|
|
1715
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(op->src[0]) < (1LL << 31));
|
|
1716
|
+
|
|
1717
|
+
char base[256];
|
|
1718
|
+
char name[256];
|
|
1719
|
+
|
|
1720
|
+
int nsg = 1;
|
|
1721
|
+
while (32*nsg < ne00 && nsg < 32) {
|
|
1722
|
+
nsg *= 2;
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
snprintf(base, 256, "kernel_count_equal_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1726
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
1727
|
+
|
|
1728
|
+
wsp_ggml_metal_pipeline_with_params res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1729
|
+
if (!res.pipeline) {
|
|
1730
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1731
|
+
|
|
1732
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
|
|
1733
|
+
|
|
1734
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1735
|
+
|
|
1736
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1737
|
+
}
|
|
1738
|
+
|
|
1739
|
+
res.smem = 32 * sizeof(int32_t);
|
|
1740
|
+
res.nsg = nsg;
|
|
1660
1741
|
|
|
1661
1742
|
return res;
|
|
1662
1743
|
}
|