llama_cpp 0.5.2 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +14 -8
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +307 -127
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +200 -94
- data/ext/llama_cpp/src/ggml-metal.metal +264 -82
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +1647 -865
- data/ext/llama_cpp/src/ggml.h +143 -52
- data/ext/llama_cpp/src/llama.cpp +1427 -635
- data/ext/llama_cpp/src/llama.h +308 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -11,11 +11,14 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
15
14
|
#ifdef GGML_METAL_NDEBUG
|
16
|
-
#define
|
15
|
+
#define GGML_METAL_LOG_INFO(...)
|
16
|
+
#define GGML_METAL_LOG_WARN(...)
|
17
|
+
#define GGML_METAL_LOG_ERROR(...)
|
17
18
|
#else
|
18
|
-
#define
|
19
|
+
#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
20
|
+
#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
21
|
+
#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
19
22
|
#endif
|
20
23
|
|
21
24
|
#define UNUSED(x) (void)(x)
|
@@ -66,6 +69,7 @@ struct ggml_metal_context {
|
|
66
69
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
67
70
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
68
71
|
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
72
|
+
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
69
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
70
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
71
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
@@ -77,6 +81,7 @@ struct ggml_metal_context {
|
|
77
81
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
78
82
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
79
83
|
GGML_METAL_DECL_KERNEL(norm);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
80
85
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
81
86
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
82
87
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
@@ -88,6 +93,7 @@ struct ggml_metal_context {
|
|
88
93
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
89
94
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
90
95
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
96
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
91
97
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
92
98
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
93
99
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
@@ -97,7 +103,8 @@ struct ggml_metal_context {
|
|
97
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
98
104
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
99
105
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
100
|
-
GGML_METAL_DECL_KERNEL(
|
106
|
+
GGML_METAL_DECL_KERNEL(rope_f32);
|
107
|
+
GGML_METAL_DECL_KERNEL(rope_f16);
|
101
108
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
102
109
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
103
110
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
@@ -117,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
117
124
|
@implementation GGMLMetalClass
|
118
125
|
@end
|
119
126
|
|
127
|
+
ggml_log_callback ggml_metal_log_callback = NULL;
|
128
|
+
void * ggml_metal_log_user_data = NULL;
|
129
|
+
|
130
|
+
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
131
|
+
ggml_metal_log_callback = log_callback;
|
132
|
+
ggml_metal_log_user_data = user_data;
|
133
|
+
}
|
134
|
+
|
135
|
+
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
136
|
+
if (ggml_metal_log_callback != NULL) {
|
137
|
+
va_list args;
|
138
|
+
va_start(args, format);
|
139
|
+
char buffer[128];
|
140
|
+
int len = vsnprintf(buffer, 128, format, args);
|
141
|
+
if (len < 128) {
|
142
|
+
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
143
|
+
} else {
|
144
|
+
char* buffer2 = malloc(len+1);
|
145
|
+
vsnprintf(buffer2, len+1, format, args);
|
146
|
+
buffer2[len] = 0;
|
147
|
+
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
148
|
+
free(buffer2);
|
149
|
+
}
|
150
|
+
va_end(args);
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
120
156
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
121
|
-
|
157
|
+
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
122
158
|
|
123
159
|
id <MTLDevice> device;
|
124
160
|
NSString * s;
|
@@ -128,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
128
164
|
NSArray * devices = MTLCopyAllDevices();
|
129
165
|
for (device in devices) {
|
130
166
|
s = [device name];
|
131
|
-
|
167
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
132
168
|
}
|
133
169
|
#endif
|
134
170
|
|
135
171
|
// Pick and show default Metal device
|
136
172
|
device = MTLCreateSystemDefaultDevice();
|
137
173
|
s = [device name];
|
138
|
-
|
174
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
139
175
|
|
140
176
|
// Configure context
|
141
177
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
@@ -145,7 +181,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
145
181
|
ctx->n_buffers = 0;
|
146
182
|
ctx->concur_list_len = 0;
|
147
183
|
|
148
|
-
ctx->d_queue = dispatch_queue_create("
|
184
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
149
185
|
|
150
186
|
#ifdef GGML_SWIFT
|
151
187
|
// load the default.metallib file
|
@@ -162,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
162
198
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
163
199
|
|
164
200
|
if (error) {
|
165
|
-
|
201
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
166
202
|
return NULL;
|
167
203
|
}
|
168
204
|
}
|
@@ -175,12 +211,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
175
211
|
|
176
212
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
177
213
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
178
|
-
NSString * path
|
179
|
-
|
214
|
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
215
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
180
216
|
|
181
217
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
182
218
|
if (error) {
|
183
|
-
|
219
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
184
220
|
return NULL;
|
185
221
|
}
|
186
222
|
|
@@ -192,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
192
228
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
193
229
|
#endif
|
194
230
|
if (error) {
|
195
|
-
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
196
232
|
return NULL;
|
197
233
|
}
|
198
234
|
}
|
@@ -204,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
204
240
|
#define GGML_METAL_ADD_KERNEL(name) \
|
205
241
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
206
242
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
207
|
-
|
243
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
208
244
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
209
245
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
210
246
|
if (error) { \
|
211
|
-
|
247
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
212
248
|
return NULL; \
|
213
249
|
}
|
214
250
|
|
@@ -224,6 +260,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
224
260
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
225
261
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
226
262
|
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
227
264
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
228
265
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
229
266
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
@@ -235,6 +272,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
235
272
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
236
273
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
237
274
|
GGML_METAL_ADD_KERNEL(norm);
|
275
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
238
276
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
239
277
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
240
278
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
@@ -246,6 +284,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
246
284
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
247
285
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
248
286
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
287
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
249
288
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
250
289
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
251
290
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
@@ -255,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
255
294
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
256
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
257
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
258
|
-
GGML_METAL_ADD_KERNEL(
|
297
|
+
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(rope_f16);
|
259
299
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
260
300
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
261
301
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
@@ -264,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
264
304
|
#undef GGML_METAL_ADD_KERNEL
|
265
305
|
}
|
266
306
|
|
267
|
-
|
307
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
268
308
|
#if TARGET_OS_OSX
|
269
|
-
|
309
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
270
310
|
if (ctx->device.maxTransferRate != 0) {
|
271
|
-
|
311
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
272
312
|
} else {
|
273
|
-
|
313
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
274
314
|
}
|
275
315
|
#endif
|
276
316
|
|
@@ -278,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
278
318
|
}
|
279
319
|
|
280
320
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
281
|
-
|
321
|
+
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
282
322
|
#define GGML_METAL_DEL_KERNEL(name) \
|
283
323
|
[ctx->function_##name release]; \
|
284
324
|
[ctx->pipeline_##name release];
|
@@ -293,7 +333,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
293
333
|
GGML_METAL_DEL_KERNEL(gelu);
|
294
334
|
GGML_METAL_DEL_KERNEL(soft_max);
|
295
335
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
336
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
296
337
|
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
338
|
+
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
297
339
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
298
340
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
299
341
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
@@ -305,6 +347,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
305
347
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
306
348
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
307
349
|
GGML_METAL_DEL_KERNEL(norm);
|
350
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
308
351
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
309
352
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
310
353
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
@@ -316,6 +359,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
316
359
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
317
360
|
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
318
361
|
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
362
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
319
363
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
320
364
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
321
365
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
@@ -325,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
325
369
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
326
370
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
327
371
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
328
|
-
GGML_METAL_DEL_KERNEL(
|
372
|
+
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(rope_f16);
|
329
374
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
330
375
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
331
376
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
@@ -350,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
350
395
|
void * data = NULL;
|
351
396
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
352
397
|
if (result != 0) {
|
353
|
-
|
398
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
354
399
|
return NULL;
|
355
400
|
}
|
356
401
|
|
@@ -378,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
378
423
|
// Metal buffer based on the host memory pointer
|
379
424
|
//
|
380
425
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
381
|
-
//
|
426
|
+
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
382
427
|
|
383
428
|
const int64_t tsize = ggml_nbytes(t);
|
384
429
|
|
@@ -386,16 +431,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
386
431
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
387
432
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
388
433
|
|
434
|
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
389
435
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
390
436
|
*offs = (size_t) ioffs;
|
391
437
|
|
392
|
-
//
|
438
|
+
//GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
393
439
|
|
394
440
|
return ctx->buffers[i].metal;
|
395
441
|
}
|
396
442
|
}
|
397
443
|
|
398
|
-
|
444
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
399
445
|
|
400
446
|
return nil;
|
401
447
|
}
|
@@ -407,7 +453,7 @@ bool ggml_metal_add_buffer(
|
|
407
453
|
size_t size,
|
408
454
|
size_t max_size) {
|
409
455
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
410
|
-
|
456
|
+
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
411
457
|
return false;
|
412
458
|
}
|
413
459
|
|
@@ -417,7 +463,7 @@ bool ggml_metal_add_buffer(
|
|
417
463
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
418
464
|
|
419
465
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
420
|
-
|
466
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
421
467
|
return false;
|
422
468
|
}
|
423
469
|
}
|
@@ -438,11 +484,11 @@ bool ggml_metal_add_buffer(
|
|
438
484
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
439
485
|
|
440
486
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
441
|
-
|
487
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
442
488
|
return false;
|
443
489
|
}
|
444
490
|
|
445
|
-
|
491
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
446
492
|
|
447
493
|
++ctx->n_buffers;
|
448
494
|
} else {
|
@@ -462,13 +508,13 @@ bool ggml_metal_add_buffer(
|
|
462
508
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
463
509
|
|
464
510
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
465
|
-
|
511
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
466
512
|
return false;
|
467
513
|
}
|
468
514
|
|
469
|
-
|
515
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
470
516
|
if (i + size_step < size) {
|
471
|
-
|
517
|
+
GGML_METAL_LOG_INFO("\n");
|
472
518
|
}
|
473
519
|
|
474
520
|
++ctx->n_buffers;
|
@@ -476,17 +522,17 @@ bool ggml_metal_add_buffer(
|
|
476
522
|
}
|
477
523
|
|
478
524
|
#if TARGET_OS_OSX
|
479
|
-
|
525
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
480
526
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
481
527
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
482
528
|
|
483
529
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
484
|
-
|
530
|
+
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
485
531
|
} else {
|
486
|
-
|
532
|
+
GGML_METAL_LOG_INFO("\n");
|
487
533
|
}
|
488
534
|
#else
|
489
|
-
|
535
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
490
536
|
#endif
|
491
537
|
}
|
492
538
|
|
@@ -599,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
|
|
599
645
|
}
|
600
646
|
|
601
647
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
602
|
-
|
648
|
+
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
603
649
|
}
|
604
650
|
}
|
605
651
|
|
@@ -653,7 +699,7 @@ void ggml_metal_graph_compute(
|
|
653
699
|
continue;
|
654
700
|
}
|
655
701
|
|
656
|
-
//
|
702
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
657
703
|
|
658
704
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
659
705
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -697,17 +743,17 @@ void ggml_metal_graph_compute(
|
|
697
743
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
698
744
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
699
745
|
|
700
|
-
//
|
746
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
701
747
|
//if (src0) {
|
702
|
-
//
|
748
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
703
749
|
// ggml_is_contiguous(src0), src0->name);
|
704
750
|
//}
|
705
751
|
//if (src1) {
|
706
|
-
//
|
752
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
707
753
|
// ggml_is_contiguous(src1), src1->name);
|
708
754
|
//}
|
709
755
|
//if (dst) {
|
710
|
-
//
|
756
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
711
757
|
// dst->name);
|
712
758
|
//}
|
713
759
|
|
@@ -723,29 +769,66 @@ void ggml_metal_graph_compute(
|
|
723
769
|
case GGML_OP_ADD:
|
724
770
|
{
|
725
771
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
772
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
726
773
|
|
727
|
-
|
728
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
729
|
-
const int64_t nb = ne00/4;
|
774
|
+
bool bcast_row = false;
|
730
775
|
|
731
|
-
|
776
|
+
int64_t nb = ne00;
|
777
|
+
|
778
|
+
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
732
779
|
// src1 is a row
|
780
|
+
GGML_ASSERT(ne11 == 1);
|
781
|
+
|
782
|
+
nb = ne00 / 4;
|
733
783
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
784
|
+
|
785
|
+
bcast_row = true;
|
734
786
|
} else {
|
735
787
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
736
788
|
}
|
737
789
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
738
790
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
739
791
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
740
|
-
[encoder setBytes:&
|
741
|
-
|
742
|
-
|
792
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
793
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
794
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
795
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
796
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
797
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
798
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
799
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
800
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
801
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
802
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
803
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
804
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
805
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
806
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
807
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
808
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
809
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
810
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
811
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
812
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
813
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
814
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
815
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
816
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
817
|
+
|
818
|
+
if (bcast_row) {
|
819
|
+
const int64_t n = ggml_nelements(dst)/4;
|
820
|
+
|
821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
822
|
+
} else {
|
823
|
+
const int nth = MIN(1024, ne0);
|
743
824
|
|
744
|
-
|
825
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
826
|
+
}
|
745
827
|
} break;
|
746
828
|
case GGML_OP_MUL:
|
747
829
|
{
|
748
830
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
831
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
749
832
|
|
750
833
|
// utilize float4
|
751
834
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -753,6 +836,7 @@ void ggml_metal_graph_compute(
|
|
753
836
|
|
754
837
|
if (ggml_nelements(src1) == ne10) {
|
755
838
|
// src1 is a row
|
839
|
+
GGML_ASSERT(ne11 == 1);
|
756
840
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
757
841
|
} else {
|
758
842
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
@@ -768,6 +852,8 @@ void ggml_metal_graph_compute(
|
|
768
852
|
} break;
|
769
853
|
case GGML_OP_SCALE:
|
770
854
|
{
|
855
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
856
|
+
|
771
857
|
const float scale = *(const float *) src1->data;
|
772
858
|
|
773
859
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -813,13 +899,13 @@ void ggml_metal_graph_compute(
|
|
813
899
|
} break;
|
814
900
|
default:
|
815
901
|
{
|
816
|
-
|
902
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
817
903
|
GGML_ASSERT(false);
|
818
904
|
}
|
819
905
|
} break;
|
820
906
|
case GGML_OP_SOFT_MAX:
|
821
907
|
{
|
822
|
-
const int nth = 32;
|
908
|
+
const int nth = MIN(32, ne00);
|
823
909
|
|
824
910
|
if (ne00%4 == 0) {
|
825
911
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
@@ -867,13 +953,14 @@ void ggml_metal_graph_compute(
|
|
867
953
|
|
868
954
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
869
955
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
870
|
-
if (
|
871
|
-
|
956
|
+
if (!ggml_is_transposed(src0) &&
|
957
|
+
!ggml_is_transposed(src1) &&
|
872
958
|
src1t == GGML_TYPE_F32 &&
|
873
959
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
874
960
|
ne00%32 == 0 &&
|
875
|
-
ne11 >
|
961
|
+
ne11 > 2) {
|
876
962
|
switch (src0->type) {
|
963
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
877
964
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
878
965
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
879
966
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
@@ -893,9 +980,12 @@ void ggml_metal_graph_compute(
|
|
893
980
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
894
981
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
895
982
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
896
|
-
[encoder setBytes:&
|
897
|
-
[encoder setBytes:&
|
898
|
-
[encoder setBytes:&
|
983
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
984
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
985
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
986
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
987
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
988
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
899
989
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
900
990
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
901
991
|
} else {
|
@@ -905,6 +995,11 @@ void ggml_metal_graph_compute(
|
|
905
995
|
|
906
996
|
// use custom matrix x vector kernel
|
907
997
|
switch (src0t) {
|
998
|
+
case GGML_TYPE_F32:
|
999
|
+
{
|
1000
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
1001
|
+
nrows = 4;
|
1002
|
+
} break;
|
908
1003
|
case GGML_TYPE_F16:
|
909
1004
|
{
|
910
1005
|
nth0 = 32;
|
@@ -993,7 +1088,7 @@ void ggml_metal_graph_compute(
|
|
993
1088
|
} break;
|
994
1089
|
default:
|
995
1090
|
{
|
996
|
-
|
1091
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
997
1092
|
GGML_ASSERT(false && "not implemented");
|
998
1093
|
}
|
999
1094
|
};
|
@@ -1045,6 +1140,7 @@ void ggml_metal_graph_compute(
|
|
1045
1140
|
case GGML_OP_GET_ROWS:
|
1046
1141
|
{
|
1047
1142
|
switch (src0->type) {
|
1143
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
1048
1144
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1049
1145
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1050
1146
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
@@ -1060,9 +1156,9 @@ void ggml_metal_graph_compute(
|
|
1060
1156
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1061
1157
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1062
1158
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1063
|
-
[encoder setBytes:&
|
1064
|
-
[encoder setBytes:&
|
1065
|
-
[encoder setBytes:&
|
1159
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1160
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1161
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
1066
1162
|
|
1067
1163
|
const int64_t n = ggml_nelements(src1);
|
1068
1164
|
|
@@ -1073,7 +1169,7 @@ void ggml_metal_graph_compute(
|
|
1073
1169
|
float eps;
|
1074
1170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1075
1171
|
|
1076
|
-
const int nth = 512;
|
1172
|
+
const int nth = MIN(512, ne00);
|
1077
1173
|
|
1078
1174
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
1079
1175
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1092,7 +1188,7 @@ void ggml_metal_graph_compute(
|
|
1092
1188
|
float eps;
|
1093
1189
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1094
1190
|
|
1095
|
-
const int nth = 256;
|
1191
|
+
const int nth = MIN(256, ne00);
|
1096
1192
|
|
1097
1193
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
1098
1194
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1110,6 +1206,8 @@ void ggml_metal_graph_compute(
|
|
1110
1206
|
{
|
1111
1207
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
1112
1208
|
|
1209
|
+
const int nth = MIN(1024, ne00);
|
1210
|
+
|
1113
1211
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
1114
1212
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1115
1213
|
float max_bias;
|
@@ -1143,12 +1241,14 @@ void ggml_metal_graph_compute(
|
|
1143
1241
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1144
1242
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1145
1243
|
|
1146
|
-
const int nth = 32;
|
1147
|
-
|
1148
1244
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1149
1245
|
} break;
|
1150
1246
|
case GGML_OP_ROPE:
|
1151
1247
|
{
|
1248
|
+
GGML_ASSERT(ne10 == ne02);
|
1249
|
+
|
1250
|
+
const int nth = MIN(1024, ne00);
|
1251
|
+
|
1152
1252
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1153
1253
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1154
1254
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1158,38 +1258,44 @@ void ggml_metal_graph_compute(
|
|
1158
1258
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1159
1259
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
1160
1260
|
|
1161
|
-
|
1261
|
+
switch (src0->type) {
|
1262
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
1263
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
1264
|
+
default: GGML_ASSERT(false);
|
1265
|
+
};
|
1266
|
+
|
1162
1267
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1163
|
-
[encoder setBuffer:
|
1164
|
-
[encoder
|
1165
|
-
[encoder setBytes:&
|
1166
|
-
[encoder setBytes:&
|
1167
|
-
[encoder setBytes:&
|
1168
|
-
[encoder setBytes:&
|
1169
|
-
[encoder setBytes:&
|
1170
|
-
[encoder setBytes:&
|
1171
|
-
[encoder setBytes:&
|
1172
|
-
[encoder setBytes:&
|
1173
|
-
[encoder setBytes:&
|
1174
|
-
[encoder setBytes:&
|
1175
|
-
[encoder setBytes:&
|
1176
|
-
[encoder setBytes:&
|
1177
|
-
[encoder setBytes:&
|
1178
|
-
[encoder setBytes:&
|
1179
|
-
[encoder setBytes:&
|
1180
|
-
[encoder setBytes:&
|
1181
|
-
[encoder setBytes:&
|
1182
|
-
[encoder setBytes:&
|
1183
|
-
[encoder setBytes:&
|
1184
|
-
[encoder setBytes:&
|
1268
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1289
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
|
1290
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
|
1185
1291
|
|
1186
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
1292
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1187
1293
|
} break;
|
1188
1294
|
case GGML_OP_DUP:
|
1189
1295
|
case GGML_OP_CPY:
|
1190
1296
|
case GGML_OP_CONT:
|
1191
1297
|
{
|
1192
|
-
const int nth =
|
1298
|
+
const int nth = MIN(1024, ne00);
|
1193
1299
|
|
1194
1300
|
switch (src0t) {
|
1195
1301
|
case GGML_TYPE_F32:
|
@@ -1234,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1234
1340
|
} break;
|
1235
1341
|
default:
|
1236
1342
|
{
|
1237
|
-
|
1343
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1238
1344
|
GGML_ASSERT(false);
|
1239
1345
|
}
|
1240
1346
|
}
|
@@ -1259,7 +1365,7 @@ void ggml_metal_graph_compute(
|
|
1259
1365
|
|
1260
1366
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1261
1367
|
if (status != MTLCommandBufferStatusCompleted) {
|
1262
|
-
|
1368
|
+
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1263
1369
|
GGML_ASSERT(false);
|
1264
1370
|
}
|
1265
1371
|
}
|