llama_cpp 0.4.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/examples/chat.rb +2 -2
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +23 -11
- data/ext/llama_cpp/src/ggml-alloc.c +13 -50
- data/ext/llama_cpp/src/ggml-cuda.cu +23 -11
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +130 -61
- data/ext/llama_cpp/src/ggml-metal.metal +44 -26
- data/ext/llama_cpp/src/ggml.c +637 -328
- data/ext/llama_cpp/src/ggml.h +45 -19
- data/ext/llama_cpp/src/k_quants.c +2 -2
- data/ext/llama_cpp/src/llama.cpp +426 -97
- data/ext/llama_cpp/src/llama.h +51 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -3
- metadata +2 -2
@@ -11,6 +11,7 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
+
// TODO: temporary - reuse llama.cpp logging
|
14
15
|
#ifdef GGML_METAL_NDEBUG
|
15
16
|
#define metal_printf(...)
|
16
17
|
#else
|
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
|
|
33
34
|
struct ggml_metal_context {
|
34
35
|
int n_cb;
|
35
36
|
|
36
|
-
float * logits;
|
37
|
-
|
38
37
|
id<MTLDevice> device;
|
39
38
|
id<MTLCommandQueue> queue;
|
40
39
|
id<MTLLibrary> library;
|
41
40
|
|
41
|
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
42
|
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
43
|
+
|
44
|
+
dispatch_queue_t d_queue;
|
45
|
+
|
42
46
|
int n_buffers;
|
43
47
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
44
48
|
|
@@ -110,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
110
114
|
@end
|
111
115
|
|
112
116
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
113
|
-
|
117
|
+
metal_printf("%s: allocating\n", __func__);
|
114
118
|
|
115
119
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
116
120
|
|
117
|
-
ctx->n_cb = n_cb;
|
121
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
118
122
|
ctx->device = MTLCreateSystemDefaultDevice();
|
119
123
|
ctx->queue = [ctx->device newCommandQueue];
|
120
124
|
ctx->n_buffers = 0;
|
121
125
|
ctx->concur_list_len = 0;
|
122
126
|
|
127
|
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
123
128
|
|
124
129
|
#if 0
|
125
130
|
// compile from source string and show compile log
|
@@ -128,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
128
133
|
|
129
134
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
130
135
|
if (error) {
|
131
|
-
|
136
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
132
137
|
return NULL;
|
133
138
|
}
|
134
139
|
}
|
@@ -142,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
142
147
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
143
148
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
144
149
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
145
|
-
|
150
|
+
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
146
151
|
|
147
152
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
148
153
|
if (error) {
|
149
|
-
|
154
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
150
155
|
return NULL;
|
151
156
|
}
|
152
157
|
|
@@ -158,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
158
163
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
159
164
|
#endif
|
160
165
|
if (error) {
|
161
|
-
|
166
|
+
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
162
167
|
return NULL;
|
163
168
|
}
|
164
169
|
}
|
@@ -170,11 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
170
175
|
#define GGML_METAL_ADD_KERNEL(name) \
|
171
176
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
172
177
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
173
|
-
|
178
|
+
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
179
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
180
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
176
181
|
if (error) { \
|
177
|
-
|
182
|
+
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
178
183
|
return NULL; \
|
179
184
|
}
|
180
185
|
|
@@ -226,22 +231,80 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
226
231
|
#undef GGML_METAL_ADD_KERNEL
|
227
232
|
}
|
228
233
|
|
229
|
-
|
230
|
-
|
234
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
235
|
+
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
231
236
|
if (ctx->device.maxTransferRate != 0) {
|
232
|
-
|
237
|
+
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
233
238
|
} else {
|
234
|
-
|
239
|
+
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
235
240
|
}
|
236
241
|
|
237
242
|
return ctx;
|
238
243
|
}
|
239
244
|
|
240
245
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
241
|
-
|
246
|
+
metal_printf("%s: deallocating\n", __func__);
|
247
|
+
#define GGML_METAL_DEL_KERNEL(name) \
|
248
|
+
[ctx->function_##name release]; \
|
249
|
+
[ctx->pipeline_##name release];
|
250
|
+
|
251
|
+
GGML_METAL_DEL_KERNEL(add);
|
252
|
+
GGML_METAL_DEL_KERNEL(add_row);
|
253
|
+
GGML_METAL_DEL_KERNEL(mul);
|
254
|
+
GGML_METAL_DEL_KERNEL(mul_row);
|
255
|
+
GGML_METAL_DEL_KERNEL(scale);
|
256
|
+
GGML_METAL_DEL_KERNEL(silu);
|
257
|
+
GGML_METAL_DEL_KERNEL(relu);
|
258
|
+
GGML_METAL_DEL_KERNEL(gelu);
|
259
|
+
GGML_METAL_DEL_KERNEL(soft_max);
|
260
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
261
|
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
262
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
263
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
264
|
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
265
|
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
266
|
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
267
|
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
268
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
269
|
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
270
|
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
271
|
+
GGML_METAL_DEL_KERNEL(norm);
|
272
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
273
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
274
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
275
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
276
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
277
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
278
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
279
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
280
|
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
281
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
282
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
283
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
284
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
285
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
286
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
287
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
288
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
289
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
290
|
+
GGML_METAL_DEL_KERNEL(rope);
|
291
|
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
292
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
293
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
294
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
295
|
+
|
296
|
+
#undef GGML_METAL_DEL_KERNEL
|
297
|
+
|
242
298
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
243
299
|
[ctx->buffers[i].metal release];
|
244
300
|
}
|
301
|
+
|
302
|
+
[ctx->library release];
|
303
|
+
[ctx->queue release];
|
304
|
+
[ctx->device release];
|
305
|
+
|
306
|
+
dispatch_release(ctx->d_queue);
|
307
|
+
|
245
308
|
free(ctx);
|
246
309
|
}
|
247
310
|
|
@@ -249,7 +312,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
249
312
|
void * data = NULL;
|
250
313
|
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
251
314
|
if (result != 0) {
|
252
|
-
|
315
|
+
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
253
316
|
return NULL;
|
254
317
|
}
|
255
318
|
|
@@ -261,7 +324,7 @@ void ggml_metal_host_free(void * data) {
|
|
261
324
|
}
|
262
325
|
|
263
326
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
264
|
-
ctx->n_cb = n_cb;
|
327
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
265
328
|
}
|
266
329
|
|
267
330
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
@@ -277,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
277
340
|
// Metal buffer based on the host memory pointer
|
278
341
|
//
|
279
342
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
280
|
-
//
|
343
|
+
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
281
344
|
|
282
345
|
const int64_t tsize = ggml_nbytes(t);
|
283
346
|
|
@@ -288,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
288
351
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
289
352
|
*offs = (size_t) ioffs;
|
290
353
|
|
291
|
-
//
|
354
|
+
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
292
355
|
|
293
356
|
return ctx->buffers[i].metal;
|
294
357
|
}
|
295
358
|
}
|
296
359
|
|
297
|
-
|
360
|
+
metal_printf("%s: error: buffer is nil\n", __func__);
|
298
361
|
|
299
362
|
return nil;
|
300
363
|
}
|
@@ -306,7 +369,7 @@ bool ggml_metal_add_buffer(
|
|
306
369
|
size_t size,
|
307
370
|
size_t max_size) {
|
308
371
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
309
|
-
|
372
|
+
metal_printf("%s: too many buffers\n", __func__);
|
310
373
|
return false;
|
311
374
|
}
|
312
375
|
|
@@ -316,7 +379,7 @@ bool ggml_metal_add_buffer(
|
|
316
379
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
317
380
|
|
318
381
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
319
|
-
|
382
|
+
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
320
383
|
return false;
|
321
384
|
}
|
322
385
|
}
|
@@ -337,11 +400,11 @@ bool ggml_metal_add_buffer(
|
|
337
400
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
338
401
|
|
339
402
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
340
|
-
|
403
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
341
404
|
return false;
|
342
405
|
}
|
343
406
|
|
344
|
-
|
407
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
345
408
|
|
346
409
|
++ctx->n_buffers;
|
347
410
|
} else {
|
@@ -361,27 +424,27 @@ bool ggml_metal_add_buffer(
|
|
361
424
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
362
425
|
|
363
426
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
364
|
-
|
427
|
+
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
365
428
|
return false;
|
366
429
|
}
|
367
430
|
|
368
|
-
|
431
|
+
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
369
432
|
if (i + size_step < size) {
|
370
|
-
|
433
|
+
metal_printf("\n");
|
371
434
|
}
|
372
435
|
|
373
436
|
++ctx->n_buffers;
|
374
437
|
}
|
375
438
|
}
|
376
439
|
|
377
|
-
|
440
|
+
metal_printf(", (%8.2f / %8.2f)",
|
378
441
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
379
442
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
380
443
|
|
381
444
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
382
|
-
|
445
|
+
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
383
446
|
} else {
|
384
|
-
|
447
|
+
metal_printf("\n");
|
385
448
|
}
|
386
449
|
}
|
387
450
|
|
@@ -391,8 +454,6 @@ bool ggml_metal_add_buffer(
|
|
391
454
|
void ggml_metal_set_tensor(
|
392
455
|
struct ggml_metal_context * ctx,
|
393
456
|
struct ggml_tensor * t) {
|
394
|
-
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
395
|
-
|
396
457
|
size_t offs;
|
397
458
|
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
398
459
|
|
@@ -402,8 +463,6 @@ void ggml_metal_set_tensor(
|
|
402
463
|
void ggml_metal_get_tensor(
|
403
464
|
struct ggml_metal_context * ctx,
|
404
465
|
struct ggml_tensor * t) {
|
405
|
-
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
406
|
-
|
407
466
|
size_t offs;
|
408
467
|
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
409
468
|
|
@@ -498,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
|
|
498
557
|
}
|
499
558
|
|
500
559
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
501
|
-
|
560
|
+
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
502
561
|
}
|
503
562
|
}
|
504
563
|
|
505
564
|
void ggml_metal_graph_compute(
|
506
565
|
struct ggml_metal_context * ctx,
|
507
566
|
struct ggml_cgraph * gf) {
|
508
|
-
|
567
|
+
@autoreleasepool {
|
509
568
|
|
510
569
|
// if there is ctx->concur_list, dispatch concurrently
|
511
570
|
// else fallback to serial dispatch
|
@@ -521,29 +580,25 @@ void ggml_metal_graph_compute(
|
|
521
580
|
|
522
581
|
const int n_cb = ctx->n_cb;
|
523
582
|
|
524
|
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
525
|
-
|
526
583
|
for (int i = 0; i < n_cb; ++i) {
|
527
|
-
command_buffers[i] = [ctx->queue commandBuffer];
|
584
|
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
528
585
|
|
529
586
|
// enqueue the command buffers in order to specify their execution order
|
530
|
-
[command_buffers[i] enqueue];
|
531
|
-
}
|
587
|
+
[ctx->command_buffers[i] enqueue];
|
532
588
|
|
533
|
-
|
534
|
-
|
589
|
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
590
|
+
}
|
535
591
|
|
536
592
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
537
593
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
538
594
|
|
539
|
-
dispatch_async(
|
595
|
+
dispatch_async(ctx->d_queue, ^{
|
540
596
|
size_t offs_src0 = 0;
|
541
597
|
size_t offs_src1 = 0;
|
542
598
|
size_t offs_dst = 0;
|
543
599
|
|
544
|
-
id<MTLCommandBuffer> command_buffer
|
545
|
-
|
546
|
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
600
|
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
601
|
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
547
602
|
|
548
603
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
604
|
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
@@ -556,7 +611,7 @@ void ggml_metal_graph_compute(
|
|
556
611
|
continue;
|
557
612
|
}
|
558
613
|
|
559
|
-
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
614
|
+
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
560
615
|
|
561
616
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
562
617
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -625,6 +680,12 @@ void ggml_metal_graph_compute(
|
|
625
680
|
} break;
|
626
681
|
case GGML_OP_ADD:
|
627
682
|
{
|
683
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
684
|
+
|
685
|
+
// utilize float4
|
686
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
687
|
+
const int64_t nb = ne00/4;
|
688
|
+
|
628
689
|
if (ggml_nelements(src1) == ne10) {
|
629
690
|
// src1 is a row
|
630
691
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -634,14 +695,20 @@ void ggml_metal_graph_compute(
|
|
634
695
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
635
696
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
636
697
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
637
|
-
[encoder setBytes:&
|
698
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
638
699
|
|
639
|
-
const int64_t n = ggml_nelements(dst);
|
700
|
+
const int64_t n = ggml_nelements(dst)/4;
|
640
701
|
|
641
702
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
642
703
|
} break;
|
643
704
|
case GGML_OP_MUL:
|
644
705
|
{
|
706
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
707
|
+
|
708
|
+
// utilize float4
|
709
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
710
|
+
const int64_t nb = ne00/4;
|
711
|
+
|
645
712
|
if (ggml_nelements(src1) == ne10) {
|
646
713
|
// src1 is a row
|
647
714
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -651,9 +718,9 @@ void ggml_metal_graph_compute(
|
|
651
718
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
652
719
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
653
720
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
654
|
-
[encoder setBytes:&
|
721
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
655
722
|
|
656
|
-
const int64_t n = ggml_nelements(dst);
|
723
|
+
const int64_t n = ggml_nelements(dst)/4;
|
657
724
|
|
658
725
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
659
726
|
} break;
|
@@ -704,7 +771,7 @@ void ggml_metal_graph_compute(
|
|
704
771
|
} break;
|
705
772
|
default:
|
706
773
|
{
|
707
|
-
|
774
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
708
775
|
GGML_ASSERT(false);
|
709
776
|
}
|
710
777
|
} break;
|
@@ -785,7 +852,7 @@ void ggml_metal_graph_compute(
|
|
785
852
|
switch (src0t) {
|
786
853
|
case GGML_TYPE_F16:
|
787
854
|
{
|
788
|
-
nth0 =
|
855
|
+
nth0 = 32;
|
789
856
|
nth1 = 1;
|
790
857
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
791
858
|
} break;
|
@@ -863,7 +930,7 @@ void ggml_metal_graph_compute(
|
|
863
930
|
} break;
|
864
931
|
default:
|
865
932
|
{
|
866
|
-
|
933
|
+
metal_printf("Asserting on type %d\n",(int)src0t);
|
867
934
|
GGML_ASSERT(false && "not implemented");
|
868
935
|
}
|
869
936
|
};
|
@@ -1101,7 +1168,7 @@ void ggml_metal_graph_compute(
|
|
1101
1168
|
} break;
|
1102
1169
|
default:
|
1103
1170
|
{
|
1104
|
-
|
1171
|
+
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1105
1172
|
GGML_ASSERT(false);
|
1106
1173
|
}
|
1107
1174
|
}
|
@@ -1117,17 +1184,19 @@ void ggml_metal_graph_compute(
|
|
1117
1184
|
}
|
1118
1185
|
|
1119
1186
|
// wait for all threads to finish
|
1120
|
-
dispatch_barrier_sync(
|
1121
|
-
|
1122
|
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
1187
|
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
1123
1188
|
|
1124
1189
|
// check status of command buffers
|
1125
1190
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
1126
1191
|
for (int i = 0; i < n_cb; i++) {
|
1127
|
-
|
1192
|
+
[ctx->command_buffers[i] waitUntilCompleted];
|
1193
|
+
|
1194
|
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1128
1195
|
if (status != MTLCommandBufferStatusCompleted) {
|
1129
|
-
|
1196
|
+
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1130
1197
|
GGML_ASSERT(false);
|
1131
1198
|
}
|
1132
1199
|
}
|
1200
|
+
|
1201
|
+
}
|
1133
1202
|
}
|
@@ -25,9 +25,9 @@ typedef struct {
|
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
27
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
28
|
+
device const float4 * src0,
|
29
|
+
device const float4 * src1,
|
30
|
+
device float4 * dst,
|
31
31
|
uint tpig[[thread_position_in_grid]]) {
|
32
32
|
dst[tpig] = src0[tpig] + src1[tpig];
|
33
33
|
}
|
@@ -35,18 +35,18 @@ kernel void kernel_add(
|
|
35
35
|
// assumption: src1 is a row
|
36
36
|
// broadcast src1 into src0
|
37
37
|
kernel void kernel_add_row(
|
38
|
-
device const
|
39
|
-
device const
|
40
|
-
device
|
41
|
-
constant int64_t &
|
38
|
+
device const float4 * src0,
|
39
|
+
device const float4 * src1,
|
40
|
+
device float4 * dst,
|
41
|
+
constant int64_t & nb,
|
42
42
|
uint tpig[[thread_position_in_grid]]) {
|
43
|
-
dst[tpig] = src0[tpig] + src1[tpig %
|
43
|
+
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
44
|
}
|
45
45
|
|
46
46
|
kernel void kernel_mul(
|
47
|
-
device const
|
48
|
-
device const
|
49
|
-
device
|
47
|
+
device const float4 * src0,
|
48
|
+
device const float4 * src1,
|
49
|
+
device float4 * dst,
|
50
50
|
uint tpig[[thread_position_in_grid]]) {
|
51
51
|
dst[tpig] = src0[tpig] * src1[tpig];
|
52
52
|
}
|
@@ -54,12 +54,12 @@ kernel void kernel_mul(
|
|
54
54
|
// assumption: src1 is a row
|
55
55
|
// broadcast src1 into src0
|
56
56
|
kernel void kernel_mul_row(
|
57
|
-
device const
|
58
|
-
device const
|
59
|
-
device
|
60
|
-
constant
|
57
|
+
device const float4 * src0,
|
58
|
+
device const float4 * src1,
|
59
|
+
device float4 * dst,
|
60
|
+
constant int64_t & nb,
|
61
61
|
uint tpig[[thread_position_in_grid]]) {
|
62
|
-
dst[tpig] = src0[tpig] * src1[tpig %
|
62
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
|
|
528
528
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
529
529
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
530
530
|
|
531
|
-
|
531
|
+
uint ith = tpitg.x;
|
532
|
+
uint nth = tptg.x;
|
533
|
+
|
534
|
+
sum[ith] = 0.0f;
|
532
535
|
|
533
|
-
for (int i =
|
534
|
-
sum[
|
536
|
+
for (int i = ith; i < ne00; i += nth) {
|
537
|
+
sum[ith] += (float) x[i] * (float) y[i];
|
535
538
|
}
|
536
539
|
|
537
540
|
// accumulate the sum from all threads in the threadgroup
|
538
541
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
539
|
-
|
540
|
-
|
541
|
-
sum[tpitg.x] += sum[tpitg.x + i];
|
542
|
-
}
|
543
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
542
|
+
if (ith%4 == 0) {
|
543
|
+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
544
544
|
}
|
545
|
-
|
546
|
-
if (
|
545
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
546
|
+
if (ith%16 == 0) {
|
547
|
+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
548
|
+
}
|
549
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
550
|
+
if (ith == 0) {
|
551
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
547
552
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
548
553
|
}
|
554
|
+
|
555
|
+
// Original implementation. Left behind commented out for now
|
556
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
557
|
+
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
558
|
+
// if (tpitg.x < i) {
|
559
|
+
// sum[tpitg.x] += sum[tpitg.x + i];
|
560
|
+
// }
|
561
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
562
|
+
//}
|
563
|
+
//
|
564
|
+
//if (tpitg.x == 0) {
|
565
|
+
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
566
|
+
//}
|
549
567
|
}
|
550
568
|
|
551
569
|
kernel void kernel_alibi_f32(
|