llama_cpp 0.12.0 → 0.12.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
27
+ #define GGML_METAL_MAX_KERNELS 256
28
28
 
29
29
  struct ggml_metal_buffer {
30
30
  const char * name;
@@ -35,6 +35,134 @@ struct ggml_metal_buffer {
35
35
  id<MTLBuffer> metal;
36
36
  };
37
37
 
38
+ struct ggml_metal_kernel {
39
+ id<MTLFunction> function;
40
+ id<MTLComputePipelineState> pipeline;
41
+ };
42
+
43
+ enum ggml_metal_kernel_type {
44
+ GGML_METAL_KERNEL_TYPE_ADD,
45
+ GGML_METAL_KERNEL_TYPE_ADD_ROW,
46
+ GGML_METAL_KERNEL_TYPE_MUL,
47
+ GGML_METAL_KERNEL_TYPE_MUL_ROW,
48
+ GGML_METAL_KERNEL_TYPE_DIV,
49
+ GGML_METAL_KERNEL_TYPE_DIV_ROW,
50
+ GGML_METAL_KERNEL_TYPE_SCALE,
51
+ GGML_METAL_KERNEL_TYPE_SCALE_4,
52
+ GGML_METAL_KERNEL_TYPE_TANH,
53
+ GGML_METAL_KERNEL_TYPE_RELU,
54
+ GGML_METAL_KERNEL_TYPE_GELU,
55
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK,
56
+ GGML_METAL_KERNEL_TYPE_SILU,
57
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX,
58
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
59
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
60
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
61
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
62
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
66
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
69
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
70
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
71
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
72
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
75
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
+ GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
+ GGML_METAL_KERNEL_TYPE_GROUP_NORM,
78
+ GGML_METAL_KERNEL_TYPE_NORM,
79
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
80
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
81
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
82
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
83
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
84
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
86
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
93
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
96
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
99
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
100
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
101
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
102
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
104
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
105
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
106
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
107
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
110
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
113
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
116
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
117
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
120
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
121
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
122
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
123
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
124
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
130
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
131
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
132
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
136
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
137
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
+ GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
+ GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
+ GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
+ GGML_METAL_KERNEL_TYPE_IM2COL_F16,
145
+ GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
+ GGML_METAL_KERNEL_TYPE_PAD_F32,
147
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
148
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
149
+ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
150
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
151
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
152
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
153
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
154
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
155
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
156
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
157
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
158
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
159
+ GGML_METAL_KERNEL_TYPE_CONCAT,
160
+ GGML_METAL_KERNEL_TYPE_SQR,
161
+ GGML_METAL_KERNEL_TYPE_SUM_ROWS,
162
+
163
+ GGML_METAL_KERNEL_TYPE_COUNT
164
+ };
165
+
38
166
  struct ggml_metal_context {
39
167
  int n_cb;
40
168
 
@@ -42,132 +170,15 @@ struct ggml_metal_context {
42
170
  id<MTLCommandQueue> queue;
43
171
  id<MTLLibrary> library;
44
172
 
45
- id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
46
- id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
47
-
48
173
  dispatch_queue_t d_queue;
49
174
 
50
175
  int n_buffers;
51
176
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
52
177
 
53
- int concur_list[GGML_MAX_CONCUR];
54
- int concur_list_len;
55
-
56
- // custom kernels
57
- #define GGML_METAL_DECL_KERNEL(name) \
58
- id<MTLFunction> function_##name; \
59
- id<MTLComputePipelineState> pipeline_##name
60
-
61
- GGML_METAL_DECL_KERNEL(add);
62
- GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
- GGML_METAL_DECL_KERNEL(mul);
64
- GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
- GGML_METAL_DECL_KERNEL(div);
66
- GGML_METAL_DECL_KERNEL(div_row);
67
- GGML_METAL_DECL_KERNEL(scale);
68
- GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(tanh);
70
- GGML_METAL_DECL_KERNEL(relu);
71
- GGML_METAL_DECL_KERNEL(gelu);
72
- GGML_METAL_DECL_KERNEL(gelu_quick);
73
- GGML_METAL_DECL_KERNEL(silu);
74
- GGML_METAL_DECL_KERNEL(soft_max);
75
- GGML_METAL_DECL_KERNEL(soft_max_4);
76
- GGML_METAL_DECL_KERNEL(diag_mask_inf);
77
- GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
78
- GGML_METAL_DECL_KERNEL(get_rows_f32);
79
- GGML_METAL_DECL_KERNEL(get_rows_f16);
80
- GGML_METAL_DECL_KERNEL(get_rows_q4_0);
81
- GGML_METAL_DECL_KERNEL(get_rows_q4_1);
82
- GGML_METAL_DECL_KERNEL(get_rows_q5_0);
83
- GGML_METAL_DECL_KERNEL(get_rows_q5_1);
84
- GGML_METAL_DECL_KERNEL(get_rows_q8_0);
85
- GGML_METAL_DECL_KERNEL(get_rows_q2_K);
86
- GGML_METAL_DECL_KERNEL(get_rows_q3_K);
87
- GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
- GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
- GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
- GGML_METAL_DECL_KERNEL(get_rows_i32);
91
- GGML_METAL_DECL_KERNEL(rms_norm);
92
- GGML_METAL_DECL_KERNEL(group_norm);
93
- GGML_METAL_DECL_KERNEL(norm);
94
- GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
95
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
96
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
97
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
98
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
99
- GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
100
- GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
101
- GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
102
- GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
103
- GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
104
- GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
105
- GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
106
- GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
107
- GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
108
- GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
109
- GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
110
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
111
- GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
112
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
113
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
114
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
115
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
116
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
117
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
118
- GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
119
- GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
120
- GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
121
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
122
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
123
- GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
124
- GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
125
- GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
126
- GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
127
- GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
128
- GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
129
- GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
130
- GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
131
- GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
132
- GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
133
- GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
134
- GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
135
- GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
136
- GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
137
- GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
138
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
139
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
140
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
141
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
142
- GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
143
- GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
144
- GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
145
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
146
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
147
- GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
148
- GGML_METAL_DECL_KERNEL(rope_f32);
149
- GGML_METAL_DECL_KERNEL(rope_f16);
150
- GGML_METAL_DECL_KERNEL(alibi_f32);
151
- GGML_METAL_DECL_KERNEL(im2col_f16);
152
- GGML_METAL_DECL_KERNEL(upscale_f32);
153
- GGML_METAL_DECL_KERNEL(pad_f32);
154
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
155
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
156
- GGML_METAL_DECL_KERNEL(leaky_relu_f32);
157
- GGML_METAL_DECL_KERNEL(cpy_f32_f16);
158
- GGML_METAL_DECL_KERNEL(cpy_f32_f32);
159
- GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
160
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
161
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
162
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
163
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
164
- GGML_METAL_DECL_KERNEL(cpy_f16_f16);
165
- GGML_METAL_DECL_KERNEL(cpy_f16_f32);
166
- GGML_METAL_DECL_KERNEL(concat);
167
- GGML_METAL_DECL_KERNEL(sqr);
168
- GGML_METAL_DECL_KERNEL(sum_rows);
169
-
170
- #undef GGML_METAL_DECL_KERNEL
178
+ struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
179
+
180
+ bool support_simdgroup_reduction;
181
+ bool support_simdgroup_mm;
171
182
  };
172
183
 
173
184
  // MSL code
@@ -181,7 +192,6 @@ struct ggml_metal_context {
181
192
  @implementation GGMLMetalClass
182
193
  @end
183
194
 
184
-
185
195
  static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
186
196
  fprintf(stderr, "%s", msg);
187
197
 
@@ -192,11 +202,6 @@ static void ggml_metal_default_log_callback(enum ggml_log_level level, const cha
192
202
  ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
193
203
  void * ggml_metal_log_user_data = NULL;
194
204
 
195
- void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
196
- ggml_metal_log_callback = log_callback;
197
- ggml_metal_log_user_data = user_data;
198
- }
199
-
200
205
  GGML_ATTRIBUTE_FORMAT(2, 3)
201
206
  static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
202
207
  if (ggml_metal_log_callback != NULL) {
@@ -219,7 +224,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
219
224
  }
220
225
  }
221
226
 
222
- struct ggml_metal_context * ggml_metal_init(int n_cb) {
227
+ static void * ggml_metal_host_malloc(size_t n) {
228
+ void * data = NULL;
229
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
230
+ if (result != 0) {
231
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
232
+ return NULL;
233
+ }
234
+
235
+ return data;
236
+ }
237
+
238
+ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
223
239
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
224
240
 
225
241
  id<MTLDevice> device;
@@ -245,7 +261,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
261
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
246
262
  ctx->queue = [ctx->device newCommandQueue];
247
263
  ctx->n_buffers = 0;
248
- ctx->concur_list_len = 0;
249
264
 
250
265
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
251
266
 
@@ -258,14 +273,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
258
273
  bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
259
274
  #endif
260
275
  NSError * error = nil;
261
- NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"];
276
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
262
277
  if (libPath != nil) {
263
278
  // pre-compiled library found
264
279
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
265
280
  GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
266
281
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
267
282
  } else {
268
- GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__);
283
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
269
284
 
270
285
  NSString * sourcePath;
271
286
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
@@ -288,19 +303,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
288
303
  return NULL;
289
304
  }
290
305
 
291
- MTLCompileOptions* options = nil;
306
+ // dictionary of preprocessor macros
307
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
308
+
292
309
  #ifdef GGML_QKK_64
293
- options = [MTLCompileOptions new];
294
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
310
+ prep[@"QK_K"] = @(64);
295
311
  #endif
296
- // try to disable fast-math
297
- // NOTE: this seems to have no effect whatsoever
298
- // instead, in order to disable fast-math, we have to build ggml.metallib from the command line
299
- // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
300
- // and go through the "pre-compiled library found" path above
312
+
313
+ MTLCompileOptions* options = [MTLCompileOptions new];
314
+ options.preprocessorMacros = prep;
315
+
301
316
  //[options setFastMathEnabled:false];
302
317
 
303
318
  ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
319
+
320
+ [options release];
321
+ [prep release];
304
322
  }
305
323
 
306
324
  if (error) {
@@ -309,22 +327,51 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
309
327
  }
310
328
  }
311
329
 
312
- #if TARGET_OS_OSX
313
330
  // print MTL GPU family:
314
331
  GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
315
332
 
333
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
334
+
316
335
  // determine max supported GPU family
317
336
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
318
337
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
319
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
320
- if ([ctx->device supportsFamily:i]) {
321
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
322
- break;
338
+ {
339
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
340
+ if ([ctx->device supportsFamily:i]) {
341
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
342
+ break;
343
+ }
344
+ }
345
+
346
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
347
+ if ([ctx->device supportsFamily:i]) {
348
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
349
+ break;
350
+ }
351
+ }
352
+
353
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
354
+ if ([ctx->device supportsFamily:i]) {
355
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
356
+ break;
357
+ }
323
358
  }
324
359
  }
325
360
 
361
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
362
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
363
+
364
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
365
+
366
+ GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
367
+ GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
326
368
  GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
327
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
369
+
370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
372
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
373
+ }
374
+ #elif TARGET_OS_OSX
328
375
  if (ctx->device.maxTransferRate != 0) {
329
376
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
330
377
  } else {
@@ -336,259 +383,171 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
336
383
  {
337
384
  NSError * error = nil;
338
385
 
386
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
387
+ ctx->kernels[i].function = nil;
388
+ ctx->kernels[i].pipeline = nil;
389
+ }
390
+
339
391
  /*
340
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
341
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
342
- (int) ctx->pipeline_##name.threadExecutionWidth); \
392
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
393
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
394
+ (int) kernel->pipeline.threadExecutionWidth); \
343
395
  */
344
- #define GGML_METAL_ADD_KERNEL(name) \
345
- ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
346
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
347
- if (error) { \
348
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
349
- return NULL; \
396
+ #define GGML_METAL_ADD_KERNEL(e, name, supported) \
397
+ if (supported) { \
398
+ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
399
+ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
400
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
401
+ if (error) { \
402
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
403
+ return NULL; \
404
+ } \
405
+ } else { \
406
+ GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
350
407
  }
351
408
 
352
- GGML_METAL_ADD_KERNEL(add);
353
- GGML_METAL_ADD_KERNEL(add_row);
354
- GGML_METAL_ADD_KERNEL(mul);
355
- GGML_METAL_ADD_KERNEL(mul_row);
356
- GGML_METAL_ADD_KERNEL(div);
357
- GGML_METAL_ADD_KERNEL(div_row);
358
- GGML_METAL_ADD_KERNEL(scale);
359
- GGML_METAL_ADD_KERNEL(scale_4);
360
- GGML_METAL_ADD_KERNEL(tanh);
361
- GGML_METAL_ADD_KERNEL(relu);
362
- GGML_METAL_ADD_KERNEL(gelu);
363
- GGML_METAL_ADD_KERNEL(gelu_quick);
364
- GGML_METAL_ADD_KERNEL(silu);
365
- GGML_METAL_ADD_KERNEL(soft_max);
366
- GGML_METAL_ADD_KERNEL(soft_max_4);
367
- GGML_METAL_ADD_KERNEL(diag_mask_inf);
368
- GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
369
- GGML_METAL_ADD_KERNEL(get_rows_f32);
370
- GGML_METAL_ADD_KERNEL(get_rows_f16);
371
- GGML_METAL_ADD_KERNEL(get_rows_q4_0);
372
- GGML_METAL_ADD_KERNEL(get_rows_q4_1);
373
- GGML_METAL_ADD_KERNEL(get_rows_q5_0);
374
- GGML_METAL_ADD_KERNEL(get_rows_q5_1);
375
- GGML_METAL_ADD_KERNEL(get_rows_q8_0);
376
- GGML_METAL_ADD_KERNEL(get_rows_q2_K);
377
- GGML_METAL_ADD_KERNEL(get_rows_q3_K);
378
- GGML_METAL_ADD_KERNEL(get_rows_q4_K);
379
- GGML_METAL_ADD_KERNEL(get_rows_q5_K);
380
- GGML_METAL_ADD_KERNEL(get_rows_q6_K);
381
- GGML_METAL_ADD_KERNEL(get_rows_i32);
382
- GGML_METAL_ADD_KERNEL(rms_norm);
383
- GGML_METAL_ADD_KERNEL(group_norm);
384
- GGML_METAL_ADD_KERNEL(norm);
385
- GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
386
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
387
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
388
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
389
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
390
- GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
391
- GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
392
- GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
393
- GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
394
- GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
395
- GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
396
- GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
397
- GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
398
- GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
399
- GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
400
- GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
401
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
402
- GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
403
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
404
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
405
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
406
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
407
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
408
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
409
- GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
410
- GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
411
- GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
412
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
413
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
414
- GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
415
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
416
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
417
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
418
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
419
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
420
- GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
421
- GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
422
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
423
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
424
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
425
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
426
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
427
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
428
- GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
429
- GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
430
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
431
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
432
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
433
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
434
- GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
435
- GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
436
- GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
437
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
438
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
439
- GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
440
- }
441
- GGML_METAL_ADD_KERNEL(rope_f32);
442
- GGML_METAL_ADD_KERNEL(rope_f16);
443
- GGML_METAL_ADD_KERNEL(alibi_f32);
444
- GGML_METAL_ADD_KERNEL(im2col_f16);
445
- GGML_METAL_ADD_KERNEL(upscale_f32);
446
- GGML_METAL_ADD_KERNEL(pad_f32);
447
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
448
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
449
- GGML_METAL_ADD_KERNEL(leaky_relu_f32);
450
- GGML_METAL_ADD_KERNEL(cpy_f32_f16);
451
- GGML_METAL_ADD_KERNEL(cpy_f32_f32);
452
- GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
453
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
454
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
455
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
456
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
457
- GGML_METAL_ADD_KERNEL(cpy_f16_f16);
458
- GGML_METAL_ADD_KERNEL(cpy_f16_f32);
459
- GGML_METAL_ADD_KERNEL(concat);
460
- GGML_METAL_ADD_KERNEL(sqr);
461
- GGML_METAL_ADD_KERNEL(sum_rows);
462
-
463
- #undef GGML_METAL_ADD_KERNEL
409
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
410
+
411
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
412
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
413
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
414
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
415
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
416
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
417
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
418
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
419
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
420
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
421
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
422
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
423
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
424
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
425
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
426
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
427
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
428
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
429
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
430
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
431
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
432
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
433
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
434
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
436
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
437
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
438
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
439
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
440
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
441
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
442
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
443
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
444
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
445
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
446
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
447
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
448
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
449
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
450
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
451
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
452
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
453
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
454
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
459
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
460
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
461
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
463
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
464
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
465
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
466
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
467
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
468
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
469
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
470
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
471
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
476
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
477
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
478
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
482
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
484
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
486
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
487
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
488
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
489
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
491
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
502
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
504
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
507
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
511
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
518
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
519
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
522
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
523
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
525
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
527
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
528
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
464
529
  }
465
530
 
466
531
  return ctx;
467
532
  }
468
533
 
469
- void ggml_metal_free(struct ggml_metal_context * ctx) {
534
+ static void ggml_metal_free(struct ggml_metal_context * ctx) {
470
535
  GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
471
- #define GGML_METAL_DEL_KERNEL(name) \
472
- [ctx->function_##name release]; \
473
- [ctx->pipeline_##name release];
474
-
475
- GGML_METAL_DEL_KERNEL(add);
476
- GGML_METAL_DEL_KERNEL(add_row);
477
- GGML_METAL_DEL_KERNEL(mul);
478
- GGML_METAL_DEL_KERNEL(mul_row);
479
- GGML_METAL_DEL_KERNEL(div);
480
- GGML_METAL_DEL_KERNEL(div_row);
481
- GGML_METAL_DEL_KERNEL(scale);
482
- GGML_METAL_DEL_KERNEL(scale_4);
483
- GGML_METAL_DEL_KERNEL(tanh);
484
- GGML_METAL_DEL_KERNEL(relu);
485
- GGML_METAL_DEL_KERNEL(gelu);
486
- GGML_METAL_DEL_KERNEL(gelu_quick);
487
- GGML_METAL_DEL_KERNEL(silu);
488
- GGML_METAL_DEL_KERNEL(soft_max);
489
- GGML_METAL_DEL_KERNEL(soft_max_4);
490
- GGML_METAL_DEL_KERNEL(diag_mask_inf);
491
- GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
492
- GGML_METAL_DEL_KERNEL(get_rows_f32);
493
- GGML_METAL_DEL_KERNEL(get_rows_f16);
494
- GGML_METAL_DEL_KERNEL(get_rows_q4_0);
495
- GGML_METAL_DEL_KERNEL(get_rows_q4_1);
496
- GGML_METAL_DEL_KERNEL(get_rows_q5_0);
497
- GGML_METAL_DEL_KERNEL(get_rows_q5_1);
498
- GGML_METAL_DEL_KERNEL(get_rows_q8_0);
499
- GGML_METAL_DEL_KERNEL(get_rows_q2_K);
500
- GGML_METAL_DEL_KERNEL(get_rows_q3_K);
501
- GGML_METAL_DEL_KERNEL(get_rows_q4_K);
502
- GGML_METAL_DEL_KERNEL(get_rows_q5_K);
503
- GGML_METAL_DEL_KERNEL(get_rows_q6_K);
504
- GGML_METAL_DEL_KERNEL(get_rows_i32);
505
- GGML_METAL_DEL_KERNEL(rms_norm);
506
- GGML_METAL_DEL_KERNEL(group_norm);
507
- GGML_METAL_DEL_KERNEL(norm);
508
- GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
509
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
510
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
511
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
512
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
513
- GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
514
- GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
515
- GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
516
- GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
517
- GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
518
- GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
519
- GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
520
- GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
521
- GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
522
- GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
523
- GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
524
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
525
- GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
526
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
527
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
528
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
529
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
530
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
531
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
532
- GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
533
- GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
534
- GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
535
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
536
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
537
- GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
538
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
539
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
540
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
541
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
542
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
543
- GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
544
- GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
545
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
546
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
547
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
548
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
549
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
550
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
551
- GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
552
- GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
553
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
554
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
555
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
556
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
557
- GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
558
- GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
559
- GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
560
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
561
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
562
- GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
563
- }
564
- GGML_METAL_DEL_KERNEL(rope_f32);
565
- GGML_METAL_DEL_KERNEL(rope_f16);
566
- GGML_METAL_DEL_KERNEL(alibi_f32);
567
- GGML_METAL_DEL_KERNEL(im2col_f16);
568
- GGML_METAL_DEL_KERNEL(upscale_f32);
569
- GGML_METAL_DEL_KERNEL(pad_f32);
570
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
571
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
572
- GGML_METAL_DEL_KERNEL(leaky_relu_f32);
573
- GGML_METAL_DEL_KERNEL(cpy_f32_f16);
574
- GGML_METAL_DEL_KERNEL(cpy_f32_f32);
575
- GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
576
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
577
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
578
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
579
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
580
- GGML_METAL_DEL_KERNEL(cpy_f16_f16);
581
- GGML_METAL_DEL_KERNEL(cpy_f16_f32);
582
- GGML_METAL_DEL_KERNEL(concat);
583
- GGML_METAL_DEL_KERNEL(sqr);
584
- GGML_METAL_DEL_KERNEL(sum_rows);
585
-
586
- #undef GGML_METAL_DEL_KERNEL
587
536
 
588
537
  for (int i = 0; i < ctx->n_buffers; ++i) {
589
538
  [ctx->buffers[i].metal release];
590
539
  }
591
540
 
541
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
542
+ if (ctx->kernels[i].pipeline) {
543
+ [ctx->kernels[i].pipeline release];
544
+ }
545
+
546
+ if (ctx->kernels[i].function) {
547
+ [ctx->kernels[i].function release];
548
+ }
549
+ }
550
+
592
551
  [ctx->library release];
593
552
  [ctx->queue release];
594
553
  [ctx->device release];
@@ -598,33 +557,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
598
557
  free(ctx);
599
558
  }
600
559
 
601
- void * ggml_metal_host_malloc(size_t n) {
602
- void * data = NULL;
603
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
604
- if (result != 0) {
605
- GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
606
- return NULL;
607
- }
608
-
609
- return data;
610
- }
611
-
612
- void ggml_metal_host_free(void * data) {
613
- free(data);
614
- }
615
-
616
- void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
617
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
618
- }
619
-
620
- int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
621
- return ctx->concur_list_len;
622
- }
623
-
624
- int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
625
- return ctx->concur_list;
626
- }
627
-
628
560
  // temporarily defined here for compatibility between ggml-backend and the old API
629
561
 
630
562
  struct ggml_backend_metal_buffer {
@@ -697,210 +629,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
697
629
  return nil;
698
630
  }
699
631
 
700
- bool ggml_metal_add_buffer(
701
- struct ggml_metal_context * ctx,
702
- const char * name,
703
- void * data,
704
- size_t size,
705
- size_t max_size) {
706
- if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
707
- GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
708
- return false;
709
- }
710
-
711
- if (data) {
712
- // verify that the buffer does not overlap with any of the existing buffers
713
- for (int i = 0; i < ctx->n_buffers; ++i) {
714
- const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
715
-
716
- if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
717
- GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
718
- return false;
719
- }
720
- }
721
-
722
- const size_t size_page = sysconf(_SC_PAGESIZE);
723
-
724
- size_t size_aligned = size;
725
- if ((size_aligned % size_page) != 0) {
726
- size_aligned += (size_page - (size_aligned % size_page));
727
- }
728
-
729
- // the buffer fits into the max buffer size allowed by the device
730
- if (size_aligned <= ctx->device.maxBufferLength) {
731
- ctx->buffers[ctx->n_buffers].name = name;
732
- ctx->buffers[ctx->n_buffers].data = data;
733
- ctx->buffers[ctx->n_buffers].size = size;
734
-
735
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
736
-
737
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
738
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
739
- return false;
740
- }
741
-
742
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
743
-
744
- ++ctx->n_buffers;
745
- } else {
746
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
747
- // one of the views
748
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
749
- const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
750
- const size_t size_view = ctx->device.maxBufferLength;
751
-
752
- for (size_t i = 0; i < size; i += size_step) {
753
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
754
-
755
- ctx->buffers[ctx->n_buffers].name = name;
756
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
757
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
758
-
759
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
760
-
761
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
762
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
763
- return false;
764
- }
765
-
766
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
767
- if (i + size_step < size) {
768
- GGML_METAL_LOG_INFO("\n");
769
- }
770
-
771
- ++ctx->n_buffers;
772
- }
773
- }
774
-
775
- #if TARGET_OS_OSX
776
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
777
- ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
778
- ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
779
-
780
- if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
781
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
782
- } else {
783
- GGML_METAL_LOG_INFO("\n");
784
- }
785
- #else
786
- GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
787
- #endif
788
- }
789
-
790
- return true;
791
- }
792
-
793
- void ggml_metal_set_tensor(
794
- struct ggml_metal_context * ctx,
795
- struct ggml_tensor * t) {
796
- size_t offs;
797
- id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
798
-
799
- memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
800
- }
801
-
802
- void ggml_metal_get_tensor(
803
- struct ggml_metal_context * ctx,
804
- struct ggml_tensor * t) {
805
- size_t offs;
806
- id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
807
-
808
- memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
809
- }
810
-
811
- void ggml_metal_graph_find_concurrency(
812
- struct ggml_metal_context * ctx,
813
- struct ggml_cgraph * gf, bool check_mem) {
814
- int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
815
- int nodes_unused[GGML_MAX_CONCUR];
816
-
817
- for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
818
- for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
819
- ctx->concur_list_len = 0;
820
-
821
- int n_left = gf->n_nodes;
822
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
823
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
824
-
825
- while (n_left > 0) {
826
- // number of nodes at a layer (that can be issued concurrently)
827
- int concurrency = 0;
828
- for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
829
- if (nodes_unused[i]) {
830
- // if the requirements for gf->nodes[i] are satisfied
831
- int exe_flag = 1;
832
-
833
- // scan all srcs
834
- for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
835
- struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
836
- if (src_cur) {
837
- // if is leaf nodes it's satisfied.
838
- // TODO: ggml_is_leaf()
839
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
840
- continue;
841
- }
842
-
843
- // otherwise this src should be the output from previous nodes.
844
- int is_found = 0;
845
-
846
- // scan 2*search_depth back because we inserted barrier.
847
- //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
848
- for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
849
- if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
850
- is_found = 1;
851
- break;
852
- }
853
- }
854
- if (is_found == 0) {
855
- exe_flag = 0;
856
- break;
857
- }
858
- }
859
- }
860
- if (exe_flag && check_mem) {
861
- // check if nodes[i]'s data will be overwritten by a node before nodes[i].
862
- // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
863
- int64_t data_start = (int64_t) gf->nodes[i]->data;
864
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
865
- for (int j = n_start; j < i; j++) {
866
- if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
867
- && gf->nodes[j]->op != GGML_OP_VIEW \
868
- && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
869
- && gf->nodes[j]->op != GGML_OP_PERMUTE) {
870
- if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
871
- ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
872
- continue;
873
- }
874
-
875
- exe_flag = 0;
876
- }
877
- }
878
- }
879
- if (exe_flag) {
880
- ctx->concur_list[level_pos + concurrency] = i;
881
- nodes_unused[i] = 0;
882
- concurrency++;
883
- ctx->concur_list_len++;
884
- }
885
- }
886
- }
887
- n_left -= concurrency;
888
- // adding a barrier different layer
889
- ctx->concur_list[level_pos + concurrency] = -1;
890
- ctx->concur_list_len++;
891
- // jump all sorted nodes at nodes_bak
892
- while (!nodes_unused[n_start]) {
893
- n_start++;
894
- }
895
- level_pos += concurrency + 1;
896
- }
897
-
898
- if (ctx->concur_list_len > GGML_MAX_CONCUR) {
899
- GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
900
- }
901
- }
902
-
903
- static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
632
+ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
904
633
  switch (op->op) {
905
634
  case GGML_OP_UNARY:
906
635
  switch (ggml_get_unary_op(op)) {
@@ -926,9 +655,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
926
655
  case GGML_OP_SCALE:
927
656
  case GGML_OP_SQR:
928
657
  case GGML_OP_SUM_ROWS:
658
+ return true;
929
659
  case GGML_OP_SOFT_MAX:
930
660
  case GGML_OP_RMS_NORM:
931
661
  case GGML_OP_GROUP_NORM:
662
+ return ctx->support_simdgroup_reduction;
932
663
  case GGML_OP_NORM:
933
664
  case GGML_OP_ALIBI:
934
665
  case GGML_OP_ROPE:
@@ -937,9 +668,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
937
668
  case GGML_OP_PAD:
938
669
  case GGML_OP_ARGSORT:
939
670
  case GGML_OP_LEAKY_RELU:
671
+ return true;
940
672
  case GGML_OP_MUL_MAT:
941
673
  case GGML_OP_MUL_MAT_ID:
942
- return true;
674
+ return ctx->support_simdgroup_reduction;
943
675
  case GGML_OP_CPY:
944
676
  case GGML_OP_DUP:
945
677
  case GGML_OP_CONT:
@@ -977,1438 +709,1556 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
977
709
  return false;
978
710
  }
979
711
  }
980
- void ggml_metal_graph_compute(
712
+
713
+ static bool ggml_metal_graph_compute(
981
714
  struct ggml_metal_context * ctx,
982
715
  struct ggml_cgraph * gf) {
983
716
  @autoreleasepool {
984
717
 
985
- // if there is ctx->concur_list, dispatch concurrently
986
- // else fallback to serial dispatch
987
718
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
988
-
989
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
990
-
991
- const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
992
- edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
719
+ edesc.dispatchType = MTLDispatchTypeSerial;
993
720
 
994
721
  // create multiple command buffers and enqueue them
995
722
  // then, we encode the graph into the command buffers in parallel
996
723
 
724
+ const int n_nodes = gf->n_nodes;
997
725
  const int n_cb = ctx->n_cb;
726
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
998
727
 
999
- for (int i = 0; i < n_cb; ++i) {
1000
- ctx->command_buffers[i] = [ctx->queue commandBuffer];
728
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
729
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
731
+ command_buffer_builder[cb_idx] = command_buffer;
1001
732
 
1002
733
  // enqueue the command buffers in order to specify their execution order
1003
- [ctx->command_buffers[i] enqueue];
1004
-
1005
- ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
734
+ [command_buffer enqueue];
1006
735
  }
736
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
1007
737
 
1008
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
1009
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
1010
-
1011
- dispatch_async(ctx->d_queue, ^{
1012
- size_t offs_src0 = 0;
1013
- size_t offs_src1 = 0;
1014
- size_t offs_dst = 0;
1015
-
1016
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
1017
- id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
1018
-
1019
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
1020
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1021
-
1022
- for (int ind = node_start; ind < node_end; ++ind) {
1023
- const int i = has_concur ? ctx->concur_list[ind] : ind;
1024
-
1025
- if (i == -1) {
1026
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
1027
- continue;
1028
- }
1029
-
1030
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
1031
-
1032
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1033
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1034
- struct ggml_tensor * dst = gf->nodes[i];
1035
-
1036
- switch (dst->op) {
1037
- case GGML_OP_NONE:
1038
- case GGML_OP_RESHAPE:
1039
- case GGML_OP_VIEW:
1040
- case GGML_OP_TRANSPOSE:
1041
- case GGML_OP_PERMUTE:
1042
- {
1043
- // noop -> next node
1044
- } continue;
1045
- default:
1046
- {
1047
- } break;
1048
- }
1049
-
1050
- if (!ggml_metal_supports_op(dst)) {
1051
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1052
- GGML_ASSERT(!"unsupported op");
1053
- }
1054
-
1055
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
1056
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
1057
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
1058
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
1059
-
1060
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
1061
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
1062
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
1063
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
1064
-
1065
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
1066
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
1067
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
1068
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
1069
-
1070
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
1071
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
1072
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
1073
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
1074
-
1075
- const int64_t ne0 = dst ? dst->ne[0] : 0;
1076
- const int64_t ne1 = dst ? dst->ne[1] : 0;
1077
- const int64_t ne2 = dst ? dst->ne[2] : 0;
1078
- const int64_t ne3 = dst ? dst->ne[3] : 0;
1079
-
1080
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
1081
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
1082
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
1083
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
1084
-
1085
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1086
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1087
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1088
-
1089
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
1090
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
1091
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
1092
-
1093
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1094
- //if (src0) {
1095
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1096
- // ggml_is_contiguous(src0), src0->name);
1097
- //}
1098
- //if (src1) {
1099
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1100
- // ggml_is_contiguous(src1), src1->name);
1101
- //}
1102
- //if (dst) {
1103
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1104
- // dst->name);
1105
- //}
1106
-
1107
- switch (dst->op) {
1108
- case GGML_OP_CONCAT:
1109
- {
1110
- const int64_t nb = ne00;
1111
-
1112
- [encoder setComputePipelineState:ctx->pipeline_concat];
1113
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1114
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1115
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1116
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1117
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1118
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1119
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1120
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1121
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1122
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1123
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1124
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1125
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1126
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1127
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1128
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1129
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1130
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1131
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1132
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1133
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1134
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1135
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1136
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1137
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1138
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1139
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1140
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1141
-
1142
- const int nth = MIN(1024, ne0);
1143
-
1144
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1145
- } break;
1146
- case GGML_OP_ADD:
1147
- case GGML_OP_MUL:
1148
- case GGML_OP_DIV:
1149
- {
1150
- const size_t offs = 0;
1151
-
1152
- bool bcast_row = false;
1153
-
1154
- int64_t nb = ne00;
1155
-
1156
- id<MTLComputePipelineState> pipeline = nil;
1157
-
1158
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1159
- GGML_ASSERT(ggml_is_contiguous(src0));
738
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
739
+ const int cb_idx = iter;
1160
740
 
1161
- // src1 is a row
1162
- GGML_ASSERT(ne11 == 1);
741
+ size_t offs_src0 = 0;
742
+ size_t offs_src1 = 0;
743
+ size_t offs_dst = 0;
1163
744
 
1164
- nb = ne00 / 4;
1165
- switch (dst->op) {
1166
- case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1167
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1168
- case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1169
- default: GGML_ASSERT(false);
1170
- }
745
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
746
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1171
747
 
1172
- bcast_row = true;
1173
- } else {
1174
- switch (dst->op) {
1175
- case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1176
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1177
- case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1178
- default: GGML_ASSERT(false);
1179
- }
1180
- }
748
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
749
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1181
750
 
1182
- [encoder setComputePipelineState:pipeline];
1183
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1184
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1185
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1186
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1187
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1188
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1189
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1190
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1191
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1192
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1193
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1194
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1195
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1196
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1197
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1198
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1199
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1200
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1201
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1202
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1203
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1204
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1205
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1206
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1207
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1208
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1209
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1210
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1211
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1212
-
1213
- if (bcast_row) {
1214
- const int64_t n = ggml_nelements(dst)/4;
751
+ for (int i = node_start; i < node_end; ++i) {
752
+ if (i == -1) {
753
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
754
+ continue;
755
+ }
1215
756
 
1216
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1217
- } else {
1218
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
757
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
758
+
759
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
760
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
761
+ struct ggml_tensor * dst = gf->nodes[i];
762
+
763
+ switch (dst->op) {
764
+ case GGML_OP_NONE:
765
+ case GGML_OP_RESHAPE:
766
+ case GGML_OP_VIEW:
767
+ case GGML_OP_TRANSPOSE:
768
+ case GGML_OP_PERMUTE:
769
+ {
770
+ // noop -> next node
771
+ } continue;
772
+ default:
773
+ {
774
+ } break;
775
+ }
1219
776
 
1220
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1221
- }
1222
- } break;
1223
- case GGML_OP_ACC:
1224
- {
1225
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1226
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1227
- GGML_ASSERT(dstt == GGML_TYPE_F32);
777
+ if (!ggml_metal_supports_op(ctx, dst)) {
778
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
779
+ GGML_ASSERT(!"unsupported op");
780
+ }
1228
781
 
1229
- GGML_ASSERT(ggml_is_contiguous(src0));
1230
- GGML_ASSERT(ggml_is_contiguous(src1));
1231
-
1232
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1233
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1234
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1235
- const size_t offs = ((int32_t *) dst->op_params)[3];
1236
-
1237
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1238
-
1239
- if (!inplace) {
1240
- // run a separete kernel to cpy src->dst
1241
- // not sure how to avoid this
1242
- // TODO: make a simpler cpy_bytes kernel
1243
-
1244
- const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
1245
-
1246
- [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1247
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1248
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1249
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1250
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1251
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1252
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1253
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1254
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1255
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1256
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1257
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1258
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1259
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1260
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1261
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1262
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1263
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1264
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1265
-
1266
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1267
- }
782
+ #ifndef GGML_METAL_NDEBUG
783
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
784
+ #endif
1268
785
 
1269
- [encoder setComputePipelineState:ctx->pipeline_add];
1270
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1271
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1272
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1273
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1274
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1275
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1276
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1277
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1278
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1279
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1280
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1281
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1282
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1283
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1284
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1285
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1286
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1287
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1288
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1289
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1290
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1291
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1292
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1293
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1294
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1295
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1296
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1297
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1298
-
1299
- const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
1300
-
1301
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1302
- } break;
1303
- case GGML_OP_SCALE:
1304
- {
786
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
787
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
788
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
789
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
790
+
791
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
792
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
793
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
794
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
795
+
796
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
797
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
798
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
799
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
800
+
801
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
802
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
803
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
804
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
805
+
806
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
807
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
808
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
809
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
810
+
811
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
812
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
813
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
814
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
815
+
816
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
817
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
818
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
819
+
820
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
821
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
822
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
823
+
824
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
825
+ //if (src0) {
826
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
827
+ // ggml_is_contiguous(src0), src0->name);
828
+ //}
829
+ //if (src1) {
830
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
831
+ // ggml_is_contiguous(src1), src1->name);
832
+ //}
833
+ //if (dst) {
834
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
835
+ // dst->name);
836
+ //}
837
+
838
+ switch (dst->op) {
839
+ case GGML_OP_CONCAT:
840
+ {
841
+ const int64_t nb = ne00;
842
+
843
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
844
+
845
+ [encoder setComputePipelineState:pipeline];
846
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
847
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
848
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
849
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
850
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
851
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
852
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
853
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
854
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
855
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
856
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
857
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
858
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
859
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
860
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
861
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
862
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
863
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
864
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
865
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
866
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
867
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
868
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
869
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
870
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
871
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
872
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
873
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
874
+
875
+ const int nth = MIN(1024, ne0);
876
+
877
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
878
+ } break;
879
+ case GGML_OP_ADD:
880
+ case GGML_OP_MUL:
881
+ case GGML_OP_DIV:
882
+ {
883
+ const size_t offs = 0;
884
+
885
+ bool bcast_row = false;
886
+
887
+ int64_t nb = ne00;
888
+
889
+ id<MTLComputePipelineState> pipeline = nil;
890
+
891
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1305
892
  GGML_ASSERT(ggml_is_contiguous(src0));
1306
893
 
1307
- const float scale = *(const float *) dst->op_params;
894
+ // src1 is a row
895
+ GGML_ASSERT(ne11 == 1);
1308
896
 
1309
- int64_t n = ggml_nelements(dst);
897
+ nb = ne00 / 4;
898
+ switch (dst->op) {
899
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
900
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
901
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
902
+ default: GGML_ASSERT(false);
903
+ }
1310
904
 
1311
- if (n % 4 == 0) {
1312
- n /= 4;
1313
- [encoder setComputePipelineState:ctx->pipeline_scale_4];
1314
- } else {
1315
- [encoder setComputePipelineState:ctx->pipeline_scale];
905
+ bcast_row = true;
906
+ } else {
907
+ switch (dst->op) {
908
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
909
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
910
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
911
+ default: GGML_ASSERT(false);
1316
912
  }
913
+ }
1317
914
 
1318
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1319
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1320
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
915
+ [encoder setComputePipelineState:pipeline];
916
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
917
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
918
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
919
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
920
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
921
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
922
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
923
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
924
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
925
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
926
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
927
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
928
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
929
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
930
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
931
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
932
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
933
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
934
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
935
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
936
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
937
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
938
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
939
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
940
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
941
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
942
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
943
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
944
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
945
+
946
+ if (bcast_row) {
947
+ const int64_t n = ggml_nelements(dst)/4;
1321
948
 
1322
949
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1323
- } break;
1324
- case GGML_OP_UNARY:
1325
- switch (ggml_get_unary_op(gf->nodes[i])) {
1326
- case GGML_UNARY_OP_TANH:
1327
- {
1328
- [encoder setComputePipelineState:ctx->pipeline_tanh];
1329
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1330
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
950
+ } else {
951
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1331
952
 
1332
- const int64_t n = ggml_nelements(dst);
953
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
954
+ }
955
+ } break;
956
+ case GGML_OP_ACC:
957
+ {
958
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
959
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
960
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1333
961
 
1334
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1335
- } break;
1336
- case GGML_UNARY_OP_RELU:
1337
- {
1338
- [encoder setComputePipelineState:ctx->pipeline_relu];
1339
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1340
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
962
+ GGML_ASSERT(ggml_is_contiguous(src0));
963
+ GGML_ASSERT(ggml_is_contiguous(src1));
1341
964
 
1342
- const int64_t n = ggml_nelements(dst);
965
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
966
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
967
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
968
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1343
969
 
1344
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1345
- } break;
1346
- case GGML_UNARY_OP_GELU:
1347
- {
1348
- [encoder setComputePipelineState:ctx->pipeline_gelu];
1349
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1350
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
970
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1351
971
 
1352
- const int64_t n = ggml_nelements(dst);
1353
- GGML_ASSERT(n % 4 == 0);
972
+ if (!inplace) {
973
+ // run a separete kernel to cpy src->dst
974
+ // not sure how to avoid this
975
+ // TODO: make a simpler cpy_bytes kernel
1354
976
 
1355
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1356
- } break;
1357
- case GGML_UNARY_OP_GELU_QUICK:
1358
- {
1359
- [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1360
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1361
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
977
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1362
978
 
1363
- const int64_t n = ggml_nelements(dst);
1364
- GGML_ASSERT(n % 4 == 0);
979
+ [encoder setComputePipelineState:pipeline];
980
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
981
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
982
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
983
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
984
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
985
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
986
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
987
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
988
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
989
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
990
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
991
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
992
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
993
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
994
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
995
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
996
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
997
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1365
998
 
1366
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1367
- } break;
1368
- case GGML_UNARY_OP_SILU:
1369
- {
1370
- [encoder setComputePipelineState:ctx->pipeline_silu];
1371
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1372
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
999
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1373
1000
 
1374
- const int64_t n = ggml_nelements(dst);
1375
- GGML_ASSERT(n % 4 == 0);
1001
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1002
+ }
1376
1003
 
1377
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1378
- } break;
1379
- default:
1380
- {
1381
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1382
- GGML_ASSERT(false);
1383
- }
1384
- } break;
1385
- case GGML_OP_SQR:
1386
- {
1387
- GGML_ASSERT(ggml_is_contiguous(src0));
1004
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1005
+
1006
+ [encoder setComputePipelineState:pipeline];
1007
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1008
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1009
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1010
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1011
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1012
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1013
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1014
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1015
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1016
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1017
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1018
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1019
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1020
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1021
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1022
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1023
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1024
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1025
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1026
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1027
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1028
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1029
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1030
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1031
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1032
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1033
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1034
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1035
+
1036
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1037
+
1038
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1039
+ } break;
1040
+ case GGML_OP_SCALE:
1041
+ {
1042
+ GGML_ASSERT(ggml_is_contiguous(src0));
1043
+
1044
+ const float scale = *(const float *) dst->op_params;
1045
+
1046
+ int64_t n = ggml_nelements(dst);
1047
+
1048
+ id<MTLComputePipelineState> pipeline = nil;
1049
+
1050
+ if (n % 4 == 0) {
1051
+ n /= 4;
1052
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1053
+ } else {
1054
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1055
+ }
1388
1056
 
1389
- [encoder setComputePipelineState:ctx->pipeline_sqr];
1390
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1391
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1057
+ [encoder setComputePipelineState:pipeline];
1058
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1059
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1060
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1392
1061
 
1393
- const int64_t n = ggml_nelements(dst);
1394
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1395
- } break;
1396
- case GGML_OP_SUM_ROWS:
1397
- {
1398
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1062
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1063
+ } break;
1064
+ case GGML_OP_UNARY:
1065
+ switch (ggml_get_unary_op(gf->nodes[i])) {
1066
+ case GGML_UNARY_OP_TANH:
1067
+ {
1068
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1399
1069
 
1400
- [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1401
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1402
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1403
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1404
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1405
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1406
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1407
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1408
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1409
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1410
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1411
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1412
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1413
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1414
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1415
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1416
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1417
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1418
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1419
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1420
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1421
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1422
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1423
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1424
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1425
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1426
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1427
-
1428
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1429
- } break;
1430
- case GGML_OP_SOFT_MAX:
1431
- {
1432
- int nth = 32; // SIMD width
1433
-
1434
- if (ne00%4 == 0) {
1435
- while (nth < ne00/4 && nth < 256) {
1436
- nth *= 2;
1437
- }
1438
- [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1439
- } else {
1440
- while (nth < ne00 && nth < 1024) {
1441
- nth *= 2;
1442
- }
1443
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
1444
- }
1070
+ [encoder setComputePipelineState:pipeline];
1071
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1072
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1445
1073
 
1446
- const float scale = ((float *) dst->op_params)[0];
1074
+ const int64_t n = ggml_nelements(dst);
1447
1075
 
1448
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1449
- if (id_src1) {
1450
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1451
- } else {
1452
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1453
- }
1454
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1455
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1456
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1457
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1458
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1459
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1460
-
1461
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1462
- } break;
1463
- case GGML_OP_DIAG_MASK_INF:
1464
- {
1465
- const int n_past = ((int32_t *)(dst->op_params))[0];
1466
-
1467
- if (ne00%8 == 0) {
1468
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
1469
- } else {
1470
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
1471
- }
1472
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1473
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1474
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1475
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1476
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1076
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1077
+ } break;
1078
+ case GGML_UNARY_OP_RELU:
1079
+ {
1080
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1477
1081
 
1478
- if (ne00%8 == 0) {
1479
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1480
- }
1481
- else {
1482
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1483
- }
1484
- } break;
1485
- case GGML_OP_MUL_MAT:
1486
- {
1487
- GGML_ASSERT(ne00 == ne10);
1082
+ [encoder setComputePipelineState:pipeline];
1083
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1084
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1488
1085
 
1489
- // TODO: assert that dim2 and dim3 are contiguous
1490
- GGML_ASSERT(ne12 % ne02 == 0);
1491
- GGML_ASSERT(ne13 % ne03 == 0);
1086
+ const int64_t n = ggml_nelements(dst);
1492
1087
 
1493
- const uint r2 = ne12/ne02;
1494
- const uint r3 = ne13/ne03;
1088
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1089
+ } break;
1090
+ case GGML_UNARY_OP_GELU:
1091
+ {
1092
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1495
1093
 
1496
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1497
- // to the matrix-vector kernel
1498
- int ne11_mm_min = 1;
1094
+ [encoder setComputePipelineState:pipeline];
1095
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1096
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1499
1097
 
1500
- #if 0
1501
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1502
- // these numbers do not translate to other devices or model sizes
1503
- // TODO: need to find a better approach
1504
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1505
- switch (src0t) {
1506
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1507
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1508
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1509
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1510
- case GGML_TYPE_Q4_0:
1511
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1512
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1513
- case GGML_TYPE_Q5_0: // not tested yet
1514
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1515
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1516
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1517
- default: ne11_mm_min = 1; break;
1518
- }
1519
- }
1520
- #endif
1098
+ const int64_t n = ggml_nelements(dst);
1099
+ GGML_ASSERT(n % 4 == 0);
1521
1100
 
1522
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1523
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1524
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1525
- !ggml_is_transposed(src0) &&
1526
- !ggml_is_transposed(src1) &&
1527
- src1t == GGML_TYPE_F32 &&
1528
- ne00 % 32 == 0 && ne00 >= 64 &&
1529
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1530
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1531
- switch (src0->type) {
1532
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
1533
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1534
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1535
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1536
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1537
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1538
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1539
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1540
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
1541
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1542
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1543
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1544
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1545
- }
1546
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1547
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1548
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1549
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1550
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1551
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1552
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1553
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1554
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1555
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1556
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1557
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1558
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1559
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1560
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1561
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1562
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1563
- } else {
1564
- int nth0 = 32;
1565
- int nth1 = 1;
1566
- int nrows = 1;
1567
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1568
-
1569
- // use custom matrix x vector kernel
1570
- switch (src0t) {
1571
- case GGML_TYPE_F32:
1572
- {
1573
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1574
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1575
- nrows = 4;
1576
- } break;
1577
- case GGML_TYPE_F16:
1578
- {
1579
- nth0 = 32;
1580
- nth1 = 1;
1581
- if (src1t == GGML_TYPE_F32) {
1582
- if (ne11 * ne12 < 4) {
1583
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1584
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1585
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1586
- nrows = ne11;
1587
- } else {
1588
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1589
- nrows = 4;
1590
- }
1591
- } else {
1592
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1593
- nrows = 4;
1594
- }
1595
- } break;
1596
- case GGML_TYPE_Q4_0:
1597
- {
1598
- nth0 = 8;
1599
- nth1 = 8;
1600
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1601
- } break;
1602
- case GGML_TYPE_Q4_1:
1603
- {
1604
- nth0 = 8;
1605
- nth1 = 8;
1606
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1607
- } break;
1608
- case GGML_TYPE_Q5_0:
1609
- {
1610
- nth0 = 8;
1611
- nth1 = 8;
1612
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1613
- } break;
1614
- case GGML_TYPE_Q5_1:
1615
- {
1616
- nth0 = 8;
1617
- nth1 = 8;
1618
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1619
- } break;
1620
- case GGML_TYPE_Q8_0:
1621
- {
1622
- nth0 = 8;
1623
- nth1 = 8;
1624
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1625
- } break;
1626
- case GGML_TYPE_Q2_K:
1627
- {
1628
- nth0 = 2;
1629
- nth1 = 32;
1630
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1631
- } break;
1632
- case GGML_TYPE_Q3_K:
1633
- {
1634
- nth0 = 2;
1635
- nth1 = 32;
1636
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1637
- } break;
1638
- case GGML_TYPE_Q4_K:
1639
- {
1640
- nth0 = 4; //1;
1641
- nth1 = 8; //32;
1642
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1643
- } break;
1644
- case GGML_TYPE_Q5_K:
1645
- {
1646
- nth0 = 2;
1647
- nth1 = 32;
1648
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1649
- } break;
1650
- case GGML_TYPE_Q6_K:
1651
- {
1652
- nth0 = 2;
1653
- nth1 = 32;
1654
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1655
- } break;
1656
- default:
1657
- {
1658
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1659
- GGML_ASSERT(false && "not implemented");
1660
- }
1661
- };
1662
-
1663
- if (ggml_is_quantized(src0t)) {
1664
- GGML_ASSERT(ne00 >= nth0*nth1);
1665
- }
1101
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1102
+ } break;
1103
+ case GGML_UNARY_OP_GELU_QUICK:
1104
+ {
1105
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1666
1106
 
1107
+ [encoder setComputePipelineState:pipeline];
1667
1108
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1668
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1669
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1670
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1671
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1672
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1673
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1674
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1675
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1676
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1677
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1678
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1679
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1680
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1681
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1682
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1683
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1684
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1685
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1686
-
1687
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1688
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1689
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1690
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1691
- }
1692
- else if (src0t == GGML_TYPE_Q4_K) {
1693
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1694
- }
1695
- else if (src0t == GGML_TYPE_Q3_K) {
1696
- #ifdef GGML_QKK_64
1697
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1698
- #else
1699
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1700
- #endif
1701
- }
1702
- else if (src0t == GGML_TYPE_Q5_K) {
1703
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1704
- }
1705
- else if (src0t == GGML_TYPE_Q6_K) {
1706
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1707
- } else {
1708
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1709
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1710
- }
1711
- }
1712
- } break;
1713
- case GGML_OP_MUL_MAT_ID:
1714
- {
1715
- //GGML_ASSERT(ne00 == ne10);
1716
- //GGML_ASSERT(ne03 == ne13);
1717
-
1718
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1719
-
1720
- const int n_as = ((int32_t *) dst->op_params)[1];
1721
-
1722
- // TODO: make this more general
1723
- GGML_ASSERT(n_as <= 8);
1724
-
1725
- // max size of the src1ids array in the kernel stack
1726
- GGML_ASSERT(ne11 <= 512);
1727
-
1728
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1729
-
1730
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1731
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1732
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1733
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1734
-
1735
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1736
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1737
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1738
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1739
-
1740
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1741
-
1742
- GGML_ASSERT(!ggml_is_transposed(src2));
1743
- GGML_ASSERT(!ggml_is_transposed(src1));
1744
-
1745
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1746
-
1747
- const uint r2 = ne12/ne22;
1748
- const uint r3 = ne13/ne23;
1749
-
1750
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1751
- // to the matrix-vector kernel
1752
- int ne11_mm_min = n_as;
1753
-
1754
- const int idx = ((int32_t *) dst->op_params)[0];
1755
-
1756
- // batch size
1757
- GGML_ASSERT(ne01 == ne11);
1758
-
1759
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1760
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1761
- // !!!
1762
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1763
- // indirect matrix multiplication
1764
- // !!!
1765
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1766
- ne20 % 32 == 0 && ne20 >= 64 &&
1767
- ne11 > ne11_mm_min) {
1768
- switch (src2->type) {
1769
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1770
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1771
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1772
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1773
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1774
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1775
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1776
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1777
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1778
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1779
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1780
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1781
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1782
- }
1783
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1784
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1785
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1786
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1787
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1788
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1789
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1790
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1791
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1792
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1793
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1794
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1795
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1796
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1797
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1798
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1799
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1800
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1801
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1802
- // TODO: how to make this an array? read Metal docs
1803
- for (int j = 0; j < 8; ++j) {
1804
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1805
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1806
-
1807
- size_t offs_src_cur = 0;
1808
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1809
-
1810
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1811
- }
1812
-
1813
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1814
-
1815
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1816
- } else {
1817
- int nth0 = 32;
1818
- int nth1 = 1;
1819
- int nrows = 1;
1820
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1821
-
1822
- // use custom matrix x vector kernel
1823
- switch (src2t) {
1824
- case GGML_TYPE_F32:
1825
- {
1826
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1827
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1828
- } break;
1829
- case GGML_TYPE_F16:
1830
- {
1831
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1832
- nth0 = 32;
1833
- nth1 = 1;
1834
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1835
- } break;
1836
- case GGML_TYPE_Q4_0:
1837
- {
1838
- nth0 = 8;
1839
- nth1 = 8;
1840
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1841
- } break;
1842
- case GGML_TYPE_Q4_1:
1843
- {
1844
- nth0 = 8;
1845
- nth1 = 8;
1846
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1847
- } break;
1848
- case GGML_TYPE_Q5_0:
1849
- {
1850
- nth0 = 8;
1851
- nth1 = 8;
1852
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1853
- } break;
1854
- case GGML_TYPE_Q5_1:
1855
- {
1856
- nth0 = 8;
1857
- nth1 = 8;
1858
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1859
- } break;
1860
- case GGML_TYPE_Q8_0:
1861
- {
1862
- nth0 = 8;
1863
- nth1 = 8;
1864
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1865
- } break;
1866
- case GGML_TYPE_Q2_K:
1867
- {
1868
- nth0 = 2;
1869
- nth1 = 32;
1870
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1871
- } break;
1872
- case GGML_TYPE_Q3_K:
1873
- {
1874
- nth0 = 2;
1875
- nth1 = 32;
1876
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1877
- } break;
1878
- case GGML_TYPE_Q4_K:
1879
- {
1880
- nth0 = 4; //1;
1881
- nth1 = 8; //32;
1882
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1883
- } break;
1884
- case GGML_TYPE_Q5_K:
1885
- {
1886
- nth0 = 2;
1887
- nth1 = 32;
1888
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1889
- } break;
1890
- case GGML_TYPE_Q6_K:
1891
- {
1892
- nth0 = 2;
1893
- nth1 = 32;
1894
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1895
- } break;
1896
- default:
1897
- {
1898
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1899
- GGML_ASSERT(false && "not implemented");
1900
- }
1901
- };
1109
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1902
1110
 
1903
- if (ggml_is_quantized(src2t)) {
1904
- GGML_ASSERT(ne20 >= nth0*nth1);
1905
- }
1111
+ const int64_t n = ggml_nelements(dst);
1112
+ GGML_ASSERT(n % 4 == 0);
1906
1113
 
1907
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1114
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1115
+ } break;
1116
+ case GGML_UNARY_OP_SILU:
1117
+ {
1118
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1908
1119
 
1120
+ [encoder setComputePipelineState:pipeline];
1909
1121
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1910
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1911
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1912
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1913
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1914
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1915
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1916
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1917
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1918
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1919
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1920
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1921
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1922
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1923
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1924
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1925
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1926
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1927
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1928
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1929
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1930
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1931
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1932
- // TODO: how to make this an array? read Metal docs
1933
- for (int j = 0; j < 8; ++j) {
1934
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1935
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1936
-
1937
- size_t offs_src_cur = 0;
1938
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1939
-
1940
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1941
- }
1942
-
1943
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1944
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1945
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1946
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1947
- }
1948
- else if (src2t == GGML_TYPE_Q4_K) {
1949
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1950
- }
1951
- else if (src2t == GGML_TYPE_Q3_K) {
1952
- #ifdef GGML_QKK_64
1953
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1954
- #else
1955
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1956
- #endif
1957
- }
1958
- else if (src2t == GGML_TYPE_Q5_K) {
1959
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1960
- }
1961
- else if (src2t == GGML_TYPE_Q6_K) {
1962
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1963
- } else {
1964
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
1965
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1966
- }
1122
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1123
+
1124
+ const int64_t n = ggml_nelements(dst);
1125
+ GGML_ASSERT(n % 4 == 0);
1126
+
1127
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1128
+ } break;
1129
+ default:
1130
+ {
1131
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1132
+ GGML_ASSERT(false);
1967
1133
  }
1968
- } break;
1969
- case GGML_OP_GET_ROWS:
1970
- {
1971
- switch (src0->type) {
1972
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1973
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1974
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1975
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1976
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1977
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1978
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1979
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1980
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
1981
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
1982
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
1983
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
1984
- case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
1985
- default: GGML_ASSERT(false && "not implemented");
1134
+ } break;
1135
+ case GGML_OP_SQR:
1136
+ {
1137
+ GGML_ASSERT(ggml_is_contiguous(src0));
1138
+
1139
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1140
+
1141
+ [encoder setComputePipelineState:pipeline];
1142
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1143
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1144
+
1145
+ const int64_t n = ggml_nelements(dst);
1146
+
1147
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1148
+ } break;
1149
+ case GGML_OP_SUM_ROWS:
1150
+ {
1151
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1152
+
1153
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1154
+
1155
+ [encoder setComputePipelineState:pipeline];
1156
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1157
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1158
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1159
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1160
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1161
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1162
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1163
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1164
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1165
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1166
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1167
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1168
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1169
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1170
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1171
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1172
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1173
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1174
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1175
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1176
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1177
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1178
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1179
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1180
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1181
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1185
+ case GGML_OP_SOFT_MAX:
1186
+ {
1187
+ int nth = 32; // SIMD width
1188
+
1189
+ id<MTLComputePipelineState> pipeline = nil;
1190
+
1191
+ if (ne00%4 == 0) {
1192
+ while (nth < ne00/4 && nth < 256) {
1193
+ nth *= 2;
1986
1194
  }
1987
-
1988
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1989
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1990
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1991
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1992
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1993
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1994
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1995
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1996
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1997
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1998
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1999
-
2000
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2001
- } break;
2002
- case GGML_OP_RMS_NORM:
2003
- {
2004
- GGML_ASSERT(ne00 % 4 == 0);
2005
-
2006
- float eps;
2007
- memcpy(&eps, dst->op_params, sizeof(float));
2008
-
2009
- int nth = 32; // SIMD width
2010
-
2011
- while (nth < ne00/4 && nth < 1024) {
1195
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
1196
+ } else {
1197
+ while (nth < ne00 && nth < 1024) {
2012
1198
  nth *= 2;
2013
1199
  }
1200
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1201
+ }
2014
1202
 
2015
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
2016
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2017
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2018
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2019
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2020
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2021
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2022
-
2023
- const int64_t nrows = ggml_nrows(src0);
2024
-
2025
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2026
- } break;
2027
- case GGML_OP_GROUP_NORM:
2028
- {
2029
- GGML_ASSERT(ne00 % 4 == 0);
2030
-
2031
- //float eps;
2032
- //memcpy(&eps, dst->op_params, sizeof(float));
2033
-
2034
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2035
-
2036
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2037
-
2038
- int nth = 32; // SIMD width
2039
-
2040
- //while (nth < ne00/4 && nth < 1024) {
2041
- // nth *= 2;
2042
- //}
2043
-
2044
- [encoder setComputePipelineState:ctx->pipeline_group_norm];
2045
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2046
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2047
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2048
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2049
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2050
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2051
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2052
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2053
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2054
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2055
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2056
-
2057
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2058
- } break;
2059
- case GGML_OP_NORM:
2060
- {
2061
- float eps;
2062
- memcpy(&eps, dst->op_params, sizeof(float));
2063
-
2064
- const int nth = MIN(256, ne00);
2065
-
2066
- [encoder setComputePipelineState:ctx->pipeline_norm];
2067
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2068
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2069
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2070
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2071
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2072
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1203
+ const float scale = ((float *) dst->op_params)[0];
2073
1204
 
2074
- const int64_t nrows = ggml_nrows(src0);
1205
+ [encoder setComputePipelineState:pipeline];
1206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
+ if (id_src1) {
1208
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1209
+ } else {
1210
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1211
+ }
1212
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1213
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1214
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1215
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1216
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1217
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1218
+
1219
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1220
+ } break;
1221
+ case GGML_OP_DIAG_MASK_INF:
1222
+ {
1223
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1224
+
1225
+ id<MTLComputePipelineState> pipeline = nil;
1226
+
1227
+ if (ne00%8 == 0) {
1228
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1229
+ } else {
1230
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1231
+ }
2075
1232
 
2076
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2077
- } break;
2078
- case GGML_OP_ALIBI:
2079
- {
2080
- GGML_ASSERT((src0t == GGML_TYPE_F32));
1233
+ [encoder setComputePipelineState:pipeline];
1234
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1235
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1236
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1237
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1238
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
2081
1239
 
2082
- const int nth = MIN(1024, ne00);
1240
+ if (ne00%8 == 0) {
1241
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1242
+ }
1243
+ else {
1244
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1245
+ }
1246
+ } break;
1247
+ case GGML_OP_MUL_MAT:
1248
+ {
1249
+ GGML_ASSERT(ne00 == ne10);
2083
1250
 
2084
- //const int n_past = ((int32_t *) dst->op_params)[0];
2085
- const int n_head = ((int32_t *) dst->op_params)[1];
2086
- float max_bias;
2087
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1251
+ // TODO: assert that dim2 and dim3 are contiguous
1252
+ GGML_ASSERT(ne12 % ne02 == 0);
1253
+ GGML_ASSERT(ne13 % ne03 == 0);
2088
1254
 
2089
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2090
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2091
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1255
+ const uint r2 = ne12/ne02;
1256
+ const uint r3 = ne13/ne03;
2092
1257
 
2093
- [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
2094
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2095
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2096
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2097
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2098
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2099
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2100
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2101
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2102
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2103
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2104
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2105
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2106
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2107
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2108
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2109
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2110
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2111
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2112
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2113
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2114
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1258
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1259
+ // to the matrix-vector kernel
1260
+ int ne11_mm_min = 1;
2115
1261
 
2116
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2117
- } break;
2118
- case GGML_OP_ROPE:
2119
- {
2120
- GGML_ASSERT(ne10 == ne02);
2121
-
2122
- const int nth = MIN(1024, ne00);
2123
-
2124
- const int n_past = ((int32_t *) dst->op_params)[0];
2125
- const int n_dims = ((int32_t *) dst->op_params)[1];
2126
- const int mode = ((int32_t *) dst->op_params)[2];
2127
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2128
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2129
-
2130
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2131
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2132
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2133
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2134
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2135
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2136
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1262
+ #if 0
1263
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1264
+ // these numbers do not translate to other devices or model sizes
1265
+ // TODO: need to find a better approach
1266
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1267
+ switch (src0t) {
1268
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1269
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1270
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1271
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1272
+ case GGML_TYPE_Q4_0:
1273
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1274
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1275
+ case GGML_TYPE_Q5_0: // not tested yet
1276
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1277
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1278
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1279
+ default: ne11_mm_min = 1; break;
1280
+ }
1281
+ }
1282
+ #endif
1283
+
1284
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1285
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1286
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1287
+ !ggml_is_transposed(src0) &&
1288
+ !ggml_is_transposed(src1) &&
1289
+ src1t == GGML_TYPE_F32 &&
1290
+ ne00 % 32 == 0 && ne00 >= 64 &&
1291
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1292
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1293
+
1294
+ id<MTLComputePipelineState> pipeline = nil;
2137
1295
 
2138
1296
  switch (src0->type) {
2139
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
2140
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
2141
- default: GGML_ASSERT(false);
1297
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1298
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1299
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1300
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1301
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1302
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1303
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1304
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1305
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1306
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1307
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1308
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1309
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1310
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1311
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1312
+ }
1313
+
1314
+ [encoder setComputePipelineState:pipeline];
1315
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1316
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1317
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1318
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1319
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1320
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1321
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1322
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1323
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1324
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1325
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1326
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1327
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1328
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1329
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1330
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1331
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1332
+ } else {
1333
+ int nth0 = 32;
1334
+ int nth1 = 1;
1335
+ int nrows = 1;
1336
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1337
+
1338
+ id<MTLComputePipelineState> pipeline = nil;
1339
+
1340
+ // use custom matrix x vector kernel
1341
+ switch (src0t) {
1342
+ case GGML_TYPE_F32:
1343
+ {
1344
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1345
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1346
+ nrows = 4;
1347
+ } break;
1348
+ case GGML_TYPE_F16:
1349
+ {
1350
+ nth0 = 32;
1351
+ nth1 = 1;
1352
+ if (src1t == GGML_TYPE_F32) {
1353
+ if (ne11 * ne12 < 4) {
1354
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1355
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1356
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1357
+ nrows = ne11;
1358
+ } else {
1359
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1360
+ nrows = 4;
1361
+ }
1362
+ } else {
1363
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1364
+ nrows = 4;
1365
+ }
1366
+ } break;
1367
+ case GGML_TYPE_Q4_0:
1368
+ {
1369
+ nth0 = 8;
1370
+ nth1 = 8;
1371
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1372
+ } break;
1373
+ case GGML_TYPE_Q4_1:
1374
+ {
1375
+ nth0 = 8;
1376
+ nth1 = 8;
1377
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1378
+ } break;
1379
+ case GGML_TYPE_Q5_0:
1380
+ {
1381
+ nth0 = 8;
1382
+ nth1 = 8;
1383
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1384
+ } break;
1385
+ case GGML_TYPE_Q5_1:
1386
+ {
1387
+ nth0 = 8;
1388
+ nth1 = 8;
1389
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1390
+ } break;
1391
+ case GGML_TYPE_Q8_0:
1392
+ {
1393
+ nth0 = 8;
1394
+ nth1 = 8;
1395
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1396
+ } break;
1397
+ case GGML_TYPE_Q2_K:
1398
+ {
1399
+ nth0 = 2;
1400
+ nth1 = 32;
1401
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1402
+ } break;
1403
+ case GGML_TYPE_Q3_K:
1404
+ {
1405
+ nth0 = 2;
1406
+ nth1 = 32;
1407
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1408
+ } break;
1409
+ case GGML_TYPE_Q4_K:
1410
+ {
1411
+ nth0 = 4; //1;
1412
+ nth1 = 8; //32;
1413
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1414
+ } break;
1415
+ case GGML_TYPE_Q5_K:
1416
+ {
1417
+ nth0 = 2;
1418
+ nth1 = 32;
1419
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1420
+ } break;
1421
+ case GGML_TYPE_Q6_K:
1422
+ {
1423
+ nth0 = 2;
1424
+ nth1 = 32;
1425
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1426
+ } break;
1427
+ case GGML_TYPE_IQ2_XXS:
1428
+ {
1429
+ nth0 = 4;
1430
+ nth1 = 16;
1431
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1432
+ } break;
1433
+ case GGML_TYPE_IQ2_XS:
1434
+ {
1435
+ nth0 = 4;
1436
+ nth1 = 16;
1437
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1438
+ } break;
1439
+ default:
1440
+ {
1441
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1442
+ GGML_ASSERT(false && "not implemented");
1443
+ }
2142
1444
  };
2143
1445
 
2144
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2145
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2146
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2147
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2148
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2149
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2150
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2151
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2152
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2153
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2154
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2155
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2156
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2157
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2158
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2159
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2160
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2161
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2162
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2163
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2164
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2165
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2166
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2167
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2168
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2169
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2170
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2171
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2172
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1446
+ if (ggml_is_quantized(src0t)) {
1447
+ GGML_ASSERT(ne00 >= nth0*nth1);
1448
+ }
2173
1449
 
2174
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2175
- } break;
2176
- case GGML_OP_IM2COL:
2177
- {
2178
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
2179
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2180
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
1450
+ [encoder setComputePipelineState:pipeline];
1451
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1452
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1453
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1454
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1455
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1456
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1457
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1458
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1459
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1460
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1461
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1462
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1463
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1464
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1465
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1466
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1467
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1468
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1469
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1470
+
1471
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1472
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1473
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1474
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1475
+ }
1476
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1477
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1478
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
+ }
1481
+ else if (src0t == GGML_TYPE_Q4_K) {
1482
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
+ }
1484
+ else if (src0t == GGML_TYPE_Q3_K) {
1485
+ #ifdef GGML_QKK_64
1486
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1487
+ #else
1488
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1489
+ #endif
1490
+ }
1491
+ else if (src0t == GGML_TYPE_Q5_K) {
1492
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1493
+ }
1494
+ else if (src0t == GGML_TYPE_Q6_K) {
1495
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1496
+ } else {
1497
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1498
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1499
+ }
1500
+ }
1501
+ } break;
1502
+ case GGML_OP_MUL_MAT_ID:
1503
+ {
1504
+ //GGML_ASSERT(ne00 == ne10);
1505
+ //GGML_ASSERT(ne03 == ne13);
2181
1506
 
2182
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2183
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2184
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2185
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2186
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2187
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2188
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1507
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
2189
1508
 
2190
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2191
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2192
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2193
- const int32_t IW = src1->ne[0];
1509
+ const int n_as = ((int32_t *) dst->op_params)[1];
2194
1510
 
2195
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2196
- const int32_t KW = src0->ne[0];
1511
+ // TODO: make this more general
1512
+ GGML_ASSERT(n_as <= 8);
2197
1513
 
2198
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2199
- const int32_t OW = dst->ne[1];
1514
+ // max size of the src1ids array in the kernel stack
1515
+ GGML_ASSERT(ne11 <= 512);
2200
1516
 
2201
- const int32_t CHW = IC * KH * KW;
1517
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
2202
1518
 
2203
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2204
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1519
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1522
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
2205
1523
 
2206
- switch (src0->type) {
2207
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2208
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
2209
- default: GGML_ASSERT(false);
2210
- };
1524
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1525
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1526
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1527
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
2211
1528
 
2212
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2213
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2214
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2215
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2216
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2217
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2218
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2219
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2220
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2221
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2222
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2223
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2224
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2225
-
2226
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2227
- } break;
2228
- case GGML_OP_UPSCALE:
2229
- {
2230
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2231
-
2232
- const int sf = dst->op_params[0];
2233
-
2234
- [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2235
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2236
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2237
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2238
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2239
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2240
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2241
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2242
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2243
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2244
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2245
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2246
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2247
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2248
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2249
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2250
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2251
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2252
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2253
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2254
-
2255
- const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
2256
-
2257
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2258
- } break;
2259
- case GGML_OP_PAD:
2260
- {
2261
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2262
-
2263
- [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2264
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2265
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2266
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2267
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2268
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2269
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2270
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2271
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2272
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2273
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2274
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2275
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2276
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2277
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2278
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2279
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2280
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2281
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2282
-
2283
- const int nth = MIN(1024, ne0);
2284
-
2285
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2286
- } break;
2287
- case GGML_OP_ARGSORT:
2288
- {
2289
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2290
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2291
-
2292
- const int nrows = ggml_nrows(src0);
2293
-
2294
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2295
-
2296
- switch (order) {
2297
- case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2298
- case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2299
- default: GGML_ASSERT(false);
2300
- };
1529
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2301
1530
 
2302
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2303
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2304
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1531
+ GGML_ASSERT(!ggml_is_transposed(src2));
1532
+ GGML_ASSERT(!ggml_is_transposed(src1));
2305
1533
 
2306
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2307
- } break;
2308
- case GGML_OP_LEAKY_RELU:
2309
- {
2310
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
1534
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2311
1535
 
2312
- float slope;
2313
- memcpy(&slope, dst->op_params, sizeof(float));
1536
+ const uint r2 = ne12/ne22;
1537
+ const uint r3 = ne13/ne23;
2314
1538
 
2315
- [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2316
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2317
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2318
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
1539
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1540
+ // to the matrix-vector kernel
1541
+ int ne11_mm_min = n_as;
2319
1542
 
2320
- const int64_t n = ggml_nelements(dst);
1543
+ const int idx = ((int32_t *) dst->op_params)[0];
2321
1544
 
2322
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2323
- } break;
2324
- case GGML_OP_DUP:
2325
- case GGML_OP_CPY:
2326
- case GGML_OP_CONT:
2327
- {
2328
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1545
+ // batch size
1546
+ GGML_ASSERT(ne01 == ne11);
2329
1547
 
2330
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1548
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1549
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1550
+ // !!!
1551
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1552
+ // indirect matrix multiplication
1553
+ // !!!
1554
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1555
+ ne20 % 32 == 0 && ne20 >= 64 &&
1556
+ ne11 > ne11_mm_min) {
2331
1557
 
2332
- switch (src0t) {
1558
+ id<MTLComputePipelineState> pipeline = nil;
1559
+
1560
+ switch (src2->type) {
1561
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1562
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1563
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1564
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1565
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1566
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1567
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1568
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1569
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1570
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1571
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1572
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1573
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
+ }
1577
+
1578
+ [encoder setComputePipelineState:pipeline];
1579
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1580
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1581
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1582
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1583
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1584
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1585
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1586
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1587
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1588
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1589
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1590
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1591
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1592
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1593
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1594
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1595
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1596
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1597
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1598
+ // TODO: how to make this an array? read Metal docs
1599
+ for (int j = 0; j < 8; ++j) {
1600
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1601
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1602
+
1603
+ size_t offs_src_cur = 0;
1604
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1605
+
1606
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1607
+ }
1608
+
1609
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1610
+
1611
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1612
+ } else {
1613
+ int nth0 = 32;
1614
+ int nth1 = 1;
1615
+ int nrows = 1;
1616
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1617
+
1618
+ id<MTLComputePipelineState> pipeline = nil;
1619
+
1620
+ // use custom matrix x vector kernel
1621
+ switch (src2t) {
2333
1622
  case GGML_TYPE_F32:
2334
1623
  {
2335
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2336
-
2337
- switch (dstt) {
2338
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2339
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2340
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2341
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2342
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2343
- //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2344
- //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
2345
- default: GGML_ASSERT(false && "not implemented");
2346
- };
1624
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1625
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2347
1626
  } break;
2348
1627
  case GGML_TYPE_F16:
2349
1628
  {
2350
- switch (dstt) {
2351
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2352
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
2353
- default: GGML_ASSERT(false && "not implemented");
2354
- };
1629
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1630
+ nth0 = 32;
1631
+ nth1 = 1;
1632
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1633
+ } break;
1634
+ case GGML_TYPE_Q4_0:
1635
+ {
1636
+ nth0 = 8;
1637
+ nth1 = 8;
1638
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
1639
+ } break;
1640
+ case GGML_TYPE_Q4_1:
1641
+ {
1642
+ nth0 = 8;
1643
+ nth1 = 8;
1644
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
2355
1645
  } break;
2356
- default: GGML_ASSERT(false && "not implemented");
1646
+ case GGML_TYPE_Q5_0:
1647
+ {
1648
+ nth0 = 8;
1649
+ nth1 = 8;
1650
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1651
+ } break;
1652
+ case GGML_TYPE_Q5_1:
1653
+ {
1654
+ nth0 = 8;
1655
+ nth1 = 8;
1656
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1657
+ } break;
1658
+ case GGML_TYPE_Q8_0:
1659
+ {
1660
+ nth0 = 8;
1661
+ nth1 = 8;
1662
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1663
+ } break;
1664
+ case GGML_TYPE_Q2_K:
1665
+ {
1666
+ nth0 = 2;
1667
+ nth1 = 32;
1668
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
1669
+ } break;
1670
+ case GGML_TYPE_Q3_K:
1671
+ {
1672
+ nth0 = 2;
1673
+ nth1 = 32;
1674
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
1675
+ } break;
1676
+ case GGML_TYPE_Q4_K:
1677
+ {
1678
+ nth0 = 4; //1;
1679
+ nth1 = 8; //32;
1680
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
1681
+ } break;
1682
+ case GGML_TYPE_Q5_K:
1683
+ {
1684
+ nth0 = 2;
1685
+ nth1 = 32;
1686
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
1687
+ } break;
1688
+ case GGML_TYPE_Q6_K:
1689
+ {
1690
+ nth0 = 2;
1691
+ nth1 = 32;
1692
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
1693
+ } break;
1694
+ case GGML_TYPE_IQ2_XXS:
1695
+ {
1696
+ nth0 = 4;
1697
+ nth1 = 16;
1698
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
1699
+ } break;
1700
+ case GGML_TYPE_IQ2_XS:
1701
+ {
1702
+ nth0 = 4;
1703
+ nth1 = 16;
1704
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1705
+ } break;
1706
+ default:
1707
+ {
1708
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1709
+ GGML_ASSERT(false && "not implemented");
1710
+ }
1711
+ };
1712
+
1713
+ if (ggml_is_quantized(src2t)) {
1714
+ GGML_ASSERT(ne20 >= nth0*nth1);
2357
1715
  }
2358
1716
 
2359
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2360
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2361
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2362
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2363
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2364
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2365
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2366
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2367
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2368
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2369
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2370
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2371
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2372
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2373
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2374
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2375
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2376
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1717
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
2377
1718
 
2378
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2379
- } break;
2380
- default:
2381
- {
2382
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2383
- GGML_ASSERT(false);
1719
+ [encoder setComputePipelineState:pipeline];
1720
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1721
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1722
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1723
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1724
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1725
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1726
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1727
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1728
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1729
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1730
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1731
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1732
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1733
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1734
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1735
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1736
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1737
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1738
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1739
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1740
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1741
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1742
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1743
+ // TODO: how to make this an array? read Metal docs
1744
+ for (int j = 0; j < 8; ++j) {
1745
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1746
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1747
+
1748
+ size_t offs_src_cur = 0;
1749
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1750
+
1751
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1752
+ }
1753
+
1754
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1755
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1756
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1758
+ }
1759
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1760
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1761
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
+ }
1764
+ else if (src2t == GGML_TYPE_Q4_K) {
1765
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1766
+ }
1767
+ else if (src2t == GGML_TYPE_Q3_K) {
1768
+ #ifdef GGML_QKK_64
1769
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
+ #else
1771
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1772
+ #endif
1773
+ }
1774
+ else if (src2t == GGML_TYPE_Q5_K) {
1775
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1776
+ }
1777
+ else if (src2t == GGML_TYPE_Q6_K) {
1778
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1779
+ } else {
1780
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1781
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1782
+ }
1783
+ }
1784
+ } break;
1785
+ case GGML_OP_GET_ROWS:
1786
+ {
1787
+ id<MTLComputePipelineState> pipeline = nil;
1788
+
1789
+ switch (src0->type) {
1790
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
1791
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
1792
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
1793
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
1794
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
1795
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
1796
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
1797
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
1798
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
1799
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
1800
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
1801
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1802
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1803
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1804
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1805
+ default: GGML_ASSERT(false && "not implemented");
1806
+ }
1807
+
1808
+ [encoder setComputePipelineState:pipeline];
1809
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1810
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1811
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1812
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1813
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1814
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1815
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1816
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1817
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1818
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1819
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1820
+
1821
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1822
+ } break;
1823
+ case GGML_OP_RMS_NORM:
1824
+ {
1825
+ GGML_ASSERT(ne00 % 4 == 0);
1826
+
1827
+ float eps;
1828
+ memcpy(&eps, dst->op_params, sizeof(float));
1829
+
1830
+ int nth = 32; // SIMD width
1831
+
1832
+ while (nth < ne00/4 && nth < 1024) {
1833
+ nth *= 2;
1834
+ }
1835
+
1836
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
1837
+
1838
+ [encoder setComputePipelineState:pipeline];
1839
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1840
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1841
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1842
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1843
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1844
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1845
+
1846
+ const int64_t nrows = ggml_nrows(src0);
1847
+
1848
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1849
+ } break;
1850
+ case GGML_OP_GROUP_NORM:
1851
+ {
1852
+ GGML_ASSERT(ne00 % 4 == 0);
1853
+
1854
+ //float eps;
1855
+ //memcpy(&eps, dst->op_params, sizeof(float));
1856
+
1857
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1858
+
1859
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1860
+
1861
+ int nth = 32; // SIMD width
1862
+
1863
+ //while (nth < ne00/4 && nth < 1024) {
1864
+ // nth *= 2;
1865
+ //}
1866
+
1867
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
1868
+
1869
+ [encoder setComputePipelineState:pipeline];
1870
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1871
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1872
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1873
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1874
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1875
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1876
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1877
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1878
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1879
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1880
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1881
+
1882
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1883
+ } break;
1884
+ case GGML_OP_NORM:
1885
+ {
1886
+ float eps;
1887
+ memcpy(&eps, dst->op_params, sizeof(float));
1888
+
1889
+ const int nth = MIN(256, ne00);
1890
+
1891
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
1892
+
1893
+ [encoder setComputePipelineState:pipeline];
1894
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1895
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1896
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1897
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1898
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1899
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1900
+
1901
+ const int64_t nrows = ggml_nrows(src0);
1902
+
1903
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1904
+ } break;
1905
+ case GGML_OP_ALIBI:
1906
+ {
1907
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
1908
+
1909
+ const int nth = MIN(1024, ne00);
1910
+
1911
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1912
+ const int n_head = ((int32_t *) dst->op_params)[1];
1913
+ float max_bias;
1914
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1915
+
1916
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1917
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1918
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1919
+
1920
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
1921
+
1922
+ [encoder setComputePipelineState:pipeline];
1923
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1925
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1926
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1927
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1928
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1929
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1930
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1931
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1932
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1933
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1934
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1935
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1936
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1937
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1938
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1939
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1940
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1941
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1942
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1943
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1946
+ } break;
1947
+ case GGML_OP_ROPE:
1948
+ {
1949
+ GGML_ASSERT(ne10 == ne02);
1950
+
1951
+ const int nth = MIN(1024, ne00);
1952
+
1953
+ const int n_past = ((int32_t *) dst->op_params)[0];
1954
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1955
+ const int mode = ((int32_t *) dst->op_params)[2];
1956
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1957
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1958
+
1959
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1960
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1961
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1962
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1963
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1964
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1965
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1966
+
1967
+ id<MTLComputePipelineState> pipeline = nil;
1968
+
1969
+ switch (src0->type) {
1970
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
1971
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
1972
+ default: GGML_ASSERT(false);
1973
+ };
1974
+
1975
+ [encoder setComputePipelineState:pipeline];
1976
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1977
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1978
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1979
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1980
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1981
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1982
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1983
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1984
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1985
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1986
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1987
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1988
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1989
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1990
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1991
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1992
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1993
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1994
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1995
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1996
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1997
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1998
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1999
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2000
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2001
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2002
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2003
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2004
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2005
+
2006
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2007
+ } break;
2008
+ case GGML_OP_IM2COL:
2009
+ {
2010
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
2011
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
2012
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
2013
+
2014
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2015
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2016
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2017
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2018
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2019
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2020
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2021
+
2022
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
2023
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2024
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2025
+ const int32_t IW = src1->ne[0];
2026
+
2027
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2028
+ const int32_t KW = src0->ne[0];
2029
+
2030
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2031
+ const int32_t OW = dst->ne[1];
2032
+
2033
+ const int32_t CHW = IC * KH * KW;
2034
+
2035
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2036
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2037
+
2038
+ id<MTLComputePipelineState> pipeline = nil;
2039
+
2040
+ switch (src0->type) {
2041
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2042
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2043
+ default: GGML_ASSERT(false);
2044
+ };
2045
+
2046
+ [encoder setComputePipelineState:pipeline];
2047
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2048
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2049
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2050
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2051
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2052
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2053
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2054
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2055
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2056
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2057
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2058
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2059
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2060
+
2061
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2062
+ } break;
2063
+ case GGML_OP_UPSCALE:
2064
+ {
2065
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2066
+
2067
+ const int sf = dst->op_params[0];
2068
+
2069
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2070
+
2071
+ [encoder setComputePipelineState:pipeline];
2072
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2073
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2074
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2075
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2076
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2077
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2078
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2079
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2080
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2081
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2082
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2083
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2084
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2085
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2086
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2087
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2088
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2089
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2090
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2091
+
2092
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2093
+
2094
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2095
+ } break;
2096
+ case GGML_OP_PAD:
2097
+ {
2098
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2099
+
2100
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2101
+
2102
+ [encoder setComputePipelineState:pipeline];
2103
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2104
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2105
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2106
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2107
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2108
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2109
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2110
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2111
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2112
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2113
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2114
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2115
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2116
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2117
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2118
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2119
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2120
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2121
+
2122
+ const int nth = MIN(1024, ne0);
2123
+
2124
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2125
+ } break;
2126
+ case GGML_OP_ARGSORT:
2127
+ {
2128
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2129
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2130
+
2131
+ const int nrows = ggml_nrows(src0);
2132
+
2133
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2134
+
2135
+ id<MTLComputePipelineState> pipeline = nil;
2136
+
2137
+ switch (order) {
2138
+ case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2139
+ case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2140
+ default: GGML_ASSERT(false);
2141
+ };
2142
+
2143
+ [encoder setComputePipelineState:pipeline];
2144
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2145
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2146
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2147
+
2148
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2149
+ } break;
2150
+ case GGML_OP_LEAKY_RELU:
2151
+ {
2152
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2153
+
2154
+ float slope;
2155
+ memcpy(&slope, dst->op_params, sizeof(float));
2156
+
2157
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2158
+
2159
+ [encoder setComputePipelineState:pipeline];
2160
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2161
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2162
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2163
+
2164
+ const int64_t n = ggml_nelements(dst);
2165
+
2166
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2167
+ } break;
2168
+ case GGML_OP_DUP:
2169
+ case GGML_OP_CPY:
2170
+ case GGML_OP_CONT:
2171
+ {
2172
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2173
+
2174
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
2175
+
2176
+ id<MTLComputePipelineState> pipeline = nil;
2177
+
2178
+ switch (src0t) {
2179
+ case GGML_TYPE_F32:
2180
+ {
2181
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2182
+
2183
+ switch (dstt) {
2184
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2185
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2186
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2187
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2188
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2189
+ //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2190
+ //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2191
+ default: GGML_ASSERT(false && "not implemented");
2192
+ };
2193
+ } break;
2194
+ case GGML_TYPE_F16:
2195
+ {
2196
+ switch (dstt) {
2197
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2198
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2199
+ default: GGML_ASSERT(false && "not implemented");
2200
+ };
2201
+ } break;
2202
+ default: GGML_ASSERT(false && "not implemented");
2384
2203
  }
2385
- }
2386
- }
2387
2204
 
2388
- if (encoder != nil) {
2389
- [encoder endEncoding];
2390
- encoder = nil;
2205
+ [encoder setComputePipelineState:pipeline];
2206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2207
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2208
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2209
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2210
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2211
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2212
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2213
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2214
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2215
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2216
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2217
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2218
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2219
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2220
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2221
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2222
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2223
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2224
+
2225
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2226
+ } break;
2227
+ default:
2228
+ {
2229
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2230
+ GGML_ASSERT(false);
2231
+ }
2391
2232
  }
2392
2233
 
2393
- [command_buffer commit];
2394
- });
2395
- }
2234
+ #ifndef GGML_METAL_NDEBUG
2235
+ [encoder popDebugGroup];
2236
+ #endif
2237
+ }
2238
+
2239
+ if (encoder != nil) {
2240
+ [encoder endEncoding];
2241
+ encoder = nil;
2242
+ }
2396
2243
 
2397
- // wait for all threads to finish
2398
- dispatch_barrier_sync(ctx->d_queue, ^{});
2244
+ [command_buffer commit];
2245
+ });
2399
2246
 
2400
- // check status of command buffers
2247
+ // Wait for completion and check status of each command buffer
2401
2248
  // needed to detect if the device ran out-of-memory for example (#1881)
2402
- for (int i = 0; i < n_cb; i++) {
2403
- [ctx->command_buffers[i] waitUntilCompleted];
2404
2249
 
2405
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
2250
+ for (int i = 0; i < n_cb; ++i) {
2251
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
2252
+ [command_buffer waitUntilCompleted];
2253
+
2254
+ MTLCommandBufferStatus status = [command_buffer status];
2406
2255
  if (status != MTLCommandBufferStatusCompleted) {
2407
2256
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2408
- GGML_ASSERT(false);
2257
+ return false;
2409
2258
  }
2410
2259
  }
2411
2260
 
2261
+ return true;
2412
2262
  }
2413
2263
  }
2414
2264
 
@@ -2441,13 +2291,13 @@ static void ggml_backend_metal_free_device(void) {
2441
2291
  }
2442
2292
  }
2443
2293
 
2444
- static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2445
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2294
+ GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
2295
+ return "Metal";
2446
2296
 
2447
- return ctx->all_data;
2297
+ UNUSED(buffer);
2448
2298
  }
2449
2299
 
2450
- static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2300
+ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2451
2301
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2452
2302
 
2453
2303
  for (int i = 0; i < ctx->n_buffers; i++) {
@@ -2462,50 +2312,80 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
2462
2312
  free(ctx);
2463
2313
  }
2464
2314
 
2465
- static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2466
- memcpy((char *)tensor->data + offset, data, size);
2315
+ GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2316
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2467
2317
 
2468
- UNUSED(buffer);
2318
+ return ctx->all_data;
2469
2319
  }
2470
2320
 
2471
- static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2472
- memcpy(data, (const char *)tensor->data + offset, size);
2321
+ GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2322
+ memcpy((char *)tensor->data + offset, data, size);
2473
2323
 
2474
2324
  UNUSED(buffer);
2475
2325
  }
2476
2326
 
2477
- static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2478
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2327
+ GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2328
+ memcpy(data, (const char *)tensor->data + offset, size);
2479
2329
 
2480
2330
  UNUSED(buffer);
2481
2331
  }
2482
2332
 
2483
- static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2484
- ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
2333
+ GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
2334
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2335
+ memcpy(dst->data, src->data, ggml_nbytes(src));
2336
+ return true;
2337
+ }
2338
+ return false;
2485
2339
 
2486
2340
  UNUSED(buffer);
2487
2341
  }
2488
2342
 
2489
- static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2343
+ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2490
2344
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2491
2345
 
2492
2346
  memset(ctx->all_data, value, ctx->all_size);
2493
2347
  }
2494
2348
 
2495
2349
  static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
2350
+ /* .get_name = */ ggml_backend_metal_buffer_get_name,
2496
2351
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2497
2352
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
2498
2353
  /* .init_tensor = */ NULL,
2499
2354
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2500
2355
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2501
- /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
2502
- /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
2356
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
2503
2357
  /* .clear = */ ggml_backend_metal_buffer_clear,
2358
+ /* .reset = */ NULL,
2504
2359
  };
2505
2360
 
2506
2361
  // default buffer type
2507
2362
 
2508
- static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2363
+ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2364
+ return "Metal";
2365
+
2366
+ UNUSED(buft);
2367
+ }
2368
+
2369
+ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
2370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
2372
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2373
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2374
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2375
+
2376
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2377
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2378
+ } else {
2379
+ GGML_METAL_LOG_INFO("\n");
2380
+ }
2381
+ } else {
2382
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2383
+ }
2384
+ #endif
2385
+ UNUSED(device);
2386
+ }
2387
+
2388
+ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2509
2389
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2510
2390
 
2511
2391
  const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -2537,46 +2417,32 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
2537
2417
  }
2538
2418
 
2539
2419
  GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2540
-
2541
-
2542
- #if TARGET_OS_OSX
2543
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2544
- device.currentAllocatedSize / 1024.0 / 1024.0,
2545
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2546
-
2547
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2548
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2549
- } else {
2550
- GGML_METAL_LOG_INFO("\n");
2551
- }
2552
- #else
2553
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2554
- #endif
2555
-
2420
+ ggml_backend_metal_log_allocated_size(device);
2556
2421
 
2557
2422
  return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
2558
2423
  }
2559
2424
 
2560
- static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2425
+ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2561
2426
  return 32;
2562
2427
  UNUSED(buft);
2563
2428
  }
2564
2429
 
2565
- static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2430
+ GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2566
2431
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
2567
2432
 
2568
2433
  UNUSED(buft);
2569
2434
  }
2570
2435
 
2571
- static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2436
+ GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2572
2437
  return true;
2573
2438
 
2574
2439
  UNUSED(buft);
2575
2440
  }
2576
2441
 
2577
- ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2442
+ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2578
2443
  static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
2579
2444
  /* .iface = */ {
2445
+ /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
2580
2446
  /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2581
2447
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2582
2448
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@@ -2591,7 +2457,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2591
2457
 
2592
2458
  // buffer from ptr
2593
2459
 
2594
- ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2460
+ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2595
2461
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2596
2462
 
2597
2463
  ctx->all_data = data;
@@ -2600,6 +2466,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2600
2466
  ctx->n_buffers = 0;
2601
2467
 
2602
2468
  const size_t size_page = sysconf(_SC_PAGESIZE);
2469
+
2470
+ // page-align the data ptr
2471
+ {
2472
+ const uintptr_t offs = (uintptr_t) data % size_page;
2473
+ data = (void *) ((char *) data - offs);
2474
+ size += offs;
2475
+ }
2476
+
2603
2477
  size_t size_aligned = size;
2604
2478
  if ((size_aligned % size_page) != 0) {
2605
2479
  size_aligned += (size_page - (size_aligned % size_page));
@@ -2651,63 +2525,50 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2651
2525
  }
2652
2526
  }
2653
2527
 
2654
- #if TARGET_OS_OSX
2655
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2656
- device.currentAllocatedSize / 1024.0 / 1024.0,
2657
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2658
-
2659
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2660
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2661
- } else {
2662
- GGML_METAL_LOG_INFO("\n");
2663
- }
2664
- #else
2665
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2666
- #endif
2528
+ ggml_backend_metal_log_allocated_size(device);
2667
2529
 
2668
2530
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
2669
2531
  }
2670
2532
 
2671
2533
  // backend
2672
2534
 
2673
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2535
+ GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2674
2536
  return "Metal";
2675
2537
 
2676
2538
  UNUSED(backend);
2677
2539
  }
2678
2540
 
2679
- static void ggml_backend_metal_free(ggml_backend_t backend) {
2541
+ GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
2680
2542
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2681
2543
  ggml_metal_free(ctx);
2682
2544
  free(backend);
2683
2545
  }
2684
2546
 
2685
- static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2547
+ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2686
2548
  return ggml_backend_metal_buffer_type();
2687
2549
 
2688
2550
  UNUSED(backend);
2689
2551
  }
2690
2552
 
2691
- static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2553
+ GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2692
2554
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2693
2555
 
2694
- ggml_metal_graph_compute(metal_ctx, cgraph);
2556
+ return ggml_metal_graph_compute(metal_ctx, cgraph);
2695
2557
  }
2696
2558
 
2697
- static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2698
- return ggml_metal_supports_op(op);
2559
+ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2560
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2699
2561
 
2700
- UNUSED(backend);
2562
+ return ggml_metal_supports_op(metal_ctx, op);
2701
2563
  }
2702
2564
 
2703
- static struct ggml_backend_i metal_backend_i = {
2565
+ static struct ggml_backend_i ggml_backend_metal_i = {
2704
2566
  /* .get_name = */ ggml_backend_metal_name,
2705
2567
  /* .free = */ ggml_backend_metal_free,
2706
2568
  /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2707
2569
  /* .set_tensor_async = */ NULL,
2708
2570
  /* .get_tensor_async = */ NULL,
2709
- /* .cpy_tensor_from_async = */ NULL,
2710
- /* .cpy_tensor_to_async = */ NULL,
2571
+ /* .cpy_tensor_async = */ NULL,
2711
2572
  /* .synchronize = */ NULL,
2712
2573
  /* .graph_plan_create = */ NULL,
2713
2574
  /* .graph_plan_free = */ NULL,
@@ -2716,6 +2577,11 @@ static struct ggml_backend_i metal_backend_i = {
2716
2577
  /* .supports_op = */ ggml_backend_metal_supports_op,
2717
2578
  };
2718
2579
 
2580
+ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
2581
+ ggml_metal_log_callback = log_callback;
2582
+ ggml_metal_log_user_data = user_data;
2583
+ }
2584
+
2719
2585
  ggml_backend_t ggml_backend_metal_init(void) {
2720
2586
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2721
2587
 
@@ -2726,7 +2592,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2726
2592
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2727
2593
 
2728
2594
  *metal_backend = (struct ggml_backend) {
2729
- /* .interface = */ metal_backend_i,
2595
+ /* .interface = */ ggml_backend_metal_i,
2730
2596
  /* .context = */ ctx,
2731
2597
  };
2732
2598
 
@@ -2734,7 +2600,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2734
2600
  }
2735
2601
 
2736
2602
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2737
- return backend->iface.get_name == ggml_backend_metal_name;
2603
+ return backend && backend->iface.get_name == ggml_backend_metal_name;
2738
2604
  }
2739
2605
 
2740
2606
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -2742,7 +2608,7 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2742
2608
 
2743
2609
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2744
2610
 
2745
- ggml_metal_set_n_cb(ctx, n_cb);
2611
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
2746
2612
  }
2747
2613
 
2748
2614
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
@@ -2753,9 +2619,9 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2753
2619
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2754
2620
  }
2755
2621
 
2756
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2622
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2757
2623
 
2758
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2624
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2759
2625
  return ggml_backend_metal_init();
2760
2626
 
2761
2627
  GGML_UNUSED(params);