llama_cpp 0.4.0 → 0.5.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/examples/chat.rb +2 -2
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +23 -11
- data/ext/llama_cpp/src/ggml-alloc.c +118 -73
- data/ext/llama_cpp/src/ggml-cuda.cu +106 -34
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +165 -72
- data/ext/llama_cpp/src/ggml-metal.metal +160 -89
- data/ext/llama_cpp/src/ggml-opencl.cpp +7 -7
- data/ext/llama_cpp/src/ggml.c +661 -380
- data/ext/llama_cpp/src/ggml.h +45 -19
- data/ext/llama_cpp/src/k_quants.c +47 -14
- data/ext/llama_cpp/src/llama.cpp +571 -166
- data/ext/llama_cpp/src/llama.h +54 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -3
- metadata +2 -2
@@ -11,6 +11,7 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
14
15
|
#ifdef GGML_METAL_NDEBUG
|
15
16
|
#define metal_printf(...)
|
16
17
|
#else
|
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
|
|
33
34
|
struct ggml_metal_context {
|
34
35
|
int n_cb;
|
35
36
|
|
36
|
-
float * logits;
|
37
|
-
|
38
37
|
id<MTLDevice> device;
|
39
38
|
id<MTLCommandQueue> queue;
|
40
39
|
id<MTLLibrary> library;
|
41
40
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
43
|
+
|
44
|
+
dispatch_queue_t d_queue;
|
45
|
+
|
42
46
|
int n_buffers;
|
43
47
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
44
48
|
|
@@ -72,6 +76,7 @@ struct ggml_metal_context {
|
|
72
76
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
73
77
|
GGML_METAL_DECL_KERNEL(norm);
|
74
78
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
79
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
75
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
76
81
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
77
82
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
@@ -110,16 +115,31 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
110
115
|
@end
|
111
116
|
|
112
117
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
113
|
-
|
118
|
+
metal_printf("%s: allocating\n", __func__);
|
119
|
+
|
120
|
+
// Show all the Metal device instances in the system
|
121
|
+
NSArray * devices = MTLCopyAllDevices();
|
122
|
+
id <MTLDevice> device;
|
123
|
+
NSString * s;
|
124
|
+
for (device in devices) {
|
125
|
+
s = [device name];
|
126
|
+
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
127
|
+
}
|
114
128
|
|
115
|
-
|
129
|
+
// Pick and show default Metal device
|
130
|
+
device = MTLCreateSystemDefaultDevice();
|
131
|
+
s = [device name];
|
132
|
+
metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
116
133
|
|
117
|
-
|
118
|
-
ctx
|
134
|
+
// Configure context
|
135
|
+
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
136
|
+
ctx->device = device;
|
137
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
119
138
|
ctx->queue = [ctx->device newCommandQueue];
|
120
139
|
ctx->n_buffers = 0;
|
121
140
|
ctx->concur_list_len = 0;
|
122
141
|
|
142
|
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
123
143
|
|
124
144
|
#if 0
|
125
145
|
// compile from source string and show compile log
|
@@ -128,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
128
148
|
|
129
149
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
130
150
|
if (error) {
|
131
|
-
|
151
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
132
152
|
return NULL;
|
133
153
|
}
|
134
154
|
}
|
@@ -142,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
142
162
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
143
163
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
144
164
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
145
|
-
|
165
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
146
166
|
|
147
167
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
148
168
|
if (error) {
|
149
|
-
|
169
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
150
170
|
return NULL;
|
151
171
|
}
|
152
172
|
|
@@ -158,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
158
178
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
159
179
|
#endif
|
160
180
|
if (error) {
|
161
|
-
|
181
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
162
182
|
return NULL;
|
163
183
|
}
|
164
184
|
}
|
@@ -170,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
170
190
|
#define GGML_METAL_ADD_KERNEL(name) \
|
171
191
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
172
192
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
173
|
-
|
193
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
194
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
195
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
176
196
|
if (error) { \
|
177
|
-
|
197
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
178
198
|
return NULL; \
|
179
199
|
}
|
180
200
|
|
@@ -200,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
200
220
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
201
221
|
GGML_METAL_ADD_KERNEL(norm);
|
202
222
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
223
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
203
224
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
204
225
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
205
226
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
@@ -226,30 +247,89 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
226
247
|
#undef GGML_METAL_ADD_KERNEL
|
227
248
|
}
|
228
249
|
|
229
|
-
|
230
|
-
|
250
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
251
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
231
252
|
if (ctx->device.maxTransferRate != 0) {
|
232
|
-
|
253
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
233
254
|
} else {
|
234
|
-
|
255
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
235
256
|
}
|
236
257
|
|
237
258
|
return ctx;
|
238
259
|
}
|
239
260
|
|
240
261
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
241
|
-
|
262
|
+
metal_printf("%s: deallocating\n", __func__);
|
263
|
+
#define GGML_METAL_DEL_KERNEL(name) \
|
264
|
+
[ctx->function_##name release]; \
|
265
|
+
[ctx->pipeline_##name release];
|
266
|
+
|
267
|
+
GGML_METAL_DEL_KERNEL(add);
|
268
|
+
GGML_METAL_DEL_KERNEL(add_row);
|
269
|
+
GGML_METAL_DEL_KERNEL(mul);
|
270
|
+
GGML_METAL_DEL_KERNEL(mul_row);
|
271
|
+
GGML_METAL_DEL_KERNEL(scale);
|
272
|
+
GGML_METAL_DEL_KERNEL(silu);
|
273
|
+
GGML_METAL_DEL_KERNEL(relu);
|
274
|
+
GGML_METAL_DEL_KERNEL(gelu);
|
275
|
+
GGML_METAL_DEL_KERNEL(soft_max);
|
276
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
277
|
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
278
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
279
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
280
|
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
281
|
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
282
|
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
283
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
284
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
285
|
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
286
|
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
287
|
+
GGML_METAL_DEL_KERNEL(norm);
|
288
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
289
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
290
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
291
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
292
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
293
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
294
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
295
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
296
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
297
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
298
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
299
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
300
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
301
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
302
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
303
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
304
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
305
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
306
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
307
|
+
GGML_METAL_DEL_KERNEL(rope);
|
308
|
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
309
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
310
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
311
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
312
|
+
|
313
|
+
#undef GGML_METAL_DEL_KERNEL
|
314
|
+
|
242
315
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
243
316
|
[ctx->buffers[i].metal release];
|
244
317
|
}
|
318
|
+
|
319
|
+
[ctx->library release];
|
320
|
+
[ctx->queue release];
|
321
|
+
[ctx->device release];
|
322
|
+
|
323
|
+
dispatch_release(ctx->d_queue);
|
324
|
+
|
245
325
|
free(ctx);
|
246
326
|
}
|
247
327
|
|
248
328
|
void * ggml_metal_host_malloc(size_t n) {
|
249
329
|
void * data = NULL;
|
250
|
-
const int result = posix_memalign((void **) &data,
|
330
|
+
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
251
331
|
if (result != 0) {
|
252
|
-
|
332
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
253
333
|
return NULL;
|
254
334
|
}
|
255
335
|
|
@@ -261,7 +341,7 @@ void ggml_metal_host_free(void * data) {
|
|
261
341
|
}
|
262
342
|
|
263
343
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
264
|
-
ctx->n_cb = n_cb;
|
344
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
265
345
|
}
|
266
346
|
|
267
347
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
@@ -277,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
277
357
|
// Metal buffer based on the host memory pointer
|
278
358
|
//
|
279
359
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
280
|
-
//
|
360
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
281
361
|
|
282
362
|
const int64_t tsize = ggml_nbytes(t);
|
283
363
|
|
@@ -288,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
288
368
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
289
369
|
*offs = (size_t) ioffs;
|
290
370
|
|
291
|
-
//
|
371
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
292
372
|
|
293
373
|
return ctx->buffers[i].metal;
|
294
374
|
}
|
295
375
|
}
|
296
376
|
|
297
|
-
|
377
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
298
378
|
|
299
379
|
return nil;
|
300
380
|
}
|
@@ -306,7 +386,7 @@ bool ggml_metal_add_buffer(
|
|
306
386
|
size_t size,
|
307
387
|
size_t max_size) {
|
308
388
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
309
|
-
|
389
|
+
metal_printf("%s: too many buffers\n", __func__);
|
310
390
|
return false;
|
311
391
|
}
|
312
392
|
|
@@ -316,12 +396,12 @@ bool ggml_metal_add_buffer(
|
|
316
396
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
317
397
|
|
318
398
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
319
|
-
|
399
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
320
400
|
return false;
|
321
401
|
}
|
322
402
|
}
|
323
403
|
|
324
|
-
const size_t size_page =
|
404
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
325
405
|
|
326
406
|
size_t size_aligned = size;
|
327
407
|
if ((size_aligned % size_page) != 0) {
|
@@ -337,11 +417,11 @@ bool ggml_metal_add_buffer(
|
|
337
417
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
338
418
|
|
339
419
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
340
|
-
|
420
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
341
421
|
return false;
|
342
422
|
}
|
343
423
|
|
344
|
-
|
424
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
345
425
|
|
346
426
|
++ctx->n_buffers;
|
347
427
|
} else {
|
@@ -361,27 +441,27 @@ bool ggml_metal_add_buffer(
|
|
361
441
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
362
442
|
|
363
443
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
364
|
-
|
444
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
365
445
|
return false;
|
366
446
|
}
|
367
447
|
|
368
|
-
|
448
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
369
449
|
if (i + size_step < size) {
|
370
|
-
|
450
|
+
metal_printf("\n");
|
371
451
|
}
|
372
452
|
|
373
453
|
++ctx->n_buffers;
|
374
454
|
}
|
375
455
|
}
|
376
456
|
|
377
|
-
|
457
|
+
metal_printf(", (%8.2f / %8.2f)",
|
378
458
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
379
459
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
380
460
|
|
381
461
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
382
|
-
|
462
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
383
463
|
} else {
|
384
|
-
|
464
|
+
metal_printf("\n");
|
385
465
|
}
|
386
466
|
}
|
387
467
|
|
@@ -391,8 +471,6 @@ bool ggml_metal_add_buffer(
|
|
391
471
|
void ggml_metal_set_tensor(
|
392
472
|
struct ggml_metal_context * ctx,
|
393
473
|
struct ggml_tensor * t) {
|
394
|
-
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
395
|
-
|
396
474
|
size_t offs;
|
397
475
|
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
398
476
|
|
@@ -402,8 +480,6 @@ void ggml_metal_set_tensor(
|
|
402
480
|
void ggml_metal_get_tensor(
|
403
481
|
struct ggml_metal_context * ctx,
|
404
482
|
struct ggml_tensor * t) {
|
405
|
-
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
406
|
-
|
407
483
|
size_t offs;
|
408
484
|
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
409
485
|
|
@@ -498,14 +574,14 @@ void ggml_metal_graph_find_concurrency(
|
|
498
574
|
}
|
499
575
|
|
500
576
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
501
|
-
|
577
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
502
578
|
}
|
503
579
|
}
|
504
580
|
|
505
581
|
void ggml_metal_graph_compute(
|
506
582
|
struct ggml_metal_context * ctx,
|
507
583
|
struct ggml_cgraph * gf) {
|
508
|
-
|
584
|
+
@autoreleasepool {
|
509
585
|
|
510
586
|
// if there is ctx->concur_list, dispatch concurrently
|
511
587
|
// else fallback to serial dispatch
|
@@ -521,29 +597,25 @@ void ggml_metal_graph_compute(
|
|
521
597
|
|
522
598
|
const int n_cb = ctx->n_cb;
|
523
599
|
|
524
|
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
525
|
-
|
526
600
|
for (int i = 0; i < n_cb; ++i) {
|
527
|
-
command_buffers[i] = [ctx->queue commandBuffer];
|
601
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
528
602
|
|
529
603
|
// enqueue the command buffers in order to specify their execution order
|
530
|
-
[command_buffers[i] enqueue];
|
531
|
-
}
|
604
|
+
[ctx->command_buffers[i] enqueue];
|
532
605
|
|
533
|
-
|
534
|
-
|
606
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
607
|
+
}
|
535
608
|
|
536
609
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
537
610
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
538
611
|
|
539
|
-
dispatch_async(
|
612
|
+
dispatch_async(ctx->d_queue, ^{
|
540
613
|
size_t offs_src0 = 0;
|
541
614
|
size_t offs_src1 = 0;
|
542
615
|
size_t offs_dst = 0;
|
543
616
|
|
544
|
-
id<MTLCommandBuffer> command_buffer
|
545
|
-
|
546
|
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
617
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
618
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
547
619
|
|
548
620
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
621
|
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
@@ -556,7 +628,7 @@ void ggml_metal_graph_compute(
|
|
556
628
|
continue;
|
557
629
|
}
|
558
630
|
|
559
|
-
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
631
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
560
632
|
|
561
633
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
562
634
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -625,6 +697,12 @@ void ggml_metal_graph_compute(
|
|
625
697
|
} break;
|
626
698
|
case GGML_OP_ADD:
|
627
699
|
{
|
700
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
701
|
+
|
702
|
+
// utilize float4
|
703
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
704
|
+
const int64_t nb = ne00/4;
|
705
|
+
|
628
706
|
if (ggml_nelements(src1) == ne10) {
|
629
707
|
// src1 is a row
|
630
708
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -634,14 +712,20 @@ void ggml_metal_graph_compute(
|
|
634
712
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
635
713
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
636
714
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
637
|
-
[encoder setBytes:&
|
715
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
638
716
|
|
639
|
-
const int64_t n = ggml_nelements(dst);
|
717
|
+
const int64_t n = ggml_nelements(dst)/4;
|
640
718
|
|
641
719
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
642
720
|
} break;
|
643
721
|
case GGML_OP_MUL:
|
644
722
|
{
|
723
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
724
|
+
|
725
|
+
// utilize float4
|
726
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
727
|
+
const int64_t nb = ne00/4;
|
728
|
+
|
645
729
|
if (ggml_nelements(src1) == ne10) {
|
646
730
|
// src1 is a row
|
647
731
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -651,9 +735,9 @@ void ggml_metal_graph_compute(
|
|
651
735
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
652
736
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
653
737
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
654
|
-
[encoder setBytes:&
|
738
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
655
739
|
|
656
|
-
const int64_t n = ggml_nelements(dst);
|
740
|
+
const int64_t n = ggml_nelements(dst)/4;
|
657
741
|
|
658
742
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
659
743
|
} break;
|
@@ -704,7 +788,7 @@ void ggml_metal_graph_compute(
|
|
704
788
|
} break;
|
705
789
|
default:
|
706
790
|
{
|
707
|
-
|
791
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
708
792
|
GGML_ASSERT(false);
|
709
793
|
}
|
710
794
|
} break;
|
@@ -785,9 +869,13 @@ void ggml_metal_graph_compute(
|
|
785
869
|
switch (src0t) {
|
786
870
|
case GGML_TYPE_F16:
|
787
871
|
{
|
788
|
-
nth0 =
|
872
|
+
nth0 = 32;
|
789
873
|
nth1 = 1;
|
790
|
-
|
874
|
+
if (ne11 * ne12 < 4) {
|
875
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
876
|
+
} else {
|
877
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
878
|
+
}
|
791
879
|
} break;
|
792
880
|
case GGML_TYPE_Q4_0:
|
793
881
|
{
|
@@ -839,8 +927,8 @@ void ggml_metal_graph_compute(
|
|
839
927
|
GGML_ASSERT(ne02 == 1);
|
840
928
|
GGML_ASSERT(ne12 == 1);
|
841
929
|
|
842
|
-
nth0 =
|
843
|
-
nth1 = 32;
|
930
|
+
nth0 = 4; //1;
|
931
|
+
nth1 = 8; //32;
|
844
932
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
845
933
|
} break;
|
846
934
|
case GGML_TYPE_Q5_K:
|
@@ -863,7 +951,7 @@ void ggml_metal_graph_compute(
|
|
863
951
|
} break;
|
864
952
|
default:
|
865
953
|
{
|
866
|
-
|
954
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
867
955
|
GGML_ASSERT(false && "not implemented");
|
868
956
|
}
|
869
957
|
};
|
@@ -888,9 +976,12 @@ void ggml_metal_graph_compute(
|
|
888
976
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
889
977
|
|
890
978
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
891
|
-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
979
|
+
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
|
892
980
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
893
981
|
}
|
982
|
+
else if (src0t == GGML_TYPE_Q4_K) {
|
983
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
984
|
+
}
|
894
985
|
else if (src0t == GGML_TYPE_Q3_K) {
|
895
986
|
#ifdef GGML_QKK_64
|
896
987
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -904,8 +995,8 @@ void ggml_metal_graph_compute(
|
|
904
995
|
else if (src0t == GGML_TYPE_Q6_K) {
|
905
996
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
906
997
|
} else {
|
907
|
-
|
908
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01,
|
998
|
+
int64_t ny = (ne11 + 3)/4;
|
999
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
909
1000
|
}
|
910
1001
|
}
|
911
1002
|
} break;
|
@@ -1050,7 +1141,7 @@ void ggml_metal_graph_compute(
|
|
1050
1141
|
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
1051
1142
|
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
1052
1143
|
|
1053
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
1144
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1054
1145
|
} break;
|
1055
1146
|
case GGML_OP_DUP:
|
1056
1147
|
case GGML_OP_CPY:
|
@@ -1101,7 +1192,7 @@ void ggml_metal_graph_compute(
|
|
1101
1192
|
} break;
|
1102
1193
|
default:
|
1103
1194
|
{
|
1104
|
-
|
1195
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1105
1196
|
GGML_ASSERT(false);
|
1106
1197
|
}
|
1107
1198
|
}
|
@@ -1117,17 +1208,19 @@ void ggml_metal_graph_compute(
|
|
1117
1208
|
}
|
1118
1209
|
|
1119
1210
|
// wait for all threads to finish
|
1120
|
-
dispatch_barrier_sync(
|
1121
|
-
|
1122
|
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
1211
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
1123
1212
|
|
1124
1213
|
// check status of command buffers
|
1125
1214
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
1126
1215
|
for (int i = 0; i < n_cb; i++) {
|
1127
|
-
|
1216
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
1217
|
+
|
1218
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1128
1219
|
if (status != MTLCommandBufferStatusCompleted) {
|
1129
|
-
|
1220
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1130
1221
|
GGML_ASSERT(false);
|
1131
1222
|
}
|
1132
1223
|
}
|
1224
|
+
|
1225
|
+
}
|
1133
1226
|
}
|