llama_cpp 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,29 +8,25 @@ extern "C" {
8
8
 
9
9
  #define GGML_CUDA_MAX_DEVICES 16
10
10
 
11
- void ggml_init_cublas(void);
12
- void ggml_cuda_set_tensor_split(const float * tensor_split);
13
-
14
- void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
- bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
- size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
17
- void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
18
-
19
- // TODO: export these with GGML_API
20
- void * ggml_cuda_host_malloc(size_t size);
21
- void ggml_cuda_host_free(void * ptr);
22
-
23
- void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
24
-
25
- void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
- void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
- void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
- void ggml_cuda_set_main_device(int main_device);
30
- void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
31
- void ggml_cuda_set_scratch_size(size_t scratch_size);
32
- void ggml_cuda_free_scratch(void);
33
- bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
11
+ GGML_API void ggml_init_cublas(void);
12
+ GGML_API void * ggml_cuda_host_malloc(size_t size);
13
+ GGML_API void ggml_cuda_host_free(void * ptr);
14
+
15
+ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
+ GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
+ GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
+ GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19
+ GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
+ GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
+ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
22
+ GGML_API void ggml_cuda_set_main_device(int main_device);
23
+ GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
+ GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
25
+ GGML_API void ggml_cuda_free_scratch(void);
26
+ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
27
+
28
+ GGML_API int ggml_cuda_get_device_count(void);
29
+ GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
30
 
35
31
  #ifdef __cplusplus
36
32
  }
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
63
63
 
64
64
  // try to find operations that can be run concurrently in the graph
65
65
  // you should run it again if the topology of your graph changes
66
- void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
67
67
 
68
- // if the graph has been optimized for concurrently dispatch
69
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
68
+ // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
69
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
71
+ // output the concur_list for ggml_alloc
72
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
70
73
 
71
74
  // same as ggml_graph_compute but uses Metal
72
75
  // creates gf->n_threads command buffers in parallel
@@ -5,7 +5,11 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
8
+
9
+ #undef MIN
10
+ #undef MAX
11
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
9
13
 
10
14
  #ifdef GGML_METAL_NDEBUG
11
15
  #define metal_printf(...)
@@ -15,6 +19,8 @@
15
19
 
16
20
  #define UNUSED(x) (void)(x)
17
21
 
22
+ #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
23
+
18
24
  struct ggml_metal_buffer {
19
25
  const char * name;
20
26
 
@@ -36,7 +42,7 @@ struct ggml_metal_context {
36
42
  int n_buffers;
37
43
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
44
 
39
- int concur_list[GGML_MAX_NODES];
45
+ int concur_list[GGML_MAX_CONCUR];
40
46
  int concur_list_len;
41
47
 
42
48
  // custom kernels
@@ -72,6 +78,14 @@ struct ggml_metal_context {
72
78
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
73
79
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
74
80
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
75
89
  GGML_METAL_DECL_KERNEL(rope);
76
90
  GGML_METAL_DECL_KERNEL(alibi_f32);
77
91
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -103,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
103
117
  ctx->n_buffers = 0;
104
118
  ctx->concur_list_len = 0;
105
119
 
106
- // determine if we can use MPS
107
- if (MPSSupportsMTLDevice(ctx->device)) {
108
- fprintf(stderr, "%s: using MPS\n", __func__);
109
- } else {
110
- fprintf(stderr, "%s: not using MPS\n", __func__);
111
- GGML_ASSERT(false && "MPS not supported");
112
- }
113
120
 
114
121
  #if 0
115
122
  // compile from source string and show compile log
@@ -119,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
119
126
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
120
127
  if (error) {
121
128
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
122
- exit(1);
129
+ return NULL;
123
130
  }
124
131
  }
125
132
  #else
@@ -137,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
137
144
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
138
145
  if (error) {
139
146
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
140
- exit(1);
147
+ return NULL;
141
148
  }
142
149
 
143
150
  #ifdef GGML_QKK_64
@@ -149,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
149
156
  #endif
150
157
  if (error) {
151
158
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
152
- exit(1);
159
+ return NULL;
153
160
  }
154
161
  }
155
162
  #endif
156
163
 
157
164
  // load kernels
158
165
  {
166
+ NSError * error = nil;
159
167
  #define GGML_METAL_ADD_KERNEL(name) \
160
168
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
161
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
162
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
169
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
+ fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
171
+ if (error) { \
172
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
+ return NULL; \
174
+ }
163
175
 
164
176
  GGML_METAL_ADD_KERNEL(add);
165
177
  GGML_METAL_ADD_KERNEL(add_row);
@@ -189,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
189
201
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
190
202
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
191
203
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
206
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
192
212
  GGML_METAL_ADD_KERNEL(rope);
193
213
  GGML_METAL_ADD_KERNEL(alibi_f32);
194
214
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -221,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
221
241
  ctx->n_cb = n_cb;
222
242
  }
223
243
 
224
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
- if (ctx->concur_list_len) {
226
- return true;
227
- }
228
- return false;
244
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
245
+ return ctx->concur_list_len;
246
+ }
247
+
248
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
249
+ return ctx->concur_list;
229
250
  }
230
251
 
231
252
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -368,17 +389,17 @@ void ggml_metal_get_tensor(
368
389
 
369
390
  void ggml_metal_graph_find_concurrency(
370
391
  struct ggml_metal_context * ctx,
371
- struct ggml_cgraph * gf) {
392
+ struct ggml_cgraph * gf, bool check_mem) {
372
393
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
- int nodes_unused[GGML_MAX_NODES];
394
+ int nodes_unused[GGML_MAX_CONCUR];
374
395
 
375
- for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
- for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
396
+ for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
397
+ for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
377
398
  ctx->concur_list_len = 0;
378
399
 
379
- int n_left = gf->n_nodes;
380
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
400
+ int n_left = gf->n_nodes;
401
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
402
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
403
 
383
404
  while (n_left > 0) {
384
405
  // number of nodes at a layer (that can be issued concurrently)
@@ -386,28 +407,40 @@ void ggml_metal_graph_find_concurrency(
386
407
  for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
408
  if (nodes_unused[i]) {
388
409
  // if the requirements for gf->nodes[i] are satisfied
389
- int exe_flag=1;
410
+ int exe_flag = 1;
411
+
390
412
  // scan all srcs
391
413
  for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
414
  struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
415
  if (src_cur) {
394
416
  // if is leaf nodes it's satisfied.
395
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
417
+ // TODO: ggml_is_leaf()
418
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
419
+ continue;
420
+ }
396
421
 
397
422
  // otherwise this src should be the output from previous nodes.
398
423
  int is_found = 0;
424
+
399
425
  // scan 2*search_depth back because we inserted barrier.
400
- for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
- if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
426
+ //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
427
+ for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
428
+ if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
429
+ is_found = 1;
430
+ break;
431
+ }
432
+ }
433
+ if (is_found == 0) {
434
+ exe_flag = 0;
435
+ break;
402
436
  }
403
- if (is_found == 0) {exe_flag = 0; break;}
404
437
  }
405
438
  }
406
- if (exe_flag) {
439
+ if (exe_flag && check_mem) {
407
440
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
441
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
442
  int64_t data_start = (int64_t) gf->nodes[i]->data;
410
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
443
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
444
  for (int j = n_start; j < i; j++) {
412
445
  if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
446
  && gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +449,9 @@ void ggml_metal_graph_find_concurrency(
416
449
  if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
450
  ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
451
  continue;
419
- } else {
420
- exe_flag = 0;
421
452
  }
453
+
454
+ exe_flag = 0;
422
455
  }
423
456
  }
424
457
  }
@@ -435,11 +468,13 @@ void ggml_metal_graph_find_concurrency(
435
468
  ctx->concur_list[level_pos + concurrency] = -1;
436
469
  ctx->concur_list_len++;
437
470
  // jump all sorted nodes at nodes_bak
438
- while (!nodes_unused[n_start]) {n_start++;}
471
+ while (!nodes_unused[n_start]) {
472
+ n_start++;
473
+ }
439
474
  level_pos += concurrency + 1;
440
475
  }
441
476
 
442
- if (ctx->concur_list_len > GGML_MAX_NODES) {
477
+ if (ctx->concur_list_len > GGML_MAX_CONCUR) {
443
478
  fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
479
  }
445
480
  }
@@ -453,7 +488,7 @@ void ggml_metal_graph_compute(
453
488
  // else fallback to serial dispatch
454
489
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
490
 
456
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
491
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
457
492
 
458
493
  const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
494
  edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
@@ -485,7 +520,7 @@ void ggml_metal_graph_compute(
485
520
 
486
521
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
487
522
 
488
- id<MTLComputeCommandEncoder> encoder = nil;
523
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
489
524
 
490
525
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
526
  const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -494,10 +529,6 @@ void ggml_metal_graph_compute(
494
529
  const int i = has_concur ? ctx->concur_list[ind] : ind;
495
530
 
496
531
  if (i == -1) {
497
- if (encoder == nil) {
498
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
- continue;
500
- }
501
532
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
533
  continue;
503
534
  }
@@ -571,10 +602,6 @@ void ggml_metal_graph_compute(
571
602
  } break;
572
603
  case GGML_OP_ADD:
573
604
  {
574
- if (encoder == nil) {
575
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
576
- }
577
-
578
605
  if (ggml_nelements(src1) == ne10) {
579
606
  // src1 is a row
580
607
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -592,10 +619,6 @@ void ggml_metal_graph_compute(
592
619
  } break;
593
620
  case GGML_OP_MUL:
594
621
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
622
  if (ggml_nelements(src1) == ne10) {
600
623
  // src1 is a row
601
624
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -613,10 +636,6 @@ void ggml_metal_graph_compute(
613
636
  } break;
614
637
  case GGML_OP_SCALE:
615
638
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
639
  const float scale = *(const float *) src1->data;
621
640
 
622
641
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -632,10 +651,6 @@ void ggml_metal_graph_compute(
632
651
  switch (ggml_get_unary_op(gf->nodes[i])) {
633
652
  case GGML_UNARY_OP_SILU:
634
653
  {
635
- if (encoder == nil) {
636
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
- }
638
-
639
654
  [encoder setComputePipelineState:ctx->pipeline_silu];
640
655
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
656
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -646,10 +661,6 @@ void ggml_metal_graph_compute(
646
661
  } break;
647
662
  case GGML_UNARY_OP_RELU:
648
663
  {
649
- if (encoder == nil) {
650
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
- }
652
-
653
664
  [encoder setComputePipelineState:ctx->pipeline_relu];
654
665
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
666
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -660,10 +671,6 @@ void ggml_metal_graph_compute(
660
671
  } break;
661
672
  case GGML_UNARY_OP_GELU:
662
673
  {
663
- if (encoder == nil) {
664
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
- }
666
-
667
674
  [encoder setComputePipelineState:ctx->pipeline_gelu];
668
675
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
676
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -680,10 +687,6 @@ void ggml_metal_graph_compute(
680
687
  } break;
681
688
  case GGML_OP_SOFT_MAX:
682
689
  {
683
- if (encoder == nil) {
684
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
685
- }
686
-
687
690
  const int nth = 32;
688
691
 
689
692
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -698,10 +701,6 @@ void ggml_metal_graph_compute(
698
701
  } break;
699
702
  case GGML_OP_DIAG_MASK_INF:
700
703
  {
701
- if (encoder == nil) {
702
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
703
- }
704
-
705
704
  const int n_past = ((int32_t *)(dst->op_params))[0];
706
705
 
707
706
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -719,53 +718,43 @@ void ggml_metal_graph_compute(
719
718
 
720
719
  GGML_ASSERT(ne00 == ne10);
721
720
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
721
+ uint gqa = ne12/ne02;
722
722
  GGML_ASSERT(ne03 == ne13);
723
723
 
724
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
725
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
724
726
  if (ggml_is_contiguous(src0) &&
725
727
  ggml_is_contiguous(src1) &&
726
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
727
-
728
- if (encoder != nil) {
729
- [encoder endEncoding];
730
- encoder = nil;
731
- }
732
-
733
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
734
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
735
-
736
- // for F32 x F32 we use MPS
737
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
738
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
739
-
740
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
741
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
742
-
743
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
744
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
745
-
746
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
747
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
748
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
749
-
750
- // we need to do ne12 multiplications
751
- // TODO: is there a way to do this in parallel - currently very slow ..
752
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
753
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
754
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
755
- size_t offs_src1_cur = offs_src1 + i02*nb12;
756
- size_t offs_dst_cur = offs_dst + i02*nb2;
757
-
758
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
759
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
760
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
761
-
762
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
763
- }
764
- } else {
765
- if (encoder == nil) {
766
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
728
+ src1t == GGML_TYPE_F32 &&
729
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
+ ne00%32 == 0 &&
731
+ ne11 > 1) {
732
+ switch (src0->type) {
733
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
+ }
743
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
767
756
  }
768
-
757
+ else {
769
758
  int nth0 = 32;
770
759
  int nth1 = 1;
771
760
 
@@ -864,23 +853,24 @@ void ggml_metal_graph_compute(
864
853
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
865
854
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
866
855
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
867
857
 
868
858
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
869
859
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
861
  }
872
862
  else if (src0t == GGML_TYPE_Q3_K) {
873
863
  #ifdef GGML_QKK_64
874
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
864
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
875
865
  #else
876
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
866
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
877
867
  #endif
878
868
  }
879
869
  else if (src0t == GGML_TYPE_Q5_K) {
880
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
881
871
  }
882
872
  else if (src0t == GGML_TYPE_Q6_K) {
883
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
884
874
  } else {
885
875
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
886
876
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -889,10 +879,6 @@ void ggml_metal_graph_compute(
889
879
  } break;
890
880
  case GGML_OP_GET_ROWS:
891
881
  {
892
- if (encoder == nil) {
893
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
894
- }
895
-
896
882
  switch (src0->type) {
897
883
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
898
884
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -918,10 +904,6 @@ void ggml_metal_graph_compute(
918
904
  } break;
919
905
  case GGML_OP_RMS_NORM:
920
906
  {
921
- if (encoder == nil) {
922
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
923
- }
924
-
925
907
  float eps;
926
908
  memcpy(&eps, dst->op_params, sizeof(float));
927
909
 
@@ -941,10 +923,6 @@ void ggml_metal_graph_compute(
941
923
  } break;
942
924
  case GGML_OP_NORM:
943
925
  {
944
- if (encoder == nil) {
945
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
946
- }
947
-
948
926
  const float eps = 1e-5f;
949
927
 
950
928
  const int nth = 256;
@@ -963,10 +941,6 @@ void ggml_metal_graph_compute(
963
941
  } break;
964
942
  case GGML_OP_ALIBI:
965
943
  {
966
- if (encoder == nil) {
967
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
968
- }
969
-
970
944
  GGML_ASSERT((src0t == GGML_TYPE_F32));
971
945
 
972
946
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1006,10 +980,6 @@ void ggml_metal_graph_compute(
1006
980
  } break;
1007
981
  case GGML_OP_ROPE:
1008
982
  {
1009
- if (encoder == nil) {
1010
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1011
- }
1012
-
1013
983
  const int n_past = ((int32_t *) dst->op_params)[0];
1014
984
  const int n_dims = ((int32_t *) dst->op_params)[1];
1015
985
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1050,10 +1020,6 @@ void ggml_metal_graph_compute(
1050
1020
  case GGML_OP_CPY:
1051
1021
  case GGML_OP_CONT:
1052
1022
  {
1053
- if (encoder == nil) {
1054
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1055
- }
1056
-
1057
1023
  const int nth = 32;
1058
1024
 
1059
1025
  switch (src0t) {