llama_cpp 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,18 +52,25 @@ struct ggml_metal_context {
52
52
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
53
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
54
54
  GGML_METAL_DECL_KERNEL(get_rows_q2_k);
55
+ GGML_METAL_DECL_KERNEL(get_rows_q3_k);
55
56
  GGML_METAL_DECL_KERNEL(get_rows_q4_k);
57
+ GGML_METAL_DECL_KERNEL(get_rows_q5_k);
56
58
  GGML_METAL_DECL_KERNEL(get_rows_q6_k);
57
59
  GGML_METAL_DECL_KERNEL(rms_norm);
60
+ GGML_METAL_DECL_KERNEL(norm);
58
61
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
59
62
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
60
63
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
61
64
  GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
65
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
62
66
  GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
67
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
63
68
  GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
64
69
  GGML_METAL_DECL_KERNEL(rope);
70
+ GGML_METAL_DECL_KERNEL(alibi_f32);
65
71
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
66
72
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
73
+ GGML_METAL_DECL_KERNEL(cpy_f16_f16);
67
74
 
68
75
  #undef GGML_METAL_DECL_KERNEL
69
76
  };
@@ -86,6 +93,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
86
93
 
87
94
  ctx->device = MTLCreateSystemDefaultDevice();
88
95
  ctx->queue = [ctx->device newCommandQueue];
96
+ ctx->n_buffers = 0;
89
97
 
90
98
  // determine if we can use MPS
91
99
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -152,22 +160,37 @@ struct ggml_metal_context * ggml_metal_init(void) {
152
160
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
153
161
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
154
162
  GGML_METAL_ADD_KERNEL(get_rows_q2_k);
163
+ GGML_METAL_ADD_KERNEL(get_rows_q3_k);
155
164
  GGML_METAL_ADD_KERNEL(get_rows_q4_k);
165
+ GGML_METAL_ADD_KERNEL(get_rows_q5_k);
156
166
  GGML_METAL_ADD_KERNEL(get_rows_q6_k);
157
167
  GGML_METAL_ADD_KERNEL(rms_norm);
168
+ GGML_METAL_ADD_KERNEL(norm);
158
169
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
159
170
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
160
171
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
161
172
  GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
173
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
162
174
  GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
175
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
163
176
  GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
164
177
  GGML_METAL_ADD_KERNEL(rope);
178
+ GGML_METAL_ADD_KERNEL(alibi_f32);
165
179
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
166
180
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
181
+ GGML_METAL_ADD_KERNEL(cpy_f16_f16);
167
182
 
168
183
  #undef GGML_METAL_ADD_KERNEL
169
184
  }
170
185
 
186
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
187
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
188
+ if (ctx->device.maxTransferRate != 0) {
189
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
190
+ } else {
191
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
192
+ }
193
+
171
194
  return ctx;
172
195
  }
173
196
 
@@ -184,10 +207,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
184
207
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
185
208
  //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
186
209
 
210
+ const int64_t tsize = ggml_nbytes(t);
211
+
212
+ // find the view that contains the tensor fully
187
213
  for (int i = 0; i < ctx->n_buffers; ++i) {
188
214
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
189
215
 
190
- if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
216
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
191
217
  *offs = (size_t) ioffs;
192
218
 
193
219
  //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
@@ -205,7 +231,8 @@ bool ggml_metal_add_buffer(
205
231
  struct ggml_metal_context * ctx,
206
232
  const char * name,
207
233
  void * data,
208
- size_t size) {
234
+ size_t size,
235
+ size_t max_size) {
209
236
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
210
237
  fprintf(stderr, "%s: too many buffers\n", __func__);
211
238
  return false;
@@ -222,30 +249,68 @@ bool ggml_metal_add_buffer(
222
249
  }
223
250
  }
224
251
 
225
- size_t page_size = getpagesize();
226
- size_t aligned_size = size;
227
- if ((aligned_size % page_size) != 0) {
228
- aligned_size += (page_size - (aligned_size % page_size));
252
+ const size_t size_page = getpagesize();
253
+
254
+ size_t size_aligned = size;
255
+ if ((size_aligned % size_page) != 0) {
256
+ size_aligned += (size_page - (size_aligned % size_page));
229
257
  }
230
258
 
231
- ctx->buffers[ctx->n_buffers].name = name;
232
- ctx->buffers[ctx->n_buffers].data = data;
233
- ctx->buffers[ctx->n_buffers].size = size;
259
+ // the buffer fits into the max buffer size allowed by the device
260
+ if (size_aligned <= ctx->device.maxBufferLength) {
261
+ ctx->buffers[ctx->n_buffers].name = name;
262
+ ctx->buffers[ctx->n_buffers].data = data;
263
+ ctx->buffers[ctx->n_buffers].size = size;
234
264
 
235
- if (ctx->device.maxBufferLength < aligned_size) {
236
- fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
237
- return false;
238
- }
239
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
265
+ ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
240
266
 
241
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
242
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
243
- return false;
267
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
268
+ fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
269
+ return false;
270
+ }
271
+
272
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
273
+
274
+ ++ctx->n_buffers;
244
275
  } else {
245
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
276
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
277
+ // one of the views
278
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
279
+ const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
280
+ const size_t size_view = ctx->device.maxBufferLength;
281
+
282
+ for (size_t i = 0; i < size; i += size_step) {
283
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
284
+
285
+ ctx->buffers[ctx->n_buffers].name = name;
286
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
287
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
288
+
289
+ ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
290
+
291
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
292
+ fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
293
+ return false;
294
+ }
295
+
296
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
297
+ if (i + size_step < size) {
298
+ fprintf(stderr, "\n");
299
+ }
300
+
301
+ ++ctx->n_buffers;
302
+ }
246
303
  }
247
304
 
248
- ++ctx->n_buffers;
305
+ fprintf(stderr, ", (%8.2f / %8.2f)",
306
+ ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
307
+ ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
308
+
309
+ if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
310
+ fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
311
+ } else {
312
+ fprintf(stderr, "\n");
313
+ }
249
314
  }
250
315
 
251
316
  return true;
@@ -275,509 +340,633 @@ void ggml_metal_get_tensor(
275
340
 
276
341
  void ggml_metal_graph_compute(
277
342
  struct ggml_metal_context * ctx,
278
- struct ggml_cgraph * gf) {
343
+ struct ggml_cgraph * gf) {
279
344
  metal_printf("%s: evaluating graph\n", __func__);
280
345
 
281
- size_t offs_src0 = 0;
282
- size_t offs_src1 = 0;
283
- size_t offs_dst = 0;
284
-
285
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
286
- id<MTLComputeCommandEncoder> encoder = nil;
287
-
288
- for (int i = 0; i < gf->n_nodes; ++i) {
289
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
290
-
291
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
292
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
293
- struct ggml_tensor * dst = gf->nodes[i];
294
-
295
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
296
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
297
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
298
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
299
-
300
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
301
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
302
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
303
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
304
-
305
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
306
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
307
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
308
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
309
-
310
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
311
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
312
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
313
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
314
-
315
- const int64_t ne0 = dst ? dst->ne[0] : 0;
316
- const int64_t ne1 = dst ? dst->ne[1] : 0;
317
- const int64_t ne2 = dst ? dst->ne[2] : 0;
318
- const int64_t ne3 = dst ? dst->ne[3] : 0;
319
-
320
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
321
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
322
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
323
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
324
-
325
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
326
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
327
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
328
-
329
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
330
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
331
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
332
-
333
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
334
- //if (src0) {
335
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
336
- // ggml_is_contiguous(src0), src0->name);
337
- //}
338
- //if (src1) {
339
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
340
- // ggml_is_contiguous(src1), src1->name);
341
- //}
342
- //if (dst) {
343
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
344
- // dst->name);
345
- //}
346
-
347
- switch (dst->op) {
348
- case GGML_OP_RESHAPE:
349
- case GGML_OP_VIEW:
350
- case GGML_OP_TRANSPOSE:
351
- case GGML_OP_PERMUTE:
352
- {
353
- // noop
354
- } break;
355
- case GGML_OP_ADD:
356
- {
357
- if (encoder == nil) {
358
- encoder = [command_buffer computeCommandEncoder];
359
- }
360
-
361
- [encoder setComputePipelineState:ctx->pipeline_add];
362
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
363
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
364
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
365
-
366
- const int64_t n = ggml_nelements(dst);
367
-
368
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
369
- } break;
370
- case GGML_OP_MUL:
371
- {
372
- if (encoder == nil) {
373
- encoder = [command_buffer computeCommandEncoder];
374
- }
375
-
376
- if (ggml_nelements(src1) == ne10) {
377
- // src1 is a row
378
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
379
- } else {
380
- [encoder setComputePipelineState:ctx->pipeline_mul];
381
- }
382
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
383
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
384
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
385
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
386
-
387
- const int64_t n = ggml_nelements(dst);
388
-
389
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
390
- } break;
391
- case GGML_OP_SCALE:
392
- {
393
- if (encoder == nil) {
394
- encoder = [command_buffer computeCommandEncoder];
395
- }
396
-
397
- const float scale = *(const float *) src1->data;
398
-
399
- [encoder setComputePipelineState:ctx->pipeline_scale];
400
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
401
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
402
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
403
-
404
- const int64_t n = ggml_nelements(dst);
405
-
406
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
407
- } break;
408
- case GGML_OP_SILU:
409
- {
410
- if (encoder == nil) {
411
- encoder = [command_buffer computeCommandEncoder];
412
- }
413
-
414
- [encoder setComputePipelineState:ctx->pipeline_silu];
415
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
416
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
417
-
418
- const int64_t n = ggml_nelements(dst);
419
-
420
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
421
- } break;
422
- case GGML_OP_RELU:
423
- {
424
- if (encoder == nil) {
425
- encoder = [command_buffer computeCommandEncoder];
426
- }
427
-
428
- [encoder setComputePipelineState:ctx->pipeline_relu];
429
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
430
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
431
-
432
- const int64_t n = ggml_nelements(dst);
433
-
434
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
435
- } break;
436
- case GGML_OP_GELU:
437
- {
438
- if (encoder == nil) {
439
- encoder = [command_buffer computeCommandEncoder];
440
- }
441
-
442
- [encoder setComputePipelineState:ctx->pipeline_gelu];
443
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
444
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
445
-
446
- const int64_t n = ggml_nelements(dst);
447
-
448
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
449
- } break;
450
- case GGML_OP_SOFT_MAX:
451
- {
452
- if (encoder == nil) {
453
- encoder = [command_buffer computeCommandEncoder];
454
- }
455
-
456
- const int nth = 32;
457
-
458
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
459
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
460
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
461
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
462
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
463
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
464
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
465
-
466
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
467
- } break;
468
- case GGML_OP_DIAG_MASK_INF:
469
- {
470
- if (encoder == nil) {
471
- encoder = [command_buffer computeCommandEncoder];
472
- }
473
-
474
- const int n_past = ((int32_t *)(src1->data))[0];
475
-
476
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
477
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
478
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
479
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
480
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
481
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
482
-
483
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
484
- } break;
485
- case GGML_OP_MUL_MAT:
486
- {
487
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
488
-
489
- GGML_ASSERT(ne00 == ne10);
490
- GGML_ASSERT(ne02 == ne12);
491
-
492
- if (ggml_is_contiguous(src0) &&
493
- ggml_is_contiguous(src1) &&
494
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
495
-
496
- if (encoder != nil) {
497
- [encoder endEncoding];
498
- encoder = nil;
499
- }
500
-
501
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
502
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
503
-
504
- // for F32 x F32 we use MPS
505
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
506
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
507
-
508
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
509
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
510
-
511
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
512
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
513
-
514
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
515
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
516
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
517
-
518
- // we need to do ne02 multiplications
519
- // TODO: is there a way to do this in parallel - currently very slow ..
520
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
521
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
522
- size_t offs_src0_cur = offs_src0 + i02*nb02;
523
- size_t offs_src1_cur = offs_src1 + i02*nb12;
524
- size_t offs_dst_cur = offs_dst + i02*nb2;
525
-
526
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
527
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
528
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
529
-
530
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
531
- }
532
- } else {
533
- if (encoder == nil) {
534
- encoder = [command_buffer computeCommandEncoder];
535
- }
536
-
537
- int nth0 = 32;
538
- int nth1 = 1;
539
-
540
- // use custom matrix x vector kernel
541
- switch (src0t) {
542
- case GGML_TYPE_F16:
543
- {
544
- GGML_ASSERT(ne02 == ne12);
545
-
546
- nth0 = 64;
547
- nth1 = 1;
548
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
549
- } break;
550
- case GGML_TYPE_Q4_0:
551
- {
552
- GGML_ASSERT(ne02 == 1);
553
- GGML_ASSERT(ne12 == 1);
554
-
555
- nth0 = 8;
556
- nth1 = 8;
557
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
558
- } break;
559
- case GGML_TYPE_Q4_1:
560
- {
561
- GGML_ASSERT(ne02 == 1);
562
- GGML_ASSERT(ne12 == 1);
563
-
564
- nth0 = 8;
565
- nth1 = 8;
566
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
567
- } break;
568
- case GGML_TYPE_Q2_K:
569
- {
570
- GGML_ASSERT(ne02 == 1);
571
- GGML_ASSERT(ne12 == 1);
572
-
573
- nth0 = 4;
574
- nth1 = 16;
575
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
576
- } break;
577
- case GGML_TYPE_Q4_K:
578
- {
579
- GGML_ASSERT(ne02 == 1);
580
- GGML_ASSERT(ne12 == 1);
581
-
582
- nth0 = 4;
583
- nth1 = 16;
584
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
585
- } break;
586
- case GGML_TYPE_Q6_K:
587
- {
588
- GGML_ASSERT(ne02 == 1);
589
- GGML_ASSERT(ne12 == 1);
590
-
591
- nth0 = 4;
592
- nth1 = 16;
593
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
594
- } break;
595
- default:
596
- {
597
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
598
- GGML_ASSERT(false && "not implemented");
346
+ // create multiple command buffers and enqueue them
347
+ // then, we encode the graph into the command buffers in parallel
348
+
349
+ const int n_cb = gf->n_threads;
350
+
351
+ NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
352
+
353
+ for (int i = 0; i < n_cb; ++i) {
354
+ command_buffers[i] = [ctx->queue commandBuffer];
355
+
356
+ // enqueue the command buffers in order to specify their execution order
357
+ [command_buffers[i] enqueue];
358
+ }
359
+
360
+ // TODO: is this the best way to start threads?
361
+ dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
362
+
363
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
364
+ const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
365
+
366
+ dispatch_async(queue, ^{
367
+ size_t offs_src0 = 0;
368
+ size_t offs_src1 = 0;
369
+ size_t offs_dst = 0;
370
+
371
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
372
+
373
+ id<MTLComputeCommandEncoder> encoder = nil;
374
+
375
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
376
+ const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
377
+
378
+ for (int i = node_start; i < node_end; ++i) {
379
+ metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
380
+
381
+ struct ggml_tensor * src0 = gf->nodes[i]->src0;
382
+ struct ggml_tensor * src1 = gf->nodes[i]->src1;
383
+ struct ggml_tensor * dst = gf->nodes[i];
384
+
385
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
386
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
387
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
388
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
389
+
390
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
391
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
392
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
393
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
394
+
395
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
396
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
397
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
398
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
399
+
400
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
401
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
402
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
403
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
404
+
405
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
406
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
407
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
408
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
409
+
410
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
411
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
412
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
413
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
414
+
415
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
416
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
417
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
418
+
419
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
420
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
421
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
422
+
423
+ //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
424
+ //if (src0) {
425
+ // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
426
+ // ggml_is_contiguous(src0), src0->name);
427
+ //}
428
+ //if (src1) {
429
+ // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
430
+ // ggml_is_contiguous(src1), src1->name);
431
+ //}
432
+ //if (dst) {
433
+ // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
434
+ // dst->name);
435
+ //}
436
+
437
+ switch (dst->op) {
438
+ case GGML_OP_RESHAPE:
439
+ case GGML_OP_VIEW:
440
+ case GGML_OP_TRANSPOSE:
441
+ case GGML_OP_PERMUTE:
442
+ {
443
+ // noop
444
+ } break;
445
+ case GGML_OP_ADD:
446
+ {
447
+ if (encoder == nil) {
448
+ encoder = [command_buffer computeCommandEncoder];
449
+ }
450
+
451
+ [encoder setComputePipelineState:ctx->pipeline_add];
452
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
453
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
454
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
455
+
456
+ const int64_t n = ggml_nelements(dst);
457
+
458
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
459
+ } break;
460
+ case GGML_OP_MUL:
461
+ {
462
+ if (encoder == nil) {
463
+ encoder = [command_buffer computeCommandEncoder];
464
+ }
465
+
466
+ if (ggml_nelements(src1) == ne10) {
467
+ // src1 is a row
468
+ [encoder setComputePipelineState:ctx->pipeline_mul_row];
469
+ } else {
470
+ [encoder setComputePipelineState:ctx->pipeline_mul];
471
+ }
472
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
473
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
474
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
475
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
476
+
477
+ const int64_t n = ggml_nelements(dst);
478
+
479
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
480
+ } break;
481
+ case GGML_OP_SCALE:
482
+ {
483
+ if (encoder == nil) {
484
+ encoder = [command_buffer computeCommandEncoder];
485
+ }
486
+
487
+ const float scale = *(const float *) src1->data;
488
+
489
+ [encoder setComputePipelineState:ctx->pipeline_scale];
490
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
491
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
492
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
493
+
494
+ const int64_t n = ggml_nelements(dst);
495
+
496
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
497
+ } break;
498
+ case GGML_OP_SILU:
499
+ {
500
+ if (encoder == nil) {
501
+ encoder = [command_buffer computeCommandEncoder];
502
+ }
503
+
504
+ [encoder setComputePipelineState:ctx->pipeline_silu];
505
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
506
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
507
+
508
+ const int64_t n = ggml_nelements(dst);
509
+
510
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
511
+ } break;
512
+ case GGML_OP_RELU:
513
+ {
514
+ if (encoder == nil) {
515
+ encoder = [command_buffer computeCommandEncoder];
516
+ }
517
+
518
+ [encoder setComputePipelineState:ctx->pipeline_relu];
519
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
520
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
521
+
522
+ const int64_t n = ggml_nelements(dst);
523
+
524
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
525
+ } break;
526
+ case GGML_OP_GELU:
527
+ {
528
+ if (encoder == nil) {
529
+ encoder = [command_buffer computeCommandEncoder];
530
+ }
531
+
532
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
533
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
534
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
535
+
536
+ const int64_t n = ggml_nelements(dst);
537
+
538
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
539
+ } break;
540
+ case GGML_OP_SOFT_MAX:
541
+ {
542
+ if (encoder == nil) {
543
+ encoder = [command_buffer computeCommandEncoder];
544
+ }
545
+
546
+ const int nth = 32;
547
+
548
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
549
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
552
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
553
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
554
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
555
+
556
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
557
+ } break;
558
+ case GGML_OP_DIAG_MASK_INF:
559
+ {
560
+ if (encoder == nil) {
561
+ encoder = [command_buffer computeCommandEncoder];
562
+ }
563
+
564
+ const int n_past = ((int32_t *)(src1->data))[0];
565
+
566
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
567
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
568
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
569
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
570
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
571
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
572
+
573
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
574
+ } break;
575
+ case GGML_OP_MUL_MAT:
576
+ {
577
+ // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
578
+
579
+ GGML_ASSERT(ne00 == ne10);
580
+ GGML_ASSERT(ne02 == ne12);
581
+
582
+ if (ggml_is_contiguous(src0) &&
583
+ ggml_is_contiguous(src1) &&
584
+ (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
585
+
586
+ if (encoder != nil) {
587
+ [encoder endEncoding];
588
+ encoder = nil;
589
+ }
590
+
591
+ MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
592
+ MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
593
+
594
+ // for F32 x F32 we use MPS
595
+ MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
596
+ matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
597
+
598
+ MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
599
+ matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
600
+
601
+ MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
602
+ matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
603
+
604
+ MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
605
+ initWithDevice:ctx->device transposeLeft:false transposeRight:true
606
+ resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
607
+
608
+ // we need to do ne02 multiplications
609
+ // TODO: is there a way to do this in parallel - currently very slow ..
610
+ // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
611
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
612
+ size_t offs_src0_cur = offs_src0 + i02*nb02;
613
+ size_t offs_src1_cur = offs_src1 + i02*nb12;
614
+ size_t offs_dst_cur = offs_dst + i02*nb2;
615
+
616
+ MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
617
+ MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
618
+ MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
619
+
620
+ [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
621
+ }
622
+ } else {
623
+ if (encoder == nil) {
624
+ encoder = [command_buffer computeCommandEncoder];
599
625
  }
600
- };
601
-
602
-
603
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
604
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
605
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
606
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
607
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
608
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
609
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
610
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
611
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
612
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
613
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
614
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
615
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
616
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
617
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
618
-
619
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
620
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
621
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
622
- } else if (src0t == GGML_TYPE_Q2_K) {
623
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
624
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
625
- } else if (src0t == GGML_TYPE_Q4_K) {
626
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
627
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
628
- } else if (src0t == GGML_TYPE_Q6_K) {
629
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
630
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
631
- } else {
632
- [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
633
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
634
- }
635
- }
636
- } break;
637
- case GGML_OP_GET_ROWS:
638
- {
639
- if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
641
- }
642
-
643
- switch (src0->type) {
644
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
645
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
646
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
647
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
648
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
649
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
650
- default: GGML_ASSERT(false && "not implemented");
651
- }
652
-
653
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
654
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
655
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
656
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
657
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
658
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
659
-
660
- const int64_t n = ggml_nelements(src1);
661
-
662
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
663
- } break;
664
- case GGML_OP_RMS_NORM:
665
- {
666
- if (encoder == nil) {
667
- encoder = [command_buffer computeCommandEncoder];
668
- }
669
-
670
- const float eps = 1e-6f;
671
-
672
- const int nth = 256;
673
-
674
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
675
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
677
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
678
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
679
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
680
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
681
-
682
- const int64_t nrows = ggml_nrows(src0);
683
-
684
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
685
- } break;
686
- case GGML_OP_ROPE:
687
- {
688
- if (encoder == nil) {
689
- encoder = [command_buffer computeCommandEncoder];
690
- }
691
-
692
- const int n_dims = ((int32_t *) src1->data)[1];
693
- const int mode = ((int32_t *) src1->data)[2];
694
-
695
- const int n_past = ((int32_t *)(src1->data))[0];
696
-
697
- [encoder setComputePipelineState:ctx->pipeline_rope];
698
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
699
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
700
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
701
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
702
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
703
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
704
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
705
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
706
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
707
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
708
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
709
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
710
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
711
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
712
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
713
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
714
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
715
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
716
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
717
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
718
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
719
-
720
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
721
- } break;
722
- case GGML_OP_CPY:
723
- {
724
- if (encoder == nil) {
725
- encoder = [command_buffer computeCommandEncoder];
726
- }
727
-
728
- const int nth = 32;
729
-
730
- switch (src0t) {
731
- case GGML_TYPE_F32:
732
- {
733
- switch (dstt) {
734
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
735
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
736
- default: GGML_ASSERT(false && "not implemented");
626
+
627
+ int nth0 = 32;
628
+ int nth1 = 1;
629
+
630
+ // use custom matrix x vector kernel
631
+ switch (src0t) {
632
+ case GGML_TYPE_F16:
633
+ {
634
+ GGML_ASSERT(ne02 == ne12);
635
+
636
+ nth0 = 64;
637
+ nth1 = 1;
638
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
639
+ } break;
640
+ case GGML_TYPE_Q4_0:
641
+ {
642
+ GGML_ASSERT(ne02 == 1);
643
+ GGML_ASSERT(ne12 == 1);
644
+
645
+ nth0 = 8;
646
+ nth1 = 8;
647
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
648
+ } break;
649
+ case GGML_TYPE_Q4_1:
650
+ {
651
+ GGML_ASSERT(ne02 == 1);
652
+ GGML_ASSERT(ne12 == 1);
653
+
654
+ nth0 = 8;
655
+ nth1 = 8;
656
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
657
+ } break;
658
+ case GGML_TYPE_Q2_K:
659
+ {
660
+ GGML_ASSERT(ne02 == 1);
661
+ GGML_ASSERT(ne12 == 1);
662
+
663
+ nth0 = 4;
664
+ nth1 = 16;
665
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
666
+ } break;
667
+ case GGML_TYPE_Q3_K:
668
+ {
669
+ GGML_ASSERT(ne02 == 1);
670
+ GGML_ASSERT(ne12 == 1);
671
+
672
+ nth0 = 4;
673
+ nth1 = 16;
674
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
675
+ } break;
676
+ case GGML_TYPE_Q4_K:
677
+ {
678
+ GGML_ASSERT(ne02 == 1);
679
+ GGML_ASSERT(ne12 == 1);
680
+
681
+ nth0 = 4;
682
+ nth1 = 16;
683
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
684
+ } break;
685
+ case GGML_TYPE_Q5_K:
686
+ {
687
+ GGML_ASSERT(ne02 == 1);
688
+ GGML_ASSERT(ne12 == 1);
689
+
690
+ nth0 = 4;
691
+ nth1 = 16;
692
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
693
+ } break;
694
+ case GGML_TYPE_Q6_K:
695
+ {
696
+ GGML_ASSERT(ne02 == 1);
697
+ GGML_ASSERT(ne12 == 1);
698
+
699
+ nth0 = 4;
700
+ nth1 = 16;
701
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
702
+ } break;
703
+ default:
704
+ {
705
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
706
+ GGML_ASSERT(false && "not implemented");
707
+ }
737
708
  };
738
- } break;
739
- default: GGML_ASSERT(false && "not implemented");
740
- }
741
-
742
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
743
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
744
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
745
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
746
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
747
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
748
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
749
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
750
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
751
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
752
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
753
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
754
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
755
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
756
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
757
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
758
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
759
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
760
-
761
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
762
- } break;
763
- default:
764
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
765
- GGML_ASSERT(false);
766
- }
767
- }
768
709
 
769
- if (encoder != nil) {
770
- [encoder endEncoding];
771
- encoder = nil;
710
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
711
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
712
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
713
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
714
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
715
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
716
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
717
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
718
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
719
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
720
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
721
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
722
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
723
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
724
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
725
+
726
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
727
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
728
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
729
+ }
730
+ else if (src0t == GGML_TYPE_Q2_K ||
731
+ src0t == GGML_TYPE_Q3_K ||
732
+ src0t == GGML_TYPE_Q4_K ||
733
+ src0t == GGML_TYPE_Q5_K ||
734
+ src0t == GGML_TYPE_Q6_K) {
735
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
736
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
737
+ } else {
738
+ [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
739
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
740
+ }
741
+ }
742
+ } break;
743
+ case GGML_OP_GET_ROWS:
744
+ {
745
+ if (encoder == nil) {
746
+ encoder = [command_buffer computeCommandEncoder];
747
+ }
748
+
749
+ switch (src0->type) {
750
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
751
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
752
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
753
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
754
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
755
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
756
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
757
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
758
+ default: GGML_ASSERT(false && "not implemented");
759
+ }
760
+
761
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
762
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
763
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
764
+ [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
765
+ [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
766
+ [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
767
+
768
+ const int64_t n = ggml_nelements(src1);
769
+
770
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
771
+ } break;
772
+ case GGML_OP_RMS_NORM:
773
+ {
774
+ if (encoder == nil) {
775
+ encoder = [command_buffer computeCommandEncoder];
776
+ }
777
+
778
+ const float eps = 1e-6f;
779
+
780
+ const int nth = 256;
781
+
782
+ [encoder setComputePipelineState:ctx->pipeline_rms_norm];
783
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
784
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
785
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
786
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
787
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
788
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
789
+
790
+ const int64_t nrows = ggml_nrows(src0);
791
+
792
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
793
+ } break;
794
+ case GGML_OP_NORM:
795
+ {
796
+ if (encoder == nil) {
797
+ encoder = [command_buffer computeCommandEncoder];
798
+ }
799
+
800
+ const float eps = 1e-5f;
801
+
802
+ const int nth = 256;
803
+
804
+ [encoder setComputePipelineState:ctx->pipeline_norm];
805
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
806
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
807
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
808
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
809
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
810
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
811
+
812
+ const int64_t nrows = ggml_nrows(src0);
813
+
814
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
815
+ } break;
816
+ case GGML_OP_ALIBI:
817
+ {
818
+ if (encoder == nil) {
819
+ encoder = [command_buffer computeCommandEncoder];
820
+ }
821
+
822
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
823
+
824
+ const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
825
+ const int n_head = ((int32_t *) src1->data)[1];
826
+ const float max_bias = ((float *) src1->data)[2];
827
+
828
+ if (__builtin_popcount(n_head) != 1) {
829
+ GGML_ASSERT(false && "only power-of-two n_head implemented");
830
+ }
831
+
832
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
833
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
834
+
835
+ [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
836
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
837
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
838
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
839
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
840
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
841
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
842
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
843
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
844
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
845
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
846
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
847
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
848
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
849
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
850
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
851
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
852
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
853
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
854
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
855
+ const int nth = 32;
856
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
857
+ } break;
858
+ case GGML_OP_ROPE:
859
+ {
860
+ if (encoder == nil) {
861
+ encoder = [command_buffer computeCommandEncoder];
862
+ }
863
+
864
+ const int n_dims = ((int32_t *) src1->data)[1];
865
+ const int mode = ((int32_t *) src1->data)[2];
866
+
867
+ const int n_past = ((int32_t *)(src1->data))[0];
868
+
869
+ [encoder setComputePipelineState:ctx->pipeline_rope];
870
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
871
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
872
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
873
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
874
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
875
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
876
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
877
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
878
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
879
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
880
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
881
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
882
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
883
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
884
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
885
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
886
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
887
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
888
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
889
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
890
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
891
+
892
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
893
+ } break;
894
+ case GGML_OP_CPY:
895
+ {
896
+ if (encoder == nil) {
897
+ encoder = [command_buffer computeCommandEncoder];
898
+ }
899
+
900
+ const int nth = 32;
901
+
902
+ switch (src0t) {
903
+ case GGML_TYPE_F32:
904
+ {
905
+ switch (dstt) {
906
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
907
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
908
+ default: GGML_ASSERT(false && "not implemented");
909
+ };
910
+ } break;
911
+ case GGML_TYPE_F16:
912
+ {
913
+ switch (dstt) {
914
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
915
+ case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
916
+ default: GGML_ASSERT(false && "not implemented");
917
+ };
918
+ } break;
919
+ default: GGML_ASSERT(false && "not implemented");
920
+ }
921
+
922
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
923
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
924
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
925
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
926
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
927
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
928
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
929
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
930
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
931
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
932
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
933
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
934
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
935
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
936
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
937
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
938
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
939
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
940
+
941
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
942
+ } break;
943
+ default:
944
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
945
+ GGML_ASSERT(false);
946
+ }
947
+ }
948
+
949
+ if (encoder != nil) {
950
+ [encoder endEncoding];
951
+ encoder = nil;
952
+ }
953
+
954
+ [command_buffer commit];
955
+ });
772
956
  }
773
957
 
774
- [command_buffer commit];
775
- [command_buffer waitUntilCompleted];
958
+ // wait for all threads to finish
959
+ dispatch_barrier_sync(queue, ^{});
776
960
 
777
- {
778
- const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
779
- UNUSED(time_elapsed);
961
+ [command_buffers[n_cb - 1] waitUntilCompleted];
780
962
 
781
- metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
963
+ // check status of command buffers
964
+ // needed to detect if the device ran out-of-memory for example (#1881)
965
+ for (int i = 0; i < n_cb; i++) {
966
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
967
+ if (status != MTLCommandBufferStatusCompleted) {
968
+ fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
969
+ GGML_ASSERT(false);
970
+ }
782
971
  }
783
972
  }