llama_cpp 0.12.1 → 0.12.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
27
+ #define GGML_METAL_MAX_KERNELS 256
28
28
 
29
29
  struct ggml_metal_buffer {
30
30
  const char * name;
@@ -35,6 +35,134 @@ struct ggml_metal_buffer {
35
35
  id<MTLBuffer> metal;
36
36
  };
37
37
 
38
+ struct ggml_metal_kernel {
39
+ id<MTLFunction> function;
40
+ id<MTLComputePipelineState> pipeline;
41
+ };
42
+
43
+ enum ggml_metal_kernel_type {
44
+ GGML_METAL_KERNEL_TYPE_ADD,
45
+ GGML_METAL_KERNEL_TYPE_ADD_ROW,
46
+ GGML_METAL_KERNEL_TYPE_MUL,
47
+ GGML_METAL_KERNEL_TYPE_MUL_ROW,
48
+ GGML_METAL_KERNEL_TYPE_DIV,
49
+ GGML_METAL_KERNEL_TYPE_DIV_ROW,
50
+ GGML_METAL_KERNEL_TYPE_SCALE,
51
+ GGML_METAL_KERNEL_TYPE_SCALE_4,
52
+ GGML_METAL_KERNEL_TYPE_TANH,
53
+ GGML_METAL_KERNEL_TYPE_RELU,
54
+ GGML_METAL_KERNEL_TYPE_GELU,
55
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK,
56
+ GGML_METAL_KERNEL_TYPE_SILU,
57
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX,
58
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
59
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
60
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
61
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
62
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
66
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
69
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
70
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
71
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
72
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
75
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
+ GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
+ GGML_METAL_KERNEL_TYPE_GROUP_NORM,
78
+ GGML_METAL_KERNEL_TYPE_NORM,
79
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
80
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
81
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
82
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
83
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
84
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
86
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
93
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
96
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
99
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
100
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
101
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
102
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
104
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
105
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
106
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
107
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
110
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
113
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
116
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
117
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
120
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
121
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
122
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
123
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
124
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
130
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
131
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
132
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
136
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
137
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
+ GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
+ GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
+ GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
+ GGML_METAL_KERNEL_TYPE_IM2COL_F16,
145
+ GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
+ GGML_METAL_KERNEL_TYPE_PAD_F32,
147
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
148
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
149
+ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
150
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
151
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
152
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
153
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
154
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
155
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
156
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
157
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
158
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
159
+ GGML_METAL_KERNEL_TYPE_CONCAT,
160
+ GGML_METAL_KERNEL_TYPE_SQR,
161
+ GGML_METAL_KERNEL_TYPE_SUM_ROWS,
162
+
163
+ GGML_METAL_KERNEL_TYPE_COUNT
164
+ };
165
+
38
166
  struct ggml_metal_context {
39
167
  int n_cb;
40
168
 
@@ -42,142 +170,15 @@ struct ggml_metal_context {
42
170
  id<MTLCommandQueue> queue;
43
171
  id<MTLLibrary> library;
44
172
 
45
- id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
46
- id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
47
-
48
173
  dispatch_queue_t d_queue;
49
174
 
50
175
  int n_buffers;
51
176
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
52
177
 
53
- int concur_list[GGML_MAX_CONCUR];
54
- int concur_list_len;
55
-
56
- // custom kernels
57
- #define GGML_METAL_DECL_KERNEL(name) \
58
- id<MTLFunction> function_##name; \
59
- id<MTLComputePipelineState> pipeline_##name
60
-
61
- GGML_METAL_DECL_KERNEL(add);
62
- GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
- GGML_METAL_DECL_KERNEL(mul);
64
- GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
- GGML_METAL_DECL_KERNEL(div);
66
- GGML_METAL_DECL_KERNEL(div_row);
67
- GGML_METAL_DECL_KERNEL(scale);
68
- GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(tanh);
70
- GGML_METAL_DECL_KERNEL(relu);
71
- GGML_METAL_DECL_KERNEL(gelu);
72
- GGML_METAL_DECL_KERNEL(gelu_quick);
73
- GGML_METAL_DECL_KERNEL(silu);
74
- GGML_METAL_DECL_KERNEL(soft_max);
75
- GGML_METAL_DECL_KERNEL(soft_max_4);
76
- GGML_METAL_DECL_KERNEL(diag_mask_inf);
77
- GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
78
- GGML_METAL_DECL_KERNEL(get_rows_f32);
79
- GGML_METAL_DECL_KERNEL(get_rows_f16);
80
- GGML_METAL_DECL_KERNEL(get_rows_q4_0);
81
- GGML_METAL_DECL_KERNEL(get_rows_q4_1);
82
- GGML_METAL_DECL_KERNEL(get_rows_q5_0);
83
- GGML_METAL_DECL_KERNEL(get_rows_q5_1);
84
- GGML_METAL_DECL_KERNEL(get_rows_q8_0);
85
- GGML_METAL_DECL_KERNEL(get_rows_q2_K);
86
- GGML_METAL_DECL_KERNEL(get_rows_q3_K);
87
- GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
- GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
- GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
- GGML_METAL_DECL_KERNEL(get_rows_i32);
91
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
92
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
93
- GGML_METAL_DECL_KERNEL(rms_norm);
94
- GGML_METAL_DECL_KERNEL(group_norm);
95
- GGML_METAL_DECL_KERNEL(norm);
96
- GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
97
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
98
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
99
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
100
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
101
- GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
102
- GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
103
- GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
104
- GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
105
- GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
106
- GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
107
- GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
108
- GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
109
- GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
110
- GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
111
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
112
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
113
- GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
114
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
115
- GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
116
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
117
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
118
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
119
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
120
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
121
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
122
- GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
123
- GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
124
- GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
125
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
126
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
127
- GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
128
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
129
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
130
- GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
131
- GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
132
- GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
133
- GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
134
- GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
135
- GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
136
- GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
137
- GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
138
- GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
139
- GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
140
- GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
141
- GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
142
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
143
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
144
- GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
145
- GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
146
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
147
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
148
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
149
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
150
- GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
151
- GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
152
- GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
153
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
154
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
155
- GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
156
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
157
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
158
- GGML_METAL_DECL_KERNEL(rope_f32);
159
- GGML_METAL_DECL_KERNEL(rope_f16);
160
- GGML_METAL_DECL_KERNEL(alibi_f32);
161
- GGML_METAL_DECL_KERNEL(im2col_f16);
162
- GGML_METAL_DECL_KERNEL(upscale_f32);
163
- GGML_METAL_DECL_KERNEL(pad_f32);
164
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
165
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
166
- GGML_METAL_DECL_KERNEL(leaky_relu_f32);
167
- GGML_METAL_DECL_KERNEL(cpy_f32_f16);
168
- GGML_METAL_DECL_KERNEL(cpy_f32_f32);
169
- GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
170
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
171
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
172
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
173
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
174
- GGML_METAL_DECL_KERNEL(cpy_f16_f16);
175
- GGML_METAL_DECL_KERNEL(cpy_f16_f32);
176
- GGML_METAL_DECL_KERNEL(concat);
177
- GGML_METAL_DECL_KERNEL(sqr);
178
- GGML_METAL_DECL_KERNEL(sum_rows);
179
-
180
- #undef GGML_METAL_DECL_KERNEL
178
+ struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
179
+
180
+ bool support_simdgroup_reduction;
181
+ bool support_simdgroup_mm;
181
182
  };
182
183
 
183
184
  // MSL code
@@ -191,7 +192,6 @@ struct ggml_metal_context {
191
192
  @implementation GGMLMetalClass
192
193
  @end
193
194
 
194
-
195
195
  static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
196
196
  fprintf(stderr, "%s", msg);
197
197
 
@@ -202,11 +202,6 @@ static void ggml_metal_default_log_callback(enum ggml_log_level level, const cha
202
202
  ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
203
203
  void * ggml_metal_log_user_data = NULL;
204
204
 
205
- void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
206
- ggml_metal_log_callback = log_callback;
207
- ggml_metal_log_user_data = user_data;
208
- }
209
-
210
205
  GGML_ATTRIBUTE_FORMAT(2, 3)
211
206
  static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
212
207
  if (ggml_metal_log_callback != NULL) {
@@ -229,7 +224,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
229
224
  }
230
225
  }
231
226
 
232
- struct ggml_metal_context * ggml_metal_init(int n_cb) {
227
+ static void * ggml_metal_host_malloc(size_t n) {
228
+ void * data = NULL;
229
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
230
+ if (result != 0) {
231
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
232
+ return NULL;
233
+ }
234
+
235
+ return data;
236
+ }
237
+
238
+ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
233
239
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
234
240
 
235
241
  id<MTLDevice> device;
@@ -255,7 +261,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
255
261
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
256
262
  ctx->queue = [ctx->device newCommandQueue];
257
263
  ctx->n_buffers = 0;
258
- ctx->concur_list_len = 0;
259
264
 
260
265
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
261
266
 
@@ -298,19 +303,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
298
303
  return NULL;
299
304
  }
300
305
 
301
- MTLCompileOptions* options = nil;
306
+ // dictionary of preprocessor macros
307
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
308
+
302
309
  #ifdef GGML_QKK_64
303
- options = [MTLCompileOptions new];
304
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
310
+ prep[@"QK_K"] = @(64);
305
311
  #endif
306
- // try to disable fast-math
307
- // NOTE: this seems to have no effect whatsoever
308
- // instead, in order to disable fast-math, we have to build default.metallib from the command line
309
- // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
310
- // and go through the "pre-compiled library found" path above
312
+
313
+ MTLCompileOptions* options = [MTLCompileOptions new];
314
+ options.preprocessorMacros = prep;
315
+
311
316
  //[options setFastMathEnabled:false];
312
317
 
313
318
  ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
319
+
320
+ [options release];
321
+ [prep release];
314
322
  }
315
323
 
316
324
  if (error) {
@@ -319,22 +327,51 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
319
327
  }
320
328
  }
321
329
 
322
- #if TARGET_OS_OSX
323
330
  // print MTL GPU family:
324
331
  GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
325
332
 
333
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
334
+
326
335
  // determine max supported GPU family
327
336
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
328
337
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
329
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
330
- if ([ctx->device supportsFamily:i]) {
331
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
332
- break;
338
+ {
339
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
340
+ if ([ctx->device supportsFamily:i]) {
341
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
342
+ break;
343
+ }
344
+ }
345
+
346
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
347
+ if ([ctx->device supportsFamily:i]) {
348
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
349
+ break;
350
+ }
351
+ }
352
+
353
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
354
+ if ([ctx->device supportsFamily:i]) {
355
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
356
+ break;
357
+ }
333
358
  }
334
359
  }
335
360
 
361
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
362
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
363
+
364
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
365
+
366
+ GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
367
+ GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
336
368
  GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
337
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
369
+
370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
372
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
373
+ }
374
+ #elif TARGET_OS_OSX
338
375
  if (ctx->device.maxTransferRate != 0) {
339
376
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
340
377
  } else {
@@ -346,279 +383,171 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
346
383
  {
347
384
  NSError * error = nil;
348
385
 
386
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
387
+ ctx->kernels[i].function = nil;
388
+ ctx->kernels[i].pipeline = nil;
389
+ }
390
+
349
391
  /*
350
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
351
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
352
- (int) ctx->pipeline_##name.threadExecutionWidth); \
392
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
393
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
394
+ (int) kernel->pipeline.threadExecutionWidth); \
353
395
  */
354
- #define GGML_METAL_ADD_KERNEL(name) \
355
- ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
356
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
357
- if (error) { \
358
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
359
- return NULL; \
396
+ #define GGML_METAL_ADD_KERNEL(e, name, supported) \
397
+ if (supported) { \
398
+ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
399
+ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
400
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
401
+ if (error) { \
402
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
403
+ return NULL; \
404
+ } \
405
+ } else { \
406
+ GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
360
407
  }
361
408
 
362
- GGML_METAL_ADD_KERNEL(add);
363
- GGML_METAL_ADD_KERNEL(add_row);
364
- GGML_METAL_ADD_KERNEL(mul);
365
- GGML_METAL_ADD_KERNEL(mul_row);
366
- GGML_METAL_ADD_KERNEL(div);
367
- GGML_METAL_ADD_KERNEL(div_row);
368
- GGML_METAL_ADD_KERNEL(scale);
369
- GGML_METAL_ADD_KERNEL(scale_4);
370
- GGML_METAL_ADD_KERNEL(tanh);
371
- GGML_METAL_ADD_KERNEL(relu);
372
- GGML_METAL_ADD_KERNEL(gelu);
373
- GGML_METAL_ADD_KERNEL(gelu_quick);
374
- GGML_METAL_ADD_KERNEL(silu);
375
- GGML_METAL_ADD_KERNEL(soft_max);
376
- GGML_METAL_ADD_KERNEL(soft_max_4);
377
- GGML_METAL_ADD_KERNEL(diag_mask_inf);
378
- GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
379
- GGML_METAL_ADD_KERNEL(get_rows_f32);
380
- GGML_METAL_ADD_KERNEL(get_rows_f16);
381
- GGML_METAL_ADD_KERNEL(get_rows_q4_0);
382
- GGML_METAL_ADD_KERNEL(get_rows_q4_1);
383
- GGML_METAL_ADD_KERNEL(get_rows_q5_0);
384
- GGML_METAL_ADD_KERNEL(get_rows_q5_1);
385
- GGML_METAL_ADD_KERNEL(get_rows_q8_0);
386
- GGML_METAL_ADD_KERNEL(get_rows_q2_K);
387
- GGML_METAL_ADD_KERNEL(get_rows_q3_K);
388
- GGML_METAL_ADD_KERNEL(get_rows_q4_K);
389
- GGML_METAL_ADD_KERNEL(get_rows_q5_K);
390
- GGML_METAL_ADD_KERNEL(get_rows_q6_K);
391
- GGML_METAL_ADD_KERNEL(get_rows_i32);
392
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
393
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
394
- GGML_METAL_ADD_KERNEL(rms_norm);
395
- GGML_METAL_ADD_KERNEL(group_norm);
396
- GGML_METAL_ADD_KERNEL(norm);
397
- GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
398
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
399
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
400
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
401
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
402
- GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
403
- GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
404
- GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
405
- GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
406
- GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
407
- GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
408
- GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
409
- GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
410
- GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
411
- GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
412
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
413
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
414
- GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
415
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
416
- GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
417
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
418
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
419
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
420
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
421
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
422
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
423
- GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
424
- GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
425
- GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
426
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
427
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
428
- GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
429
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
430
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
431
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
432
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
433
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
434
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
435
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
436
- GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
437
- GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
438
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
439
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
440
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
441
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
442
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
443
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
444
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
445
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
446
- GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
447
- GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
448
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
449
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
450
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
451
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
452
- GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
453
- GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
454
- GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
455
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
456
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
457
- GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
458
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
459
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
460
- }
461
- GGML_METAL_ADD_KERNEL(rope_f32);
462
- GGML_METAL_ADD_KERNEL(rope_f16);
463
- GGML_METAL_ADD_KERNEL(alibi_f32);
464
- GGML_METAL_ADD_KERNEL(im2col_f16);
465
- GGML_METAL_ADD_KERNEL(upscale_f32);
466
- GGML_METAL_ADD_KERNEL(pad_f32);
467
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
468
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
469
- GGML_METAL_ADD_KERNEL(leaky_relu_f32);
470
- GGML_METAL_ADD_KERNEL(cpy_f32_f16);
471
- GGML_METAL_ADD_KERNEL(cpy_f32_f32);
472
- GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
473
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
474
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
475
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
476
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
477
- GGML_METAL_ADD_KERNEL(cpy_f16_f16);
478
- GGML_METAL_ADD_KERNEL(cpy_f16_f32);
479
- GGML_METAL_ADD_KERNEL(concat);
480
- GGML_METAL_ADD_KERNEL(sqr);
481
- GGML_METAL_ADD_KERNEL(sum_rows);
482
-
483
- #undef GGML_METAL_ADD_KERNEL
409
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
410
+
411
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
412
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
413
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
414
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
415
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
416
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
417
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
418
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
419
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
420
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
421
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
422
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
423
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
424
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
425
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
426
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
427
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
428
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
429
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
430
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
431
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
432
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
433
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
434
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
436
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
437
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
438
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
439
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
440
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
441
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
442
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
443
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
444
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
445
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
446
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
447
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
448
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
449
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
450
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
451
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
452
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
453
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
454
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
459
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
460
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
461
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
463
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
464
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
465
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
466
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
467
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
468
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
469
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
470
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
471
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
476
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
477
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
478
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
482
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
484
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
486
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
487
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
488
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
489
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
491
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
502
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
504
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
507
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
511
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
518
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
519
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
522
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
523
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
525
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
527
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
528
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
484
529
  }
485
530
 
486
531
  return ctx;
487
532
  }
488
533
 
489
- void ggml_metal_free(struct ggml_metal_context * ctx) {
534
+ static void ggml_metal_free(struct ggml_metal_context * ctx) {
490
535
  GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
491
- #define GGML_METAL_DEL_KERNEL(name) \
492
- [ctx->function_##name release]; \
493
- [ctx->pipeline_##name release];
494
-
495
- GGML_METAL_DEL_KERNEL(add);
496
- GGML_METAL_DEL_KERNEL(add_row);
497
- GGML_METAL_DEL_KERNEL(mul);
498
- GGML_METAL_DEL_KERNEL(mul_row);
499
- GGML_METAL_DEL_KERNEL(div);
500
- GGML_METAL_DEL_KERNEL(div_row);
501
- GGML_METAL_DEL_KERNEL(scale);
502
- GGML_METAL_DEL_KERNEL(scale_4);
503
- GGML_METAL_DEL_KERNEL(tanh);
504
- GGML_METAL_DEL_KERNEL(relu);
505
- GGML_METAL_DEL_KERNEL(gelu);
506
- GGML_METAL_DEL_KERNEL(gelu_quick);
507
- GGML_METAL_DEL_KERNEL(silu);
508
- GGML_METAL_DEL_KERNEL(soft_max);
509
- GGML_METAL_DEL_KERNEL(soft_max_4);
510
- GGML_METAL_DEL_KERNEL(diag_mask_inf);
511
- GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
512
- GGML_METAL_DEL_KERNEL(get_rows_f32);
513
- GGML_METAL_DEL_KERNEL(get_rows_f16);
514
- GGML_METAL_DEL_KERNEL(get_rows_q4_0);
515
- GGML_METAL_DEL_KERNEL(get_rows_q4_1);
516
- GGML_METAL_DEL_KERNEL(get_rows_q5_0);
517
- GGML_METAL_DEL_KERNEL(get_rows_q5_1);
518
- GGML_METAL_DEL_KERNEL(get_rows_q8_0);
519
- GGML_METAL_DEL_KERNEL(get_rows_q2_K);
520
- GGML_METAL_DEL_KERNEL(get_rows_q3_K);
521
- GGML_METAL_DEL_KERNEL(get_rows_q4_K);
522
- GGML_METAL_DEL_KERNEL(get_rows_q5_K);
523
- GGML_METAL_DEL_KERNEL(get_rows_q6_K);
524
- GGML_METAL_DEL_KERNEL(get_rows_i32);
525
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
526
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
527
- GGML_METAL_DEL_KERNEL(rms_norm);
528
- GGML_METAL_DEL_KERNEL(group_norm);
529
- GGML_METAL_DEL_KERNEL(norm);
530
- GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
531
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
532
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
533
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
534
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
535
- GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
536
- GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
537
- GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
538
- GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
539
- GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
540
- GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
541
- GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
542
- GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
543
- GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
544
- GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
545
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
546
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
547
- GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
548
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
549
- GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
550
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
551
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
552
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
553
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
554
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
555
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
556
- GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
557
- GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
558
- GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
559
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
560
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
561
- GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
562
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
563
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
564
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
565
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
566
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
567
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
568
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
569
- GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
570
- GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
571
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
572
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
573
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
574
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
575
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
576
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
577
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
578
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
579
- GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
580
- GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
581
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
582
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
583
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
584
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
585
- GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
586
- GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
587
- GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
588
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
589
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
590
- GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
591
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
592
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
593
- }
594
- GGML_METAL_DEL_KERNEL(rope_f32);
595
- GGML_METAL_DEL_KERNEL(rope_f16);
596
- GGML_METAL_DEL_KERNEL(alibi_f32);
597
- GGML_METAL_DEL_KERNEL(im2col_f16);
598
- GGML_METAL_DEL_KERNEL(upscale_f32);
599
- GGML_METAL_DEL_KERNEL(pad_f32);
600
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
601
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
602
- GGML_METAL_DEL_KERNEL(leaky_relu_f32);
603
- GGML_METAL_DEL_KERNEL(cpy_f32_f16);
604
- GGML_METAL_DEL_KERNEL(cpy_f32_f32);
605
- GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
606
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
607
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
608
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
609
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
610
- GGML_METAL_DEL_KERNEL(cpy_f16_f16);
611
- GGML_METAL_DEL_KERNEL(cpy_f16_f32);
612
- GGML_METAL_DEL_KERNEL(concat);
613
- GGML_METAL_DEL_KERNEL(sqr);
614
- GGML_METAL_DEL_KERNEL(sum_rows);
615
-
616
- #undef GGML_METAL_DEL_KERNEL
617
536
 
618
537
  for (int i = 0; i < ctx->n_buffers; ++i) {
619
538
  [ctx->buffers[i].metal release];
620
539
  }
621
540
 
541
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
542
+ if (ctx->kernels[i].pipeline) {
543
+ [ctx->kernels[i].pipeline release];
544
+ }
545
+
546
+ if (ctx->kernels[i].function) {
547
+ [ctx->kernels[i].function release];
548
+ }
549
+ }
550
+
622
551
  [ctx->library release];
623
552
  [ctx->queue release];
624
553
  [ctx->device release];
@@ -628,33 +557,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
628
557
  free(ctx);
629
558
  }
630
559
 
631
- void * ggml_metal_host_malloc(size_t n) {
632
- void * data = NULL;
633
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
634
- if (result != 0) {
635
- GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
636
- return NULL;
637
- }
638
-
639
- return data;
640
- }
641
-
642
- void ggml_metal_host_free(void * data) {
643
- free(data);
644
- }
645
-
646
- void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
647
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
648
- }
649
-
650
- int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
651
- return ctx->concur_list_len;
652
- }
653
-
654
- int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
655
- return ctx->concur_list;
656
- }
657
-
658
560
  // temporarily defined here for compatibility between ggml-backend and the old API
659
561
 
660
562
  struct ggml_backend_metal_buffer {
@@ -727,210 +629,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
727
629
  return nil;
728
630
  }
729
631
 
730
- bool ggml_metal_add_buffer(
731
- struct ggml_metal_context * ctx,
732
- const char * name,
733
- void * data,
734
- size_t size,
735
- size_t max_size) {
736
- if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
737
- GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
738
- return false;
739
- }
740
-
741
- if (data) {
742
- // verify that the buffer does not overlap with any of the existing buffers
743
- for (int i = 0; i < ctx->n_buffers; ++i) {
744
- const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
745
-
746
- if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
747
- GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
748
- return false;
749
- }
750
- }
751
-
752
- const size_t size_page = sysconf(_SC_PAGESIZE);
753
-
754
- size_t size_aligned = size;
755
- if ((size_aligned % size_page) != 0) {
756
- size_aligned += (size_page - (size_aligned % size_page));
757
- }
758
-
759
- // the buffer fits into the max buffer size allowed by the device
760
- if (size_aligned <= ctx->device.maxBufferLength) {
761
- ctx->buffers[ctx->n_buffers].name = name;
762
- ctx->buffers[ctx->n_buffers].data = data;
763
- ctx->buffers[ctx->n_buffers].size = size;
764
-
765
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
766
-
767
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
768
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
769
- return false;
770
- }
771
-
772
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
773
-
774
- ++ctx->n_buffers;
775
- } else {
776
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
777
- // one of the views
778
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
779
- const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
780
- const size_t size_view = ctx->device.maxBufferLength;
781
-
782
- for (size_t i = 0; i < size; i += size_step) {
783
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
784
-
785
- ctx->buffers[ctx->n_buffers].name = name;
786
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
787
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
788
-
789
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
790
-
791
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
792
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
793
- return false;
794
- }
795
-
796
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
797
- if (i + size_step < size) {
798
- GGML_METAL_LOG_INFO("\n");
799
- }
800
-
801
- ++ctx->n_buffers;
802
- }
803
- }
804
-
805
- #if TARGET_OS_OSX
806
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
807
- ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
808
- ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
809
-
810
- if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
811
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
812
- } else {
813
- GGML_METAL_LOG_INFO("\n");
814
- }
815
- #else
816
- GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
817
- #endif
818
- }
819
-
820
- return true;
821
- }
822
-
823
- void ggml_metal_set_tensor(
824
- struct ggml_metal_context * ctx,
825
- struct ggml_tensor * t) {
826
- size_t offs;
827
- id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
828
-
829
- memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
830
- }
831
-
832
- void ggml_metal_get_tensor(
833
- struct ggml_metal_context * ctx,
834
- struct ggml_tensor * t) {
835
- size_t offs;
836
- id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
837
-
838
- memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
839
- }
840
-
841
- void ggml_metal_graph_find_concurrency(
842
- struct ggml_metal_context * ctx,
843
- struct ggml_cgraph * gf, bool check_mem) {
844
- int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
845
- int nodes_unused[GGML_MAX_CONCUR];
846
-
847
- for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
848
- for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
849
- ctx->concur_list_len = 0;
850
-
851
- int n_left = gf->n_nodes;
852
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
853
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
854
-
855
- while (n_left > 0) {
856
- // number of nodes at a layer (that can be issued concurrently)
857
- int concurrency = 0;
858
- for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
859
- if (nodes_unused[i]) {
860
- // if the requirements for gf->nodes[i] are satisfied
861
- int exe_flag = 1;
862
-
863
- // scan all srcs
864
- for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
865
- struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
866
- if (src_cur) {
867
- // if is leaf nodes it's satisfied.
868
- // TODO: ggml_is_leaf()
869
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
870
- continue;
871
- }
872
-
873
- // otherwise this src should be the output from previous nodes.
874
- int is_found = 0;
875
-
876
- // scan 2*search_depth back because we inserted barrier.
877
- //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
878
- for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
879
- if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
880
- is_found = 1;
881
- break;
882
- }
883
- }
884
- if (is_found == 0) {
885
- exe_flag = 0;
886
- break;
887
- }
888
- }
889
- }
890
- if (exe_flag && check_mem) {
891
- // check if nodes[i]'s data will be overwritten by a node before nodes[i].
892
- // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
893
- int64_t data_start = (int64_t) gf->nodes[i]->data;
894
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
895
- for (int j = n_start; j < i; j++) {
896
- if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
897
- && gf->nodes[j]->op != GGML_OP_VIEW \
898
- && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
899
- && gf->nodes[j]->op != GGML_OP_PERMUTE) {
900
- if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
901
- ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
902
- continue;
903
- }
904
-
905
- exe_flag = 0;
906
- }
907
- }
908
- }
909
- if (exe_flag) {
910
- ctx->concur_list[level_pos + concurrency] = i;
911
- nodes_unused[i] = 0;
912
- concurrency++;
913
- ctx->concur_list_len++;
914
- }
915
- }
916
- }
917
- n_left -= concurrency;
918
- // adding a barrier different layer
919
- ctx->concur_list[level_pos + concurrency] = -1;
920
- ctx->concur_list_len++;
921
- // jump all sorted nodes at nodes_bak
922
- while (!nodes_unused[n_start]) {
923
- n_start++;
924
- }
925
- level_pos += concurrency + 1;
926
- }
927
-
928
- if (ctx->concur_list_len > GGML_MAX_CONCUR) {
929
- GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
930
- }
931
- }
932
-
933
- static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
632
+ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
934
633
  switch (op->op) {
935
634
  case GGML_OP_UNARY:
936
635
  switch (ggml_get_unary_op(op)) {
@@ -956,9 +655,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
956
655
  case GGML_OP_SCALE:
957
656
  case GGML_OP_SQR:
958
657
  case GGML_OP_SUM_ROWS:
658
+ return true;
959
659
  case GGML_OP_SOFT_MAX:
960
660
  case GGML_OP_RMS_NORM:
961
661
  case GGML_OP_GROUP_NORM:
662
+ return ctx->support_simdgroup_reduction;
962
663
  case GGML_OP_NORM:
963
664
  case GGML_OP_ALIBI:
964
665
  case GGML_OP_ROPE:
@@ -967,9 +668,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
967
668
  case GGML_OP_PAD:
968
669
  case GGML_OP_ARGSORT:
969
670
  case GGML_OP_LEAKY_RELU:
671
+ return true;
970
672
  case GGML_OP_MUL_MAT:
971
673
  case GGML_OP_MUL_MAT_ID:
972
- return true;
674
+ return ctx->support_simdgroup_reduction;
973
675
  case GGML_OP_CPY:
974
676
  case GGML_OP_DUP:
975
677
  case GGML_OP_CONT:
@@ -1007,1480 +709,1549 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
1007
709
  return false;
1008
710
  }
1009
711
  }
1010
- bool ggml_metal_graph_compute(
712
+
713
+ static bool ggml_metal_graph_compute(
1011
714
  struct ggml_metal_context * ctx,
1012
715
  struct ggml_cgraph * gf) {
1013
716
  @autoreleasepool {
1014
717
 
1015
- // if there is ctx->concur_list, dispatch concurrently
1016
- // else fallback to serial dispatch
1017
718
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
1018
-
1019
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
1020
-
1021
- const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
1022
- edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
719
+ edesc.dispatchType = MTLDispatchTypeSerial;
1023
720
 
1024
721
  // create multiple command buffers and enqueue them
1025
722
  // then, we encode the graph into the command buffers in parallel
1026
723
 
724
+ const int n_nodes = gf->n_nodes;
1027
725
  const int n_cb = ctx->n_cb;
726
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
1028
727
 
1029
- for (int i = 0; i < n_cb; ++i) {
1030
- ctx->command_buffers[i] = [ctx->queue commandBuffer];
728
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
729
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
731
+ command_buffer_builder[cb_idx] = command_buffer;
1031
732
 
1032
733
  // enqueue the command buffers in order to specify their execution order
1033
- [ctx->command_buffers[i] enqueue];
1034
-
1035
- ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
734
+ [command_buffer enqueue];
1036
735
  }
736
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
1037
737
 
1038
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
1039
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
1040
-
1041
- dispatch_async(ctx->d_queue, ^{
1042
- size_t offs_src0 = 0;
1043
- size_t offs_src1 = 0;
1044
- size_t offs_dst = 0;
1045
-
1046
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
1047
- id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
1048
-
1049
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
1050
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1051
-
1052
- for (int ind = node_start; ind < node_end; ++ind) {
1053
- const int i = has_concur ? ctx->concur_list[ind] : ind;
1054
-
1055
- if (i == -1) {
1056
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
1057
- continue;
1058
- }
1059
-
1060
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
1061
-
1062
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1063
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1064
- struct ggml_tensor * dst = gf->nodes[i];
1065
-
1066
- switch (dst->op) {
1067
- case GGML_OP_NONE:
1068
- case GGML_OP_RESHAPE:
1069
- case GGML_OP_VIEW:
1070
- case GGML_OP_TRANSPOSE:
1071
- case GGML_OP_PERMUTE:
1072
- {
1073
- // noop -> next node
1074
- } continue;
1075
- default:
1076
- {
1077
- } break;
1078
- }
738
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
739
+ const int cb_idx = iter;
1079
740
 
1080
- if (!ggml_metal_supports_op(dst)) {
1081
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1082
- GGML_ASSERT(!"unsupported op");
1083
- }
1084
-
1085
- #ifndef GGML_METAL_NDEBUG
1086
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
1087
- #endif
741
+ size_t offs_src0 = 0;
742
+ size_t offs_src1 = 0;
743
+ size_t offs_dst = 0;
1088
744
 
1089
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
1090
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
1091
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
1092
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
1093
-
1094
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
1095
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
1096
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
1097
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
1098
-
1099
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
1100
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
1101
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
1102
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
1103
-
1104
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
1105
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
1106
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
1107
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
1108
-
1109
- const int64_t ne0 = dst ? dst->ne[0] : 0;
1110
- const int64_t ne1 = dst ? dst->ne[1] : 0;
1111
- const int64_t ne2 = dst ? dst->ne[2] : 0;
1112
- const int64_t ne3 = dst ? dst->ne[3] : 0;
1113
-
1114
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
1115
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
1116
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
1117
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
1118
-
1119
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1120
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1121
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1122
-
1123
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
1124
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
1125
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
1126
-
1127
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1128
- //if (src0) {
1129
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1130
- // ggml_is_contiguous(src0), src0->name);
1131
- //}
1132
- //if (src1) {
1133
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1134
- // ggml_is_contiguous(src1), src1->name);
1135
- //}
1136
- //if (dst) {
1137
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1138
- // dst->name);
1139
- //}
1140
-
1141
- switch (dst->op) {
1142
- case GGML_OP_CONCAT:
1143
- {
1144
- const int64_t nb = ne00;
1145
-
1146
- [encoder setComputePipelineState:ctx->pipeline_concat];
1147
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1148
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1149
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1150
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1151
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1152
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1153
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1154
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1155
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1156
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1157
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1158
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1159
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1160
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1161
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1162
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1163
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1164
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1165
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1166
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1167
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1168
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1169
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1170
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1171
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1172
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1173
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1174
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1175
-
1176
- const int nth = MIN(1024, ne0);
1177
-
1178
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1179
- } break;
1180
- case GGML_OP_ADD:
1181
- case GGML_OP_MUL:
1182
- case GGML_OP_DIV:
1183
- {
1184
- const size_t offs = 0;
1185
-
1186
- bool bcast_row = false;
1187
-
1188
- int64_t nb = ne00;
745
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
746
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1189
747
 
1190
- id<MTLComputePipelineState> pipeline = nil;
748
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
749
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1191
750
 
1192
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1193
- GGML_ASSERT(ggml_is_contiguous(src0));
751
+ for (int i = node_start; i < node_end; ++i) {
752
+ if (i == -1) {
753
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
754
+ continue;
755
+ }
1194
756
 
1195
- // src1 is a row
1196
- GGML_ASSERT(ne11 == 1);
757
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
758
+
759
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
760
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
761
+ struct ggml_tensor * dst = gf->nodes[i];
762
+
763
+ switch (dst->op) {
764
+ case GGML_OP_NONE:
765
+ case GGML_OP_RESHAPE:
766
+ case GGML_OP_VIEW:
767
+ case GGML_OP_TRANSPOSE:
768
+ case GGML_OP_PERMUTE:
769
+ {
770
+ // noop -> next node
771
+ } continue;
772
+ default:
773
+ {
774
+ } break;
775
+ }
1197
776
 
1198
- nb = ne00 / 4;
1199
- switch (dst->op) {
1200
- case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1201
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1202
- case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1203
- default: GGML_ASSERT(false);
1204
- }
777
+ if (!ggml_metal_supports_op(ctx, dst)) {
778
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
779
+ GGML_ASSERT(!"unsupported op");
780
+ }
1205
781
 
1206
- bcast_row = true;
1207
- } else {
1208
- switch (dst->op) {
1209
- case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1210
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1211
- case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1212
- default: GGML_ASSERT(false);
1213
- }
1214
- }
782
+ #ifndef GGML_METAL_NDEBUG
783
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
784
+ #endif
1215
785
 
1216
- [encoder setComputePipelineState:pipeline];
1217
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1218
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1219
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1220
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1221
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1222
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1223
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1224
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1225
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1226
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1227
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1228
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1229
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1230
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1231
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1232
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1233
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1234
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1235
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1236
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1237
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1238
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1239
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1240
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1241
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1242
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1243
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1244
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1245
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1246
-
1247
- if (bcast_row) {
1248
- const int64_t n = ggml_nelements(dst)/4;
786
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
787
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
788
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
789
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
790
+
791
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
792
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
793
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
794
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
795
+
796
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
797
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
798
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
799
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
800
+
801
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
802
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
803
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
804
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
805
+
806
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
807
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
808
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
809
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
810
+
811
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
812
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
813
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
814
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
815
+
816
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
817
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
818
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
819
+
820
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
821
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
822
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
823
+
824
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
825
+ //if (src0) {
826
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
827
+ // ggml_is_contiguous(src0), src0->name);
828
+ //}
829
+ //if (src1) {
830
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
831
+ // ggml_is_contiguous(src1), src1->name);
832
+ //}
833
+ //if (dst) {
834
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
835
+ // dst->name);
836
+ //}
837
+
838
+ switch (dst->op) {
839
+ case GGML_OP_CONCAT:
840
+ {
841
+ const int64_t nb = ne00;
842
+
843
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
844
+
845
+ [encoder setComputePipelineState:pipeline];
846
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
847
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
848
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
849
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
850
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
851
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
852
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
853
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
854
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
855
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
856
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
857
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
858
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
859
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
860
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
861
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
862
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
863
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
864
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
865
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
866
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
867
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
868
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
869
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
870
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
871
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
872
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
873
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
874
+
875
+ const int nth = MIN(1024, ne0);
876
+
877
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
878
+ } break;
879
+ case GGML_OP_ADD:
880
+ case GGML_OP_MUL:
881
+ case GGML_OP_DIV:
882
+ {
883
+ const size_t offs = 0;
884
+
885
+ bool bcast_row = false;
886
+
887
+ int64_t nb = ne00;
888
+
889
+ id<MTLComputePipelineState> pipeline = nil;
890
+
891
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
892
+ GGML_ASSERT(ggml_is_contiguous(src0));
1249
893
 
1250
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1251
- } else {
1252
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
894
+ // src1 is a row
895
+ GGML_ASSERT(ne11 == 1);
1253
896
 
1254
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
897
+ nb = ne00 / 4;
898
+ switch (dst->op) {
899
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
900
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
901
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
902
+ default: GGML_ASSERT(false);
1255
903
  }
1256
- } break;
1257
- case GGML_OP_ACC:
1258
- {
1259
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1260
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1261
- GGML_ASSERT(dstt == GGML_TYPE_F32);
1262
904
 
1263
- GGML_ASSERT(ggml_is_contiguous(src0));
1264
- GGML_ASSERT(ggml_is_contiguous(src1));
1265
-
1266
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1267
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1268
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1269
- const size_t offs = ((int32_t *) dst->op_params)[3];
1270
-
1271
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1272
-
1273
- if (!inplace) {
1274
- // run a separete kernel to cpy src->dst
1275
- // not sure how to avoid this
1276
- // TODO: make a simpler cpy_bytes kernel
1277
-
1278
- const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
1279
-
1280
- [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1281
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1282
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1283
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1284
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1285
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1286
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1287
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1288
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1289
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1290
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1291
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1292
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1293
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1294
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1295
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1296
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1297
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1298
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1299
-
1300
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
905
+ bcast_row = true;
906
+ } else {
907
+ switch (dst->op) {
908
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
909
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
910
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
911
+ default: GGML_ASSERT(false);
1301
912
  }
913
+ }
1302
914
 
1303
- [encoder setComputePipelineState:ctx->pipeline_add];
1304
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1305
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1306
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1307
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1308
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1309
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1310
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1311
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1312
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1313
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1314
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1315
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1316
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1317
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1318
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1319
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1320
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1321
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1322
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1323
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1324
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1325
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1326
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1327
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1328
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1329
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1330
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1331
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1332
-
1333
- const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
1334
-
1335
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1336
- } break;
1337
- case GGML_OP_SCALE:
1338
- {
1339
- GGML_ASSERT(ggml_is_contiguous(src0));
915
+ [encoder setComputePipelineState:pipeline];
916
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
917
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
918
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
919
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
920
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
921
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
922
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
923
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
924
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
925
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
926
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
927
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
928
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
929
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
930
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
931
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
932
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
933
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
934
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
935
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
936
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
937
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
938
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
939
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
940
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
941
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
942
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
943
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
944
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
945
+
946
+ if (bcast_row) {
947
+ const int64_t n = ggml_nelements(dst)/4;
1340
948
 
1341
- const float scale = *(const float *) dst->op_params;
949
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
950
+ } else {
951
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1342
952
 
1343
- int64_t n = ggml_nelements(dst);
953
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
954
+ }
955
+ } break;
956
+ case GGML_OP_ACC:
957
+ {
958
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
959
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
960
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1344
961
 
1345
- if (n % 4 == 0) {
1346
- n /= 4;
1347
- [encoder setComputePipelineState:ctx->pipeline_scale_4];
1348
- } else {
1349
- [encoder setComputePipelineState:ctx->pipeline_scale];
1350
- }
962
+ GGML_ASSERT(ggml_is_contiguous(src0));
963
+ GGML_ASSERT(ggml_is_contiguous(src1));
1351
964
 
1352
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1353
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1354
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
965
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
966
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
967
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
968
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1355
969
 
1356
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1357
- } break;
1358
- case GGML_OP_UNARY:
1359
- switch (ggml_get_unary_op(gf->nodes[i])) {
1360
- case GGML_UNARY_OP_TANH:
1361
- {
1362
- [encoder setComputePipelineState:ctx->pipeline_tanh];
1363
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1364
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
970
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1365
971
 
1366
- const int64_t n = ggml_nelements(dst);
972
+ if (!inplace) {
973
+ // run a separete kernel to cpy src->dst
974
+ // not sure how to avoid this
975
+ // TODO: make a simpler cpy_bytes kernel
1367
976
 
1368
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1369
- } break;
1370
- case GGML_UNARY_OP_RELU:
1371
- {
1372
- [encoder setComputePipelineState:ctx->pipeline_relu];
1373
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1374
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
977
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1375
978
 
1376
- const int64_t n = ggml_nelements(dst);
979
+ [encoder setComputePipelineState:pipeline];
980
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
981
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
982
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
983
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
984
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
985
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
986
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
987
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
988
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
989
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
990
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
991
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
992
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
993
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
994
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
995
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
996
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
997
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1377
998
 
1378
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1379
- } break;
1380
- case GGML_UNARY_OP_GELU:
1381
- {
1382
- [encoder setComputePipelineState:ctx->pipeline_gelu];
1383
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1384
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
999
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1385
1000
 
1386
- const int64_t n = ggml_nelements(dst);
1387
- GGML_ASSERT(n % 4 == 0);
1001
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1002
+ }
1388
1003
 
1389
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1390
- } break;
1391
- case GGML_UNARY_OP_GELU_QUICK:
1392
- {
1393
- [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1394
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1395
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1004
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1005
+
1006
+ [encoder setComputePipelineState:pipeline];
1007
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1008
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1009
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1010
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1011
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1012
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1013
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1014
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1015
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1016
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1017
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1018
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1019
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1020
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1021
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1022
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1023
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1024
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1025
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1026
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1027
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1028
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1029
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1030
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1031
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1032
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1033
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1034
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1035
+
1036
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1037
+
1038
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1039
+ } break;
1040
+ case GGML_OP_SCALE:
1041
+ {
1042
+ GGML_ASSERT(ggml_is_contiguous(src0));
1043
+
1044
+ const float scale = *(const float *) dst->op_params;
1045
+
1046
+ int64_t n = ggml_nelements(dst);
1047
+
1048
+ id<MTLComputePipelineState> pipeline = nil;
1049
+
1050
+ if (n % 4 == 0) {
1051
+ n /= 4;
1052
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1053
+ } else {
1054
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1055
+ }
1396
1056
 
1397
- const int64_t n = ggml_nelements(dst);
1398
- GGML_ASSERT(n % 4 == 0);
1057
+ [encoder setComputePipelineState:pipeline];
1058
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1059
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1060
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1399
1061
 
1400
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1401
- } break;
1402
- case GGML_UNARY_OP_SILU:
1403
- {
1404
- [encoder setComputePipelineState:ctx->pipeline_silu];
1405
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1406
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1062
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1063
+ } break;
1064
+ case GGML_OP_UNARY:
1065
+ switch (ggml_get_unary_op(gf->nodes[i])) {
1066
+ case GGML_UNARY_OP_TANH:
1067
+ {
1068
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1407
1069
 
1408
- const int64_t n = ggml_nelements(dst);
1409
- GGML_ASSERT(n % 4 == 0);
1070
+ [encoder setComputePipelineState:pipeline];
1071
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1072
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1410
1073
 
1411
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1412
- } break;
1413
- default:
1414
- {
1415
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1416
- GGML_ASSERT(false);
1417
- }
1418
- } break;
1419
- case GGML_OP_SQR:
1420
- {
1421
- GGML_ASSERT(ggml_is_contiguous(src0));
1074
+ const int64_t n = ggml_nelements(dst);
1422
1075
 
1423
- [encoder setComputePipelineState:ctx->pipeline_sqr];
1424
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1425
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1076
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1077
+ } break;
1078
+ case GGML_UNARY_OP_RELU:
1079
+ {
1080
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1426
1081
 
1427
- const int64_t n = ggml_nelements(dst);
1428
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1429
- } break;
1430
- case GGML_OP_SUM_ROWS:
1431
- {
1432
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1082
+ [encoder setComputePipelineState:pipeline];
1083
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1084
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1433
1085
 
1434
- [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1435
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1436
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1437
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1438
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1439
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1440
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1441
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1442
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1443
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1444
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1445
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1446
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1447
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1448
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1449
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1450
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1451
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1452
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1453
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1454
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1455
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1456
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1457
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1458
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1459
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1460
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1461
-
1462
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1463
- } break;
1464
- case GGML_OP_SOFT_MAX:
1465
- {
1466
- int nth = 32; // SIMD width
1467
-
1468
- if (ne00%4 == 0) {
1469
- while (nth < ne00/4 && nth < 256) {
1470
- nth *= 2;
1471
- }
1472
- [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1473
- } else {
1474
- while (nth < ne00 && nth < 1024) {
1475
- nth *= 2;
1476
- }
1477
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
1478
- }
1086
+ const int64_t n = ggml_nelements(dst);
1479
1087
 
1480
- const float scale = ((float *) dst->op_params)[0];
1088
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1089
+ } break;
1090
+ case GGML_UNARY_OP_GELU:
1091
+ {
1092
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1481
1093
 
1482
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1483
- if (id_src1) {
1484
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1485
- } else {
1486
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1487
- }
1488
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1489
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1490
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1491
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1492
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1493
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1494
-
1495
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1496
- } break;
1497
- case GGML_OP_DIAG_MASK_INF:
1498
- {
1499
- const int n_past = ((int32_t *)(dst->op_params))[0];
1500
-
1501
- if (ne00%8 == 0) {
1502
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
1503
- } else {
1504
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
1505
- }
1506
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1507
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1508
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1509
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1510
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1094
+ [encoder setComputePipelineState:pipeline];
1095
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1096
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1511
1097
 
1512
- if (ne00%8 == 0) {
1513
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1514
- }
1515
- else {
1516
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1517
- }
1518
- } break;
1519
- case GGML_OP_MUL_MAT:
1520
- {
1521
- GGML_ASSERT(ne00 == ne10);
1098
+ const int64_t n = ggml_nelements(dst);
1099
+ GGML_ASSERT(n % 4 == 0);
1522
1100
 
1523
- // TODO: assert that dim2 and dim3 are contiguous
1524
- GGML_ASSERT(ne12 % ne02 == 0);
1525
- GGML_ASSERT(ne13 % ne03 == 0);
1101
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1102
+ } break;
1103
+ case GGML_UNARY_OP_GELU_QUICK:
1104
+ {
1105
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1526
1106
 
1527
- const uint r2 = ne12/ne02;
1528
- const uint r3 = ne13/ne03;
1107
+ [encoder setComputePipelineState:pipeline];
1108
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1109
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1529
1110
 
1530
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1531
- // to the matrix-vector kernel
1532
- int ne11_mm_min = 1;
1111
+ const int64_t n = ggml_nelements(dst);
1112
+ GGML_ASSERT(n % 4 == 0);
1113
+
1114
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1115
+ } break;
1116
+ case GGML_UNARY_OP_SILU:
1117
+ {
1118
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1119
+
1120
+ [encoder setComputePipelineState:pipeline];
1121
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1122
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1123
+
1124
+ const int64_t n = ggml_nelements(dst);
1125
+ GGML_ASSERT(n % 4 == 0);
1126
+
1127
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1128
+ } break;
1129
+ default:
1130
+ {
1131
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1132
+ GGML_ASSERT(false);
1133
+ }
1134
+ } break;
1135
+ case GGML_OP_SQR:
1136
+ {
1137
+ GGML_ASSERT(ggml_is_contiguous(src0));
1138
+
1139
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1140
+
1141
+ [encoder setComputePipelineState:pipeline];
1142
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1143
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1144
+
1145
+ const int64_t n = ggml_nelements(dst);
1146
+
1147
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1148
+ } break;
1149
+ case GGML_OP_SUM_ROWS:
1150
+ {
1151
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1152
+
1153
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1154
+
1155
+ [encoder setComputePipelineState:pipeline];
1156
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1157
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1158
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1159
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1160
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1161
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1162
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1163
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1164
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1165
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1166
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1167
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1168
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1169
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1170
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1171
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1172
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1173
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1174
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1175
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1176
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1177
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1178
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1179
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1180
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1181
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1185
+ case GGML_OP_SOFT_MAX:
1186
+ {
1187
+ int nth = 32; // SIMD width
1188
+
1189
+ id<MTLComputePipelineState> pipeline = nil;
1190
+
1191
+ if (ne00%4 == 0) {
1192
+ while (nth < ne00/4 && nth < 256) {
1193
+ nth *= 2;
1194
+ }
1195
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
1196
+ } else {
1197
+ while (nth < ne00 && nth < 1024) {
1198
+ nth *= 2;
1199
+ }
1200
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1201
+ }
1202
+
1203
+ const float scale = ((float *) dst->op_params)[0];
1204
+
1205
+ [encoder setComputePipelineState:pipeline];
1206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
+ if (id_src1) {
1208
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1209
+ } else {
1210
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1211
+ }
1212
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1213
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1214
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1215
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1216
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1217
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1218
+
1219
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1220
+ } break;
1221
+ case GGML_OP_DIAG_MASK_INF:
1222
+ {
1223
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1224
+
1225
+ id<MTLComputePipelineState> pipeline = nil;
1226
+
1227
+ if (ne00%8 == 0) {
1228
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1229
+ } else {
1230
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1231
+ }
1232
+
1233
+ [encoder setComputePipelineState:pipeline];
1234
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1235
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1236
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1237
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1238
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1239
+
1240
+ if (ne00%8 == 0) {
1241
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1242
+ }
1243
+ else {
1244
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1245
+ }
1246
+ } break;
1247
+ case GGML_OP_MUL_MAT:
1248
+ {
1249
+ GGML_ASSERT(ne00 == ne10);
1250
+
1251
+ // TODO: assert that dim2 and dim3 are contiguous
1252
+ GGML_ASSERT(ne12 % ne02 == 0);
1253
+ GGML_ASSERT(ne13 % ne03 == 0);
1254
+
1255
+ const uint r2 = ne12/ne02;
1256
+ const uint r3 = ne13/ne03;
1257
+
1258
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1259
+ // to the matrix-vector kernel
1260
+ int ne11_mm_min = 1;
1533
1261
 
1534
1262
  #if 0
1535
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1536
- // these numbers do not translate to other devices or model sizes
1537
- // TODO: need to find a better approach
1538
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1539
- switch (src0t) {
1540
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1541
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1542
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1543
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1544
- case GGML_TYPE_Q4_0:
1545
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1546
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1547
- case GGML_TYPE_Q5_0: // not tested yet
1548
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1549
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1550
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1551
- default: ne11_mm_min = 1; break;
1552
- }
1263
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1264
+ // these numbers do not translate to other devices or model sizes
1265
+ // TODO: need to find a better approach
1266
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1267
+ switch (src0t) {
1268
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1269
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1270
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1271
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1272
+ case GGML_TYPE_Q4_0:
1273
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1274
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1275
+ case GGML_TYPE_Q5_0: // not tested yet
1276
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1277
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1278
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1279
+ default: ne11_mm_min = 1; break;
1553
1280
  }
1281
+ }
1554
1282
  #endif
1555
1283
 
1556
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1557
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1558
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1559
- !ggml_is_transposed(src0) &&
1560
- !ggml_is_transposed(src1) &&
1561
- src1t == GGML_TYPE_F32 &&
1562
- ne00 % 32 == 0 && ne00 >= 64 &&
1563
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1564
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1565
- switch (src0->type) {
1566
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
1567
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1568
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1569
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1570
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1571
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1572
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1573
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1574
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
1575
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1576
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1577
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1578
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
1579
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
1580
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1581
- }
1582
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1583
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1584
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1585
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1586
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1587
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1588
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1589
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1590
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1591
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1592
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1593
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1594
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1595
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1596
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1597
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1598
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1599
- } else {
1600
- int nth0 = 32;
1601
- int nth1 = 1;
1602
- int nrows = 1;
1603
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1604
-
1605
- // use custom matrix x vector kernel
1606
- switch (src0t) {
1607
- case GGML_TYPE_F32:
1608
- {
1609
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1610
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1611
- nrows = 4;
1612
- } break;
1613
- case GGML_TYPE_F16:
1614
- {
1615
- nth0 = 32;
1616
- nth1 = 1;
1617
- if (src1t == GGML_TYPE_F32) {
1618
- if (ne11 * ne12 < 4) {
1619
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1620
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1621
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1622
- nrows = ne11;
1623
- } else {
1624
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1625
- nrows = 4;
1626
- }
1284
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1285
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1286
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1287
+ !ggml_is_transposed(src0) &&
1288
+ !ggml_is_transposed(src1) &&
1289
+ src1t == GGML_TYPE_F32 &&
1290
+ ne00 % 32 == 0 && ne00 >= 64 &&
1291
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1292
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1293
+
1294
+ id<MTLComputePipelineState> pipeline = nil;
1295
+
1296
+ switch (src0->type) {
1297
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1298
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1299
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1300
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1301
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1302
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1303
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1304
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1305
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1306
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1307
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1308
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1309
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1310
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1311
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1312
+ }
1313
+
1314
+ [encoder setComputePipelineState:pipeline];
1315
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1316
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1317
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1318
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1319
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1320
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1321
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1322
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1323
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1324
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1325
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1326
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1327
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1328
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1329
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1330
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1331
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1332
+ } else {
1333
+ int nth0 = 32;
1334
+ int nth1 = 1;
1335
+ int nrows = 1;
1336
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1337
+
1338
+ id<MTLComputePipelineState> pipeline = nil;
1339
+
1340
+ // use custom matrix x vector kernel
1341
+ switch (src0t) {
1342
+ case GGML_TYPE_F32:
1343
+ {
1344
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1345
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1346
+ nrows = 4;
1347
+ } break;
1348
+ case GGML_TYPE_F16:
1349
+ {
1350
+ nth0 = 32;
1351
+ nth1 = 1;
1352
+ if (src1t == GGML_TYPE_F32) {
1353
+ if (ne11 * ne12 < 4) {
1354
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1355
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1356
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1357
+ nrows = ne11;
1627
1358
  } else {
1628
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1359
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1629
1360
  nrows = 4;
1630
1361
  }
1631
- } break;
1632
- case GGML_TYPE_Q4_0:
1633
- {
1634
- nth0 = 8;
1635
- nth1 = 8;
1636
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1637
- } break;
1638
- case GGML_TYPE_Q4_1:
1639
- {
1640
- nth0 = 8;
1641
- nth1 = 8;
1642
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1643
- } break;
1644
- case GGML_TYPE_Q5_0:
1645
- {
1646
- nth0 = 8;
1647
- nth1 = 8;
1648
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1649
- } break;
1650
- case GGML_TYPE_Q5_1:
1651
- {
1652
- nth0 = 8;
1653
- nth1 = 8;
1654
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1655
- } break;
1656
- case GGML_TYPE_Q8_0:
1657
- {
1658
- nth0 = 8;
1659
- nth1 = 8;
1660
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1661
- } break;
1662
- case GGML_TYPE_Q2_K:
1663
- {
1664
- nth0 = 2;
1665
- nth1 = 32;
1666
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1667
- } break;
1668
- case GGML_TYPE_Q3_K:
1669
- {
1670
- nth0 = 2;
1671
- nth1 = 32;
1672
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1673
- } break;
1674
- case GGML_TYPE_Q4_K:
1675
- {
1676
- nth0 = 4; //1;
1677
- nth1 = 8; //32;
1678
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1679
- } break;
1680
- case GGML_TYPE_Q5_K:
1681
- {
1682
- nth0 = 2;
1683
- nth1 = 32;
1684
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1685
- } break;
1686
- case GGML_TYPE_Q6_K:
1687
- {
1688
- nth0 = 2;
1689
- nth1 = 32;
1690
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1691
- } break;
1692
- case GGML_TYPE_IQ2_XXS:
1693
- {
1694
- nth0 = 4;
1695
- nth1 = 16;
1696
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
1697
- } break;
1698
- case GGML_TYPE_IQ2_XS:
1699
- {
1700
- nth0 = 4;
1701
- nth1 = 16;
1702
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
1703
- } break;
1704
- default:
1705
- {
1706
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1707
- GGML_ASSERT(false && "not implemented");
1362
+ } else {
1363
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1364
+ nrows = 4;
1708
1365
  }
1709
- };
1710
-
1711
- if (ggml_is_quantized(src0t)) {
1712
- GGML_ASSERT(ne00 >= nth0*nth1);
1713
- }
1366
+ } break;
1367
+ case GGML_TYPE_Q4_0:
1368
+ {
1369
+ nth0 = 8;
1370
+ nth1 = 8;
1371
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1372
+ } break;
1373
+ case GGML_TYPE_Q4_1:
1374
+ {
1375
+ nth0 = 8;
1376
+ nth1 = 8;
1377
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1378
+ } break;
1379
+ case GGML_TYPE_Q5_0:
1380
+ {
1381
+ nth0 = 8;
1382
+ nth1 = 8;
1383
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1384
+ } break;
1385
+ case GGML_TYPE_Q5_1:
1386
+ {
1387
+ nth0 = 8;
1388
+ nth1 = 8;
1389
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1390
+ } break;
1391
+ case GGML_TYPE_Q8_0:
1392
+ {
1393
+ nth0 = 8;
1394
+ nth1 = 8;
1395
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1396
+ } break;
1397
+ case GGML_TYPE_Q2_K:
1398
+ {
1399
+ nth0 = 2;
1400
+ nth1 = 32;
1401
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1402
+ } break;
1403
+ case GGML_TYPE_Q3_K:
1404
+ {
1405
+ nth0 = 2;
1406
+ nth1 = 32;
1407
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1408
+ } break;
1409
+ case GGML_TYPE_Q4_K:
1410
+ {
1411
+ nth0 = 4; //1;
1412
+ nth1 = 8; //32;
1413
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1414
+ } break;
1415
+ case GGML_TYPE_Q5_K:
1416
+ {
1417
+ nth0 = 2;
1418
+ nth1 = 32;
1419
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1420
+ } break;
1421
+ case GGML_TYPE_Q6_K:
1422
+ {
1423
+ nth0 = 2;
1424
+ nth1 = 32;
1425
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1426
+ } break;
1427
+ case GGML_TYPE_IQ2_XXS:
1428
+ {
1429
+ nth0 = 4;
1430
+ nth1 = 16;
1431
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1432
+ } break;
1433
+ case GGML_TYPE_IQ2_XS:
1434
+ {
1435
+ nth0 = 4;
1436
+ nth1 = 16;
1437
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1438
+ } break;
1439
+ default:
1440
+ {
1441
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1442
+ GGML_ASSERT(false && "not implemented");
1443
+ }
1444
+ };
1714
1445
 
1715
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1716
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1717
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1718
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1719
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1720
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1721
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1722
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1723
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1724
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1725
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1726
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1727
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1728
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1729
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1730
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1731
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1732
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1733
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1734
-
1735
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1736
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1737
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1738
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1739
- }
1740
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1742
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1743
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1744
- }
1745
- else if (src0t == GGML_TYPE_Q4_K) {
1746
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1747
- }
1748
- else if (src0t == GGML_TYPE_Q3_K) {
1749
- #ifdef GGML_QKK_64
1750
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1751
- #else
1752
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1753
- #endif
1754
- }
1755
- else if (src0t == GGML_TYPE_Q5_K) {
1756
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1757
- }
1758
- else if (src0t == GGML_TYPE_Q6_K) {
1759
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1760
- } else {
1761
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1762
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
- }
1446
+ if (ggml_is_quantized(src0t)) {
1447
+ GGML_ASSERT(ne00 >= nth0*nth1);
1764
1448
  }
1765
- } break;
1766
- case GGML_OP_MUL_MAT_ID:
1767
- {
1768
- //GGML_ASSERT(ne00 == ne10);
1769
- //GGML_ASSERT(ne03 == ne13);
1770
-
1771
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1772
-
1773
- const int n_as = ((int32_t *) dst->op_params)[1];
1774
-
1775
- // TODO: make this more general
1776
- GGML_ASSERT(n_as <= 8);
1777
-
1778
- // max size of the src1ids array in the kernel stack
1779
- GGML_ASSERT(ne11 <= 512);
1780
-
1781
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1782
-
1783
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1784
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1785
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1786
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1787
-
1788
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1789
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1790
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1791
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1792
-
1793
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1794
-
1795
- GGML_ASSERT(!ggml_is_transposed(src2));
1796
- GGML_ASSERT(!ggml_is_transposed(src1));
1797
-
1798
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1799
-
1800
- const uint r2 = ne12/ne22;
1801
- const uint r3 = ne13/ne23;
1802
-
1803
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1804
- // to the matrix-vector kernel
1805
- int ne11_mm_min = n_as;
1806
-
1807
- const int idx = ((int32_t *) dst->op_params)[0];
1808
-
1809
- // batch size
1810
- GGML_ASSERT(ne01 == ne11);
1811
-
1812
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1813
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1814
- // !!!
1815
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1816
- // indirect matrix multiplication
1817
- // !!!
1818
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1819
- ne20 % 32 == 0 && ne20 >= 64 &&
1820
- ne11 > ne11_mm_min) {
1821
- switch (src2->type) {
1822
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1823
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1824
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1825
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1826
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1827
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1828
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1829
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1830
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1831
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1832
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1833
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1834
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
1835
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
1836
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1837
- }
1838
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1839
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1840
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1841
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1842
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1843
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1844
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1845
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1846
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1847
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1848
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1849
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1850
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1851
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1852
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1853
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1854
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1855
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1856
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1857
- // TODO: how to make this an array? read Metal docs
1858
- for (int j = 0; j < 8; ++j) {
1859
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1860
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1861
-
1862
- size_t offs_src_cur = 0;
1863
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1864
-
1865
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1866
- }
1867
-
1868
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1869
-
1870
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1871
- } else {
1872
- int nth0 = 32;
1873
- int nth1 = 1;
1874
- int nrows = 1;
1875
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1876
-
1877
- // use custom matrix x vector kernel
1878
- switch (src2t) {
1879
- case GGML_TYPE_F32:
1880
- {
1881
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1882
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1883
- } break;
1884
- case GGML_TYPE_F16:
1885
- {
1886
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1887
- nth0 = 32;
1888
- nth1 = 1;
1889
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1890
- } break;
1891
- case GGML_TYPE_Q4_0:
1892
- {
1893
- nth0 = 8;
1894
- nth1 = 8;
1895
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1896
- } break;
1897
- case GGML_TYPE_Q4_1:
1898
- {
1899
- nth0 = 8;
1900
- nth1 = 8;
1901
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1902
- } break;
1903
- case GGML_TYPE_Q5_0:
1904
- {
1905
- nth0 = 8;
1906
- nth1 = 8;
1907
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1908
- } break;
1909
- case GGML_TYPE_Q5_1:
1910
- {
1911
- nth0 = 8;
1912
- nth1 = 8;
1913
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1914
- } break;
1915
- case GGML_TYPE_Q8_0:
1916
- {
1917
- nth0 = 8;
1918
- nth1 = 8;
1919
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1920
- } break;
1921
- case GGML_TYPE_Q2_K:
1922
- {
1923
- nth0 = 2;
1924
- nth1 = 32;
1925
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1926
- } break;
1927
- case GGML_TYPE_Q3_K:
1928
- {
1929
- nth0 = 2;
1930
- nth1 = 32;
1931
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1932
- } break;
1933
- case GGML_TYPE_Q4_K:
1934
- {
1935
- nth0 = 4; //1;
1936
- nth1 = 8; //32;
1937
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1938
- } break;
1939
- case GGML_TYPE_Q5_K:
1940
- {
1941
- nth0 = 2;
1942
- nth1 = 32;
1943
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1944
- } break;
1945
- case GGML_TYPE_Q6_K:
1946
- {
1947
- nth0 = 2;
1948
- nth1 = 32;
1949
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1950
- } break;
1951
- case GGML_TYPE_IQ2_XXS:
1952
- {
1953
- nth0 = 4;
1954
- nth1 = 16;
1955
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1956
- } break;
1957
- case GGML_TYPE_IQ2_XS:
1958
- {
1959
- nth0 = 4;
1960
- nth1 = 16;
1961
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962
- } break;
1963
- default:
1964
- {
1965
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1966
- GGML_ASSERT(false && "not implemented");
1967
- }
1968
- };
1969
-
1970
- if (ggml_is_quantized(src2t)) {
1971
- GGML_ASSERT(ne20 >= nth0*nth1);
1972
- }
1973
1449
 
1974
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1975
-
1976
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1977
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1978
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1979
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1980
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1981
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1982
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1983
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1984
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1985
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1986
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1987
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1988
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1989
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1990
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1991
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1992
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1993
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1994
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1995
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1996
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1997
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1998
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1999
- // TODO: how to make this an array? read Metal docs
2000
- for (int j = 0; j < 8; ++j) {
2001
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
2002
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
2003
-
2004
- size_t offs_src_cur = 0;
2005
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
2006
-
2007
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
2008
- }
2009
-
2010
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
2011
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
2012
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
2013
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2014
- }
2015
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2017
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2018
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2019
- }
2020
- else if (src2t == GGML_TYPE_Q4_K) {
2021
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2022
- }
2023
- else if (src2t == GGML_TYPE_Q3_K) {
1450
+ [encoder setComputePipelineState:pipeline];
1451
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1452
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1453
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1454
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1455
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1456
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1457
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1458
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1459
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1460
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1461
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1462
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1463
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1464
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1465
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1466
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1467
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1468
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1469
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1470
+
1471
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1472
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1473
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1474
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1475
+ }
1476
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1477
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1478
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
+ }
1481
+ else if (src0t == GGML_TYPE_Q4_K) {
1482
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
+ }
1484
+ else if (src0t == GGML_TYPE_Q3_K) {
2024
1485
  #ifdef GGML_QKK_64
2025
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1486
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2026
1487
  #else
2027
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1488
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2028
1489
  #endif
2029
- }
2030
- else if (src2t == GGML_TYPE_Q5_K) {
2031
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2032
- }
2033
- else if (src2t == GGML_TYPE_Q6_K) {
2034
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
- } else {
2036
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2037
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
- }
2039
1490
  }
2040
- } break;
2041
- case GGML_OP_GET_ROWS:
2042
- {
2043
- switch (src0->type) {
2044
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
2045
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
2046
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
2047
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
2048
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
2049
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
2050
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
2051
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
2052
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
2053
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
2054
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
2055
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
2056
- case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
2057
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
2058
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
2059
- default: GGML_ASSERT(false && "not implemented");
1491
+ else if (src0t == GGML_TYPE_Q5_K) {
1492
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2060
1493
  }
2061
-
2062
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2063
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2064
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2065
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2066
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2067
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2068
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2069
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2070
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2071
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2072
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
2073
-
2074
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2075
- } break;
2076
- case GGML_OP_RMS_NORM:
2077
- {
2078
- GGML_ASSERT(ne00 % 4 == 0);
2079
-
2080
- float eps;
2081
- memcpy(&eps, dst->op_params, sizeof(float));
2082
-
2083
- int nth = 32; // SIMD width
2084
-
2085
- while (nth < ne00/4 && nth < 1024) {
2086
- nth *= 2;
1494
+ else if (src0t == GGML_TYPE_Q6_K) {
1495
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1496
+ } else {
1497
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1498
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2087
1499
  }
1500
+ }
1501
+ } break;
1502
+ case GGML_OP_MUL_MAT_ID:
1503
+ {
1504
+ //GGML_ASSERT(ne00 == ne10);
1505
+ //GGML_ASSERT(ne03 == ne13);
2088
1506
 
2089
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
2090
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2091
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2092
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2093
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2094
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2095
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2096
-
2097
- const int64_t nrows = ggml_nrows(src0);
2098
-
2099
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2100
- } break;
2101
- case GGML_OP_GROUP_NORM:
2102
- {
2103
- GGML_ASSERT(ne00 % 4 == 0);
2104
-
2105
- //float eps;
2106
- //memcpy(&eps, dst->op_params, sizeof(float));
2107
-
2108
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2109
-
2110
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2111
-
2112
- int nth = 32; // SIMD width
2113
-
2114
- //while (nth < ne00/4 && nth < 1024) {
2115
- // nth *= 2;
2116
- //}
2117
-
2118
- [encoder setComputePipelineState:ctx->pipeline_group_norm];
2119
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2120
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2121
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2122
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2123
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2124
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2125
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2126
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2127
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2128
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2129
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2130
-
2131
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2132
- } break;
2133
- case GGML_OP_NORM:
2134
- {
2135
- float eps;
2136
- memcpy(&eps, dst->op_params, sizeof(float));
2137
-
2138
- const int nth = MIN(256, ne00);
2139
-
2140
- [encoder setComputePipelineState:ctx->pipeline_norm];
2141
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2142
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2143
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2144
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2145
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2146
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2147
-
2148
- const int64_t nrows = ggml_nrows(src0);
2149
-
2150
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2151
- } break;
2152
- case GGML_OP_ALIBI:
2153
- {
2154
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2155
-
2156
- const int nth = MIN(1024, ne00);
2157
-
2158
- //const int n_past = ((int32_t *) dst->op_params)[0];
2159
- const int n_head = ((int32_t *) dst->op_params)[1];
2160
- float max_bias;
2161
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1507
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
2162
1508
 
2163
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2164
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2165
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1509
+ const int n_as = ((int32_t *) dst->op_params)[1];
2166
1510
 
2167
- [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
2168
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2170
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2171
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2172
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2173
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2174
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2175
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2176
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2177
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2178
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2179
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2180
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2181
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2182
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2183
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2184
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2185
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2186
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2187
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2188
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1511
+ // TODO: make this more general
1512
+ GGML_ASSERT(n_as <= 8);
2189
1513
 
2190
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2191
- } break;
2192
- case GGML_OP_ROPE:
2193
- {
2194
- GGML_ASSERT(ne10 == ne02);
2195
-
2196
- const int nth = MIN(1024, ne00);
2197
-
2198
- const int n_past = ((int32_t *) dst->op_params)[0];
2199
- const int n_dims = ((int32_t *) dst->op_params)[1];
2200
- const int mode = ((int32_t *) dst->op_params)[2];
2201
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2202
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2203
-
2204
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2205
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2206
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2207
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2208
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2209
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2210
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1514
+ // max size of the src1ids array in the kernel stack
1515
+ GGML_ASSERT(ne11 <= 512);
2211
1516
 
2212
- switch (src0->type) {
2213
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
2214
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
2215
- default: GGML_ASSERT(false);
2216
- };
1517
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
2217
1518
 
2218
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2219
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2220
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2221
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2222
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2223
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2224
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2225
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2226
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2227
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2228
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2229
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2230
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2231
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2232
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2233
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2234
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2235
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2236
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2237
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2238
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2239
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2240
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2241
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2242
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2243
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2244
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2245
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2246
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1519
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1522
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
2247
1523
 
2248
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2249
- } break;
2250
- case GGML_OP_IM2COL:
2251
- {
2252
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
2253
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2254
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
1524
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1525
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1526
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1527
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
2255
1528
 
2256
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2257
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2258
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2259
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2260
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2261
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2262
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1529
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2263
1530
 
2264
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2265
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2266
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2267
- const int32_t IW = src1->ne[0];
1531
+ GGML_ASSERT(!ggml_is_transposed(src2));
1532
+ GGML_ASSERT(!ggml_is_transposed(src1));
2268
1533
 
2269
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2270
- const int32_t KW = src0->ne[0];
1534
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2271
1535
 
2272
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2273
- const int32_t OW = dst->ne[1];
1536
+ const uint r2 = ne12/ne22;
1537
+ const uint r3 = ne13/ne23;
2274
1538
 
2275
- const int32_t CHW = IC * KH * KW;
1539
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1540
+ // to the matrix-vector kernel
1541
+ int ne11_mm_min = n_as;
2276
1542
 
2277
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2278
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1543
+ const int idx = ((int32_t *) dst->op_params)[0];
2279
1544
 
2280
- switch (src0->type) {
2281
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2282
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
2283
- default: GGML_ASSERT(false);
2284
- };
1545
+ // batch size
1546
+ GGML_ASSERT(ne01 == ne11);
2285
1547
 
2286
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2287
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2288
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2289
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2290
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2291
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2292
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2293
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2294
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2295
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2296
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2297
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2298
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2299
-
2300
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2301
- } break;
2302
- case GGML_OP_UPSCALE:
2303
- {
2304
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2305
-
2306
- const int sf = dst->op_params[0];
2307
-
2308
- [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2309
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2310
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2311
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2312
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2313
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2314
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2315
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2316
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2317
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2318
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2319
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2320
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2321
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2322
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2323
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2324
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2325
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2326
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2327
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2328
-
2329
- const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
2330
-
2331
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2332
- } break;
2333
- case GGML_OP_PAD:
2334
- {
2335
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2336
-
2337
- [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2338
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2339
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2340
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2341
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2342
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2343
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2344
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2345
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2346
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2347
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2348
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2349
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2350
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2351
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2352
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2353
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2354
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2355
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2356
-
2357
- const int nth = MIN(1024, ne0);
2358
-
2359
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2360
- } break;
2361
- case GGML_OP_ARGSORT:
2362
- {
2363
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2364
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2365
-
2366
- const int nrows = ggml_nrows(src0);
2367
-
2368
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2369
-
2370
- switch (order) {
2371
- case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2372
- case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2373
- default: GGML_ASSERT(false);
2374
- };
1548
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1549
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1550
+ // !!!
1551
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1552
+ // indirect matrix multiplication
1553
+ // !!!
1554
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1555
+ ne20 % 32 == 0 && ne20 >= 64 &&
1556
+ ne11 > ne11_mm_min) {
2375
1557
 
2376
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2377
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2378
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2379
-
2380
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2381
- } break;
2382
- case GGML_OP_LEAKY_RELU:
2383
- {
2384
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
1558
+ id<MTLComputePipelineState> pipeline = nil;
2385
1559
 
2386
- float slope;
2387
- memcpy(&slope, dst->op_params, sizeof(float));
1560
+ switch (src2->type) {
1561
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1562
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1563
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1564
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1565
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1566
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1567
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1568
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1569
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1570
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1571
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1572
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1573
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
+ }
2388
1577
 
2389
- [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2390
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2391
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2392
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
1578
+ [encoder setComputePipelineState:pipeline];
1579
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1580
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1581
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1582
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1583
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1584
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1585
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1586
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1587
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1588
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1589
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1590
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1591
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1592
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1593
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1594
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1595
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1596
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1597
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1598
+ // TODO: how to make this an array? read Metal docs
1599
+ for (int j = 0; j < 8; ++j) {
1600
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1601
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1602
+
1603
+ size_t offs_src_cur = 0;
1604
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1605
+
1606
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1607
+ }
2393
1608
 
2394
- const int64_t n = ggml_nelements(dst);
1609
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2395
1610
 
2396
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2397
- } break;
2398
- case GGML_OP_DUP:
2399
- case GGML_OP_CPY:
2400
- case GGML_OP_CONT:
2401
- {
2402
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1611
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1612
+ } else {
1613
+ int nth0 = 32;
1614
+ int nth1 = 1;
1615
+ int nrows = 1;
1616
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2403
1617
 
2404
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1618
+ id<MTLComputePipelineState> pipeline = nil;
2405
1619
 
2406
- switch (src0t) {
1620
+ // use custom matrix x vector kernel
1621
+ switch (src2t) {
2407
1622
  case GGML_TYPE_F32:
2408
1623
  {
2409
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2410
-
2411
- switch (dstt) {
2412
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2413
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2414
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2415
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2416
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2417
- //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2418
- //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
2419
- default: GGML_ASSERT(false && "not implemented");
2420
- };
1624
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1625
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2421
1626
  } break;
2422
1627
  case GGML_TYPE_F16:
2423
1628
  {
2424
- switch (dstt) {
2425
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2426
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
2427
- default: GGML_ASSERT(false && "not implemented");
2428
- };
1629
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1630
+ nth0 = 32;
1631
+ nth1 = 1;
1632
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1633
+ } break;
1634
+ case GGML_TYPE_Q4_0:
1635
+ {
1636
+ nth0 = 8;
1637
+ nth1 = 8;
1638
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2429
1639
  } break;
2430
- default: GGML_ASSERT(false && "not implemented");
1640
+ case GGML_TYPE_Q4_1:
1641
+ {
1642
+ nth0 = 8;
1643
+ nth1 = 8;
1644
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1645
+ } break;
1646
+ case GGML_TYPE_Q5_0:
1647
+ {
1648
+ nth0 = 8;
1649
+ nth1 = 8;
1650
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1651
+ } break;
1652
+ case GGML_TYPE_Q5_1:
1653
+ {
1654
+ nth0 = 8;
1655
+ nth1 = 8;
1656
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1657
+ } break;
1658
+ case GGML_TYPE_Q8_0:
1659
+ {
1660
+ nth0 = 8;
1661
+ nth1 = 8;
1662
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1663
+ } break;
1664
+ case GGML_TYPE_Q2_K:
1665
+ {
1666
+ nth0 = 2;
1667
+ nth1 = 32;
1668
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
1669
+ } break;
1670
+ case GGML_TYPE_Q3_K:
1671
+ {
1672
+ nth0 = 2;
1673
+ nth1 = 32;
1674
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
1675
+ } break;
1676
+ case GGML_TYPE_Q4_K:
1677
+ {
1678
+ nth0 = 4; //1;
1679
+ nth1 = 8; //32;
1680
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
1681
+ } break;
1682
+ case GGML_TYPE_Q5_K:
1683
+ {
1684
+ nth0 = 2;
1685
+ nth1 = 32;
1686
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
1687
+ } break;
1688
+ case GGML_TYPE_Q6_K:
1689
+ {
1690
+ nth0 = 2;
1691
+ nth1 = 32;
1692
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
1693
+ } break;
1694
+ case GGML_TYPE_IQ2_XXS:
1695
+ {
1696
+ nth0 = 4;
1697
+ nth1 = 16;
1698
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
1699
+ } break;
1700
+ case GGML_TYPE_IQ2_XS:
1701
+ {
1702
+ nth0 = 4;
1703
+ nth1 = 16;
1704
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1705
+ } break;
1706
+ default:
1707
+ {
1708
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1709
+ GGML_ASSERT(false && "not implemented");
1710
+ }
1711
+ };
1712
+
1713
+ if (ggml_is_quantized(src2t)) {
1714
+ GGML_ASSERT(ne20 >= nth0*nth1);
2431
1715
  }
2432
1716
 
2433
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2437
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2438
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2439
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2440
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2441
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2442
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2443
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2444
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2445
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2446
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2447
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2448
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2449
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2450
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1717
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
2451
1718
 
2452
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2453
- } break;
2454
- default:
2455
- {
2456
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2457
- GGML_ASSERT(false);
2458
- }
2459
- }
1719
+ [encoder setComputePipelineState:pipeline];
1720
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1721
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1722
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1723
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1724
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1725
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1726
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1727
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1728
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1729
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1730
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1731
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1732
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1733
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1734
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1735
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1736
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1737
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1738
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1739
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1740
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1741
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1742
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1743
+ // TODO: how to make this an array? read Metal docs
1744
+ for (int j = 0; j < 8; ++j) {
1745
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1746
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1747
+
1748
+ size_t offs_src_cur = 0;
1749
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1750
+
1751
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1752
+ }
2460
1753
 
2461
- #ifndef GGML_METAL_NDEBUG
2462
- [encoder popDebugGroup];
1754
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1755
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1756
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1758
+ }
1759
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1760
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1761
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
+ }
1764
+ else if (src2t == GGML_TYPE_Q4_K) {
1765
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1766
+ }
1767
+ else if (src2t == GGML_TYPE_Q3_K) {
1768
+ #ifdef GGML_QKK_64
1769
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
+ #else
1771
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2463
1772
  #endif
2464
- }
1773
+ }
1774
+ else if (src2t == GGML_TYPE_Q5_K) {
1775
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1776
+ }
1777
+ else if (src2t == GGML_TYPE_Q6_K) {
1778
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1779
+ } else {
1780
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1781
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1782
+ }
1783
+ }
1784
+ } break;
1785
+ case GGML_OP_GET_ROWS:
1786
+ {
1787
+ id<MTLComputePipelineState> pipeline = nil;
1788
+
1789
+ switch (src0->type) {
1790
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
1791
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
1792
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
1793
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
1794
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
1795
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
1796
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
1797
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
1798
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
1799
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
1800
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
1801
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1802
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1803
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1804
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1805
+ default: GGML_ASSERT(false && "not implemented");
1806
+ }
1807
+
1808
+ [encoder setComputePipelineState:pipeline];
1809
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1810
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1811
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1812
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1813
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1814
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1815
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1816
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1817
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1818
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1819
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1820
+
1821
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1822
+ } break;
1823
+ case GGML_OP_RMS_NORM:
1824
+ {
1825
+ GGML_ASSERT(ne00 % 4 == 0);
1826
+
1827
+ float eps;
1828
+ memcpy(&eps, dst->op_params, sizeof(float));
1829
+
1830
+ int nth = 32; // SIMD width
1831
+
1832
+ while (nth < ne00/4 && nth < 1024) {
1833
+ nth *= 2;
1834
+ }
2465
1835
 
2466
- if (encoder != nil) {
2467
- [encoder endEncoding];
2468
- encoder = nil;
1836
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
1837
+
1838
+ [encoder setComputePipelineState:pipeline];
1839
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1840
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1841
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1842
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1843
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1844
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1845
+
1846
+ const int64_t nrows = ggml_nrows(src0);
1847
+
1848
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1849
+ } break;
1850
+ case GGML_OP_GROUP_NORM:
1851
+ {
1852
+ GGML_ASSERT(ne00 % 4 == 0);
1853
+
1854
+ //float eps;
1855
+ //memcpy(&eps, dst->op_params, sizeof(float));
1856
+
1857
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1858
+
1859
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1860
+
1861
+ int nth = 32; // SIMD width
1862
+
1863
+ //while (nth < ne00/4 && nth < 1024) {
1864
+ // nth *= 2;
1865
+ //}
1866
+
1867
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
1868
+
1869
+ [encoder setComputePipelineState:pipeline];
1870
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1871
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1872
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1873
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1874
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1875
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1876
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1877
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1878
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1879
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1880
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1881
+
1882
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1883
+ } break;
1884
+ case GGML_OP_NORM:
1885
+ {
1886
+ float eps;
1887
+ memcpy(&eps, dst->op_params, sizeof(float));
1888
+
1889
+ const int nth = MIN(256, ne00);
1890
+
1891
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
1892
+
1893
+ [encoder setComputePipelineState:pipeline];
1894
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1895
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1896
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1897
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1898
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1899
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1900
+
1901
+ const int64_t nrows = ggml_nrows(src0);
1902
+
1903
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1904
+ } break;
1905
+ case GGML_OP_ALIBI:
1906
+ {
1907
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
1908
+
1909
+ const int nth = MIN(1024, ne00);
1910
+
1911
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1912
+ const int n_head = ((int32_t *) dst->op_params)[1];
1913
+ float max_bias;
1914
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1915
+
1916
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1917
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1918
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1919
+
1920
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
1921
+
1922
+ [encoder setComputePipelineState:pipeline];
1923
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1925
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1926
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1927
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1928
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1929
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1930
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1931
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1932
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1933
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1934
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1935
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1936
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1937
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1938
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1939
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1940
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1941
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1942
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1943
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1946
+ } break;
1947
+ case GGML_OP_ROPE:
1948
+ {
1949
+ GGML_ASSERT(ne10 == ne02);
1950
+
1951
+ const int nth = MIN(1024, ne00);
1952
+
1953
+ const int n_past = ((int32_t *) dst->op_params)[0];
1954
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1955
+ const int mode = ((int32_t *) dst->op_params)[2];
1956
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1957
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1958
+
1959
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1960
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1961
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1962
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1963
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1964
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1965
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1966
+
1967
+ id<MTLComputePipelineState> pipeline = nil;
1968
+
1969
+ switch (src0->type) {
1970
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
1971
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
1972
+ default: GGML_ASSERT(false);
1973
+ };
1974
+
1975
+ [encoder setComputePipelineState:pipeline];
1976
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1977
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1978
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1979
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1980
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1981
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1982
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1983
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1984
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1985
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1986
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1987
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1988
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1989
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1990
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1991
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1992
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1993
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1994
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1995
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1996
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1997
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1998
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1999
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2000
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2001
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2002
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2003
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2004
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2005
+
2006
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2007
+ } break;
2008
+ case GGML_OP_IM2COL:
2009
+ {
2010
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
2011
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
2012
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
2013
+
2014
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2015
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2016
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2017
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2018
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2019
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2020
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2021
+
2022
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
2023
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2024
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2025
+ const int32_t IW = src1->ne[0];
2026
+
2027
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2028
+ const int32_t KW = src0->ne[0];
2029
+
2030
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2031
+ const int32_t OW = dst->ne[1];
2032
+
2033
+ const int32_t CHW = IC * KH * KW;
2034
+
2035
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2036
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2037
+
2038
+ id<MTLComputePipelineState> pipeline = nil;
2039
+
2040
+ switch (src0->type) {
2041
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2042
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2043
+ default: GGML_ASSERT(false);
2044
+ };
2045
+
2046
+ [encoder setComputePipelineState:pipeline];
2047
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2048
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2049
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2050
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2051
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2052
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2053
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2054
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2055
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2056
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2057
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2058
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2059
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2060
+
2061
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2062
+ } break;
2063
+ case GGML_OP_UPSCALE:
2064
+ {
2065
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2066
+
2067
+ const int sf = dst->op_params[0];
2068
+
2069
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2070
+
2071
+ [encoder setComputePipelineState:pipeline];
2072
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2073
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2074
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2075
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2076
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2077
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2078
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2079
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2080
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2081
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2082
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2083
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2084
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2085
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2086
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2087
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2088
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2089
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2090
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2091
+
2092
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2093
+
2094
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2095
+ } break;
2096
+ case GGML_OP_PAD:
2097
+ {
2098
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2099
+
2100
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2101
+
2102
+ [encoder setComputePipelineState:pipeline];
2103
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2104
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2105
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2106
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2107
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2108
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2109
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2110
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2111
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2112
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2113
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2114
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2115
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2116
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2117
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2118
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2119
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2120
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2121
+
2122
+ const int nth = MIN(1024, ne0);
2123
+
2124
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2125
+ } break;
2126
+ case GGML_OP_ARGSORT:
2127
+ {
2128
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2129
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2130
+
2131
+ const int nrows = ggml_nrows(src0);
2132
+
2133
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2134
+
2135
+ id<MTLComputePipelineState> pipeline = nil;
2136
+
2137
+ switch (order) {
2138
+ case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2139
+ case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2140
+ default: GGML_ASSERT(false);
2141
+ };
2142
+
2143
+ [encoder setComputePipelineState:pipeline];
2144
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2145
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2146
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2147
+
2148
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2149
+ } break;
2150
+ case GGML_OP_LEAKY_RELU:
2151
+ {
2152
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2153
+
2154
+ float slope;
2155
+ memcpy(&slope, dst->op_params, sizeof(float));
2156
+
2157
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2158
+
2159
+ [encoder setComputePipelineState:pipeline];
2160
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2161
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2162
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2163
+
2164
+ const int64_t n = ggml_nelements(dst);
2165
+
2166
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2167
+ } break;
2168
+ case GGML_OP_DUP:
2169
+ case GGML_OP_CPY:
2170
+ case GGML_OP_CONT:
2171
+ {
2172
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2173
+
2174
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
2175
+
2176
+ id<MTLComputePipelineState> pipeline = nil;
2177
+
2178
+ switch (src0t) {
2179
+ case GGML_TYPE_F32:
2180
+ {
2181
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2182
+
2183
+ switch (dstt) {
2184
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2185
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2186
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2187
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2188
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2189
+ //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2190
+ //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2191
+ default: GGML_ASSERT(false && "not implemented");
2192
+ };
2193
+ } break;
2194
+ case GGML_TYPE_F16:
2195
+ {
2196
+ switch (dstt) {
2197
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2198
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2199
+ default: GGML_ASSERT(false && "not implemented");
2200
+ };
2201
+ } break;
2202
+ default: GGML_ASSERT(false && "not implemented");
2203
+ }
2204
+
2205
+ [encoder setComputePipelineState:pipeline];
2206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2207
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2208
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2209
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2210
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2211
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2212
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2213
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2214
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2215
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2216
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2217
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2218
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2219
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2220
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2221
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2222
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2223
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2224
+
2225
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2226
+ } break;
2227
+ default:
2228
+ {
2229
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2230
+ GGML_ASSERT(false);
2231
+ }
2469
2232
  }
2470
2233
 
2471
- [command_buffer commit];
2472
- });
2473
- }
2234
+ #ifndef GGML_METAL_NDEBUG
2235
+ [encoder popDebugGroup];
2236
+ #endif
2237
+ }
2238
+
2239
+ if (encoder != nil) {
2240
+ [encoder endEncoding];
2241
+ encoder = nil;
2242
+ }
2474
2243
 
2475
- // wait for all threads to finish
2476
- dispatch_barrier_sync(ctx->d_queue, ^{});
2244
+ [command_buffer commit];
2245
+ });
2477
2246
 
2478
- // check status of command buffers
2247
+ // Wait for completion and check status of each command buffer
2479
2248
  // needed to detect if the device ran out-of-memory for example (#1881)
2480
- for (int i = 0; i < n_cb; i++) {
2481
- [ctx->command_buffers[i] waitUntilCompleted];
2482
2249
 
2483
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
2250
+ for (int i = 0; i < n_cb; ++i) {
2251
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
2252
+ [command_buffer waitUntilCompleted];
2253
+
2254
+ MTLCommandBufferStatus status = [command_buffer status];
2484
2255
  if (status != MTLCommandBufferStatusCompleted) {
2485
2256
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2486
2257
  return false;
@@ -2520,13 +2291,13 @@ static void ggml_backend_metal_free_device(void) {
2520
2291
  }
2521
2292
  }
2522
2293
 
2523
- static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2524
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2294
+ GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
2295
+ return "Metal";
2525
2296
 
2526
- return ctx->all_data;
2297
+ UNUSED(buffer);
2527
2298
  }
2528
2299
 
2529
- static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2300
+ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2530
2301
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2531
2302
 
2532
2303
  for (int i = 0; i < ctx->n_buffers; i++) {
@@ -2541,50 +2312,80 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
2541
2312
  free(ctx);
2542
2313
  }
2543
2314
 
2544
- static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2545
- memcpy((char *)tensor->data + offset, data, size);
2315
+ GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2316
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2546
2317
 
2547
- UNUSED(buffer);
2318
+ return ctx->all_data;
2548
2319
  }
2549
2320
 
2550
- static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2551
- memcpy(data, (const char *)tensor->data + offset, size);
2321
+ GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2322
+ memcpy((char *)tensor->data + offset, data, size);
2552
2323
 
2553
2324
  UNUSED(buffer);
2554
2325
  }
2555
2326
 
2556
- static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2557
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2327
+ GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2328
+ memcpy(data, (const char *)tensor->data + offset, size);
2558
2329
 
2559
2330
  UNUSED(buffer);
2560
2331
  }
2561
2332
 
2562
- static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2563
- ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
2333
+ GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
2334
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2335
+ memcpy(dst->data, src->data, ggml_nbytes(src));
2336
+ return true;
2337
+ }
2338
+ return false;
2564
2339
 
2565
2340
  UNUSED(buffer);
2566
2341
  }
2567
2342
 
2568
- static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2343
+ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2569
2344
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2570
2345
 
2571
2346
  memset(ctx->all_data, value, ctx->all_size);
2572
2347
  }
2573
2348
 
2574
2349
  static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
2350
+ /* .get_name = */ ggml_backend_metal_buffer_get_name,
2575
2351
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2576
2352
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
2577
2353
  /* .init_tensor = */ NULL,
2578
2354
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2579
2355
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2580
- /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
2581
- /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
2356
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
2582
2357
  /* .clear = */ ggml_backend_metal_buffer_clear,
2358
+ /* .reset = */ NULL,
2583
2359
  };
2584
2360
 
2585
2361
  // default buffer type
2586
2362
 
2587
- static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2363
+ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2364
+ return "Metal";
2365
+
2366
+ UNUSED(buft);
2367
+ }
2368
+
2369
+ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
2370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
2372
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2373
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2374
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2375
+
2376
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2377
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2378
+ } else {
2379
+ GGML_METAL_LOG_INFO("\n");
2380
+ }
2381
+ } else {
2382
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2383
+ }
2384
+ #endif
2385
+ UNUSED(device);
2386
+ }
2387
+
2388
+ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2588
2389
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2589
2390
 
2590
2391
  const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -2616,46 +2417,32 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
2616
2417
  }
2617
2418
 
2618
2419
  GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2619
-
2620
-
2621
- #if TARGET_OS_OSX
2622
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2623
- device.currentAllocatedSize / 1024.0 / 1024.0,
2624
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2625
-
2626
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2627
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2628
- } else {
2629
- GGML_METAL_LOG_INFO("\n");
2630
- }
2631
- #else
2632
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2633
- #endif
2634
-
2420
+ ggml_backend_metal_log_allocated_size(device);
2635
2421
 
2636
2422
  return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
2637
2423
  }
2638
2424
 
2639
- static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2425
+ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2640
2426
  return 32;
2641
2427
  UNUSED(buft);
2642
2428
  }
2643
2429
 
2644
- static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2430
+ GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2645
2431
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
2646
2432
 
2647
2433
  UNUSED(buft);
2648
2434
  }
2649
2435
 
2650
- static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2436
+ GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2651
2437
  return true;
2652
2438
 
2653
2439
  UNUSED(buft);
2654
2440
  }
2655
2441
 
2656
- ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2442
+ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2657
2443
  static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
2658
2444
  /* .iface = */ {
2445
+ /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
2659
2446
  /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2660
2447
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2661
2448
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@@ -2670,7 +2457,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2670
2457
 
2671
2458
  // buffer from ptr
2672
2459
 
2673
- ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2460
+ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2674
2461
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2675
2462
 
2676
2463
  ctx->all_data = data;
@@ -2679,6 +2466,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2679
2466
  ctx->n_buffers = 0;
2680
2467
 
2681
2468
  const size_t size_page = sysconf(_SC_PAGESIZE);
2469
+
2470
+ // page-align the data ptr
2471
+ {
2472
+ const uintptr_t offs = (uintptr_t) data % size_page;
2473
+ data = (void *) ((char *) data - offs);
2474
+ size += offs;
2475
+ }
2476
+
2682
2477
  size_t size_aligned = size;
2683
2478
  if ((size_aligned % size_page) != 0) {
2684
2479
  size_aligned += (size_page - (size_aligned % size_page));
@@ -2730,63 +2525,50 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2730
2525
  }
2731
2526
  }
2732
2527
 
2733
- #if TARGET_OS_OSX
2734
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2735
- device.currentAllocatedSize / 1024.0 / 1024.0,
2736
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2737
-
2738
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2739
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2740
- } else {
2741
- GGML_METAL_LOG_INFO("\n");
2742
- }
2743
- #else
2744
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2745
- #endif
2528
+ ggml_backend_metal_log_allocated_size(device);
2746
2529
 
2747
2530
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
2748
2531
  }
2749
2532
 
2750
2533
  // backend
2751
2534
 
2752
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2535
+ GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2753
2536
  return "Metal";
2754
2537
 
2755
2538
  UNUSED(backend);
2756
2539
  }
2757
2540
 
2758
- static void ggml_backend_metal_free(ggml_backend_t backend) {
2541
+ GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
2759
2542
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2760
2543
  ggml_metal_free(ctx);
2761
2544
  free(backend);
2762
2545
  }
2763
2546
 
2764
- static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2547
+ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2765
2548
  return ggml_backend_metal_buffer_type();
2766
2549
 
2767
2550
  UNUSED(backend);
2768
2551
  }
2769
2552
 
2770
- static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2553
+ GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2771
2554
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2772
2555
 
2773
2556
  return ggml_metal_graph_compute(metal_ctx, cgraph);
2774
2557
  }
2775
2558
 
2776
- static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2777
- return ggml_metal_supports_op(op);
2559
+ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2560
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2778
2561
 
2779
- UNUSED(backend);
2562
+ return ggml_metal_supports_op(metal_ctx, op);
2780
2563
  }
2781
2564
 
2782
- static struct ggml_backend_i metal_backend_i = {
2565
+ static struct ggml_backend_i ggml_backend_metal_i = {
2783
2566
  /* .get_name = */ ggml_backend_metal_name,
2784
2567
  /* .free = */ ggml_backend_metal_free,
2785
2568
  /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2786
2569
  /* .set_tensor_async = */ NULL,
2787
2570
  /* .get_tensor_async = */ NULL,
2788
- /* .cpy_tensor_from_async = */ NULL,
2789
- /* .cpy_tensor_to_async = */ NULL,
2571
+ /* .cpy_tensor_async = */ NULL,
2790
2572
  /* .synchronize = */ NULL,
2791
2573
  /* .graph_plan_create = */ NULL,
2792
2574
  /* .graph_plan_free = */ NULL,
@@ -2795,6 +2577,11 @@ static struct ggml_backend_i metal_backend_i = {
2795
2577
  /* .supports_op = */ ggml_backend_metal_supports_op,
2796
2578
  };
2797
2579
 
2580
+ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
2581
+ ggml_metal_log_callback = log_callback;
2582
+ ggml_metal_log_user_data = user_data;
2583
+ }
2584
+
2798
2585
  ggml_backend_t ggml_backend_metal_init(void) {
2799
2586
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2800
2587
 
@@ -2805,7 +2592,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2805
2592
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2806
2593
 
2807
2594
  *metal_backend = (struct ggml_backend) {
2808
- /* .interface = */ metal_backend_i,
2595
+ /* .interface = */ ggml_backend_metal_i,
2809
2596
  /* .context = */ ctx,
2810
2597
  };
2811
2598
 
@@ -2813,7 +2600,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2813
2600
  }
2814
2601
 
2815
2602
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2816
- return backend->iface.get_name == ggml_backend_metal_name;
2603
+ return backend && backend->iface.get_name == ggml_backend_metal_name;
2817
2604
  }
2818
2605
 
2819
2606
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -2821,7 +2608,7 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2821
2608
 
2822
2609
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2823
2610
 
2824
- ggml_metal_set_n_cb(ctx, n_cb);
2611
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
2825
2612
  }
2826
2613
 
2827
2614
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
@@ -2832,9 +2619,9 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2832
2619
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2833
2620
  }
2834
2621
 
2835
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2622
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2836
2623
 
2837
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2624
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2838
2625
  return ggml_backend_metal_init();
2839
2626
 
2840
2627
  GGML_UNUSED(params);