llama_cpp 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -11,11 +11,14 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
15
14
|
#ifdef GGML_METAL_NDEBUG
|
16
|
-
#define
|
15
|
+
#define GGML_METAL_LOG_INFO(...)
|
16
|
+
#define GGML_METAL_LOG_WARN(...)
|
17
|
+
#define GGML_METAL_LOG_ERROR(...)
|
17
18
|
#else
|
18
|
-
#define
|
19
|
+
#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
20
|
+
#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
21
|
+
#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
19
22
|
#endif
|
20
23
|
|
21
24
|
#define UNUSED(x) (void)(x)
|
@@ -100,7 +103,8 @@ struct ggml_metal_context {
|
|
100
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
101
104
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
102
105
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
103
|
-
GGML_METAL_DECL_KERNEL(
|
106
|
+
GGML_METAL_DECL_KERNEL(rope_f32);
|
107
|
+
GGML_METAL_DECL_KERNEL(rope_f16);
|
104
108
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
105
109
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
106
110
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
@@ -120,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
120
124
|
@implementation GGMLMetalClass
|
121
125
|
@end
|
122
126
|
|
127
|
+
ggml_log_callback ggml_metal_log_callback = NULL;
|
128
|
+
void * ggml_metal_log_user_data = NULL;
|
129
|
+
|
130
|
+
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
131
|
+
ggml_metal_log_callback = log_callback;
|
132
|
+
ggml_metal_log_user_data = user_data;
|
133
|
+
}
|
134
|
+
|
135
|
+
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
136
|
+
if (ggml_metal_log_callback != NULL) {
|
137
|
+
va_list args;
|
138
|
+
va_start(args, format);
|
139
|
+
char buffer[128];
|
140
|
+
int len = vsnprintf(buffer, 128, format, args);
|
141
|
+
if (len < 128) {
|
142
|
+
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
143
|
+
} else {
|
144
|
+
char* buffer2 = malloc(len+1);
|
145
|
+
vsnprintf(buffer2, len+1, format, args);
|
146
|
+
buffer2[len] = 0;
|
147
|
+
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
148
|
+
free(buffer2);
|
149
|
+
}
|
150
|
+
va_end(args);
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
123
156
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
124
|
-
|
157
|
+
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
125
158
|
|
126
159
|
id <MTLDevice> device;
|
127
160
|
NSString * s;
|
@@ -131,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
131
164
|
NSArray * devices = MTLCopyAllDevices();
|
132
165
|
for (device in devices) {
|
133
166
|
s = [device name];
|
134
|
-
|
167
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
135
168
|
}
|
136
169
|
#endif
|
137
170
|
|
138
171
|
// Pick and show default Metal device
|
139
172
|
device = MTLCreateSystemDefaultDevice();
|
140
173
|
s = [device name];
|
141
|
-
|
174
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
142
175
|
|
143
176
|
// Configure context
|
144
177
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
@@ -165,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
165
198
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
166
199
|
|
167
200
|
if (error) {
|
168
|
-
|
201
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
169
202
|
return NULL;
|
170
203
|
}
|
171
204
|
}
|
@@ -179,11 +212,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
179
212
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
180
213
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
181
214
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
182
|
-
|
215
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
183
216
|
|
184
217
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
185
218
|
if (error) {
|
186
|
-
|
219
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
187
220
|
return NULL;
|
188
221
|
}
|
189
222
|
|
@@ -195,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
195
228
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
196
229
|
#endif
|
197
230
|
if (error) {
|
198
|
-
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
199
232
|
return NULL;
|
200
233
|
}
|
201
234
|
}
|
@@ -207,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
207
240
|
#define GGML_METAL_ADD_KERNEL(name) \
|
208
241
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
209
242
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
210
|
-
|
243
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
211
244
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
212
245
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
213
246
|
if (error) { \
|
214
|
-
|
247
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
215
248
|
return NULL; \
|
216
249
|
}
|
217
250
|
|
@@ -261,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
261
294
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
262
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
263
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
264
|
-
GGML_METAL_ADD_KERNEL(
|
297
|
+
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(rope_f16);
|
265
299
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
266
300
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
267
301
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
@@ -270,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
270
304
|
#undef GGML_METAL_ADD_KERNEL
|
271
305
|
}
|
272
306
|
|
273
|
-
|
307
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
274
308
|
#if TARGET_OS_OSX
|
275
|
-
|
309
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
276
310
|
if (ctx->device.maxTransferRate != 0) {
|
277
|
-
|
311
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
278
312
|
} else {
|
279
|
-
|
313
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
280
314
|
}
|
281
315
|
#endif
|
282
316
|
|
@@ -284,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
284
318
|
}
|
285
319
|
|
286
320
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
287
|
-
|
321
|
+
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
288
322
|
#define GGML_METAL_DEL_KERNEL(name) \
|
289
323
|
[ctx->function_##name release]; \
|
290
324
|
[ctx->pipeline_##name release];
|
@@ -335,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
335
369
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
336
370
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
337
371
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
338
|
-
GGML_METAL_DEL_KERNEL(
|
372
|
+
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(rope_f16);
|
339
374
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
340
375
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
341
376
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
@@ -360,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
360
395
|
void * data = NULL;
|
361
396
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
362
397
|
if (result != 0) {
|
363
|
-
|
398
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
364
399
|
return NULL;
|
365
400
|
}
|
366
401
|
|
@@ -388,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
388
423
|
// Metal buffer based on the host memory pointer
|
389
424
|
//
|
390
425
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
391
|
-
//
|
426
|
+
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
392
427
|
|
393
428
|
const int64_t tsize = ggml_nbytes(t);
|
394
429
|
|
@@ -400,13 +435,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
400
435
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
401
436
|
*offs = (size_t) ioffs;
|
402
437
|
|
403
|
-
//
|
438
|
+
//GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
404
439
|
|
405
440
|
return ctx->buffers[i].metal;
|
406
441
|
}
|
407
442
|
}
|
408
443
|
|
409
|
-
|
444
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
410
445
|
|
411
446
|
return nil;
|
412
447
|
}
|
@@ -418,7 +453,7 @@ bool ggml_metal_add_buffer(
|
|
418
453
|
size_t size,
|
419
454
|
size_t max_size) {
|
420
455
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
421
|
-
|
456
|
+
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
422
457
|
return false;
|
423
458
|
}
|
424
459
|
|
@@ -428,7 +463,7 @@ bool ggml_metal_add_buffer(
|
|
428
463
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
429
464
|
|
430
465
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
431
|
-
|
466
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
432
467
|
return false;
|
433
468
|
}
|
434
469
|
}
|
@@ -449,11 +484,11 @@ bool ggml_metal_add_buffer(
|
|
449
484
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
450
485
|
|
451
486
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
452
|
-
|
487
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
453
488
|
return false;
|
454
489
|
}
|
455
490
|
|
456
|
-
|
491
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
457
492
|
|
458
493
|
++ctx->n_buffers;
|
459
494
|
} else {
|
@@ -473,13 +508,13 @@ bool ggml_metal_add_buffer(
|
|
473
508
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
474
509
|
|
475
510
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
476
|
-
|
511
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
477
512
|
return false;
|
478
513
|
}
|
479
514
|
|
480
|
-
|
515
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
481
516
|
if (i + size_step < size) {
|
482
|
-
|
517
|
+
GGML_METAL_LOG_INFO("\n");
|
483
518
|
}
|
484
519
|
|
485
520
|
++ctx->n_buffers;
|
@@ -487,17 +522,17 @@ bool ggml_metal_add_buffer(
|
|
487
522
|
}
|
488
523
|
|
489
524
|
#if TARGET_OS_OSX
|
490
|
-
|
525
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
491
526
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
492
527
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
493
528
|
|
494
529
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
495
|
-
|
530
|
+
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
496
531
|
} else {
|
497
|
-
|
532
|
+
GGML_METAL_LOG_INFO("\n");
|
498
533
|
}
|
499
534
|
#else
|
500
|
-
|
535
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
501
536
|
#endif
|
502
537
|
}
|
503
538
|
|
@@ -610,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
|
|
610
645
|
}
|
611
646
|
|
612
647
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
613
|
-
|
648
|
+
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
614
649
|
}
|
615
650
|
}
|
616
651
|
|
@@ -664,7 +699,7 @@ void ggml_metal_graph_compute(
|
|
664
699
|
continue;
|
665
700
|
}
|
666
701
|
|
667
|
-
//
|
702
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
668
703
|
|
669
704
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
670
705
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -708,17 +743,17 @@ void ggml_metal_graph_compute(
|
|
708
743
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
709
744
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
710
745
|
|
711
|
-
//
|
746
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
712
747
|
//if (src0) {
|
713
|
-
//
|
748
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
714
749
|
// ggml_is_contiguous(src0), src0->name);
|
715
750
|
//}
|
716
751
|
//if (src1) {
|
717
|
-
//
|
752
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
718
753
|
// ggml_is_contiguous(src1), src1->name);
|
719
754
|
//}
|
720
755
|
//if (dst) {
|
721
|
-
//
|
756
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
722
757
|
// dst->name);
|
723
758
|
//}
|
724
759
|
|
@@ -736,25 +771,59 @@ void ggml_metal_graph_compute(
|
|
736
771
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
737
772
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
738
773
|
|
739
|
-
|
740
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
741
|
-
const int64_t nb = ne00/4;
|
774
|
+
bool bcast_row = false;
|
742
775
|
|
743
|
-
|
776
|
+
int64_t nb = ne00;
|
777
|
+
|
778
|
+
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
744
779
|
// src1 is a row
|
745
780
|
GGML_ASSERT(ne11 == 1);
|
781
|
+
|
782
|
+
nb = ne00 / 4;
|
746
783
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
784
|
+
|
785
|
+
bcast_row = true;
|
747
786
|
} else {
|
748
787
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
749
788
|
}
|
750
789
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
751
790
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
752
791
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
753
|
-
[encoder setBytes:&
|
754
|
-
|
755
|
-
|
792
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
793
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
794
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
795
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
796
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
797
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
798
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
799
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
800
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
801
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
802
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
803
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
804
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
805
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
806
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
807
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
808
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
809
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
810
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
811
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
812
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
813
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
814
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
815
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
816
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
817
|
+
|
818
|
+
if (bcast_row) {
|
819
|
+
const int64_t n = ggml_nelements(dst)/4;
|
820
|
+
|
821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
822
|
+
} else {
|
823
|
+
const int nth = MIN(1024, ne0);
|
756
824
|
|
757
|
-
|
825
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
826
|
+
}
|
758
827
|
} break;
|
759
828
|
case GGML_OP_MUL:
|
760
829
|
{
|
@@ -830,13 +899,13 @@ void ggml_metal_graph_compute(
|
|
830
899
|
} break;
|
831
900
|
default:
|
832
901
|
{
|
833
|
-
|
902
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
834
903
|
GGML_ASSERT(false);
|
835
904
|
}
|
836
905
|
} break;
|
837
906
|
case GGML_OP_SOFT_MAX:
|
838
907
|
{
|
839
|
-
const int nth = 32;
|
908
|
+
const int nth = MIN(32, ne00);
|
840
909
|
|
841
910
|
if (ne00%4 == 0) {
|
842
911
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
@@ -889,7 +958,7 @@ void ggml_metal_graph_compute(
|
|
889
958
|
src1t == GGML_TYPE_F32 &&
|
890
959
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
891
960
|
ne00%32 == 0 &&
|
892
|
-
ne11 >
|
961
|
+
ne11 > 2) {
|
893
962
|
switch (src0->type) {
|
894
963
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
895
964
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
@@ -1019,7 +1088,7 @@ void ggml_metal_graph_compute(
|
|
1019
1088
|
} break;
|
1020
1089
|
default:
|
1021
1090
|
{
|
1022
|
-
|
1091
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1023
1092
|
GGML_ASSERT(false && "not implemented");
|
1024
1093
|
}
|
1025
1094
|
};
|
@@ -1100,7 +1169,7 @@ void ggml_metal_graph_compute(
|
|
1100
1169
|
float eps;
|
1101
1170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1102
1171
|
|
1103
|
-
const int nth = 512;
|
1172
|
+
const int nth = MIN(512, ne00);
|
1104
1173
|
|
1105
1174
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
1106
1175
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1119,7 +1188,7 @@ void ggml_metal_graph_compute(
|
|
1119
1188
|
float eps;
|
1120
1189
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1121
1190
|
|
1122
|
-
const int nth = 256;
|
1191
|
+
const int nth = MIN(256, ne00);
|
1123
1192
|
|
1124
1193
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
1125
1194
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1137,6 +1206,8 @@ void ggml_metal_graph_compute(
|
|
1137
1206
|
{
|
1138
1207
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
1139
1208
|
|
1209
|
+
const int nth = MIN(1024, ne00);
|
1210
|
+
|
1140
1211
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
1141
1212
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1142
1213
|
float max_bias;
|
@@ -1170,12 +1241,14 @@ void ggml_metal_graph_compute(
|
|
1170
1241
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1171
1242
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1172
1243
|
|
1173
|
-
const int nth = 32;
|
1174
|
-
|
1175
1244
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1176
1245
|
} break;
|
1177
1246
|
case GGML_OP_ROPE:
|
1178
1247
|
{
|
1248
|
+
GGML_ASSERT(ne10 == ne02);
|
1249
|
+
|
1250
|
+
const int nth = MIN(1024, ne00);
|
1251
|
+
|
1179
1252
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1180
1253
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1181
1254
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1185,38 +1258,44 @@ void ggml_metal_graph_compute(
|
|
1185
1258
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1186
1259
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
1187
1260
|
|
1188
|
-
|
1261
|
+
switch (src0->type) {
|
1262
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
1263
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
1264
|
+
default: GGML_ASSERT(false);
|
1265
|
+
};
|
1266
|
+
|
1189
1267
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1190
|
-
[encoder setBuffer:
|
1191
|
-
[encoder
|
1192
|
-
[encoder setBytes:&
|
1193
|
-
[encoder setBytes:&
|
1194
|
-
[encoder setBytes:&
|
1195
|
-
[encoder setBytes:&
|
1196
|
-
[encoder setBytes:&
|
1197
|
-
[encoder setBytes:&
|
1198
|
-
[encoder setBytes:&
|
1199
|
-
[encoder setBytes:&
|
1200
|
-
[encoder setBytes:&
|
1201
|
-
[encoder setBytes:&
|
1202
|
-
[encoder setBytes:&
|
1203
|
-
[encoder setBytes:&
|
1204
|
-
[encoder setBytes:&
|
1205
|
-
[encoder setBytes:&
|
1206
|
-
[encoder setBytes:&
|
1207
|
-
[encoder setBytes:&
|
1208
|
-
[encoder setBytes:&
|
1209
|
-
[encoder setBytes:&
|
1210
|
-
[encoder setBytes:&
|
1211
|
-
[encoder setBytes:&
|
1268
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1289
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
|
1290
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
|
1212
1291
|
|
1213
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
1292
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1214
1293
|
} break;
|
1215
1294
|
case GGML_OP_DUP:
|
1216
1295
|
case GGML_OP_CPY:
|
1217
1296
|
case GGML_OP_CONT:
|
1218
1297
|
{
|
1219
|
-
const int nth =
|
1298
|
+
const int nth = MIN(1024, ne00);
|
1220
1299
|
|
1221
1300
|
switch (src0t) {
|
1222
1301
|
case GGML_TYPE_F32:
|
@@ -1261,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1261
1340
|
} break;
|
1262
1341
|
default:
|
1263
1342
|
{
|
1264
|
-
|
1343
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1265
1344
|
GGML_ASSERT(false);
|
1266
1345
|
}
|
1267
1346
|
}
|
@@ -1286,7 +1365,7 @@ void ggml_metal_graph_compute(
|
|
1286
1365
|
|
1287
1366
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1288
1367
|
if (status != MTLCommandBufferStatusCompleted) {
|
1289
|
-
|
1368
|
+
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1290
1369
|
GGML_ASSERT(false);
|
1291
1370
|
}
|
1292
1371
|
}
|