gpt_neox_client 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1195 @@
1
+ #import "ggml-metal.h"
2
+
3
+ #import "ggml.h"
4
+
5
+ #import <Foundation/Foundation.h>
6
+
7
+ #import <Metal/Metal.h>
8
+
9
+ #undef MIN
10
+ #undef MAX
11
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
+
14
+ #ifdef GGML_METAL_NDEBUG
15
+ #define metal_printf(...)
16
+ #else
17
+ #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
18
+ #endif
19
+
20
+ #define UNUSED(x) (void)(x)
21
+
22
+ #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
23
+
24
+ struct ggml_metal_buffer {
25
+ const char * name;
26
+
27
+ void * data;
28
+ size_t size;
29
+
30
+ id<MTLBuffer> metal;
31
+ };
32
+
33
+ struct ggml_metal_context {
34
+ int n_cb;
35
+
36
+ id<MTLDevice> device;
37
+ id<MTLCommandQueue> queue;
38
+ id<MTLLibrary> library;
39
+
40
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
41
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
42
+
43
+ dispatch_queue_t d_queue;
44
+
45
+ int n_buffers;
46
+ struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
47
+
48
+ int concur_list[GGML_MAX_CONCUR];
49
+ int concur_list_len;
50
+
51
+ // custom kernels
52
+ #define GGML_METAL_DECL_KERNEL(name) \
53
+ id<MTLFunction> function_##name; \
54
+ id<MTLComputePipelineState> pipeline_##name
55
+
56
+ GGML_METAL_DECL_KERNEL(add);
57
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
58
+ GGML_METAL_DECL_KERNEL(mul);
59
+ GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
60
+ GGML_METAL_DECL_KERNEL(scale);
61
+ GGML_METAL_DECL_KERNEL(silu);
62
+ GGML_METAL_DECL_KERNEL(relu);
63
+ GGML_METAL_DECL_KERNEL(gelu);
64
+ GGML_METAL_DECL_KERNEL(soft_max);
65
+ GGML_METAL_DECL_KERNEL(diag_mask_inf);
66
+ GGML_METAL_DECL_KERNEL(get_rows_f16);
67
+ GGML_METAL_DECL_KERNEL(get_rows_q4_0);
68
+ GGML_METAL_DECL_KERNEL(get_rows_q4_1);
69
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
70
+ GGML_METAL_DECL_KERNEL(get_rows_q2_K);
71
+ GGML_METAL_DECL_KERNEL(get_rows_q3_K);
72
+ GGML_METAL_DECL_KERNEL(get_rows_q4_K);
73
+ GGML_METAL_DECL_KERNEL(get_rows_q5_K);
74
+ GGML_METAL_DECL_KERNEL(get_rows_q6_K);
75
+ GGML_METAL_DECL_KERNEL(rms_norm);
76
+ GGML_METAL_DECL_KERNEL(norm);
77
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
78
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
79
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
80
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
82
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
95
+ GGML_METAL_DECL_KERNEL(rope);
96
+ GGML_METAL_DECL_KERNEL(alibi_f32);
97
+ GGML_METAL_DECL_KERNEL(cpy_f32_f16);
98
+ GGML_METAL_DECL_KERNEL(cpy_f32_f32);
99
+ GGML_METAL_DECL_KERNEL(cpy_f16_f16);
100
+
101
+ #undef GGML_METAL_DECL_KERNEL
102
+ };
103
+
104
+ // MSL code
105
+ // TODO: move the contents here when ready
106
+ // for now it is easier to work in a separate file
107
+ static NSString * const msl_library_source = @"see metal.metal";
108
+
109
+ // Here to assist with NSBundle Path Hack
110
+ @interface GGMLMetalClass : NSObject
111
+ @end
112
+ @implementation GGMLMetalClass
113
+ @end
114
+
115
+ struct ggml_metal_context * ggml_metal_init(int n_cb) {
116
+ fprintf(stderr, "%s: allocating\n", __func__);
117
+
118
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
119
+
120
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
121
+ ctx->device = MTLCreateSystemDefaultDevice();
122
+ ctx->queue = [ctx->device newCommandQueue];
123
+ ctx->n_buffers = 0;
124
+ ctx->concur_list_len = 0;
125
+
126
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
127
+
128
+ #if 0
129
+ // compile from source string and show compile log
130
+ {
131
+ NSError * error = nil;
132
+
133
+ ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
134
+ if (error) {
135
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ return NULL;
137
+ }
138
+ }
139
+ #else
140
+ UNUSED(msl_library_source);
141
+
142
+ // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
143
+ {
144
+ NSError * error = nil;
145
+
146
+ //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
147
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
148
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
149
+ fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
150
+
151
+ NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
152
+ if (error) {
153
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ return NULL;
155
+ }
156
+
157
+ #ifdef GGML_QKK_64
158
+ MTLCompileOptions* options = [MTLCompileOptions new];
159
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
160
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
161
+ #else
162
+ ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
163
+ #endif
164
+ if (error) {
165
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ return NULL;
167
+ }
168
+ }
169
+ #endif
170
+
171
+ // load kernels
172
+ {
173
+ NSError * error = nil;
174
+ #define GGML_METAL_ADD_KERNEL(name) \
175
+ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
176
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
177
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
178
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
179
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
180
+ if (error) { \
181
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ return NULL; \
183
+ }
184
+
185
+ GGML_METAL_ADD_KERNEL(add);
186
+ GGML_METAL_ADD_KERNEL(add_row);
187
+ GGML_METAL_ADD_KERNEL(mul);
188
+ GGML_METAL_ADD_KERNEL(mul_row);
189
+ GGML_METAL_ADD_KERNEL(scale);
190
+ GGML_METAL_ADD_KERNEL(silu);
191
+ GGML_METAL_ADD_KERNEL(relu);
192
+ GGML_METAL_ADD_KERNEL(gelu);
193
+ GGML_METAL_ADD_KERNEL(soft_max);
194
+ GGML_METAL_ADD_KERNEL(diag_mask_inf);
195
+ GGML_METAL_ADD_KERNEL(get_rows_f16);
196
+ GGML_METAL_ADD_KERNEL(get_rows_q4_0);
197
+ GGML_METAL_ADD_KERNEL(get_rows_q4_1);
198
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
199
+ GGML_METAL_ADD_KERNEL(get_rows_q2_K);
200
+ GGML_METAL_ADD_KERNEL(get_rows_q3_K);
201
+ GGML_METAL_ADD_KERNEL(get_rows_q4_K);
202
+ GGML_METAL_ADD_KERNEL(get_rows_q5_K);
203
+ GGML_METAL_ADD_KERNEL(get_rows_q6_K);
204
+ GGML_METAL_ADD_KERNEL(rms_norm);
205
+ GGML_METAL_ADD_KERNEL(norm);
206
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
207
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
208
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
212
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
213
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
214
+ GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
215
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
216
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
217
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
219
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
220
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
221
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
222
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
223
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
224
+ GGML_METAL_ADD_KERNEL(rope);
225
+ GGML_METAL_ADD_KERNEL(alibi_f32);
226
+ GGML_METAL_ADD_KERNEL(cpy_f32_f16);
227
+ GGML_METAL_ADD_KERNEL(cpy_f32_f32);
228
+ GGML_METAL_ADD_KERNEL(cpy_f16_f16);
229
+
230
+ #undef GGML_METAL_ADD_KERNEL
231
+ }
232
+
233
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
234
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
235
+ if (ctx->device.maxTransferRate != 0) {
236
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
+ } else {
238
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
+ }
240
+
241
+ return ctx;
242
+ }
243
+
244
+ void ggml_metal_free(struct ggml_metal_context * ctx) {
245
+ fprintf(stderr, "%s: deallocating\n", __func__);
246
+ #define GGML_METAL_DEL_KERNEL(name) \
247
+ [ctx->function_##name release]; \
248
+ [ctx->pipeline_##name release];
249
+
250
+ GGML_METAL_DEL_KERNEL(add);
251
+ GGML_METAL_DEL_KERNEL(add_row);
252
+ GGML_METAL_DEL_KERNEL(mul);
253
+ GGML_METAL_DEL_KERNEL(mul_row);
254
+ GGML_METAL_DEL_KERNEL(scale);
255
+ GGML_METAL_DEL_KERNEL(silu);
256
+ GGML_METAL_DEL_KERNEL(relu);
257
+ GGML_METAL_DEL_KERNEL(gelu);
258
+ GGML_METAL_DEL_KERNEL(soft_max);
259
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
260
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
261
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
269
+ GGML_METAL_DEL_KERNEL(rms_norm);
270
+ GGML_METAL_DEL_KERNEL(norm);
271
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
289
+ GGML_METAL_DEL_KERNEL(rope);
290
+ GGML_METAL_DEL_KERNEL(alibi_f32);
291
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
293
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
294
+
295
+ #undef GGML_METAL_DEL_KERNEL
296
+
297
+ for (int i = 0; i < ctx->n_buffers; ++i) {
298
+ [ctx->buffers[i].metal release];
299
+ }
300
+
301
+ [ctx->library release];
302
+ [ctx->queue release];
303
+ [ctx->device release];
304
+
305
+ dispatch_release(ctx->d_queue);
306
+
307
+ free(ctx);
308
+ }
309
+
310
+ void * ggml_metal_host_malloc(size_t n) {
311
+ void * data = NULL;
312
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
313
+ if (result != 0) {
314
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
315
+ return NULL;
316
+ }
317
+
318
+ return data;
319
+ }
320
+
321
+ void ggml_metal_host_free(void * data) {
322
+ free(data);
323
+ }
324
+
325
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
326
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
327
+ }
328
+
329
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
330
+ return ctx->concur_list_len;
331
+ }
332
+
333
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
334
+ return ctx->concur_list;
335
+ }
336
+
337
+ // finds the Metal buffer that contains the tensor data on the GPU device
338
+ // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
339
+ // Metal buffer based on the host memory pointer
340
+ //
341
+ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
342
+ //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
343
+
344
+ const int64_t tsize = ggml_nbytes(t);
345
+
346
+ // find the view that contains the tensor fully
347
+ for (int i = 0; i < ctx->n_buffers; ++i) {
348
+ const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
349
+
350
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
351
+ *offs = (size_t) ioffs;
352
+
353
+ //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
354
+
355
+ return ctx->buffers[i].metal;
356
+ }
357
+ }
358
+
359
+ fprintf(stderr, "%s: error: buffer is nil\n", __func__);
360
+
361
+ return nil;
362
+ }
363
+
364
+ bool ggml_metal_add_buffer(
365
+ struct ggml_metal_context * ctx,
366
+ const char * name,
367
+ void * data,
368
+ size_t size,
369
+ size_t max_size) {
370
+ if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
371
+ fprintf(stderr, "%s: too many buffers\n", __func__);
372
+ return false;
373
+ }
374
+
375
+ if (data) {
376
+ // verify that the buffer does not overlap with any of the existing buffers
377
+ for (int i = 0; i < ctx->n_buffers; ++i) {
378
+ const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
379
+
380
+ if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
381
+ fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
382
+ return false;
383
+ }
384
+ }
385
+
386
+ const size_t size_page = getpagesize();
387
+
388
+ size_t size_aligned = size;
389
+ if ((size_aligned % size_page) != 0) {
390
+ size_aligned += (size_page - (size_aligned % size_page));
391
+ }
392
+
393
+ // the buffer fits into the max buffer size allowed by the device
394
+ if (size_aligned <= ctx->device.maxBufferLength) {
395
+ ctx->buffers[ctx->n_buffers].name = name;
396
+ ctx->buffers[ctx->n_buffers].data = data;
397
+ ctx->buffers[ctx->n_buffers].size = size;
398
+
399
+ ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
400
+
401
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
402
+ fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
403
+ return false;
404
+ }
405
+
406
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
407
+
408
+ ++ctx->n_buffers;
409
+ } else {
410
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
411
+ // one of the views
412
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
413
+ const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
414
+ const size_t size_view = ctx->device.maxBufferLength;
415
+
416
+ for (size_t i = 0; i < size; i += size_step) {
417
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
418
+
419
+ ctx->buffers[ctx->n_buffers].name = name;
420
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
421
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
422
+
423
+ ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
424
+
425
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
426
+ fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
427
+ return false;
428
+ }
429
+
430
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
431
+ if (i + size_step < size) {
432
+ fprintf(stderr, "\n");
433
+ }
434
+
435
+ ++ctx->n_buffers;
436
+ }
437
+ }
438
+
439
+ fprintf(stderr, ", (%8.2f / %8.2f)",
440
+ ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
441
+ ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
442
+
443
+ if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
444
+ fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
445
+ } else {
446
+ fprintf(stderr, "\n");
447
+ }
448
+ }
449
+
450
+ return true;
451
+ }
452
+
453
+ void ggml_metal_set_tensor(
454
+ struct ggml_metal_context * ctx,
455
+ struct ggml_tensor * t) {
456
+ metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
457
+
458
+ size_t offs;
459
+ id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
460
+
461
+ memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
462
+ }
463
+
464
+ void ggml_metal_get_tensor(
465
+ struct ggml_metal_context * ctx,
466
+ struct ggml_tensor * t) {
467
+ metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
468
+
469
+ size_t offs;
470
+ id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
471
+
472
+ memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
473
+ }
474
+
475
+ void ggml_metal_graph_find_concurrency(
476
+ struct ggml_metal_context * ctx,
477
+ struct ggml_cgraph * gf, bool check_mem) {
478
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
479
+ int nodes_unused[GGML_MAX_CONCUR];
480
+
481
+ for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
482
+ for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
483
+ ctx->concur_list_len = 0;
484
+
485
+ int n_left = gf->n_nodes;
486
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
487
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
488
+
489
+ while (n_left > 0) {
490
+ // number of nodes at a layer (that can be issued concurrently)
491
+ int concurrency = 0;
492
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
493
+ if (nodes_unused[i]) {
494
+ // if the requirements for gf->nodes[i] are satisfied
495
+ int exe_flag = 1;
496
+
497
+ // scan all srcs
498
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
499
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
500
+ if (src_cur) {
501
+ // if is leaf nodes it's satisfied.
502
+ // TODO: ggml_is_leaf()
503
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
504
+ continue;
505
+ }
506
+
507
+ // otherwise this src should be the output from previous nodes.
508
+ int is_found = 0;
509
+
510
+ // scan 2*search_depth back because we inserted barrier.
511
+ //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
512
+ for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
513
+ if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
514
+ is_found = 1;
515
+ break;
516
+ }
517
+ }
518
+ if (is_found == 0) {
519
+ exe_flag = 0;
520
+ break;
521
+ }
522
+ }
523
+ }
524
+ if (exe_flag && check_mem) {
525
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
526
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
527
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
528
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
529
+ for (int j = n_start; j < i; j++) {
530
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
531
+ && gf->nodes[j]->op != GGML_OP_VIEW \
532
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
533
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
534
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
535
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
536
+ continue;
537
+ }
538
+
539
+ exe_flag = 0;
540
+ }
541
+ }
542
+ }
543
+ if (exe_flag) {
544
+ ctx->concur_list[level_pos + concurrency] = i;
545
+ nodes_unused[i] = 0;
546
+ concurrency++;
547
+ ctx->concur_list_len++;
548
+ }
549
+ }
550
+ }
551
+ n_left -= concurrency;
552
+ // adding a barrier different layer
553
+ ctx->concur_list[level_pos + concurrency] = -1;
554
+ ctx->concur_list_len++;
555
+ // jump all sorted nodes at nodes_bak
556
+ while (!nodes_unused[n_start]) {
557
+ n_start++;
558
+ }
559
+ level_pos += concurrency + 1;
560
+ }
561
+
562
+ if (ctx->concur_list_len > GGML_MAX_CONCUR) {
563
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
564
+ }
565
+ }
566
+
567
+ void ggml_metal_graph_compute(
568
+ struct ggml_metal_context * ctx,
569
+ struct ggml_cgraph * gf) {
570
+ metal_printf("%s: evaluating graph\n", __func__);
571
+
572
+ @autoreleasepool {
573
+
574
+ // if there is ctx->concur_list, dispatch concurrently
575
+ // else fallback to serial dispatch
576
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
577
+
578
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
579
+
580
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
581
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
582
+
583
+ // create multiple command buffers and enqueue them
584
+ // then, we encode the graph into the command buffers in parallel
585
+
586
+ const int n_cb = ctx->n_cb;
587
+
588
+ for (int i = 0; i < n_cb; ++i) {
589
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
590
+
591
+ // enqueue the command buffers in order to specify their execution order
592
+ [ctx->command_buffers[i] enqueue];
593
+
594
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
595
+ }
596
+
597
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
598
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
599
+
600
+ dispatch_async(ctx->d_queue, ^{
601
+ size_t offs_src0 = 0;
602
+ size_t offs_src1 = 0;
603
+ size_t offs_dst = 0;
604
+
605
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
606
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
607
+
608
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
609
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
610
+
611
+ for (int ind = node_start; ind < node_end; ++ind) {
612
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
613
+
614
+ if (i == -1) {
615
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
616
+ continue;
617
+ }
618
+
619
+ metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
620
+
621
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
622
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
623
+ struct ggml_tensor * dst = gf->nodes[i];
624
+
625
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
626
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
627
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
628
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
629
+
630
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
631
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
632
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
633
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
634
+
635
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
636
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
637
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
638
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
639
+
640
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
641
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
642
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
643
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
644
+
645
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
646
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
647
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
648
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
649
+
650
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
651
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
652
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
653
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
654
+
655
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
656
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
657
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
658
+
659
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
660
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
661
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
662
+
663
+ //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
664
+ //if (src0) {
665
+ // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
666
+ // ggml_is_contiguous(src0), src0->name);
667
+ //}
668
+ //if (src1) {
669
+ // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
670
+ // ggml_is_contiguous(src1), src1->name);
671
+ //}
672
+ //if (dst) {
673
+ // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
674
+ // dst->name);
675
+ //}
676
+
677
+ switch (dst->op) {
678
+ case GGML_OP_NONE:
679
+ case GGML_OP_RESHAPE:
680
+ case GGML_OP_VIEW:
681
+ case GGML_OP_TRANSPOSE:
682
+ case GGML_OP_PERMUTE:
683
+ {
684
+ // noop
685
+ } break;
686
+ case GGML_OP_ADD:
687
+ {
688
+ if (ggml_nelements(src1) == ne10) {
689
+ // src1 is a row
690
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
691
+ } else {
692
+ [encoder setComputePipelineState:ctx->pipeline_add];
693
+ }
694
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
695
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
696
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
697
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
+
699
+ const int64_t n = ggml_nelements(dst);
700
+
701
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
702
+ } break;
703
+ case GGML_OP_MUL:
704
+ {
705
+ if (ggml_nelements(src1) == ne10) {
706
+ // src1 is a row
707
+ [encoder setComputePipelineState:ctx->pipeline_mul_row];
708
+ } else {
709
+ [encoder setComputePipelineState:ctx->pipeline_mul];
710
+ }
711
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
712
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
713
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
714
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
715
+
716
+ const int64_t n = ggml_nelements(dst);
717
+
718
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
719
+ } break;
720
+ case GGML_OP_SCALE:
721
+ {
722
+ const float scale = *(const float *) src1->data;
723
+
724
+ [encoder setComputePipelineState:ctx->pipeline_scale];
725
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
726
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
727
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
728
+
729
+ const int64_t n = ggml_nelements(dst);
730
+
731
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
732
+ } break;
733
+ case GGML_OP_UNARY:
734
+ switch (ggml_get_unary_op(gf->nodes[i])) {
735
+ case GGML_UNARY_OP_SILU:
736
+ {
737
+ [encoder setComputePipelineState:ctx->pipeline_silu];
738
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
739
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
740
+
741
+ const int64_t n = ggml_nelements(dst);
742
+
743
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
744
+ } break;
745
+ case GGML_UNARY_OP_RELU:
746
+ {
747
+ [encoder setComputePipelineState:ctx->pipeline_relu];
748
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
749
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
750
+
751
+ const int64_t n = ggml_nelements(dst);
752
+
753
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
754
+ } break;
755
+ case GGML_UNARY_OP_GELU:
756
+ {
757
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
758
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
759
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
760
+
761
+ const int64_t n = ggml_nelements(dst);
762
+
763
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
764
+ } break;
765
+ default:
766
+ {
767
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
768
+ GGML_ASSERT(false);
769
+ }
770
+ } break;
771
+ case GGML_OP_SOFT_MAX:
772
+ {
773
+ const int nth = 32;
774
+
775
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
776
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
777
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
778
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
779
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
780
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
781
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
782
+
783
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
784
+ } break;
785
+ case GGML_OP_DIAG_MASK_INF:
786
+ {
787
+ const int n_past = ((int32_t *)(dst->op_params))[0];
788
+
789
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
790
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
791
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
792
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
793
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
794
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
795
+
796
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
797
+ } break;
798
+ case GGML_OP_MUL_MAT:
799
+ {
800
+ // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
801
+
802
+ GGML_ASSERT(ne00 == ne10);
803
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
804
+ uint gqa = ne12/ne02;
805
+ GGML_ASSERT(ne03 == ne13);
806
+
807
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
808
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
809
+ if (ggml_is_contiguous(src0) &&
810
+ ggml_is_contiguous(src1) &&
811
+ src1t == GGML_TYPE_F32 &&
812
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
813
+ ne00%32 == 0 &&
814
+ ne11 > 1) {
815
+ switch (src0->type) {
816
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
817
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
818
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
819
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
820
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
821
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
822
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
823
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
824
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
825
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
826
+ }
827
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
828
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
829
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
830
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
831
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
832
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
833
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
834
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
835
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
836
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
837
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
838
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
839
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
840
+ } else {
841
+ int nth0 = 32;
842
+ int nth1 = 1;
843
+
844
+ // use custom matrix x vector kernel
845
+ switch (src0t) {
846
+ case GGML_TYPE_F16:
847
+ {
848
+ nth0 = 64;
849
+ nth1 = 1;
850
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
851
+ } break;
852
+ case GGML_TYPE_Q4_0:
853
+ {
854
+ GGML_ASSERT(ne02 == 1);
855
+ GGML_ASSERT(ne12 == 1);
856
+
857
+ nth0 = 8;
858
+ nth1 = 8;
859
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
860
+ } break;
861
+ case GGML_TYPE_Q4_1:
862
+ {
863
+ GGML_ASSERT(ne02 == 1);
864
+ GGML_ASSERT(ne12 == 1);
865
+
866
+ nth0 = 8;
867
+ nth1 = 8;
868
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
869
+ } break;
870
+ case GGML_TYPE_Q8_0:
871
+ {
872
+ GGML_ASSERT(ne02 == 1);
873
+ GGML_ASSERT(ne12 == 1);
874
+
875
+ nth0 = 8;
876
+ nth1 = 8;
877
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
878
+ } break;
879
+ case GGML_TYPE_Q2_K:
880
+ {
881
+ GGML_ASSERT(ne02 == 1);
882
+ GGML_ASSERT(ne12 == 1);
883
+
884
+ nth0 = 2;
885
+ nth1 = 32;
886
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
887
+ } break;
888
+ case GGML_TYPE_Q3_K:
889
+ {
890
+ GGML_ASSERT(ne02 == 1);
891
+ GGML_ASSERT(ne12 == 1);
892
+
893
+ nth0 = 2;
894
+ nth1 = 32;
895
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
896
+ } break;
897
+ case GGML_TYPE_Q4_K:
898
+ {
899
+ GGML_ASSERT(ne02 == 1);
900
+ GGML_ASSERT(ne12 == 1);
901
+
902
+ nth0 = 2;
903
+ nth1 = 32;
904
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
905
+ } break;
906
+ case GGML_TYPE_Q5_K:
907
+ {
908
+ GGML_ASSERT(ne02 == 1);
909
+ GGML_ASSERT(ne12 == 1);
910
+
911
+ nth0 = 2;
912
+ nth1 = 32;
913
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
914
+ } break;
915
+ case GGML_TYPE_Q6_K:
916
+ {
917
+ GGML_ASSERT(ne02 == 1);
918
+ GGML_ASSERT(ne12 == 1);
919
+
920
+ nth0 = 2;
921
+ nth1 = 32;
922
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
923
+ } break;
924
+ default:
925
+ {
926
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
927
+ GGML_ASSERT(false && "not implemented");
928
+ }
929
+ };
930
+
931
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
932
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
933
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
934
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
935
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
936
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
937
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
938
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
939
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
940
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
941
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
942
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
943
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
944
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
945
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
946
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
947
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
948
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
949
+
950
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
951
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
952
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
953
+ }
954
+ else if (src0t == GGML_TYPE_Q3_K) {
955
+ #ifdef GGML_QKK_64
956
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
957
+ #else
958
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
959
+ #endif
960
+ }
961
+ else if (src0t == GGML_TYPE_Q5_K) {
962
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
963
+ }
964
+ else if (src0t == GGML_TYPE_Q6_K) {
965
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
966
+ } else {
967
+ [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
968
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
969
+ }
970
+ }
971
+ } break;
972
+ case GGML_OP_GET_ROWS:
973
+ {
974
+ switch (src0->type) {
975
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
976
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
977
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
978
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
979
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
980
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
981
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
982
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
983
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
984
+ default: GGML_ASSERT(false && "not implemented");
985
+ }
986
+
987
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
988
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
989
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
990
+ [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
991
+ [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
992
+ [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
993
+
994
+ const int64_t n = ggml_nelements(src1);
995
+
996
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
997
+ } break;
998
+ case GGML_OP_RMS_NORM:
999
+ {
1000
+ float eps;
1001
+ memcpy(&eps, dst->op_params, sizeof(float));
1002
+
1003
+ const int nth = 512;
1004
+
1005
+ [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1006
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1007
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1008
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1009
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1010
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1011
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1012
+
1013
+ const int64_t nrows = ggml_nrows(src0);
1014
+
1015
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1016
+ } break;
1017
+ case GGML_OP_NORM:
1018
+ {
1019
+ float eps;
1020
+ memcpy(&eps, dst->op_params, sizeof(float));
1021
+
1022
+ const int nth = 256;
1023
+
1024
+ [encoder setComputePipelineState:ctx->pipeline_norm];
1025
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1026
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1027
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1028
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1029
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1030
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1031
+
1032
+ const int64_t nrows = ggml_nrows(src0);
1033
+
1034
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1035
+ } break;
1036
+ case GGML_OP_ALIBI:
1037
+ {
1038
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
1039
+
1040
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1041
+ const int n_head = ((int32_t *) dst->op_params)[1];
1042
+ float max_bias;
1043
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1044
+
1045
+ if (__builtin_popcount(n_head) != 1) {
1046
+ GGML_ASSERT(false && "only power-of-two n_head implemented");
1047
+ }
1048
+
1049
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1050
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1051
+
1052
+ [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1053
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1054
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1055
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1056
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1057
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1058
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1059
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1060
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1061
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1062
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1063
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1064
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1065
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1066
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1067
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1068
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1069
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1070
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1071
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1072
+
1073
+ const int nth = 32;
1074
+
1075
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1076
+ } break;
1077
+ case GGML_OP_ROPE:
1078
+ {
1079
+ const int n_past = ((int32_t *) dst->op_params)[0];
1080
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1081
+ const int mode = ((int32_t *) dst->op_params)[2];
1082
+
1083
+ float freq_base;
1084
+ float freq_scale;
1085
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1086
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1087
+
1088
+ [encoder setComputePipelineState:ctx->pipeline_rope];
1089
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1090
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1091
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1092
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1093
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1094
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1095
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1096
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1097
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1098
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1099
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1100
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1101
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1102
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1103
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1104
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1105
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1106
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1107
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1108
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1109
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1110
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1111
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1112
+
1113
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1114
+ } break;
1115
+ case GGML_OP_DUP:
1116
+ case GGML_OP_CPY:
1117
+ case GGML_OP_CONT:
1118
+ {
1119
+ const int nth = 32;
1120
+
1121
+ switch (src0t) {
1122
+ case GGML_TYPE_F32:
1123
+ {
1124
+ switch (dstt) {
1125
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1126
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1127
+ default: GGML_ASSERT(false && "not implemented");
1128
+ };
1129
+ } break;
1130
+ case GGML_TYPE_F16:
1131
+ {
1132
+ switch (dstt) {
1133
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1134
+ case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
1135
+ default: GGML_ASSERT(false && "not implemented");
1136
+ };
1137
+ } break;
1138
+ default: GGML_ASSERT(false && "not implemented");
1139
+ }
1140
+
1141
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1142
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1143
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1144
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1145
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1146
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1147
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1148
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1149
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1150
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1151
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1152
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1153
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1154
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1155
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1156
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1157
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1158
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1159
+
1160
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1161
+ } break;
1162
+ default:
1163
+ {
1164
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1165
+ GGML_ASSERT(false);
1166
+ }
1167
+ }
1168
+ }
1169
+
1170
+ if (encoder != nil) {
1171
+ [encoder endEncoding];
1172
+ encoder = nil;
1173
+ }
1174
+
1175
+ [command_buffer commit];
1176
+ });
1177
+ }
1178
+
1179
+ // wait for all threads to finish
1180
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1181
+
1182
+ // check status of command buffers
1183
+ // needed to detect if the device ran out-of-memory for example (#1881)
1184
+ for (int i = 0; i < n_cb; i++) {
1185
+ [ctx->command_buffers[i] waitUntilCompleted];
1186
+
1187
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1188
+ if (status != MTLCommandBufferStatusCompleted) {
1189
+ fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1190
+ GGML_ASSERT(false);
1191
+ }
1192
+ }
1193
+
1194
+ }
1195
+ }