llama_cpp 0.12.1 → 0.12.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -24,7 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
27
+ #define GGML_METAL_MAX_KERNELS 256
28
28
 
29
29
  struct ggml_metal_buffer {
30
30
  const char * name;
@@ -35,6 +35,134 @@ struct ggml_metal_buffer {
35
35
  id<MTLBuffer> metal;
36
36
  };
37
37
 
38
+ struct ggml_metal_kernel {
39
+ id<MTLFunction> function;
40
+ id<MTLComputePipelineState> pipeline;
41
+ };
42
+
43
+ enum ggml_metal_kernel_type {
44
+ GGML_METAL_KERNEL_TYPE_ADD,
45
+ GGML_METAL_KERNEL_TYPE_ADD_ROW,
46
+ GGML_METAL_KERNEL_TYPE_MUL,
47
+ GGML_METAL_KERNEL_TYPE_MUL_ROW,
48
+ GGML_METAL_KERNEL_TYPE_DIV,
49
+ GGML_METAL_KERNEL_TYPE_DIV_ROW,
50
+ GGML_METAL_KERNEL_TYPE_SCALE,
51
+ GGML_METAL_KERNEL_TYPE_SCALE_4,
52
+ GGML_METAL_KERNEL_TYPE_TANH,
53
+ GGML_METAL_KERNEL_TYPE_RELU,
54
+ GGML_METAL_KERNEL_TYPE_GELU,
55
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK,
56
+ GGML_METAL_KERNEL_TYPE_SILU,
57
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX,
58
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
59
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
60
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
61
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
62
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
66
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
69
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
70
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
71
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
72
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
75
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
+ GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
+ GGML_METAL_KERNEL_TYPE_GROUP_NORM,
78
+ GGML_METAL_KERNEL_TYPE_NORM,
79
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
80
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
81
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
82
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
83
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
84
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
86
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
93
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
96
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
99
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
100
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
101
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
102
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
104
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
105
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
106
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
107
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
110
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
113
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
116
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
117
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
120
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
121
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
122
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
123
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
124
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
130
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
131
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
132
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
136
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
137
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
+ GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
+ GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
+ GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
+ GGML_METAL_KERNEL_TYPE_IM2COL_F16,
145
+ GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
+ GGML_METAL_KERNEL_TYPE_PAD_F32,
147
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
148
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
149
+ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
150
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
151
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
152
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
153
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
154
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
155
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
156
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
157
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
158
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
159
+ GGML_METAL_KERNEL_TYPE_CONCAT,
160
+ GGML_METAL_KERNEL_TYPE_SQR,
161
+ GGML_METAL_KERNEL_TYPE_SUM_ROWS,
162
+
163
+ GGML_METAL_KERNEL_TYPE_COUNT
164
+ };
165
+
38
166
  struct ggml_metal_context {
39
167
  int n_cb;
40
168
 
@@ -42,142 +170,15 @@ struct ggml_metal_context {
42
170
  id<MTLCommandQueue> queue;
43
171
  id<MTLLibrary> library;
44
172
 
45
- id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
46
- id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
47
-
48
173
  dispatch_queue_t d_queue;
49
174
 
50
175
  int n_buffers;
51
176
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
52
177
 
53
- int concur_list[GGML_MAX_CONCUR];
54
- int concur_list_len;
55
-
56
- // custom kernels
57
- #define GGML_METAL_DECL_KERNEL(name) \
58
- id<MTLFunction> function_##name; \
59
- id<MTLComputePipelineState> pipeline_##name
60
-
61
- GGML_METAL_DECL_KERNEL(add);
62
- GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
- GGML_METAL_DECL_KERNEL(mul);
64
- GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
- GGML_METAL_DECL_KERNEL(div);
66
- GGML_METAL_DECL_KERNEL(div_row);
67
- GGML_METAL_DECL_KERNEL(scale);
68
- GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(tanh);
70
- GGML_METAL_DECL_KERNEL(relu);
71
- GGML_METAL_DECL_KERNEL(gelu);
72
- GGML_METAL_DECL_KERNEL(gelu_quick);
73
- GGML_METAL_DECL_KERNEL(silu);
74
- GGML_METAL_DECL_KERNEL(soft_max);
75
- GGML_METAL_DECL_KERNEL(soft_max_4);
76
- GGML_METAL_DECL_KERNEL(diag_mask_inf);
77
- GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
78
- GGML_METAL_DECL_KERNEL(get_rows_f32);
79
- GGML_METAL_DECL_KERNEL(get_rows_f16);
80
- GGML_METAL_DECL_KERNEL(get_rows_q4_0);
81
- GGML_METAL_DECL_KERNEL(get_rows_q4_1);
82
- GGML_METAL_DECL_KERNEL(get_rows_q5_0);
83
- GGML_METAL_DECL_KERNEL(get_rows_q5_1);
84
- GGML_METAL_DECL_KERNEL(get_rows_q8_0);
85
- GGML_METAL_DECL_KERNEL(get_rows_q2_K);
86
- GGML_METAL_DECL_KERNEL(get_rows_q3_K);
87
- GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
- GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
- GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
- GGML_METAL_DECL_KERNEL(get_rows_i32);
91
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
92
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
93
- GGML_METAL_DECL_KERNEL(rms_norm);
94
- GGML_METAL_DECL_KERNEL(group_norm);
95
- GGML_METAL_DECL_KERNEL(norm);
96
- GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
97
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
98
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
99
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
100
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
101
- GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
102
- GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
103
- GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
104
- GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
105
- GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
106
- GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
107
- GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
108
- GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
109
- GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
110
- GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
111
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
112
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
113
- GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
114
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
115
- GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
116
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
117
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
118
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
119
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
120
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
121
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
122
- GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
123
- GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
124
- GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
125
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
126
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
127
- GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
128
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
129
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
130
- GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
131
- GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
132
- GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
133
- GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
134
- GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
135
- GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
136
- GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
137
- GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
138
- GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
139
- GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
140
- GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
141
- GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
142
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
143
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
144
- GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
145
- GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
146
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
147
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
148
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
149
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
150
- GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
151
- GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
152
- GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
153
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
154
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
155
- GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
156
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
157
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
158
- GGML_METAL_DECL_KERNEL(rope_f32);
159
- GGML_METAL_DECL_KERNEL(rope_f16);
160
- GGML_METAL_DECL_KERNEL(alibi_f32);
161
- GGML_METAL_DECL_KERNEL(im2col_f16);
162
- GGML_METAL_DECL_KERNEL(upscale_f32);
163
- GGML_METAL_DECL_KERNEL(pad_f32);
164
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
165
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
166
- GGML_METAL_DECL_KERNEL(leaky_relu_f32);
167
- GGML_METAL_DECL_KERNEL(cpy_f32_f16);
168
- GGML_METAL_DECL_KERNEL(cpy_f32_f32);
169
- GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
170
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
171
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
172
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
173
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
174
- GGML_METAL_DECL_KERNEL(cpy_f16_f16);
175
- GGML_METAL_DECL_KERNEL(cpy_f16_f32);
176
- GGML_METAL_DECL_KERNEL(concat);
177
- GGML_METAL_DECL_KERNEL(sqr);
178
- GGML_METAL_DECL_KERNEL(sum_rows);
179
-
180
- #undef GGML_METAL_DECL_KERNEL
178
+ struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
179
+
180
+ bool support_simdgroup_reduction;
181
+ bool support_simdgroup_mm;
181
182
  };
182
183
 
183
184
  // MSL code
@@ -191,7 +192,6 @@ struct ggml_metal_context {
191
192
  @implementation GGMLMetalClass
192
193
  @end
193
194
 
194
-
195
195
  static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
196
196
  fprintf(stderr, "%s", msg);
197
197
 
@@ -202,11 +202,6 @@ static void ggml_metal_default_log_callback(enum ggml_log_level level, const cha
202
202
  ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
203
203
  void * ggml_metal_log_user_data = NULL;
204
204
 
205
- void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
206
- ggml_metal_log_callback = log_callback;
207
- ggml_metal_log_user_data = user_data;
208
- }
209
-
210
205
  GGML_ATTRIBUTE_FORMAT(2, 3)
211
206
  static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
212
207
  if (ggml_metal_log_callback != NULL) {
@@ -229,7 +224,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
229
224
  }
230
225
  }
231
226
 
232
- struct ggml_metal_context * ggml_metal_init(int n_cb) {
227
+ static void * ggml_metal_host_malloc(size_t n) {
228
+ void * data = NULL;
229
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
230
+ if (result != 0) {
231
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
232
+ return NULL;
233
+ }
234
+
235
+ return data;
236
+ }
237
+
238
+ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
233
239
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
234
240
 
235
241
  id<MTLDevice> device;
@@ -255,7 +261,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
255
261
  ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
256
262
  ctx->queue = [ctx->device newCommandQueue];
257
263
  ctx->n_buffers = 0;
258
- ctx->concur_list_len = 0;
259
264
 
260
265
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
261
266
 
@@ -298,19 +303,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
298
303
  return NULL;
299
304
  }
300
305
 
301
- MTLCompileOptions* options = nil;
306
+ // dictionary of preprocessor macros
307
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
308
+
302
309
  #ifdef GGML_QKK_64
303
- options = [MTLCompileOptions new];
304
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
310
+ prep[@"QK_K"] = @(64);
305
311
  #endif
306
- // try to disable fast-math
307
- // NOTE: this seems to have no effect whatsoever
308
- // instead, in order to disable fast-math, we have to build default.metallib from the command line
309
- // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
310
- // and go through the "pre-compiled library found" path above
312
+
313
+ MTLCompileOptions* options = [MTLCompileOptions new];
314
+ options.preprocessorMacros = prep;
315
+
311
316
  //[options setFastMathEnabled:false];
312
317
 
313
318
  ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
319
+
320
+ [options release];
321
+ [prep release];
314
322
  }
315
323
 
316
324
  if (error) {
@@ -319,22 +327,51 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
319
327
  }
320
328
  }
321
329
 
322
- #if TARGET_OS_OSX
323
330
  // print MTL GPU family:
324
331
  GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
325
332
 
333
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
334
+
326
335
  // determine max supported GPU family
327
336
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
328
337
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
329
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
330
- if ([ctx->device supportsFamily:i]) {
331
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
332
- break;
338
+ {
339
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
340
+ if ([ctx->device supportsFamily:i]) {
341
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
342
+ break;
343
+ }
344
+ }
345
+
346
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
347
+ if ([ctx->device supportsFamily:i]) {
348
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
349
+ break;
350
+ }
351
+ }
352
+
353
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
354
+ if ([ctx->device supportsFamily:i]) {
355
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
356
+ break;
357
+ }
333
358
  }
334
359
  }
335
360
 
361
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
362
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
363
+
364
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
365
+
366
+ GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
367
+ GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
336
368
  GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
337
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
369
+
370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
372
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
373
+ }
374
+ #elif TARGET_OS_OSX
338
375
  if (ctx->device.maxTransferRate != 0) {
339
376
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
340
377
  } else {
@@ -346,279 +383,171 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
346
383
  {
347
384
  NSError * error = nil;
348
385
 
386
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
387
+ ctx->kernels[i].function = nil;
388
+ ctx->kernels[i].pipeline = nil;
389
+ }
390
+
349
391
  /*
350
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
351
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
352
- (int) ctx->pipeline_##name.threadExecutionWidth); \
392
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
393
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
394
+ (int) kernel->pipeline.threadExecutionWidth); \
353
395
  */
354
- #define GGML_METAL_ADD_KERNEL(name) \
355
- ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
356
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
357
- if (error) { \
358
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
359
- return NULL; \
396
+ #define GGML_METAL_ADD_KERNEL(e, name, supported) \
397
+ if (supported) { \
398
+ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
399
+ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
400
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
401
+ if (error) { \
402
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
403
+ return NULL; \
404
+ } \
405
+ } else { \
406
+ GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
360
407
  }
361
408
 
362
- GGML_METAL_ADD_KERNEL(add);
363
- GGML_METAL_ADD_KERNEL(add_row);
364
- GGML_METAL_ADD_KERNEL(mul);
365
- GGML_METAL_ADD_KERNEL(mul_row);
366
- GGML_METAL_ADD_KERNEL(div);
367
- GGML_METAL_ADD_KERNEL(div_row);
368
- GGML_METAL_ADD_KERNEL(scale);
369
- GGML_METAL_ADD_KERNEL(scale_4);
370
- GGML_METAL_ADD_KERNEL(tanh);
371
- GGML_METAL_ADD_KERNEL(relu);
372
- GGML_METAL_ADD_KERNEL(gelu);
373
- GGML_METAL_ADD_KERNEL(gelu_quick);
374
- GGML_METAL_ADD_KERNEL(silu);
375
- GGML_METAL_ADD_KERNEL(soft_max);
376
- GGML_METAL_ADD_KERNEL(soft_max_4);
377
- GGML_METAL_ADD_KERNEL(diag_mask_inf);
378
- GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
379
- GGML_METAL_ADD_KERNEL(get_rows_f32);
380
- GGML_METAL_ADD_KERNEL(get_rows_f16);
381
- GGML_METAL_ADD_KERNEL(get_rows_q4_0);
382
- GGML_METAL_ADD_KERNEL(get_rows_q4_1);
383
- GGML_METAL_ADD_KERNEL(get_rows_q5_0);
384
- GGML_METAL_ADD_KERNEL(get_rows_q5_1);
385
- GGML_METAL_ADD_KERNEL(get_rows_q8_0);
386
- GGML_METAL_ADD_KERNEL(get_rows_q2_K);
387
- GGML_METAL_ADD_KERNEL(get_rows_q3_K);
388
- GGML_METAL_ADD_KERNEL(get_rows_q4_K);
389
- GGML_METAL_ADD_KERNEL(get_rows_q5_K);
390
- GGML_METAL_ADD_KERNEL(get_rows_q6_K);
391
- GGML_METAL_ADD_KERNEL(get_rows_i32);
392
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
393
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
394
- GGML_METAL_ADD_KERNEL(rms_norm);
395
- GGML_METAL_ADD_KERNEL(group_norm);
396
- GGML_METAL_ADD_KERNEL(norm);
397
- GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
398
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
399
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
400
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
401
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
402
- GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
403
- GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
404
- GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
405
- GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
406
- GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
407
- GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
408
- GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
409
- GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
410
- GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
411
- GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
412
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
413
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
414
- GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
415
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
416
- GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
417
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
418
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
419
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
420
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
421
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
422
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
423
- GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
424
- GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
425
- GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
426
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
427
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
428
- GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
429
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
430
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
431
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
432
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
433
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
434
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
435
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
436
- GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
437
- GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
438
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
439
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
440
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
441
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
442
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
443
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
444
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
445
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
446
- GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
447
- GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
448
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
449
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
450
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
451
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
452
- GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
453
- GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
454
- GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
455
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
456
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
457
- GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
458
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
459
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
460
- }
461
- GGML_METAL_ADD_KERNEL(rope_f32);
462
- GGML_METAL_ADD_KERNEL(rope_f16);
463
- GGML_METAL_ADD_KERNEL(alibi_f32);
464
- GGML_METAL_ADD_KERNEL(im2col_f16);
465
- GGML_METAL_ADD_KERNEL(upscale_f32);
466
- GGML_METAL_ADD_KERNEL(pad_f32);
467
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
468
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
469
- GGML_METAL_ADD_KERNEL(leaky_relu_f32);
470
- GGML_METAL_ADD_KERNEL(cpy_f32_f16);
471
- GGML_METAL_ADD_KERNEL(cpy_f32_f32);
472
- GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
473
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
474
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
475
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
476
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
477
- GGML_METAL_ADD_KERNEL(cpy_f16_f16);
478
- GGML_METAL_ADD_KERNEL(cpy_f16_f32);
479
- GGML_METAL_ADD_KERNEL(concat);
480
- GGML_METAL_ADD_KERNEL(sqr);
481
- GGML_METAL_ADD_KERNEL(sum_rows);
482
-
483
- #undef GGML_METAL_ADD_KERNEL
409
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
410
+
411
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
412
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
413
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
414
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
415
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
416
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
417
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
418
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
419
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
420
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
421
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
422
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
423
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
424
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
425
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
426
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
427
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
428
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
429
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
430
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
431
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
432
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
433
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
434
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
436
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
437
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
438
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
439
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
440
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
441
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
442
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
443
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
444
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
445
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
446
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
447
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
448
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
449
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
450
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
451
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
452
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
453
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
454
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
459
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
460
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
461
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
463
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
464
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
465
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
466
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
467
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
468
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
469
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
470
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
471
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
476
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
477
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
478
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
482
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
484
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
486
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
487
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
488
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
489
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
491
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
502
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
504
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
507
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
511
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
518
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
519
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
522
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
523
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
525
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
527
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
528
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
484
529
  }
485
530
 
486
531
  return ctx;
487
532
  }
488
533
 
489
- void ggml_metal_free(struct ggml_metal_context * ctx) {
534
+ static void ggml_metal_free(struct ggml_metal_context * ctx) {
490
535
  GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
491
- #define GGML_METAL_DEL_KERNEL(name) \
492
- [ctx->function_##name release]; \
493
- [ctx->pipeline_##name release];
494
-
495
- GGML_METAL_DEL_KERNEL(add);
496
- GGML_METAL_DEL_KERNEL(add_row);
497
- GGML_METAL_DEL_KERNEL(mul);
498
- GGML_METAL_DEL_KERNEL(mul_row);
499
- GGML_METAL_DEL_KERNEL(div);
500
- GGML_METAL_DEL_KERNEL(div_row);
501
- GGML_METAL_DEL_KERNEL(scale);
502
- GGML_METAL_DEL_KERNEL(scale_4);
503
- GGML_METAL_DEL_KERNEL(tanh);
504
- GGML_METAL_DEL_KERNEL(relu);
505
- GGML_METAL_DEL_KERNEL(gelu);
506
- GGML_METAL_DEL_KERNEL(gelu_quick);
507
- GGML_METAL_DEL_KERNEL(silu);
508
- GGML_METAL_DEL_KERNEL(soft_max);
509
- GGML_METAL_DEL_KERNEL(soft_max_4);
510
- GGML_METAL_DEL_KERNEL(diag_mask_inf);
511
- GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
512
- GGML_METAL_DEL_KERNEL(get_rows_f32);
513
- GGML_METAL_DEL_KERNEL(get_rows_f16);
514
- GGML_METAL_DEL_KERNEL(get_rows_q4_0);
515
- GGML_METAL_DEL_KERNEL(get_rows_q4_1);
516
- GGML_METAL_DEL_KERNEL(get_rows_q5_0);
517
- GGML_METAL_DEL_KERNEL(get_rows_q5_1);
518
- GGML_METAL_DEL_KERNEL(get_rows_q8_0);
519
- GGML_METAL_DEL_KERNEL(get_rows_q2_K);
520
- GGML_METAL_DEL_KERNEL(get_rows_q3_K);
521
- GGML_METAL_DEL_KERNEL(get_rows_q4_K);
522
- GGML_METAL_DEL_KERNEL(get_rows_q5_K);
523
- GGML_METAL_DEL_KERNEL(get_rows_q6_K);
524
- GGML_METAL_DEL_KERNEL(get_rows_i32);
525
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
526
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
527
- GGML_METAL_DEL_KERNEL(rms_norm);
528
- GGML_METAL_DEL_KERNEL(group_norm);
529
- GGML_METAL_DEL_KERNEL(norm);
530
- GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
531
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
532
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
533
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
534
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
535
- GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
536
- GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
537
- GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
538
- GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
539
- GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
540
- GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
541
- GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
542
- GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
543
- GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
544
- GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
545
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
546
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
547
- GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
548
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
549
- GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
550
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
551
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
552
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
553
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
554
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
555
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
556
- GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
557
- GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
558
- GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
559
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
560
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
561
- GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
562
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
563
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
564
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
565
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
566
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
567
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
568
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
569
- GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
570
- GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
571
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
572
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
573
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
574
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
575
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
576
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
577
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
578
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
579
- GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
580
- GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
581
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
582
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
583
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
584
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
585
- GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
586
- GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
587
- GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
588
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
589
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
590
- GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
591
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
592
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
593
- }
594
- GGML_METAL_DEL_KERNEL(rope_f32);
595
- GGML_METAL_DEL_KERNEL(rope_f16);
596
- GGML_METAL_DEL_KERNEL(alibi_f32);
597
- GGML_METAL_DEL_KERNEL(im2col_f16);
598
- GGML_METAL_DEL_KERNEL(upscale_f32);
599
- GGML_METAL_DEL_KERNEL(pad_f32);
600
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
601
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
602
- GGML_METAL_DEL_KERNEL(leaky_relu_f32);
603
- GGML_METAL_DEL_KERNEL(cpy_f32_f16);
604
- GGML_METAL_DEL_KERNEL(cpy_f32_f32);
605
- GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
606
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
607
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
608
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
609
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
610
- GGML_METAL_DEL_KERNEL(cpy_f16_f16);
611
- GGML_METAL_DEL_KERNEL(cpy_f16_f32);
612
- GGML_METAL_DEL_KERNEL(concat);
613
- GGML_METAL_DEL_KERNEL(sqr);
614
- GGML_METAL_DEL_KERNEL(sum_rows);
615
-
616
- #undef GGML_METAL_DEL_KERNEL
617
536
 
618
537
  for (int i = 0; i < ctx->n_buffers; ++i) {
619
538
  [ctx->buffers[i].metal release];
620
539
  }
621
540
 
541
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
542
+ if (ctx->kernels[i].pipeline) {
543
+ [ctx->kernels[i].pipeline release];
544
+ }
545
+
546
+ if (ctx->kernels[i].function) {
547
+ [ctx->kernels[i].function release];
548
+ }
549
+ }
550
+
622
551
  [ctx->library release];
623
552
  [ctx->queue release];
624
553
  [ctx->device release];
@@ -628,33 +557,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
628
557
  free(ctx);
629
558
  }
630
559
 
631
- void * ggml_metal_host_malloc(size_t n) {
632
- void * data = NULL;
633
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
634
- if (result != 0) {
635
- GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
636
- return NULL;
637
- }
638
-
639
- return data;
640
- }
641
-
642
- void ggml_metal_host_free(void * data) {
643
- free(data);
644
- }
645
-
646
- void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
647
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
648
- }
649
-
650
- int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
651
- return ctx->concur_list_len;
652
- }
653
-
654
- int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
655
- return ctx->concur_list;
656
- }
657
-
658
560
  // temporarily defined here for compatibility between ggml-backend and the old API
659
561
 
660
562
  struct ggml_backend_metal_buffer {
@@ -727,210 +629,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
727
629
  return nil;
728
630
  }
729
631
 
730
- bool ggml_metal_add_buffer(
731
- struct ggml_metal_context * ctx,
732
- const char * name,
733
- void * data,
734
- size_t size,
735
- size_t max_size) {
736
- if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
737
- GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
738
- return false;
739
- }
740
-
741
- if (data) {
742
- // verify that the buffer does not overlap with any of the existing buffers
743
- for (int i = 0; i < ctx->n_buffers; ++i) {
744
- const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
745
-
746
- if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
747
- GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
748
- return false;
749
- }
750
- }
751
-
752
- const size_t size_page = sysconf(_SC_PAGESIZE);
753
-
754
- size_t size_aligned = size;
755
- if ((size_aligned % size_page) != 0) {
756
- size_aligned += (size_page - (size_aligned % size_page));
757
- }
758
-
759
- // the buffer fits into the max buffer size allowed by the device
760
- if (size_aligned <= ctx->device.maxBufferLength) {
761
- ctx->buffers[ctx->n_buffers].name = name;
762
- ctx->buffers[ctx->n_buffers].data = data;
763
- ctx->buffers[ctx->n_buffers].size = size;
764
-
765
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
766
-
767
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
768
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
769
- return false;
770
- }
771
-
772
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
773
-
774
- ++ctx->n_buffers;
775
- } else {
776
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
777
- // one of the views
778
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
779
- const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
780
- const size_t size_view = ctx->device.maxBufferLength;
781
-
782
- for (size_t i = 0; i < size; i += size_step) {
783
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
784
-
785
- ctx->buffers[ctx->n_buffers].name = name;
786
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
787
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
788
-
789
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
790
-
791
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
792
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
793
- return false;
794
- }
795
-
796
- GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
797
- if (i + size_step < size) {
798
- GGML_METAL_LOG_INFO("\n");
799
- }
800
-
801
- ++ctx->n_buffers;
802
- }
803
- }
804
-
805
- #if TARGET_OS_OSX
806
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
807
- ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
808
- ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
809
-
810
- if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
811
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
812
- } else {
813
- GGML_METAL_LOG_INFO("\n");
814
- }
815
- #else
816
- GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
817
- #endif
818
- }
819
-
820
- return true;
821
- }
822
-
823
- void ggml_metal_set_tensor(
824
- struct ggml_metal_context * ctx,
825
- struct ggml_tensor * t) {
826
- size_t offs;
827
- id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
828
-
829
- memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
830
- }
831
-
832
- void ggml_metal_get_tensor(
833
- struct ggml_metal_context * ctx,
834
- struct ggml_tensor * t) {
835
- size_t offs;
836
- id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
837
-
838
- memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
839
- }
840
-
841
- void ggml_metal_graph_find_concurrency(
842
- struct ggml_metal_context * ctx,
843
- struct ggml_cgraph * gf, bool check_mem) {
844
- int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
845
- int nodes_unused[GGML_MAX_CONCUR];
846
-
847
- for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
848
- for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
849
- ctx->concur_list_len = 0;
850
-
851
- int n_left = gf->n_nodes;
852
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
853
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
854
-
855
- while (n_left > 0) {
856
- // number of nodes at a layer (that can be issued concurrently)
857
- int concurrency = 0;
858
- for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
859
- if (nodes_unused[i]) {
860
- // if the requirements for gf->nodes[i] are satisfied
861
- int exe_flag = 1;
862
-
863
- // scan all srcs
864
- for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
865
- struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
866
- if (src_cur) {
867
- // if is leaf nodes it's satisfied.
868
- // TODO: ggml_is_leaf()
869
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
870
- continue;
871
- }
872
-
873
- // otherwise this src should be the output from previous nodes.
874
- int is_found = 0;
875
-
876
- // scan 2*search_depth back because we inserted barrier.
877
- //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
878
- for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
879
- if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
880
- is_found = 1;
881
- break;
882
- }
883
- }
884
- if (is_found == 0) {
885
- exe_flag = 0;
886
- break;
887
- }
888
- }
889
- }
890
- if (exe_flag && check_mem) {
891
- // check if nodes[i]'s data will be overwritten by a node before nodes[i].
892
- // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
893
- int64_t data_start = (int64_t) gf->nodes[i]->data;
894
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
895
- for (int j = n_start; j < i; j++) {
896
- if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
897
- && gf->nodes[j]->op != GGML_OP_VIEW \
898
- && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
899
- && gf->nodes[j]->op != GGML_OP_PERMUTE) {
900
- if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
901
- ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
902
- continue;
903
- }
904
-
905
- exe_flag = 0;
906
- }
907
- }
908
- }
909
- if (exe_flag) {
910
- ctx->concur_list[level_pos + concurrency] = i;
911
- nodes_unused[i] = 0;
912
- concurrency++;
913
- ctx->concur_list_len++;
914
- }
915
- }
916
- }
917
- n_left -= concurrency;
918
- // adding a barrier different layer
919
- ctx->concur_list[level_pos + concurrency] = -1;
920
- ctx->concur_list_len++;
921
- // jump all sorted nodes at nodes_bak
922
- while (!nodes_unused[n_start]) {
923
- n_start++;
924
- }
925
- level_pos += concurrency + 1;
926
- }
927
-
928
- if (ctx->concur_list_len > GGML_MAX_CONCUR) {
929
- GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
930
- }
931
- }
932
-
933
- static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
632
+ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
934
633
  switch (op->op) {
935
634
  case GGML_OP_UNARY:
936
635
  switch (ggml_get_unary_op(op)) {
@@ -956,9 +655,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
956
655
  case GGML_OP_SCALE:
957
656
  case GGML_OP_SQR:
958
657
  case GGML_OP_SUM_ROWS:
658
+ return true;
959
659
  case GGML_OP_SOFT_MAX:
960
660
  case GGML_OP_RMS_NORM:
961
661
  case GGML_OP_GROUP_NORM:
662
+ return ctx->support_simdgroup_reduction;
962
663
  case GGML_OP_NORM:
963
664
  case GGML_OP_ALIBI:
964
665
  case GGML_OP_ROPE:
@@ -967,9 +668,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
967
668
  case GGML_OP_PAD:
968
669
  case GGML_OP_ARGSORT:
969
670
  case GGML_OP_LEAKY_RELU:
671
+ return true;
970
672
  case GGML_OP_MUL_MAT:
971
673
  case GGML_OP_MUL_MAT_ID:
972
- return true;
674
+ return ctx->support_simdgroup_reduction;
973
675
  case GGML_OP_CPY:
974
676
  case GGML_OP_DUP:
975
677
  case GGML_OP_CONT:
@@ -1007,1480 +709,1549 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
1007
709
  return false;
1008
710
  }
1009
711
  }
1010
- bool ggml_metal_graph_compute(
712
+
713
+ static bool ggml_metal_graph_compute(
1011
714
  struct ggml_metal_context * ctx,
1012
715
  struct ggml_cgraph * gf) {
1013
716
  @autoreleasepool {
1014
717
 
1015
- // if there is ctx->concur_list, dispatch concurrently
1016
- // else fallback to serial dispatch
1017
718
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
1018
-
1019
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
1020
-
1021
- const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
1022
- edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
719
+ edesc.dispatchType = MTLDispatchTypeSerial;
1023
720
 
1024
721
  // create multiple command buffers and enqueue them
1025
722
  // then, we encode the graph into the command buffers in parallel
1026
723
 
724
+ const int n_nodes = gf->n_nodes;
1027
725
  const int n_cb = ctx->n_cb;
726
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
1028
727
 
1029
- for (int i = 0; i < n_cb; ++i) {
1030
- ctx->command_buffers[i] = [ctx->queue commandBuffer];
728
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
729
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
730
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
731
+ command_buffer_builder[cb_idx] = command_buffer;
1031
732
 
1032
733
  // enqueue the command buffers in order to specify their execution order
1033
- [ctx->command_buffers[i] enqueue];
1034
-
1035
- ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
734
+ [command_buffer enqueue];
1036
735
  }
736
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
1037
737
 
1038
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
1039
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
1040
-
1041
- dispatch_async(ctx->d_queue, ^{
1042
- size_t offs_src0 = 0;
1043
- size_t offs_src1 = 0;
1044
- size_t offs_dst = 0;
1045
-
1046
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
1047
- id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
1048
-
1049
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
1050
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1051
-
1052
- for (int ind = node_start; ind < node_end; ++ind) {
1053
- const int i = has_concur ? ctx->concur_list[ind] : ind;
1054
-
1055
- if (i == -1) {
1056
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
1057
- continue;
1058
- }
1059
-
1060
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
1061
-
1062
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1063
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1064
- struct ggml_tensor * dst = gf->nodes[i];
1065
-
1066
- switch (dst->op) {
1067
- case GGML_OP_NONE:
1068
- case GGML_OP_RESHAPE:
1069
- case GGML_OP_VIEW:
1070
- case GGML_OP_TRANSPOSE:
1071
- case GGML_OP_PERMUTE:
1072
- {
1073
- // noop -> next node
1074
- } continue;
1075
- default:
1076
- {
1077
- } break;
1078
- }
738
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
739
+ const int cb_idx = iter;
1079
740
 
1080
- if (!ggml_metal_supports_op(dst)) {
1081
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1082
- GGML_ASSERT(!"unsupported op");
1083
- }
1084
-
1085
- #ifndef GGML_METAL_NDEBUG
1086
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
1087
- #endif
741
+ size_t offs_src0 = 0;
742
+ size_t offs_src1 = 0;
743
+ size_t offs_dst = 0;
1088
744
 
1089
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
1090
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
1091
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
1092
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
1093
-
1094
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
1095
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
1096
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
1097
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
1098
-
1099
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
1100
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
1101
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
1102
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
1103
-
1104
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
1105
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
1106
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
1107
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
1108
-
1109
- const int64_t ne0 = dst ? dst->ne[0] : 0;
1110
- const int64_t ne1 = dst ? dst->ne[1] : 0;
1111
- const int64_t ne2 = dst ? dst->ne[2] : 0;
1112
- const int64_t ne3 = dst ? dst->ne[3] : 0;
1113
-
1114
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
1115
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
1116
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
1117
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
1118
-
1119
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1120
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1121
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1122
-
1123
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
1124
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
1125
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
1126
-
1127
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1128
- //if (src0) {
1129
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1130
- // ggml_is_contiguous(src0), src0->name);
1131
- //}
1132
- //if (src1) {
1133
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1134
- // ggml_is_contiguous(src1), src1->name);
1135
- //}
1136
- //if (dst) {
1137
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1138
- // dst->name);
1139
- //}
1140
-
1141
- switch (dst->op) {
1142
- case GGML_OP_CONCAT:
1143
- {
1144
- const int64_t nb = ne00;
1145
-
1146
- [encoder setComputePipelineState:ctx->pipeline_concat];
1147
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1148
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1149
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1150
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1151
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1152
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1153
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1154
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1155
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1156
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1157
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1158
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1159
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1160
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1161
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1162
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1163
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1164
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1165
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1166
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1167
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1168
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1169
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1170
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1171
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1172
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1173
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1174
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1175
-
1176
- const int nth = MIN(1024, ne0);
1177
-
1178
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1179
- } break;
1180
- case GGML_OP_ADD:
1181
- case GGML_OP_MUL:
1182
- case GGML_OP_DIV:
1183
- {
1184
- const size_t offs = 0;
1185
-
1186
- bool bcast_row = false;
1187
-
1188
- int64_t nb = ne00;
745
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
746
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1189
747
 
1190
- id<MTLComputePipelineState> pipeline = nil;
748
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
749
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1191
750
 
1192
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1193
- GGML_ASSERT(ggml_is_contiguous(src0));
751
+ for (int i = node_start; i < node_end; ++i) {
752
+ if (i == -1) {
753
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
754
+ continue;
755
+ }
1194
756
 
1195
- // src1 is a row
1196
- GGML_ASSERT(ne11 == 1);
757
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
758
+
759
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
760
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
761
+ struct ggml_tensor * dst = gf->nodes[i];
762
+
763
+ switch (dst->op) {
764
+ case GGML_OP_NONE:
765
+ case GGML_OP_RESHAPE:
766
+ case GGML_OP_VIEW:
767
+ case GGML_OP_TRANSPOSE:
768
+ case GGML_OP_PERMUTE:
769
+ {
770
+ // noop -> next node
771
+ } continue;
772
+ default:
773
+ {
774
+ } break;
775
+ }
1197
776
 
1198
- nb = ne00 / 4;
1199
- switch (dst->op) {
1200
- case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1201
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1202
- case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1203
- default: GGML_ASSERT(false);
1204
- }
777
+ if (!ggml_metal_supports_op(ctx, dst)) {
778
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
779
+ GGML_ASSERT(!"unsupported op");
780
+ }
1205
781
 
1206
- bcast_row = true;
1207
- } else {
1208
- switch (dst->op) {
1209
- case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1210
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1211
- case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1212
- default: GGML_ASSERT(false);
1213
- }
1214
- }
782
+ #ifndef GGML_METAL_NDEBUG
783
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
784
+ #endif
1215
785
 
1216
- [encoder setComputePipelineState:pipeline];
1217
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1218
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1219
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1220
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1221
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1222
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1223
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1224
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1225
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1226
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1227
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1228
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1229
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1230
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1231
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1232
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1233
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1234
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1235
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1236
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1237
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1238
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1239
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1240
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1241
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1242
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1243
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1244
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1245
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1246
-
1247
- if (bcast_row) {
1248
- const int64_t n = ggml_nelements(dst)/4;
786
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
787
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
788
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
789
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
790
+
791
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
792
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
793
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
794
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
795
+
796
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
797
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
798
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
799
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
800
+
801
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
802
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
803
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
804
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
805
+
806
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
807
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
808
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
809
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
810
+
811
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
812
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
813
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
814
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
815
+
816
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
817
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
818
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
819
+
820
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
821
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
822
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
823
+
824
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
825
+ //if (src0) {
826
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
827
+ // ggml_is_contiguous(src0), src0->name);
828
+ //}
829
+ //if (src1) {
830
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
831
+ // ggml_is_contiguous(src1), src1->name);
832
+ //}
833
+ //if (dst) {
834
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
835
+ // dst->name);
836
+ //}
837
+
838
+ switch (dst->op) {
839
+ case GGML_OP_CONCAT:
840
+ {
841
+ const int64_t nb = ne00;
842
+
843
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
844
+
845
+ [encoder setComputePipelineState:pipeline];
846
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
847
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
848
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
849
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
850
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
851
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
852
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
853
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
854
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
855
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
856
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
857
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
858
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
859
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
860
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
861
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
862
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
863
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
864
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
865
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
866
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
867
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
868
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
869
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
870
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
871
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
872
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
873
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
874
+
875
+ const int nth = MIN(1024, ne0);
876
+
877
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
878
+ } break;
879
+ case GGML_OP_ADD:
880
+ case GGML_OP_MUL:
881
+ case GGML_OP_DIV:
882
+ {
883
+ const size_t offs = 0;
884
+
885
+ bool bcast_row = false;
886
+
887
+ int64_t nb = ne00;
888
+
889
+ id<MTLComputePipelineState> pipeline = nil;
890
+
891
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
892
+ GGML_ASSERT(ggml_is_contiguous(src0));
1249
893
 
1250
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1251
- } else {
1252
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
894
+ // src1 is a row
895
+ GGML_ASSERT(ne11 == 1);
1253
896
 
1254
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
897
+ nb = ne00 / 4;
898
+ switch (dst->op) {
899
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
900
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
901
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
902
+ default: GGML_ASSERT(false);
1255
903
  }
1256
- } break;
1257
- case GGML_OP_ACC:
1258
- {
1259
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1260
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1261
- GGML_ASSERT(dstt == GGML_TYPE_F32);
1262
904
 
1263
- GGML_ASSERT(ggml_is_contiguous(src0));
1264
- GGML_ASSERT(ggml_is_contiguous(src1));
1265
-
1266
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1267
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1268
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1269
- const size_t offs = ((int32_t *) dst->op_params)[3];
1270
-
1271
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1272
-
1273
- if (!inplace) {
1274
- // run a separete kernel to cpy src->dst
1275
- // not sure how to avoid this
1276
- // TODO: make a simpler cpy_bytes kernel
1277
-
1278
- const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
1279
-
1280
- [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1281
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1282
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1283
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1284
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1285
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1286
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1287
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1288
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1289
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1290
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1291
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1292
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1293
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1294
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1295
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1296
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1297
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1298
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1299
-
1300
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
905
+ bcast_row = true;
906
+ } else {
907
+ switch (dst->op) {
908
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
909
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
910
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
911
+ default: GGML_ASSERT(false);
1301
912
  }
913
+ }
1302
914
 
1303
- [encoder setComputePipelineState:ctx->pipeline_add];
1304
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1305
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1306
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1307
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1308
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1309
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1310
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1311
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1312
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1313
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1314
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1315
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1316
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1317
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1318
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1319
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1320
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1321
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1322
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1323
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1324
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1325
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1326
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1327
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1328
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1329
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1330
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1331
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1332
-
1333
- const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
1334
-
1335
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1336
- } break;
1337
- case GGML_OP_SCALE:
1338
- {
1339
- GGML_ASSERT(ggml_is_contiguous(src0));
915
+ [encoder setComputePipelineState:pipeline];
916
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
917
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
918
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
919
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
920
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
921
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
922
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
923
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
924
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
925
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
926
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
927
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
928
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
929
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
930
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
931
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
932
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
933
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
934
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
935
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
936
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
937
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
938
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
939
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
940
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
941
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
942
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
943
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
944
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
945
+
946
+ if (bcast_row) {
947
+ const int64_t n = ggml_nelements(dst)/4;
1340
948
 
1341
- const float scale = *(const float *) dst->op_params;
949
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
950
+ } else {
951
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1342
952
 
1343
- int64_t n = ggml_nelements(dst);
953
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
954
+ }
955
+ } break;
956
+ case GGML_OP_ACC:
957
+ {
958
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
959
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
960
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1344
961
 
1345
- if (n % 4 == 0) {
1346
- n /= 4;
1347
- [encoder setComputePipelineState:ctx->pipeline_scale_4];
1348
- } else {
1349
- [encoder setComputePipelineState:ctx->pipeline_scale];
1350
- }
962
+ GGML_ASSERT(ggml_is_contiguous(src0));
963
+ GGML_ASSERT(ggml_is_contiguous(src1));
1351
964
 
1352
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1353
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1354
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
965
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
966
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
967
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
968
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1355
969
 
1356
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1357
- } break;
1358
- case GGML_OP_UNARY:
1359
- switch (ggml_get_unary_op(gf->nodes[i])) {
1360
- case GGML_UNARY_OP_TANH:
1361
- {
1362
- [encoder setComputePipelineState:ctx->pipeline_tanh];
1363
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1364
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
970
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1365
971
 
1366
- const int64_t n = ggml_nelements(dst);
972
+ if (!inplace) {
973
+ // run a separete kernel to cpy src->dst
974
+ // not sure how to avoid this
975
+ // TODO: make a simpler cpy_bytes kernel
1367
976
 
1368
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1369
- } break;
1370
- case GGML_UNARY_OP_RELU:
1371
- {
1372
- [encoder setComputePipelineState:ctx->pipeline_relu];
1373
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1374
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
977
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1375
978
 
1376
- const int64_t n = ggml_nelements(dst);
979
+ [encoder setComputePipelineState:pipeline];
980
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
981
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
982
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
983
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
984
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
985
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
986
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
987
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
988
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
989
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
990
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
991
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
992
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
993
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
994
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
995
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
996
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
997
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1377
998
 
1378
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1379
- } break;
1380
- case GGML_UNARY_OP_GELU:
1381
- {
1382
- [encoder setComputePipelineState:ctx->pipeline_gelu];
1383
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1384
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
999
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1385
1000
 
1386
- const int64_t n = ggml_nelements(dst);
1387
- GGML_ASSERT(n % 4 == 0);
1001
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1002
+ }
1388
1003
 
1389
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1390
- } break;
1391
- case GGML_UNARY_OP_GELU_QUICK:
1392
- {
1393
- [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1394
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1395
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1004
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1005
+
1006
+ [encoder setComputePipelineState:pipeline];
1007
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1008
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1009
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1010
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1011
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1012
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1013
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1014
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1015
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1016
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1017
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1018
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1019
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1020
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1021
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1022
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1023
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1024
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1025
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1026
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1027
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1028
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1029
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1030
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1031
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1032
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1033
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1034
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1035
+
1036
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1037
+
1038
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1039
+ } break;
1040
+ case GGML_OP_SCALE:
1041
+ {
1042
+ GGML_ASSERT(ggml_is_contiguous(src0));
1043
+
1044
+ const float scale = *(const float *) dst->op_params;
1045
+
1046
+ int64_t n = ggml_nelements(dst);
1047
+
1048
+ id<MTLComputePipelineState> pipeline = nil;
1049
+
1050
+ if (n % 4 == 0) {
1051
+ n /= 4;
1052
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1053
+ } else {
1054
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1055
+ }
1396
1056
 
1397
- const int64_t n = ggml_nelements(dst);
1398
- GGML_ASSERT(n % 4 == 0);
1057
+ [encoder setComputePipelineState:pipeline];
1058
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1059
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1060
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1399
1061
 
1400
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1401
- } break;
1402
- case GGML_UNARY_OP_SILU:
1403
- {
1404
- [encoder setComputePipelineState:ctx->pipeline_silu];
1405
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1406
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1062
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1063
+ } break;
1064
+ case GGML_OP_UNARY:
1065
+ switch (ggml_get_unary_op(gf->nodes[i])) {
1066
+ case GGML_UNARY_OP_TANH:
1067
+ {
1068
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1407
1069
 
1408
- const int64_t n = ggml_nelements(dst);
1409
- GGML_ASSERT(n % 4 == 0);
1070
+ [encoder setComputePipelineState:pipeline];
1071
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1072
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1410
1073
 
1411
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1412
- } break;
1413
- default:
1414
- {
1415
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1416
- GGML_ASSERT(false);
1417
- }
1418
- } break;
1419
- case GGML_OP_SQR:
1420
- {
1421
- GGML_ASSERT(ggml_is_contiguous(src0));
1074
+ const int64_t n = ggml_nelements(dst);
1422
1075
 
1423
- [encoder setComputePipelineState:ctx->pipeline_sqr];
1424
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1425
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1076
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1077
+ } break;
1078
+ case GGML_UNARY_OP_RELU:
1079
+ {
1080
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1426
1081
 
1427
- const int64_t n = ggml_nelements(dst);
1428
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1429
- } break;
1430
- case GGML_OP_SUM_ROWS:
1431
- {
1432
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1082
+ [encoder setComputePipelineState:pipeline];
1083
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1084
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1433
1085
 
1434
- [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1435
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1436
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1437
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1438
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1439
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1440
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1441
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1442
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1443
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1444
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1445
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1446
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1447
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1448
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1449
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1450
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1451
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1452
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1453
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1454
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1455
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1456
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1457
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1458
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1459
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1460
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1461
-
1462
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1463
- } break;
1464
- case GGML_OP_SOFT_MAX:
1465
- {
1466
- int nth = 32; // SIMD width
1467
-
1468
- if (ne00%4 == 0) {
1469
- while (nth < ne00/4 && nth < 256) {
1470
- nth *= 2;
1471
- }
1472
- [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1473
- } else {
1474
- while (nth < ne00 && nth < 1024) {
1475
- nth *= 2;
1476
- }
1477
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
1478
- }
1086
+ const int64_t n = ggml_nelements(dst);
1479
1087
 
1480
- const float scale = ((float *) dst->op_params)[0];
1088
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1089
+ } break;
1090
+ case GGML_UNARY_OP_GELU:
1091
+ {
1092
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1481
1093
 
1482
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1483
- if (id_src1) {
1484
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1485
- } else {
1486
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1487
- }
1488
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1489
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1490
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1491
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1492
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1493
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1494
-
1495
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1496
- } break;
1497
- case GGML_OP_DIAG_MASK_INF:
1498
- {
1499
- const int n_past = ((int32_t *)(dst->op_params))[0];
1500
-
1501
- if (ne00%8 == 0) {
1502
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
1503
- } else {
1504
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
1505
- }
1506
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1507
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1508
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1509
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1510
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1094
+ [encoder setComputePipelineState:pipeline];
1095
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1096
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1511
1097
 
1512
- if (ne00%8 == 0) {
1513
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1514
- }
1515
- else {
1516
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1517
- }
1518
- } break;
1519
- case GGML_OP_MUL_MAT:
1520
- {
1521
- GGML_ASSERT(ne00 == ne10);
1098
+ const int64_t n = ggml_nelements(dst);
1099
+ GGML_ASSERT(n % 4 == 0);
1522
1100
 
1523
- // TODO: assert that dim2 and dim3 are contiguous
1524
- GGML_ASSERT(ne12 % ne02 == 0);
1525
- GGML_ASSERT(ne13 % ne03 == 0);
1101
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1102
+ } break;
1103
+ case GGML_UNARY_OP_GELU_QUICK:
1104
+ {
1105
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1526
1106
 
1527
- const uint r2 = ne12/ne02;
1528
- const uint r3 = ne13/ne03;
1107
+ [encoder setComputePipelineState:pipeline];
1108
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1109
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1529
1110
 
1530
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1531
- // to the matrix-vector kernel
1532
- int ne11_mm_min = 1;
1111
+ const int64_t n = ggml_nelements(dst);
1112
+ GGML_ASSERT(n % 4 == 0);
1113
+
1114
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1115
+ } break;
1116
+ case GGML_UNARY_OP_SILU:
1117
+ {
1118
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1119
+
1120
+ [encoder setComputePipelineState:pipeline];
1121
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1122
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1123
+
1124
+ const int64_t n = ggml_nelements(dst);
1125
+ GGML_ASSERT(n % 4 == 0);
1126
+
1127
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1128
+ } break;
1129
+ default:
1130
+ {
1131
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1132
+ GGML_ASSERT(false);
1133
+ }
1134
+ } break;
1135
+ case GGML_OP_SQR:
1136
+ {
1137
+ GGML_ASSERT(ggml_is_contiguous(src0));
1138
+
1139
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1140
+
1141
+ [encoder setComputePipelineState:pipeline];
1142
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1143
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1144
+
1145
+ const int64_t n = ggml_nelements(dst);
1146
+
1147
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1148
+ } break;
1149
+ case GGML_OP_SUM_ROWS:
1150
+ {
1151
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1152
+
1153
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1154
+
1155
+ [encoder setComputePipelineState:pipeline];
1156
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1157
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1158
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1159
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1160
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1161
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1162
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1163
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1164
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1165
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1166
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1167
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1168
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1169
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1170
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1171
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1172
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1173
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1174
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1175
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1176
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1177
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1178
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1179
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1180
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1181
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1185
+ case GGML_OP_SOFT_MAX:
1186
+ {
1187
+ int nth = 32; // SIMD width
1188
+
1189
+ id<MTLComputePipelineState> pipeline = nil;
1190
+
1191
+ if (ne00%4 == 0) {
1192
+ while (nth < ne00/4 && nth < 256) {
1193
+ nth *= 2;
1194
+ }
1195
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
1196
+ } else {
1197
+ while (nth < ne00 && nth < 1024) {
1198
+ nth *= 2;
1199
+ }
1200
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1201
+ }
1202
+
1203
+ const float scale = ((float *) dst->op_params)[0];
1204
+
1205
+ [encoder setComputePipelineState:pipeline];
1206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
+ if (id_src1) {
1208
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1209
+ } else {
1210
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1211
+ }
1212
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1213
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1214
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1215
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1216
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1217
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1218
+
1219
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1220
+ } break;
1221
+ case GGML_OP_DIAG_MASK_INF:
1222
+ {
1223
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1224
+
1225
+ id<MTLComputePipelineState> pipeline = nil;
1226
+
1227
+ if (ne00%8 == 0) {
1228
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1229
+ } else {
1230
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1231
+ }
1232
+
1233
+ [encoder setComputePipelineState:pipeline];
1234
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1235
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1236
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1237
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1238
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1239
+
1240
+ if (ne00%8 == 0) {
1241
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1242
+ }
1243
+ else {
1244
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1245
+ }
1246
+ } break;
1247
+ case GGML_OP_MUL_MAT:
1248
+ {
1249
+ GGML_ASSERT(ne00 == ne10);
1250
+
1251
+ // TODO: assert that dim2 and dim3 are contiguous
1252
+ GGML_ASSERT(ne12 % ne02 == 0);
1253
+ GGML_ASSERT(ne13 % ne03 == 0);
1254
+
1255
+ const uint r2 = ne12/ne02;
1256
+ const uint r3 = ne13/ne03;
1257
+
1258
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1259
+ // to the matrix-vector kernel
1260
+ int ne11_mm_min = 1;
1533
1261
 
1534
1262
  #if 0
1535
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1536
- // these numbers do not translate to other devices or model sizes
1537
- // TODO: need to find a better approach
1538
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1539
- switch (src0t) {
1540
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1541
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1542
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1543
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1544
- case GGML_TYPE_Q4_0:
1545
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1546
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1547
- case GGML_TYPE_Q5_0: // not tested yet
1548
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1549
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1550
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1551
- default: ne11_mm_min = 1; break;
1552
- }
1263
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1264
+ // these numbers do not translate to other devices or model sizes
1265
+ // TODO: need to find a better approach
1266
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1267
+ switch (src0t) {
1268
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1269
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1270
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1271
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1272
+ case GGML_TYPE_Q4_0:
1273
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1274
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1275
+ case GGML_TYPE_Q5_0: // not tested yet
1276
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1277
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1278
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1279
+ default: ne11_mm_min = 1; break;
1553
1280
  }
1281
+ }
1554
1282
  #endif
1555
1283
 
1556
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1557
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1558
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1559
- !ggml_is_transposed(src0) &&
1560
- !ggml_is_transposed(src1) &&
1561
- src1t == GGML_TYPE_F32 &&
1562
- ne00 % 32 == 0 && ne00 >= 64 &&
1563
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1564
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1565
- switch (src0->type) {
1566
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
1567
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1568
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1569
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1570
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1571
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1572
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1573
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1574
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
1575
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1576
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1577
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1578
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
1579
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
1580
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1581
- }
1582
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1583
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1584
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1585
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1586
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1587
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1588
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1589
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1590
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1591
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1592
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1593
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1594
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1595
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1596
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1597
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1598
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1599
- } else {
1600
- int nth0 = 32;
1601
- int nth1 = 1;
1602
- int nrows = 1;
1603
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1604
-
1605
- // use custom matrix x vector kernel
1606
- switch (src0t) {
1607
- case GGML_TYPE_F32:
1608
- {
1609
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1610
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1611
- nrows = 4;
1612
- } break;
1613
- case GGML_TYPE_F16:
1614
- {
1615
- nth0 = 32;
1616
- nth1 = 1;
1617
- if (src1t == GGML_TYPE_F32) {
1618
- if (ne11 * ne12 < 4) {
1619
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1620
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1621
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1622
- nrows = ne11;
1623
- } else {
1624
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1625
- nrows = 4;
1626
- }
1284
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1285
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1286
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1287
+ !ggml_is_transposed(src0) &&
1288
+ !ggml_is_transposed(src1) &&
1289
+ src1t == GGML_TYPE_F32 &&
1290
+ ne00 % 32 == 0 && ne00 >= 64 &&
1291
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1292
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1293
+
1294
+ id<MTLComputePipelineState> pipeline = nil;
1295
+
1296
+ switch (src0->type) {
1297
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1298
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1299
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1300
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1301
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1302
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1303
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1304
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1305
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1306
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1307
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1308
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1309
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1310
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1311
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1312
+ }
1313
+
1314
+ [encoder setComputePipelineState:pipeline];
1315
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1316
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1317
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1318
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1319
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1320
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1321
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1322
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1323
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1324
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1325
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1326
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1327
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1328
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1329
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1330
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1331
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1332
+ } else {
1333
+ int nth0 = 32;
1334
+ int nth1 = 1;
1335
+ int nrows = 1;
1336
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1337
+
1338
+ id<MTLComputePipelineState> pipeline = nil;
1339
+
1340
+ // use custom matrix x vector kernel
1341
+ switch (src0t) {
1342
+ case GGML_TYPE_F32:
1343
+ {
1344
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1345
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1346
+ nrows = 4;
1347
+ } break;
1348
+ case GGML_TYPE_F16:
1349
+ {
1350
+ nth0 = 32;
1351
+ nth1 = 1;
1352
+ if (src1t == GGML_TYPE_F32) {
1353
+ if (ne11 * ne12 < 4) {
1354
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1355
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1356
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1357
+ nrows = ne11;
1627
1358
  } else {
1628
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1359
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1629
1360
  nrows = 4;
1630
1361
  }
1631
- } break;
1632
- case GGML_TYPE_Q4_0:
1633
- {
1634
- nth0 = 8;
1635
- nth1 = 8;
1636
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1637
- } break;
1638
- case GGML_TYPE_Q4_1:
1639
- {
1640
- nth0 = 8;
1641
- nth1 = 8;
1642
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1643
- } break;
1644
- case GGML_TYPE_Q5_0:
1645
- {
1646
- nth0 = 8;
1647
- nth1 = 8;
1648
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1649
- } break;
1650
- case GGML_TYPE_Q5_1:
1651
- {
1652
- nth0 = 8;
1653
- nth1 = 8;
1654
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1655
- } break;
1656
- case GGML_TYPE_Q8_0:
1657
- {
1658
- nth0 = 8;
1659
- nth1 = 8;
1660
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1661
- } break;
1662
- case GGML_TYPE_Q2_K:
1663
- {
1664
- nth0 = 2;
1665
- nth1 = 32;
1666
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1667
- } break;
1668
- case GGML_TYPE_Q3_K:
1669
- {
1670
- nth0 = 2;
1671
- nth1 = 32;
1672
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1673
- } break;
1674
- case GGML_TYPE_Q4_K:
1675
- {
1676
- nth0 = 4; //1;
1677
- nth1 = 8; //32;
1678
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1679
- } break;
1680
- case GGML_TYPE_Q5_K:
1681
- {
1682
- nth0 = 2;
1683
- nth1 = 32;
1684
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1685
- } break;
1686
- case GGML_TYPE_Q6_K:
1687
- {
1688
- nth0 = 2;
1689
- nth1 = 32;
1690
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1691
- } break;
1692
- case GGML_TYPE_IQ2_XXS:
1693
- {
1694
- nth0 = 4;
1695
- nth1 = 16;
1696
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
1697
- } break;
1698
- case GGML_TYPE_IQ2_XS:
1699
- {
1700
- nth0 = 4;
1701
- nth1 = 16;
1702
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
1703
- } break;
1704
- default:
1705
- {
1706
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1707
- GGML_ASSERT(false && "not implemented");
1362
+ } else {
1363
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1364
+ nrows = 4;
1708
1365
  }
1709
- };
1710
-
1711
- if (ggml_is_quantized(src0t)) {
1712
- GGML_ASSERT(ne00 >= nth0*nth1);
1713
- }
1366
+ } break;
1367
+ case GGML_TYPE_Q4_0:
1368
+ {
1369
+ nth0 = 8;
1370
+ nth1 = 8;
1371
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1372
+ } break;
1373
+ case GGML_TYPE_Q4_1:
1374
+ {
1375
+ nth0 = 8;
1376
+ nth1 = 8;
1377
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1378
+ } break;
1379
+ case GGML_TYPE_Q5_0:
1380
+ {
1381
+ nth0 = 8;
1382
+ nth1 = 8;
1383
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1384
+ } break;
1385
+ case GGML_TYPE_Q5_1:
1386
+ {
1387
+ nth0 = 8;
1388
+ nth1 = 8;
1389
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1390
+ } break;
1391
+ case GGML_TYPE_Q8_0:
1392
+ {
1393
+ nth0 = 8;
1394
+ nth1 = 8;
1395
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1396
+ } break;
1397
+ case GGML_TYPE_Q2_K:
1398
+ {
1399
+ nth0 = 2;
1400
+ nth1 = 32;
1401
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1402
+ } break;
1403
+ case GGML_TYPE_Q3_K:
1404
+ {
1405
+ nth0 = 2;
1406
+ nth1 = 32;
1407
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1408
+ } break;
1409
+ case GGML_TYPE_Q4_K:
1410
+ {
1411
+ nth0 = 4; //1;
1412
+ nth1 = 8; //32;
1413
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1414
+ } break;
1415
+ case GGML_TYPE_Q5_K:
1416
+ {
1417
+ nth0 = 2;
1418
+ nth1 = 32;
1419
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1420
+ } break;
1421
+ case GGML_TYPE_Q6_K:
1422
+ {
1423
+ nth0 = 2;
1424
+ nth1 = 32;
1425
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1426
+ } break;
1427
+ case GGML_TYPE_IQ2_XXS:
1428
+ {
1429
+ nth0 = 4;
1430
+ nth1 = 16;
1431
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1432
+ } break;
1433
+ case GGML_TYPE_IQ2_XS:
1434
+ {
1435
+ nth0 = 4;
1436
+ nth1 = 16;
1437
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1438
+ } break;
1439
+ default:
1440
+ {
1441
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1442
+ GGML_ASSERT(false && "not implemented");
1443
+ }
1444
+ };
1714
1445
 
1715
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1716
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1717
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1718
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1719
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1720
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1721
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1722
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1723
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1724
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1725
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1726
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1727
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1728
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1729
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1730
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1731
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1732
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1733
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1734
-
1735
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1736
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1737
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1738
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1739
- }
1740
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1742
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1743
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1744
- }
1745
- else if (src0t == GGML_TYPE_Q4_K) {
1746
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1747
- }
1748
- else if (src0t == GGML_TYPE_Q3_K) {
1749
- #ifdef GGML_QKK_64
1750
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1751
- #else
1752
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1753
- #endif
1754
- }
1755
- else if (src0t == GGML_TYPE_Q5_K) {
1756
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1757
- }
1758
- else if (src0t == GGML_TYPE_Q6_K) {
1759
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1760
- } else {
1761
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1762
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
- }
1446
+ if (ggml_is_quantized(src0t)) {
1447
+ GGML_ASSERT(ne00 >= nth0*nth1);
1764
1448
  }
1765
- } break;
1766
- case GGML_OP_MUL_MAT_ID:
1767
- {
1768
- //GGML_ASSERT(ne00 == ne10);
1769
- //GGML_ASSERT(ne03 == ne13);
1770
-
1771
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1772
-
1773
- const int n_as = ((int32_t *) dst->op_params)[1];
1774
-
1775
- // TODO: make this more general
1776
- GGML_ASSERT(n_as <= 8);
1777
-
1778
- // max size of the src1ids array in the kernel stack
1779
- GGML_ASSERT(ne11 <= 512);
1780
-
1781
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1782
-
1783
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1784
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1785
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1786
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1787
-
1788
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1789
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1790
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1791
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1792
-
1793
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1794
-
1795
- GGML_ASSERT(!ggml_is_transposed(src2));
1796
- GGML_ASSERT(!ggml_is_transposed(src1));
1797
-
1798
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1799
-
1800
- const uint r2 = ne12/ne22;
1801
- const uint r3 = ne13/ne23;
1802
-
1803
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1804
- // to the matrix-vector kernel
1805
- int ne11_mm_min = n_as;
1806
-
1807
- const int idx = ((int32_t *) dst->op_params)[0];
1808
-
1809
- // batch size
1810
- GGML_ASSERT(ne01 == ne11);
1811
-
1812
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1813
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1814
- // !!!
1815
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1816
- // indirect matrix multiplication
1817
- // !!!
1818
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1819
- ne20 % 32 == 0 && ne20 >= 64 &&
1820
- ne11 > ne11_mm_min) {
1821
- switch (src2->type) {
1822
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1823
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1824
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1825
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1826
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1827
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1828
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1829
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1830
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1831
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1832
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1833
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1834
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
1835
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
1836
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1837
- }
1838
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1839
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1840
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1841
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1842
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1843
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1844
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1845
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1846
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1847
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1848
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1849
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1850
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1851
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1852
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1853
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1854
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1855
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1856
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1857
- // TODO: how to make this an array? read Metal docs
1858
- for (int j = 0; j < 8; ++j) {
1859
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1860
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1861
-
1862
- size_t offs_src_cur = 0;
1863
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1864
-
1865
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1866
- }
1867
-
1868
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1869
-
1870
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1871
- } else {
1872
- int nth0 = 32;
1873
- int nth1 = 1;
1874
- int nrows = 1;
1875
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1876
-
1877
- // use custom matrix x vector kernel
1878
- switch (src2t) {
1879
- case GGML_TYPE_F32:
1880
- {
1881
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1882
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1883
- } break;
1884
- case GGML_TYPE_F16:
1885
- {
1886
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1887
- nth0 = 32;
1888
- nth1 = 1;
1889
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1890
- } break;
1891
- case GGML_TYPE_Q4_0:
1892
- {
1893
- nth0 = 8;
1894
- nth1 = 8;
1895
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1896
- } break;
1897
- case GGML_TYPE_Q4_1:
1898
- {
1899
- nth0 = 8;
1900
- nth1 = 8;
1901
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1902
- } break;
1903
- case GGML_TYPE_Q5_0:
1904
- {
1905
- nth0 = 8;
1906
- nth1 = 8;
1907
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1908
- } break;
1909
- case GGML_TYPE_Q5_1:
1910
- {
1911
- nth0 = 8;
1912
- nth1 = 8;
1913
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1914
- } break;
1915
- case GGML_TYPE_Q8_0:
1916
- {
1917
- nth0 = 8;
1918
- nth1 = 8;
1919
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1920
- } break;
1921
- case GGML_TYPE_Q2_K:
1922
- {
1923
- nth0 = 2;
1924
- nth1 = 32;
1925
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1926
- } break;
1927
- case GGML_TYPE_Q3_K:
1928
- {
1929
- nth0 = 2;
1930
- nth1 = 32;
1931
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1932
- } break;
1933
- case GGML_TYPE_Q4_K:
1934
- {
1935
- nth0 = 4; //1;
1936
- nth1 = 8; //32;
1937
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1938
- } break;
1939
- case GGML_TYPE_Q5_K:
1940
- {
1941
- nth0 = 2;
1942
- nth1 = 32;
1943
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1944
- } break;
1945
- case GGML_TYPE_Q6_K:
1946
- {
1947
- nth0 = 2;
1948
- nth1 = 32;
1949
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1950
- } break;
1951
- case GGML_TYPE_IQ2_XXS:
1952
- {
1953
- nth0 = 4;
1954
- nth1 = 16;
1955
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1956
- } break;
1957
- case GGML_TYPE_IQ2_XS:
1958
- {
1959
- nth0 = 4;
1960
- nth1 = 16;
1961
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962
- } break;
1963
- default:
1964
- {
1965
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1966
- GGML_ASSERT(false && "not implemented");
1967
- }
1968
- };
1969
-
1970
- if (ggml_is_quantized(src2t)) {
1971
- GGML_ASSERT(ne20 >= nth0*nth1);
1972
- }
1973
1449
 
1974
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1975
-
1976
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1977
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1978
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1979
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1980
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1981
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1982
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1983
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1984
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1985
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1986
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1987
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1988
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1989
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1990
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1991
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1992
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1993
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1994
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1995
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1996
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1997
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1998
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1999
- // TODO: how to make this an array? read Metal docs
2000
- for (int j = 0; j < 8; ++j) {
2001
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
2002
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
2003
-
2004
- size_t offs_src_cur = 0;
2005
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
2006
-
2007
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
2008
- }
2009
-
2010
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
2011
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
2012
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
2013
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2014
- }
2015
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2017
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2018
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2019
- }
2020
- else if (src2t == GGML_TYPE_Q4_K) {
2021
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2022
- }
2023
- else if (src2t == GGML_TYPE_Q3_K) {
1450
+ [encoder setComputePipelineState:pipeline];
1451
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1452
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1453
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1454
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1455
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1456
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1457
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1458
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1459
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1460
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1461
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1462
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1463
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1464
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1465
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1466
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1467
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1468
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1469
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1470
+
1471
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1472
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1473
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1474
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1475
+ }
1476
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1477
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1478
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
+ }
1481
+ else if (src0t == GGML_TYPE_Q4_K) {
1482
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
+ }
1484
+ else if (src0t == GGML_TYPE_Q3_K) {
2024
1485
  #ifdef GGML_QKK_64
2025
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1486
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2026
1487
  #else
2027
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1488
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2028
1489
  #endif
2029
- }
2030
- else if (src2t == GGML_TYPE_Q5_K) {
2031
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2032
- }
2033
- else if (src2t == GGML_TYPE_Q6_K) {
2034
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
- } else {
2036
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2037
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
- }
2039
1490
  }
2040
- } break;
2041
- case GGML_OP_GET_ROWS:
2042
- {
2043
- switch (src0->type) {
2044
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
2045
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
2046
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
2047
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
2048
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
2049
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
2050
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
2051
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
2052
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
2053
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
2054
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
2055
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
2056
- case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
2057
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
2058
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
2059
- default: GGML_ASSERT(false && "not implemented");
1491
+ else if (src0t == GGML_TYPE_Q5_K) {
1492
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2060
1493
  }
2061
-
2062
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2063
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2064
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2065
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2066
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2067
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2068
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2069
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2070
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2071
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2072
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
2073
-
2074
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2075
- } break;
2076
- case GGML_OP_RMS_NORM:
2077
- {
2078
- GGML_ASSERT(ne00 % 4 == 0);
2079
-
2080
- float eps;
2081
- memcpy(&eps, dst->op_params, sizeof(float));
2082
-
2083
- int nth = 32; // SIMD width
2084
-
2085
- while (nth < ne00/4 && nth < 1024) {
2086
- nth *= 2;
1494
+ else if (src0t == GGML_TYPE_Q6_K) {
1495
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1496
+ } else {
1497
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1498
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2087
1499
  }
1500
+ }
1501
+ } break;
1502
+ case GGML_OP_MUL_MAT_ID:
1503
+ {
1504
+ //GGML_ASSERT(ne00 == ne10);
1505
+ //GGML_ASSERT(ne03 == ne13);
2088
1506
 
2089
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
2090
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2091
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2092
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2093
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2094
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2095
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2096
-
2097
- const int64_t nrows = ggml_nrows(src0);
2098
-
2099
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2100
- } break;
2101
- case GGML_OP_GROUP_NORM:
2102
- {
2103
- GGML_ASSERT(ne00 % 4 == 0);
2104
-
2105
- //float eps;
2106
- //memcpy(&eps, dst->op_params, sizeof(float));
2107
-
2108
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2109
-
2110
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2111
-
2112
- int nth = 32; // SIMD width
2113
-
2114
- //while (nth < ne00/4 && nth < 1024) {
2115
- // nth *= 2;
2116
- //}
2117
-
2118
- [encoder setComputePipelineState:ctx->pipeline_group_norm];
2119
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2120
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2121
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2122
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2123
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2124
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2125
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2126
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2127
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2128
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2129
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2130
-
2131
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2132
- } break;
2133
- case GGML_OP_NORM:
2134
- {
2135
- float eps;
2136
- memcpy(&eps, dst->op_params, sizeof(float));
2137
-
2138
- const int nth = MIN(256, ne00);
2139
-
2140
- [encoder setComputePipelineState:ctx->pipeline_norm];
2141
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2142
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2143
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2144
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2145
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2146
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2147
-
2148
- const int64_t nrows = ggml_nrows(src0);
2149
-
2150
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2151
- } break;
2152
- case GGML_OP_ALIBI:
2153
- {
2154
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2155
-
2156
- const int nth = MIN(1024, ne00);
2157
-
2158
- //const int n_past = ((int32_t *) dst->op_params)[0];
2159
- const int n_head = ((int32_t *) dst->op_params)[1];
2160
- float max_bias;
2161
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1507
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
2162
1508
 
2163
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2164
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2165
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1509
+ const int n_as = ((int32_t *) dst->op_params)[1];
2166
1510
 
2167
- [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
2168
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2170
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2171
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2172
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2173
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2174
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2175
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2176
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2177
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2178
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2179
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2180
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2181
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2182
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2183
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2184
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2185
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2186
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2187
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2188
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1511
+ // TODO: make this more general
1512
+ GGML_ASSERT(n_as <= 8);
2189
1513
 
2190
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2191
- } break;
2192
- case GGML_OP_ROPE:
2193
- {
2194
- GGML_ASSERT(ne10 == ne02);
2195
-
2196
- const int nth = MIN(1024, ne00);
2197
-
2198
- const int n_past = ((int32_t *) dst->op_params)[0];
2199
- const int n_dims = ((int32_t *) dst->op_params)[1];
2200
- const int mode = ((int32_t *) dst->op_params)[2];
2201
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2202
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2203
-
2204
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2205
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2206
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2207
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2208
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2209
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2210
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1514
+ // max size of the src1ids array in the kernel stack
1515
+ GGML_ASSERT(ne11 <= 512);
2211
1516
 
2212
- switch (src0->type) {
2213
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
2214
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
2215
- default: GGML_ASSERT(false);
2216
- };
1517
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
2217
1518
 
2218
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2219
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2220
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2221
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2222
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2223
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2224
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2225
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2226
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2227
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2228
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2229
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2230
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2231
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2232
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2233
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2234
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2235
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2236
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2237
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2238
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2239
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2240
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2241
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2242
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2243
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2244
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2245
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2246
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1519
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1522
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
2247
1523
 
2248
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2249
- } break;
2250
- case GGML_OP_IM2COL:
2251
- {
2252
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
2253
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2254
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
1524
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1525
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1526
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1527
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
2255
1528
 
2256
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2257
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2258
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2259
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2260
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2261
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2262
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1529
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2263
1530
 
2264
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2265
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2266
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2267
- const int32_t IW = src1->ne[0];
1531
+ GGML_ASSERT(!ggml_is_transposed(src2));
1532
+ GGML_ASSERT(!ggml_is_transposed(src1));
2268
1533
 
2269
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2270
- const int32_t KW = src0->ne[0];
1534
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2271
1535
 
2272
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2273
- const int32_t OW = dst->ne[1];
1536
+ const uint r2 = ne12/ne22;
1537
+ const uint r3 = ne13/ne23;
2274
1538
 
2275
- const int32_t CHW = IC * KH * KW;
1539
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1540
+ // to the matrix-vector kernel
1541
+ int ne11_mm_min = n_as;
2276
1542
 
2277
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2278
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1543
+ const int idx = ((int32_t *) dst->op_params)[0];
2279
1544
 
2280
- switch (src0->type) {
2281
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2282
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
2283
- default: GGML_ASSERT(false);
2284
- };
1545
+ // batch size
1546
+ GGML_ASSERT(ne01 == ne11);
2285
1547
 
2286
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2287
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2288
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2289
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2290
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2291
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2292
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2293
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2294
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2295
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2296
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2297
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2298
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2299
-
2300
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2301
- } break;
2302
- case GGML_OP_UPSCALE:
2303
- {
2304
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2305
-
2306
- const int sf = dst->op_params[0];
2307
-
2308
- [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2309
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2310
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2311
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2312
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2313
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2314
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2315
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2316
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2317
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2318
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2319
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2320
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2321
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2322
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2323
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2324
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2325
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2326
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2327
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2328
-
2329
- const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
2330
-
2331
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2332
- } break;
2333
- case GGML_OP_PAD:
2334
- {
2335
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2336
-
2337
- [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2338
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2339
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2340
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2341
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2342
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2343
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2344
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2345
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2346
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2347
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2348
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2349
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2350
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2351
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2352
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2353
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2354
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2355
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2356
-
2357
- const int nth = MIN(1024, ne0);
2358
-
2359
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2360
- } break;
2361
- case GGML_OP_ARGSORT:
2362
- {
2363
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2364
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2365
-
2366
- const int nrows = ggml_nrows(src0);
2367
-
2368
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2369
-
2370
- switch (order) {
2371
- case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2372
- case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2373
- default: GGML_ASSERT(false);
2374
- };
1548
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1549
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1550
+ // !!!
1551
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1552
+ // indirect matrix multiplication
1553
+ // !!!
1554
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1555
+ ne20 % 32 == 0 && ne20 >= 64 &&
1556
+ ne11 > ne11_mm_min) {
2375
1557
 
2376
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2377
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2378
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2379
-
2380
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2381
- } break;
2382
- case GGML_OP_LEAKY_RELU:
2383
- {
2384
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
1558
+ id<MTLComputePipelineState> pipeline = nil;
2385
1559
 
2386
- float slope;
2387
- memcpy(&slope, dst->op_params, sizeof(float));
1560
+ switch (src2->type) {
1561
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1562
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1563
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1564
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1565
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1566
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1567
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1568
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1569
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1570
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1571
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1572
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1573
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
+ }
2388
1577
 
2389
- [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2390
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2391
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2392
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
1578
+ [encoder setComputePipelineState:pipeline];
1579
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1580
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1581
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1582
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1583
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1584
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1585
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1586
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1587
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1588
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1589
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1590
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1591
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1592
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1593
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1594
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1595
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1596
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1597
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1598
+ // TODO: how to make this an array? read Metal docs
1599
+ for (int j = 0; j < 8; ++j) {
1600
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1601
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1602
+
1603
+ size_t offs_src_cur = 0;
1604
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1605
+
1606
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1607
+ }
2393
1608
 
2394
- const int64_t n = ggml_nelements(dst);
1609
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2395
1610
 
2396
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2397
- } break;
2398
- case GGML_OP_DUP:
2399
- case GGML_OP_CPY:
2400
- case GGML_OP_CONT:
2401
- {
2402
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1611
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1612
+ } else {
1613
+ int nth0 = 32;
1614
+ int nth1 = 1;
1615
+ int nrows = 1;
1616
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2403
1617
 
2404
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1618
+ id<MTLComputePipelineState> pipeline = nil;
2405
1619
 
2406
- switch (src0t) {
1620
+ // use custom matrix x vector kernel
1621
+ switch (src2t) {
2407
1622
  case GGML_TYPE_F32:
2408
1623
  {
2409
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2410
-
2411
- switch (dstt) {
2412
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2413
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2414
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2415
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2416
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2417
- //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2418
- //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
2419
- default: GGML_ASSERT(false && "not implemented");
2420
- };
1624
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1625
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2421
1626
  } break;
2422
1627
  case GGML_TYPE_F16:
2423
1628
  {
2424
- switch (dstt) {
2425
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2426
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
2427
- default: GGML_ASSERT(false && "not implemented");
2428
- };
1629
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1630
+ nth0 = 32;
1631
+ nth1 = 1;
1632
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1633
+ } break;
1634
+ case GGML_TYPE_Q4_0:
1635
+ {
1636
+ nth0 = 8;
1637
+ nth1 = 8;
1638
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2429
1639
  } break;
2430
- default: GGML_ASSERT(false && "not implemented");
1640
+ case GGML_TYPE_Q4_1:
1641
+ {
1642
+ nth0 = 8;
1643
+ nth1 = 8;
1644
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1645
+ } break;
1646
+ case GGML_TYPE_Q5_0:
1647
+ {
1648
+ nth0 = 8;
1649
+ nth1 = 8;
1650
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1651
+ } break;
1652
+ case GGML_TYPE_Q5_1:
1653
+ {
1654
+ nth0 = 8;
1655
+ nth1 = 8;
1656
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1657
+ } break;
1658
+ case GGML_TYPE_Q8_0:
1659
+ {
1660
+ nth0 = 8;
1661
+ nth1 = 8;
1662
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1663
+ } break;
1664
+ case GGML_TYPE_Q2_K:
1665
+ {
1666
+ nth0 = 2;
1667
+ nth1 = 32;
1668
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
1669
+ } break;
1670
+ case GGML_TYPE_Q3_K:
1671
+ {
1672
+ nth0 = 2;
1673
+ nth1 = 32;
1674
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
1675
+ } break;
1676
+ case GGML_TYPE_Q4_K:
1677
+ {
1678
+ nth0 = 4; //1;
1679
+ nth1 = 8; //32;
1680
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
1681
+ } break;
1682
+ case GGML_TYPE_Q5_K:
1683
+ {
1684
+ nth0 = 2;
1685
+ nth1 = 32;
1686
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
1687
+ } break;
1688
+ case GGML_TYPE_Q6_K:
1689
+ {
1690
+ nth0 = 2;
1691
+ nth1 = 32;
1692
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
1693
+ } break;
1694
+ case GGML_TYPE_IQ2_XXS:
1695
+ {
1696
+ nth0 = 4;
1697
+ nth1 = 16;
1698
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
1699
+ } break;
1700
+ case GGML_TYPE_IQ2_XS:
1701
+ {
1702
+ nth0 = 4;
1703
+ nth1 = 16;
1704
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1705
+ } break;
1706
+ default:
1707
+ {
1708
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1709
+ GGML_ASSERT(false && "not implemented");
1710
+ }
1711
+ };
1712
+
1713
+ if (ggml_is_quantized(src2t)) {
1714
+ GGML_ASSERT(ne20 >= nth0*nth1);
2431
1715
  }
2432
1716
 
2433
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2437
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2438
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2439
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2440
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2441
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2442
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2443
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2444
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2445
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2446
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2447
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2448
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2449
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2450
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1717
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
2451
1718
 
2452
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2453
- } break;
2454
- default:
2455
- {
2456
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2457
- GGML_ASSERT(false);
2458
- }
2459
- }
1719
+ [encoder setComputePipelineState:pipeline];
1720
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1721
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1722
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1723
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1724
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1725
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1726
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1727
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1728
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1729
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1730
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1731
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1732
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1733
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1734
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1735
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1736
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1737
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1738
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1739
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1740
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1741
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1742
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1743
+ // TODO: how to make this an array? read Metal docs
1744
+ for (int j = 0; j < 8; ++j) {
1745
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1746
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1747
+
1748
+ size_t offs_src_cur = 0;
1749
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1750
+
1751
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1752
+ }
2460
1753
 
2461
- #ifndef GGML_METAL_NDEBUG
2462
- [encoder popDebugGroup];
1754
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1755
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1756
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1758
+ }
1759
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1760
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1761
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
+ }
1764
+ else if (src2t == GGML_TYPE_Q4_K) {
1765
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1766
+ }
1767
+ else if (src2t == GGML_TYPE_Q3_K) {
1768
+ #ifdef GGML_QKK_64
1769
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
+ #else
1771
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2463
1772
  #endif
2464
- }
1773
+ }
1774
+ else if (src2t == GGML_TYPE_Q5_K) {
1775
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1776
+ }
1777
+ else if (src2t == GGML_TYPE_Q6_K) {
1778
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1779
+ } else {
1780
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1781
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1782
+ }
1783
+ }
1784
+ } break;
1785
+ case GGML_OP_GET_ROWS:
1786
+ {
1787
+ id<MTLComputePipelineState> pipeline = nil;
1788
+
1789
+ switch (src0->type) {
1790
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
1791
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
1792
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
1793
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
1794
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
1795
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
1796
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
1797
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
1798
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
1799
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
1800
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
1801
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1802
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1803
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1804
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1805
+ default: GGML_ASSERT(false && "not implemented");
1806
+ }
1807
+
1808
+ [encoder setComputePipelineState:pipeline];
1809
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1810
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1811
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1812
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1813
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1814
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1815
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1816
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1817
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1818
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1819
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1820
+
1821
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1822
+ } break;
1823
+ case GGML_OP_RMS_NORM:
1824
+ {
1825
+ GGML_ASSERT(ne00 % 4 == 0);
1826
+
1827
+ float eps;
1828
+ memcpy(&eps, dst->op_params, sizeof(float));
1829
+
1830
+ int nth = 32; // SIMD width
1831
+
1832
+ while (nth < ne00/4 && nth < 1024) {
1833
+ nth *= 2;
1834
+ }
2465
1835
 
2466
- if (encoder != nil) {
2467
- [encoder endEncoding];
2468
- encoder = nil;
1836
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
1837
+
1838
+ [encoder setComputePipelineState:pipeline];
1839
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1840
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1841
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1842
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1843
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1844
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1845
+
1846
+ const int64_t nrows = ggml_nrows(src0);
1847
+
1848
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1849
+ } break;
1850
+ case GGML_OP_GROUP_NORM:
1851
+ {
1852
+ GGML_ASSERT(ne00 % 4 == 0);
1853
+
1854
+ //float eps;
1855
+ //memcpy(&eps, dst->op_params, sizeof(float));
1856
+
1857
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1858
+
1859
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1860
+
1861
+ int nth = 32; // SIMD width
1862
+
1863
+ //while (nth < ne00/4 && nth < 1024) {
1864
+ // nth *= 2;
1865
+ //}
1866
+
1867
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
1868
+
1869
+ [encoder setComputePipelineState:pipeline];
1870
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1871
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1872
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1873
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1874
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1875
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1876
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1877
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1878
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1879
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1880
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1881
+
1882
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1883
+ } break;
1884
+ case GGML_OP_NORM:
1885
+ {
1886
+ float eps;
1887
+ memcpy(&eps, dst->op_params, sizeof(float));
1888
+
1889
+ const int nth = MIN(256, ne00);
1890
+
1891
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
1892
+
1893
+ [encoder setComputePipelineState:pipeline];
1894
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1895
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1896
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1897
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1898
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1899
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1900
+
1901
+ const int64_t nrows = ggml_nrows(src0);
1902
+
1903
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1904
+ } break;
1905
+ case GGML_OP_ALIBI:
1906
+ {
1907
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
1908
+
1909
+ const int nth = MIN(1024, ne00);
1910
+
1911
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1912
+ const int n_head = ((int32_t *) dst->op_params)[1];
1913
+ float max_bias;
1914
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1915
+
1916
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1917
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1918
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1919
+
1920
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
1921
+
1922
+ [encoder setComputePipelineState:pipeline];
1923
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1925
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1926
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1927
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1928
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1929
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1930
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1931
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1932
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1933
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1934
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1935
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1936
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1937
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1938
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1939
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1940
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1941
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1942
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1943
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1946
+ } break;
1947
+ case GGML_OP_ROPE:
1948
+ {
1949
+ GGML_ASSERT(ne10 == ne02);
1950
+
1951
+ const int nth = MIN(1024, ne00);
1952
+
1953
+ const int n_past = ((int32_t *) dst->op_params)[0];
1954
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1955
+ const int mode = ((int32_t *) dst->op_params)[2];
1956
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1957
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1958
+
1959
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1960
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1961
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1962
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1963
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1964
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1965
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1966
+
1967
+ id<MTLComputePipelineState> pipeline = nil;
1968
+
1969
+ switch (src0->type) {
1970
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
1971
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
1972
+ default: GGML_ASSERT(false);
1973
+ };
1974
+
1975
+ [encoder setComputePipelineState:pipeline];
1976
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1977
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1978
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1979
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1980
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1981
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1982
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1983
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1984
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1985
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1986
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1987
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1988
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1989
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1990
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1991
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1992
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1993
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1994
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1995
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1996
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1997
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1998
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1999
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2000
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2001
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2002
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2003
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2004
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2005
+
2006
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2007
+ } break;
2008
+ case GGML_OP_IM2COL:
2009
+ {
2010
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
2011
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
2012
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
2013
+
2014
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2015
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2016
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2017
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2018
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2019
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2020
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2021
+
2022
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
2023
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2024
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2025
+ const int32_t IW = src1->ne[0];
2026
+
2027
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2028
+ const int32_t KW = src0->ne[0];
2029
+
2030
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2031
+ const int32_t OW = dst->ne[1];
2032
+
2033
+ const int32_t CHW = IC * KH * KW;
2034
+
2035
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2036
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2037
+
2038
+ id<MTLComputePipelineState> pipeline = nil;
2039
+
2040
+ switch (src0->type) {
2041
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2042
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2043
+ default: GGML_ASSERT(false);
2044
+ };
2045
+
2046
+ [encoder setComputePipelineState:pipeline];
2047
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2048
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2049
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2050
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2051
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2052
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2053
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2054
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2055
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2056
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2057
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2058
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2059
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2060
+
2061
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2062
+ } break;
2063
+ case GGML_OP_UPSCALE:
2064
+ {
2065
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2066
+
2067
+ const int sf = dst->op_params[0];
2068
+
2069
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2070
+
2071
+ [encoder setComputePipelineState:pipeline];
2072
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2073
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2074
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2075
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2076
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2077
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2078
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2079
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2080
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2081
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2082
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2083
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2084
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2085
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2086
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2087
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2088
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2089
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2090
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2091
+
2092
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2093
+
2094
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2095
+ } break;
2096
+ case GGML_OP_PAD:
2097
+ {
2098
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2099
+
2100
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2101
+
2102
+ [encoder setComputePipelineState:pipeline];
2103
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2104
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2105
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2106
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2107
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2108
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2109
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2110
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2111
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2112
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2113
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2114
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2115
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2116
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2117
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2118
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2119
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2120
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2121
+
2122
+ const int nth = MIN(1024, ne0);
2123
+
2124
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2125
+ } break;
2126
+ case GGML_OP_ARGSORT:
2127
+ {
2128
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2129
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2130
+
2131
+ const int nrows = ggml_nrows(src0);
2132
+
2133
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2134
+
2135
+ id<MTLComputePipelineState> pipeline = nil;
2136
+
2137
+ switch (order) {
2138
+ case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2139
+ case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2140
+ default: GGML_ASSERT(false);
2141
+ };
2142
+
2143
+ [encoder setComputePipelineState:pipeline];
2144
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2145
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2146
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2147
+
2148
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2149
+ } break;
2150
+ case GGML_OP_LEAKY_RELU:
2151
+ {
2152
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2153
+
2154
+ float slope;
2155
+ memcpy(&slope, dst->op_params, sizeof(float));
2156
+
2157
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2158
+
2159
+ [encoder setComputePipelineState:pipeline];
2160
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2161
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2162
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2163
+
2164
+ const int64_t n = ggml_nelements(dst);
2165
+
2166
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2167
+ } break;
2168
+ case GGML_OP_DUP:
2169
+ case GGML_OP_CPY:
2170
+ case GGML_OP_CONT:
2171
+ {
2172
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2173
+
2174
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
2175
+
2176
+ id<MTLComputePipelineState> pipeline = nil;
2177
+
2178
+ switch (src0t) {
2179
+ case GGML_TYPE_F32:
2180
+ {
2181
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2182
+
2183
+ switch (dstt) {
2184
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2185
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2186
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2187
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2188
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2189
+ //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2190
+ //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2191
+ default: GGML_ASSERT(false && "not implemented");
2192
+ };
2193
+ } break;
2194
+ case GGML_TYPE_F16:
2195
+ {
2196
+ switch (dstt) {
2197
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2198
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2199
+ default: GGML_ASSERT(false && "not implemented");
2200
+ };
2201
+ } break;
2202
+ default: GGML_ASSERT(false && "not implemented");
2203
+ }
2204
+
2205
+ [encoder setComputePipelineState:pipeline];
2206
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2207
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2208
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2209
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2210
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2211
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2212
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2213
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2214
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2215
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2216
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2217
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2218
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2219
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2220
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2221
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2222
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2223
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2224
+
2225
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2226
+ } break;
2227
+ default:
2228
+ {
2229
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2230
+ GGML_ASSERT(false);
2231
+ }
2469
2232
  }
2470
2233
 
2471
- [command_buffer commit];
2472
- });
2473
- }
2234
+ #ifndef GGML_METAL_NDEBUG
2235
+ [encoder popDebugGroup];
2236
+ #endif
2237
+ }
2238
+
2239
+ if (encoder != nil) {
2240
+ [encoder endEncoding];
2241
+ encoder = nil;
2242
+ }
2474
2243
 
2475
- // wait for all threads to finish
2476
- dispatch_barrier_sync(ctx->d_queue, ^{});
2244
+ [command_buffer commit];
2245
+ });
2477
2246
 
2478
- // check status of command buffers
2247
+ // Wait for completion and check status of each command buffer
2479
2248
  // needed to detect if the device ran out-of-memory for example (#1881)
2480
- for (int i = 0; i < n_cb; i++) {
2481
- [ctx->command_buffers[i] waitUntilCompleted];
2482
2249
 
2483
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
2250
+ for (int i = 0; i < n_cb; ++i) {
2251
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
2252
+ [command_buffer waitUntilCompleted];
2253
+
2254
+ MTLCommandBufferStatus status = [command_buffer status];
2484
2255
  if (status != MTLCommandBufferStatusCompleted) {
2485
2256
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2486
2257
  return false;
@@ -2520,13 +2291,13 @@ static void ggml_backend_metal_free_device(void) {
2520
2291
  }
2521
2292
  }
2522
2293
 
2523
- static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2524
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2294
+ GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
2295
+ return "Metal";
2525
2296
 
2526
- return ctx->all_data;
2297
+ UNUSED(buffer);
2527
2298
  }
2528
2299
 
2529
- static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2300
+ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2530
2301
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2531
2302
 
2532
2303
  for (int i = 0; i < ctx->n_buffers; i++) {
@@ -2541,50 +2312,80 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
2541
2312
  free(ctx);
2542
2313
  }
2543
2314
 
2544
- static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2545
- memcpy((char *)tensor->data + offset, data, size);
2315
+ GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2316
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2546
2317
 
2547
- UNUSED(buffer);
2318
+ return ctx->all_data;
2548
2319
  }
2549
2320
 
2550
- static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2551
- memcpy(data, (const char *)tensor->data + offset, size);
2321
+ GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2322
+ memcpy((char *)tensor->data + offset, data, size);
2552
2323
 
2553
2324
  UNUSED(buffer);
2554
2325
  }
2555
2326
 
2556
- static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2557
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2327
+ GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2328
+ memcpy(data, (const char *)tensor->data + offset, size);
2558
2329
 
2559
2330
  UNUSED(buffer);
2560
2331
  }
2561
2332
 
2562
- static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2563
- ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
2333
+ GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
2334
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2335
+ memcpy(dst->data, src->data, ggml_nbytes(src));
2336
+ return true;
2337
+ }
2338
+ return false;
2564
2339
 
2565
2340
  UNUSED(buffer);
2566
2341
  }
2567
2342
 
2568
- static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2343
+ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2569
2344
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2570
2345
 
2571
2346
  memset(ctx->all_data, value, ctx->all_size);
2572
2347
  }
2573
2348
 
2574
2349
  static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
2350
+ /* .get_name = */ ggml_backend_metal_buffer_get_name,
2575
2351
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2576
2352
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
2577
2353
  /* .init_tensor = */ NULL,
2578
2354
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2579
2355
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2580
- /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
2581
- /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
2356
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
2582
2357
  /* .clear = */ ggml_backend_metal_buffer_clear,
2358
+ /* .reset = */ NULL,
2583
2359
  };
2584
2360
 
2585
2361
  // default buffer type
2586
2362
 
2587
- static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2363
+ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2364
+ return "Metal";
2365
+
2366
+ UNUSED(buft);
2367
+ }
2368
+
2369
+ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
2370
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2371
+ if (@available(macOS 10.12, iOS 16.0, *)) {
2372
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2373
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2374
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2375
+
2376
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2377
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2378
+ } else {
2379
+ GGML_METAL_LOG_INFO("\n");
2380
+ }
2381
+ } else {
2382
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2383
+ }
2384
+ #endif
2385
+ UNUSED(device);
2386
+ }
2387
+
2388
+ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2588
2389
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2589
2390
 
2590
2391
  const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -2616,46 +2417,32 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
2616
2417
  }
2617
2418
 
2618
2419
  GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2619
-
2620
-
2621
- #if TARGET_OS_OSX
2622
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2623
- device.currentAllocatedSize / 1024.0 / 1024.0,
2624
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2625
-
2626
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2627
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2628
- } else {
2629
- GGML_METAL_LOG_INFO("\n");
2630
- }
2631
- #else
2632
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2633
- #endif
2634
-
2420
+ ggml_backend_metal_log_allocated_size(device);
2635
2421
 
2636
2422
  return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
2637
2423
  }
2638
2424
 
2639
- static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2425
+ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2640
2426
  return 32;
2641
2427
  UNUSED(buft);
2642
2428
  }
2643
2429
 
2644
- static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2430
+ GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2645
2431
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
2646
2432
 
2647
2433
  UNUSED(buft);
2648
2434
  }
2649
2435
 
2650
- static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2436
+ GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2651
2437
  return true;
2652
2438
 
2653
2439
  UNUSED(buft);
2654
2440
  }
2655
2441
 
2656
- ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2442
+ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2657
2443
  static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
2658
2444
  /* .iface = */ {
2445
+ /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
2659
2446
  /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2660
2447
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2661
2448
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@@ -2670,7 +2457,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2670
2457
 
2671
2458
  // buffer from ptr
2672
2459
 
2673
- ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2460
+ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2674
2461
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2675
2462
 
2676
2463
  ctx->all_data = data;
@@ -2679,6 +2466,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2679
2466
  ctx->n_buffers = 0;
2680
2467
 
2681
2468
  const size_t size_page = sysconf(_SC_PAGESIZE);
2469
+
2470
+ // page-align the data ptr
2471
+ {
2472
+ const uintptr_t offs = (uintptr_t) data % size_page;
2473
+ data = (void *) ((char *) data - offs);
2474
+ size += offs;
2475
+ }
2476
+
2682
2477
  size_t size_aligned = size;
2683
2478
  if ((size_aligned % size_page) != 0) {
2684
2479
  size_aligned += (size_page - (size_aligned % size_page));
@@ -2730,63 +2525,50 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
2730
2525
  }
2731
2526
  }
2732
2527
 
2733
- #if TARGET_OS_OSX
2734
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2735
- device.currentAllocatedSize / 1024.0 / 1024.0,
2736
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2737
-
2738
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2739
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2740
- } else {
2741
- GGML_METAL_LOG_INFO("\n");
2742
- }
2743
- #else
2744
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2745
- #endif
2528
+ ggml_backend_metal_log_allocated_size(device);
2746
2529
 
2747
2530
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
2748
2531
  }
2749
2532
 
2750
2533
  // backend
2751
2534
 
2752
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2535
+ GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2753
2536
  return "Metal";
2754
2537
 
2755
2538
  UNUSED(backend);
2756
2539
  }
2757
2540
 
2758
- static void ggml_backend_metal_free(ggml_backend_t backend) {
2541
+ GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
2759
2542
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2760
2543
  ggml_metal_free(ctx);
2761
2544
  free(backend);
2762
2545
  }
2763
2546
 
2764
- static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2547
+ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2765
2548
  return ggml_backend_metal_buffer_type();
2766
2549
 
2767
2550
  UNUSED(backend);
2768
2551
  }
2769
2552
 
2770
- static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2553
+ GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2771
2554
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2772
2555
 
2773
2556
  return ggml_metal_graph_compute(metal_ctx, cgraph);
2774
2557
  }
2775
2558
 
2776
- static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2777
- return ggml_metal_supports_op(op);
2559
+ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
2560
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2778
2561
 
2779
- UNUSED(backend);
2562
+ return ggml_metal_supports_op(metal_ctx, op);
2780
2563
  }
2781
2564
 
2782
- static struct ggml_backend_i metal_backend_i = {
2565
+ static struct ggml_backend_i ggml_backend_metal_i = {
2783
2566
  /* .get_name = */ ggml_backend_metal_name,
2784
2567
  /* .free = */ ggml_backend_metal_free,
2785
2568
  /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2786
2569
  /* .set_tensor_async = */ NULL,
2787
2570
  /* .get_tensor_async = */ NULL,
2788
- /* .cpy_tensor_from_async = */ NULL,
2789
- /* .cpy_tensor_to_async = */ NULL,
2571
+ /* .cpy_tensor_async = */ NULL,
2790
2572
  /* .synchronize = */ NULL,
2791
2573
  /* .graph_plan_create = */ NULL,
2792
2574
  /* .graph_plan_free = */ NULL,
@@ -2795,6 +2577,11 @@ static struct ggml_backend_i metal_backend_i = {
2795
2577
  /* .supports_op = */ ggml_backend_metal_supports_op,
2796
2578
  };
2797
2579
 
2580
+ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
2581
+ ggml_metal_log_callback = log_callback;
2582
+ ggml_metal_log_user_data = user_data;
2583
+ }
2584
+
2798
2585
  ggml_backend_t ggml_backend_metal_init(void) {
2799
2586
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2800
2587
 
@@ -2805,7 +2592,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2805
2592
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2806
2593
 
2807
2594
  *metal_backend = (struct ggml_backend) {
2808
- /* .interface = */ metal_backend_i,
2595
+ /* .interface = */ ggml_backend_metal_i,
2809
2596
  /* .context = */ ctx,
2810
2597
  };
2811
2598
 
@@ -2813,7 +2600,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2813
2600
  }
2814
2601
 
2815
2602
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2816
- return backend->iface.get_name == ggml_backend_metal_name;
2603
+ return backend && backend->iface.get_name == ggml_backend_metal_name;
2817
2604
  }
2818
2605
 
2819
2606
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -2821,7 +2608,7 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2821
2608
 
2822
2609
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2823
2610
 
2824
- ggml_metal_set_n_cb(ctx, n_cb);
2611
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
2825
2612
  }
2826
2613
 
2827
2614
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
@@ -2832,9 +2619,9 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2832
2619
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2833
2620
  }
2834
2621
 
2835
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2622
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2836
2623
 
2837
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2624
+ GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2838
2625
  return ggml_backend_metal_init();
2839
2626
 
2840
2627
  GGML_UNUSED(params);