whisper.rn 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +292 -130
- package/cpp/ggml-backend-impl.h +4 -4
- package/cpp/ggml-backend-reg.cpp +13 -5
- package/cpp/ggml-backend.cpp +207 -17
- package/cpp/ggml-backend.h +19 -1
- package/cpp/ggml-cpu/amx/amx.cpp +5 -2
- package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
- package/cpp/ggml-cpu/arch-fallback.h +0 -4
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
- package/cpp/ggml-cpu/ggml-cpu.c +65 -44
- package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
- package/cpp/ggml-cpu/ops.cpp +542 -775
- package/cpp/ggml-cpu/ops.h +2 -0
- package/cpp/ggml-cpu/simd-mappings.h +88 -59
- package/cpp/ggml-cpu/unary-ops.cpp +135 -0
- package/cpp/ggml-cpu/unary-ops.h +5 -0
- package/cpp/ggml-cpu/vec.cpp +227 -20
- package/cpp/ggml-cpu/vec.h +407 -56
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +94 -12
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
- package/cpp/ggml-metal/ggml-metal-device.h +244 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +40 -40
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-quants.c +1 -0
- package/cpp/ggml.c +341 -15
- package/cpp/ggml.h +150 -5
- package/cpp/jsi/RNWhisperJSI.cpp +9 -2
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +89 -72
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/RNWhisperVadContext.mm +14 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6779
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
|
@@ -0,0 +1,1565 @@
|
|
|
1
|
+
#include "ggml-metal-device.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml-metal-impl.h"
|
|
4
|
+
|
|
5
|
+
#include "ggml-impl.h"
|
|
6
|
+
|
|
7
|
+
#include <cassert>
|
|
8
|
+
#include <memory>
|
|
9
|
+
#include <string>
|
|
10
|
+
#include <unordered_map>
|
|
11
|
+
|
|
12
|
+
struct wsp_ggml_metal_device_deleter {
|
|
13
|
+
void operator()(wsp_ggml_metal_device_t ctx) {
|
|
14
|
+
wsp_ggml_metal_device_free(ctx);
|
|
15
|
+
}
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
typedef std::unique_ptr<wsp_ggml_metal_device, wsp_ggml_metal_device_deleter> wsp_ggml_metal_device_ptr;
|
|
19
|
+
|
|
20
|
+
wsp_ggml_metal_device_t wsp_ggml_metal_device_get(void) {
|
|
21
|
+
static wsp_ggml_metal_device_ptr ctx { wsp_ggml_metal_device_init() };
|
|
22
|
+
|
|
23
|
+
return ctx.get();
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
struct wsp_ggml_metal_pipelines {
|
|
27
|
+
std::unordered_map<std::string, wsp_ggml_metal_pipeline_t> data;
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
wsp_ggml_metal_pipelines_t wsp_ggml_metal_pipelines_init(void) {
|
|
31
|
+
wsp_ggml_metal_pipelines_t res = new wsp_ggml_metal_pipelines();
|
|
32
|
+
|
|
33
|
+
return res;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
void wsp_ggml_metal_pipelines_free(wsp_ggml_metal_pipelines_t ppls) {
|
|
37
|
+
if (!ppls) {
|
|
38
|
+
return;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
|
|
42
|
+
wsp_ggml_metal_pipeline_free(it->second);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
delete ppls;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
void wsp_ggml_metal_pipelines_add(wsp_ggml_metal_pipelines_t ppls, const char * name, wsp_ggml_metal_pipeline_t pipeline) {
|
|
49
|
+
ppls->data[name] = pipeline;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_pipelines_get(wsp_ggml_metal_pipelines_t ppls, const char * name) {
|
|
53
|
+
if (ppls->data.find(name) == ppls->data.end()) {
|
|
54
|
+
return nullptr;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
return ppls->data[name];
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_base(wsp_ggml_metal_library_t lib, wsp_ggml_op op) {
|
|
61
|
+
char base[256];
|
|
62
|
+
char name[256];
|
|
63
|
+
|
|
64
|
+
const char * op_str = "undefined";
|
|
65
|
+
switch (op) {
|
|
66
|
+
case WSP_GGML_OP_ADD_ID: op_str = "add_id"; break;
|
|
67
|
+
case WSP_GGML_OP_CONCAT: op_str = "concat"; break;
|
|
68
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
snprintf(base, 256, "kernel_%s", op_str);
|
|
72
|
+
snprintf(name, 256, "%s", base);
|
|
73
|
+
|
|
74
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
75
|
+
if (res) {
|
|
76
|
+
return res;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
80
|
+
|
|
81
|
+
return res;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cpy(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc, wsp_ggml_type tdst) {
|
|
85
|
+
char base[256];
|
|
86
|
+
char name[256];
|
|
87
|
+
|
|
88
|
+
snprintf(base, 256, "kernel_cpy_%s_%s", wsp_ggml_type_name(tsrc), wsp_ggml_type_name(tdst));
|
|
89
|
+
snprintf(name, 256, "%s", base);
|
|
90
|
+
|
|
91
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
92
|
+
if (res) {
|
|
93
|
+
return res;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
97
|
+
|
|
98
|
+
return res;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pool_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, wsp_ggml_op_pool op_pool) {
|
|
102
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
103
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32 && op->src[0]->type == op->type);
|
|
104
|
+
|
|
105
|
+
const char * pool_str = "undefined";
|
|
106
|
+
switch (op_pool) {
|
|
107
|
+
case WSP_GGML_OP_POOL_AVG: pool_str = "avg"; break;
|
|
108
|
+
case WSP_GGML_OP_POOL_MAX: pool_str = "max"; break;
|
|
109
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
char base[256];
|
|
113
|
+
char name[256];
|
|
114
|
+
|
|
115
|
+
snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, wsp_ggml_type_name(op->src[0]->type));
|
|
116
|
+
snprintf(name, 256, "%s", base);
|
|
117
|
+
|
|
118
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
119
|
+
if (res) {
|
|
120
|
+
return res;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
124
|
+
|
|
125
|
+
return res;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_get_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
|
|
129
|
+
char base[256];
|
|
130
|
+
char name[256];
|
|
131
|
+
|
|
132
|
+
snprintf(base, 256, "kernel_get_rows_%s", wsp_ggml_type_name(tsrc));
|
|
133
|
+
snprintf(name, 256, "%s", base);
|
|
134
|
+
|
|
135
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
136
|
+
if (res) {
|
|
137
|
+
return res;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
141
|
+
|
|
142
|
+
return res;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows(wsp_ggml_metal_library_t lib, wsp_ggml_type tidx, wsp_ggml_type tdst) {
|
|
146
|
+
char base[256];
|
|
147
|
+
char name[256];
|
|
148
|
+
|
|
149
|
+
snprintf(base, 256, "kernel_set_rows_%s_%s", wsp_ggml_type_name(tdst), wsp_ggml_type_name(tidx));
|
|
150
|
+
snprintf(name, 256, "%s", base);
|
|
151
|
+
|
|
152
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
153
|
+
if (res) {
|
|
154
|
+
return res;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
158
|
+
|
|
159
|
+
return res;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc) {
|
|
163
|
+
char base[256];
|
|
164
|
+
char name[256];
|
|
165
|
+
|
|
166
|
+
snprintf(base, 256, "kernel_repeat_%s", wsp_ggml_type_name(tsrc));
|
|
167
|
+
snprintf(name, 256, "%s", base);
|
|
168
|
+
|
|
169
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
170
|
+
if (res) {
|
|
171
|
+
return res;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
175
|
+
|
|
176
|
+
return res;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
180
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
181
|
+
|
|
182
|
+
char base[256];
|
|
183
|
+
char name[256];
|
|
184
|
+
|
|
185
|
+
const int64_t n = wsp_ggml_nelements(op);
|
|
186
|
+
|
|
187
|
+
const char * op_str = "undefined";
|
|
188
|
+
switch (op->op) {
|
|
189
|
+
case WSP_GGML_OP_SCALE: op_str = "scale"; break;
|
|
190
|
+
case WSP_GGML_OP_CLAMP: op_str = "clamp"; break;
|
|
191
|
+
case WSP_GGML_OP_SQR: op_str = "sqr"; break;
|
|
192
|
+
case WSP_GGML_OP_SQRT: op_str = "sqrt"; break;
|
|
193
|
+
case WSP_GGML_OP_SIN: op_str = "sin"; break;
|
|
194
|
+
case WSP_GGML_OP_COS: op_str = "cos"; break;
|
|
195
|
+
case WSP_GGML_OP_LOG: op_str = "log"; break;
|
|
196
|
+
case WSP_GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
|
|
197
|
+
case WSP_GGML_OP_UNARY:
|
|
198
|
+
switch (wsp_ggml_get_unary_op(op)) {
|
|
199
|
+
case WSP_GGML_UNARY_OP_TANH: op_str = "tanh"; break;
|
|
200
|
+
case WSP_GGML_UNARY_OP_RELU: op_str = "relu"; break;
|
|
201
|
+
case WSP_GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
|
|
202
|
+
case WSP_GGML_UNARY_OP_GELU: op_str = "gelu"; break;
|
|
203
|
+
case WSP_GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
|
|
204
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
|
|
205
|
+
case WSP_GGML_UNARY_OP_SILU: op_str = "silu"; break;
|
|
206
|
+
case WSP_GGML_UNARY_OP_ELU: op_str = "elu"; break;
|
|
207
|
+
case WSP_GGML_UNARY_OP_NEG: op_str = "neg"; break;
|
|
208
|
+
case WSP_GGML_UNARY_OP_ABS: op_str = "abs"; break;
|
|
209
|
+
case WSP_GGML_UNARY_OP_SGN: op_str = "sgn"; break;
|
|
210
|
+
case WSP_GGML_UNARY_OP_STEP: op_str = "step"; break;
|
|
211
|
+
case WSP_GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
|
212
|
+
case WSP_GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
|
213
|
+
case WSP_GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
|
214
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
215
|
+
} break;
|
|
216
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
217
|
+
};
|
|
218
|
+
|
|
219
|
+
const char * suffix = "";
|
|
220
|
+
if (n % 4 == 0) {
|
|
221
|
+
suffix = "_4";
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
snprintf(base, 256, "kernel_%s_%s%s", op_str, wsp_ggml_type_name(op->src[0]->type), suffix);
|
|
225
|
+
snprintf(name, 256, "%s", base);
|
|
226
|
+
|
|
227
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
228
|
+
if (res) {
|
|
229
|
+
return res;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
233
|
+
|
|
234
|
+
return res;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
238
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
|
|
239
|
+
|
|
240
|
+
char base[256];
|
|
241
|
+
char name[256];
|
|
242
|
+
|
|
243
|
+
const char * op_str = "undefined";
|
|
244
|
+
switch (op->op) {
|
|
245
|
+
case WSP_GGML_OP_GLU:
|
|
246
|
+
switch (wsp_ggml_get_glu_op(op)) {
|
|
247
|
+
case WSP_GGML_GLU_OP_REGLU: op_str = "reglu"; break;
|
|
248
|
+
case WSP_GGML_GLU_OP_GEGLU: op_str = "geglu"; break;
|
|
249
|
+
case WSP_GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break;
|
|
250
|
+
case WSP_GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break;
|
|
251
|
+
case WSP_GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break;
|
|
252
|
+
case WSP_GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break;
|
|
253
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
254
|
+
} break;
|
|
255
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
256
|
+
};
|
|
257
|
+
|
|
258
|
+
snprintf(base, 256, "kernel_%s_%s", op_str, wsp_ggml_type_name(op->src[0]->type));
|
|
259
|
+
snprintf(name, 256, "%s", base);
|
|
260
|
+
|
|
261
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
262
|
+
if (res) {
|
|
263
|
+
return res;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
267
|
+
|
|
268
|
+
return res;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
|
+
assert(op->op == WSP_GGML_OP_SUM);
|
|
273
|
+
|
|
274
|
+
char base[256];
|
|
275
|
+
char name[256];
|
|
276
|
+
|
|
277
|
+
snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
278
|
+
snprintf(name, 256, "%s", base);
|
|
279
|
+
|
|
280
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
281
|
+
if (res) {
|
|
282
|
+
return res;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
286
|
+
|
|
287
|
+
return res;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
291
|
+
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
292
|
+
|
|
293
|
+
char base[256];
|
|
294
|
+
char name[256];
|
|
295
|
+
|
|
296
|
+
const char * op_str = "undefined";
|
|
297
|
+
switch (op->op) {
|
|
298
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
299
|
+
op_str = "sum_rows"; break;
|
|
300
|
+
case WSP_GGML_OP_MEAN:
|
|
301
|
+
op_str = "mean"; break;
|
|
302
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
303
|
+
};
|
|
304
|
+
|
|
305
|
+
snprintf(base, 256, "kernel_%s_%s", op_str, wsp_ggml_type_name(op->src[0]->type));
|
|
306
|
+
|
|
307
|
+
snprintf(name, 256, "%s", base);
|
|
308
|
+
|
|
309
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
310
|
+
if (res) {
|
|
311
|
+
return res;
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
315
|
+
|
|
316
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
317
|
+
|
|
318
|
+
return res;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
322
|
+
WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
323
|
+
|
|
324
|
+
char base[256];
|
|
325
|
+
char name[256];
|
|
326
|
+
|
|
327
|
+
const char * suffix = "";
|
|
328
|
+
|
|
329
|
+
if (op->src[0]->ne[0] % 4 == 0) {
|
|
330
|
+
suffix = "_4";
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
const wsp_ggml_type tsrc1 = op->src[1] ? op->src[1]->type : WSP_GGML_TYPE_F32;
|
|
334
|
+
|
|
335
|
+
snprintf(base, 256, "kernel_soft_max_%s%s", wsp_ggml_type_name(tsrc1), suffix);
|
|
336
|
+
snprintf(name, 256, "%s", base);
|
|
337
|
+
|
|
338
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
339
|
+
if (res) {
|
|
340
|
+
return res;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
344
|
+
|
|
345
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
346
|
+
|
|
347
|
+
return res;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
351
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
352
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
353
|
+
|
|
354
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
355
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
356
|
+
|
|
357
|
+
char base[256];
|
|
358
|
+
char name[256];
|
|
359
|
+
|
|
360
|
+
const char * suffix = "";
|
|
361
|
+
|
|
362
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
363
|
+
suffix = "_4";
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
|
|
367
|
+
snprintf(name, 256, "%s", base);
|
|
368
|
+
|
|
369
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
370
|
+
if (res) {
|
|
371
|
+
return res;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
375
|
+
|
|
376
|
+
return res;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
380
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
381
|
+
|
|
382
|
+
char base[256];
|
|
383
|
+
char name[256];
|
|
384
|
+
|
|
385
|
+
const int nsg = (ne00 + 31)/32;
|
|
386
|
+
|
|
387
|
+
snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
388
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
389
|
+
|
|
390
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
391
|
+
if (res) {
|
|
392
|
+
return res;
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
396
|
+
|
|
397
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
|
|
398
|
+
|
|
399
|
+
return res;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rwkv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
403
|
+
char base[256];
|
|
404
|
+
char name[256];
|
|
405
|
+
|
|
406
|
+
const int64_t C = op->ne[0];
|
|
407
|
+
const int64_t H = op->src[0]->ne[1];
|
|
408
|
+
|
|
409
|
+
switch (op->op) {
|
|
410
|
+
case WSP_GGML_OP_RWKV_WKV6:
|
|
411
|
+
{
|
|
412
|
+
WSP_GGML_ASSERT(op->src[5]->type == WSP_GGML_TYPE_F32);
|
|
413
|
+
WSP_GGML_ASSERT(C % H == 0);
|
|
414
|
+
WSP_GGML_ASSERT(C / H == 64);
|
|
415
|
+
|
|
416
|
+
snprintf(base, 256, "kernel_rwkv_wkv6_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
417
|
+
} break;
|
|
418
|
+
case WSP_GGML_OP_RWKV_WKV7:
|
|
419
|
+
{
|
|
420
|
+
WSP_GGML_ASSERT(op->src[6]->type == WSP_GGML_TYPE_F32);
|
|
421
|
+
WSP_GGML_ASSERT(C % H == 0);
|
|
422
|
+
WSP_GGML_ASSERT(C / H == 64);
|
|
423
|
+
|
|
424
|
+
snprintf(base, 256, "kernel_rwkv_wkv7_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
425
|
+
} break;
|
|
426
|
+
default:
|
|
427
|
+
WSP_GGML_ABORT("fatal error");
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
snprintf(name, 256, "%s", base);
|
|
431
|
+
|
|
432
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
433
|
+
if (res) {
|
|
434
|
+
return res;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
438
|
+
|
|
439
|
+
return res;
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_ext(wsp_ggml_metal_library_t lib, wsp_ggml_type tsrc0, wsp_ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
|
443
|
+
char base[256];
|
|
444
|
+
char name[256];
|
|
445
|
+
|
|
446
|
+
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), r1ptg);
|
|
447
|
+
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
|
448
|
+
|
|
449
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
450
|
+
if (res) {
|
|
451
|
+
return res;
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
455
|
+
|
|
456
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
457
|
+
wsp_ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
|
458
|
+
|
|
459
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
460
|
+
|
|
461
|
+
wsp_ggml_metal_cv_free(cv);
|
|
462
|
+
|
|
463
|
+
return res;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
467
|
+
char base[256];
|
|
468
|
+
char name[256];
|
|
469
|
+
|
|
470
|
+
const wsp_ggml_type tsrc0 = op->src[0]->type;
|
|
471
|
+
const wsp_ggml_type tsrc1 = op->src[1]->type;
|
|
472
|
+
|
|
473
|
+
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
|
474
|
+
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
|
|
475
|
+
|
|
476
|
+
snprintf(base, 256, "kernel_mul_mm_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
|
|
477
|
+
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
|
478
|
+
|
|
479
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
480
|
+
if (res) {
|
|
481
|
+
return res;
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
485
|
+
|
|
486
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
487
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
|
488
|
+
|
|
489
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
490
|
+
|
|
491
|
+
wsp_ggml_metal_cv_free(cv);
|
|
492
|
+
|
|
493
|
+
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
|
494
|
+
wsp_ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
|
|
495
|
+
|
|
496
|
+
return res;
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
500
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
501
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
502
|
+
|
|
503
|
+
char base[256];
|
|
504
|
+
char name[256];
|
|
505
|
+
|
|
506
|
+
int nsg = 0; // number of simdgroups
|
|
507
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
|
508
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
|
509
|
+
|
|
510
|
+
size_t smem = 0; // shared memory
|
|
511
|
+
|
|
512
|
+
const wsp_ggml_type tsrc0 = op->src[0]->type;
|
|
513
|
+
const wsp_ggml_type tsrc1 = op->src[1]->type;
|
|
514
|
+
|
|
515
|
+
const char * suffix = "";
|
|
516
|
+
|
|
517
|
+
// use custom matrix x vector kernel
|
|
518
|
+
switch (tsrc0) {
|
|
519
|
+
case WSP_GGML_TYPE_F32:
|
|
520
|
+
case WSP_GGML_TYPE_F16:
|
|
521
|
+
case WSP_GGML_TYPE_BF16:
|
|
522
|
+
{
|
|
523
|
+
if (ne00 < 32) {
|
|
524
|
+
nsg = 1;
|
|
525
|
+
nr0 = 32;
|
|
526
|
+
nr1 = 1;
|
|
527
|
+
suffix = "_short";
|
|
528
|
+
} else {
|
|
529
|
+
nsg = std::min(4, (ne00 + 127) / 128);
|
|
530
|
+
nr0 = 2;
|
|
531
|
+
nr1 = 1;
|
|
532
|
+
smem = 32*sizeof(float)*nr0;
|
|
533
|
+
suffix = ne00 % 4 == 0 ? "_4" : "";
|
|
534
|
+
}
|
|
535
|
+
} break;
|
|
536
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
537
|
+
{
|
|
538
|
+
nsg = N_SG_Q4_0;
|
|
539
|
+
nr0 = N_R0_Q4_0;
|
|
540
|
+
} break;
|
|
541
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
542
|
+
{
|
|
543
|
+
nsg = N_SG_Q4_1;
|
|
544
|
+
nr0 = N_R0_Q4_1;
|
|
545
|
+
} break;
|
|
546
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
547
|
+
{
|
|
548
|
+
nsg = N_SG_Q5_0;
|
|
549
|
+
nr0 = N_R0_Q5_0;
|
|
550
|
+
} break;
|
|
551
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
552
|
+
{
|
|
553
|
+
nsg = N_SG_Q5_1;
|
|
554
|
+
nr0 = N_R0_Q5_1;
|
|
555
|
+
} break;
|
|
556
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
557
|
+
{
|
|
558
|
+
nsg = N_SG_Q8_0;
|
|
559
|
+
nr0 = N_R0_Q8_0;
|
|
560
|
+
smem = 32*sizeof(float)*N_R0_Q8_0;
|
|
561
|
+
} break;
|
|
562
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
563
|
+
{
|
|
564
|
+
nsg = N_SG_MXFP4;
|
|
565
|
+
nr0 = N_R0_MXFP4;
|
|
566
|
+
smem = 32*sizeof(float);
|
|
567
|
+
} break;
|
|
568
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
569
|
+
{
|
|
570
|
+
nsg = N_SG_Q2_K;
|
|
571
|
+
nr0 = N_R0_Q2_K;
|
|
572
|
+
} break;
|
|
573
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
574
|
+
{
|
|
575
|
+
nsg = N_SG_Q3_K;
|
|
576
|
+
nr0 = N_R0_Q3_K;
|
|
577
|
+
} break;
|
|
578
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
579
|
+
{
|
|
580
|
+
nsg = N_SG_Q4_K;
|
|
581
|
+
nr0 = N_R0_Q4_K;
|
|
582
|
+
} break;
|
|
583
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
584
|
+
{
|
|
585
|
+
nsg = N_SG_Q5_K;
|
|
586
|
+
nr0 = N_R0_Q5_K;
|
|
587
|
+
} break;
|
|
588
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
589
|
+
{
|
|
590
|
+
nsg = N_SG_Q6_K;
|
|
591
|
+
nr0 = N_R0_Q6_K;
|
|
592
|
+
} break;
|
|
593
|
+
case WSP_GGML_TYPE_IQ2_XXS:
|
|
594
|
+
{
|
|
595
|
+
nsg = N_SG_IQ2_XXS;
|
|
596
|
+
nr0 = N_R0_IQ2_XXS;
|
|
597
|
+
smem = 256*8+128;
|
|
598
|
+
} break;
|
|
599
|
+
case WSP_GGML_TYPE_IQ2_XS:
|
|
600
|
+
{
|
|
601
|
+
nsg = N_SG_IQ2_XS;
|
|
602
|
+
nr0 = N_R0_IQ2_XS;
|
|
603
|
+
smem = 512*8+128;
|
|
604
|
+
} break;
|
|
605
|
+
case WSP_GGML_TYPE_IQ3_XXS:
|
|
606
|
+
{
|
|
607
|
+
nsg = N_SG_IQ3_XXS;
|
|
608
|
+
nr0 = N_R0_IQ3_XXS;
|
|
609
|
+
smem = 256*4+128;
|
|
610
|
+
} break;
|
|
611
|
+
case WSP_GGML_TYPE_IQ3_S:
|
|
612
|
+
{
|
|
613
|
+
nsg = N_SG_IQ3_S;
|
|
614
|
+
nr0 = N_R0_IQ3_S;
|
|
615
|
+
smem = 512*4;
|
|
616
|
+
} break;
|
|
617
|
+
case WSP_GGML_TYPE_IQ2_S:
|
|
618
|
+
{
|
|
619
|
+
nsg = N_SG_IQ2_S;
|
|
620
|
+
nr0 = N_R0_IQ2_S;
|
|
621
|
+
} break;
|
|
622
|
+
case WSP_GGML_TYPE_IQ1_S:
|
|
623
|
+
{
|
|
624
|
+
nsg = N_SG_IQ1_S;
|
|
625
|
+
nr0 = N_R0_IQ1_S;
|
|
626
|
+
} break;
|
|
627
|
+
case WSP_GGML_TYPE_IQ1_M:
|
|
628
|
+
{
|
|
629
|
+
nsg = N_SG_IQ1_M;
|
|
630
|
+
nr0 = N_R0_IQ1_M;
|
|
631
|
+
} break;
|
|
632
|
+
case WSP_GGML_TYPE_IQ4_NL:
|
|
633
|
+
{
|
|
634
|
+
nsg = N_SG_IQ4_NL;
|
|
635
|
+
nr0 = N_R0_IQ4_NL;
|
|
636
|
+
smem = 32*sizeof(float);
|
|
637
|
+
} break;
|
|
638
|
+
case WSP_GGML_TYPE_IQ4_XS:
|
|
639
|
+
{
|
|
640
|
+
nsg = N_SG_IQ4_XS;
|
|
641
|
+
nr0 = N_R0_IQ4_XS;
|
|
642
|
+
smem = 32*sizeof(float);
|
|
643
|
+
} break;
|
|
644
|
+
default:
|
|
645
|
+
{
|
|
646
|
+
WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0);
|
|
647
|
+
WSP_GGML_ABORT("not implemented");
|
|
648
|
+
}
|
|
649
|
+
};
|
|
650
|
+
|
|
651
|
+
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
|
|
652
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
653
|
+
|
|
654
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
655
|
+
if (res) {
|
|
656
|
+
return res;
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
660
|
+
|
|
661
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
662
|
+
|
|
663
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
664
|
+
|
|
665
|
+
wsp_ggml_metal_cv_free(cv);
|
|
666
|
+
|
|
667
|
+
wsp_ggml_metal_pipeline_set_nr0 (res, nr0);
|
|
668
|
+
wsp_ggml_metal_pipeline_set_nr1 (res, nr1);
|
|
669
|
+
wsp_ggml_metal_pipeline_set_nsg (res, nsg);
|
|
670
|
+
wsp_ggml_metal_pipeline_set_smem(res, smem);
|
|
671
|
+
|
|
672
|
+
return res;
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp_ggml_metal_library_t lib, int ne02, int ne20) {
|
|
676
|
+
char base[256];
|
|
677
|
+
char name[256];
|
|
678
|
+
|
|
679
|
+
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
680
|
+
snprintf(name, 256, "%s", base);
|
|
681
|
+
|
|
682
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
683
|
+
if (res) {
|
|
684
|
+
return res;
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
688
|
+
|
|
689
|
+
const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
690
|
+
|
|
691
|
+
wsp_ggml_metal_pipeline_set_smem(res, smem);
|
|
692
|
+
|
|
693
|
+
return res;
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
697
|
+
char base[256];
|
|
698
|
+
char name[256];
|
|
699
|
+
|
|
700
|
+
const wsp_ggml_type tsrc0 = op->src[0]->type;
|
|
701
|
+
const wsp_ggml_type tsrc1 = op->src[1]->type;
|
|
702
|
+
|
|
703
|
+
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
|
704
|
+
|
|
705
|
+
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1));
|
|
706
|
+
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
|
707
|
+
|
|
708
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
709
|
+
if (res) {
|
|
710
|
+
return res;
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
714
|
+
|
|
715
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
716
|
+
|
|
717
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
718
|
+
|
|
719
|
+
wsp_ggml_metal_cv_free(cv);
|
|
720
|
+
|
|
721
|
+
wsp_ggml_metal_pipeline_set_smem(res, 8192);
|
|
722
|
+
|
|
723
|
+
return res;
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
727
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
728
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
729
|
+
|
|
730
|
+
char base[256];
|
|
731
|
+
char name[256];
|
|
732
|
+
|
|
733
|
+
int nsg = 0; // number of simdgroups
|
|
734
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
|
735
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
|
736
|
+
|
|
737
|
+
size_t smem = 0; // shared memory
|
|
738
|
+
|
|
739
|
+
const wsp_ggml_type tsrc0 = op->src[0]->type;
|
|
740
|
+
const wsp_ggml_type tsrc1 = op->src[1]->type;
|
|
741
|
+
|
|
742
|
+
const char * suffix = "";
|
|
743
|
+
|
|
744
|
+
// use custom matrix x vector kernel
|
|
745
|
+
switch (tsrc0) {
|
|
746
|
+
case WSP_GGML_TYPE_F32:
|
|
747
|
+
case WSP_GGML_TYPE_F16:
|
|
748
|
+
case WSP_GGML_TYPE_BF16:
|
|
749
|
+
{
|
|
750
|
+
nsg = std::min(4, (ne00 + 127) / 128);
|
|
751
|
+
nr0 = 2;
|
|
752
|
+
nr1 = 1;
|
|
753
|
+
smem = 32*sizeof(float)*nr0;
|
|
754
|
+
suffix = ne00 % 4 == 0 ? "_4" : "";
|
|
755
|
+
} break;
|
|
756
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
757
|
+
{
|
|
758
|
+
nsg = N_SG_Q4_0;
|
|
759
|
+
nr0 = N_R0_Q4_0;
|
|
760
|
+
} break;
|
|
761
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
762
|
+
{
|
|
763
|
+
nsg = N_SG_Q4_1;
|
|
764
|
+
nr0 = N_R0_Q4_1;
|
|
765
|
+
} break;
|
|
766
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
767
|
+
{
|
|
768
|
+
nsg = N_SG_Q5_0;
|
|
769
|
+
nr0 = N_R0_Q5_0;
|
|
770
|
+
} break;
|
|
771
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
772
|
+
{
|
|
773
|
+
nsg = N_SG_Q5_1;
|
|
774
|
+
nr0 = N_R0_Q5_1;
|
|
775
|
+
} break;
|
|
776
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
777
|
+
{
|
|
778
|
+
nsg = N_SG_Q8_0;
|
|
779
|
+
nr0 = N_R0_Q8_0;
|
|
780
|
+
smem = 32*sizeof(float)*N_R0_Q8_0;
|
|
781
|
+
} break;
|
|
782
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
783
|
+
{
|
|
784
|
+
nsg = N_SG_MXFP4;
|
|
785
|
+
nr0 = N_R0_MXFP4;
|
|
786
|
+
smem = 32*sizeof(float);
|
|
787
|
+
} break;
|
|
788
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
789
|
+
{
|
|
790
|
+
nsg = N_SG_Q2_K;
|
|
791
|
+
nr0 = N_R0_Q2_K;
|
|
792
|
+
} break;
|
|
793
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
794
|
+
{
|
|
795
|
+
nsg = N_SG_Q3_K;
|
|
796
|
+
nr0 = N_R0_Q3_K;
|
|
797
|
+
} break;
|
|
798
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
799
|
+
{
|
|
800
|
+
nsg = N_SG_Q4_K;
|
|
801
|
+
nr0 = N_R0_Q4_K;
|
|
802
|
+
} break;
|
|
803
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
804
|
+
{
|
|
805
|
+
nsg = N_SG_Q5_K;
|
|
806
|
+
nr0 = N_R0_Q5_K;
|
|
807
|
+
} break;
|
|
808
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
809
|
+
{
|
|
810
|
+
nsg = N_SG_Q6_K;
|
|
811
|
+
nr0 = N_R0_Q6_K;
|
|
812
|
+
} break;
|
|
813
|
+
case WSP_GGML_TYPE_IQ2_XXS:
|
|
814
|
+
{
|
|
815
|
+
nsg = N_SG_IQ2_XXS;
|
|
816
|
+
nr0 = N_R0_IQ2_XXS;
|
|
817
|
+
smem = 256*8+128;
|
|
818
|
+
} break;
|
|
819
|
+
case WSP_GGML_TYPE_IQ2_XS:
|
|
820
|
+
{
|
|
821
|
+
nsg = N_SG_IQ2_XS;
|
|
822
|
+
nr0 = N_R0_IQ2_XS;
|
|
823
|
+
smem = 512*8+128;
|
|
824
|
+
} break;
|
|
825
|
+
case WSP_GGML_TYPE_IQ3_XXS:
|
|
826
|
+
{
|
|
827
|
+
nsg = N_SG_IQ3_XXS;
|
|
828
|
+
nr0 = N_R0_IQ3_XXS;
|
|
829
|
+
smem = 256*4+128;
|
|
830
|
+
} break;
|
|
831
|
+
case WSP_GGML_TYPE_IQ3_S:
|
|
832
|
+
{
|
|
833
|
+
nsg = N_SG_IQ3_S;
|
|
834
|
+
nr0 = N_R0_IQ3_S;
|
|
835
|
+
smem = 512*4;
|
|
836
|
+
} break;
|
|
837
|
+
case WSP_GGML_TYPE_IQ2_S:
|
|
838
|
+
{
|
|
839
|
+
nsg = N_SG_IQ2_S;
|
|
840
|
+
nr0 = N_R0_IQ2_S;
|
|
841
|
+
} break;
|
|
842
|
+
case WSP_GGML_TYPE_IQ1_S:
|
|
843
|
+
{
|
|
844
|
+
nsg = N_SG_IQ1_S;
|
|
845
|
+
nr0 = N_R0_IQ1_S;
|
|
846
|
+
} break;
|
|
847
|
+
case WSP_GGML_TYPE_IQ1_M:
|
|
848
|
+
{
|
|
849
|
+
nsg = N_SG_IQ1_M;
|
|
850
|
+
nr0 = N_R0_IQ1_M;
|
|
851
|
+
} break;
|
|
852
|
+
case WSP_GGML_TYPE_IQ4_NL:
|
|
853
|
+
{
|
|
854
|
+
nsg = N_SG_IQ4_NL;
|
|
855
|
+
nr0 = N_R0_IQ4_NL;
|
|
856
|
+
smem = 32*sizeof(float);
|
|
857
|
+
} break;
|
|
858
|
+
case WSP_GGML_TYPE_IQ4_XS:
|
|
859
|
+
{
|
|
860
|
+
nsg = N_SG_IQ4_XS;
|
|
861
|
+
nr0 = N_R0_IQ4_XS;
|
|
862
|
+
smem = 32*sizeof(float);
|
|
863
|
+
} break;
|
|
864
|
+
default:
|
|
865
|
+
{
|
|
866
|
+
WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type);
|
|
867
|
+
WSP_GGML_ABORT("not implemented");
|
|
868
|
+
}
|
|
869
|
+
};
|
|
870
|
+
|
|
871
|
+
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", wsp_ggml_type_name(tsrc0), wsp_ggml_type_name(tsrc1), suffix);
|
|
872
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
873
|
+
|
|
874
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
875
|
+
if (res) {
|
|
876
|
+
return res;
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
880
|
+
|
|
881
|
+
wsp_ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
882
|
+
|
|
883
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
884
|
+
|
|
885
|
+
wsp_ggml_metal_cv_free(cv);
|
|
886
|
+
|
|
887
|
+
wsp_ggml_metal_pipeline_set_nr0 (res, nr0);
|
|
888
|
+
wsp_ggml_metal_pipeline_set_nr1 (res, nr1);
|
|
889
|
+
wsp_ggml_metal_pipeline_set_nsg (res, nsg);
|
|
890
|
+
wsp_ggml_metal_pipeline_set_smem(res, smem);
|
|
891
|
+
|
|
892
|
+
return res;
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
896
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
897
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
|
|
898
|
+
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
899
|
+
|
|
900
|
+
char base[256];
|
|
901
|
+
char name[256];
|
|
902
|
+
|
|
903
|
+
snprintf(base, 256, "kernel_argmax_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
904
|
+
snprintf(name, 256, "%s", base);
|
|
905
|
+
|
|
906
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
907
|
+
if (res) {
|
|
908
|
+
return res;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
912
|
+
|
|
913
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
|
|
914
|
+
|
|
915
|
+
return res;
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
919
|
+
assert(op->op == WSP_GGML_OP_ARGSORT);
|
|
920
|
+
|
|
921
|
+
char base[256];
|
|
922
|
+
char name[256];
|
|
923
|
+
|
|
924
|
+
wsp_ggml_sort_order order = (wsp_ggml_sort_order) op->op_params[0];
|
|
925
|
+
|
|
926
|
+
const char * order_str = "undefined";
|
|
927
|
+
switch (order) {
|
|
928
|
+
case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
929
|
+
case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
930
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
931
|
+
};
|
|
932
|
+
|
|
933
|
+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
934
|
+
snprintf(name, 256, "%s", base);
|
|
935
|
+
|
|
936
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
937
|
+
if (res) {
|
|
938
|
+
return res;
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
942
|
+
|
|
943
|
+
return res;
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
947
|
+
wsp_ggml_metal_library_t lib,
|
|
948
|
+
const struct wsp_ggml_tensor * op,
|
|
949
|
+
bool has_mask,
|
|
950
|
+
int32_t ncpsg) {
|
|
951
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
952
|
+
WSP_GGML_UNUSED(op);
|
|
953
|
+
|
|
954
|
+
char base[256];
|
|
955
|
+
char name[256];
|
|
956
|
+
|
|
957
|
+
snprintf(base, 256, "kernel_%s",
|
|
958
|
+
"flash_attn_ext_pad");
|
|
959
|
+
|
|
960
|
+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
|
961
|
+
base,
|
|
962
|
+
has_mask,
|
|
963
|
+
ncpsg);
|
|
964
|
+
|
|
965
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
966
|
+
if (res) {
|
|
967
|
+
return res;
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
971
|
+
|
|
972
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
973
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
974
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
975
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
976
|
+
|
|
977
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
978
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
979
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
980
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
981
|
+
//wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
982
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
983
|
+
|
|
984
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
985
|
+
|
|
986
|
+
wsp_ggml_metal_cv_free(cv);
|
|
987
|
+
|
|
988
|
+
return res;
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
992
|
+
wsp_ggml_metal_library_t lib,
|
|
993
|
+
const struct wsp_ggml_tensor * op,
|
|
994
|
+
int32_t nqptg,
|
|
995
|
+
int32_t ncpsg) {
|
|
996
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
997
|
+
WSP_GGML_UNUSED(op);
|
|
998
|
+
|
|
999
|
+
char base[256];
|
|
1000
|
+
char name[256];
|
|
1001
|
+
|
|
1002
|
+
snprintf(base, 256, "kernel_%s",
|
|
1003
|
+
"flash_attn_ext_blk");
|
|
1004
|
+
|
|
1005
|
+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
|
1006
|
+
base,
|
|
1007
|
+
nqptg,
|
|
1008
|
+
ncpsg);
|
|
1009
|
+
|
|
1010
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1011
|
+
if (res) {
|
|
1012
|
+
return res;
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1016
|
+
|
|
1017
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1018
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1019
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1020
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1021
|
+
|
|
1022
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1023
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1024
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1025
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1026
|
+
wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1027
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1028
|
+
|
|
1029
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1030
|
+
|
|
1031
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1032
|
+
|
|
1033
|
+
return res;
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
1037
|
+
wsp_ggml_metal_library_t lib,
|
|
1038
|
+
const wsp_ggml_tensor * op,
|
|
1039
|
+
bool has_mask,
|
|
1040
|
+
bool has_sinks,
|
|
1041
|
+
bool has_bias,
|
|
1042
|
+
bool has_scap,
|
|
1043
|
+
bool has_kvpad,
|
|
1044
|
+
int32_t nsg) {
|
|
1045
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1046
|
+
|
|
1047
|
+
char base[256];
|
|
1048
|
+
char name[256];
|
|
1049
|
+
|
|
1050
|
+
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
|
1051
|
+
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
|
1052
|
+
|
|
1053
|
+
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
1054
|
+
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
1055
|
+
|
|
1056
|
+
// do bounds checks for the mask?
|
|
1057
|
+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
|
1058
|
+
|
|
1059
|
+
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
1060
|
+
"flash_attn_ext",
|
|
1061
|
+
wsp_ggml_type_name(op->src[1]->type),
|
|
1062
|
+
dk,
|
|
1063
|
+
dv);
|
|
1064
|
+
|
|
1065
|
+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
1066
|
+
base,
|
|
1067
|
+
has_mask,
|
|
1068
|
+
has_sinks,
|
|
1069
|
+
has_bias,
|
|
1070
|
+
has_scap,
|
|
1071
|
+
has_kvpad,
|
|
1072
|
+
bc_mask,
|
|
1073
|
+
ns10,
|
|
1074
|
+
ns20,
|
|
1075
|
+
nsg);
|
|
1076
|
+
|
|
1077
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1078
|
+
if (res) {
|
|
1079
|
+
return res;
|
|
1080
|
+
}
|
|
1081
|
+
|
|
1082
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1083
|
+
|
|
1084
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
|
|
1085
|
+
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
1086
|
+
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
1087
|
+
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1088
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
1089
|
+
|
|
1090
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
1091
|
+
|
|
1092
|
+
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
1093
|
+
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
1094
|
+
wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
|
|
1095
|
+
|
|
1096
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1097
|
+
|
|
1098
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1099
|
+
|
|
1100
|
+
return res;
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
1104
|
+
wsp_ggml_metal_library_t lib,
|
|
1105
|
+
const wsp_ggml_tensor * op,
|
|
1106
|
+
bool has_mask,
|
|
1107
|
+
bool has_sinks,
|
|
1108
|
+
bool has_bias,
|
|
1109
|
+
bool has_scap,
|
|
1110
|
+
bool has_kvpad,
|
|
1111
|
+
int32_t nsg,
|
|
1112
|
+
int32_t nwg) {
|
|
1113
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1114
|
+
|
|
1115
|
+
char base[256];
|
|
1116
|
+
char name[256];
|
|
1117
|
+
|
|
1118
|
+
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
|
1119
|
+
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
|
1120
|
+
|
|
1121
|
+
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
1122
|
+
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
1123
|
+
|
|
1124
|
+
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
1125
|
+
"flash_attn_ext_vec",
|
|
1126
|
+
wsp_ggml_type_name(op->src[1]->type),
|
|
1127
|
+
dk,
|
|
1128
|
+
dv);
|
|
1129
|
+
|
|
1130
|
+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
|
1131
|
+
base,
|
|
1132
|
+
has_mask,
|
|
1133
|
+
has_sinks,
|
|
1134
|
+
has_bias,
|
|
1135
|
+
has_scap,
|
|
1136
|
+
has_kvpad,
|
|
1137
|
+
ns10,
|
|
1138
|
+
ns20,
|
|
1139
|
+
nsg, nwg);
|
|
1140
|
+
|
|
1141
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1142
|
+
if (res) {
|
|
1143
|
+
return res;
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1147
|
+
|
|
1148
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
|
|
1149
|
+
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1150
|
+
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1151
|
+
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1152
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1153
|
+
|
|
1154
|
+
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1155
|
+
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
1156
|
+
wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
|
|
1157
|
+
wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
|
|
1158
|
+
|
|
1159
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1160
|
+
|
|
1161
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1162
|
+
|
|
1163
|
+
return res;
|
|
1164
|
+
}
|
|
1165
|
+
|
|
1166
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
1167
|
+
wsp_ggml_metal_library_t lib,
|
|
1168
|
+
const wsp_ggml_tensor * op,
|
|
1169
|
+
int32_t dv,
|
|
1170
|
+
int32_t nwg) {
|
|
1171
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1172
|
+
|
|
1173
|
+
char base[256];
|
|
1174
|
+
char name[256];
|
|
1175
|
+
|
|
1176
|
+
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
|
1177
|
+
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
|
1178
|
+
|
|
1179
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1180
|
+
if (res) {
|
|
1181
|
+
return res;
|
|
1182
|
+
}
|
|
1183
|
+
|
|
1184
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1185
|
+
|
|
1186
|
+
wsp_ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
|
|
1187
|
+
wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
|
|
1188
|
+
|
|
1189
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1190
|
+
|
|
1191
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1192
|
+
|
|
1193
|
+
return res;
|
|
1194
|
+
|
|
1195
|
+
WSP_GGML_UNUSED(op);
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin(
|
|
1199
|
+
wsp_ggml_metal_library_t lib,
|
|
1200
|
+
wsp_ggml_op op,
|
|
1201
|
+
int32_t n_fuse,
|
|
1202
|
+
bool row) {
|
|
1203
|
+
char base[256];
|
|
1204
|
+
char name[256];
|
|
1205
|
+
|
|
1206
|
+
const char * op_str = "undefined";
|
|
1207
|
+
switch (op) {
|
|
1208
|
+
case WSP_GGML_OP_ADD: op_str = "add"; break;
|
|
1209
|
+
case WSP_GGML_OP_SUB: op_str = "sub"; break;
|
|
1210
|
+
case WSP_GGML_OP_MUL: op_str = "mul"; break;
|
|
1211
|
+
case WSP_GGML_OP_DIV: op_str = "div"; break;
|
|
1212
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1213
|
+
};
|
|
1214
|
+
|
|
1215
|
+
if (row) {
|
|
1216
|
+
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
|
1217
|
+
} else {
|
|
1218
|
+
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
snprintf(name, 256, "%s", base);
|
|
1222
|
+
|
|
1223
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1224
|
+
if (res) {
|
|
1225
|
+
return res;
|
|
1226
|
+
}
|
|
1227
|
+
|
|
1228
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1229
|
+
|
|
1230
|
+
return res;
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1234
|
+
assert(op->op == WSP_GGML_OP_L2_NORM);
|
|
1235
|
+
|
|
1236
|
+
WSP_GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
|
1237
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(op->src[0]));
|
|
1238
|
+
|
|
1239
|
+
char base[256];
|
|
1240
|
+
char name[256];
|
|
1241
|
+
|
|
1242
|
+
snprintf(base, 256, "kernel_l2_norm_f32");
|
|
1243
|
+
snprintf(name, 256, "%s", base);
|
|
1244
|
+
|
|
1245
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1246
|
+
if (res) {
|
|
1247
|
+
return res;
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1251
|
+
|
|
1252
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1253
|
+
|
|
1254
|
+
return res;
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1258
|
+
assert(op->op == WSP_GGML_OP_GROUP_NORM);
|
|
1259
|
+
|
|
1260
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1261
|
+
|
|
1262
|
+
char base[256];
|
|
1263
|
+
char name[256];
|
|
1264
|
+
|
|
1265
|
+
snprintf(base, 256, "kernel_group_norm_f32");
|
|
1266
|
+
snprintf(name, 256, "%s", base);
|
|
1267
|
+
|
|
1268
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1269
|
+
if (res) {
|
|
1270
|
+
return res;
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1274
|
+
|
|
1275
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1276
|
+
|
|
1277
|
+
return res;
|
|
1278
|
+
}
|
|
1279
|
+
|
|
1280
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op, int n_fuse) {
|
|
1281
|
+
assert(op->op == WSP_GGML_OP_NORM || op->op == WSP_GGML_OP_RMS_NORM);
|
|
1282
|
+
|
|
1283
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
1284
|
+
|
|
1285
|
+
char base[256];
|
|
1286
|
+
char name[256];
|
|
1287
|
+
|
|
1288
|
+
const char * suffix = "";
|
|
1289
|
+
if (op->ne[0] % 4 == 0) {
|
|
1290
|
+
suffix = "_4";
|
|
1291
|
+
}
|
|
1292
|
+
|
|
1293
|
+
switch (op->op) {
|
|
1294
|
+
case WSP_GGML_OP_NORM:
|
|
1295
|
+
switch (n_fuse) {
|
|
1296
|
+
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
|
|
1297
|
+
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
|
|
1298
|
+
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
|
|
1299
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1300
|
+
} break;
|
|
1301
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
1302
|
+
switch (n_fuse) {
|
|
1303
|
+
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
|
|
1304
|
+
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
|
|
1305
|
+
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
|
|
1306
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1307
|
+
} break;
|
|
1308
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
1309
|
+
}
|
|
1310
|
+
|
|
1311
|
+
snprintf(name, 256, "%s", base);
|
|
1312
|
+
|
|
1313
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1314
|
+
if (res) {
|
|
1315
|
+
return res;
|
|
1316
|
+
}
|
|
1317
|
+
|
|
1318
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1319
|
+
|
|
1320
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1321
|
+
|
|
1322
|
+
return res;
|
|
1323
|
+
}
|
|
1324
|
+
|
|
1325
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1326
|
+
assert(op->op == WSP_GGML_OP_ROPE);
|
|
1327
|
+
|
|
1328
|
+
char base[256];
|
|
1329
|
+
char name[256];
|
|
1330
|
+
|
|
1331
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
|
1332
|
+
|
|
1333
|
+
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
1334
|
+
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
1335
|
+
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
1336
|
+
|
|
1337
|
+
if (is_neox) {
|
|
1338
|
+
snprintf(base, 256, "kernel_rope_neox_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1339
|
+
} else if (is_mrope && !is_vision) {
|
|
1340
|
+
WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1341
|
+
snprintf(base, 256, "kernel_rope_multi_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1342
|
+
} else if (is_vision) {
|
|
1343
|
+
WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1344
|
+
snprintf(base, 256, "kernel_rope_vision_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1345
|
+
} else {
|
|
1346
|
+
snprintf(base, 256, "kernel_rope_norm_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
snprintf(name, 256, "%s", base);
|
|
1350
|
+
|
|
1351
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1352
|
+
if (res) {
|
|
1353
|
+
return res;
|
|
1354
|
+
}
|
|
1355
|
+
|
|
1356
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1357
|
+
|
|
1358
|
+
return res;
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1362
|
+
assert(op->op == WSP_GGML_OP_IM2COL);
|
|
1363
|
+
|
|
1364
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
1365
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1366
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F16 || op->type == WSP_GGML_TYPE_F32);
|
|
1367
|
+
|
|
1368
|
+
char base[256];
|
|
1369
|
+
char name[256];
|
|
1370
|
+
|
|
1371
|
+
snprintf(base, 256, "kernel_im2col_%s", wsp_ggml_type_name(op->type));
|
|
1372
|
+
snprintf(name, 256, "%s", base);
|
|
1373
|
+
|
|
1374
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1375
|
+
if (res) {
|
|
1376
|
+
return res;
|
|
1377
|
+
}
|
|
1378
|
+
|
|
1379
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1380
|
+
|
|
1381
|
+
return res;
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1385
|
+
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_1D);
|
|
1386
|
+
|
|
1387
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1388
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
1389
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1390
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1391
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1392
|
+
|
|
1393
|
+
char base[256];
|
|
1394
|
+
char name[256];
|
|
1395
|
+
|
|
1396
|
+
snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1397
|
+
snprintf(name, 256, "%s", base);
|
|
1398
|
+
|
|
1399
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1400
|
+
if (res) {
|
|
1401
|
+
return res;
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1405
|
+
|
|
1406
|
+
return res;
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1410
|
+
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
|
|
1411
|
+
|
|
1412
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1413
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
1414
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1415
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1416
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1417
|
+
|
|
1418
|
+
char base[256];
|
|
1419
|
+
char name[256];
|
|
1420
|
+
|
|
1421
|
+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1422
|
+
snprintf(name, 256, "%s", base);
|
|
1423
|
+
|
|
1424
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1425
|
+
if (res) {
|
|
1426
|
+
return res;
|
|
1427
|
+
}
|
|
1428
|
+
|
|
1429
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1430
|
+
|
|
1431
|
+
return res;
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1435
|
+
assert(op->op == WSP_GGML_OP_UPSCALE);
|
|
1436
|
+
|
|
1437
|
+
char base[256];
|
|
1438
|
+
char name[256];
|
|
1439
|
+
|
|
1440
|
+
snprintf(base, 256, "kernel_upscale_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1441
|
+
snprintf(name, 256, "%s", base);
|
|
1442
|
+
|
|
1443
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1444
|
+
if (res) {
|
|
1445
|
+
return res;
|
|
1446
|
+
}
|
|
1447
|
+
|
|
1448
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1449
|
+
|
|
1450
|
+
return res;
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1454
|
+
assert(op->op == WSP_GGML_OP_PAD);
|
|
1455
|
+
|
|
1456
|
+
char base[256];
|
|
1457
|
+
char name[256];
|
|
1458
|
+
|
|
1459
|
+
snprintf(base, 256, "kernel_pad_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1460
|
+
snprintf(name, 256, "%s", base);
|
|
1461
|
+
|
|
1462
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1463
|
+
if (res) {
|
|
1464
|
+
return res;
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1467
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1468
|
+
|
|
1469
|
+
return res;
|
|
1470
|
+
}
|
|
1471
|
+
|
|
1472
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1473
|
+
assert(op->op == WSP_GGML_OP_PAD_REFLECT_1D);
|
|
1474
|
+
|
|
1475
|
+
char base[256];
|
|
1476
|
+
char name[256];
|
|
1477
|
+
|
|
1478
|
+
snprintf(base, 256, "kernel_pad_reflect_1d_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1479
|
+
snprintf(name, 256, "%s", base);
|
|
1480
|
+
|
|
1481
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1482
|
+
if (res) {
|
|
1483
|
+
return res;
|
|
1484
|
+
}
|
|
1485
|
+
|
|
1486
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1487
|
+
|
|
1488
|
+
return res;
|
|
1489
|
+
}
|
|
1490
|
+
|
|
1491
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1492
|
+
assert(op->op == WSP_GGML_OP_ARANGE);
|
|
1493
|
+
|
|
1494
|
+
char base[256];
|
|
1495
|
+
char name[256];
|
|
1496
|
+
|
|
1497
|
+
snprintf(base, 256, "kernel_arange_%s", wsp_ggml_type_name(op->type));
|
|
1498
|
+
snprintf(name, 256, "%s", base);
|
|
1499
|
+
|
|
1500
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1501
|
+
if (res) {
|
|
1502
|
+
return res;
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1506
|
+
|
|
1507
|
+
return res;
|
|
1508
|
+
}
|
|
1509
|
+
|
|
1510
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1511
|
+
assert(op->op == WSP_GGML_OP_TIMESTEP_EMBEDDING);
|
|
1512
|
+
|
|
1513
|
+
char base[256];
|
|
1514
|
+
char name[256];
|
|
1515
|
+
|
|
1516
|
+
snprintf(base, 256, "kernel_timestep_embedding_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1517
|
+
snprintf(name, 256, "%s", base);
|
|
1518
|
+
|
|
1519
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1520
|
+
if (res) {
|
|
1521
|
+
return res;
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1525
|
+
|
|
1526
|
+
return res;
|
|
1527
|
+
}
|
|
1528
|
+
|
|
1529
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1530
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
|
|
1531
|
+
|
|
1532
|
+
char base[256];
|
|
1533
|
+
char name[256];
|
|
1534
|
+
|
|
1535
|
+
snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1536
|
+
snprintf(name, 256, "%s", base);
|
|
1537
|
+
|
|
1538
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1539
|
+
if (res) {
|
|
1540
|
+
return res;
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1544
|
+
|
|
1545
|
+
return res;
|
|
1546
|
+
}
|
|
1547
|
+
|
|
1548
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1549
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
|
|
1550
|
+
|
|
1551
|
+
char base[256];
|
|
1552
|
+
char name[256];
|
|
1553
|
+
|
|
1554
|
+
snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1555
|
+
snprintf(name, 256, "%s", base);
|
|
1556
|
+
|
|
1557
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1558
|
+
if (res) {
|
|
1559
|
+
return res;
|
|
1560
|
+
}
|
|
1561
|
+
|
|
1562
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1563
|
+
|
|
1564
|
+
return res;
|
|
1565
|
+
}
|