llama_cpp 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +14 -8
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +307 -127
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +200 -94
- data/ext/llama_cpp/src/ggml-metal.metal +264 -82
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +1647 -865
- data/ext/llama_cpp/src/ggml.h +143 -52
- data/ext/llama_cpp/src/llama.cpp +1427 -635
- data/ext/llama_cpp/src/llama.h +308 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -11,11 +11,14 @@
|
|
11
11
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
12
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
13
13
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
15
14
|
#ifdef GGML_METAL_NDEBUG
|
16
|
-
#define
|
15
|
+
#define GGML_METAL_LOG_INFO(...)
|
16
|
+
#define GGML_METAL_LOG_WARN(...)
|
17
|
+
#define GGML_METAL_LOG_ERROR(...)
|
17
18
|
#else
|
18
|
-
#define
|
19
|
+
#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
20
|
+
#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
21
|
+
#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
19
22
|
#endif
|
20
23
|
|
21
24
|
#define UNUSED(x) (void)(x)
|
@@ -66,6 +69,7 @@ struct ggml_metal_context {
|
|
66
69
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
67
70
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
68
71
|
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
72
|
+
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
69
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
70
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
71
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
@@ -77,6 +81,7 @@ struct ggml_metal_context {
|
|
77
81
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
78
82
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
79
83
|
GGML_METAL_DECL_KERNEL(norm);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
80
85
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
81
86
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
82
87
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
@@ -88,6 +93,7 @@ struct ggml_metal_context {
|
|
88
93
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
89
94
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
90
95
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
96
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
91
97
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
92
98
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
93
99
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
@@ -97,7 +103,8 @@ struct ggml_metal_context {
|
|
97
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
98
104
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
99
105
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
100
|
-
GGML_METAL_DECL_KERNEL(
|
106
|
+
GGML_METAL_DECL_KERNEL(rope_f32);
|
107
|
+
GGML_METAL_DECL_KERNEL(rope_f16);
|
101
108
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
102
109
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
103
110
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
@@ -117,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
117
124
|
@implementation GGMLMetalClass
|
118
125
|
@end
|
119
126
|
|
127
|
+
ggml_log_callback ggml_metal_log_callback = NULL;
|
128
|
+
void * ggml_metal_log_user_data = NULL;
|
129
|
+
|
130
|
+
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
131
|
+
ggml_metal_log_callback = log_callback;
|
132
|
+
ggml_metal_log_user_data = user_data;
|
133
|
+
}
|
134
|
+
|
135
|
+
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
136
|
+
if (ggml_metal_log_callback != NULL) {
|
137
|
+
va_list args;
|
138
|
+
va_start(args, format);
|
139
|
+
char buffer[128];
|
140
|
+
int len = vsnprintf(buffer, 128, format, args);
|
141
|
+
if (len < 128) {
|
142
|
+
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
143
|
+
} else {
|
144
|
+
char* buffer2 = malloc(len+1);
|
145
|
+
vsnprintf(buffer2, len+1, format, args);
|
146
|
+
buffer2[len] = 0;
|
147
|
+
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
148
|
+
free(buffer2);
|
149
|
+
}
|
150
|
+
va_end(args);
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
120
156
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
121
|
-
|
157
|
+
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
122
158
|
|
123
159
|
id <MTLDevice> device;
|
124
160
|
NSString * s;
|
@@ -128,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
128
164
|
NSArray * devices = MTLCopyAllDevices();
|
129
165
|
for (device in devices) {
|
130
166
|
s = [device name];
|
131
|
-
|
167
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
132
168
|
}
|
133
169
|
#endif
|
134
170
|
|
135
171
|
// Pick and show default Metal device
|
136
172
|
device = MTLCreateSystemDefaultDevice();
|
137
173
|
s = [device name];
|
138
|
-
|
174
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
139
175
|
|
140
176
|
// Configure context
|
141
177
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
@@ -145,7 +181,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
145
181
|
ctx->n_buffers = 0;
|
146
182
|
ctx->concur_list_len = 0;
|
147
183
|
|
148
|
-
ctx->d_queue = dispatch_queue_create("
|
184
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
149
185
|
|
150
186
|
#ifdef GGML_SWIFT
|
151
187
|
// load the default.metallib file
|
@@ -162,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
162
198
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
163
199
|
|
164
200
|
if (error) {
|
165
|
-
|
201
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
166
202
|
return NULL;
|
167
203
|
}
|
168
204
|
}
|
@@ -175,12 +211,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
175
211
|
|
176
212
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
177
213
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
178
|
-
NSString * path
|
179
|
-
|
214
|
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
215
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
180
216
|
|
181
217
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
182
218
|
if (error) {
|
183
|
-
|
219
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
184
220
|
return NULL;
|
185
221
|
}
|
186
222
|
|
@@ -192,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
192
228
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
193
229
|
#endif
|
194
230
|
if (error) {
|
195
|
-
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
196
232
|
return NULL;
|
197
233
|
}
|
198
234
|
}
|
@@ -204,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
204
240
|
#define GGML_METAL_ADD_KERNEL(name) \
|
205
241
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
206
242
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
207
|
-
|
243
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
208
244
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
209
245
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
210
246
|
if (error) { \
|
211
|
-
|
247
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
212
248
|
return NULL; \
|
213
249
|
}
|
214
250
|
|
@@ -224,6 +260,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
224
260
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
225
261
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
226
262
|
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
227
264
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
228
265
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
229
266
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
@@ -235,6 +272,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
235
272
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
236
273
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
237
274
|
GGML_METAL_ADD_KERNEL(norm);
|
275
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
238
276
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
239
277
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
240
278
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
@@ -246,6 +284,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
246
284
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
247
285
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
248
286
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
287
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
249
288
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
250
289
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
251
290
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
@@ -255,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
255
294
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
256
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
257
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
258
|
-
GGML_METAL_ADD_KERNEL(
|
297
|
+
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(rope_f16);
|
259
299
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
260
300
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
261
301
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
@@ -264,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
264
304
|
#undef GGML_METAL_ADD_KERNEL
|
265
305
|
}
|
266
306
|
|
267
|
-
|
307
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
268
308
|
#if TARGET_OS_OSX
|
269
|
-
|
309
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
270
310
|
if (ctx->device.maxTransferRate != 0) {
|
271
|
-
|
311
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
272
312
|
} else {
|
273
|
-
|
313
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
274
314
|
}
|
275
315
|
#endif
|
276
316
|
|
@@ -278,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
278
318
|
}
|
279
319
|
|
280
320
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
281
|
-
|
321
|
+
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
282
322
|
#define GGML_METAL_DEL_KERNEL(name) \
|
283
323
|
[ctx->function_##name release]; \
|
284
324
|
[ctx->pipeline_##name release];
|
@@ -293,7 +333,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
293
333
|
GGML_METAL_DEL_KERNEL(gelu);
|
294
334
|
GGML_METAL_DEL_KERNEL(soft_max);
|
295
335
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
336
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
296
337
|
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
338
|
+
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
297
339
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
298
340
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
299
341
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
@@ -305,6 +347,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
305
347
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
306
348
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
307
349
|
GGML_METAL_DEL_KERNEL(norm);
|
350
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
308
351
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
309
352
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
310
353
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
@@ -316,6 +359,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
316
359
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
317
360
|
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
318
361
|
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
362
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
319
363
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
320
364
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
321
365
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
@@ -325,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
325
369
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
326
370
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
327
371
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
328
|
-
GGML_METAL_DEL_KERNEL(
|
372
|
+
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(rope_f16);
|
329
374
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
330
375
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
331
376
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
@@ -350,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
|
350
395
|
void * data = NULL;
|
351
396
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
352
397
|
if (result != 0) {
|
353
|
-
|
398
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
354
399
|
return NULL;
|
355
400
|
}
|
356
401
|
|
@@ -378,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
378
423
|
// Metal buffer based on the host memory pointer
|
379
424
|
//
|
380
425
|
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
381
|
-
//
|
426
|
+
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
382
427
|
|
383
428
|
const int64_t tsize = ggml_nbytes(t);
|
384
429
|
|
@@ -386,16 +431,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
386
431
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
387
432
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
388
433
|
|
434
|
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
389
435
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
390
436
|
*offs = (size_t) ioffs;
|
391
437
|
|
392
|
-
//
|
438
|
+
//GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
393
439
|
|
394
440
|
return ctx->buffers[i].metal;
|
395
441
|
}
|
396
442
|
}
|
397
443
|
|
398
|
-
|
444
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
399
445
|
|
400
446
|
return nil;
|
401
447
|
}
|
@@ -407,7 +453,7 @@ bool ggml_metal_add_buffer(
|
|
407
453
|
size_t size,
|
408
454
|
size_t max_size) {
|
409
455
|
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
410
|
-
|
456
|
+
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
411
457
|
return false;
|
412
458
|
}
|
413
459
|
|
@@ -417,7 +463,7 @@ bool ggml_metal_add_buffer(
|
|
417
463
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
418
464
|
|
419
465
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
420
|
-
|
466
|
+
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
421
467
|
return false;
|
422
468
|
}
|
423
469
|
}
|
@@ -438,11 +484,11 @@ bool ggml_metal_add_buffer(
|
|
438
484
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
439
485
|
|
440
486
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
441
|
-
|
487
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
442
488
|
return false;
|
443
489
|
}
|
444
490
|
|
445
|
-
|
491
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
446
492
|
|
447
493
|
++ctx->n_buffers;
|
448
494
|
} else {
|
@@ -462,13 +508,13 @@ bool ggml_metal_add_buffer(
|
|
462
508
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
463
509
|
|
464
510
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
465
|
-
|
511
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
466
512
|
return false;
|
467
513
|
}
|
468
514
|
|
469
|
-
|
515
|
+
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
470
516
|
if (i + size_step < size) {
|
471
|
-
|
517
|
+
GGML_METAL_LOG_INFO("\n");
|
472
518
|
}
|
473
519
|
|
474
520
|
++ctx->n_buffers;
|
@@ -476,17 +522,17 @@ bool ggml_metal_add_buffer(
|
|
476
522
|
}
|
477
523
|
|
478
524
|
#if TARGET_OS_OSX
|
479
|
-
|
525
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
480
526
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
481
527
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
482
528
|
|
483
529
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
484
|
-
|
530
|
+
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
485
531
|
} else {
|
486
|
-
|
532
|
+
GGML_METAL_LOG_INFO("\n");
|
487
533
|
}
|
488
534
|
#else
|
489
|
-
|
535
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
490
536
|
#endif
|
491
537
|
}
|
492
538
|
|
@@ -599,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
|
|
599
645
|
}
|
600
646
|
|
601
647
|
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
602
|
-
|
648
|
+
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
603
649
|
}
|
604
650
|
}
|
605
651
|
|
@@ -653,7 +699,7 @@ void ggml_metal_graph_compute(
|
|
653
699
|
continue;
|
654
700
|
}
|
655
701
|
|
656
|
-
//
|
702
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
657
703
|
|
658
704
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
659
705
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
@@ -697,17 +743,17 @@ void ggml_metal_graph_compute(
|
|
697
743
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
698
744
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
699
745
|
|
700
|
-
//
|
746
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
701
747
|
//if (src0) {
|
702
|
-
//
|
748
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
703
749
|
// ggml_is_contiguous(src0), src0->name);
|
704
750
|
//}
|
705
751
|
//if (src1) {
|
706
|
-
//
|
752
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
707
753
|
// ggml_is_contiguous(src1), src1->name);
|
708
754
|
//}
|
709
755
|
//if (dst) {
|
710
|
-
//
|
756
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
711
757
|
// dst->name);
|
712
758
|
//}
|
713
759
|
|
@@ -723,29 +769,66 @@ void ggml_metal_graph_compute(
|
|
723
769
|
case GGML_OP_ADD:
|
724
770
|
{
|
725
771
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
772
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
726
773
|
|
727
|
-
|
728
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
729
|
-
const int64_t nb = ne00/4;
|
774
|
+
bool bcast_row = false;
|
730
775
|
|
731
|
-
|
776
|
+
int64_t nb = ne00;
|
777
|
+
|
778
|
+
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
732
779
|
// src1 is a row
|
780
|
+
GGML_ASSERT(ne11 == 1);
|
781
|
+
|
782
|
+
nb = ne00 / 4;
|
733
783
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
784
|
+
|
785
|
+
bcast_row = true;
|
734
786
|
} else {
|
735
787
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
736
788
|
}
|
737
789
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
738
790
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
739
791
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
740
|
-
[encoder setBytes:&
|
741
|
-
|
742
|
-
|
792
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
793
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
794
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
795
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
796
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
797
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
798
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
799
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
800
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
801
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
802
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
803
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
804
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
805
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
806
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
807
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
808
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
809
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
810
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
811
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
812
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
813
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
814
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
815
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
816
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
817
|
+
|
818
|
+
if (bcast_row) {
|
819
|
+
const int64_t n = ggml_nelements(dst)/4;
|
820
|
+
|
821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
822
|
+
} else {
|
823
|
+
const int nth = MIN(1024, ne0);
|
743
824
|
|
744
|
-
|
825
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
826
|
+
}
|
745
827
|
} break;
|
746
828
|
case GGML_OP_MUL:
|
747
829
|
{
|
748
830
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
831
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
749
832
|
|
750
833
|
// utilize float4
|
751
834
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -753,6 +836,7 @@ void ggml_metal_graph_compute(
|
|
753
836
|
|
754
837
|
if (ggml_nelements(src1) == ne10) {
|
755
838
|
// src1 is a row
|
839
|
+
GGML_ASSERT(ne11 == 1);
|
756
840
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
757
841
|
} else {
|
758
842
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
@@ -768,6 +852,8 @@ void ggml_metal_graph_compute(
|
|
768
852
|
} break;
|
769
853
|
case GGML_OP_SCALE:
|
770
854
|
{
|
855
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
856
|
+
|
771
857
|
const float scale = *(const float *) src1->data;
|
772
858
|
|
773
859
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -813,13 +899,13 @@ void ggml_metal_graph_compute(
|
|
813
899
|
} break;
|
814
900
|
default:
|
815
901
|
{
|
816
|
-
|
902
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
817
903
|
GGML_ASSERT(false);
|
818
904
|
}
|
819
905
|
} break;
|
820
906
|
case GGML_OP_SOFT_MAX:
|
821
907
|
{
|
822
|
-
const int nth = 32;
|
908
|
+
const int nth = MIN(32, ne00);
|
823
909
|
|
824
910
|
if (ne00%4 == 0) {
|
825
911
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
@@ -867,13 +953,14 @@ void ggml_metal_graph_compute(
|
|
867
953
|
|
868
954
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
869
955
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
870
|
-
if (
|
871
|
-
|
956
|
+
if (!ggml_is_transposed(src0) &&
|
957
|
+
!ggml_is_transposed(src1) &&
|
872
958
|
src1t == GGML_TYPE_F32 &&
|
873
959
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
874
960
|
ne00%32 == 0 &&
|
875
|
-
ne11 >
|
961
|
+
ne11 > 2) {
|
876
962
|
switch (src0->type) {
|
963
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
877
964
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
878
965
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
879
966
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
@@ -893,9 +980,12 @@ void ggml_metal_graph_compute(
|
|
893
980
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
894
981
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
895
982
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
896
|
-
[encoder setBytes:&
|
897
|
-
[encoder setBytes:&
|
898
|
-
[encoder setBytes:&
|
983
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
984
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
985
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
986
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
987
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
988
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
899
989
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
900
990
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
901
991
|
} else {
|
@@ -905,6 +995,11 @@ void ggml_metal_graph_compute(
|
|
905
995
|
|
906
996
|
// use custom matrix x vector kernel
|
907
997
|
switch (src0t) {
|
998
|
+
case GGML_TYPE_F32:
|
999
|
+
{
|
1000
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
1001
|
+
nrows = 4;
|
1002
|
+
} break;
|
908
1003
|
case GGML_TYPE_F16:
|
909
1004
|
{
|
910
1005
|
nth0 = 32;
|
@@ -993,7 +1088,7 @@ void ggml_metal_graph_compute(
|
|
993
1088
|
} break;
|
994
1089
|
default:
|
995
1090
|
{
|
996
|
-
|
1091
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
997
1092
|
GGML_ASSERT(false && "not implemented");
|
998
1093
|
}
|
999
1094
|
};
|
@@ -1045,6 +1140,7 @@ void ggml_metal_graph_compute(
|
|
1045
1140
|
case GGML_OP_GET_ROWS:
|
1046
1141
|
{
|
1047
1142
|
switch (src0->type) {
|
1143
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
1048
1144
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1049
1145
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1050
1146
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
@@ -1060,9 +1156,9 @@ void ggml_metal_graph_compute(
|
|
1060
1156
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1061
1157
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1062
1158
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1063
|
-
[encoder setBytes:&
|
1064
|
-
[encoder setBytes:&
|
1065
|
-
[encoder setBytes:&
|
1159
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1160
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1161
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
1066
1162
|
|
1067
1163
|
const int64_t n = ggml_nelements(src1);
|
1068
1164
|
|
@@ -1073,7 +1169,7 @@ void ggml_metal_graph_compute(
|
|
1073
1169
|
float eps;
|
1074
1170
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1075
1171
|
|
1076
|
-
const int nth = 512;
|
1172
|
+
const int nth = MIN(512, ne00);
|
1077
1173
|
|
1078
1174
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
1079
1175
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1092,7 +1188,7 @@ void ggml_metal_graph_compute(
|
|
1092
1188
|
float eps;
|
1093
1189
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1094
1190
|
|
1095
|
-
const int nth = 256;
|
1191
|
+
const int nth = MIN(256, ne00);
|
1096
1192
|
|
1097
1193
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
1098
1194
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1110,6 +1206,8 @@ void ggml_metal_graph_compute(
|
|
1110
1206
|
{
|
1111
1207
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
1112
1208
|
|
1209
|
+
const int nth = MIN(1024, ne00);
|
1210
|
+
|
1113
1211
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
1114
1212
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1115
1213
|
float max_bias;
|
@@ -1143,12 +1241,14 @@ void ggml_metal_graph_compute(
|
|
1143
1241
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1144
1242
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1145
1243
|
|
1146
|
-
const int nth = 32;
|
1147
|
-
|
1148
1244
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1149
1245
|
} break;
|
1150
1246
|
case GGML_OP_ROPE:
|
1151
1247
|
{
|
1248
|
+
GGML_ASSERT(ne10 == ne02);
|
1249
|
+
|
1250
|
+
const int nth = MIN(1024, ne00);
|
1251
|
+
|
1152
1252
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1153
1253
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1154
1254
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1158,38 +1258,44 @@ void ggml_metal_graph_compute(
|
|
1158
1258
|
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1159
1259
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
1160
1260
|
|
1161
|
-
|
1261
|
+
switch (src0->type) {
|
1262
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
1263
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
1264
|
+
default: GGML_ASSERT(false);
|
1265
|
+
};
|
1266
|
+
|
1162
1267
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1163
|
-
[encoder setBuffer:
|
1164
|
-
[encoder
|
1165
|
-
[encoder setBytes:&
|
1166
|
-
[encoder setBytes:&
|
1167
|
-
[encoder setBytes:&
|
1168
|
-
[encoder setBytes:&
|
1169
|
-
[encoder setBytes:&
|
1170
|
-
[encoder setBytes:&
|
1171
|
-
[encoder setBytes:&
|
1172
|
-
[encoder setBytes:&
|
1173
|
-
[encoder setBytes:&
|
1174
|
-
[encoder setBytes:&
|
1175
|
-
[encoder setBytes:&
|
1176
|
-
[encoder setBytes:&
|
1177
|
-
[encoder setBytes:&
|
1178
|
-
[encoder setBytes:&
|
1179
|
-
[encoder setBytes:&
|
1180
|
-
[encoder setBytes:&
|
1181
|
-
[encoder setBytes:&
|
1182
|
-
[encoder setBytes:&
|
1183
|
-
[encoder setBytes:&
|
1184
|
-
[encoder setBytes:&
|
1268
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1289
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
|
1290
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
|
1185
1291
|
|
1186
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
1292
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1187
1293
|
} break;
|
1188
1294
|
case GGML_OP_DUP:
|
1189
1295
|
case GGML_OP_CPY:
|
1190
1296
|
case GGML_OP_CONT:
|
1191
1297
|
{
|
1192
|
-
const int nth =
|
1298
|
+
const int nth = MIN(1024, ne00);
|
1193
1299
|
|
1194
1300
|
switch (src0t) {
|
1195
1301
|
case GGML_TYPE_F32:
|
@@ -1234,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1234
1340
|
} break;
|
1235
1341
|
default:
|
1236
1342
|
{
|
1237
|
-
|
1343
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1238
1344
|
GGML_ASSERT(false);
|
1239
1345
|
}
|
1240
1346
|
}
|
@@ -1259,7 +1365,7 @@ void ggml_metal_graph_compute(
|
|
1259
1365
|
|
1260
1366
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
1261
1367
|
if (status != MTLCommandBufferStatusCompleted) {
|
1262
|
-
|
1368
|
+
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
1263
1369
|
GGML_ASSERT(false);
|
1264
1370
|
}
|
1265
1371
|
}
|