llama_cpp 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,834 @@
1
+ #import "ggml-metal.h"
2
+
3
+ #import "ggml.h"
4
+
5
+ #import <Foundation/Foundation.h>
6
+
7
+ #import <Metal/Metal.h>
8
+ #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9
+
10
+ #ifdef GGML_METAL_NDEBUG
11
+ #define metal_printf(...)
12
+ #else
13
+ #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
14
+ #endif
15
+
16
+ #define UNUSED(x) (void)(x)
17
+
18
+ struct ggml_metal_buffer {
19
+ const char * name;
20
+
21
+ void * data;
22
+ size_t size;
23
+
24
+ id<MTLBuffer> metal;
25
+ };
26
+
27
+ struct ggml_metal_context {
28
+ float * logits;
29
+
30
+ id<MTLDevice> device;
31
+ id<MTLCommandQueue> queue;
32
+ id<MTLLibrary> library;
33
+
34
+ int n_buffers;
35
+ struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
36
+
37
+ // custom kernels
38
+ #define GGML_METAL_DECL_KERNEL(name) \
39
+ id<MTLFunction> function_##name; \
40
+ id<MTLComputePipelineState> pipeline_##name
41
+
42
+ GGML_METAL_DECL_KERNEL(add);
43
+ GGML_METAL_DECL_KERNEL(mul);
44
+ GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
45
+ GGML_METAL_DECL_KERNEL(scale);
46
+ GGML_METAL_DECL_KERNEL(silu);
47
+ GGML_METAL_DECL_KERNEL(relu);
48
+ GGML_METAL_DECL_KERNEL(gelu);
49
+ GGML_METAL_DECL_KERNEL(soft_max);
50
+ GGML_METAL_DECL_KERNEL(diag_mask_inf);
51
+ GGML_METAL_DECL_KERNEL(get_rows_f16);
52
+ GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
+ GGML_METAL_DECL_KERNEL(get_rows_q4_1);
54
+ GGML_METAL_DECL_KERNEL(get_rows_q2_k);
55
+ GGML_METAL_DECL_KERNEL(get_rows_q3_k);
56
+ GGML_METAL_DECL_KERNEL(get_rows_q4_k);
57
+ GGML_METAL_DECL_KERNEL(get_rows_q5_k);
58
+ GGML_METAL_DECL_KERNEL(get_rows_q6_k);
59
+ GGML_METAL_DECL_KERNEL(rms_norm);
60
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
61
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
62
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
63
+ GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
64
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
65
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
66
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
67
+ GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
68
+ GGML_METAL_DECL_KERNEL(rope);
69
+ GGML_METAL_DECL_KERNEL(cpy_f32_f16);
70
+ GGML_METAL_DECL_KERNEL(cpy_f32_f32);
71
+
72
+ #undef GGML_METAL_DECL_KERNEL
73
+ };
74
+
75
+ // MSL code
76
+ // TODO: move the contents here when ready
77
+ // for now it is easier to work in a separate file
78
+ static NSString * const msl_library_source = @"see metal.metal";
79
+
80
+ // Here to assist with NSBundle Path Hack
81
+ @interface GGMLMetalClass : NSObject
82
+ @end
83
+ @implementation GGMLMetalClass
84
+ @end
85
+
86
+ struct ggml_metal_context * ggml_metal_init(void) {
87
+ fprintf(stderr, "%s: allocating\n", __func__);
88
+
89
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
90
+
91
+ ctx->device = MTLCreateSystemDefaultDevice();
92
+ ctx->queue = [ctx->device newCommandQueue];
93
+ ctx->n_buffers = 0;
94
+
95
+ // determine if we can use MPS
96
+ if (MPSSupportsMTLDevice(ctx->device)) {
97
+ fprintf(stderr, "%s: using MPS\n", __func__);
98
+ } else {
99
+ fprintf(stderr, "%s: not using MPS\n", __func__);
100
+ GGML_ASSERT(false && "MPS not supported");
101
+ }
102
+
103
+ #if 0
104
+ // compile from source string and show compile log
105
+ {
106
+ NSError * error = nil;
107
+
108
+ ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
109
+ if (error) {
110
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
111
+ exit(1);
112
+ }
113
+ }
114
+ #else
115
+ UNUSED(msl_library_source);
116
+
117
+ // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
118
+ {
119
+ NSError * error = nil;
120
+
121
+ //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
122
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
123
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
124
+ fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
125
+
126
+ NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
127
+ if (error) {
128
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
129
+ exit(1);
130
+ }
131
+
132
+ ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
133
+ if (error) {
134
+ fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
135
+ exit(1);
136
+ }
137
+ }
138
+ #endif
139
+
140
+ // load kernels
141
+ {
142
+ #define GGML_METAL_ADD_KERNEL(name) \
143
+ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
144
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
145
+ fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
146
+
147
+ GGML_METAL_ADD_KERNEL(add);
148
+ GGML_METAL_ADD_KERNEL(mul);
149
+ GGML_METAL_ADD_KERNEL(mul_row);
150
+ GGML_METAL_ADD_KERNEL(scale);
151
+ GGML_METAL_ADD_KERNEL(silu);
152
+ GGML_METAL_ADD_KERNEL(relu);
153
+ GGML_METAL_ADD_KERNEL(gelu);
154
+ GGML_METAL_ADD_KERNEL(soft_max);
155
+ GGML_METAL_ADD_KERNEL(diag_mask_inf);
156
+ GGML_METAL_ADD_KERNEL(get_rows_f16);
157
+ GGML_METAL_ADD_KERNEL(get_rows_q4_0);
158
+ GGML_METAL_ADD_KERNEL(get_rows_q4_1);
159
+ GGML_METAL_ADD_KERNEL(get_rows_q2_k);
160
+ GGML_METAL_ADD_KERNEL(get_rows_q3_k);
161
+ GGML_METAL_ADD_KERNEL(get_rows_q4_k);
162
+ GGML_METAL_ADD_KERNEL(get_rows_q5_k);
163
+ GGML_METAL_ADD_KERNEL(get_rows_q6_k);
164
+ GGML_METAL_ADD_KERNEL(rms_norm);
165
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
166
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
167
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
168
+ GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
169
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
170
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
171
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
172
+ GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
173
+ GGML_METAL_ADD_KERNEL(rope);
174
+ GGML_METAL_ADD_KERNEL(cpy_f32_f16);
175
+ GGML_METAL_ADD_KERNEL(cpy_f32_f32);
176
+
177
+ #undef GGML_METAL_ADD_KERNEL
178
+ }
179
+
180
+ return ctx;
181
+ }
182
+
183
+ void ggml_metal_free(struct ggml_metal_context * ctx) {
184
+ fprintf(stderr, "%s: deallocating\n", __func__);
185
+
186
+ free(ctx);
187
+ }
188
+
189
+ // finds the Metal buffer that contains the tensor data on the GPU device
190
+ // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
191
+ // Metal buffer based on the host memory pointer
192
+ //
193
+ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
194
+ //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
195
+
196
+ for (int i = 0; i < ctx->n_buffers; ++i) {
197
+ const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
198
+
199
+ if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
200
+ *offs = (size_t) ioffs;
201
+
202
+ //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
203
+
204
+ return ctx->buffers[i].metal;
205
+ }
206
+ }
207
+
208
+ fprintf(stderr, "%s: error: buffer is nil\n", __func__);
209
+
210
+ return nil;
211
+ }
212
+
213
+ bool ggml_metal_add_buffer(
214
+ struct ggml_metal_context * ctx,
215
+ const char * name,
216
+ void * data,
217
+ size_t size) {
218
+ if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
219
+ fprintf(stderr, "%s: too many buffers\n", __func__);
220
+ return false;
221
+ }
222
+
223
+ if (data) {
224
+ // verify that the buffer does not overlap with any of the existing buffers
225
+ for (int i = 0; i < ctx->n_buffers; ++i) {
226
+ const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
227
+
228
+ if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
229
+ fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
230
+ return false;
231
+ }
232
+ }
233
+
234
+ size_t page_size = getpagesize();
235
+ size_t aligned_size = size;
236
+ if ((aligned_size % page_size) != 0) {
237
+ aligned_size += (page_size - (aligned_size % page_size));
238
+ }
239
+
240
+ ctx->buffers[ctx->n_buffers].name = name;
241
+ ctx->buffers[ctx->n_buffers].data = data;
242
+ ctx->buffers[ctx->n_buffers].size = size;
243
+
244
+ if (ctx->device.maxBufferLength < aligned_size) {
245
+ fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
246
+ return false;
247
+ }
248
+ ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
249
+
250
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
251
+ fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
252
+ return false;
253
+ } else {
254
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
255
+ }
256
+
257
+ ++ctx->n_buffers;
258
+ }
259
+
260
+ return true;
261
+ }
262
+
263
+ void ggml_metal_set_tensor(
264
+ struct ggml_metal_context * ctx,
265
+ struct ggml_tensor * t) {
266
+ metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
267
+
268
+ size_t offs;
269
+ id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
270
+
271
+ memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
272
+ }
273
+
274
+ void ggml_metal_get_tensor(
275
+ struct ggml_metal_context * ctx,
276
+ struct ggml_tensor * t) {
277
+ metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
278
+
279
+ size_t offs;
280
+ id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
281
+
282
+ memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
283
+ }
284
+
285
+ void ggml_metal_graph_compute(
286
+ struct ggml_metal_context * ctx,
287
+ struct ggml_cgraph * gf) {
288
+ metal_printf("%s: evaluating graph\n", __func__);
289
+
290
+ // create multiple command buffers and enqueue them
291
+ // then, we encode the graph into the command buffers in parallel
292
+
293
+ const int n_cb = gf->n_threads;
294
+
295
+ NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
296
+
297
+ for (int i = 0; i < n_cb; ++i) {
298
+ command_buffers[i] = [ctx->queue commandBuffer];
299
+
300
+ // enqueue the command buffers in order to specify their execution order
301
+ [command_buffers[i] enqueue];
302
+ }
303
+
304
+ // TODO: is this the best way to start threads?
305
+ dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
306
+
307
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
308
+ const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
309
+
310
+ dispatch_async(queue, ^{
311
+ size_t offs_src0 = 0;
312
+ size_t offs_src1 = 0;
313
+ size_t offs_dst = 0;
314
+
315
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
316
+
317
+ id<MTLComputeCommandEncoder> encoder = nil;
318
+
319
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
320
+ const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
321
+
322
+ for (int i = node_start; i < node_end; ++i) {
323
+ metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
324
+
325
+ struct ggml_tensor * src0 = gf->nodes[i]->src0;
326
+ struct ggml_tensor * src1 = gf->nodes[i]->src1;
327
+ struct ggml_tensor * dst = gf->nodes[i];
328
+
329
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
330
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
331
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
332
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
333
+
334
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
335
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
336
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
337
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
338
+
339
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
340
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
341
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
342
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
343
+
344
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
345
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
346
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
347
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
348
+
349
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
350
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
351
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
352
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
353
+
354
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
355
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
356
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
357
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
358
+
359
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
360
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
361
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
362
+
363
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
364
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
365
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
366
+
367
+ //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
368
+ //if (src0) {
369
+ // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
370
+ // ggml_is_contiguous(src0), src0->name);
371
+ //}
372
+ //if (src1) {
373
+ // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
374
+ // ggml_is_contiguous(src1), src1->name);
375
+ //}
376
+ //if (dst) {
377
+ // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
378
+ // dst->name);
379
+ //}
380
+
381
+ switch (dst->op) {
382
+ case GGML_OP_RESHAPE:
383
+ case GGML_OP_VIEW:
384
+ case GGML_OP_TRANSPOSE:
385
+ case GGML_OP_PERMUTE:
386
+ {
387
+ // noop
388
+ } break;
389
+ case GGML_OP_ADD:
390
+ {
391
+ if (encoder == nil) {
392
+ encoder = [command_buffer computeCommandEncoder];
393
+ }
394
+
395
+ [encoder setComputePipelineState:ctx->pipeline_add];
396
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
397
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
398
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
399
+
400
+ const int64_t n = ggml_nelements(dst);
401
+
402
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
403
+ } break;
404
+ case GGML_OP_MUL:
405
+ {
406
+ if (encoder == nil) {
407
+ encoder = [command_buffer computeCommandEncoder];
408
+ }
409
+
410
+ if (ggml_nelements(src1) == ne10) {
411
+ // src1 is a row
412
+ [encoder setComputePipelineState:ctx->pipeline_mul_row];
413
+ } else {
414
+ [encoder setComputePipelineState:ctx->pipeline_mul];
415
+ }
416
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
417
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
418
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
419
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
420
+
421
+ const int64_t n = ggml_nelements(dst);
422
+
423
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
424
+ } break;
425
+ case GGML_OP_SCALE:
426
+ {
427
+ if (encoder == nil) {
428
+ encoder = [command_buffer computeCommandEncoder];
429
+ }
430
+
431
+ const float scale = *(const float *) src1->data;
432
+
433
+ [encoder setComputePipelineState:ctx->pipeline_scale];
434
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
435
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
436
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
437
+
438
+ const int64_t n = ggml_nelements(dst);
439
+
440
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
441
+ } break;
442
+ case GGML_OP_SILU:
443
+ {
444
+ if (encoder == nil) {
445
+ encoder = [command_buffer computeCommandEncoder];
446
+ }
447
+
448
+ [encoder setComputePipelineState:ctx->pipeline_silu];
449
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
450
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
451
+
452
+ const int64_t n = ggml_nelements(dst);
453
+
454
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
455
+ } break;
456
+ case GGML_OP_RELU:
457
+ {
458
+ if (encoder == nil) {
459
+ encoder = [command_buffer computeCommandEncoder];
460
+ }
461
+
462
+ [encoder setComputePipelineState:ctx->pipeline_relu];
463
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
464
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
465
+
466
+ const int64_t n = ggml_nelements(dst);
467
+
468
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
469
+ } break;
470
+ case GGML_OP_GELU:
471
+ {
472
+ if (encoder == nil) {
473
+ encoder = [command_buffer computeCommandEncoder];
474
+ }
475
+
476
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
477
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
478
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
479
+
480
+ const int64_t n = ggml_nelements(dst);
481
+
482
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
483
+ } break;
484
+ case GGML_OP_SOFT_MAX:
485
+ {
486
+ if (encoder == nil) {
487
+ encoder = [command_buffer computeCommandEncoder];
488
+ }
489
+
490
+ const int nth = 32;
491
+
492
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
493
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
494
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
495
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
496
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
497
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
498
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
499
+
500
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
501
+ } break;
502
+ case GGML_OP_DIAG_MASK_INF:
503
+ {
504
+ if (encoder == nil) {
505
+ encoder = [command_buffer computeCommandEncoder];
506
+ }
507
+
508
+ const int n_past = ((int32_t *)(src1->data))[0];
509
+
510
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
511
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
512
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
513
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
514
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
515
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
516
+
517
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
518
+ } break;
519
+ case GGML_OP_MUL_MAT:
520
+ {
521
+ // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
522
+
523
+ GGML_ASSERT(ne00 == ne10);
524
+ GGML_ASSERT(ne02 == ne12);
525
+
526
+ if (ggml_is_contiguous(src0) &&
527
+ ggml_is_contiguous(src1) &&
528
+ (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
529
+
530
+ if (encoder != nil) {
531
+ [encoder endEncoding];
532
+ encoder = nil;
533
+ }
534
+
535
+ MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
536
+ MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
537
+
538
+ // for F32 x F32 we use MPS
539
+ MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
540
+ matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
541
+
542
+ MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
543
+ matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
544
+
545
+ MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
546
+ matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
547
+
548
+ MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
549
+ initWithDevice:ctx->device transposeLeft:false transposeRight:true
550
+ resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
551
+
552
+ // we need to do ne02 multiplications
553
+ // TODO: is there a way to do this in parallel - currently very slow ..
554
+ // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
555
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
556
+ size_t offs_src0_cur = offs_src0 + i02*nb02;
557
+ size_t offs_src1_cur = offs_src1 + i02*nb12;
558
+ size_t offs_dst_cur = offs_dst + i02*nb2;
559
+
560
+ MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
561
+ MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
562
+ MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
563
+
564
+ [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
565
+ }
566
+ } else {
567
+ if (encoder == nil) {
568
+ encoder = [command_buffer computeCommandEncoder];
569
+ }
570
+
571
+ int nth0 = 32;
572
+ int nth1 = 1;
573
+
574
+ // use custom matrix x vector kernel
575
+ switch (src0t) {
576
+ case GGML_TYPE_F16:
577
+ {
578
+ GGML_ASSERT(ne02 == ne12);
579
+
580
+ nth0 = 64;
581
+ nth1 = 1;
582
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
583
+ } break;
584
+ case GGML_TYPE_Q4_0:
585
+ {
586
+ GGML_ASSERT(ne02 == 1);
587
+ GGML_ASSERT(ne12 == 1);
588
+
589
+ nth0 = 8;
590
+ nth1 = 8;
591
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
592
+ } break;
593
+ case GGML_TYPE_Q4_1:
594
+ {
595
+ GGML_ASSERT(ne02 == 1);
596
+ GGML_ASSERT(ne12 == 1);
597
+
598
+ nth0 = 8;
599
+ nth1 = 8;
600
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
601
+ } break;
602
+ case GGML_TYPE_Q2_K:
603
+ {
604
+ GGML_ASSERT(ne02 == 1);
605
+ GGML_ASSERT(ne12 == 1);
606
+
607
+ nth0 = 4;
608
+ nth1 = 16;
609
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
610
+ } break;
611
+ case GGML_TYPE_Q3_K:
612
+ {
613
+ GGML_ASSERT(ne02 == 1);
614
+ GGML_ASSERT(ne12 == 1);
615
+
616
+ nth0 = 4;
617
+ nth1 = 16;
618
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
619
+ } break;
620
+ case GGML_TYPE_Q4_K:
621
+ {
622
+ GGML_ASSERT(ne02 == 1);
623
+ GGML_ASSERT(ne12 == 1);
624
+
625
+ nth0 = 4;
626
+ nth1 = 16;
627
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
628
+ } break;
629
+ case GGML_TYPE_Q5_K:
630
+ {
631
+ GGML_ASSERT(ne02 == 1);
632
+ GGML_ASSERT(ne12 == 1);
633
+
634
+ nth0 = 4;
635
+ nth1 = 16;
636
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
637
+ } break;
638
+ case GGML_TYPE_Q6_K:
639
+ {
640
+ GGML_ASSERT(ne02 == 1);
641
+ GGML_ASSERT(ne12 == 1);
642
+
643
+ nth0 = 4;
644
+ nth1 = 16;
645
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
646
+ } break;
647
+ default:
648
+ {
649
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
650
+ GGML_ASSERT(false && "not implemented");
651
+ }
652
+ };
653
+
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
656
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
657
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
658
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
659
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
660
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
661
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
662
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
663
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
664
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
665
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
666
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
667
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
668
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
669
+
670
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
671
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
672
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
673
+ }
674
+ else if (src0t == GGML_TYPE_Q2_K ||
675
+ src0t == GGML_TYPE_Q3_K ||
676
+ src0t == GGML_TYPE_Q4_K ||
677
+ src0t == GGML_TYPE_Q5_K ||
678
+ src0t == GGML_TYPE_Q6_K) {
679
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
680
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
681
+ } else {
682
+ [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
683
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
684
+ }
685
+ }
686
+ } break;
687
+ case GGML_OP_GET_ROWS:
688
+ {
689
+ if (encoder == nil) {
690
+ encoder = [command_buffer computeCommandEncoder];
691
+ }
692
+
693
+ switch (src0->type) {
694
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
695
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
696
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
697
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
698
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
699
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
700
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
701
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
702
+ default: GGML_ASSERT(false && "not implemented");
703
+ }
704
+
705
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
706
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
707
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
708
+ [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
709
+ [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
710
+ [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
711
+
712
+ const int64_t n = ggml_nelements(src1);
713
+
714
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
715
+ } break;
716
+ case GGML_OP_RMS_NORM:
717
+ {
718
+ if (encoder == nil) {
719
+ encoder = [command_buffer computeCommandEncoder];
720
+ }
721
+
722
+ const float eps = 1e-6f;
723
+
724
+ const int nth = 256;
725
+
726
+ [encoder setComputePipelineState:ctx->pipeline_rms_norm];
727
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
728
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
729
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
730
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
731
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
732
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
733
+
734
+ const int64_t nrows = ggml_nrows(src0);
735
+
736
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
737
+ } break;
738
+ case GGML_OP_ROPE:
739
+ {
740
+ if (encoder == nil) {
741
+ encoder = [command_buffer computeCommandEncoder];
742
+ }
743
+
744
+ const int n_dims = ((int32_t *) src1->data)[1];
745
+ const int mode = ((int32_t *) src1->data)[2];
746
+
747
+ const int n_past = ((int32_t *)(src1->data))[0];
748
+
749
+ [encoder setComputePipelineState:ctx->pipeline_rope];
750
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
752
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
753
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
754
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
755
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
756
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
757
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
758
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
759
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
760
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
761
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
762
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
763
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
764
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
765
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
766
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
767
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
768
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
769
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
770
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
771
+
772
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
773
+ } break;
774
+ case GGML_OP_CPY:
775
+ {
776
+ if (encoder == nil) {
777
+ encoder = [command_buffer computeCommandEncoder];
778
+ }
779
+
780
+ const int nth = 32;
781
+
782
+ switch (src0t) {
783
+ case GGML_TYPE_F32:
784
+ {
785
+ switch (dstt) {
786
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
787
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
788
+ default: GGML_ASSERT(false && "not implemented");
789
+ };
790
+ } break;
791
+ default: GGML_ASSERT(false && "not implemented");
792
+ }
793
+
794
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
795
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
796
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
797
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
798
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
799
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
800
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
801
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
802
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
803
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
804
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
805
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
806
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
807
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
808
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
809
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
810
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
811
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
812
+
813
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
814
+ } break;
815
+ default:
816
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
817
+ GGML_ASSERT(false);
818
+ }
819
+ }
820
+
821
+ if (encoder != nil) {
822
+ [encoder endEncoding];
823
+ encoder = nil;
824
+ }
825
+
826
+ [command_buffer commit];
827
+ });
828
+ }
829
+
830
+ // wait for all threads to finish
831
+ dispatch_barrier_sync(queue, ^{});
832
+
833
+ [command_buffers[n_cb - 1] waitUntilCompleted];
834
+ }