llama_cpp 0.3.8 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +1 -1
- data/examples/chat.rb +4 -6
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +129 -124
- data/ext/llama_cpp/src/ggml-alloc.c +90 -113
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +350 -77
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +226 -121
- data/ext/llama_cpp/src/ggml-metal.metal +157 -35
- data/ext/llama_cpp/src/ggml.c +2724 -584
- data/ext/llama_cpp/src/ggml.h +282 -31
- data/ext/llama_cpp/src/k_quants.c +112 -56
- data/ext/llama_cpp/src/llama.cpp +4857 -2986
- data/ext/llama_cpp/src/llama.h +180 -126
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -2
- data/sig/llama_cpp.rbs +12 -11
- metadata +2 -2
@@ -11,6 +11,7 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
14
15
|
#ifdef GGML_METAL_NDEBUG
|
15
16
|
#define metal_printf(...)
|
16
17
|
#else
|
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
|
|
33
34
|
struct ggml_metal_context {
|
34
35
|
int n_cb;
|
35
36
|
|
36
|
-
float * logits;
|
37
|
-
|
38
37
|
id<MTLDevice> device;
|
39
38
|
id<MTLCommandQueue> queue;
|
40
39
|
id<MTLLibrary> library;
|
41
40
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
43
|
+
|
44
|
+
dispatch_queue_t d_queue;
|
45
|
+
|
42
46
|
int n_buffers;
|
43
47
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
44
48
|
|
@@ -63,6 +67,7 @@ struct ggml_metal_context {
|
|
63
67
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
64
68
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
65
69
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
70
|
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
66
71
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
67
72
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
68
73
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
@@ -73,6 +78,7 @@ struct ggml_metal_context {
|
|
73
78
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
74
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
75
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
76
82
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
77
83
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
78
84
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
@@ -81,6 +87,7 @@ struct ggml_metal_context {
|
|
81
87
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
88
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
89
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
90
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
84
91
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
92
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
93
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
@@ -107,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
107
114
|
@end
|
108
115
|
|
109
116
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
110
|
-
|
117
|
+
metal_printf("%s: allocating\n", __func__);
|
111
118
|
|
112
119
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
113
120
|
|
114
|
-
ctx->n_cb = n_cb;
|
121
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
115
122
|
ctx->device = MTLCreateSystemDefaultDevice();
|
116
123
|
ctx->queue = [ctx->device newCommandQueue];
|
117
124
|
ctx->n_buffers = 0;
|
118
125
|
ctx->concur_list_len = 0;
|
119
126
|
|
127
|
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
120
128
|
|
121
129
|
#if 0
|
122
130
|
// compile from source string and show compile log
|
@@ -125,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
125
133
|
|
126
134
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
127
135
|
if (error) {
|
128
|
-
|
136
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
129
137
|
return NULL;
|
130
138
|
}
|
131
139
|
}
|
@@ -139,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
139
147
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
140
148
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
141
149
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
142
|
-
|
150
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
143
151
|
|
144
152
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
145
153
|
if (error) {
|
146
|
-
|
154
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
147
155
|
return NULL;
|
148
156
|
}
|
149
157
|
|
@@ -155,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
155
163
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
156
164
|
#endif
|
157
165
|
if (error) {
|
158
|
-
|
166
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
159
167
|
return NULL;
|
160
168
|
}
|
161
169
|
}
|
@@ -167,9 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
167
175
|
#define GGML_METAL_ADD_KERNEL(name) \
|
168
176
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
169
177
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
-
|
178
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
179
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
180
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
171
181
|
if (error) { \
|
172
|
-
|
182
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
183
|
return NULL; \
|
174
184
|
}
|
175
185
|
|
@@ -186,6 +196,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
186
196
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
187
197
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
188
198
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
199
|
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
189
200
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
190
201
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
191
202
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
@@ -196,6 +207,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
196
207
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
197
208
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
198
209
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
210
|
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
199
211
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
200
212
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
201
213
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
@@ -203,6 +215,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
203
215
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
216
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
217
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
218
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
206
219
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
220
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
221
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
@@ -218,27 +231,100 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
218
231
|
#undef GGML_METAL_ADD_KERNEL
|
219
232
|
}
|
220
233
|
|
221
|
-
|
222
|
-
|
234
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
235
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
223
236
|
if (ctx->device.maxTransferRate != 0) {
|
224
|
-
|
237
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
225
238
|
} else {
|
226
|
-
|
239
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
227
240
|
}
|
228
241
|
|
229
242
|
return ctx;
|
230
243
|
}
|
231
244
|
|
232
245
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
233
|
-
|
246
|
+
metal_printf("%s: deallocating\n", __func__);
|
247
|
+
#define GGML_METAL_DEL_KERNEL(name) \
|
248
|
+
[ctx->function_##name release]; \
|
249
|
+
[ctx->pipeline_##name release];
|
250
|
+
|
251
|
+
GGML_METAL_DEL_KERNEL(add);
|
252
|
+
GGML_METAL_DEL_KERNEL(add_row);
|
253
|
+
GGML_METAL_DEL_KERNEL(mul);
|
254
|
+
GGML_METAL_DEL_KERNEL(mul_row);
|
255
|
+
GGML_METAL_DEL_KERNEL(scale);
|
256
|
+
GGML_METAL_DEL_KERNEL(silu);
|
257
|
+
GGML_METAL_DEL_KERNEL(relu);
|
258
|
+
GGML_METAL_DEL_KERNEL(gelu);
|
259
|
+
GGML_METAL_DEL_KERNEL(soft_max);
|
260
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
261
|
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
262
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
263
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
264
|
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
265
|
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
266
|
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
267
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
268
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
269
|
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
270
|
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
271
|
+
GGML_METAL_DEL_KERNEL(norm);
|
272
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
273
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
274
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
275
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
276
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
277
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
278
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
279
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
280
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
281
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
282
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
283
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
284
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
285
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
286
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
287
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
288
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
289
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
290
|
+
GGML_METAL_DEL_KERNEL(rope);
|
291
|
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
292
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
293
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
294
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
295
|
+
|
296
|
+
#undef GGML_METAL_DEL_KERNEL
|
297
|
+
|
234
298
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
235
299
|
[ctx->buffers[i].metal release];
|
236
300
|
}
|
301
|
+
|
302
|
+
[ctx->library release];
|
303
|
+
[ctx->queue release];
|
304
|
+
[ctx->device release];
|
305
|
+
|
306
|
+
dispatch_release(ctx->d_queue);
|
307
|
+
|
237
308
|
free(ctx);
|
238
309
|
}
|
239
310
|
|
311
|
+
void * ggml_metal_host_malloc(size_t n) {
|
312
|
+
void * data = NULL;
|
313
|
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
314
|
+
if (result != 0) {
|
315
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
316
|
+
return NULL;
|
317
|
+
}
|
318
|
+
|
319
|
+
return data;
|
320
|
+
}
|
321
|
+
|
322
|
+
void ggml_metal_host_free(void * data) {
|
323
|
+
free(data);
|
324
|
+
}
|
325
|
+
|
240
326
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
241
|
-
ctx->n_cb = n_cb;
|
327
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
242
328
|
}
|
243
329
|
|
244
330
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
@@ -254,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
254
340
|
// Metal buffer based on the host memory pointer
|
255
341
|
//
|
256
342
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
257
|
-
//
|
343
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
258
344
|
|
259
345
|
const int64_t tsize = ggml_nbytes(t);
|
260
346
|
|
@@ -265,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
265
351
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
266
352
|
*offs = (size_t) ioffs;
|
267
353
|
|
268
|
-
//
|
354
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
269
355
|
|
270
356
|
return ctx->buffers[i].metal;
|
271
357
|
}
|
272
358
|
}
|
273
359
|
|
274
|
-
|
360
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
275
361
|
|
276
362
|
return nil;
|
277
363
|
}
|
@@ -283,7 +369,7 @@ bool ggml_metal_add_buffer(
|
|
283
369
|
size_t size,
|
284
370
|
size_t max_size) {
|
285
371
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
286
|
-
|
372
|
+
metal_printf("%s: too many buffers\n", __func__);
|
287
373
|
return false;
|
288
374
|
}
|
289
375
|
|
@@ -293,7 +379,7 @@ bool ggml_metal_add_buffer(
|
|
293
379
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
294
380
|
|
295
381
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
296
|
-
|
382
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
297
383
|
return false;
|
298
384
|
}
|
299
385
|
}
|
@@ -314,11 +400,11 @@ bool ggml_metal_add_buffer(
|
|
314
400
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
315
401
|
|
316
402
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
317
|
-
|
403
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
318
404
|
return false;
|
319
405
|
}
|
320
406
|
|
321
|
-
|
407
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
322
408
|
|
323
409
|
++ctx->n_buffers;
|
324
410
|
} else {
|
@@ -338,27 +424,27 @@ bool ggml_metal_add_buffer(
|
|
338
424
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
339
425
|
|
340
426
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
341
|
-
|
427
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
342
428
|
return false;
|
343
429
|
}
|
344
430
|
|
345
|
-
|
431
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
346
432
|
if (i + size_step < size) {
|
347
|
-
|
433
|
+
metal_printf("\n");
|
348
434
|
}
|
349
435
|
|
350
436
|
++ctx->n_buffers;
|
351
437
|
}
|
352
438
|
}
|
353
439
|
|
354
|
-
|
440
|
+
metal_printf(", (%8.2f / %8.2f)",
|
355
441
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
356
442
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
357
443
|
|
358
444
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
359
|
-
|
445
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
360
446
|
} else {
|
361
|
-
|
447
|
+
metal_printf("\n");
|
362
448
|
}
|
363
449
|
}
|
364
450
|
|
@@ -368,8 +454,6 @@ bool ggml_metal_add_buffer(
|
|
368
454
|
void ggml_metal_set_tensor(
|
369
455
|
struct ggml_metal_context * ctx,
|
370
456
|
struct ggml_tensor * t) {
|
371
|
-
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
372
|
-
|
373
457
|
size_t offs;
|
374
458
|
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
375
459
|
|
@@ -379,8 +463,6 @@ void ggml_metal_set_tensor(
|
|
379
463
|
void ggml_metal_get_tensor(
|
380
464
|
struct ggml_metal_context * ctx,
|
381
465
|
struct ggml_tensor * t) {
|
382
|
-
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
383
|
-
|
384
466
|
size_t offs;
|
385
467
|
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
386
468
|
|
@@ -475,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
|
|
475
557
|
}
|
476
558
|
|
477
559
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
478
|
-
|
560
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
479
561
|
}
|
480
562
|
}
|
481
563
|
|
482
564
|
void ggml_metal_graph_compute(
|
483
565
|
struct ggml_metal_context * ctx,
|
484
566
|
struct ggml_cgraph * gf) {
|
485
|
-
|
567
|
+
@autoreleasepool {
|
486
568
|
|
487
569
|
// if there is ctx->concur_list, dispatch concurrently
|
488
570
|
// else fallback to serial dispatch
|
@@ -498,32 +580,28 @@ void ggml_metal_graph_compute(
|
|
498
580
|
|
499
581
|
const int n_cb = ctx->n_cb;
|
500
582
|
|
501
|
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
502
|
-
|
503
583
|
for (int i = 0; i < n_cb; ++i) {
|
504
|
-
command_buffers[i] = [ctx->queue commandBuffer];
|
584
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
505
585
|
|
506
586
|
// enqueue the command buffers in order to specify their execution order
|
507
|
-
[command_buffers[i] enqueue];
|
508
|
-
}
|
587
|
+
[ctx->command_buffers[i] enqueue];
|
509
588
|
|
510
|
-
|
511
|
-
|
589
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
590
|
+
}
|
512
591
|
|
513
592
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
514
593
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
515
594
|
|
516
|
-
dispatch_async(
|
595
|
+
dispatch_async(ctx->d_queue, ^{
|
517
596
|
size_t offs_src0 = 0;
|
518
597
|
size_t offs_src1 = 0;
|
519
598
|
size_t offs_dst = 0;
|
520
599
|
|
521
|
-
id<MTLCommandBuffer> command_buffer
|
522
|
-
|
523
|
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
600
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
601
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
524
602
|
|
525
|
-
const int node_start =
|
526
|
-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
603
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
604
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
527
605
|
|
528
606
|
for (int ind = node_start; ind < node_end; ++ind) {
|
529
607
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
@@ -533,7 +611,7 @@ void ggml_metal_graph_compute(
|
|
533
611
|
continue;
|
534
612
|
}
|
535
613
|
|
536
|
-
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
614
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
537
615
|
|
538
616
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
539
617
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -602,6 +680,12 @@ void ggml_metal_graph_compute(
|
|
602
680
|
} break;
|
603
681
|
case GGML_OP_ADD:
|
604
682
|
{
|
683
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
684
|
+
|
685
|
+
// utilize float4
|
686
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
687
|
+
const int64_t nb = ne00/4;
|
688
|
+
|
605
689
|
if (ggml_nelements(src1) == ne10) {
|
606
690
|
// src1 is a row
|
607
691
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -611,14 +695,20 @@ void ggml_metal_graph_compute(
|
|
611
695
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
612
696
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
613
697
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
614
|
-
[encoder setBytes:&
|
698
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
615
699
|
|
616
|
-
const int64_t n = ggml_nelements(dst);
|
700
|
+
const int64_t n = ggml_nelements(dst)/4;
|
617
701
|
|
618
702
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
619
703
|
} break;
|
620
704
|
case GGML_OP_MUL:
|
621
705
|
{
|
706
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
707
|
+
|
708
|
+
// utilize float4
|
709
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
710
|
+
const int64_t nb = ne00/4;
|
711
|
+
|
622
712
|
if (ggml_nelements(src1) == ne10) {
|
623
713
|
// src1 is a row
|
624
714
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -628,9 +718,9 @@ void ggml_metal_graph_compute(
|
|
628
718
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
629
719
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
630
720
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
631
|
-
[encoder setBytes:&
|
721
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
632
722
|
|
633
|
-
const int64_t n = ggml_nelements(dst);
|
723
|
+
const int64_t n = ggml_nelements(dst)/4;
|
634
724
|
|
635
725
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
636
726
|
} break;
|
@@ -681,7 +771,7 @@ void ggml_metal_graph_compute(
|
|
681
771
|
} break;
|
682
772
|
default:
|
683
773
|
{
|
684
|
-
|
774
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
685
775
|
GGML_ASSERT(false);
|
686
776
|
}
|
687
777
|
} break;
|
@@ -729,32 +819,32 @@ void ggml_metal_graph_compute(
|
|
729
819
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
820
|
ne00%32 == 0 &&
|
731
821
|
ne11 > 1) {
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
-
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
822
|
+
switch (src0->type) {
|
823
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
824
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
825
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
826
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
827
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
828
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
829
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
830
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
831
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
832
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
756
833
|
}
|
757
|
-
|
834
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
835
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
836
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
837
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
838
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
839
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
840
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
841
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
842
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
843
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
844
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
845
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
846
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
847
|
+
} else {
|
758
848
|
int nth0 = 32;
|
759
849
|
int nth1 = 1;
|
760
850
|
|
@@ -762,7 +852,7 @@ void ggml_metal_graph_compute(
|
|
762
852
|
switch (src0t) {
|
763
853
|
case GGML_TYPE_F16:
|
764
854
|
{
|
765
|
-
nth0 =
|
855
|
+
nth0 = 32;
|
766
856
|
nth1 = 1;
|
767
857
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
768
858
|
} break;
|
@@ -784,6 +874,15 @@ void ggml_metal_graph_compute(
|
|
784
874
|
nth1 = 8;
|
785
875
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
786
876
|
} break;
|
877
|
+
case GGML_TYPE_Q8_0:
|
878
|
+
{
|
879
|
+
GGML_ASSERT(ne02 == 1);
|
880
|
+
GGML_ASSERT(ne12 == 1);
|
881
|
+
|
882
|
+
nth0 = 8;
|
883
|
+
nth1 = 8;
|
884
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
885
|
+
} break;
|
787
886
|
case GGML_TYPE_Q2_K:
|
788
887
|
{
|
789
888
|
GGML_ASSERT(ne02 == 1);
|
@@ -831,7 +930,7 @@ void ggml_metal_graph_compute(
|
|
831
930
|
} break;
|
832
931
|
default:
|
833
932
|
{
|
834
|
-
|
933
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
835
934
|
GGML_ASSERT(false && "not implemented");
|
836
935
|
}
|
837
936
|
};
|
@@ -853,24 +952,24 @@ void ggml_metal_graph_compute(
|
|
853
952
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
854
953
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
855
954
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
-
[encoder setBytes:&gqa
|
955
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
857
956
|
|
858
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
957
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
859
958
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
860
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)
|
959
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
861
960
|
}
|
862
961
|
else if (src0t == GGML_TYPE_Q3_K) {
|
863
962
|
#ifdef GGML_QKK_64
|
864
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
963
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
865
964
|
#else
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
965
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
867
966
|
#endif
|
868
967
|
}
|
869
968
|
else if (src0t == GGML_TYPE_Q5_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)
|
969
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
970
|
}
|
872
971
|
else if (src0t == GGML_TYPE_Q6_K) {
|
873
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
972
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
874
973
|
} else {
|
875
974
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
876
975
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -880,9 +979,10 @@ void ggml_metal_graph_compute(
|
|
880
979
|
case GGML_OP_GET_ROWS:
|
881
980
|
{
|
882
981
|
switch (src0->type) {
|
883
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
982
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
884
983
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
885
984
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
985
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
886
986
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
887
987
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
888
988
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
@@ -923,16 +1023,17 @@ void ggml_metal_graph_compute(
|
|
923
1023
|
} break;
|
924
1024
|
case GGML_OP_NORM:
|
925
1025
|
{
|
926
|
-
|
1026
|
+
float eps;
|
1027
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
927
1028
|
|
928
1029
|
const int nth = 256;
|
929
1030
|
|
930
1031
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
931
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
932
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
933
|
-
[encoder setBytes:&ne00
|
934
|
-
[encoder setBytes:&nb01
|
935
|
-
[encoder setBytes:&eps
|
1032
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1033
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1034
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1035
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
1036
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
936
1037
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
937
1038
|
|
938
1039
|
const int64_t nrows = ggml_nrows(src0);
|
@@ -975,7 +1076,9 @@ void ggml_metal_graph_compute(
|
|
975
1076
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
976
1077
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
977
1078
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1079
|
+
|
978
1080
|
const int nth = 32;
|
1081
|
+
|
979
1082
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
980
1083
|
} break;
|
981
1084
|
case GGML_OP_ROPE:
|
@@ -990,8 +1093,8 @@ void ggml_metal_graph_compute(
|
|
990
1093
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
991
1094
|
|
992
1095
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
993
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
994
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1096
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1097
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
995
1098
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
996
1099
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
997
1100
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
@@ -1042,30 +1145,30 @@ void ggml_metal_graph_compute(
|
|
1042
1145
|
default: GGML_ASSERT(false && "not implemented");
|
1043
1146
|
}
|
1044
1147
|
|
1045
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1046
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1047
|
-
[encoder setBytes:&ne00
|
1048
|
-
[encoder setBytes:&ne01
|
1049
|
-
[encoder setBytes:&ne02
|
1050
|
-
[encoder setBytes:&ne03
|
1051
|
-
[encoder setBytes:&nb00
|
1052
|
-
[encoder setBytes:&nb01
|
1053
|
-
[encoder setBytes:&nb02
|
1054
|
-
[encoder setBytes:&nb03
|
1055
|
-
[encoder setBytes:&ne0
|
1056
|
-
[encoder setBytes:&ne1
|
1057
|
-
[encoder setBytes:&ne2
|
1058
|
-
[encoder setBytes:&ne3
|
1059
|
-
[encoder setBytes:&nb0
|
1060
|
-
[encoder setBytes:&nb1
|
1061
|
-
[encoder setBytes:&nb2
|
1062
|
-
[encoder setBytes:&nb3
|
1148
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1149
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1150
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1151
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1152
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1153
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1154
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1155
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1156
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1157
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1158
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1159
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1160
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1161
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1162
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1163
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1164
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1165
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1063
1166
|
|
1064
1167
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1065
1168
|
} break;
|
1066
1169
|
default:
|
1067
1170
|
{
|
1068
|
-
|
1171
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1069
1172
|
GGML_ASSERT(false);
|
1070
1173
|
}
|
1071
1174
|
}
|
@@ -1081,17 +1184,19 @@ void ggml_metal_graph_compute(
|
|
1081
1184
|
}
|
1082
1185
|
|
1083
1186
|
// wait for all threads to finish
|
1084
|
-
dispatch_barrier_sync(
|
1085
|
-
|
1086
|
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
1187
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
1087
1188
|
|
1088
1189
|
// check status of command buffers
|
1089
1190
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
1090
1191
|
for (int i = 0; i < n_cb; i++) {
|
1091
|
-
|
1192
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
1193
|
+
|
1194
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1092
1195
|
if (status != MTLCommandBufferStatusCompleted) {
|
1093
|
-
|
1196
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1094
1197
|
GGML_ASSERT(false);
|
1095
1198
|
}
|
1096
1199
|
}
|
1200
|
+
|
1201
|
+
}
|
1097
1202
|
}
|