llama_cpp 0.4.0 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -72,6 +76,7 @@ struct ggml_metal_context {
72
76
  GGML_METAL_DECL_KERNEL(rms_norm);
73
77
  GGML_METAL_DECL_KERNEL(norm);
74
78
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
75
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
76
81
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
82
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -110,16 +115,31 @@ static NSString * const msl_library_source = @"see metal.metal";
110
115
  @end
111
116
 
112
117
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
113
- fprintf(stderr, "%s: allocating\n", __func__);
118
+ metal_printf("%s: allocating\n", __func__);
119
+
120
+ // Show all the Metal device instances in the system
121
+ NSArray * devices = MTLCopyAllDevices();
122
+ id <MTLDevice> device;
123
+ NSString * s;
124
+ for (device in devices) {
125
+ s = [device name];
126
+ metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
127
+ }
114
128
 
115
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
129
+ // Pick and show default Metal device
130
+ device = MTLCreateSystemDefaultDevice();
131
+ s = [device name];
132
+ metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
116
133
 
117
- ctx->n_cb = n_cb;
118
- ctx->device = MTLCreateSystemDefaultDevice();
134
+ // Configure context
135
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
136
+ ctx->device = device;
137
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
119
138
  ctx->queue = [ctx->device newCommandQueue];
120
139
  ctx->n_buffers = 0;
121
140
  ctx->concur_list_len = 0;
122
141
 
142
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
123
143
 
124
144
  #if 0
125
145
  // compile from source string and show compile log
@@ -128,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
128
148
 
129
149
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
130
150
  if (error) {
131
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
151
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
132
152
  return NULL;
133
153
  }
134
154
  }
@@ -142,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
142
162
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
143
163
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
144
164
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
145
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
165
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
146
166
 
147
167
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
148
168
  if (error) {
149
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
169
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
150
170
  return NULL;
151
171
  }
152
172
 
@@ -158,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
158
178
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
159
179
  #endif
160
180
  if (error) {
161
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
181
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
162
182
  return NULL;
163
183
  }
164
184
  }
@@ -170,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  #define GGML_METAL_ADD_KERNEL(name) \
171
191
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
172
192
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
- fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
193
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
194
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
195
  (int) ctx->pipeline_##name.threadExecutionWidth); \
176
196
  if (error) { \
177
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
197
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
198
  return NULL; \
179
199
  }
180
200
 
@@ -200,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
200
220
  GGML_METAL_ADD_KERNEL(rms_norm);
201
221
  GGML_METAL_ADD_KERNEL(norm);
202
222
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
203
224
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
204
225
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
226
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -226,30 +247,89 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
226
247
  #undef GGML_METAL_ADD_KERNEL
227
248
  }
228
249
 
229
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
250
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
251
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
231
252
  if (ctx->device.maxTransferRate != 0) {
232
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
253
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
233
254
  } else {
234
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
255
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
235
256
  }
236
257
 
237
258
  return ctx;
238
259
  }
239
260
 
240
261
  void ggml_metal_free(struct ggml_metal_context * ctx) {
241
- fprintf(stderr, "%s: deallocating\n", __func__);
262
+ metal_printf("%s: deallocating\n", __func__);
263
+ #define GGML_METAL_DEL_KERNEL(name) \
264
+ [ctx->function_##name release]; \
265
+ [ctx->pipeline_##name release];
266
+
267
+ GGML_METAL_DEL_KERNEL(add);
268
+ GGML_METAL_DEL_KERNEL(add_row);
269
+ GGML_METAL_DEL_KERNEL(mul);
270
+ GGML_METAL_DEL_KERNEL(mul_row);
271
+ GGML_METAL_DEL_KERNEL(scale);
272
+ GGML_METAL_DEL_KERNEL(silu);
273
+ GGML_METAL_DEL_KERNEL(relu);
274
+ GGML_METAL_DEL_KERNEL(gelu);
275
+ GGML_METAL_DEL_KERNEL(soft_max);
276
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
277
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
278
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
279
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
280
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
281
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
282
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
283
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
284
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
285
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
286
+ GGML_METAL_DEL_KERNEL(rms_norm);
287
+ GGML_METAL_DEL_KERNEL(norm);
288
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
290
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
293
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
294
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
295
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
296
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
297
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
298
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
299
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
300
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
301
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
302
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
303
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
304
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
305
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
306
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
307
+ GGML_METAL_DEL_KERNEL(rope);
308
+ GGML_METAL_DEL_KERNEL(alibi_f32);
309
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
310
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
311
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
312
+
313
+ #undef GGML_METAL_DEL_KERNEL
314
+
242
315
  for (int i = 0; i < ctx->n_buffers; ++i) {
243
316
  [ctx->buffers[i].metal release];
244
317
  }
318
+
319
+ [ctx->library release];
320
+ [ctx->queue release];
321
+ [ctx->device release];
322
+
323
+ dispatch_release(ctx->d_queue);
324
+
245
325
  free(ctx);
246
326
  }
247
327
 
248
328
  void * ggml_metal_host_malloc(size_t n) {
249
329
  void * data = NULL;
250
- const int result = posix_memalign((void **) &data, getpagesize(), n);
330
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
251
331
  if (result != 0) {
252
- fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
332
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
253
333
  return NULL;
254
334
  }
255
335
 
@@ -261,7 +341,7 @@ void ggml_metal_host_free(void * data) {
261
341
  }
262
342
 
263
343
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
264
- ctx->n_cb = n_cb;
344
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
265
345
  }
266
346
 
267
347
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -277,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
277
357
  // Metal buffer based on the host memory pointer
278
358
  //
279
359
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
280
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
360
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
281
361
 
282
362
  const int64_t tsize = ggml_nbytes(t);
283
363
 
@@ -288,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
288
368
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
289
369
  *offs = (size_t) ioffs;
290
370
 
291
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
371
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
292
372
 
293
373
  return ctx->buffers[i].metal;
294
374
  }
295
375
  }
296
376
 
297
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
377
+ metal_printf("%s: error: buffer is nil\n", __func__);
298
378
 
299
379
  return nil;
300
380
  }
@@ -306,7 +386,7 @@ bool ggml_metal_add_buffer(
306
386
  size_t size,
307
387
  size_t max_size) {
308
388
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
309
- fprintf(stderr, "%s: too many buffers\n", __func__);
389
+ metal_printf("%s: too many buffers\n", __func__);
310
390
  return false;
311
391
  }
312
392
 
@@ -316,12 +396,12 @@ bool ggml_metal_add_buffer(
316
396
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
317
397
 
318
398
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
319
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
399
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
320
400
  return false;
321
401
  }
322
402
  }
323
403
 
324
- const size_t size_page = getpagesize();
404
+ const size_t size_page = sysconf(_SC_PAGESIZE);
325
405
 
326
406
  size_t size_aligned = size;
327
407
  if ((size_aligned % size_page) != 0) {
@@ -337,11 +417,11 @@ bool ggml_metal_add_buffer(
337
417
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
338
418
 
339
419
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
340
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
420
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
341
421
  return false;
342
422
  }
343
423
 
344
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
424
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
345
425
 
346
426
  ++ctx->n_buffers;
347
427
  } else {
@@ -361,27 +441,27 @@ bool ggml_metal_add_buffer(
361
441
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
362
442
 
363
443
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
364
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
444
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
365
445
  return false;
366
446
  }
367
447
 
368
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
448
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
369
449
  if (i + size_step < size) {
370
- fprintf(stderr, "\n");
450
+ metal_printf("\n");
371
451
  }
372
452
 
373
453
  ++ctx->n_buffers;
374
454
  }
375
455
  }
376
456
 
377
- fprintf(stderr, ", (%8.2f / %8.2f)",
457
+ metal_printf(", (%8.2f / %8.2f)",
378
458
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
379
459
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
380
460
 
381
461
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
382
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
462
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
383
463
  } else {
384
- fprintf(stderr, "\n");
464
+ metal_printf("\n");
385
465
  }
386
466
  }
387
467
 
@@ -391,8 +471,6 @@ bool ggml_metal_add_buffer(
391
471
  void ggml_metal_set_tensor(
392
472
  struct ggml_metal_context * ctx,
393
473
  struct ggml_tensor * t) {
394
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
395
-
396
474
  size_t offs;
397
475
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
398
476
 
@@ -402,8 +480,6 @@ void ggml_metal_set_tensor(
402
480
  void ggml_metal_get_tensor(
403
481
  struct ggml_metal_context * ctx,
404
482
  struct ggml_tensor * t) {
405
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
406
-
407
483
  size_t offs;
408
484
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
409
485
 
@@ -498,14 +574,14 @@ void ggml_metal_graph_find_concurrency(
498
574
  }
499
575
 
500
576
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
501
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
577
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
502
578
  }
503
579
  }
504
580
 
505
581
  void ggml_metal_graph_compute(
506
582
  struct ggml_metal_context * ctx,
507
583
  struct ggml_cgraph * gf) {
508
- metal_printf("%s: evaluating graph\n", __func__);
584
+ @autoreleasepool {
509
585
 
510
586
  // if there is ctx->concur_list, dispatch concurrently
511
587
  // else fallback to serial dispatch
@@ -521,29 +597,25 @@ void ggml_metal_graph_compute(
521
597
 
522
598
  const int n_cb = ctx->n_cb;
523
599
 
524
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
525
-
526
600
  for (int i = 0; i < n_cb; ++i) {
527
- command_buffers[i] = [ctx->queue commandBuffer];
601
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
528
602
 
529
603
  // enqueue the command buffers in order to specify their execution order
530
- [command_buffers[i] enqueue];
531
- }
604
+ [ctx->command_buffers[i] enqueue];
532
605
 
533
- // TODO: is this the best way to start threads?
534
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
606
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
607
+ }
535
608
 
536
609
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
537
610
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
538
611
 
539
- dispatch_async(queue, ^{
612
+ dispatch_async(ctx->d_queue, ^{
540
613
  size_t offs_src0 = 0;
541
614
  size_t offs_src1 = 0;
542
615
  size_t offs_dst = 0;
543
616
 
544
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
545
-
546
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
617
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
618
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
547
619
 
548
620
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
621
  const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@@ -556,7 +628,7 @@ void ggml_metal_graph_compute(
556
628
  continue;
557
629
  }
558
630
 
559
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
631
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
560
632
 
561
633
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
562
634
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -625,6 +697,12 @@ void ggml_metal_graph_compute(
625
697
  } break;
626
698
  case GGML_OP_ADD:
627
699
  {
700
+ GGML_ASSERT(ggml_is_contiguous(src0));
701
+
702
+ // utilize float4
703
+ GGML_ASSERT(ne00 % 4 == 0);
704
+ const int64_t nb = ne00/4;
705
+
628
706
  if (ggml_nelements(src1) == ne10) {
629
707
  // src1 is a row
630
708
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -634,14 +712,20 @@ void ggml_metal_graph_compute(
634
712
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
635
713
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
636
714
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
637
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
715
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
638
716
 
639
- const int64_t n = ggml_nelements(dst);
717
+ const int64_t n = ggml_nelements(dst)/4;
640
718
 
641
719
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
642
720
  } break;
643
721
  case GGML_OP_MUL:
644
722
  {
723
+ GGML_ASSERT(ggml_is_contiguous(src0));
724
+
725
+ // utilize float4
726
+ GGML_ASSERT(ne00 % 4 == 0);
727
+ const int64_t nb = ne00/4;
728
+
645
729
  if (ggml_nelements(src1) == ne10) {
646
730
  // src1 is a row
647
731
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -651,9 +735,9 @@ void ggml_metal_graph_compute(
651
735
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
652
736
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
653
737
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
654
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
738
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
655
739
 
656
- const int64_t n = ggml_nelements(dst);
740
+ const int64_t n = ggml_nelements(dst)/4;
657
741
 
658
742
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
659
743
  } break;
@@ -704,7 +788,7 @@ void ggml_metal_graph_compute(
704
788
  } break;
705
789
  default:
706
790
  {
707
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
791
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
708
792
  GGML_ASSERT(false);
709
793
  }
710
794
  } break;
@@ -785,9 +869,13 @@ void ggml_metal_graph_compute(
785
869
  switch (src0t) {
786
870
  case GGML_TYPE_F16:
787
871
  {
788
- nth0 = 64;
872
+ nth0 = 32;
789
873
  nth1 = 1;
790
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
874
+ if (ne11 * ne12 < 4) {
875
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
876
+ } else {
877
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
878
+ }
791
879
  } break;
792
880
  case GGML_TYPE_Q4_0:
793
881
  {
@@ -839,8 +927,8 @@ void ggml_metal_graph_compute(
839
927
  GGML_ASSERT(ne02 == 1);
840
928
  GGML_ASSERT(ne12 == 1);
841
929
 
842
- nth0 = 2;
843
- nth1 = 32;
930
+ nth0 = 4; //1;
931
+ nth1 = 8; //32;
844
932
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
845
933
  } break;
846
934
  case GGML_TYPE_Q5_K:
@@ -863,7 +951,7 @@ void ggml_metal_graph_compute(
863
951
  } break;
864
952
  default:
865
953
  {
866
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
954
+ metal_printf("Asserting on type %d\n",(int)src0t);
867
955
  GGML_ASSERT(false && "not implemented");
868
956
  }
869
957
  };
@@ -888,9 +976,12 @@ void ggml_metal_graph_compute(
888
976
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
889
977
 
890
978
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
891
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
979
+ src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
892
980
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
893
981
  }
982
+ else if (src0t == GGML_TYPE_Q4_K) {
983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
984
+ }
894
985
  else if (src0t == GGML_TYPE_Q3_K) {
895
986
  #ifdef GGML_QKK_64
896
987
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -904,8 +995,8 @@ void ggml_metal_graph_compute(
904
995
  else if (src0t == GGML_TYPE_Q6_K) {
905
996
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
906
997
  } else {
907
- [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
908
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
998
+ int64_t ny = (ne11 + 3)/4;
999
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
909
1000
  }
910
1001
  }
911
1002
  } break;
@@ -1050,7 +1141,7 @@ void ggml_metal_graph_compute(
1050
1141
  [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1051
1142
  [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1052
1143
 
1053
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1144
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1054
1145
  } break;
1055
1146
  case GGML_OP_DUP:
1056
1147
  case GGML_OP_CPY:
@@ -1101,7 +1192,7 @@ void ggml_metal_graph_compute(
1101
1192
  } break;
1102
1193
  default:
1103
1194
  {
1104
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1195
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1105
1196
  GGML_ASSERT(false);
1106
1197
  }
1107
1198
  }
@@ -1117,17 +1208,19 @@ void ggml_metal_graph_compute(
1117
1208
  }
1118
1209
 
1119
1210
  // wait for all threads to finish
1120
- dispatch_barrier_sync(queue, ^{});
1121
-
1122
- [command_buffers[n_cb - 1] waitUntilCompleted];
1211
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1123
1212
 
1124
1213
  // check status of command buffers
1125
1214
  // needed to detect if the device ran out-of-memory for example (#1881)
1126
1215
  for (int i = 0; i < n_cb; i++) {
1127
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1216
+ [ctx->command_buffers[i] waitUntilCompleted];
1217
+
1218
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1128
1219
  if (status != MTLCommandBufferStatusCompleted) {
1129
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1220
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1130
1221
  GGML_ASSERT(false);
1131
1222
  }
1132
1223
  }
1224
+
1225
+ }
1133
1226
  }