llama_cpp 0.12.3 → 0.12.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-backend.h"
5
+
6
+ #include <stdbool.h>
7
+ #include <stddef.h>
8
+ #include <stdint.h>
9
+
10
+ #ifdef __cplusplus
11
+ extern "C" {
12
+ #endif
13
+
14
+ struct ggml_vk_device {
15
+ int index;
16
+ int type; // same as VkPhysicalDeviceType
17
+ size_t heapSize;
18
+ const char * name;
19
+ const char * vendor;
20
+ int subgroupSize;
21
+ uint64_t bufferAlignment;
22
+ uint64_t maxAlloc;
23
+ };
24
+
25
+ struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
26
+ bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
27
+ bool ggml_vk_has_vulkan(void);
28
+ bool ggml_vk_has_device(void);
29
+ struct ggml_vk_device ggml_vk_current_device(void);
30
+
31
+ //
32
+ // backend API
33
+ //
34
+
35
+ // forward declaration
36
+ typedef struct ggml_backend * ggml_backend_t;
37
+
38
+ GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
39
+
40
+ GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
41
+
42
+ GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
43
+
44
+ #ifdef __cplusplus
45
+ }
46
+ #endif
@@ -57,6 +57,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
57
57
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
58
58
  GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
59
59
 
60
+ // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
61
+ GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
62
+
60
63
  #ifdef __cplusplus
61
64
  }
62
65
  #endif
@@ -24,19 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define GGML_METAL_MAX_KERNELS 256
28
-
29
- struct ggml_metal_buffer {
30
- const char * name;
31
-
32
- void * data;
33
- size_t size;
34
-
35
- id<MTLBuffer> metal;
36
- };
37
-
38
27
  struct ggml_metal_kernel {
39
- id<MTLFunction> function;
40
28
  id<MTLComputePipelineState> pipeline;
41
29
  };
42
30
 
@@ -72,6 +60,7 @@ enum ggml_metal_kernel_type {
72
60
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
75
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
65
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
66
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -93,6 +82,7 @@ enum ggml_metal_kernel_type {
93
82
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
83
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
96
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
87
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -110,6 +100,7 @@ enum ggml_metal_kernel_type {
110
100
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
101
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
102
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
113
104
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
105
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
106
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -124,6 +115,7 @@ enum ggml_metal_kernel_type {
124
115
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
127
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -138,10 +130,12 @@ enum ggml_metal_kernel_type {
138
130
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
141
134
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
137
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
138
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
145
139
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
140
  GGML_METAL_KERNEL_TYPE_PAD_F32,
147
141
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -168,17 +162,15 @@ struct ggml_metal_context {
168
162
 
169
163
  id<MTLDevice> device;
170
164
  id<MTLCommandQueue> queue;
171
- id<MTLLibrary> library;
172
165
 
173
166
  dispatch_queue_t d_queue;
174
167
 
175
- int n_buffers;
176
- struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
177
-
178
- struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
168
+ struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
179
169
 
180
170
  bool support_simdgroup_reduction;
181
171
  bool support_simdgroup_mm;
172
+
173
+ bool should_capture_next_compute;
182
174
  };
183
175
 
184
176
  // MSL code
@@ -242,26 +234,24 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
242
234
  // Show all the Metal device instances in the system
243
235
  NSArray * devices = MTLCopyAllDevices();
244
236
  for (id<MTLDevice> device in devices) {
245
- NSString * s = [device name];
246
- GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
237
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
247
238
  }
248
239
  [devices release]; // since it was created by a *Copy* C method
249
240
  #endif
250
241
 
251
242
  // Pick and show default Metal device
252
243
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
253
- NSString * s = [device name];
254
- GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
244
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
255
245
 
256
246
  // Configure context
257
247
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
258
248
  ctx->device = device;
259
249
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
260
250
  ctx->queue = [ctx->device newCommandQueue];
261
- ctx->n_buffers = 0;
262
-
263
251
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
264
252
 
253
+ id<MTLLibrary> metal_library;
254
+
265
255
  // load library
266
256
  {
267
257
  NSBundle * bundle = nil;
@@ -276,7 +266,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
276
266
  // pre-compiled library found
277
267
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
278
268
  GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
279
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
269
+ metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
280
270
  if (error) {
281
271
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
282
272
  return NULL;
@@ -318,7 +308,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
318
308
 
319
309
  //[options setFastMathEnabled:false];
320
310
 
321
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
311
+ metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
322
312
  if (error) {
323
313
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
324
314
  return NULL;
@@ -367,6 +357,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
367
357
  GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
368
358
  GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
369
359
 
360
+ ctx->should_capture_next_compute = false;
361
+
370
362
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
371
363
  if (@available(macOS 10.12, iOS 16.0, *)) {
372
364
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
@@ -383,8 +375,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
383
375
  {
384
376
  NSError * error = nil;
385
377
 
386
- for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
387
- ctx->kernels[i].function = nil;
378
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
388
379
  ctx->kernels[i].pipeline = nil;
389
380
  }
390
381
 
@@ -396,10 +387,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
396
387
  #define GGML_METAL_ADD_KERNEL(e, name, supported) \
397
388
  if (supported) { \
398
389
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
399
- kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
400
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
390
+ id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
391
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
392
+ [metal_function release]; \
401
393
  if (error) { \
402
394
  GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
395
+ [metal_library release]; \
403
396
  return NULL; \
404
397
  } \
405
398
  } else { \
@@ -439,6 +432,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
439
432
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
440
433
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
441
434
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
442
436
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
443
437
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
444
438
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -460,6 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
460
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
461
455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462
456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
463
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
464
459
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
465
460
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -477,6 +472,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
477
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
478
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
479
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
480
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
481
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
482
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -491,6 +487,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
491
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
492
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
493
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
494
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
495
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
496
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -505,10 +502,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
505
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
506
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
507
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
508
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
509
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
510
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
511
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
512
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
513
512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
514
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -528,27 +527,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
528
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
529
528
  }
530
529
 
530
+ [metal_library release];
531
531
  return ctx;
532
532
  }
533
533
 
534
534
  static void ggml_metal_free(struct ggml_metal_context * ctx) {
535
535
  GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
536
536
 
537
- for (int i = 0; i < ctx->n_buffers; ++i) {
538
- [ctx->buffers[i].metal release];
537
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
538
+ [ctx->kernels[i].pipeline release];
539
539
  }
540
540
 
541
- for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
542
- if (ctx->kernels[i].pipeline) {
543
- [ctx->kernels[i].pipeline release];
544
- }
545
-
546
- if (ctx->kernels[i].function) {
547
- [ctx->kernels[i].function release];
548
- }
549
- }
550
-
551
- [ctx->library release];
552
541
  [ctx->queue release];
553
542
  [ctx->device release];
554
543
 
@@ -580,51 +569,30 @@ struct ggml_backend_metal_buffer_context {
580
569
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
581
570
  // Metal buffer based on the host memory pointer
582
571
  //
583
- static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
572
+ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
584
573
  //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
585
574
 
586
575
  const int64_t tsize = ggml_nbytes(t);
587
576
 
588
577
  ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
589
578
 
590
- // compatibility with ggml-backend
591
- if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
592
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
593
-
594
- // find the view that contains the tensor fully
595
- for (int i = 0; i < buf_ctx->n_buffers; ++i) {
596
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
597
-
598
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
599
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
600
- *offs = (size_t) ioffs;
601
-
602
- //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
603
-
604
- return buf_ctx->buffers[i].metal;
605
- }
606
- }
607
-
608
- GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
609
-
610
- return nil;
611
- }
579
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
612
580
 
613
581
  // find the view that contains the tensor fully
614
- for (int i = 0; i < ctx->n_buffers; ++i) {
615
- const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
582
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
583
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
616
584
 
617
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
618
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
585
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
586
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
619
587
  *offs = (size_t) ioffs;
620
588
 
621
- //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
589
+ //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
622
590
 
623
- return ctx->buffers[i].metal;
591
+ return buf_ctx->buffers[i].metal;
624
592
  }
625
593
  }
626
594
 
627
- GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
595
+ GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
628
596
 
629
597
  return nil;
630
598
  }
@@ -664,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
664
632
  case GGML_OP_ALIBI:
665
633
  case GGML_OP_ROPE:
666
634
  case GGML_OP_IM2COL:
635
+ return true;
636
+ case GGML_OP_POOL_1D:
637
+ case GGML_OP_POOL_2D:
638
+ return false;
667
639
  case GGML_OP_UPSCALE:
668
640
  case GGML_OP_PAD:
669
641
  case GGML_OP_ARGSORT:
@@ -725,6 +697,20 @@ static bool ggml_metal_graph_compute(
725
697
  const int n_cb = ctx->n_cb;
726
698
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
727
699
 
700
+ const bool should_capture = ctx->should_capture_next_compute;
701
+ if (should_capture) {
702
+ ctx->should_capture_next_compute = false;
703
+
704
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
705
+ descriptor.captureObject = ctx->queue;
706
+
707
+ NSError * error = nil;
708
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
709
+ GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
710
+ GGML_ASSERT(!"capture failed");
711
+ }
712
+ }
713
+
728
714
  id<MTLCommandBuffer> command_buffer_builder[n_cb];
729
715
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730
716
  id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
@@ -733,6 +719,7 @@ static bool ggml_metal_graph_compute(
733
719
  // enqueue the command buffers in order to specify their execution order
734
720
  [command_buffer enqueue];
735
721
  }
722
+
736
723
  const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
737
724
 
738
725
  dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
@@ -779,9 +766,9 @@ static bool ggml_metal_graph_compute(
779
766
  GGML_ASSERT(!"unsupported op");
780
767
  }
781
768
 
782
- #ifndef GGML_METAL_NDEBUG
783
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
784
- #endif
769
+ if (should_capture) {
770
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
771
+ }
785
772
 
786
773
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
787
774
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -817,9 +804,9 @@ static bool ggml_metal_graph_compute(
817
804
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
818
805
  const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
819
806
 
820
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
821
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
822
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
807
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
808
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
809
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
823
810
 
824
811
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
825
812
  //if (src0) {
@@ -1308,6 +1295,7 @@ static bool ggml_metal_graph_compute(
1308
1295
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1309
1296
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1310
1297
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1298
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1311
1299
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1312
1300
  }
1313
1301
 
@@ -1436,6 +1424,12 @@ static bool ggml_metal_graph_compute(
1436
1424
  nth1 = 16;
1437
1425
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1438
1426
  } break;
1427
+ case GGML_TYPE_IQ3_XXS:
1428
+ {
1429
+ nth0 = 4;
1430
+ nth1 = 16;
1431
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1432
+ } break;
1439
1433
  default:
1440
1434
  {
1441
1435
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1478,6 +1472,11 @@ static bool ggml_metal_graph_compute(
1478
1472
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
1473
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
1474
  }
1475
+ else if (src0t == GGML_TYPE_IQ3_XXS) {
1476
+ const int mem_size = 256*4+128;
1477
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1478
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1479
+ }
1481
1480
  else if (src0t == GGML_TYPE_Q4_K) {
1482
1481
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
1482
  }
@@ -1572,6 +1571,7 @@ static bool ggml_metal_graph_compute(
1572
1571
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1573
1572
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
1573
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1574
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1575
1575
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
1576
  }
1577
1577
 
@@ -1601,7 +1601,7 @@ static bool ggml_metal_graph_compute(
1601
1601
  struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1602
1602
 
1603
1603
  size_t offs_src_cur = 0;
1604
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1604
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1605
1605
 
1606
1606
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1607
1607
  }
@@ -1703,6 +1703,12 @@ static bool ggml_metal_graph_compute(
1703
1703
  nth1 = 16;
1704
1704
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1705
1705
  } break;
1706
+ case GGML_TYPE_IQ3_XXS:
1707
+ {
1708
+ nth0 = 4;
1709
+ nth1 = 16;
1710
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1711
+ } break;
1706
1712
  default:
1707
1713
  {
1708
1714
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1746,7 +1752,7 @@ static bool ggml_metal_graph_compute(
1746
1752
  struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1747
1753
 
1748
1754
  size_t offs_src_cur = 0;
1749
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1755
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1750
1756
 
1751
1757
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1752
1758
  }
@@ -1761,6 +1767,11 @@ static bool ggml_metal_graph_compute(
1761
1767
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
1768
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
1769
  }
1770
+ else if (src2t == GGML_TYPE_IQ3_XXS) {
1771
+ const int mem_size = 256*4+128;
1772
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1773
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1774
+ }
1764
1775
  else if (src2t == GGML_TYPE_Q4_K) {
1765
1776
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1766
1777
  }
@@ -1801,6 +1812,7 @@ static bool ggml_metal_graph_compute(
1801
1812
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1802
1813
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1803
1814
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1815
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1804
1816
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1805
1817
  default: GGML_ASSERT(false && "not implemented");
1806
1818
  }
@@ -2009,7 +2021,7 @@ static bool ggml_metal_graph_compute(
2009
2021
  {
2010
2022
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2011
2023
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2012
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
2024
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2013
2025
 
2014
2026
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2015
2027
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2017,6 +2029,7 @@ static bool ggml_metal_graph_compute(
2017
2029
  const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2018
2030
  const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2019
2031
  const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2032
+
2020
2033
  const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2021
2034
 
2022
2035
  const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2037,8 +2050,8 @@ static bool ggml_metal_graph_compute(
2037
2050
 
2038
2051
  id<MTLComputePipelineState> pipeline = nil;
2039
2052
 
2040
- switch (src0->type) {
2041
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2053
+ switch (dst->type) {
2054
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2042
2055
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2043
2056
  default: GGML_ASSERT(false);
2044
2057
  };
@@ -2231,9 +2244,9 @@ static bool ggml_metal_graph_compute(
2231
2244
  }
2232
2245
  }
2233
2246
 
2234
- #ifndef GGML_METAL_NDEBUG
2235
- [encoder popDebugGroup];
2236
- #endif
2247
+ if (should_capture) {
2248
+ [encoder popDebugGroup];
2249
+ }
2237
2250
  }
2238
2251
 
2239
2252
  [encoder endEncoding];
@@ -2255,6 +2268,10 @@ static bool ggml_metal_graph_compute(
2255
2268
  }
2256
2269
  }
2257
2270
 
2271
+ if (should_capture) {
2272
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
2273
+ }
2274
+
2258
2275
  return true;
2259
2276
  }
2260
2277
 
@@ -2423,6 +2440,16 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backen
2423
2440
  UNUSED(buft);
2424
2441
  }
2425
2442
 
2443
+ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2444
+ id<MTLDevice> device = ggml_backend_metal_get_device();
2445
+ size_t max_size = device.maxBufferLength;
2446
+ ggml_backend_metal_free_device();
2447
+
2448
+ return max_size;
2449
+
2450
+ UNUSED(buft);
2451
+ }
2452
+
2426
2453
  GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2427
2454
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
2428
2455
 
@@ -2441,6 +2468,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2441
2468
  /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
2442
2469
  /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2443
2470
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2471
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
2444
2472
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2445
2473
  /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2446
2474
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
@@ -2615,6 +2643,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2615
2643
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2616
2644
  }
2617
2645
 
2646
+ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
2647
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2648
+
2649
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2650
+ ctx->should_capture_next_compute = true;
2651
+ }
2652
+
2618
2653
  GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2619
2654
 
2620
2655
  GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {