llama_cpp 0.3.6 → 0.3.8

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,29 +8,25 @@ extern "C" {
8
8
 
9
9
  #define GGML_CUDA_MAX_DEVICES 16
10
10
 
11
- void ggml_init_cublas(void);
12
- void ggml_cuda_set_tensor_split(const float * tensor_split);
13
-
14
- void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
- bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
- size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
17
- void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
18
-
19
- // TODO: export these with GGML_API
20
- void * ggml_cuda_host_malloc(size_t size);
21
- void ggml_cuda_host_free(void * ptr);
22
-
23
- void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
24
-
25
- void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
- void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
- void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
- void ggml_cuda_set_main_device(int main_device);
30
- void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
31
- void ggml_cuda_set_scratch_size(size_t scratch_size);
32
- void ggml_cuda_free_scratch(void);
33
- bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
11
+ GGML_API void ggml_init_cublas(void);
12
+ GGML_API void * ggml_cuda_host_malloc(size_t size);
13
+ GGML_API void ggml_cuda_host_free(void * ptr);
14
+
15
+ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
+ GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
+ GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
+ GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19
+ GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
+ GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
+ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
22
+ GGML_API void ggml_cuda_set_main_device(int main_device);
23
+ GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
+ GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
25
+ GGML_API void ggml_cuda_free_scratch(void);
26
+ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
27
+
28
+ GGML_API int ggml_cuda_get_device_count(void);
29
+ GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
30
 
35
31
  #ifdef __cplusplus
36
32
  }
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
63
63
 
64
64
  // try to find operations that can be run concurrently in the graph
65
65
  // you should run it again if the topology of your graph changes
66
- void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
67
67
 
68
- // if the graph has been optimized for concurrently dispatch
69
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
68
+ // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
69
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
71
+ // output the concur_list for ggml_alloc
72
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
70
73
 
71
74
  // same as ggml_graph_compute but uses Metal
72
75
  // creates gf->n_threads command buffers in parallel
@@ -5,7 +5,11 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
8
+
9
+ #undef MIN
10
+ #undef MAX
11
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
9
13
 
10
14
  #ifdef GGML_METAL_NDEBUG
11
15
  #define metal_printf(...)
@@ -15,6 +19,8 @@
15
19
 
16
20
  #define UNUSED(x) (void)(x)
17
21
 
22
+ #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
23
+
18
24
  struct ggml_metal_buffer {
19
25
  const char * name;
20
26
 
@@ -36,7 +42,7 @@ struct ggml_metal_context {
36
42
  int n_buffers;
37
43
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
44
 
39
- int concur_list[GGML_MAX_NODES];
45
+ int concur_list[GGML_MAX_CONCUR];
40
46
  int concur_list_len;
41
47
 
42
48
  // custom kernels
@@ -72,6 +78,14 @@ struct ggml_metal_context {
72
78
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
73
79
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
74
80
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
75
89
  GGML_METAL_DECL_KERNEL(rope);
76
90
  GGML_METAL_DECL_KERNEL(alibi_f32);
77
91
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -103,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
103
117
  ctx->n_buffers = 0;
104
118
  ctx->concur_list_len = 0;
105
119
 
106
- // determine if we can use MPS
107
- if (MPSSupportsMTLDevice(ctx->device)) {
108
- fprintf(stderr, "%s: using MPS\n", __func__);
109
- } else {
110
- fprintf(stderr, "%s: not using MPS\n", __func__);
111
- GGML_ASSERT(false && "MPS not supported");
112
- }
113
120
 
114
121
  #if 0
115
122
  // compile from source string and show compile log
@@ -119,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
119
126
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
120
127
  if (error) {
121
128
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
122
- exit(1);
129
+ return NULL;
123
130
  }
124
131
  }
125
132
  #else
@@ -137,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
137
144
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
138
145
  if (error) {
139
146
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
140
- exit(1);
147
+ return NULL;
141
148
  }
142
149
 
143
150
  #ifdef GGML_QKK_64
@@ -149,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
149
156
  #endif
150
157
  if (error) {
151
158
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
152
- exit(1);
159
+ return NULL;
153
160
  }
154
161
  }
155
162
  #endif
156
163
 
157
164
  // load kernels
158
165
  {
166
+ NSError * error = nil;
159
167
  #define GGML_METAL_ADD_KERNEL(name) \
160
168
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
161
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
162
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
169
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
+ fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
171
+ if (error) { \
172
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
+ return NULL; \
174
+ }
163
175
 
164
176
  GGML_METAL_ADD_KERNEL(add);
165
177
  GGML_METAL_ADD_KERNEL(add_row);
@@ -189,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
189
201
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
190
202
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
191
203
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
206
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
192
212
  GGML_METAL_ADD_KERNEL(rope);
193
213
  GGML_METAL_ADD_KERNEL(alibi_f32);
194
214
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -221,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
221
241
  ctx->n_cb = n_cb;
222
242
  }
223
243
 
224
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
- if (ctx->concur_list_len) {
226
- return true;
227
- }
228
- return false;
244
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
245
+ return ctx->concur_list_len;
246
+ }
247
+
248
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
249
+ return ctx->concur_list;
229
250
  }
230
251
 
231
252
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -368,17 +389,17 @@ void ggml_metal_get_tensor(
368
389
 
369
390
  void ggml_metal_graph_find_concurrency(
370
391
  struct ggml_metal_context * ctx,
371
- struct ggml_cgraph * gf) {
392
+ struct ggml_cgraph * gf, bool check_mem) {
372
393
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
- int nodes_unused[GGML_MAX_NODES];
394
+ int nodes_unused[GGML_MAX_CONCUR];
374
395
 
375
- for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
- for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
396
+ for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
397
+ for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
377
398
  ctx->concur_list_len = 0;
378
399
 
379
- int n_left = gf->n_nodes;
380
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
400
+ int n_left = gf->n_nodes;
401
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
402
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
403
 
383
404
  while (n_left > 0) {
384
405
  // number of nodes at a layer (that can be issued concurrently)
@@ -386,28 +407,40 @@ void ggml_metal_graph_find_concurrency(
386
407
  for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
408
  if (nodes_unused[i]) {
388
409
  // if the requirements for gf->nodes[i] are satisfied
389
- int exe_flag=1;
410
+ int exe_flag = 1;
411
+
390
412
  // scan all srcs
391
413
  for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
414
  struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
415
  if (src_cur) {
394
416
  // if is leaf nodes it's satisfied.
395
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
417
+ // TODO: ggml_is_leaf()
418
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
419
+ continue;
420
+ }
396
421
 
397
422
  // otherwise this src should be the output from previous nodes.
398
423
  int is_found = 0;
424
+
399
425
  // scan 2*search_depth back because we inserted barrier.
400
- for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
- if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
426
+ //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
427
+ for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
428
+ if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
429
+ is_found = 1;
430
+ break;
431
+ }
432
+ }
433
+ if (is_found == 0) {
434
+ exe_flag = 0;
435
+ break;
402
436
  }
403
- if (is_found == 0) {exe_flag = 0; break;}
404
437
  }
405
438
  }
406
- if (exe_flag) {
439
+ if (exe_flag && check_mem) {
407
440
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
441
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
442
  int64_t data_start = (int64_t) gf->nodes[i]->data;
410
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
443
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
444
  for (int j = n_start; j < i; j++) {
412
445
  if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
446
  && gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +449,9 @@ void ggml_metal_graph_find_concurrency(
416
449
  if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
450
  ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
451
  continue;
419
- } else {
420
- exe_flag = 0;
421
452
  }
453
+
454
+ exe_flag = 0;
422
455
  }
423
456
  }
424
457
  }
@@ -435,11 +468,13 @@ void ggml_metal_graph_find_concurrency(
435
468
  ctx->concur_list[level_pos + concurrency] = -1;
436
469
  ctx->concur_list_len++;
437
470
  // jump all sorted nodes at nodes_bak
438
- while (!nodes_unused[n_start]) {n_start++;}
471
+ while (!nodes_unused[n_start]) {
472
+ n_start++;
473
+ }
439
474
  level_pos += concurrency + 1;
440
475
  }
441
476
 
442
- if (ctx->concur_list_len > GGML_MAX_NODES) {
477
+ if (ctx->concur_list_len > GGML_MAX_CONCUR) {
443
478
  fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
479
  }
445
480
  }
@@ -453,7 +488,7 @@ void ggml_metal_graph_compute(
453
488
  // else fallback to serial dispatch
454
489
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
490
 
456
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
491
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
457
492
 
458
493
  const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
494
  edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
@@ -485,7 +520,7 @@ void ggml_metal_graph_compute(
485
520
 
486
521
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
487
522
 
488
- id<MTLComputeCommandEncoder> encoder = nil;
523
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
489
524
 
490
525
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
526
  const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -494,10 +529,6 @@ void ggml_metal_graph_compute(
494
529
  const int i = has_concur ? ctx->concur_list[ind] : ind;
495
530
 
496
531
  if (i == -1) {
497
- if (encoder == nil) {
498
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
- continue;
500
- }
501
532
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
533
  continue;
503
534
  }
@@ -571,10 +602,6 @@ void ggml_metal_graph_compute(
571
602
  } break;
572
603
  case GGML_OP_ADD:
573
604
  {
574
- if (encoder == nil) {
575
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
576
- }
577
-
578
605
  if (ggml_nelements(src1) == ne10) {
579
606
  // src1 is a row
580
607
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -592,10 +619,6 @@ void ggml_metal_graph_compute(
592
619
  } break;
593
620
  case GGML_OP_MUL:
594
621
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
622
  if (ggml_nelements(src1) == ne10) {
600
623
  // src1 is a row
601
624
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -613,10 +636,6 @@ void ggml_metal_graph_compute(
613
636
  } break;
614
637
  case GGML_OP_SCALE:
615
638
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
639
  const float scale = *(const float *) src1->data;
621
640
 
622
641
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -632,10 +651,6 @@ void ggml_metal_graph_compute(
632
651
  switch (ggml_get_unary_op(gf->nodes[i])) {
633
652
  case GGML_UNARY_OP_SILU:
634
653
  {
635
- if (encoder == nil) {
636
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
- }
638
-
639
654
  [encoder setComputePipelineState:ctx->pipeline_silu];
640
655
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
656
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -646,10 +661,6 @@ void ggml_metal_graph_compute(
646
661
  } break;
647
662
  case GGML_UNARY_OP_RELU:
648
663
  {
649
- if (encoder == nil) {
650
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
- }
652
-
653
664
  [encoder setComputePipelineState:ctx->pipeline_relu];
654
665
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
666
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -660,10 +671,6 @@ void ggml_metal_graph_compute(
660
671
  } break;
661
672
  case GGML_UNARY_OP_GELU:
662
673
  {
663
- if (encoder == nil) {
664
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
- }
666
-
667
674
  [encoder setComputePipelineState:ctx->pipeline_gelu];
668
675
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
676
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -680,10 +687,6 @@ void ggml_metal_graph_compute(
680
687
  } break;
681
688
  case GGML_OP_SOFT_MAX:
682
689
  {
683
- if (encoder == nil) {
684
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
685
- }
686
-
687
690
  const int nth = 32;
688
691
 
689
692
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -698,10 +701,6 @@ void ggml_metal_graph_compute(
698
701
  } break;
699
702
  case GGML_OP_DIAG_MASK_INF:
700
703
  {
701
- if (encoder == nil) {
702
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
703
- }
704
-
705
704
  const int n_past = ((int32_t *)(dst->op_params))[0];
706
705
 
707
706
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -719,53 +718,43 @@ void ggml_metal_graph_compute(
719
718
 
720
719
  GGML_ASSERT(ne00 == ne10);
721
720
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
721
+ uint gqa = ne12/ne02;
722
722
  GGML_ASSERT(ne03 == ne13);
723
723
 
724
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
725
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
724
726
  if (ggml_is_contiguous(src0) &&
725
727
  ggml_is_contiguous(src1) &&
726
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
727
-
728
- if (encoder != nil) {
729
- [encoder endEncoding];
730
- encoder = nil;
731
- }
732
-
733
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
734
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
735
-
736
- // for F32 x F32 we use MPS
737
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
738
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
739
-
740
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
741
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
742
-
743
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
744
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
745
-
746
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
747
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
748
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
749
-
750
- // we need to do ne12 multiplications
751
- // TODO: is there a way to do this in parallel - currently very slow ..
752
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
753
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
754
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
755
- size_t offs_src1_cur = offs_src1 + i02*nb12;
756
- size_t offs_dst_cur = offs_dst + i02*nb2;
757
-
758
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
759
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
760
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
761
-
762
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
763
- }
764
- } else {
765
- if (encoder == nil) {
766
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
728
+ src1t == GGML_TYPE_F32 &&
729
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
+ ne00%32 == 0 &&
731
+ ne11 > 1) {
732
+ switch (src0->type) {
733
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
+ }
743
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
767
756
  }
768
-
757
+ else {
769
758
  int nth0 = 32;
770
759
  int nth1 = 1;
771
760
 
@@ -864,23 +853,24 @@ void ggml_metal_graph_compute(
864
853
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
865
854
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
866
855
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
867
857
 
868
858
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
869
859
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
861
  }
872
862
  else if (src0t == GGML_TYPE_Q3_K) {
873
863
  #ifdef GGML_QKK_64
874
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
864
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
875
865
  #else
876
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
866
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
877
867
  #endif
878
868
  }
879
869
  else if (src0t == GGML_TYPE_Q5_K) {
880
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
881
871
  }
882
872
  else if (src0t == GGML_TYPE_Q6_K) {
883
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
884
874
  } else {
885
875
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
886
876
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -889,10 +879,6 @@ void ggml_metal_graph_compute(
889
879
  } break;
890
880
  case GGML_OP_GET_ROWS:
891
881
  {
892
- if (encoder == nil) {
893
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
894
- }
895
-
896
882
  switch (src0->type) {
897
883
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
898
884
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -918,10 +904,6 @@ void ggml_metal_graph_compute(
918
904
  } break;
919
905
  case GGML_OP_RMS_NORM:
920
906
  {
921
- if (encoder == nil) {
922
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
923
- }
924
-
925
907
  float eps;
926
908
  memcpy(&eps, dst->op_params, sizeof(float));
927
909
 
@@ -941,10 +923,6 @@ void ggml_metal_graph_compute(
941
923
  } break;
942
924
  case GGML_OP_NORM:
943
925
  {
944
- if (encoder == nil) {
945
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
946
- }
947
-
948
926
  const float eps = 1e-5f;
949
927
 
950
928
  const int nth = 256;
@@ -963,10 +941,6 @@ void ggml_metal_graph_compute(
963
941
  } break;
964
942
  case GGML_OP_ALIBI:
965
943
  {
966
- if (encoder == nil) {
967
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
968
- }
969
-
970
944
  GGML_ASSERT((src0t == GGML_TYPE_F32));
971
945
 
972
946
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1006,10 +980,6 @@ void ggml_metal_graph_compute(
1006
980
  } break;
1007
981
  case GGML_OP_ROPE:
1008
982
  {
1009
- if (encoder == nil) {
1010
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1011
- }
1012
-
1013
983
  const int n_past = ((int32_t *) dst->op_params)[0];
1014
984
  const int n_dims = ((int32_t *) dst->op_params)[1];
1015
985
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1050,10 +1020,6 @@ void ggml_metal_graph_compute(
1050
1020
  case GGML_OP_CPY:
1051
1021
  case GGML_OP_CONT:
1052
1022
  {
1053
- if (encoder == nil) {
1054
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1055
- }
1056
-
1057
1023
  const int nth = 32;
1058
1024
 
1059
1025
  switch (src0t) {