llama_cpp 0.4.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -72,6 +76,7 @@ struct ggml_metal_context {
72
76
  GGML_METAL_DECL_KERNEL(rms_norm);
73
77
  GGML_METAL_DECL_KERNEL(norm);
74
78
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
75
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
76
81
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
82
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -110,16 +115,31 @@ static NSString * const msl_library_source = @"see metal.metal";
110
115
  @end
111
116
 
112
117
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
113
- fprintf(stderr, "%s: allocating\n", __func__);
118
+ metal_printf("%s: allocating\n", __func__);
119
+
120
+ // Show all the Metal device instances in the system
121
+ NSArray * devices = MTLCopyAllDevices();
122
+ id <MTLDevice> device;
123
+ NSString * s;
124
+ for (device in devices) {
125
+ s = [device name];
126
+ metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
127
+ }
114
128
 
115
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
129
+ // Pick and show default Metal device
130
+ device = MTLCreateSystemDefaultDevice();
131
+ s = [device name];
132
+ metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
116
133
 
117
- ctx->n_cb = n_cb;
118
- ctx->device = MTLCreateSystemDefaultDevice();
134
+ // Configure context
135
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
136
+ ctx->device = device;
137
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
119
138
  ctx->queue = [ctx->device newCommandQueue];
120
139
  ctx->n_buffers = 0;
121
140
  ctx->concur_list_len = 0;
122
141
 
142
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
123
143
 
124
144
  #if 0
125
145
  // compile from source string and show compile log
@@ -128,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
128
148
 
129
149
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
130
150
  if (error) {
131
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
151
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
132
152
  return NULL;
133
153
  }
134
154
  }
@@ -142,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
142
162
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
143
163
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
144
164
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
145
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
165
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
146
166
 
147
167
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
148
168
  if (error) {
149
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
169
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
150
170
  return NULL;
151
171
  }
152
172
 
@@ -158,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
158
178
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
159
179
  #endif
160
180
  if (error) {
161
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
181
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
162
182
  return NULL;
163
183
  }
164
184
  }
@@ -170,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  #define GGML_METAL_ADD_KERNEL(name) \
171
191
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
172
192
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
- fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
193
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
194
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
195
  (int) ctx->pipeline_##name.threadExecutionWidth); \
176
196
  if (error) { \
177
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
197
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
198
  return NULL; \
179
199
  }
180
200
 
@@ -200,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
200
220
  GGML_METAL_ADD_KERNEL(rms_norm);
201
221
  GGML_METAL_ADD_KERNEL(norm);
202
222
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
203
224
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
204
225
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
226
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -226,30 +247,89 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
226
247
  #undef GGML_METAL_ADD_KERNEL
227
248
  }
228
249
 
229
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
250
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
251
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
231
252
  if (ctx->device.maxTransferRate != 0) {
232
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
253
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
233
254
  } else {
234
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
255
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
235
256
  }
236
257
 
237
258
  return ctx;
238
259
  }
239
260
 
240
261
  void ggml_metal_free(struct ggml_metal_context * ctx) {
241
- fprintf(stderr, "%s: deallocating\n", __func__);
262
+ metal_printf("%s: deallocating\n", __func__);
263
+ #define GGML_METAL_DEL_KERNEL(name) \
264
+ [ctx->function_##name release]; \
265
+ [ctx->pipeline_##name release];
266
+
267
+ GGML_METAL_DEL_KERNEL(add);
268
+ GGML_METAL_DEL_KERNEL(add_row);
269
+ GGML_METAL_DEL_KERNEL(mul);
270
+ GGML_METAL_DEL_KERNEL(mul_row);
271
+ GGML_METAL_DEL_KERNEL(scale);
272
+ GGML_METAL_DEL_KERNEL(silu);
273
+ GGML_METAL_DEL_KERNEL(relu);
274
+ GGML_METAL_DEL_KERNEL(gelu);
275
+ GGML_METAL_DEL_KERNEL(soft_max);
276
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
277
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
278
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
279
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
280
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
281
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
282
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
283
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
284
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
285
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
286
+ GGML_METAL_DEL_KERNEL(rms_norm);
287
+ GGML_METAL_DEL_KERNEL(norm);
288
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
290
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
293
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
294
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
295
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
296
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
297
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
298
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
299
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
300
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
301
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
302
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
303
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
304
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
305
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
306
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
307
+ GGML_METAL_DEL_KERNEL(rope);
308
+ GGML_METAL_DEL_KERNEL(alibi_f32);
309
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
310
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
311
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
312
+
313
+ #undef GGML_METAL_DEL_KERNEL
314
+
242
315
  for (int i = 0; i < ctx->n_buffers; ++i) {
243
316
  [ctx->buffers[i].metal release];
244
317
  }
318
+
319
+ [ctx->library release];
320
+ [ctx->queue release];
321
+ [ctx->device release];
322
+
323
+ dispatch_release(ctx->d_queue);
324
+
245
325
  free(ctx);
246
326
  }
247
327
 
248
328
  void * ggml_metal_host_malloc(size_t n) {
249
329
  void * data = NULL;
250
- const int result = posix_memalign((void **) &data, getpagesize(), n);
330
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
251
331
  if (result != 0) {
252
- fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
332
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
253
333
  return NULL;
254
334
  }
255
335
 
@@ -261,7 +341,7 @@ void ggml_metal_host_free(void * data) {
261
341
  }
262
342
 
263
343
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
264
- ctx->n_cb = n_cb;
344
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
265
345
  }
266
346
 
267
347
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -277,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
277
357
  // Metal buffer based on the host memory pointer
278
358
  //
279
359
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
280
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
360
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
281
361
 
282
362
  const int64_t tsize = ggml_nbytes(t);
283
363
 
@@ -288,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
288
368
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
289
369
  *offs = (size_t) ioffs;
290
370
 
291
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
371
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
292
372
 
293
373
  return ctx->buffers[i].metal;
294
374
  }
295
375
  }
296
376
 
297
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
377
+ metal_printf("%s: error: buffer is nil\n", __func__);
298
378
 
299
379
  return nil;
300
380
  }
@@ -306,7 +386,7 @@ bool ggml_metal_add_buffer(
306
386
  size_t size,
307
387
  size_t max_size) {
308
388
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
309
- fprintf(stderr, "%s: too many buffers\n", __func__);
389
+ metal_printf("%s: too many buffers\n", __func__);
310
390
  return false;
311
391
  }
312
392
 
@@ -316,12 +396,12 @@ bool ggml_metal_add_buffer(
316
396
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
317
397
 
318
398
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
319
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
399
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
320
400
  return false;
321
401
  }
322
402
  }
323
403
 
324
- const size_t size_page = getpagesize();
404
+ const size_t size_page = sysconf(_SC_PAGESIZE);
325
405
 
326
406
  size_t size_aligned = size;
327
407
  if ((size_aligned % size_page) != 0) {
@@ -337,11 +417,11 @@ bool ggml_metal_add_buffer(
337
417
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
338
418
 
339
419
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
340
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
420
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
341
421
  return false;
342
422
  }
343
423
 
344
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
424
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
345
425
 
346
426
  ++ctx->n_buffers;
347
427
  } else {
@@ -361,27 +441,27 @@ bool ggml_metal_add_buffer(
361
441
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
362
442
 
363
443
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
364
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
444
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
365
445
  return false;
366
446
  }
367
447
 
368
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
448
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
369
449
  if (i + size_step < size) {
370
- fprintf(stderr, "\n");
450
+ metal_printf("\n");
371
451
  }
372
452
 
373
453
  ++ctx->n_buffers;
374
454
  }
375
455
  }
376
456
 
377
- fprintf(stderr, ", (%8.2f / %8.2f)",
457
+ metal_printf(", (%8.2f / %8.2f)",
378
458
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
379
459
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
380
460
 
381
461
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
382
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
462
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
383
463
  } else {
384
- fprintf(stderr, "\n");
464
+ metal_printf("\n");
385
465
  }
386
466
  }
387
467
 
@@ -391,8 +471,6 @@ bool ggml_metal_add_buffer(
391
471
  void ggml_metal_set_tensor(
392
472
  struct ggml_metal_context * ctx,
393
473
  struct ggml_tensor * t) {
394
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
395
-
396
474
  size_t offs;
397
475
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
398
476
 
@@ -402,8 +480,6 @@ void ggml_metal_set_tensor(
402
480
  void ggml_metal_get_tensor(
403
481
  struct ggml_metal_context * ctx,
404
482
  struct ggml_tensor * t) {
405
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
406
-
407
483
  size_t offs;
408
484
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
409
485
 
@@ -498,14 +574,14 @@ void ggml_metal_graph_find_concurrency(
498
574
  }
499
575
 
500
576
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
501
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
577
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
502
578
  }
503
579
  }
504
580
 
505
581
  void ggml_metal_graph_compute(
506
582
  struct ggml_metal_context * ctx,
507
583
  struct ggml_cgraph * gf) {
508
- metal_printf("%s: evaluating graph\n", __func__);
584
+ @autoreleasepool {
509
585
 
510
586
  // if there is ctx->concur_list, dispatch concurrently
511
587
  // else fallback to serial dispatch
@@ -521,29 +597,25 @@ void ggml_metal_graph_compute(
521
597
 
522
598
  const int n_cb = ctx->n_cb;
523
599
 
524
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
525
-
526
600
  for (int i = 0; i < n_cb; ++i) {
527
- command_buffers[i] = [ctx->queue commandBuffer];
601
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
528
602
 
529
603
  // enqueue the command buffers in order to specify their execution order
530
- [command_buffers[i] enqueue];
531
- }
604
+ [ctx->command_buffers[i] enqueue];
532
605
 
533
- // TODO: is this the best way to start threads?
534
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
606
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
607
+ }
535
608
 
536
609
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
537
610
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
538
611
 
539
- dispatch_async(queue, ^{
612
+ dispatch_async(ctx->d_queue, ^{
540
613
  size_t offs_src0 = 0;
541
614
  size_t offs_src1 = 0;
542
615
  size_t offs_dst = 0;
543
616
 
544
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
545
-
546
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
617
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
618
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
547
619
 
548
620
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
621
  const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@@ -556,7 +628,7 @@ void ggml_metal_graph_compute(
556
628
  continue;
557
629
  }
558
630
 
559
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
631
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
560
632
 
561
633
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
562
634
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -625,6 +697,12 @@ void ggml_metal_graph_compute(
625
697
  } break;
626
698
  case GGML_OP_ADD:
627
699
  {
700
+ GGML_ASSERT(ggml_is_contiguous(src0));
701
+
702
+ // utilize float4
703
+ GGML_ASSERT(ne00 % 4 == 0);
704
+ const int64_t nb = ne00/4;
705
+
628
706
  if (ggml_nelements(src1) == ne10) {
629
707
  // src1 is a row
630
708
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -634,14 +712,20 @@ void ggml_metal_graph_compute(
634
712
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
635
713
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
636
714
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
637
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
715
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
638
716
 
639
- const int64_t n = ggml_nelements(dst);
717
+ const int64_t n = ggml_nelements(dst)/4;
640
718
 
641
719
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
642
720
  } break;
643
721
  case GGML_OP_MUL:
644
722
  {
723
+ GGML_ASSERT(ggml_is_contiguous(src0));
724
+
725
+ // utilize float4
726
+ GGML_ASSERT(ne00 % 4 == 0);
727
+ const int64_t nb = ne00/4;
728
+
645
729
  if (ggml_nelements(src1) == ne10) {
646
730
  // src1 is a row
647
731
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -651,9 +735,9 @@ void ggml_metal_graph_compute(
651
735
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
652
736
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
653
737
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
654
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
738
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
655
739
 
656
- const int64_t n = ggml_nelements(dst);
740
+ const int64_t n = ggml_nelements(dst)/4;
657
741
 
658
742
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
659
743
  } break;
@@ -704,7 +788,7 @@ void ggml_metal_graph_compute(
704
788
  } break;
705
789
  default:
706
790
  {
707
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
791
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
708
792
  GGML_ASSERT(false);
709
793
  }
710
794
  } break;
@@ -785,9 +869,13 @@ void ggml_metal_graph_compute(
785
869
  switch (src0t) {
786
870
  case GGML_TYPE_F16:
787
871
  {
788
- nth0 = 64;
872
+ nth0 = 32;
789
873
  nth1 = 1;
790
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
874
+ if (ne11 * ne12 < 4) {
875
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
876
+ } else {
877
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
878
+ }
791
879
  } break;
792
880
  case GGML_TYPE_Q4_0:
793
881
  {
@@ -839,8 +927,8 @@ void ggml_metal_graph_compute(
839
927
  GGML_ASSERT(ne02 == 1);
840
928
  GGML_ASSERT(ne12 == 1);
841
929
 
842
- nth0 = 2;
843
- nth1 = 32;
930
+ nth0 = 4; //1;
931
+ nth1 = 8; //32;
844
932
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
845
933
  } break;
846
934
  case GGML_TYPE_Q5_K:
@@ -863,7 +951,7 @@ void ggml_metal_graph_compute(
863
951
  } break;
864
952
  default:
865
953
  {
866
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
954
+ metal_printf("Asserting on type %d\n",(int)src0t);
867
955
  GGML_ASSERT(false && "not implemented");
868
956
  }
869
957
  };
@@ -888,9 +976,12 @@ void ggml_metal_graph_compute(
888
976
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
889
977
 
890
978
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
891
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
979
+ src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
892
980
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
893
981
  }
982
+ else if (src0t == GGML_TYPE_Q4_K) {
983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
984
+ }
894
985
  else if (src0t == GGML_TYPE_Q3_K) {
895
986
  #ifdef GGML_QKK_64
896
987
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -904,8 +995,8 @@ void ggml_metal_graph_compute(
904
995
  else if (src0t == GGML_TYPE_Q6_K) {
905
996
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
906
997
  } else {
907
- [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
908
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
998
+ int64_t ny = (ne11 + 3)/4;
999
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
909
1000
  }
910
1001
  }
911
1002
  } break;
@@ -1050,7 +1141,7 @@ void ggml_metal_graph_compute(
1050
1141
  [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1051
1142
  [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1052
1143
 
1053
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1144
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1054
1145
  } break;
1055
1146
  case GGML_OP_DUP:
1056
1147
  case GGML_OP_CPY:
@@ -1101,7 +1192,7 @@ void ggml_metal_graph_compute(
1101
1192
  } break;
1102
1193
  default:
1103
1194
  {
1104
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1195
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1105
1196
  GGML_ASSERT(false);
1106
1197
  }
1107
1198
  }
@@ -1117,17 +1208,19 @@ void ggml_metal_graph_compute(
1117
1208
  }
1118
1209
 
1119
1210
  // wait for all threads to finish
1120
- dispatch_barrier_sync(queue, ^{});
1121
-
1122
- [command_buffers[n_cb - 1] waitUntilCompleted];
1211
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1123
1212
 
1124
1213
  // check status of command buffers
1125
1214
  // needed to detect if the device ran out-of-memory for example (#1881)
1126
1215
  for (int i = 0; i < n_cb; i++) {
1127
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1216
+ [ctx->command_buffers[i] waitUntilCompleted];
1217
+
1218
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1128
1219
  if (status != MTLCommandBufferStatusCompleted) {
1129
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1220
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1130
1221
  GGML_ASSERT(false);
1131
1222
  }
1132
1223
  }
1224
+
1225
+ }
1133
1226
  }