llama_cpp 0.12.2 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,19 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define GGML_METAL_MAX_KERNELS 256
28
-
29
- struct ggml_metal_buffer {
30
- const char * name;
31
-
32
- void * data;
33
- size_t size;
34
-
35
- id<MTLBuffer> metal;
36
- };
37
-
38
27
  struct ggml_metal_kernel {
39
- id<MTLFunction> function;
40
28
  id<MTLComputePipelineState> pipeline;
41
29
  };
42
30
 
@@ -72,6 +60,7 @@ enum ggml_metal_kernel_type {
72
60
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
75
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
65
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
66
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -93,6 +82,7 @@ enum ggml_metal_kernel_type {
93
82
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
83
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
96
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
87
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -110,6 +100,7 @@ enum ggml_metal_kernel_type {
110
100
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
101
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
102
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
113
104
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
105
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
106
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -124,6 +115,7 @@ enum ggml_metal_kernel_type {
124
115
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
127
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -138,10 +130,12 @@ enum ggml_metal_kernel_type {
138
130
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
141
134
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
137
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
138
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
145
139
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
140
  GGML_METAL_KERNEL_TYPE_PAD_F32,
147
141
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -168,17 +162,15 @@ struct ggml_metal_context {
168
162
 
169
163
  id<MTLDevice> device;
170
164
  id<MTLCommandQueue> queue;
171
- id<MTLLibrary> library;
172
165
 
173
166
  dispatch_queue_t d_queue;
174
167
 
175
- int n_buffers;
176
- struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
177
-
178
- struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
168
+ struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
179
169
 
180
170
  bool support_simdgroup_reduction;
181
171
  bool support_simdgroup_mm;
172
+
173
+ bool should_capture_next_compute;
182
174
  };
183
175
 
184
176
  // MSL code
@@ -238,32 +230,28 @@ static void * ggml_metal_host_malloc(size_t n) {
238
230
  static struct ggml_metal_context * ggml_metal_init(int n_cb) {
239
231
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
240
232
 
241
- id<MTLDevice> device;
242
- NSString * s;
243
-
244
- #if TARGET_OS_OSX
233
+ #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
245
234
  // Show all the Metal device instances in the system
246
235
  NSArray * devices = MTLCopyAllDevices();
247
- for (device in devices) {
248
- s = [device name];
249
- GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
236
+ for (id<MTLDevice> device in devices) {
237
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
250
238
  }
239
+ [devices release]; // since it was created by a *Copy* C method
251
240
  #endif
252
241
 
253
242
  // Pick and show default Metal device
254
- device = MTLCreateSystemDefaultDevice();
255
- s = [device name];
256
- GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
243
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
244
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
257
245
 
258
246
  // Configure context
259
247
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
260
248
  ctx->device = device;
261
249
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
262
250
  ctx->queue = [ctx->device newCommandQueue];
263
- ctx->n_buffers = 0;
264
-
265
251
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
266
252
 
253
+ id<MTLLibrary> metal_library;
254
+
267
255
  // load library
268
256
  {
269
257
  NSBundle * bundle = nil;
@@ -278,7 +266,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
278
266
  // pre-compiled library found
279
267
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
280
268
  GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
281
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
269
+ metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
270
+ if (error) {
271
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
272
+ return NULL;
273
+ }
282
274
  } else {
283
275
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
284
276
 
@@ -303,27 +295,25 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
303
295
  return NULL;
304
296
  }
305
297
 
306
- // dictionary of preprocessor macros
307
- NSMutableDictionary * prep = [NSMutableDictionary dictionary];
298
+ @autoreleasepool {
299
+ // dictionary of preprocessor macros
300
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
308
301
 
309
302
  #ifdef GGML_QKK_64
310
- prep[@"QK_K"] = @(64);
303
+ prep[@"QK_K"] = @(64);
311
304
  #endif
312
305
 
313
- MTLCompileOptions* options = [MTLCompileOptions new];
314
- options.preprocessorMacros = prep;
315
-
316
- //[options setFastMathEnabled:false];
306
+ MTLCompileOptions* options = [MTLCompileOptions new];
307
+ options.preprocessorMacros = prep;
317
308
 
318
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
319
-
320
- [options release];
321
- [prep release];
322
- }
309
+ //[options setFastMathEnabled:false];
323
310
 
324
- if (error) {
325
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
326
- return NULL;
311
+ metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
312
+ if (error) {
313
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
314
+ return NULL;
315
+ }
316
+ }
327
317
  }
328
318
  }
329
319
 
@@ -367,6 +357,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
367
357
  GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
368
358
  GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
369
359
 
360
+ ctx->should_capture_next_compute = false;
361
+
370
362
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
371
363
  if (@available(macOS 10.12, iOS 16.0, *)) {
372
364
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
@@ -383,8 +375,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
383
375
  {
384
376
  NSError * error = nil;
385
377
 
386
- for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
387
- ctx->kernels[i].function = nil;
378
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
388
379
  ctx->kernels[i].pipeline = nil;
389
380
  }
390
381
 
@@ -396,10 +387,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
396
387
  #define GGML_METAL_ADD_KERNEL(e, name, supported) \
397
388
  if (supported) { \
398
389
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
399
- kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
400
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
390
+ id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
391
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
392
+ [metal_function release]; \
401
393
  if (error) { \
402
394
  GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
395
+ [metal_library release]; \
403
396
  return NULL; \
404
397
  } \
405
398
  } else { \
@@ -439,6 +432,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
439
432
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
440
433
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
441
434
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
442
436
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
443
437
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
444
438
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -460,6 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
460
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
461
455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462
456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
463
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
464
459
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
465
460
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -477,6 +472,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
477
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
478
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
479
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
480
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
481
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
482
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -491,6 +487,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
491
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
492
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
493
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
494
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
495
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
496
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -505,10 +502,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
505
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
506
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
507
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
508
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
509
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
510
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
511
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
512
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
513
512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
514
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -528,27 +527,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
528
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
529
528
  }
530
529
 
530
+ [metal_library release];
531
531
  return ctx;
532
532
  }
533
533
 
534
534
  static void ggml_metal_free(struct ggml_metal_context * ctx) {
535
535
  GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
536
536
 
537
- for (int i = 0; i < ctx->n_buffers; ++i) {
538
- [ctx->buffers[i].metal release];
537
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
538
+ [ctx->kernels[i].pipeline release];
539
539
  }
540
540
 
541
- for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
542
- if (ctx->kernels[i].pipeline) {
543
- [ctx->kernels[i].pipeline release];
544
- }
545
-
546
- if (ctx->kernels[i].function) {
547
- [ctx->kernels[i].function release];
548
- }
549
- }
550
-
551
- [ctx->library release];
552
541
  [ctx->queue release];
553
542
  [ctx->device release];
554
543
 
@@ -580,51 +569,30 @@ struct ggml_backend_metal_buffer_context {
580
569
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
581
570
  // Metal buffer based on the host memory pointer
582
571
  //
583
- static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
572
+ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
584
573
  //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
585
574
 
586
575
  const int64_t tsize = ggml_nbytes(t);
587
576
 
588
577
  ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
589
578
 
590
- // compatibility with ggml-backend
591
- if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
592
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
593
-
594
- // find the view that contains the tensor fully
595
- for (int i = 0; i < buf_ctx->n_buffers; ++i) {
596
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
597
-
598
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
599
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
600
- *offs = (size_t) ioffs;
601
-
602
- //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
603
-
604
- return buf_ctx->buffers[i].metal;
605
- }
606
- }
607
-
608
- GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
609
-
610
- return nil;
611
- }
579
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
612
580
 
613
581
  // find the view that contains the tensor fully
614
- for (int i = 0; i < ctx->n_buffers; ++i) {
615
- const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
582
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
583
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
616
584
 
617
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
618
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
585
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
586
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
619
587
  *offs = (size_t) ioffs;
620
588
 
621
- //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
589
+ //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
622
590
 
623
- return ctx->buffers[i].metal;
591
+ return buf_ctx->buffers[i].metal;
624
592
  }
625
593
  }
626
594
 
627
- GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
595
+ GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
628
596
 
629
597
  return nil;
630
598
  }
@@ -664,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
664
632
  case GGML_OP_ALIBI:
665
633
  case GGML_OP_ROPE:
666
634
  case GGML_OP_IM2COL:
635
+ return true;
636
+ case GGML_OP_POOL_1D:
637
+ case GGML_OP_POOL_2D:
638
+ return false;
667
639
  case GGML_OP_UPSCALE:
668
640
  case GGML_OP_PAD:
669
641
  case GGML_OP_ARGSORT:
@@ -671,7 +643,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
671
643
  return true;
672
644
  case GGML_OP_MUL_MAT:
673
645
  case GGML_OP_MUL_MAT_ID:
674
- return ctx->support_simdgroup_reduction;
646
+ return ctx->support_simdgroup_reduction &&
647
+ (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
675
648
  case GGML_OP_CPY:
676
649
  case GGML_OP_DUP:
677
650
  case GGML_OP_CONT:
@@ -713,7 +686,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
713
686
  static bool ggml_metal_graph_compute(
714
687
  struct ggml_metal_context * ctx,
715
688
  struct ggml_cgraph * gf) {
716
- @autoreleasepool {
717
689
 
718
690
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
719
691
  edesc.dispatchType = MTLDispatchTypeSerial;
@@ -725,6 +697,20 @@ static bool ggml_metal_graph_compute(
725
697
  const int n_cb = ctx->n_cb;
726
698
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
727
699
 
700
+ const bool should_capture = ctx->should_capture_next_compute;
701
+ if (should_capture) {
702
+ ctx->should_capture_next_compute = false;
703
+
704
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
705
+ descriptor.captureObject = ctx->queue;
706
+
707
+ NSError * error = nil;
708
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
709
+ GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
710
+ GGML_ASSERT(!"capture failed");
711
+ }
712
+ }
713
+
728
714
  id<MTLCommandBuffer> command_buffer_builder[n_cb];
729
715
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730
716
  id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
@@ -733,6 +719,7 @@ static bool ggml_metal_graph_compute(
733
719
  // enqueue the command buffers in order to specify their execution order
734
720
  [command_buffer enqueue];
735
721
  }
722
+
736
723
  const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
737
724
 
738
725
  dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
@@ -779,9 +766,9 @@ static bool ggml_metal_graph_compute(
779
766
  GGML_ASSERT(!"unsupported op");
780
767
  }
781
768
 
782
- #ifndef GGML_METAL_NDEBUG
783
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
784
- #endif
769
+ if (should_capture) {
770
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
771
+ }
785
772
 
786
773
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
787
774
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -817,9 +804,9 @@ static bool ggml_metal_graph_compute(
817
804
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
818
805
  const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
819
806
 
820
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
821
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
822
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
807
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
808
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
809
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
823
810
 
824
811
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
825
812
  //if (src0) {
@@ -1308,6 +1295,7 @@ static bool ggml_metal_graph_compute(
1308
1295
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1309
1296
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1310
1297
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1298
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1311
1299
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1312
1300
  }
1313
1301
 
@@ -1436,6 +1424,12 @@ static bool ggml_metal_graph_compute(
1436
1424
  nth1 = 16;
1437
1425
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1438
1426
  } break;
1427
+ case GGML_TYPE_IQ3_XXS:
1428
+ {
1429
+ nth0 = 4;
1430
+ nth1 = 16;
1431
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1432
+ } break;
1439
1433
  default:
1440
1434
  {
1441
1435
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1478,6 +1472,11 @@ static bool ggml_metal_graph_compute(
1478
1472
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
1473
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
1474
  }
1475
+ else if (src0t == GGML_TYPE_IQ3_XXS) {
1476
+ const int mem_size = 256*4+128;
1477
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1478
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1479
+ }
1481
1480
  else if (src0t == GGML_TYPE_Q4_K) {
1482
1481
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
1482
  }
@@ -1572,6 +1571,7 @@ static bool ggml_metal_graph_compute(
1572
1571
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1573
1572
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
1573
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1574
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1575
1575
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
1576
  }
1577
1577
 
@@ -1601,7 +1601,7 @@ static bool ggml_metal_graph_compute(
1601
1601
  struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1602
1602
 
1603
1603
  size_t offs_src_cur = 0;
1604
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1604
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1605
1605
 
1606
1606
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1607
1607
  }
@@ -1703,6 +1703,12 @@ static bool ggml_metal_graph_compute(
1703
1703
  nth1 = 16;
1704
1704
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1705
1705
  } break;
1706
+ case GGML_TYPE_IQ3_XXS:
1707
+ {
1708
+ nth0 = 4;
1709
+ nth1 = 16;
1710
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1711
+ } break;
1706
1712
  default:
1707
1713
  {
1708
1714
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1746,7 +1752,7 @@ static bool ggml_metal_graph_compute(
1746
1752
  struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1747
1753
 
1748
1754
  size_t offs_src_cur = 0;
1749
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1755
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1750
1756
 
1751
1757
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1752
1758
  }
@@ -1761,6 +1767,11 @@ static bool ggml_metal_graph_compute(
1761
1767
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
1768
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
1769
  }
1770
+ else if (src2t == GGML_TYPE_IQ3_XXS) {
1771
+ const int mem_size = 256*4+128;
1772
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1773
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1774
+ }
1764
1775
  else if (src2t == GGML_TYPE_Q4_K) {
1765
1776
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1766
1777
  }
@@ -1801,6 +1812,7 @@ static bool ggml_metal_graph_compute(
1801
1812
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1802
1813
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1803
1814
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1815
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1804
1816
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1805
1817
  default: GGML_ASSERT(false && "not implemented");
1806
1818
  }
@@ -2009,7 +2021,7 @@ static bool ggml_metal_graph_compute(
2009
2021
  {
2010
2022
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2011
2023
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2012
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
2024
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2013
2025
 
2014
2026
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2015
2027
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2017,6 +2029,7 @@ static bool ggml_metal_graph_compute(
2017
2029
  const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2018
2030
  const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2019
2031
  const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2032
+
2020
2033
  const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2021
2034
 
2022
2035
  const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2037,8 +2050,8 @@ static bool ggml_metal_graph_compute(
2037
2050
 
2038
2051
  id<MTLComputePipelineState> pipeline = nil;
2039
2052
 
2040
- switch (src0->type) {
2041
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2053
+ switch (dst->type) {
2054
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2042
2055
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2043
2056
  default: GGML_ASSERT(false);
2044
2057
  };
@@ -2231,15 +2244,12 @@ static bool ggml_metal_graph_compute(
2231
2244
  }
2232
2245
  }
2233
2246
 
2234
- #ifndef GGML_METAL_NDEBUG
2235
- [encoder popDebugGroup];
2236
- #endif
2247
+ if (should_capture) {
2248
+ [encoder popDebugGroup];
2249
+ }
2237
2250
  }
2238
2251
 
2239
- if (encoder != nil) {
2240
- [encoder endEncoding];
2241
- encoder = nil;
2242
- }
2252
+ [encoder endEncoding];
2243
2253
 
2244
2254
  [command_buffer commit];
2245
2255
  });
@@ -2258,8 +2268,11 @@ static bool ggml_metal_graph_compute(
2258
2268
  }
2259
2269
  }
2260
2270
 
2261
- return true;
2271
+ if (should_capture) {
2272
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
2262
2273
  }
2274
+
2275
+ return true;
2263
2276
  }
2264
2277
 
2265
2278
  ////////////////////////////////////////////////////////////////////////////////
@@ -2427,6 +2440,16 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backen
2427
2440
  UNUSED(buft);
2428
2441
  }
2429
2442
 
2443
+ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2444
+ id<MTLDevice> device = ggml_backend_metal_get_device();
2445
+ size_t max_size = device.maxBufferLength;
2446
+ ggml_backend_metal_free_device();
2447
+
2448
+ return max_size;
2449
+
2450
+ UNUSED(buft);
2451
+ }
2452
+
2430
2453
  GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2431
2454
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
2432
2455
 
@@ -2445,6 +2468,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2445
2468
  /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
2446
2469
  /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2447
2470
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2471
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
2448
2472
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2449
2473
  /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2450
2474
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
@@ -2619,6 +2643,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2619
2643
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2620
2644
  }
2621
2645
 
2646
+ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
2647
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2648
+
2649
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2650
+ ctx->should_capture_next_compute = true;
2651
+ }
2652
+
2622
2653
  GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2623
2654
 
2624
2655
  GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {