llama_cpp 0.3.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +1 -1
- data/examples/chat.rb +4 -6
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +129 -124
- data/ext/llama_cpp/src/ggml-alloc.c +90 -113
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +350 -77
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +226 -121
- data/ext/llama_cpp/src/ggml-metal.metal +157 -35
- data/ext/llama_cpp/src/ggml.c +2724 -584
- data/ext/llama_cpp/src/ggml.h +282 -31
- data/ext/llama_cpp/src/k_quants.c +112 -56
- data/ext/llama_cpp/src/llama.cpp +4857 -2986
- data/ext/llama_cpp/src/llama.h +180 -126
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -2
- data/sig/llama_cpp.rbs +12 -11
- metadata +2 -2
@@ -11,6 +11,7 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
14
15
|
#ifdef GGML_METAL_NDEBUG
|
15
16
|
#define metal_printf(...)
|
16
17
|
#else
|
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
|
|
33
34
|
struct ggml_metal_context {
|
34
35
|
int n_cb;
|
35
36
|
|
36
|
-
float * logits;
|
37
|
-
|
38
37
|
id<MTLDevice> device;
|
39
38
|
id<MTLCommandQueue> queue;
|
40
39
|
id<MTLLibrary> library;
|
41
40
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
43
|
+
|
44
|
+
dispatch_queue_t d_queue;
|
45
|
+
|
42
46
|
int n_buffers;
|
43
47
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
44
48
|
|
@@ -63,6 +67,7 @@ struct ggml_metal_context {
|
|
63
67
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
64
68
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
65
69
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
70
|
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
66
71
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
67
72
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
68
73
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
@@ -73,6 +78,7 @@ struct ggml_metal_context {
|
|
73
78
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
74
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
75
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
76
82
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
77
83
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
78
84
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
@@ -81,6 +87,7 @@ struct ggml_metal_context {
|
|
81
87
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
88
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
89
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
90
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
84
91
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
92
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
93
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
@@ -107,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
107
114
|
@end
|
108
115
|
|
109
116
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
110
|
-
|
117
|
+
metal_printf("%s: allocating\n", __func__);
|
111
118
|
|
112
119
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
113
120
|
|
114
|
-
ctx->n_cb = n_cb;
|
121
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
115
122
|
ctx->device = MTLCreateSystemDefaultDevice();
|
116
123
|
ctx->queue = [ctx->device newCommandQueue];
|
117
124
|
ctx->n_buffers = 0;
|
118
125
|
ctx->concur_list_len = 0;
|
119
126
|
|
127
|
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
120
128
|
|
121
129
|
#if 0
|
122
130
|
// compile from source string and show compile log
|
@@ -125,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
125
133
|
|
126
134
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
127
135
|
if (error) {
|
128
|
-
|
136
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
129
137
|
return NULL;
|
130
138
|
}
|
131
139
|
}
|
@@ -139,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
139
147
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
140
148
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
141
149
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
142
|
-
|
150
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
143
151
|
|
144
152
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
145
153
|
if (error) {
|
146
|
-
|
154
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
147
155
|
return NULL;
|
148
156
|
}
|
149
157
|
|
@@ -155,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
155
163
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
156
164
|
#endif
|
157
165
|
if (error) {
|
158
|
-
|
166
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
159
167
|
return NULL;
|
160
168
|
}
|
161
169
|
}
|
@@ -167,9 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
167
175
|
#define GGML_METAL_ADD_KERNEL(name) \
|
168
176
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
169
177
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
-
|
178
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
179
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
180
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
171
181
|
if (error) { \
|
172
|
-
|
182
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
183
|
return NULL; \
|
174
184
|
}
|
175
185
|
|
@@ -186,6 +196,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
186
196
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
187
197
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
188
198
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
199
|
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
189
200
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
190
201
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
191
202
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
@@ -196,6 +207,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
196
207
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
197
208
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
198
209
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
210
|
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
199
211
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
200
212
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
201
213
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
@@ -203,6 +215,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
203
215
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
216
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
217
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
218
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
206
219
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
220
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
221
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
@@ -218,27 +231,100 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
218
231
|
#undef GGML_METAL_ADD_KERNEL
|
219
232
|
}
|
220
233
|
|
221
|
-
|
222
|
-
|
234
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
235
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
223
236
|
if (ctx->device.maxTransferRate != 0) {
|
224
|
-
|
237
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
225
238
|
} else {
|
226
|
-
|
239
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
227
240
|
}
|
228
241
|
|
229
242
|
return ctx;
|
230
243
|
}
|
231
244
|
|
232
245
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
233
|
-
|
246
|
+
metal_printf("%s: deallocating\n", __func__);
|
247
|
+
#define GGML_METAL_DEL_KERNEL(name) \
|
248
|
+
[ctx->function_##name release]; \
|
249
|
+
[ctx->pipeline_##name release];
|
250
|
+
|
251
|
+
GGML_METAL_DEL_KERNEL(add);
|
252
|
+
GGML_METAL_DEL_KERNEL(add_row);
|
253
|
+
GGML_METAL_DEL_KERNEL(mul);
|
254
|
+
GGML_METAL_DEL_KERNEL(mul_row);
|
255
|
+
GGML_METAL_DEL_KERNEL(scale);
|
256
|
+
GGML_METAL_DEL_KERNEL(silu);
|
257
|
+
GGML_METAL_DEL_KERNEL(relu);
|
258
|
+
GGML_METAL_DEL_KERNEL(gelu);
|
259
|
+
GGML_METAL_DEL_KERNEL(soft_max);
|
260
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
261
|
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
262
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
263
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
264
|
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
265
|
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
266
|
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
267
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
268
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
269
|
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
270
|
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
271
|
+
GGML_METAL_DEL_KERNEL(norm);
|
272
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
273
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
274
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
275
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
276
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
277
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
278
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
279
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
280
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
281
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
282
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
283
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
284
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
285
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
286
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
287
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
288
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
289
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
290
|
+
GGML_METAL_DEL_KERNEL(rope);
|
291
|
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
292
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
293
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
294
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
295
|
+
|
296
|
+
#undef GGML_METAL_DEL_KERNEL
|
297
|
+
|
234
298
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
235
299
|
[ctx->buffers[i].metal release];
|
236
300
|
}
|
301
|
+
|
302
|
+
[ctx->library release];
|
303
|
+
[ctx->queue release];
|
304
|
+
[ctx->device release];
|
305
|
+
|
306
|
+
dispatch_release(ctx->d_queue);
|
307
|
+
|
237
308
|
free(ctx);
|
238
309
|
}
|
239
310
|
|
311
|
+
void * ggml_metal_host_malloc(size_t n) {
|
312
|
+
void * data = NULL;
|
313
|
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
314
|
+
if (result != 0) {
|
315
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
316
|
+
return NULL;
|
317
|
+
}
|
318
|
+
|
319
|
+
return data;
|
320
|
+
}
|
321
|
+
|
322
|
+
void ggml_metal_host_free(void * data) {
|
323
|
+
free(data);
|
324
|
+
}
|
325
|
+
|
240
326
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
241
|
-
ctx->n_cb = n_cb;
|
327
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
242
328
|
}
|
243
329
|
|
244
330
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
@@ -254,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
254
340
|
// Metal buffer based on the host memory pointer
|
255
341
|
//
|
256
342
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
257
|
-
//
|
343
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
258
344
|
|
259
345
|
const int64_t tsize = ggml_nbytes(t);
|
260
346
|
|
@@ -265,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
265
351
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
266
352
|
*offs = (size_t) ioffs;
|
267
353
|
|
268
|
-
//
|
354
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
269
355
|
|
270
356
|
return ctx->buffers[i].metal;
|
271
357
|
}
|
272
358
|
}
|
273
359
|
|
274
|
-
|
360
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
275
361
|
|
276
362
|
return nil;
|
277
363
|
}
|
@@ -283,7 +369,7 @@ bool ggml_metal_add_buffer(
|
|
283
369
|
size_t size,
|
284
370
|
size_t max_size) {
|
285
371
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
286
|
-
|
372
|
+
metal_printf("%s: too many buffers\n", __func__);
|
287
373
|
return false;
|
288
374
|
}
|
289
375
|
|
@@ -293,7 +379,7 @@ bool ggml_metal_add_buffer(
|
|
293
379
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
294
380
|
|
295
381
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
296
|
-
|
382
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
297
383
|
return false;
|
298
384
|
}
|
299
385
|
}
|
@@ -314,11 +400,11 @@ bool ggml_metal_add_buffer(
|
|
314
400
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
315
401
|
|
316
402
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
317
|
-
|
403
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
318
404
|
return false;
|
319
405
|
}
|
320
406
|
|
321
|
-
|
407
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
322
408
|
|
323
409
|
++ctx->n_buffers;
|
324
410
|
} else {
|
@@ -338,27 +424,27 @@ bool ggml_metal_add_buffer(
|
|
338
424
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
339
425
|
|
340
426
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
341
|
-
|
427
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
342
428
|
return false;
|
343
429
|
}
|
344
430
|
|
345
|
-
|
431
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
346
432
|
if (i + size_step < size) {
|
347
|
-
|
433
|
+
metal_printf("\n");
|
348
434
|
}
|
349
435
|
|
350
436
|
++ctx->n_buffers;
|
351
437
|
}
|
352
438
|
}
|
353
439
|
|
354
|
-
|
440
|
+
metal_printf(", (%8.2f / %8.2f)",
|
355
441
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
356
442
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
357
443
|
|
358
444
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
359
|
-
|
445
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
360
446
|
} else {
|
361
|
-
|
447
|
+
metal_printf("\n");
|
362
448
|
}
|
363
449
|
}
|
364
450
|
|
@@ -368,8 +454,6 @@ bool ggml_metal_add_buffer(
|
|
368
454
|
void ggml_metal_set_tensor(
|
369
455
|
struct ggml_metal_context * ctx,
|
370
456
|
struct ggml_tensor * t) {
|
371
|
-
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
372
|
-
|
373
457
|
size_t offs;
|
374
458
|
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
375
459
|
|
@@ -379,8 +463,6 @@ void ggml_metal_set_tensor(
|
|
379
463
|
void ggml_metal_get_tensor(
|
380
464
|
struct ggml_metal_context * ctx,
|
381
465
|
struct ggml_tensor * t) {
|
382
|
-
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
383
|
-
|
384
466
|
size_t offs;
|
385
467
|
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
386
468
|
|
@@ -475,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
|
|
475
557
|
}
|
476
558
|
|
477
559
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
478
|
-
|
560
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
479
561
|
}
|
480
562
|
}
|
481
563
|
|
482
564
|
void ggml_metal_graph_compute(
|
483
565
|
struct ggml_metal_context * ctx,
|
484
566
|
struct ggml_cgraph * gf) {
|
485
|
-
|
567
|
+
@autoreleasepool {
|
486
568
|
|
487
569
|
// if there is ctx->concur_list, dispatch concurrently
|
488
570
|
// else fallback to serial dispatch
|
@@ -498,32 +580,28 @@ void ggml_metal_graph_compute(
|
|
498
580
|
|
499
581
|
const int n_cb = ctx->n_cb;
|
500
582
|
|
501
|
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
502
|
-
|
503
583
|
for (int i = 0; i < n_cb; ++i) {
|
504
|
-
command_buffers[i] = [ctx->queue commandBuffer];
|
584
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
505
585
|
|
506
586
|
// enqueue the command buffers in order to specify their execution order
|
507
|
-
[command_buffers[i] enqueue];
|
508
|
-
}
|
587
|
+
[ctx->command_buffers[i] enqueue];
|
509
588
|
|
510
|
-
|
511
|
-
|
589
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
590
|
+
}
|
512
591
|
|
513
592
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
514
593
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
515
594
|
|
516
|
-
dispatch_async(
|
595
|
+
dispatch_async(ctx->d_queue, ^{
|
517
596
|
size_t offs_src0 = 0;
|
518
597
|
size_t offs_src1 = 0;
|
519
598
|
size_t offs_dst = 0;
|
520
599
|
|
521
|
-
id<MTLCommandBuffer> command_buffer
|
522
|
-
|
523
|
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
600
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
601
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
524
602
|
|
525
|
-
const int node_start =
|
526
|
-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
603
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
604
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
527
605
|
|
528
606
|
for (int ind = node_start; ind < node_end; ++ind) {
|
529
607
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
@@ -533,7 +611,7 @@ void ggml_metal_graph_compute(
|
|
533
611
|
continue;
|
534
612
|
}
|
535
613
|
|
536
|
-
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
614
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
537
615
|
|
538
616
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
539
617
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -602,6 +680,12 @@ void ggml_metal_graph_compute(
|
|
602
680
|
} break;
|
603
681
|
case GGML_OP_ADD:
|
604
682
|
{
|
683
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
684
|
+
|
685
|
+
// utilize float4
|
686
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
687
|
+
const int64_t nb = ne00/4;
|
688
|
+
|
605
689
|
if (ggml_nelements(src1) == ne10) {
|
606
690
|
// src1 is a row
|
607
691
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -611,14 +695,20 @@ void ggml_metal_graph_compute(
|
|
611
695
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
612
696
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
613
697
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
614
|
-
[encoder setBytes:&
|
698
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
615
699
|
|
616
|
-
const int64_t n = ggml_nelements(dst);
|
700
|
+
const int64_t n = ggml_nelements(dst)/4;
|
617
701
|
|
618
702
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
619
703
|
} break;
|
620
704
|
case GGML_OP_MUL:
|
621
705
|
{
|
706
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
707
|
+
|
708
|
+
// utilize float4
|
709
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
710
|
+
const int64_t nb = ne00/4;
|
711
|
+
|
622
712
|
if (ggml_nelements(src1) == ne10) {
|
623
713
|
// src1 is a row
|
624
714
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -628,9 +718,9 @@ void ggml_metal_graph_compute(
|
|
628
718
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
629
719
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
630
720
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
631
|
-
[encoder setBytes:&
|
721
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
632
722
|
|
633
|
-
const int64_t n = ggml_nelements(dst);
|
723
|
+
const int64_t n = ggml_nelements(dst)/4;
|
634
724
|
|
635
725
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
636
726
|
} break;
|
@@ -681,7 +771,7 @@ void ggml_metal_graph_compute(
|
|
681
771
|
} break;
|
682
772
|
default:
|
683
773
|
{
|
684
|
-
|
774
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
685
775
|
GGML_ASSERT(false);
|
686
776
|
}
|
687
777
|
} break;
|
@@ -729,32 +819,32 @@ void ggml_metal_graph_compute(
|
|
729
819
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
820
|
ne00%32 == 0 &&
|
731
821
|
ne11 > 1) {
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
-
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
822
|
+
switch (src0->type) {
|
823
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
824
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
825
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
826
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
827
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
828
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
829
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
830
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
831
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
832
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
756
833
|
}
|
757
|
-
|
834
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
835
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
836
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
837
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
838
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
839
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
840
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
841
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
842
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
843
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
844
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
845
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
846
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
847
|
+
} else {
|
758
848
|
int nth0 = 32;
|
759
849
|
int nth1 = 1;
|
760
850
|
|
@@ -762,7 +852,7 @@ void ggml_metal_graph_compute(
|
|
762
852
|
switch (src0t) {
|
763
853
|
case GGML_TYPE_F16:
|
764
854
|
{
|
765
|
-
nth0 =
|
855
|
+
nth0 = 32;
|
766
856
|
nth1 = 1;
|
767
857
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
768
858
|
} break;
|
@@ -784,6 +874,15 @@ void ggml_metal_graph_compute(
|
|
784
874
|
nth1 = 8;
|
785
875
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
786
876
|
} break;
|
877
|
+
case GGML_TYPE_Q8_0:
|
878
|
+
{
|
879
|
+
GGML_ASSERT(ne02 == 1);
|
880
|
+
GGML_ASSERT(ne12 == 1);
|
881
|
+
|
882
|
+
nth0 = 8;
|
883
|
+
nth1 = 8;
|
884
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
885
|
+
} break;
|
787
886
|
case GGML_TYPE_Q2_K:
|
788
887
|
{
|
789
888
|
GGML_ASSERT(ne02 == 1);
|
@@ -831,7 +930,7 @@ void ggml_metal_graph_compute(
|
|
831
930
|
} break;
|
832
931
|
default:
|
833
932
|
{
|
834
|
-
|
933
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
835
934
|
GGML_ASSERT(false && "not implemented");
|
836
935
|
}
|
837
936
|
};
|
@@ -853,24 +952,24 @@ void ggml_metal_graph_compute(
|
|
853
952
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
854
953
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
855
954
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
-
[encoder setBytes:&gqa
|
955
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
857
956
|
|
858
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
957
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
859
958
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
860
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)
|
959
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
861
960
|
}
|
862
961
|
else if (src0t == GGML_TYPE_Q3_K) {
|
863
962
|
#ifdef GGML_QKK_64
|
864
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
963
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
865
964
|
#else
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
965
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
867
966
|
#endif
|
868
967
|
}
|
869
968
|
else if (src0t == GGML_TYPE_Q5_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)
|
969
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
970
|
}
|
872
971
|
else if (src0t == GGML_TYPE_Q6_K) {
|
873
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
972
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
874
973
|
} else {
|
875
974
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
876
975
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -880,9 +979,10 @@ void ggml_metal_graph_compute(
|
|
880
979
|
case GGML_OP_GET_ROWS:
|
881
980
|
{
|
882
981
|
switch (src0->type) {
|
883
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
982
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
884
983
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
885
984
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
985
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
886
986
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
887
987
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
888
988
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
@@ -923,16 +1023,17 @@ void ggml_metal_graph_compute(
|
|
923
1023
|
} break;
|
924
1024
|
case GGML_OP_NORM:
|
925
1025
|
{
|
926
|
-
|
1026
|
+
float eps;
|
1027
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
927
1028
|
|
928
1029
|
const int nth = 256;
|
929
1030
|
|
930
1031
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
931
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
932
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
933
|
-
[encoder setBytes:&ne00
|
934
|
-
[encoder setBytes:&nb01
|
935
|
-
[encoder setBytes:&eps
|
1032
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1033
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1034
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1035
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
1036
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
936
1037
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
937
1038
|
|
938
1039
|
const int64_t nrows = ggml_nrows(src0);
|
@@ -975,7 +1076,9 @@ void ggml_metal_graph_compute(
|
|
975
1076
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
976
1077
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
977
1078
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1079
|
+
|
978
1080
|
const int nth = 32;
|
1081
|
+
|
979
1082
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
980
1083
|
} break;
|
981
1084
|
case GGML_OP_ROPE:
|
@@ -990,8 +1093,8 @@ void ggml_metal_graph_compute(
|
|
990
1093
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
991
1094
|
|
992
1095
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
993
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
994
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1096
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1097
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
995
1098
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
996
1099
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
997
1100
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
@@ -1042,30 +1145,30 @@ void ggml_metal_graph_compute(
|
|
1042
1145
|
default: GGML_ASSERT(false && "not implemented");
|
1043
1146
|
}
|
1044
1147
|
|
1045
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1046
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1047
|
-
[encoder setBytes:&ne00
|
1048
|
-
[encoder setBytes:&ne01
|
1049
|
-
[encoder setBytes:&ne02
|
1050
|
-
[encoder setBytes:&ne03
|
1051
|
-
[encoder setBytes:&nb00
|
1052
|
-
[encoder setBytes:&nb01
|
1053
|
-
[encoder setBytes:&nb02
|
1054
|
-
[encoder setBytes:&nb03
|
1055
|
-
[encoder setBytes:&ne0
|
1056
|
-
[encoder setBytes:&ne1
|
1057
|
-
[encoder setBytes:&ne2
|
1058
|
-
[encoder setBytes:&ne3
|
1059
|
-
[encoder setBytes:&nb0
|
1060
|
-
[encoder setBytes:&nb1
|
1061
|
-
[encoder setBytes:&nb2
|
1062
|
-
[encoder setBytes:&nb3
|
1148
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1149
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1150
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1151
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1152
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1153
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1154
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1155
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1156
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1157
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1158
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1159
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1160
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1161
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1162
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1163
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1164
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1165
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1063
1166
|
|
1064
1167
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1065
1168
|
} break;
|
1066
1169
|
default:
|
1067
1170
|
{
|
1068
|
-
|
1171
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1069
1172
|
GGML_ASSERT(false);
|
1070
1173
|
}
|
1071
1174
|
}
|
@@ -1081,17 +1184,19 @@ void ggml_metal_graph_compute(
|
|
1081
1184
|
}
|
1082
1185
|
|
1083
1186
|
// wait for all threads to finish
|
1084
|
-
dispatch_barrier_sync(
|
1085
|
-
|
1086
|
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
1187
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
1087
1188
|
|
1088
1189
|
// check status of command buffers
|
1089
1190
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
1090
1191
|
for (int i = 0; i < n_cb; i++) {
|
1091
|
-
|
1192
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
1193
|
+
|
1194
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1092
1195
|
if (status != MTLCommandBufferStatusCompleted) {
|
1093
|
-
|
1196
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1094
1197
|
GGML_ASSERT(false);
|
1095
1198
|
}
|
1096
1199
|
}
|
1200
|
+
|
1201
|
+
}
|
1097
1202
|
}
|