llama_cpp 0.4.0 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/examples/chat.rb +2 -2
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +23 -11
- data/ext/llama_cpp/src/ggml-alloc.c +13 -50
- data/ext/llama_cpp/src/ggml-cuda.cu +23 -11
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +130 -61
- data/ext/llama_cpp/src/ggml-metal.metal +44 -26
- data/ext/llama_cpp/src/ggml.c +637 -328
- data/ext/llama_cpp/src/ggml.h +45 -19
- data/ext/llama_cpp/src/k_quants.c +2 -2
- data/ext/llama_cpp/src/llama.cpp +426 -97
- data/ext/llama_cpp/src/llama.h +51 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -3
- metadata +2 -2
@@ -11,6 +11,7 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
14
15
|
#ifdef GGML_METAL_NDEBUG
|
15
16
|
#define metal_printf(...)
|
16
17
|
#else
|
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
|
|
33
34
|
struct ggml_metal_context {
|
34
35
|
int n_cb;
|
35
36
|
|
36
|
-
float * logits;
|
37
|
-
|
38
37
|
id<MTLDevice> device;
|
39
38
|
id<MTLCommandQueue> queue;
|
40
39
|
id<MTLLibrary> library;
|
41
40
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
43
|
+
|
44
|
+
dispatch_queue_t d_queue;
|
45
|
+
|
42
46
|
int n_buffers;
|
43
47
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
44
48
|
|
@@ -110,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
110
114
|
@end
|
111
115
|
|
112
116
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
113
|
-
|
117
|
+
metal_printf("%s: allocating\n", __func__);
|
114
118
|
|
115
119
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
116
120
|
|
117
|
-
ctx->n_cb = n_cb;
|
121
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
118
122
|
ctx->device = MTLCreateSystemDefaultDevice();
|
119
123
|
ctx->queue = [ctx->device newCommandQueue];
|
120
124
|
ctx->n_buffers = 0;
|
121
125
|
ctx->concur_list_len = 0;
|
122
126
|
|
127
|
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
123
128
|
|
124
129
|
#if 0
|
125
130
|
// compile from source string and show compile log
|
@@ -128,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
128
133
|
|
129
134
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
130
135
|
if (error) {
|
131
|
-
|
136
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
132
137
|
return NULL;
|
133
138
|
}
|
134
139
|
}
|
@@ -142,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
142
147
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
143
148
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
144
149
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
145
|
-
|
150
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
146
151
|
|
147
152
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
148
153
|
if (error) {
|
149
|
-
|
154
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
150
155
|
return NULL;
|
151
156
|
}
|
152
157
|
|
@@ -158,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
158
163
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
159
164
|
#endif
|
160
165
|
if (error) {
|
161
|
-
|
166
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
162
167
|
return NULL;
|
163
168
|
}
|
164
169
|
}
|
@@ -170,11 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
170
175
|
#define GGML_METAL_ADD_KERNEL(name) \
|
171
176
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
172
177
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
173
|
-
|
178
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
179
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
180
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
176
181
|
if (error) { \
|
177
|
-
|
182
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
178
183
|
return NULL; \
|
179
184
|
}
|
180
185
|
|
@@ -226,22 +231,80 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
226
231
|
#undef GGML_METAL_ADD_KERNEL
|
227
232
|
}
|
228
233
|
|
229
|
-
|
230
|
-
|
234
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
235
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
231
236
|
if (ctx->device.maxTransferRate != 0) {
|
232
|
-
|
237
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
233
238
|
} else {
|
234
|
-
|
239
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
235
240
|
}
|
236
241
|
|
237
242
|
return ctx;
|
238
243
|
}
|
239
244
|
|
240
245
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
241
|
-
|
246
|
+
metal_printf("%s: deallocating\n", __func__);
|
247
|
+
#define GGML_METAL_DEL_KERNEL(name) \
|
248
|
+
[ctx->function_##name release]; \
|
249
|
+
[ctx->pipeline_##name release];
|
250
|
+
|
251
|
+
GGML_METAL_DEL_KERNEL(add);
|
252
|
+
GGML_METAL_DEL_KERNEL(add_row);
|
253
|
+
GGML_METAL_DEL_KERNEL(mul);
|
254
|
+
GGML_METAL_DEL_KERNEL(mul_row);
|
255
|
+
GGML_METAL_DEL_KERNEL(scale);
|
256
|
+
GGML_METAL_DEL_KERNEL(silu);
|
257
|
+
GGML_METAL_DEL_KERNEL(relu);
|
258
|
+
GGML_METAL_DEL_KERNEL(gelu);
|
259
|
+
GGML_METAL_DEL_KERNEL(soft_max);
|
260
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
261
|
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
262
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
263
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
264
|
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
265
|
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
266
|
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
267
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
268
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
269
|
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
270
|
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
271
|
+
GGML_METAL_DEL_KERNEL(norm);
|
272
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
273
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
274
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
275
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
276
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
277
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
278
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
279
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
280
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
281
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
282
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
283
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
284
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
285
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
286
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
287
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
288
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
289
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
290
|
+
GGML_METAL_DEL_KERNEL(rope);
|
291
|
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
292
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
293
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
294
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
295
|
+
|
296
|
+
#undef GGML_METAL_DEL_KERNEL
|
297
|
+
|
242
298
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
243
299
|
[ctx->buffers[i].metal release];
|
244
300
|
}
|
301
|
+
|
302
|
+
[ctx->library release];
|
303
|
+
[ctx->queue release];
|
304
|
+
[ctx->device release];
|
305
|
+
|
306
|
+
dispatch_release(ctx->d_queue);
|
307
|
+
|
245
308
|
free(ctx);
|
246
309
|
}
|
247
310
|
|
@@ -249,7 +312,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
249
312
|
void * data = NULL;
|
250
313
|
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
251
314
|
if (result != 0) {
|
252
|
-
|
315
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
253
316
|
return NULL;
|
254
317
|
}
|
255
318
|
|
@@ -261,7 +324,7 @@ void ggml_metal_host_free(void * data) {
|
|
261
324
|
}
|
262
325
|
|
263
326
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
264
|
-
ctx->n_cb = n_cb;
|
327
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
265
328
|
}
|
266
329
|
|
267
330
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
@@ -277,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
277
340
|
// Metal buffer based on the host memory pointer
|
278
341
|
//
|
279
342
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
280
|
-
//
|
343
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
281
344
|
|
282
345
|
const int64_t tsize = ggml_nbytes(t);
|
283
346
|
|
@@ -288,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
288
351
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
289
352
|
*offs = (size_t) ioffs;
|
290
353
|
|
291
|
-
//
|
354
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
292
355
|
|
293
356
|
return ctx->buffers[i].metal;
|
294
357
|
}
|
295
358
|
}
|
296
359
|
|
297
|
-
|
360
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
298
361
|
|
299
362
|
return nil;
|
300
363
|
}
|
@@ -306,7 +369,7 @@ bool ggml_metal_add_buffer(
|
|
306
369
|
size_t size,
|
307
370
|
size_t max_size) {
|
308
371
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
309
|
-
|
372
|
+
metal_printf("%s: too many buffers\n", __func__);
|
310
373
|
return false;
|
311
374
|
}
|
312
375
|
|
@@ -316,7 +379,7 @@ bool ggml_metal_add_buffer(
|
|
316
379
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
317
380
|
|
318
381
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
319
|
-
|
382
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
320
383
|
return false;
|
321
384
|
}
|
322
385
|
}
|
@@ -337,11 +400,11 @@ bool ggml_metal_add_buffer(
|
|
337
400
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
338
401
|
|
339
402
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
340
|
-
|
403
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
341
404
|
return false;
|
342
405
|
}
|
343
406
|
|
344
|
-
|
407
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
345
408
|
|
346
409
|
++ctx->n_buffers;
|
347
410
|
} else {
|
@@ -361,27 +424,27 @@ bool ggml_metal_add_buffer(
|
|
361
424
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
362
425
|
|
363
426
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
364
|
-
|
427
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
365
428
|
return false;
|
366
429
|
}
|
367
430
|
|
368
|
-
|
431
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
369
432
|
if (i + size_step < size) {
|
370
|
-
|
433
|
+
metal_printf("\n");
|
371
434
|
}
|
372
435
|
|
373
436
|
++ctx->n_buffers;
|
374
437
|
}
|
375
438
|
}
|
376
439
|
|
377
|
-
|
440
|
+
metal_printf(", (%8.2f / %8.2f)",
|
378
441
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
379
442
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
380
443
|
|
381
444
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
382
|
-
|
445
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
383
446
|
} else {
|
384
|
-
|
447
|
+
metal_printf("\n");
|
385
448
|
}
|
386
449
|
}
|
387
450
|
|
@@ -391,8 +454,6 @@ bool ggml_metal_add_buffer(
|
|
391
454
|
void ggml_metal_set_tensor(
|
392
455
|
struct ggml_metal_context * ctx,
|
393
456
|
struct ggml_tensor * t) {
|
394
|
-
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
395
|
-
|
396
457
|
size_t offs;
|
397
458
|
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
398
459
|
|
@@ -402,8 +463,6 @@ void ggml_metal_set_tensor(
|
|
402
463
|
void ggml_metal_get_tensor(
|
403
464
|
struct ggml_metal_context * ctx,
|
404
465
|
struct ggml_tensor * t) {
|
405
|
-
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
406
|
-
|
407
466
|
size_t offs;
|
408
467
|
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
409
468
|
|
@@ -498,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
|
|
498
557
|
}
|
499
558
|
|
500
559
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
501
|
-
|
560
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
502
561
|
}
|
503
562
|
}
|
504
563
|
|
505
564
|
void ggml_metal_graph_compute(
|
506
565
|
struct ggml_metal_context * ctx,
|
507
566
|
struct ggml_cgraph * gf) {
|
508
|
-
|
567
|
+
@autoreleasepool {
|
509
568
|
|
510
569
|
// if there is ctx->concur_list, dispatch concurrently
|
511
570
|
// else fallback to serial dispatch
|
@@ -521,29 +580,25 @@ void ggml_metal_graph_compute(
|
|
521
580
|
|
522
581
|
const int n_cb = ctx->n_cb;
|
523
582
|
|
524
|
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
525
|
-
|
526
583
|
for (int i = 0; i < n_cb; ++i) {
|
527
|
-
command_buffers[i] = [ctx->queue commandBuffer];
|
584
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
528
585
|
|
529
586
|
// enqueue the command buffers in order to specify their execution order
|
530
|
-
[command_buffers[i] enqueue];
|
531
|
-
}
|
587
|
+
[ctx->command_buffers[i] enqueue];
|
532
588
|
|
533
|
-
|
534
|
-
|
589
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
590
|
+
}
|
535
591
|
|
536
592
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
537
593
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
538
594
|
|
539
|
-
dispatch_async(
|
595
|
+
dispatch_async(ctx->d_queue, ^{
|
540
596
|
size_t offs_src0 = 0;
|
541
597
|
size_t offs_src1 = 0;
|
542
598
|
size_t offs_dst = 0;
|
543
599
|
|
544
|
-
id<MTLCommandBuffer> command_buffer
|
545
|
-
|
546
|
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
600
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
601
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
547
602
|
|
548
603
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
604
|
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
@@ -556,7 +611,7 @@ void ggml_metal_graph_compute(
|
|
556
611
|
continue;
|
557
612
|
}
|
558
613
|
|
559
|
-
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
614
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
560
615
|
|
561
616
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
562
617
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -625,6 +680,12 @@ void ggml_metal_graph_compute(
|
|
625
680
|
} break;
|
626
681
|
case GGML_OP_ADD:
|
627
682
|
{
|
683
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
684
|
+
|
685
|
+
// utilize float4
|
686
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
687
|
+
const int64_t nb = ne00/4;
|
688
|
+
|
628
689
|
if (ggml_nelements(src1) == ne10) {
|
629
690
|
// src1 is a row
|
630
691
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -634,14 +695,20 @@ void ggml_metal_graph_compute(
|
|
634
695
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
635
696
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
636
697
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
637
|
-
[encoder setBytes:&
|
698
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
638
699
|
|
639
|
-
const int64_t n = ggml_nelements(dst);
|
700
|
+
const int64_t n = ggml_nelements(dst)/4;
|
640
701
|
|
641
702
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
642
703
|
} break;
|
643
704
|
case GGML_OP_MUL:
|
644
705
|
{
|
706
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
707
|
+
|
708
|
+
// utilize float4
|
709
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
710
|
+
const int64_t nb = ne00/4;
|
711
|
+
|
645
712
|
if (ggml_nelements(src1) == ne10) {
|
646
713
|
// src1 is a row
|
647
714
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -651,9 +718,9 @@ void ggml_metal_graph_compute(
|
|
651
718
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
652
719
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
653
720
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
654
|
-
[encoder setBytes:&
|
721
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
655
722
|
|
656
|
-
const int64_t n = ggml_nelements(dst);
|
723
|
+
const int64_t n = ggml_nelements(dst)/4;
|
657
724
|
|
658
725
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
659
726
|
} break;
|
@@ -704,7 +771,7 @@ void ggml_metal_graph_compute(
|
|
704
771
|
} break;
|
705
772
|
default:
|
706
773
|
{
|
707
|
-
|
774
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
708
775
|
GGML_ASSERT(false);
|
709
776
|
}
|
710
777
|
} break;
|
@@ -785,7 +852,7 @@ void ggml_metal_graph_compute(
|
|
785
852
|
switch (src0t) {
|
786
853
|
case GGML_TYPE_F16:
|
787
854
|
{
|
788
|
-
nth0 =
|
855
|
+
nth0 = 32;
|
789
856
|
nth1 = 1;
|
790
857
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
791
858
|
} break;
|
@@ -863,7 +930,7 @@ void ggml_metal_graph_compute(
|
|
863
930
|
} break;
|
864
931
|
default:
|
865
932
|
{
|
866
|
-
|
933
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
867
934
|
GGML_ASSERT(false && "not implemented");
|
868
935
|
}
|
869
936
|
};
|
@@ -1101,7 +1168,7 @@ void ggml_metal_graph_compute(
|
|
1101
1168
|
} break;
|
1102
1169
|
default:
|
1103
1170
|
{
|
1104
|
-
|
1171
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1105
1172
|
GGML_ASSERT(false);
|
1106
1173
|
}
|
1107
1174
|
}
|
@@ -1117,17 +1184,19 @@ void ggml_metal_graph_compute(
|
|
1117
1184
|
}
|
1118
1185
|
|
1119
1186
|
// wait for all threads to finish
|
1120
|
-
dispatch_barrier_sync(
|
1121
|
-
|
1122
|
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
1187
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
1123
1188
|
|
1124
1189
|
// check status of command buffers
|
1125
1190
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
1126
1191
|
for (int i = 0; i < n_cb; i++) {
|
1127
|
-
|
1192
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
1193
|
+
|
1194
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1128
1195
|
if (status != MTLCommandBufferStatusCompleted) {
|
1129
|
-
|
1196
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1130
1197
|
GGML_ASSERT(false);
|
1131
1198
|
}
|
1132
1199
|
}
|
1200
|
+
|
1201
|
+
}
|
1133
1202
|
}
|
@@ -25,9 +25,9 @@ typedef struct {
|
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
27
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
28
|
+
device const float4 * src0,
|
29
|
+
device const float4 * src1,
|
30
|
+
device float4 * dst,
|
31
31
|
uint tpig[[thread_position_in_grid]]) {
|
32
32
|
dst[tpig] = src0[tpig] + src1[tpig];
|
33
33
|
}
|
@@ -35,18 +35,18 @@ kernel void kernel_add(
|
|
35
35
|
// assumption: src1 is a row
|
36
36
|
// broadcast src1 into src0
|
37
37
|
kernel void kernel_add_row(
|
38
|
-
device const
|
39
|
-
device const
|
40
|
-
device
|
41
|
-
constant int64_t &
|
38
|
+
device const float4 * src0,
|
39
|
+
device const float4 * src1,
|
40
|
+
device float4 * dst,
|
41
|
+
constant int64_t & nb,
|
42
42
|
uint tpig[[thread_position_in_grid]]) {
|
43
|
-
dst[tpig] = src0[tpig] + src1[tpig %
|
43
|
+
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
44
|
}
|
45
45
|
|
46
46
|
kernel void kernel_mul(
|
47
|
-
device const
|
48
|
-
device const
|
49
|
-
device
|
47
|
+
device const float4 * src0,
|
48
|
+
device const float4 * src1,
|
49
|
+
device float4 * dst,
|
50
50
|
uint tpig[[thread_position_in_grid]]) {
|
51
51
|
dst[tpig] = src0[tpig] * src1[tpig];
|
52
52
|
}
|
@@ -54,12 +54,12 @@ kernel void kernel_mul(
|
|
54
54
|
// assumption: src1 is a row
|
55
55
|
// broadcast src1 into src0
|
56
56
|
kernel void kernel_mul_row(
|
57
|
-
device const
|
58
|
-
device const
|
59
|
-
device
|
60
|
-
constant
|
57
|
+
device const float4 * src0,
|
58
|
+
device const float4 * src1,
|
59
|
+
device float4 * dst,
|
60
|
+
constant int64_t & nb,
|
61
61
|
uint tpig[[thread_position_in_grid]]) {
|
62
|
-
dst[tpig] = src0[tpig] * src1[tpig %
|
62
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
|
|
528
528
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
529
529
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
530
530
|
|
531
|
-
|
531
|
+
uint ith = tpitg.x;
|
532
|
+
uint nth = tptg.x;
|
533
|
+
|
534
|
+
sum[ith] = 0.0f;
|
532
535
|
|
533
|
-
for (int i =
|
534
|
-
sum[
|
536
|
+
for (int i = ith; i < ne00; i += nth) {
|
537
|
+
sum[ith] += (float) x[i] * (float) y[i];
|
535
538
|
}
|
536
539
|
|
537
540
|
// accumulate the sum from all threads in the threadgroup
|
538
541
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
539
|
-
|
540
|
-
|
541
|
-
sum[tpitg.x] += sum[tpitg.x + i];
|
542
|
-
}
|
543
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
542
|
+
if (ith%4 == 0) {
|
543
|
+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
544
544
|
}
|
545
|
-
|
546
|
-
if (
|
545
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
546
|
+
if (ith%16 == 0) {
|
547
|
+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
548
|
+
}
|
549
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
550
|
+
if (ith == 0) {
|
551
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
547
552
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
548
553
|
}
|
554
|
+
|
555
|
+
// Original implementation. Left behind commented out for now
|
556
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
557
|
+
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
558
|
+
// if (tpitg.x < i) {
|
559
|
+
// sum[tpitg.x] += sum[tpitg.x + i];
|
560
|
+
// }
|
561
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
562
|
+
//}
|
563
|
+
//
|
564
|
+
//if (tpitg.x == 0) {
|
565
|
+
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
566
|
+
//}
|
549
567
|
}
|
550
568
|
|
551
569
|
kernel void kernel_alibi_f32(
|