whisper.rn 0.3.9 → 0.4.0-rc.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -1
- package/android/src/main/jni.cpp +7 -1
- package/cpp/coreml/whisper-encoder.mm +7 -1
- package/cpp/ggml-alloc.c +633 -0
- package/cpp/ggml-alloc.h +26 -0
- package/cpp/ggml-metal.h +85 -0
- package/cpp/ggml-metal.m +1283 -0
- package/cpp/ggml-metal.metal +2353 -0
- package/cpp/ggml.c +5024 -2924
- package/cpp/ggml.h +569 -95
- package/cpp/whisper.cpp +1014 -667
- package/cpp/whisper.h +13 -0
- package/ios/RNWhisper.mm +2 -0
- package/ios/RNWhisperContext.h +1 -1
- package/ios/RNWhisperContext.mm +18 -4
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +3 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +3 -1
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +3 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/index.ts +4 -0
- package/whisper-rn.podspec +8 -2
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml-metal.m
ADDED
|
@@ -0,0 +1,1283 @@
|
|
|
1
|
+
#import "ggml-metal.h"
|
|
2
|
+
|
|
3
|
+
#import "ggml.h"
|
|
4
|
+
|
|
5
|
+
#import <Foundation/Foundation.h>
|
|
6
|
+
|
|
7
|
+
#import <Metal/Metal.h>
|
|
8
|
+
|
|
9
|
+
#undef MIN
|
|
10
|
+
#undef MAX
|
|
11
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
12
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
13
|
+
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
|
15
|
+
#ifdef WSP_GGML_METAL_NDEBUG
|
|
16
|
+
#define metal_printf(...)
|
|
17
|
+
#else
|
|
18
|
+
#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#define UNUSED(x) (void)(x)
|
|
22
|
+
|
|
23
|
+
#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES)
|
|
24
|
+
|
|
25
|
+
struct wsp_ggml_metal_buffer {
|
|
26
|
+
const char * name;
|
|
27
|
+
|
|
28
|
+
void * data;
|
|
29
|
+
size_t size;
|
|
30
|
+
|
|
31
|
+
id<MTLBuffer> metal;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
struct wsp_ggml_metal_context {
|
|
35
|
+
int n_cb;
|
|
36
|
+
|
|
37
|
+
id<MTLDevice> device;
|
|
38
|
+
id<MTLCommandQueue> queue;
|
|
39
|
+
id<MTLLibrary> library;
|
|
40
|
+
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
43
|
+
|
|
44
|
+
dispatch_queue_t d_queue;
|
|
45
|
+
|
|
46
|
+
int n_buffers;
|
|
47
|
+
struct wsp_ggml_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
|
|
48
|
+
|
|
49
|
+
int concur_list[WSP_GGML_MAX_CONCUR];
|
|
50
|
+
int concur_list_len;
|
|
51
|
+
|
|
52
|
+
// custom kernels
|
|
53
|
+
#define WSP_GGML_METAL_DECL_KERNEL(name) \
|
|
54
|
+
id<MTLFunction> function_##name; \
|
|
55
|
+
id<MTLComputePipelineState> pipeline_##name
|
|
56
|
+
|
|
57
|
+
WSP_GGML_METAL_DECL_KERNEL(add);
|
|
58
|
+
WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
59
|
+
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
60
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
61
|
+
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
62
|
+
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
63
|
+
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
64
|
+
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
65
|
+
WSP_GGML_METAL_DECL_KERNEL(soft_max);
|
|
66
|
+
WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
|
|
67
|
+
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
68
|
+
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
|
69
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_f32);
|
|
70
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
71
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
72
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
73
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
|
74
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
|
75
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
|
76
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
|
77
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
|
78
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
79
|
+
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
80
|
+
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
81
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
|
82
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
|
83
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
|
84
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
|
85
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
|
86
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
|
87
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
|
88
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
|
89
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
|
90
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
|
92
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
|
93
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
94
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
95
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
96
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
|
97
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
|
98
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
|
99
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
|
100
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
101
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
102
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
103
|
+
WSP_GGML_METAL_DECL_KERNEL(rope);
|
|
104
|
+
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
105
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
106
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
107
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
108
|
+
|
|
109
|
+
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
// MSL code
|
|
113
|
+
// TODO: move the contents here when ready
|
|
114
|
+
// for now it is easier to work in a separate file
|
|
115
|
+
static NSString * const msl_library_source = @"see metal.metal";
|
|
116
|
+
|
|
117
|
+
// Here to assist with NSBundle Path Hack
|
|
118
|
+
@interface GGMLMetalClass : NSObject
|
|
119
|
+
@end
|
|
120
|
+
@implementation GGMLMetalClass
|
|
121
|
+
@end
|
|
122
|
+
|
|
123
|
+
struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
124
|
+
metal_printf("%s: allocating\n", __func__);
|
|
125
|
+
|
|
126
|
+
id <MTLDevice> device;
|
|
127
|
+
NSString * s;
|
|
128
|
+
|
|
129
|
+
#if TARGET_OS_OSX
|
|
130
|
+
// Show all the Metal device instances in the system
|
|
131
|
+
NSArray * devices = MTLCopyAllDevices();
|
|
132
|
+
for (device in devices) {
|
|
133
|
+
s = [device name];
|
|
134
|
+
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
|
135
|
+
}
|
|
136
|
+
#endif
|
|
137
|
+
|
|
138
|
+
// Pick and show default Metal device
|
|
139
|
+
device = MTLCreateSystemDefaultDevice();
|
|
140
|
+
s = [device name];
|
|
141
|
+
metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
142
|
+
|
|
143
|
+
// Configure context
|
|
144
|
+
struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
|
|
145
|
+
ctx->device = device;
|
|
146
|
+
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
147
|
+
ctx->queue = [ctx->device newCommandQueue];
|
|
148
|
+
ctx->n_buffers = 0;
|
|
149
|
+
ctx->concur_list_len = 0;
|
|
150
|
+
|
|
151
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
152
|
+
|
|
153
|
+
#ifdef WSP_GGML_SWIFT
|
|
154
|
+
// load the default.metallib file
|
|
155
|
+
{
|
|
156
|
+
NSError * error = nil;
|
|
157
|
+
|
|
158
|
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
|
159
|
+
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
|
160
|
+
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
|
161
|
+
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
|
162
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
163
|
+
|
|
164
|
+
// Load the metallib file into a Metal library
|
|
165
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
166
|
+
|
|
167
|
+
if (error) {
|
|
168
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
169
|
+
return NULL;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
#else
|
|
173
|
+
UNUSED(msl_library_source);
|
|
174
|
+
|
|
175
|
+
// read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
|
|
176
|
+
{
|
|
177
|
+
NSError * error = nil;
|
|
178
|
+
|
|
179
|
+
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
|
180
|
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
|
181
|
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
|
182
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
|
183
|
+
|
|
184
|
+
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
|
185
|
+
if (error) {
|
|
186
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
187
|
+
return NULL;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
#ifdef WSP_GGML_QKK_64
|
|
191
|
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
|
192
|
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
|
193
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
194
|
+
#else
|
|
195
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
|
196
|
+
#endif
|
|
197
|
+
if (error) {
|
|
198
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
199
|
+
return NULL;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
#endif
|
|
203
|
+
|
|
204
|
+
// load kernels
|
|
205
|
+
{
|
|
206
|
+
NSError * error = nil;
|
|
207
|
+
#define WSP_GGML_METAL_ADD_KERNEL(name) \
|
|
208
|
+
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
|
209
|
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
|
210
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \
|
|
211
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
|
212
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
|
213
|
+
if (error) { \
|
|
214
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
215
|
+
return NULL; \
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
WSP_GGML_METAL_ADD_KERNEL(add);
|
|
219
|
+
WSP_GGML_METAL_ADD_KERNEL(add_row);
|
|
220
|
+
WSP_GGML_METAL_ADD_KERNEL(mul);
|
|
221
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_row);
|
|
222
|
+
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
223
|
+
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
224
|
+
WSP_GGML_METAL_ADD_KERNEL(relu);
|
|
225
|
+
WSP_GGML_METAL_ADD_KERNEL(gelu);
|
|
226
|
+
WSP_GGML_METAL_ADD_KERNEL(soft_max);
|
|
227
|
+
WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
|
|
228
|
+
WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
229
|
+
WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
|
230
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_f32);
|
|
231
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
232
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
|
233
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
234
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
|
235
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
|
236
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
|
237
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
|
238
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
|
239
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
240
|
+
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
241
|
+
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
242
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
|
243
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
|
244
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
|
245
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
|
246
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
|
247
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
|
248
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
|
249
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
|
250
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
|
251
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
|
252
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
|
253
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
|
254
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
|
255
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
256
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
|
257
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
|
258
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
|
259
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
|
260
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
|
261
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
|
262
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
|
263
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
264
|
+
WSP_GGML_METAL_ADD_KERNEL(rope);
|
|
265
|
+
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
266
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
267
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
268
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
269
|
+
|
|
270
|
+
#undef WSP_GGML_METAL_ADD_KERNEL
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
274
|
+
#if TARGET_OS_OSX
|
|
275
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
276
|
+
if (ctx->device.maxTransferRate != 0) {
|
|
277
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
278
|
+
} else {
|
|
279
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
280
|
+
}
|
|
281
|
+
#endif
|
|
282
|
+
|
|
283
|
+
return ctx;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
287
|
+
metal_printf("%s: deallocating\n", __func__);
|
|
288
|
+
#define WSP_GGML_METAL_DEL_KERNEL(name) \
|
|
289
|
+
|
|
290
|
+
WSP_GGML_METAL_DEL_KERNEL(add);
|
|
291
|
+
WSP_GGML_METAL_DEL_KERNEL(add_row);
|
|
292
|
+
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
293
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
294
|
+
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
295
|
+
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
296
|
+
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
297
|
+
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
298
|
+
WSP_GGML_METAL_DEL_KERNEL(soft_max);
|
|
299
|
+
WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
|
|
300
|
+
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
301
|
+
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
|
302
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_f32);
|
|
303
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
304
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
305
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
306
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
307
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
308
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
309
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
|
310
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
311
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
312
|
+
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
313
|
+
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
314
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
|
315
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
|
316
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
|
317
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
|
318
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
|
319
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
|
320
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
|
321
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
|
322
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
|
323
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
|
324
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
|
325
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
|
326
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
327
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
328
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
329
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
330
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
331
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
332
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
333
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
334
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
335
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
336
|
+
WSP_GGML_METAL_DEL_KERNEL(rope);
|
|
337
|
+
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
338
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
339
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
340
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
341
|
+
|
|
342
|
+
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
343
|
+
|
|
344
|
+
free(ctx);
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
void * wsp_ggml_metal_host_malloc(size_t n) {
|
|
348
|
+
void * data = NULL;
|
|
349
|
+
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
350
|
+
if (result != 0) {
|
|
351
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
|
352
|
+
return NULL;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
return data;
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
void wsp_ggml_metal_host_free(void * data) {
|
|
359
|
+
free(data);
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb) {
|
|
363
|
+
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) {
|
|
367
|
+
return ctx->concur_list_len;
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
|
|
371
|
+
return ctx->concur_list;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
375
|
+
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
376
|
+
// Metal buffer based on the host memory pointer
|
|
377
|
+
//
|
|
378
|
+
static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) {
|
|
379
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
|
380
|
+
|
|
381
|
+
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
382
|
+
|
|
383
|
+
// find the view that contains the tensor fully
|
|
384
|
+
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
385
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
386
|
+
|
|
387
|
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
|
388
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
|
389
|
+
*offs = (size_t) ioffs;
|
|
390
|
+
|
|
391
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
|
392
|
+
|
|
393
|
+
return ctx->buffers[i].metal;
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
|
398
|
+
|
|
399
|
+
return nil;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
bool wsp_ggml_metal_add_buffer(
|
|
403
|
+
struct wsp_ggml_metal_context * ctx,
|
|
404
|
+
const char * name,
|
|
405
|
+
void * data,
|
|
406
|
+
size_t size,
|
|
407
|
+
size_t max_size) {
|
|
408
|
+
if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
|
|
409
|
+
metal_printf("%s: too many buffers\n", __func__);
|
|
410
|
+
return false;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
if (data) {
|
|
414
|
+
// verify that the buffer does not overlap with any of the existing buffers
|
|
415
|
+
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
416
|
+
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
|
417
|
+
|
|
418
|
+
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
|
419
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
|
420
|
+
return false;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
425
|
+
|
|
426
|
+
size_t size_aligned = size;
|
|
427
|
+
if ((size_aligned % size_page) != 0) {
|
|
428
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
// the buffer fits into the max buffer size allowed by the device
|
|
432
|
+
if (size_aligned <= ctx->device.maxBufferLength) {
|
|
433
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
|
434
|
+
ctx->buffers[ctx->n_buffers].data = data;
|
|
435
|
+
ctx->buffers[ctx->n_buffers].size = size;
|
|
436
|
+
|
|
437
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
438
|
+
|
|
439
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
440
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
441
|
+
return false;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
445
|
+
|
|
446
|
+
++ctx->n_buffers;
|
|
447
|
+
} else {
|
|
448
|
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
449
|
+
// one of the views
|
|
450
|
+
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
451
|
+
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
|
452
|
+
const size_t size_view = ctx->device.maxBufferLength;
|
|
453
|
+
|
|
454
|
+
for (size_t i = 0; i < size; i += size_step) {
|
|
455
|
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
456
|
+
|
|
457
|
+
ctx->buffers[ctx->n_buffers].name = name;
|
|
458
|
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
|
459
|
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
460
|
+
|
|
461
|
+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
462
|
+
|
|
463
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
464
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
465
|
+
return false;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
469
|
+
if (i + size_step < size) {
|
|
470
|
+
metal_printf("\n");
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
++ctx->n_buffers;
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
#if TARGET_OS_OSX
|
|
478
|
+
metal_printf(", (%8.2f / %8.2f)",
|
|
479
|
+
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
480
|
+
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
481
|
+
|
|
482
|
+
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
483
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
|
484
|
+
} else {
|
|
485
|
+
metal_printf("\n");
|
|
486
|
+
}
|
|
487
|
+
#else
|
|
488
|
+
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
489
|
+
#endif
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
return true;
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
void wsp_ggml_metal_set_tensor(
|
|
496
|
+
struct wsp_ggml_metal_context * ctx,
|
|
497
|
+
struct wsp_ggml_tensor * t) {
|
|
498
|
+
size_t offs;
|
|
499
|
+
id<MTLBuffer> id_dst = wsp_ggml_metal_get_buffer(ctx, t, &offs);
|
|
500
|
+
|
|
501
|
+
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, wsp_ggml_nbytes(t));
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
void wsp_ggml_metal_get_tensor(
|
|
505
|
+
struct wsp_ggml_metal_context * ctx,
|
|
506
|
+
struct wsp_ggml_tensor * t) {
|
|
507
|
+
size_t offs;
|
|
508
|
+
id<MTLBuffer> id_src = wsp_ggml_metal_get_buffer(ctx, t, &offs);
|
|
509
|
+
|
|
510
|
+
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), wsp_ggml_nbytes(t));
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
void wsp_ggml_metal_graph_find_concurrency(
|
|
514
|
+
struct wsp_ggml_metal_context * ctx,
|
|
515
|
+
struct wsp_ggml_cgraph * gf, bool check_mem) {
|
|
516
|
+
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
|
517
|
+
int nodes_unused[WSP_GGML_MAX_CONCUR];
|
|
518
|
+
|
|
519
|
+
for (int i = 0; i < WSP_GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
|
520
|
+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
|
521
|
+
ctx->concur_list_len = 0;
|
|
522
|
+
|
|
523
|
+
int n_left = gf->n_nodes;
|
|
524
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
|
525
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
|
526
|
+
|
|
527
|
+
while (n_left > 0) {
|
|
528
|
+
// number of nodes at a layer (that can be issued concurrently)
|
|
529
|
+
int concurrency = 0;
|
|
530
|
+
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
|
531
|
+
if (nodes_unused[i]) {
|
|
532
|
+
// if the requirements for gf->nodes[i] are satisfied
|
|
533
|
+
int exe_flag = 1;
|
|
534
|
+
|
|
535
|
+
// scan all srcs
|
|
536
|
+
for (int src_ind = 0; src_ind < WSP_GGML_MAX_SRC; src_ind++) {
|
|
537
|
+
struct wsp_ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
|
538
|
+
if (src_cur) {
|
|
539
|
+
// if is leaf nodes it's satisfied.
|
|
540
|
+
// TODO: wsp_ggml_is_leaf()
|
|
541
|
+
if (src_cur->op == WSP_GGML_OP_NONE && src_cur->grad == NULL) {
|
|
542
|
+
continue;
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
// otherwise this src should be the output from previous nodes.
|
|
546
|
+
int is_found = 0;
|
|
547
|
+
|
|
548
|
+
// scan 2*search_depth back because we inserted barrier.
|
|
549
|
+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
|
550
|
+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
|
551
|
+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
|
552
|
+
is_found = 1;
|
|
553
|
+
break;
|
|
554
|
+
}
|
|
555
|
+
}
|
|
556
|
+
if (is_found == 0) {
|
|
557
|
+
exe_flag = 0;
|
|
558
|
+
break;
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
if (exe_flag && check_mem) {
|
|
563
|
+
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
|
564
|
+
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
|
565
|
+
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
|
566
|
+
int64_t length = (int64_t) wsp_ggml_nbytes(gf->nodes[i]);
|
|
567
|
+
for (int j = n_start; j < i; j++) {
|
|
568
|
+
if (nodes_unused[j] && gf->nodes[j]->op != WSP_GGML_OP_RESHAPE \
|
|
569
|
+
&& gf->nodes[j]->op != WSP_GGML_OP_VIEW \
|
|
570
|
+
&& gf->nodes[j]->op != WSP_GGML_OP_TRANSPOSE \
|
|
571
|
+
&& gf->nodes[j]->op != WSP_GGML_OP_PERMUTE) {
|
|
572
|
+
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
|
573
|
+
((int64_t)gf->nodes[j]->data) + (int64_t) wsp_ggml_nbytes(gf->nodes[j]) <= data_start) {
|
|
574
|
+
continue;
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
exe_flag = 0;
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
if (exe_flag) {
|
|
582
|
+
ctx->concur_list[level_pos + concurrency] = i;
|
|
583
|
+
nodes_unused[i] = 0;
|
|
584
|
+
concurrency++;
|
|
585
|
+
ctx->concur_list_len++;
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
n_left -= concurrency;
|
|
590
|
+
// adding a barrier different layer
|
|
591
|
+
ctx->concur_list[level_pos + concurrency] = -1;
|
|
592
|
+
ctx->concur_list_len++;
|
|
593
|
+
// jump all sorted nodes at nodes_bak
|
|
594
|
+
while (!nodes_unused[n_start]) {
|
|
595
|
+
n_start++;
|
|
596
|
+
}
|
|
597
|
+
level_pos += concurrency + 1;
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
|
|
601
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
void wsp_ggml_metal_graph_compute(
|
|
606
|
+
struct wsp_ggml_metal_context * ctx,
|
|
607
|
+
struct wsp_ggml_cgraph * gf) {
|
|
608
|
+
@autoreleasepool {
|
|
609
|
+
|
|
610
|
+
// if there is ctx->concur_list, dispatch concurrently
|
|
611
|
+
// else fallback to serial dispatch
|
|
612
|
+
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
|
613
|
+
|
|
614
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= WSP_GGML_MAX_CONCUR;
|
|
615
|
+
|
|
616
|
+
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
|
617
|
+
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
|
618
|
+
|
|
619
|
+
// create multiple command buffers and enqueue them
|
|
620
|
+
// then, we encode the graph into the command buffers in parallel
|
|
621
|
+
|
|
622
|
+
const int n_cb = ctx->n_cb;
|
|
623
|
+
|
|
624
|
+
for (int i = 0; i < n_cb; ++i) {
|
|
625
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
|
626
|
+
|
|
627
|
+
// enqueue the command buffers in order to specify their execution order
|
|
628
|
+
[ctx->command_buffers[i] enqueue];
|
|
629
|
+
|
|
630
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
634
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
|
635
|
+
|
|
636
|
+
dispatch_async(ctx->d_queue, ^{
|
|
637
|
+
size_t offs_src0 = 0;
|
|
638
|
+
size_t offs_src1 = 0;
|
|
639
|
+
size_t offs_dst = 0;
|
|
640
|
+
|
|
641
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
642
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
|
643
|
+
|
|
644
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
645
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
646
|
+
|
|
647
|
+
for (int ind = node_start; ind < node_end; ++ind) {
|
|
648
|
+
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
|
649
|
+
|
|
650
|
+
if (i == -1) {
|
|
651
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
652
|
+
continue;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
|
|
656
|
+
|
|
657
|
+
struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
658
|
+
struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
659
|
+
struct wsp_ggml_tensor * dst = gf->nodes[i];
|
|
660
|
+
|
|
661
|
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
662
|
+
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
663
|
+
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
664
|
+
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
665
|
+
|
|
666
|
+
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
667
|
+
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
668
|
+
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
669
|
+
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
670
|
+
|
|
671
|
+
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
672
|
+
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
673
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
674
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
|
675
|
+
|
|
676
|
+
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
677
|
+
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
678
|
+
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
679
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
|
680
|
+
|
|
681
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
682
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
683
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
684
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
685
|
+
|
|
686
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
687
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
688
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
689
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
690
|
+
|
|
691
|
+
const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
|
|
692
|
+
const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
|
|
693
|
+
const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
|
|
694
|
+
|
|
695
|
+
id<MTLBuffer> id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
|
696
|
+
id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
697
|
+
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
698
|
+
|
|
699
|
+
//metal_printf("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
700
|
+
//if (src0) {
|
|
701
|
+
// metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
|
|
702
|
+
// wsp_ggml_is_contiguous(src0), src0->name);
|
|
703
|
+
//}
|
|
704
|
+
//if (src1) {
|
|
705
|
+
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
|
|
706
|
+
// wsp_ggml_is_contiguous(src1), src1->name);
|
|
707
|
+
//}
|
|
708
|
+
//if (dst) {
|
|
709
|
+
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
|
|
710
|
+
// dst->name);
|
|
711
|
+
//}
|
|
712
|
+
|
|
713
|
+
switch (dst->op) {
|
|
714
|
+
case WSP_GGML_OP_NONE:
|
|
715
|
+
case WSP_GGML_OP_RESHAPE:
|
|
716
|
+
case WSP_GGML_OP_VIEW:
|
|
717
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
718
|
+
case WSP_GGML_OP_PERMUTE:
|
|
719
|
+
{
|
|
720
|
+
// noop
|
|
721
|
+
} break;
|
|
722
|
+
case WSP_GGML_OP_ADD:
|
|
723
|
+
{
|
|
724
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
725
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
726
|
+
|
|
727
|
+
// utilize float4
|
|
728
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
729
|
+
const int64_t nb = ne00/4;
|
|
730
|
+
|
|
731
|
+
if (wsp_ggml_nelements(src1) == ne10) {
|
|
732
|
+
// src1 is a row
|
|
733
|
+
WSP_GGML_ASSERT(ne11 == 1);
|
|
734
|
+
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
|
735
|
+
} else {
|
|
736
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
737
|
+
}
|
|
738
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
739
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
740
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
741
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
|
742
|
+
|
|
743
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
744
|
+
|
|
745
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
746
|
+
} break;
|
|
747
|
+
case WSP_GGML_OP_MUL:
|
|
748
|
+
{
|
|
749
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
750
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
751
|
+
|
|
752
|
+
// utilize float4
|
|
753
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
754
|
+
const int64_t nb = ne00/4;
|
|
755
|
+
|
|
756
|
+
if (wsp_ggml_nelements(src1) == ne10) {
|
|
757
|
+
// src1 is a row
|
|
758
|
+
WSP_GGML_ASSERT(ne11 == 1);
|
|
759
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
760
|
+
} else {
|
|
761
|
+
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
762
|
+
}
|
|
763
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
764
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
765
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
766
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
|
767
|
+
|
|
768
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
769
|
+
|
|
770
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
771
|
+
} break;
|
|
772
|
+
case WSP_GGML_OP_SCALE:
|
|
773
|
+
{
|
|
774
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
775
|
+
|
|
776
|
+
const float scale = *(const float *) src1->data;
|
|
777
|
+
|
|
778
|
+
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
779
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
780
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
781
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
782
|
+
|
|
783
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
784
|
+
|
|
785
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
786
|
+
} break;
|
|
787
|
+
case WSP_GGML_OP_UNARY:
|
|
788
|
+
switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
|
|
789
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
790
|
+
{
|
|
791
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
|
792
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
793
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
794
|
+
|
|
795
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
796
|
+
|
|
797
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
798
|
+
} break;
|
|
799
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
800
|
+
{
|
|
801
|
+
[encoder setComputePipelineState:ctx->pipeline_relu];
|
|
802
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
803
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
804
|
+
|
|
805
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
806
|
+
|
|
807
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
808
|
+
} break;
|
|
809
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
810
|
+
{
|
|
811
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
|
812
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
813
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
814
|
+
|
|
815
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
816
|
+
|
|
817
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
818
|
+
} break;
|
|
819
|
+
default:
|
|
820
|
+
{
|
|
821
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
822
|
+
WSP_GGML_ASSERT(false);
|
|
823
|
+
}
|
|
824
|
+
} break;
|
|
825
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
826
|
+
{
|
|
827
|
+
const int nth = 32;
|
|
828
|
+
|
|
829
|
+
if (ne00%4 == 0) {
|
|
830
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
831
|
+
} else {
|
|
832
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
833
|
+
}
|
|
834
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
835
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
836
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
837
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
838
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
839
|
+
|
|
840
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
841
|
+
} break;
|
|
842
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
843
|
+
{
|
|
844
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
845
|
+
|
|
846
|
+
if (ne00%8 == 0) {
|
|
847
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
|
848
|
+
} else {
|
|
849
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
|
850
|
+
}
|
|
851
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
852
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
853
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
854
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
855
|
+
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
856
|
+
|
|
857
|
+
if (ne00%8 == 0) {
|
|
858
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
859
|
+
}
|
|
860
|
+
else {
|
|
861
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
862
|
+
}
|
|
863
|
+
} break;
|
|
864
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
865
|
+
{
|
|
866
|
+
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
|
867
|
+
|
|
868
|
+
WSP_GGML_ASSERT(ne00 == ne10);
|
|
869
|
+
// WSP_GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
|
870
|
+
uint gqa = ne12/ne02;
|
|
871
|
+
WSP_GGML_ASSERT(ne03 == ne13);
|
|
872
|
+
|
|
873
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
874
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
875
|
+
if (!wsp_ggml_is_transposed(src0) &&
|
|
876
|
+
!wsp_ggml_is_transposed(src1) &&
|
|
877
|
+
src1t == WSP_GGML_TYPE_F32 &&
|
|
878
|
+
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
879
|
+
ne00%32 == 0 &&
|
|
880
|
+
ne11 > 1) {
|
|
881
|
+
switch (src0->type) {
|
|
882
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
883
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
|
884
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
|
885
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
|
886
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
|
887
|
+
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
|
888
|
+
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
|
889
|
+
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
|
890
|
+
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
|
891
|
+
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
|
892
|
+
default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
893
|
+
}
|
|
894
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
895
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
896
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
897
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
898
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
899
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
|
900
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
|
901
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
902
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
903
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
904
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
905
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
906
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
907
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
|
908
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
909
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
910
|
+
} else {
|
|
911
|
+
int nth0 = 32;
|
|
912
|
+
int nth1 = 1;
|
|
913
|
+
int nrows = 1;
|
|
914
|
+
|
|
915
|
+
// use custom matrix x vector kernel
|
|
916
|
+
switch (src0t) {
|
|
917
|
+
case WSP_GGML_TYPE_F32:
|
|
918
|
+
{
|
|
919
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
|
920
|
+
nrows = 4;
|
|
921
|
+
} break;
|
|
922
|
+
case WSP_GGML_TYPE_F16:
|
|
923
|
+
{
|
|
924
|
+
nth0 = 32;
|
|
925
|
+
nth1 = 1;
|
|
926
|
+
if (ne11 * ne12 < 4) {
|
|
927
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
|
928
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
929
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
|
930
|
+
nrows = ne11;
|
|
931
|
+
} else {
|
|
932
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
933
|
+
nrows = 4;
|
|
934
|
+
}
|
|
935
|
+
} break;
|
|
936
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
937
|
+
{
|
|
938
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
939
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
940
|
+
|
|
941
|
+
nth0 = 8;
|
|
942
|
+
nth1 = 8;
|
|
943
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
|
944
|
+
} break;
|
|
945
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
946
|
+
{
|
|
947
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
948
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
949
|
+
|
|
950
|
+
nth0 = 8;
|
|
951
|
+
nth1 = 8;
|
|
952
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
|
953
|
+
} break;
|
|
954
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
955
|
+
{
|
|
956
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
957
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
958
|
+
|
|
959
|
+
nth0 = 8;
|
|
960
|
+
nth1 = 8;
|
|
961
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
|
962
|
+
} break;
|
|
963
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
964
|
+
{
|
|
965
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
966
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
967
|
+
|
|
968
|
+
nth0 = 2;
|
|
969
|
+
nth1 = 32;
|
|
970
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
|
971
|
+
} break;
|
|
972
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
973
|
+
{
|
|
974
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
975
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
976
|
+
|
|
977
|
+
nth0 = 2;
|
|
978
|
+
nth1 = 32;
|
|
979
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
|
980
|
+
} break;
|
|
981
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
982
|
+
{
|
|
983
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
984
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
985
|
+
|
|
986
|
+
nth0 = 4; //1;
|
|
987
|
+
nth1 = 8; //32;
|
|
988
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
|
989
|
+
} break;
|
|
990
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
991
|
+
{
|
|
992
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
993
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
994
|
+
|
|
995
|
+
nth0 = 2;
|
|
996
|
+
nth1 = 32;
|
|
997
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
|
998
|
+
} break;
|
|
999
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
1000
|
+
{
|
|
1001
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
1002
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
1003
|
+
|
|
1004
|
+
nth0 = 2;
|
|
1005
|
+
nth1 = 32;
|
|
1006
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
|
1007
|
+
} break;
|
|
1008
|
+
default:
|
|
1009
|
+
{
|
|
1010
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
|
1011
|
+
WSP_GGML_ASSERT(false && "not implemented");
|
|
1012
|
+
}
|
|
1013
|
+
};
|
|
1014
|
+
|
|
1015
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1016
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1017
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1018
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1019
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1020
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1021
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1022
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1023
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1024
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
1025
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
1026
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
1027
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
1028
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
1029
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1030
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1031
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1032
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
|
1033
|
+
|
|
1034
|
+
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1035
|
+
src0t == WSP_GGML_TYPE_Q2_K) {// || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1036
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1037
|
+
}
|
|
1038
|
+
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1039
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1040
|
+
}
|
|
1041
|
+
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1042
|
+
#ifdef WSP_GGML_QKK_64
|
|
1043
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1044
|
+
#else
|
|
1045
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1046
|
+
#endif
|
|
1047
|
+
}
|
|
1048
|
+
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1049
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1050
|
+
}
|
|
1051
|
+
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1052
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1053
|
+
} else {
|
|
1054
|
+
int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1055
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
} break;
|
|
1059
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
1060
|
+
{
|
|
1061
|
+
switch (src0->type) {
|
|
1062
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
|
1063
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
|
1064
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
|
1065
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
1066
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
|
1067
|
+
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
|
1068
|
+
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
|
1069
|
+
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
|
1070
|
+
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
|
1071
|
+
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
|
1072
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1076
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1077
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1078
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1079
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1080
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
|
1081
|
+
|
|
1082
|
+
const int64_t n = wsp_ggml_nelements(src1);
|
|
1083
|
+
|
|
1084
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1085
|
+
} break;
|
|
1086
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
1087
|
+
{
|
|
1088
|
+
float eps;
|
|
1089
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1090
|
+
|
|
1091
|
+
const int nth = 512;
|
|
1092
|
+
|
|
1093
|
+
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
|
1094
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1095
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1096
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1097
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1098
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1099
|
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
|
1100
|
+
|
|
1101
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1102
|
+
|
|
1103
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1104
|
+
} break;
|
|
1105
|
+
case WSP_GGML_OP_NORM:
|
|
1106
|
+
{
|
|
1107
|
+
float eps;
|
|
1108
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1109
|
+
|
|
1110
|
+
const int nth = 256;
|
|
1111
|
+
|
|
1112
|
+
[encoder setComputePipelineState:ctx->pipeline_norm];
|
|
1113
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1114
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1115
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1116
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1117
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1118
|
+
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
|
1119
|
+
|
|
1120
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1121
|
+
|
|
1122
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1123
|
+
} break;
|
|
1124
|
+
case WSP_GGML_OP_ALIBI:
|
|
1125
|
+
{
|
|
1126
|
+
WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
|
|
1127
|
+
|
|
1128
|
+
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
|
1129
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
1130
|
+
float max_bias;
|
|
1131
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1132
|
+
|
|
1133
|
+
if (__builtin_popcount(n_head) != 1) {
|
|
1134
|
+
WSP_GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
1138
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
1139
|
+
|
|
1140
|
+
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
|
1141
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1142
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1143
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1144
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1145
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1146
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1147
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1148
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1149
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1150
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1151
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1152
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1153
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1154
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1155
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1156
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1157
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1158
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1159
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
1160
|
+
|
|
1161
|
+
const int nth = 32;
|
|
1162
|
+
|
|
1163
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1164
|
+
} break;
|
|
1165
|
+
case WSP_GGML_OP_ROPE:
|
|
1166
|
+
{
|
|
1167
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1168
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1169
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1170
|
+
|
|
1171
|
+
float freq_base;
|
|
1172
|
+
float freq_scale;
|
|
1173
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
|
1174
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1175
|
+
|
|
1176
|
+
[encoder setComputePipelineState:ctx->pipeline_rope];
|
|
1177
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1178
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1179
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1180
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1181
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1182
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1183
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1184
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1185
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1186
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1187
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1188
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1189
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1190
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1191
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1192
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1193
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1194
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1195
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
|
1196
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
|
1197
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
|
1198
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
|
1199
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
|
1200
|
+
|
|
1201
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1202
|
+
} break;
|
|
1203
|
+
case WSP_GGML_OP_DUP:
|
|
1204
|
+
case WSP_GGML_OP_CPY:
|
|
1205
|
+
case WSP_GGML_OP_CONT:
|
|
1206
|
+
{
|
|
1207
|
+
const int nth = 32;
|
|
1208
|
+
|
|
1209
|
+
switch (src0t) {
|
|
1210
|
+
case WSP_GGML_TYPE_F32:
|
|
1211
|
+
{
|
|
1212
|
+
switch (dstt) {
|
|
1213
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
1214
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
1215
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1216
|
+
};
|
|
1217
|
+
} break;
|
|
1218
|
+
case WSP_GGML_TYPE_F16:
|
|
1219
|
+
{
|
|
1220
|
+
switch (dstt) {
|
|
1221
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
|
1222
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
|
|
1223
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1224
|
+
};
|
|
1225
|
+
} break;
|
|
1226
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1230
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1231
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1232
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1233
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1234
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1235
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1236
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1237
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1238
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1239
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1240
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1241
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1242
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1243
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1244
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1245
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1246
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1247
|
+
|
|
1248
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1249
|
+
} break;
|
|
1250
|
+
default:
|
|
1251
|
+
{
|
|
1252
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
1253
|
+
WSP_GGML_ASSERT(false);
|
|
1254
|
+
}
|
|
1255
|
+
}
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
if (encoder != nil) {
|
|
1259
|
+
[encoder endEncoding];
|
|
1260
|
+
encoder = nil;
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
[command_buffer commit];
|
|
1264
|
+
});
|
|
1265
|
+
}
|
|
1266
|
+
|
|
1267
|
+
// wait for all threads to finish
|
|
1268
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
|
1269
|
+
|
|
1270
|
+
// check status of command buffers
|
|
1271
|
+
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
1272
|
+
for (int i = 0; i < n_cb; i++) {
|
|
1273
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
|
1274
|
+
|
|
1275
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
|
1276
|
+
if (status != MTLCommandBufferStatusCompleted) {
|
|
1277
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
1278
|
+
WSP_GGML_ASSERT(false);
|
|
1279
|
+
}
|
|
1280
|
+
}
|
|
1281
|
+
|
|
1282
|
+
}
|
|
1283
|
+
}
|