llama_cpp 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +8 -2
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +209 -82
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +163 -84
- data/ext/llama_cpp/src/ggml-metal.metal +121 -38
- data/ext/llama_cpp/src/ggml.c +1596 -842
- data/ext/llama_cpp/src/ggml.h +116 -35
- data/ext/llama_cpp/src/llama.cpp +1015 -586
- data/ext/llama_cpp/src/llama.h +304 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -11,11 +11,14 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
15
14
|
#ifdef GGML_METAL_NDEBUG
|
16
|
-
#define
|
15
|
+
#define GGML_METAL_LOG_INFO(...)
|
16
|
+
#define GGML_METAL_LOG_WARN(...)
|
17
|
+
#define GGML_METAL_LOG_ERROR(...)
|
17
18
|
#else
|
18
|
-
#define
|
19
|
+
#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
20
|
+
#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
21
|
+
#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
19
22
|
#endif
|
20
23
|
|
21
24
|
#define UNUSED(x) (void)(x)
|
@@ -100,7 +103,8 @@ struct ggml_metal_context {
|
|
100
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
101
104
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
102
105
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
103
|
-
GGML_METAL_DECL_KERNEL(
|
106
|
+
GGML_METAL_DECL_KERNEL(rope_f32);
|
107
|
+
GGML_METAL_DECL_KERNEL(rope_f16);
|
104
108
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
105
109
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
106
110
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
@@ -120,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
120
124
|
@implementation GGMLMetalClass
|
121
125
|
@end
|
122
126
|
|
127
|
+
ggml_log_callback ggml_metal_log_callback = NULL;
|
128
|
+
void * ggml_metal_log_user_data = NULL;
|
129
|
+
|
130
|
+
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
131
|
+
ggml_metal_log_callback = log_callback;
|
132
|
+
ggml_metal_log_user_data = user_data;
|
133
|
+
}
|
134
|
+
|
135
|
+
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
136
|
+
if (ggml_metal_log_callback != NULL) {
|
137
|
+
va_list args;
|
138
|
+
va_start(args, format);
|
139
|
+
char buffer[128];
|
140
|
+
int len = vsnprintf(buffer, 128, format, args);
|
141
|
+
if (len < 128) {
|
142
|
+
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
143
|
+
} else {
|
144
|
+
char* buffer2 = malloc(len+1);
|
145
|
+
vsnprintf(buffer2, len+1, format, args);
|
146
|
+
buffer2[len] = 0;
|
147
|
+
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
148
|
+
free(buffer2);
|
149
|
+
}
|
150
|
+
va_end(args);
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
123
156
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
124
|
-
|
157
|
+
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
125
158
|
|
126
159
|
id <MTLDevice> device;
|
127
160
|
NSString * s;
|
@@ -131,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
131
164
|
NSArray * devices = MTLCopyAllDevices();
|
132
165
|
for (device in devices) {
|
133
166
|
s = [device name];
|
134
|
-
|
167
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
135
168
|
}
|
136
169
|
#endif
|
137
170
|
|
138
171
|
// Pick and show default Metal device
|
139
172
|
device = MTLCreateSystemDefaultDevice();
|
140
173
|
s = [device name];
|
141
|
-
|
174
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
142
175
|
|
143
176
|
// Configure context
|
144
177
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
@@ -165,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
165
198
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
166
199
|
|
167
200
|
if (error) {
|
168
|
-
|
201
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
169
202
|
return NULL;
|
170
203
|
}
|
171
204
|
}
|
@@ -179,11 +212,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
179
212
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
180
213
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
181
214
|
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
182
|
-
|
215
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
183
216
|
|
184
217
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
185
218
|
if (error) {
|
186
|
-
|
219
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
187
220
|
return NULL;
|
188
221
|
}
|
189
222
|
|
@@ -195,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
195
228
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
196
229
|
#endif
|
197
230
|
if (error) {
|
198
|
-
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
199
232
|
return NULL;
|
200
233
|
}
|
201
234
|
}
|
@@ -207,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
207
240
|
#define GGML_METAL_ADD_KERNEL(name) \
|
208
241
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
209
242
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
210
|
-
|
243
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
211
244
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
212
245
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
213
246
|
if (error) { \
|
214
|
-
|
247
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
215
248
|
return NULL; \
|
216
249
|
}
|
217
250
|
|
@@ -261,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
261
294
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
262
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
263
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
264
|
-
GGML_METAL_ADD_KERNEL(
|
297
|
+
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(rope_f16);
|
265
299
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
266
300
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
267
301
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
@@ -270,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
270
304
|
#undef GGML_METAL_ADD_KERNEL
|
271
305
|
}
|
272
306
|
|
273
|
-
|
307
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
274
308
|
#if TARGET_OS_OSX
|
275
|
-
|
309
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
276
310
|
if (ctx->device.maxTransferRate != 0) {
|
277
|
-
|
311
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
278
312
|
} else {
|
279
|
-
|
313
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
280
314
|
}
|
281
315
|
#endif
|
282
316
|
|
@@ -284,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
284
318
|
}
|
285
319
|
|
286
320
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
287
|
-
|
321
|
+
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
288
322
|
#define GGML_METAL_DEL_KERNEL(name) \
|
289
323
|
[ctx->function_##name release]; \
|
290
324
|
[ctx->pipeline_##name release];
|
@@ -335,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
335
369
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
336
370
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
337
371
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
338
|
-
GGML_METAL_DEL_KERNEL(
|
372
|
+
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(rope_f16);
|
339
374
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
340
375
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
341
376
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
@@ -360,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
360
395
|
void * data = NULL;
|
361
396
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
362
397
|
if (result != 0) {
|
363
|
-
|
398
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
364
399
|
return NULL;
|
365
400
|
}
|
366
401
|
|
@@ -388,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
388
423
|
// Metal buffer based on the host memory pointer
|
389
424
|
//
|
390
425
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
391
|
-
//
|
426
|
+
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
392
427
|
|
393
428
|
const int64_t tsize = ggml_nbytes(t);
|
394
429
|
|
@@ -400,13 +435,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
400
435
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
401
436
|
*offs = (size_t) ioffs;
|
402
437
|
|
403
|
-
//
|
438
|
+
//GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
404
439
|
|
405
440
|
return ctx->buffers[i].metal;
|
406
441
|
}
|
407
442
|
}
|
408
443
|
|
409
|
-
|
444
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
410
445
|
|
411
446
|
return nil;
|
412
447
|
}
|
@@ -418,7 +453,7 @@ bool ggml_metal_add_buffer(
|
|
418
453
|
size_t size,
|
419
454
|
size_t max_size) {
|
420
455
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
421
|
-
|
456
|
+
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
422
457
|
return false;
|
423
458
|
}
|
424
459
|
|
@@ -428,7 +463,7 @@ bool ggml_metal_add_buffer(
|
|
428
463
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
429
464
|
|
430
465
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
431
|
-
|
466
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
432
467
|
return false;
|
433
468
|
}
|
434
469
|
}
|
@@ -449,11 +484,11 @@ bool ggml_metal_add_buffer(
|
|
449
484
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
450
485
|
|
451
486
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
452
|
-
|
487
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
453
488
|
return false;
|
454
489
|
}
|
455
490
|
|
456
|
-
|
491
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
457
492
|
|
458
493
|
++ctx->n_buffers;
|
459
494
|
} else {
|
@@ -473,13 +508,13 @@ bool ggml_metal_add_buffer(
|
|
473
508
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
474
509
|
|
475
510
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
476
|
-
|
511
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
477
512
|
return false;
|
478
513
|
}
|
479
514
|
|
480
|
-
|
515
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
481
516
|
if (i + size_step < size) {
|
482
|
-
|
517
|
+
GGML_METAL_LOG_INFO("\n");
|
483
518
|
}
|
484
519
|
|
485
520
|
++ctx->n_buffers;
|
@@ -487,17 +522,17 @@ bool ggml_metal_add_buffer(
|
|
487
522
|
}
|
488
523
|
|
489
524
|
#if TARGET_OS_OSX
|
490
|
-
|
525
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
491
526
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
492
527
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
493
528
|
|
494
529
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
495
|
-
|
530
|
+
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
496
531
|
} else {
|
497
|
-
|
532
|
+
GGML_METAL_LOG_INFO("\n");
|
498
533
|
}
|
499
534
|
#else
|
500
|
-
|
535
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
501
536
|
#endif
|
502
537
|
}
|
503
538
|
|
@@ -610,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
|
|
610
645
|
}
|
611
646
|
|
612
647
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
613
|
-
|
648
|
+
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
614
649
|
}
|
615
650
|
}
|
616
651
|
|
@@ -664,7 +699,7 @@ void ggml_metal_graph_compute(
|
|
664
699
|
continue;
|
665
700
|
}
|
666
701
|
|
667
|
-
//
|
702
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
668
703
|
|
669
704
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
670
705
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -708,17 +743,17 @@ void ggml_metal_graph_compute(
|
|
708
743
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
709
744
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
710
745
|
|
711
|
-
//
|
746
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
712
747
|
//if (src0) {
|
713
|
-
//
|
748
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
714
749
|
// ggml_is_contiguous(src0), src0->name);
|
715
750
|
//}
|
716
751
|
//if (src1) {
|
717
|
-
//
|
752
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
718
753
|
// ggml_is_contiguous(src1), src1->name);
|
719
754
|
//}
|
720
755
|
//if (dst) {
|
721
|
-
//
|
756
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
722
757
|
// dst->name);
|
723
758
|
//}
|
724
759
|
|
@@ -736,25 +771,59 @@ void ggml_metal_graph_compute(
|
|
736
771
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
737
772
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
738
773
|
|
739
|
-
|
740
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
741
|
-
const int64_t nb = ne00/4;
|
774
|
+
bool bcast_row = false;
|
742
775
|
|
743
|
-
|
776
|
+
int64_t nb = ne00;
|
777
|
+
|
778
|
+
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
744
779
|
// src1 is a row
|
745
780
|
GGML_ASSERT(ne11 == 1);
|
781
|
+
|
782
|
+
nb = ne00 / 4;
|
746
783
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
784
|
+
|
785
|
+
bcast_row = true;
|
747
786
|
} else {
|
748
787
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
749
788
|
}
|
750
789
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
751
790
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
752
791
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
753
|
-
[encoder setBytes:&
|
754
|
-
|
755
|
-
|
792
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
793
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
794
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
795
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
796
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
797
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
798
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
799
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
800
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
801
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
802
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
803
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
804
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
805
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
806
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
807
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
808
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
809
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
810
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
811
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
812
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
813
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
814
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
815
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
816
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
817
|
+
|
818
|
+
if (bcast_row) {
|
819
|
+
const int64_t n = ggml_nelements(dst)/4;
|
820
|
+
|
821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
822
|
+
} else {
|
823
|
+
const int nth = MIN(1024, ne0);
|
756
824
|
|
757
|
-
|
825
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
826
|
+
}
|
758
827
|
} break;
|
759
828
|
case GGML_OP_MUL:
|
760
829
|
{
|
@@ -830,13 +899,13 @@ void ggml_metal_graph_compute(
|
|
830
899
|
} break;
|
831
900
|
default:
|
832
901
|
{
|
833
|
-
|
902
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
834
903
|
GGML_ASSERT(false);
|
835
904
|
}
|
836
905
|
} break;
|
837
906
|
case GGML_OP_SOFT_MAX:
|
838
907
|
{
|
839
|
-
const int nth = 32;
|
908
|
+
const int nth = MIN(32, ne00);
|
840
909
|
|
841
910
|
if (ne00%4 == 0) {
|
842
911
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
@@ -889,7 +958,7 @@ void ggml_metal_graph_compute(
|
|
889
958
|
src1t == GGML_TYPE_F32 &&
|
890
959
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
891
960
|
ne00%32 == 0 &&
|
892
|
-
ne11 >
|
961
|
+
ne11 > 2) {
|
893
962
|
switch (src0->type) {
|
894
963
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
895
964
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
@@ -1019,7 +1088,7 @@ void ggml_metal_graph_compute(
|
|
1019
1088
|
} break;
|
1020
1089
|
default:
|
1021
1090
|
{
|
1022
|
-
|
1091
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1023
1092
|
GGML_ASSERT(false && "not implemented");
|
1024
1093
|
}
|
1025
1094
|
};
|
@@ -1100,7 +1169,7 @@ void ggml_metal_graph_compute(
|
|
1100
1169
|
float eps;
|
1101
1170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1102
1171
|
|
1103
|
-
const int nth = 512;
|
1172
|
+
const int nth = MIN(512, ne00);
|
1104
1173
|
|
1105
1174
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
1106
1175
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1119,7 +1188,7 @@ void ggml_metal_graph_compute(
|
|
1119
1188
|
float eps;
|
1120
1189
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1121
1190
|
|
1122
|
-
const int nth = 256;
|
1191
|
+
const int nth = MIN(256, ne00);
|
1123
1192
|
|
1124
1193
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
1125
1194
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1137,6 +1206,8 @@ void ggml_metal_graph_compute(
|
|
1137
1206
|
{
|
1138
1207
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
1139
1208
|
|
1209
|
+
const int nth = MIN(1024, ne00);
|
1210
|
+
|
1140
1211
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
1141
1212
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1142
1213
|
float max_bias;
|
@@ -1170,12 +1241,14 @@ void ggml_metal_graph_compute(
|
|
1170
1241
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1171
1242
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1172
1243
|
|
1173
|
-
const int nth = 32;
|
1174
|
-
|
1175
1244
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1176
1245
|
} break;
|
1177
1246
|
case GGML_OP_ROPE:
|
1178
1247
|
{
|
1248
|
+
GGML_ASSERT(ne10 == ne02);
|
1249
|
+
|
1250
|
+
const int nth = MIN(1024, ne00);
|
1251
|
+
|
1179
1252
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1180
1253
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1181
1254
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1185,38 +1258,44 @@ void ggml_metal_graph_compute(
|
|
1185
1258
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1186
1259
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
1187
1260
|
|
1188
|
-
|
1261
|
+
switch (src0->type) {
|
1262
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
1263
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
1264
|
+
default: GGML_ASSERT(false);
|
1265
|
+
};
|
1266
|
+
|
1189
1267
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1190
|
-
[encoder setBuffer:
|
1191
|
-
[encoder
|
1192
|
-
[encoder setBytes:&
|
1193
|
-
[encoder setBytes:&
|
1194
|
-
[encoder setBytes:&
|
1195
|
-
[encoder setBytes:&
|
1196
|
-
[encoder setBytes:&
|
1197
|
-
[encoder setBytes:&
|
1198
|
-
[encoder setBytes:&
|
1199
|
-
[encoder setBytes:&
|
1200
|
-
[encoder setBytes:&
|
1201
|
-
[encoder setBytes:&
|
1202
|
-
[encoder setBytes:&
|
1203
|
-
[encoder setBytes:&
|
1204
|
-
[encoder setBytes:&
|
1205
|
-
[encoder setBytes:&
|
1206
|
-
[encoder setBytes:&
|
1207
|
-
[encoder setBytes:&
|
1208
|
-
[encoder setBytes:&
|
1209
|
-
[encoder setBytes:&
|
1210
|
-
[encoder setBytes:&
|
1211
|
-
[encoder setBytes:&
|
1268
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1289
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
|
1290
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
|
1212
1291
|
|
1213
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
1292
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1214
1293
|
} break;
|
1215
1294
|
case GGML_OP_DUP:
|
1216
1295
|
case GGML_OP_CPY:
|
1217
1296
|
case GGML_OP_CONT:
|
1218
1297
|
{
|
1219
|
-
const int nth =
|
1298
|
+
const int nth = MIN(1024, ne00);
|
1220
1299
|
|
1221
1300
|
switch (src0t) {
|
1222
1301
|
case GGML_TYPE_F32:
|
@@ -1261,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1261
1340
|
} break;
|
1262
1341
|
default:
|
1263
1342
|
{
|
1264
|
-
|
1343
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1265
1344
|
GGML_ASSERT(false);
|
1266
1345
|
}
|
1267
1346
|
}
|
@@ -1286,7 +1365,7 @@ void ggml_metal_graph_compute(
|
|
1286
1365
|
|
1287
1366
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1288
1367
|
if (status != MTLCommandBufferStatusCompleted) {
|
1289
|
-
|
1368
|
+
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1290
1369
|
GGML_ASSERT(false);
|
1291
1370
|
}
|
1292
1371
|
}
|