whisper.rn 0.4.0-rc.6 → 0.4.0-rc.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-metal.m CHANGED
@@ -24,7 +24,7 @@
24
24
 
25
25
  #define UNUSED(x) (void)(x)
26
26
 
27
- #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE)
27
+ #define WSP_GGML_METAL_MAX_KERNELS 256
28
28
 
29
29
  struct wsp_ggml_metal_buffer {
30
30
  const char * name;
@@ -35,6 +35,134 @@ struct wsp_ggml_metal_buffer {
35
35
  id<MTLBuffer> metal;
36
36
  };
37
37
 
38
+ struct wsp_ggml_metal_kernel {
39
+ id<MTLFunction> function;
40
+ id<MTLComputePipelineState> pipeline;
41
+ };
42
+
43
+ enum wsp_ggml_metal_kernel_type {
44
+ WSP_GGML_METAL_KERNEL_TYPE_ADD,
45
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW,
46
+ WSP_GGML_METAL_KERNEL_TYPE_MUL,
47
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW,
48
+ WSP_GGML_METAL_KERNEL_TYPE_DIV,
49
+ WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW,
50
+ WSP_GGML_METAL_KERNEL_TYPE_SCALE,
51
+ WSP_GGML_METAL_KERNEL_TYPE_SCALE_4,
52
+ WSP_GGML_METAL_KERNEL_TYPE_TANH,
53
+ WSP_GGML_METAL_KERNEL_TYPE_RELU,
54
+ WSP_GGML_METAL_KERNEL_TYPE_GELU,
55
+ WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
56
+ WSP_GGML_METAL_KERNEL_TYPE_SILU,
57
+ WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX,
58
+ WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
59
+ WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
60
+ WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
61
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
62
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
63
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
64
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
65
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
66
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
67
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
68
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
69
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
70
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
71
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
72
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
73
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
74
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
75
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
76
+ WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
77
+ WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
78
+ WSP_GGML_METAL_KERNEL_TYPE_NORM,
79
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
80
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
81
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
82
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
83
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
84
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
85
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
86
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
87
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
88
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
89
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
90
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
91
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
92
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
93
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
94
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
95
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
96
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
+ //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
99
+ //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
100
+ //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
101
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
102
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
103
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
104
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
105
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
106
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
107
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
108
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
109
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
110
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
111
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
112
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
113
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
114
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
115
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
116
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
117
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
118
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
119
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
120
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
121
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
122
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
123
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
124
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
125
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
130
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
131
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
132
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
133
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
134
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
135
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
136
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
137
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
138
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
139
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
+ WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32,
142
+ WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16,
143
+ WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32,
144
+ WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
145
+ WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
146
+ WSP_GGML_METAL_KERNEL_TYPE_PAD_F32,
147
+ WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
148
+ WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
149
+ WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
150
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
151
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
152
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
153
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
154
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
155
+ //WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
156
+ //WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
157
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
158
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
159
+ WSP_GGML_METAL_KERNEL_TYPE_CONCAT,
160
+ WSP_GGML_METAL_KERNEL_TYPE_SQR,
161
+ WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
162
+
163
+ WSP_GGML_METAL_KERNEL_TYPE_COUNT
164
+ };
165
+
38
166
  struct wsp_ggml_metal_context {
39
167
  int n_cb;
40
168
 
@@ -42,131 +170,15 @@ struct wsp_ggml_metal_context {
42
170
  id<MTLCommandQueue> queue;
43
171
  id<MTLLibrary> library;
44
172
 
45
- id<MTLCommandBuffer> command_buffers [WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
46
- id<MTLComputeCommandEncoder> command_encoders[WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
47
-
48
173
  dispatch_queue_t d_queue;
49
174
 
50
175
  int n_buffers;
51
176
  struct wsp_ggml_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
52
177
 
53
- int concur_list[WSP_GGML_MAX_CONCUR];
54
- int concur_list_len;
55
-
56
- // custom kernels
57
- #define WSP_GGML_METAL_DECL_KERNEL(name) \
58
- id<MTLFunction> function_##name; \
59
- id<MTLComputePipelineState> pipeline_##name
60
-
61
- WSP_GGML_METAL_DECL_KERNEL(add);
62
- WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
- WSP_GGML_METAL_DECL_KERNEL(mul);
64
- WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
- WSP_GGML_METAL_DECL_KERNEL(div);
66
- WSP_GGML_METAL_DECL_KERNEL(div_row);
67
- WSP_GGML_METAL_DECL_KERNEL(scale);
68
- WSP_GGML_METAL_DECL_KERNEL(scale_4);
69
- WSP_GGML_METAL_DECL_KERNEL(tanh);
70
- WSP_GGML_METAL_DECL_KERNEL(relu);
71
- WSP_GGML_METAL_DECL_KERNEL(gelu);
72
- WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
73
- WSP_GGML_METAL_DECL_KERNEL(silu);
74
- WSP_GGML_METAL_DECL_KERNEL(soft_max);
75
- WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
76
- WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
77
- WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
78
- WSP_GGML_METAL_DECL_KERNEL(get_rows_f32);
79
- WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
80
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
81
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
82
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
83
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
84
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
85
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
86
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
87
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
- WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
- WSP_GGML_METAL_DECL_KERNEL(rms_norm);
91
- WSP_GGML_METAL_DECL_KERNEL(group_norm);
92
- WSP_GGML_METAL_DECL_KERNEL(norm);
93
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
94
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
95
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
96
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
97
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
98
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
99
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
100
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
101
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
102
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
103
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
104
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
105
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
106
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
107
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
- //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
- //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
- //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
- WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
123
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
124
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
125
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
126
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
127
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
128
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
129
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
130
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
131
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
132
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
133
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
134
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
135
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
136
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
137
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
138
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
139
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
140
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
141
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
142
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
143
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
144
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
145
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
146
- WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
147
- WSP_GGML_METAL_DECL_KERNEL(rope_f32);
148
- WSP_GGML_METAL_DECL_KERNEL(rope_f16);
149
- WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
150
- WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
151
- WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
152
- WSP_GGML_METAL_DECL_KERNEL(pad_f32);
153
- WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
154
- WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
- WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
156
- WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
157
- WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
158
- WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
159
- WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
160
- WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
161
- //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
162
- //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
163
- WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
- WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
165
- WSP_GGML_METAL_DECL_KERNEL(concat);
166
- WSP_GGML_METAL_DECL_KERNEL(sqr);
167
- WSP_GGML_METAL_DECL_KERNEL(sum_rows);
168
-
169
- #undef WSP_GGML_METAL_DECL_KERNEL
178
+ struct wsp_ggml_metal_kernel kernels[WSP_GGML_METAL_MAX_KERNELS];
179
+
180
+ bool support_simdgroup_reduction;
181
+ bool support_simdgroup_mm;
170
182
  };
171
183
 
172
184
  // MSL code
@@ -180,14 +192,16 @@ struct wsp_ggml_metal_context {
180
192
  @implementation WSPGGMLMetalClass
181
193
  @end
182
194
 
183
- wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL;
184
- void * wsp_ggml_metal_log_user_data = NULL;
195
+ static void wsp_ggml_metal_default_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
196
+ fprintf(stderr, "%s", msg);
185
197
 
186
- void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
187
- wsp_ggml_metal_log_callback = log_callback;
188
- wsp_ggml_metal_log_user_data = user_data;
198
+ UNUSED(level);
199
+ UNUSED(user_data);
189
200
  }
190
201
 
202
+ wsp_ggml_log_callback wsp_ggml_metal_log_callback = wsp_ggml_metal_default_log_callback;
203
+ void * wsp_ggml_metal_log_user_data = NULL;
204
+
191
205
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
192
206
  static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
193
207
  if (wsp_ggml_metal_log_callback != NULL) {
@@ -210,24 +224,33 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma
210
224
  }
211
225
  }
212
226
 
213
- struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
214
- WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
227
+ static void * wsp_ggml_metal_host_malloc(size_t n) {
228
+ void * data = NULL;
229
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
230
+ if (result != 0) {
231
+ WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
232
+ return NULL;
233
+ }
234
+
235
+ return data;
236
+ }
215
237
 
216
- id<MTLDevice> device;
217
- NSString * s;
238
+ static struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
239
+ WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
218
240
 
219
- #if TARGET_OS_OSX
241
+ #if TARGET_OS_OSX && !WSP_GGML_METAL_NDEBUG
220
242
  // Show all the Metal device instances in the system
221
243
  NSArray * devices = MTLCopyAllDevices();
222
- for (device in devices) {
223
- s = [device name];
244
+ for (id<MTLDevice> device in devices) {
245
+ NSString * s = [device name];
224
246
  WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
225
247
  }
248
+ [devices release]; // since it was created by a *Copy* C method
226
249
  #endif
227
250
 
228
251
  // Pick and show default Metal device
229
- device = MTLCreateSystemDefaultDevice();
230
- s = [device name];
252
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
253
+ NSString * s = [device name];
231
254
  WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
232
255
 
233
256
  // Configure context
@@ -236,7 +259,6 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
236
259
  ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
237
260
  ctx->queue = [ctx->device newCommandQueue];
238
261
  ctx->n_buffers = 0;
239
- ctx->concur_list_len = 0;
240
262
 
241
263
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
242
264
 
@@ -251,6 +273,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
251
273
  NSError * error = nil;
252
274
  NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
253
275
  if (libPath != nil) {
276
+ // pre-compiled library found
254
277
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
255
278
  WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
256
279
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
@@ -278,12 +301,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
278
301
  return NULL;
279
302
  }
280
303
 
281
- MTLCompileOptions* options = nil;
304
+ @autoreleasepool {
305
+ // dictionary of preprocessor macros
306
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
307
+
282
308
  #ifdef WSP_GGML_QKK_64
283
- options = [MTLCompileOptions new];
284
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
309
+ prep[@"QK_K"] = @(64);
285
310
  #endif
286
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
311
+
312
+ MTLCompileOptions* options = [MTLCompileOptions new];
313
+ options.preprocessorMacros = prep;
314
+
315
+ //[options setFastMathEnabled:false];
316
+
317
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
318
+ }
287
319
  }
288
320
 
289
321
  if (error) {
@@ -292,22 +324,51 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
292
324
  }
293
325
  }
294
326
 
295
- #if TARGET_OS_OSX
296
327
  // print MTL GPU family:
297
328
  WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
298
329
 
330
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
331
+
299
332
  // determine max supported GPU family
300
333
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
301
334
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
302
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
303
- if ([ctx->device supportsFamily:i]) {
304
- WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
305
- break;
335
+ {
336
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
337
+ if ([ctx->device supportsFamily:i]) {
338
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
339
+ break;
340
+ }
341
+ }
342
+
343
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
344
+ if ([ctx->device supportsFamily:i]) {
345
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
346
+ break;
347
+ }
348
+ }
349
+
350
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
351
+ if ([ctx->device supportsFamily:i]) {
352
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
353
+ break;
354
+ }
306
355
  }
307
356
  }
308
357
 
358
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
359
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
360
+
361
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
362
+
363
+ WSP_GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
364
+ WSP_GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
309
365
  WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
310
- WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
366
+
367
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
368
+ if (@available(macOS 10.12, iOS 16.0, *)) {
369
+ WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
370
+ }
371
+ #elif TARGET_OS_OSX
311
372
  if (ctx->device.maxTransferRate != 0) {
312
373
  WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
313
374
  } else {
@@ -319,286 +380,177 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
319
380
  {
320
381
  NSError * error = nil;
321
382
 
383
+ for (int i = 0; i < WSP_GGML_METAL_MAX_KERNELS; ++i) {
384
+ ctx->kernels[i].function = nil;
385
+ ctx->kernels[i].pipeline = nil;
386
+ }
387
+
322
388
  /*
323
- WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
324
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
325
- (int) ctx->pipeline_##name.threadExecutionWidth); \
389
+ WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
390
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
391
+ (int) kernel->pipeline.threadExecutionWidth); \
326
392
  */
327
- #define WSP_GGML_METAL_ADD_KERNEL(name) \
328
- ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
329
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
330
- if (error) { \
331
- WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
332
- return NULL; \
393
+ #define WSP_GGML_METAL_ADD_KERNEL(e, name, supported) \
394
+ if (supported) { \
395
+ struct wsp_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
396
+ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
397
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
398
+ if (error) { \
399
+ WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
400
+ return NULL; \
401
+ } \
402
+ } else { \
403
+ WSP_GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
333
404
  }
334
405
 
335
- WSP_GGML_METAL_ADD_KERNEL(add);
336
- WSP_GGML_METAL_ADD_KERNEL(add_row);
337
- WSP_GGML_METAL_ADD_KERNEL(mul);
338
- WSP_GGML_METAL_ADD_KERNEL(mul_row);
339
- WSP_GGML_METAL_ADD_KERNEL(div);
340
- WSP_GGML_METAL_ADD_KERNEL(div_row);
341
- WSP_GGML_METAL_ADD_KERNEL(scale);
342
- WSP_GGML_METAL_ADD_KERNEL(scale_4);
343
- WSP_GGML_METAL_ADD_KERNEL(tanh);
344
- WSP_GGML_METAL_ADD_KERNEL(relu);
345
- WSP_GGML_METAL_ADD_KERNEL(gelu);
346
- WSP_GGML_METAL_ADD_KERNEL(gelu_quick);
347
- WSP_GGML_METAL_ADD_KERNEL(silu);
348
- WSP_GGML_METAL_ADD_KERNEL(soft_max);
349
- WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
350
- WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
351
- WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
352
- WSP_GGML_METAL_ADD_KERNEL(get_rows_f32);
353
- WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
354
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
355
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
356
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
357
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
358
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
359
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
360
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
361
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_K);
362
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
363
- WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
364
- WSP_GGML_METAL_ADD_KERNEL(rms_norm);
365
- WSP_GGML_METAL_ADD_KERNEL(group_norm);
366
- WSP_GGML_METAL_ADD_KERNEL(norm);
367
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
368
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
369
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
370
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
371
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
372
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
373
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
374
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
375
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
376
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
377
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
378
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
379
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
380
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
381
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
- //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
- //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
- //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
- WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
397
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
398
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
399
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
400
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
401
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
402
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
403
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
404
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
405
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
406
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
407
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
408
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
409
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
410
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
411
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
412
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
413
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
414
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
415
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
416
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
417
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
418
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
419
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
420
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
421
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
422
- }
423
- WSP_GGML_METAL_ADD_KERNEL(rope_f32);
424
- WSP_GGML_METAL_ADD_KERNEL(rope_f16);
425
- WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
426
- WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
427
- WSP_GGML_METAL_ADD_KERNEL(upscale_f32);
428
- WSP_GGML_METAL_ADD_KERNEL(pad_f32);
429
- WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
430
- WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
- WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32);
432
- WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
433
- WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
434
- WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
435
- WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
436
- WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
437
- //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
438
- //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
439
- WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
- WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32);
441
- WSP_GGML_METAL_ADD_KERNEL(concat);
442
- WSP_GGML_METAL_ADD_KERNEL(sqr);
443
- WSP_GGML_METAL_ADD_KERNEL(sum_rows);
444
-
445
- #undef WSP_GGML_METAL_ADD_KERNEL
406
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
407
+
408
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD, add, true);
409
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
410
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
411
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
412
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV, div, true);
413
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
414
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
415
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
416
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
417
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
418
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
419
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
420
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
421
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
422
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
423
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
424
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
425
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
426
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
427
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
428
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
429
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
430
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
431
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
432
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
433
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
434
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
435
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
436
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
437
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
438
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
439
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
440
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
441
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
442
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
443
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
444
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
445
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
446
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
447
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
448
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
449
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
450
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
451
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
452
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
453
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
454
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
455
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
456
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
457
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
458
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
459
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
460
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
461
+ //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
462
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
463
+ //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
464
+ //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
465
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
466
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
467
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
468
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
469
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
470
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
471
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
472
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
473
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
474
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
475
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
476
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
477
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
478
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
479
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
480
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
481
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
482
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
483
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
484
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
485
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
486
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
487
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
488
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
489
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
490
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
491
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
494
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
495
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
496
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
497
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
498
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
499
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
500
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
501
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
502
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
503
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
506
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
507
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
508
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
509
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
510
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
511
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
512
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
513
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
514
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
515
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
516
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
517
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
518
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
519
+ //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
520
+ //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
521
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
522
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
523
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
524
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
525
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
446
526
  }
447
527
 
448
528
  return ctx;
449
529
  }
450
530
 
451
- void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
531
+ static void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
452
532
  WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
453
- #define WSP_GGML_METAL_DEL_KERNEL(name) \
454
-
455
- WSP_GGML_METAL_DEL_KERNEL(add);
456
- WSP_GGML_METAL_DEL_KERNEL(add_row);
457
- WSP_GGML_METAL_DEL_KERNEL(mul);
458
- WSP_GGML_METAL_DEL_KERNEL(mul_row);
459
- WSP_GGML_METAL_DEL_KERNEL(div);
460
- WSP_GGML_METAL_DEL_KERNEL(div_row);
461
- WSP_GGML_METAL_DEL_KERNEL(scale);
462
- WSP_GGML_METAL_DEL_KERNEL(scale_4);
463
- WSP_GGML_METAL_DEL_KERNEL(tanh);
464
- WSP_GGML_METAL_DEL_KERNEL(relu);
465
- WSP_GGML_METAL_DEL_KERNEL(gelu);
466
- WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
467
- WSP_GGML_METAL_DEL_KERNEL(silu);
468
- WSP_GGML_METAL_DEL_KERNEL(soft_max);
469
- WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
470
- WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
471
- WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
472
- WSP_GGML_METAL_DEL_KERNEL(get_rows_f32);
473
- WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
474
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
475
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
476
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
477
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
478
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
479
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
480
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
481
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_K);
482
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
483
- WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
484
- WSP_GGML_METAL_DEL_KERNEL(rms_norm);
485
- WSP_GGML_METAL_DEL_KERNEL(group_norm);
486
- WSP_GGML_METAL_DEL_KERNEL(norm);
487
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
488
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
489
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
490
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
491
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
492
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
493
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
494
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
495
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
496
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
497
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
498
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
499
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
500
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
501
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
502
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
503
- //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
504
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
505
- //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
506
- //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
507
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
508
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
509
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
510
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
511
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
512
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
513
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
514
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
515
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
516
- WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
517
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
518
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
519
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
520
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
521
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
522
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
523
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
524
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
525
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
526
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
527
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
528
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
529
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
530
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
531
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
532
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
533
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
534
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
535
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
536
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
537
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
538
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
539
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
540
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
541
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
542
- }
543
- WSP_GGML_METAL_DEL_KERNEL(rope_f32);
544
- WSP_GGML_METAL_DEL_KERNEL(rope_f16);
545
- WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
546
- WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
547
- WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
548
- WSP_GGML_METAL_DEL_KERNEL(pad_f32);
549
- WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
550
- WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
551
- WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
552
- WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
553
- WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
554
- WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
555
- WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
556
- WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
557
- //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
558
- //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
559
- WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
560
- WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
561
- WSP_GGML_METAL_DEL_KERNEL(concat);
562
- WSP_GGML_METAL_DEL_KERNEL(sqr);
563
- WSP_GGML_METAL_DEL_KERNEL(sum_rows);
564
-
565
- #undef WSP_GGML_METAL_DEL_KERNEL
566
533
 
567
534
  free(ctx);
568
535
  }
569
536
 
570
- void * wsp_ggml_metal_host_malloc(size_t n) {
571
- void * data = NULL;
572
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
573
- if (result != 0) {
574
- WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
575
- return NULL;
576
- }
577
-
578
- return data;
579
- }
580
-
581
- void wsp_ggml_metal_host_free(void * data) {
582
- free(data);
583
- }
584
-
585
- void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb) {
586
- ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
587
- }
537
+ // temporarily defined here for compatibility between ggml-backend and the old API
588
538
 
589
- int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) {
590
- return ctx->concur_list_len;
591
- }
539
+ struct wsp_ggml_backend_metal_buffer {
540
+ void * data;
541
+ size_t size;
592
542
 
593
- int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
594
- return ctx->concur_list;
595
- }
543
+ id<MTLBuffer> metal;
544
+ };
596
545
 
597
- // temporarily defined here for compatibility between ggml-backend and the old API
598
546
  struct wsp_ggml_backend_metal_buffer_context {
599
- void * data;
547
+ void * all_data;
548
+ size_t all_size;
549
+ bool owned;
600
550
 
601
- id<MTLBuffer> metal;
551
+ // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
552
+ int n_buffers;
553
+ struct wsp_ggml_backend_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
602
554
  };
603
555
 
604
556
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -610,17 +562,29 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
610
562
 
611
563
  const int64_t tsize = wsp_ggml_nbytes(t);
612
564
 
565
+ wsp_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
566
+
613
567
  // compatibility with ggml-backend
614
- if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
615
- struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
568
+ if (buffer && buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
569
+ struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) buffer->context;
616
570
 
617
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
571
+ // find the view that contains the tensor fully
572
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
573
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
618
574
 
619
- WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
575
+ //WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
576
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
577
+ *offs = (size_t) ioffs;
620
578
 
621
- *offs = (size_t) ioffs;
579
+ //WSP_GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
622
580
 
623
- return buf_ctx->metal;
581
+ return buf_ctx->buffers[i].metal;
582
+ }
583
+ }
584
+
585
+ WSP_GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
586
+
587
+ return nil;
624
588
  }
625
589
 
626
590
  // find the view that contains the tensor fully
@@ -642,210 +606,7 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
642
606
  return nil;
643
607
  }
644
608
 
645
- bool wsp_ggml_metal_add_buffer(
646
- struct wsp_ggml_metal_context * ctx,
647
- const char * name,
648
- void * data,
649
- size_t size,
650
- size_t max_size) {
651
- if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
652
- WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
653
- return false;
654
- }
655
-
656
- if (data) {
657
- // verify that the buffer does not overlap with any of the existing buffers
658
- for (int i = 0; i < ctx->n_buffers; ++i) {
659
- const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
660
-
661
- if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
662
- WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
663
- return false;
664
- }
665
- }
666
-
667
- const size_t size_page = sysconf(_SC_PAGESIZE);
668
-
669
- size_t size_aligned = size;
670
- if ((size_aligned % size_page) != 0) {
671
- size_aligned += (size_page - (size_aligned % size_page));
672
- }
673
-
674
- // the buffer fits into the max buffer size allowed by the device
675
- if (size_aligned <= ctx->device.maxBufferLength) {
676
- ctx->buffers[ctx->n_buffers].name = name;
677
- ctx->buffers[ctx->n_buffers].data = data;
678
- ctx->buffers[ctx->n_buffers].size = size;
679
-
680
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
681
-
682
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
683
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
684
- return false;
685
- }
686
-
687
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
688
-
689
- ++ctx->n_buffers;
690
- } else {
691
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
692
- // one of the views
693
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
694
- const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
695
- const size_t size_view = ctx->device.maxBufferLength;
696
-
697
- for (size_t i = 0; i < size; i += size_step) {
698
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
699
-
700
- ctx->buffers[ctx->n_buffers].name = name;
701
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
702
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
703
-
704
- ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
705
-
706
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
707
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
708
- return false;
709
- }
710
-
711
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
712
- if (i + size_step < size) {
713
- WSP_GGML_METAL_LOG_INFO("\n");
714
- }
715
-
716
- ++ctx->n_buffers;
717
- }
718
- }
719
-
720
- #if TARGET_OS_OSX
721
- WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
722
- ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
723
- ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
724
-
725
- if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
726
- WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
727
- } else {
728
- WSP_GGML_METAL_LOG_INFO("\n");
729
- }
730
- #else
731
- WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
732
- #endif
733
- }
734
-
735
- return true;
736
- }
737
-
738
- void wsp_ggml_metal_set_tensor(
739
- struct wsp_ggml_metal_context * ctx,
740
- struct wsp_ggml_tensor * t) {
741
- size_t offs;
742
- id<MTLBuffer> id_dst = wsp_ggml_metal_get_buffer(ctx, t, &offs);
743
-
744
- memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, wsp_ggml_nbytes(t));
745
- }
746
-
747
- void wsp_ggml_metal_get_tensor(
748
- struct wsp_ggml_metal_context * ctx,
749
- struct wsp_ggml_tensor * t) {
750
- size_t offs;
751
- id<MTLBuffer> id_src = wsp_ggml_metal_get_buffer(ctx, t, &offs);
752
-
753
- memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), wsp_ggml_nbytes(t));
754
- }
755
-
756
- void wsp_ggml_metal_graph_find_concurrency(
757
- struct wsp_ggml_metal_context * ctx,
758
- struct wsp_ggml_cgraph * gf, bool check_mem) {
759
- int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
760
- int nodes_unused[WSP_GGML_MAX_CONCUR];
761
-
762
- for (int i = 0; i < WSP_GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
763
- for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
764
- ctx->concur_list_len = 0;
765
-
766
- int n_left = gf->n_nodes;
767
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
768
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
769
-
770
- while (n_left > 0) {
771
- // number of nodes at a layer (that can be issued concurrently)
772
- int concurrency = 0;
773
- for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
774
- if (nodes_unused[i]) {
775
- // if the requirements for gf->nodes[i] are satisfied
776
- int exe_flag = 1;
777
-
778
- // scan all srcs
779
- for (int src_ind = 0; src_ind < WSP_GGML_MAX_SRC; src_ind++) {
780
- struct wsp_ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
781
- if (src_cur) {
782
- // if is leaf nodes it's satisfied.
783
- // TODO: wsp_ggml_is_leaf()
784
- if (src_cur->op == WSP_GGML_OP_NONE && src_cur->grad == NULL) {
785
- continue;
786
- }
787
-
788
- // otherwise this src should be the output from previous nodes.
789
- int is_found = 0;
790
-
791
- // scan 2*search_depth back because we inserted barrier.
792
- //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
793
- for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
794
- if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
795
- is_found = 1;
796
- break;
797
- }
798
- }
799
- if (is_found == 0) {
800
- exe_flag = 0;
801
- break;
802
- }
803
- }
804
- }
805
- if (exe_flag && check_mem) {
806
- // check if nodes[i]'s data will be overwritten by a node before nodes[i].
807
- // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
808
- int64_t data_start = (int64_t) gf->nodes[i]->data;
809
- int64_t length = (int64_t) wsp_ggml_nbytes(gf->nodes[i]);
810
- for (int j = n_start; j < i; j++) {
811
- if (nodes_unused[j] && gf->nodes[j]->op != WSP_GGML_OP_RESHAPE \
812
- && gf->nodes[j]->op != WSP_GGML_OP_VIEW \
813
- && gf->nodes[j]->op != WSP_GGML_OP_TRANSPOSE \
814
- && gf->nodes[j]->op != WSP_GGML_OP_PERMUTE) {
815
- if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
816
- ((int64_t)gf->nodes[j]->data) + (int64_t) wsp_ggml_nbytes(gf->nodes[j]) <= data_start) {
817
- continue;
818
- }
819
-
820
- exe_flag = 0;
821
- }
822
- }
823
- }
824
- if (exe_flag) {
825
- ctx->concur_list[level_pos + concurrency] = i;
826
- nodes_unused[i] = 0;
827
- concurrency++;
828
- ctx->concur_list_len++;
829
- }
830
- }
831
- }
832
- n_left -= concurrency;
833
- // adding a barrier different layer
834
- ctx->concur_list[level_pos + concurrency] = -1;
835
- ctx->concur_list_len++;
836
- // jump all sorted nodes at nodes_bak
837
- while (!nodes_unused[n_start]) {
838
- n_start++;
839
- }
840
- level_pos += concurrency + 1;
841
- }
842
-
843
- if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
844
- WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
845
- }
846
- }
847
-
848
- static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
609
+ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_metal_context * ctx, const struct wsp_ggml_tensor * op) {
849
610
  switch (op->op) {
850
611
  case WSP_GGML_OP_UNARY:
851
612
  switch (wsp_ggml_get_unary_op(op)) {
@@ -871,9 +632,11 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
871
632
  case WSP_GGML_OP_SCALE:
872
633
  case WSP_GGML_OP_SQR:
873
634
  case WSP_GGML_OP_SUM_ROWS:
635
+ return true;
874
636
  case WSP_GGML_OP_SOFT_MAX:
875
637
  case WSP_GGML_OP_RMS_NORM:
876
638
  case WSP_GGML_OP_GROUP_NORM:
639
+ return ctx->support_simdgroup_reduction;
877
640
  case WSP_GGML_OP_NORM:
878
641
  case WSP_GGML_OP_ALIBI:
879
642
  case WSP_GGML_OP_ROPE:
@@ -882,9 +645,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
882
645
  case WSP_GGML_OP_PAD:
883
646
  case WSP_GGML_OP_ARGSORT:
884
647
  case WSP_GGML_OP_LEAKY_RELU:
648
+ return true;
885
649
  case WSP_GGML_OP_MUL_MAT:
886
650
  case WSP_GGML_OP_MUL_MAT_ID:
887
- return true;
651
+ return ctx->support_simdgroup_reduction;
888
652
  case WSP_GGML_OP_CPY:
889
653
  case WSP_GGML_OP_DUP:
890
654
  case WSP_GGML_OP_CONT:
@@ -922,1433 +686,1559 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
922
686
  return false;
923
687
  }
924
688
  }
925
- void wsp_ggml_metal_graph_compute(
689
+
690
+ static bool wsp_ggml_metal_graph_compute(
926
691
  struct wsp_ggml_metal_context * ctx,
927
692
  struct wsp_ggml_cgraph * gf) {
928
- @autoreleasepool {
929
693
 
930
- // if there is ctx->concur_list, dispatch concurrently
931
- // else fallback to serial dispatch
932
694
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
933
-
934
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= WSP_GGML_MAX_CONCUR;
935
-
936
- const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
937
- edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
695
+ edesc.dispatchType = MTLDispatchTypeSerial;
938
696
 
939
697
  // create multiple command buffers and enqueue them
940
698
  // then, we encode the graph into the command buffers in parallel
941
699
 
700
+ const int n_nodes = gf->n_nodes;
942
701
  const int n_cb = ctx->n_cb;
702
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
943
703
 
944
- for (int i = 0; i < n_cb; ++i) {
945
- ctx->command_buffers[i] = [ctx->queue commandBuffer];
704
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
705
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
706
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
707
+ command_buffer_builder[cb_idx] = command_buffer;
946
708
 
947
709
  // enqueue the command buffers in order to specify their execution order
948
- [ctx->command_buffers[i] enqueue];
949
-
950
- ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
710
+ [command_buffer enqueue];
951
711
  }
712
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
952
713
 
953
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
954
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
955
-
956
- dispatch_async(ctx->d_queue, ^{
957
- size_t offs_src0 = 0;
958
- size_t offs_src1 = 0;
959
- size_t offs_dst = 0;
960
-
961
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
962
- id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
963
-
964
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
965
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
966
-
967
- for (int ind = node_start; ind < node_end; ++ind) {
968
- const int i = has_concur ? ctx->concur_list[ind] : ind;
969
-
970
- if (i == -1) {
971
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
972
- continue;
973
- }
974
-
975
- //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
976
-
977
- struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
978
- struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
979
- struct wsp_ggml_tensor * dst = gf->nodes[i];
980
-
981
- switch (dst->op) {
982
- case WSP_GGML_OP_NONE:
983
- case WSP_GGML_OP_RESHAPE:
984
- case WSP_GGML_OP_VIEW:
985
- case WSP_GGML_OP_TRANSPOSE:
986
- case WSP_GGML_OP_PERMUTE:
987
- {
988
- // noop -> next node
989
- } continue;
990
- default:
991
- {
992
- } break;
993
- }
994
-
995
- if (!wsp_ggml_metal_supports_op(dst)) {
996
- WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
997
- WSP_GGML_ASSERT(!"unsupported op");
998
- }
999
-
1000
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
1001
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
1002
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
1003
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
1004
-
1005
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
1006
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
1007
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
1008
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
1009
-
1010
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
1011
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
1012
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
1013
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
1014
-
1015
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
1016
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
1017
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
1018
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
1019
-
1020
- const int64_t ne0 = dst ? dst->ne[0] : 0;
1021
- const int64_t ne1 = dst ? dst->ne[1] : 0;
1022
- const int64_t ne2 = dst ? dst->ne[2] : 0;
1023
- const int64_t ne3 = dst ? dst->ne[3] : 0;
1024
-
1025
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
1026
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
1027
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
1028
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
1029
-
1030
- const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
1031
- const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
1032
- const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
1033
-
1034
- id<MTLBuffer> id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
1035
- id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
1036
- id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
1037
-
1038
- //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
1039
- //if (src0) {
1040
- // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
1041
- // wsp_ggml_is_contiguous(src0), src0->name);
1042
- //}
1043
- //if (src1) {
1044
- // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
1045
- // wsp_ggml_is_contiguous(src1), src1->name);
1046
- //}
1047
- //if (dst) {
1048
- // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
1049
- // dst->name);
1050
- //}
1051
-
1052
- switch (dst->op) {
1053
- case WSP_GGML_OP_CONCAT:
1054
- {
1055
- const int64_t nb = ne00;
1056
-
1057
- [encoder setComputePipelineState:ctx->pipeline_concat];
1058
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1059
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1060
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1061
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1062
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1063
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1064
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1065
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1066
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1067
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1068
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1069
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1070
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1071
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1072
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1073
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1074
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1075
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1076
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1077
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1078
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1079
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1080
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1081
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1082
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1083
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1084
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1085
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1086
-
1087
- const int nth = MIN(1024, ne0);
1088
-
1089
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1090
- } break;
1091
- case WSP_GGML_OP_ADD:
1092
- case WSP_GGML_OP_MUL:
1093
- case WSP_GGML_OP_DIV:
1094
- {
1095
- const size_t offs = 0;
1096
-
1097
- bool bcast_row = false;
1098
-
1099
- int64_t nb = ne00;
1100
-
1101
- id<MTLComputePipelineState> pipeline = nil;
1102
-
1103
- if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1104
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
714
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
715
+ const int cb_idx = iter;
1105
716
 
1106
- // src1 is a row
1107
- WSP_GGML_ASSERT(ne11 == 1);
717
+ size_t offs_src0 = 0;
718
+ size_t offs_src1 = 0;
719
+ size_t offs_dst = 0;
1108
720
 
1109
- nb = ne00 / 4;
1110
- switch (dst->op) {
1111
- case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1112
- case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1113
- case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1114
- default: WSP_GGML_ASSERT(false);
1115
- }
721
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
722
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1116
723
 
1117
- bcast_row = true;
1118
- } else {
1119
- switch (dst->op) {
1120
- case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1121
- case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1122
- case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1123
- default: WSP_GGML_ASSERT(false);
1124
- }
1125
- }
724
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
725
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
1126
726
 
1127
- [encoder setComputePipelineState:pipeline];
1128
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1129
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1130
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1131
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1132
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1133
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1134
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1135
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1136
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1137
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1138
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1139
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1140
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1141
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1142
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1143
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1144
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1145
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1146
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1147
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1148
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1149
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1150
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1151
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1152
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1153
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1154
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1155
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1156
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1157
-
1158
- if (bcast_row) {
1159
- const int64_t n = wsp_ggml_nelements(dst)/4;
727
+ for (int i = node_start; i < node_end; ++i) {
728
+ if (i == -1) {
729
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
730
+ continue;
731
+ }
1160
732
 
1161
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1162
- } else {
1163
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
733
+ //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
734
+
735
+ struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
736
+ struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
737
+ struct wsp_ggml_tensor * dst = gf->nodes[i];
738
+
739
+ switch (dst->op) {
740
+ case WSP_GGML_OP_NONE:
741
+ case WSP_GGML_OP_RESHAPE:
742
+ case WSP_GGML_OP_VIEW:
743
+ case WSP_GGML_OP_TRANSPOSE:
744
+ case WSP_GGML_OP_PERMUTE:
745
+ {
746
+ // noop -> next node
747
+ } continue;
748
+ default:
749
+ {
750
+ } break;
751
+ }
1164
752
 
1165
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1166
- }
1167
- } break;
1168
- case WSP_GGML_OP_ACC:
1169
- {
1170
- WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
1171
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1172
- WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
753
+ if (!wsp_ggml_metal_supports_op(ctx, dst)) {
754
+ WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
755
+ WSP_GGML_ASSERT(!"unsupported op");
756
+ }
1173
757
 
1174
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1175
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
1176
-
1177
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1178
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1179
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1180
- const size_t offs = ((int32_t *) dst->op_params)[3];
1181
-
1182
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1183
-
1184
- if (!inplace) {
1185
- // run a separete kernel to cpy src->dst
1186
- // not sure how to avoid this
1187
- // TODO: make a simpler cpy_bytes kernel
1188
-
1189
- const int nth = MIN(1024, ne00);
1190
-
1191
- [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1192
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1193
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1194
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1195
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1196
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1197
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1198
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1199
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1200
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1201
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1202
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1203
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1204
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1205
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1206
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1207
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1208
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1209
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1210
-
1211
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1212
- }
758
+ #ifndef WSP_GGML_METAL_NDEBUG
759
+ [encoder pushDebugGroup:[NSString stringWithCString:wsp_ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
760
+ #endif
1213
761
 
1214
- [encoder setComputePipelineState:ctx->pipeline_add];
1215
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1216
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1217
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1218
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1219
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1220
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1221
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1222
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1223
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1224
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1225
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1226
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1227
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1228
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1229
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1230
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1231
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1232
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1233
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1234
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1235
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1236
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1237
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1238
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1239
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1240
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1241
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1242
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1243
-
1244
- const int nth = MIN(1024, ne0);
1245
-
1246
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1247
- } break;
1248
- case WSP_GGML_OP_SCALE:
1249
- {
762
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
763
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
764
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
765
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
766
+
767
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
768
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
769
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
770
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
771
+
772
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
773
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
774
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
775
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
776
+
777
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
778
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
779
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
780
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
781
+
782
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
783
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
784
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
785
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
786
+
787
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
788
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
789
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
790
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
791
+
792
+ const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
793
+ const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
794
+ const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
795
+
796
+ id<MTLBuffer> id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
797
+ id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
798
+ id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
799
+
800
+ //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
801
+ //if (src0) {
802
+ // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
803
+ // wsp_ggml_is_contiguous(src0), src0->name);
804
+ //}
805
+ //if (src1) {
806
+ // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
807
+ // wsp_ggml_is_contiguous(src1), src1->name);
808
+ //}
809
+ //if (dst) {
810
+ // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
811
+ // dst->name);
812
+ //}
813
+
814
+ switch (dst->op) {
815
+ case WSP_GGML_OP_CONCAT:
816
+ {
817
+ const int64_t nb = ne00;
818
+
819
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
820
+
821
+ [encoder setComputePipelineState:pipeline];
822
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
823
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
824
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
825
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
826
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
827
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
828
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
829
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
830
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
831
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
832
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
833
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
834
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
835
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
836
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
837
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
838
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
839
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
840
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
841
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
842
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
843
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
844
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
845
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
846
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
847
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
848
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
849
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
850
+
851
+ const int nth = MIN(1024, ne0);
852
+
853
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
854
+ } break;
855
+ case WSP_GGML_OP_ADD:
856
+ case WSP_GGML_OP_MUL:
857
+ case WSP_GGML_OP_DIV:
858
+ {
859
+ const size_t offs = 0;
860
+
861
+ bool bcast_row = false;
862
+
863
+ int64_t nb = ne00;
864
+
865
+ id<MTLComputePipelineState> pipeline = nil;
866
+
867
+ if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1250
868
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1251
869
 
1252
- const float scale = *(const float *) src1->data;
870
+ // src1 is a row
871
+ WSP_GGML_ASSERT(ne11 == 1);
1253
872
 
1254
- int64_t n = wsp_ggml_nelements(dst);
873
+ nb = ne00 / 4;
874
+ switch (dst->op) {
875
+ case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
876
+ case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
877
+ case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
878
+ default: WSP_GGML_ASSERT(false);
879
+ }
1255
880
 
1256
- if (n % 4 == 0) {
1257
- n /= 4;
1258
- [encoder setComputePipelineState:ctx->pipeline_scale_4];
1259
- } else {
1260
- [encoder setComputePipelineState:ctx->pipeline_scale];
881
+ bcast_row = true;
882
+ } else {
883
+ switch (dst->op) {
884
+ case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
885
+ case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
886
+ case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
887
+ default: WSP_GGML_ASSERT(false);
1261
888
  }
889
+ }
1262
890
 
1263
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1264
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1265
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
891
+ [encoder setComputePipelineState:pipeline];
892
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
893
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
894
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
895
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
896
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
897
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
898
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
899
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
900
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
901
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
902
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
903
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
904
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
905
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
906
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
907
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
908
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
909
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
910
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
911
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
912
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
913
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
914
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
915
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
916
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
917
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
918
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
919
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
920
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
921
+
922
+ if (bcast_row) {
923
+ const int64_t n = wsp_ggml_nelements(dst)/4;
1266
924
 
1267
925
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1268
- } break;
1269
- case WSP_GGML_OP_UNARY:
1270
- switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
1271
- case WSP_GGML_UNARY_OP_TANH:
1272
- {
1273
- [encoder setComputePipelineState:ctx->pipeline_tanh];
1274
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1275
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
926
+ } else {
927
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1276
928
 
1277
- const int64_t n = wsp_ggml_nelements(dst);
929
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
930
+ }
931
+ } break;
932
+ case WSP_GGML_OP_ACC:
933
+ {
934
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
935
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
936
+ WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
1278
937
 
1279
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1280
- } break;
1281
- case WSP_GGML_UNARY_OP_RELU:
1282
- {
1283
- [encoder setComputePipelineState:ctx->pipeline_relu];
1284
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1285
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
938
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
939
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
1286
940
 
1287
- const int64_t n = wsp_ggml_nelements(dst);
941
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
942
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
943
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
944
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1288
945
 
1289
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1290
- } break;
1291
- case WSP_GGML_UNARY_OP_GELU:
1292
- {
1293
- [encoder setComputePipelineState:ctx->pipeline_gelu];
1294
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1295
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
946
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1296
947
 
1297
- const int64_t n = wsp_ggml_nelements(dst);
1298
- WSP_GGML_ASSERT(n % 4 == 0);
948
+ if (!inplace) {
949
+ // run a separete kernel to cpy src->dst
950
+ // not sure how to avoid this
951
+ // TODO: make a simpler cpy_bytes kernel
1299
952
 
1300
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1301
- } break;
1302
- case WSP_GGML_UNARY_OP_GELU_QUICK:
1303
- {
1304
- [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1305
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1306
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
953
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1307
954
 
1308
- const int64_t n = wsp_ggml_nelements(dst);
1309
- WSP_GGML_ASSERT(n % 4 == 0);
955
+ [encoder setComputePipelineState:pipeline];
956
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
957
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
958
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
959
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
960
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
961
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
962
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
963
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
964
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
965
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
966
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
967
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
968
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
969
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
970
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
971
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
972
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
973
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1310
974
 
1311
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1312
- } break;
1313
- case WSP_GGML_UNARY_OP_SILU:
1314
- {
1315
- [encoder setComputePipelineState:ctx->pipeline_silu];
1316
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1317
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
975
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1318
976
 
1319
- const int64_t n = wsp_ggml_nelements(dst);
1320
- WSP_GGML_ASSERT(n % 4 == 0);
977
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
978
+ }
1321
979
 
1322
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1323
- } break;
1324
- default:
1325
- {
1326
- WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1327
- WSP_GGML_ASSERT(false);
1328
- }
1329
- } break;
1330
- case WSP_GGML_OP_SQR:
1331
- {
1332
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
980
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline;
981
+
982
+ [encoder setComputePipelineState:pipeline];
983
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
984
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
985
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
986
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
987
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
988
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
989
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
990
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
991
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
992
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
993
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
994
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
995
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
996
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
997
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
998
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
999
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1000
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1001
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1002
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1003
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1004
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1005
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1006
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1007
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1008
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1009
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1010
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1011
+
1012
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1013
+
1014
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1015
+ } break;
1016
+ case WSP_GGML_OP_SCALE:
1017
+ {
1018
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1019
+
1020
+ const float scale = *(const float *) dst->op_params;
1021
+
1022
+ int64_t n = wsp_ggml_nelements(dst);
1023
+
1024
+ id<MTLComputePipelineState> pipeline = nil;
1025
+
1026
+ if (n % 4 == 0) {
1027
+ n /= 4;
1028
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1029
+ } else {
1030
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1031
+ }
1333
1032
 
1334
- [encoder setComputePipelineState:ctx->pipeline_sqr];
1335
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1336
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1033
+ [encoder setComputePipelineState:pipeline];
1034
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1035
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1036
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1337
1037
 
1338
- const int64_t n = wsp_ggml_nelements(dst);
1339
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1340
- } break;
1341
- case WSP_GGML_OP_SUM_ROWS:
1342
- {
1343
- WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
1038
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1039
+ } break;
1040
+ case WSP_GGML_OP_UNARY:
1041
+ switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
1042
+ case WSP_GGML_UNARY_OP_TANH:
1043
+ {
1044
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1344
1045
 
1345
- [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1346
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1347
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1348
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1349
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1350
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1351
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1352
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1353
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1354
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1355
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1356
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1357
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1358
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1359
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1360
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1361
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1362
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1363
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1364
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1365
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1366
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1367
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1368
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1369
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1370
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1371
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1372
-
1373
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1374
- } break;
1375
- case WSP_GGML_OP_SOFT_MAX:
1376
- {
1377
- int nth = 32; // SIMD width
1378
-
1379
- if (ne00%4 == 0) {
1380
- while (nth < ne00/4 && nth < 256) {
1381
- nth *= 2;
1382
- }
1383
- [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1384
- } else {
1385
- while (nth < ne00 && nth < 1024) {
1386
- nth *= 2;
1387
- }
1388
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
1389
- }
1046
+ [encoder setComputePipelineState:pipeline];
1047
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1048
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1390
1049
 
1391
- const float scale = ((float *) dst->op_params)[0];
1050
+ const int64_t n = wsp_ggml_nelements(dst);
1392
1051
 
1393
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1394
- if (id_src1) {
1395
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1396
- } else {
1397
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1398
- }
1399
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1400
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1401
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1402
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1403
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1404
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1405
-
1406
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1407
- } break;
1408
- case WSP_GGML_OP_DIAG_MASK_INF:
1409
- {
1410
- const int n_past = ((int32_t *)(dst->op_params))[0];
1411
-
1412
- if (ne00%8 == 0) {
1413
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
1414
- } else {
1415
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
1416
- }
1417
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1418
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1419
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1420
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1421
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1052
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1053
+ } break;
1054
+ case WSP_GGML_UNARY_OP_RELU:
1055
+ {
1056
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1422
1057
 
1423
- if (ne00%8 == 0) {
1424
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1425
- }
1426
- else {
1427
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1058
+ [encoder setComputePipelineState:pipeline];
1059
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1060
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1061
+
1062
+ const int64_t n = wsp_ggml_nelements(dst);
1063
+
1064
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1065
+ } break;
1066
+ case WSP_GGML_UNARY_OP_GELU:
1067
+ {
1068
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1069
+
1070
+ [encoder setComputePipelineState:pipeline];
1071
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1072
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1073
+
1074
+ const int64_t n = wsp_ggml_nelements(dst);
1075
+ WSP_GGML_ASSERT(n % 4 == 0);
1076
+
1077
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1078
+ } break;
1079
+ case WSP_GGML_UNARY_OP_GELU_QUICK:
1080
+ {
1081
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1082
+
1083
+ [encoder setComputePipelineState:pipeline];
1084
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1085
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1086
+
1087
+ const int64_t n = wsp_ggml_nelements(dst);
1088
+ WSP_GGML_ASSERT(n % 4 == 0);
1089
+
1090
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1091
+ } break;
1092
+ case WSP_GGML_UNARY_OP_SILU:
1093
+ {
1094
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1095
+
1096
+ [encoder setComputePipelineState:pipeline];
1097
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1099
+
1100
+ const int64_t n = wsp_ggml_nelements(dst);
1101
+ WSP_GGML_ASSERT(n % 4 == 0);
1102
+
1103
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1104
+ } break;
1105
+ default:
1106
+ {
1107
+ WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1108
+ WSP_GGML_ASSERT(false);
1428
1109
  }
1429
- } break;
1430
- case WSP_GGML_OP_MUL_MAT:
1431
- {
1432
- WSP_GGML_ASSERT(ne00 == ne10);
1110
+ } break;
1111
+ case WSP_GGML_OP_SQR:
1112
+ {
1113
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1114
+
1115
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1116
+
1117
+ [encoder setComputePipelineState:pipeline];
1118
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1119
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1120
+
1121
+ const int64_t n = wsp_ggml_nelements(dst);
1122
+
1123
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1124
+ } break;
1125
+ case WSP_GGML_OP_SUM_ROWS:
1126
+ {
1127
+ WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
1128
+
1129
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1130
+
1131
+ [encoder setComputePipelineState:pipeline];
1132
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1133
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1134
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1135
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1136
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1137
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1138
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1139
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1140
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1141
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1142
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1143
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1144
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1145
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1146
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1147
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1148
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1149
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1150
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1151
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1152
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1153
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1154
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1155
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1156
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1157
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1158
+
1159
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1160
+ } break;
1161
+ case WSP_GGML_OP_SOFT_MAX:
1162
+ {
1163
+ int nth = 32; // SIMD width
1164
+
1165
+ id<MTLComputePipelineState> pipeline = nil;
1166
+
1167
+ if (ne00%4 == 0) {
1168
+ while (nth < ne00/4 && nth < 256) {
1169
+ nth *= 2;
1170
+ }
1171
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
1172
+ } else {
1173
+ while (nth < ne00 && nth < 1024) {
1174
+ nth *= 2;
1175
+ }
1176
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1177
+ }
1433
1178
 
1434
- // TODO: assert that dim2 and dim3 are contiguous
1435
- WSP_GGML_ASSERT(ne12 % ne02 == 0);
1436
- WSP_GGML_ASSERT(ne13 % ne03 == 0);
1179
+ const float scale = ((float *) dst->op_params)[0];
1437
1180
 
1438
- const uint r2 = ne12/ne02;
1439
- const uint r3 = ne13/ne03;
1181
+ [encoder setComputePipelineState:pipeline];
1182
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1183
+ if (id_src1) {
1184
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1185
+ } else {
1186
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1187
+ }
1188
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1189
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1190
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1191
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1192
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1193
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1194
+
1195
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1196
+ } break;
1197
+ case WSP_GGML_OP_DIAG_MASK_INF:
1198
+ {
1199
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1200
+
1201
+ id<MTLComputePipelineState> pipeline = nil;
1202
+
1203
+ if (ne00%8 == 0) {
1204
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1205
+ } else {
1206
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1207
+ }
1208
+
1209
+ [encoder setComputePipelineState:pipeline];
1210
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1211
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1212
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1213
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1214
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1215
+
1216
+ if (ne00%8 == 0) {
1217
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1218
+ }
1219
+ else {
1220
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
+ }
1222
+ } break;
1223
+ case WSP_GGML_OP_MUL_MAT:
1224
+ {
1225
+ WSP_GGML_ASSERT(ne00 == ne10);
1440
1226
 
1441
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1442
- // to the matrix-vector kernel
1443
- int ne11_mm_min = 1;
1227
+ // TODO: assert that dim2 and dim3 are contiguous
1228
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1229
+ WSP_GGML_ASSERT(ne13 % ne03 == 0);
1230
+
1231
+ const uint r2 = ne12/ne02;
1232
+ const uint r3 = ne13/ne03;
1233
+
1234
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1235
+ // to the matrix-vector kernel
1236
+ int ne11_mm_min = 1;
1444
1237
 
1445
1238
  #if 0
1446
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1447
- // these numbers do not translate to other devices or model sizes
1448
- // TODO: need to find a better approach
1449
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1450
- switch (src0t) {
1451
- case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
1452
- case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1453
- case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1454
- case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1455
- case WSP_GGML_TYPE_Q4_0:
1456
- case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1457
- case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1458
- case WSP_GGML_TYPE_Q5_0: // not tested yet
1459
- case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1460
- case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1461
- case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1462
- default: ne11_mm_min = 1; break;
1463
- }
1239
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1240
+ // these numbers do not translate to other devices or model sizes
1241
+ // TODO: need to find a better approach
1242
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1243
+ switch (src0t) {
1244
+ case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
1245
+ case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1246
+ case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1247
+ case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1248
+ case WSP_GGML_TYPE_Q4_0:
1249
+ case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1250
+ case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1251
+ case WSP_GGML_TYPE_Q5_0: // not tested yet
1252
+ case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1253
+ case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1254
+ case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1255
+ default: ne11_mm_min = 1; break;
1464
1256
  }
1257
+ }
1465
1258
  #endif
1466
1259
 
1467
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1468
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1469
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1470
- !wsp_ggml_is_transposed(src0) &&
1471
- !wsp_ggml_is_transposed(src1) &&
1472
- src1t == WSP_GGML_TYPE_F32 &&
1473
- ne00 % 32 == 0 && ne00 >= 64 &&
1474
- (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1475
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1476
- switch (src0->type) {
1477
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
1478
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1479
- case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1480
- case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1481
- case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1482
- case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1483
- case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1484
- case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1485
- case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
1486
- case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1487
- case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1488
- case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1489
- default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1490
- }
1491
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1492
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1493
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1494
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1495
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1496
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1497
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1498
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1499
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1500
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1501
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1502
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1503
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1504
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1505
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1506
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1507
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1508
- } else {
1509
- int nth0 = 32;
1510
- int nth1 = 1;
1511
- int nrows = 1;
1512
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1513
-
1514
- // use custom matrix x vector kernel
1515
- switch (src0t) {
1516
- case WSP_GGML_TYPE_F32:
1517
- {
1518
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1519
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1520
- nrows = 4;
1521
- } break;
1522
- case WSP_GGML_TYPE_F16:
1523
- {
1524
- nth0 = 32;
1525
- nth1 = 1;
1526
- if (src1t == WSP_GGML_TYPE_F32) {
1527
- if (ne11 * ne12 < 4) {
1528
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1529
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1530
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1531
- nrows = ne11;
1532
- } else {
1533
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1534
- nrows = 4;
1535
- }
1260
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1261
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1262
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1263
+ !wsp_ggml_is_transposed(src0) &&
1264
+ !wsp_ggml_is_transposed(src1) &&
1265
+ src1t == WSP_GGML_TYPE_F32 &&
1266
+ ne00 % 32 == 0 && ne00 >= 64 &&
1267
+ (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1268
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1269
+
1270
+ id<MTLComputePipelineState> pipeline = nil;
1271
+
1272
+ switch (src0->type) {
1273
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1274
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1275
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1276
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1277
+ case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1278
+ case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1279
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1280
+ case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1281
+ case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1282
+ case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1283
+ case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1284
+ case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1285
+ case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1286
+ case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1287
+ default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1288
+ }
1289
+
1290
+ [encoder setComputePipelineState:pipeline];
1291
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1292
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1293
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1294
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1295
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1296
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1297
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1298
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1299
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1300
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1301
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1302
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1303
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1304
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1305
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1306
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1307
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1308
+ } else {
1309
+ int nth0 = 32;
1310
+ int nth1 = 1;
1311
+ int nrows = 1;
1312
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1313
+
1314
+ id<MTLComputePipelineState> pipeline = nil;
1315
+
1316
+ // use custom matrix x vector kernel
1317
+ switch (src0t) {
1318
+ case WSP_GGML_TYPE_F32:
1319
+ {
1320
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1321
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1322
+ nrows = 4;
1323
+ } break;
1324
+ case WSP_GGML_TYPE_F16:
1325
+ {
1326
+ nth0 = 32;
1327
+ nth1 = 1;
1328
+ if (src1t == WSP_GGML_TYPE_F32) {
1329
+ if (ne11 * ne12 < 4) {
1330
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1331
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1332
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1333
+ nrows = ne11;
1536
1334
  } else {
1537
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1335
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1538
1336
  nrows = 4;
1539
1337
  }
1540
- } break;
1541
- case WSP_GGML_TYPE_Q4_0:
1542
- {
1543
- nth0 = 8;
1544
- nth1 = 8;
1545
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1546
- } break;
1547
- case WSP_GGML_TYPE_Q4_1:
1548
- {
1549
- nth0 = 8;
1550
- nth1 = 8;
1551
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1552
- } break;
1553
- case WSP_GGML_TYPE_Q5_0:
1554
- {
1555
- nth0 = 8;
1556
- nth1 = 8;
1557
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1558
- } break;
1559
- case WSP_GGML_TYPE_Q5_1:
1560
- {
1561
- nth0 = 8;
1562
- nth1 = 8;
1563
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1564
- } break;
1565
- case WSP_GGML_TYPE_Q8_0:
1566
- {
1567
- nth0 = 8;
1568
- nth1 = 8;
1569
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1570
- } break;
1571
- case WSP_GGML_TYPE_Q2_K:
1572
- {
1573
- nth0 = 2;
1574
- nth1 = 32;
1575
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1576
- } break;
1577
- case WSP_GGML_TYPE_Q3_K:
1578
- {
1579
- nth0 = 2;
1580
- nth1 = 32;
1581
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1582
- } break;
1583
- case WSP_GGML_TYPE_Q4_K:
1584
- {
1585
- nth0 = 4; //1;
1586
- nth1 = 8; //32;
1587
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1588
- } break;
1589
- case WSP_GGML_TYPE_Q5_K:
1590
- {
1591
- nth0 = 2;
1592
- nth1 = 32;
1593
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1594
- } break;
1595
- case WSP_GGML_TYPE_Q6_K:
1596
- {
1597
- nth0 = 2;
1598
- nth1 = 32;
1599
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1600
- } break;
1601
- default:
1602
- {
1603
- WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1604
- WSP_GGML_ASSERT(false && "not implemented");
1338
+ } else {
1339
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1340
+ nrows = 4;
1605
1341
  }
1606
- };
1342
+ } break;
1343
+ case WSP_GGML_TYPE_Q4_0:
1344
+ {
1345
+ nth0 = 8;
1346
+ nth1 = 8;
1347
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1348
+ } break;
1349
+ case WSP_GGML_TYPE_Q4_1:
1350
+ {
1351
+ nth0 = 8;
1352
+ nth1 = 8;
1353
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1354
+ } break;
1355
+ case WSP_GGML_TYPE_Q5_0:
1356
+ {
1357
+ nth0 = 8;
1358
+ nth1 = 8;
1359
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1360
+ } break;
1361
+ case WSP_GGML_TYPE_Q5_1:
1362
+ {
1363
+ nth0 = 8;
1364
+ nth1 = 8;
1365
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1366
+ } break;
1367
+ case WSP_GGML_TYPE_Q8_0:
1368
+ {
1369
+ nth0 = 8;
1370
+ nth1 = 8;
1371
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1372
+ } break;
1373
+ case WSP_GGML_TYPE_Q2_K:
1374
+ {
1375
+ nth0 = 2;
1376
+ nth1 = 32;
1377
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1378
+ } break;
1379
+ case WSP_GGML_TYPE_Q3_K:
1380
+ {
1381
+ nth0 = 2;
1382
+ nth1 = 32;
1383
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1384
+ } break;
1385
+ case WSP_GGML_TYPE_Q4_K:
1386
+ {
1387
+ nth0 = 4; //1;
1388
+ nth1 = 8; //32;
1389
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1390
+ } break;
1391
+ case WSP_GGML_TYPE_Q5_K:
1392
+ {
1393
+ nth0 = 2;
1394
+ nth1 = 32;
1395
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1396
+ } break;
1397
+ case WSP_GGML_TYPE_Q6_K:
1398
+ {
1399
+ nth0 = 2;
1400
+ nth1 = 32;
1401
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1402
+ } break;
1403
+ case WSP_GGML_TYPE_IQ2_XXS:
1404
+ {
1405
+ nth0 = 4;
1406
+ nth1 = 16;
1407
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1408
+ } break;
1409
+ case WSP_GGML_TYPE_IQ2_XS:
1410
+ {
1411
+ nth0 = 4;
1412
+ nth1 = 16;
1413
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1414
+ } break;
1415
+ default:
1416
+ {
1417
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1418
+ WSP_GGML_ASSERT(false && "not implemented");
1419
+ }
1420
+ };
1607
1421
 
1608
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1609
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1610
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1611
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1612
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1613
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1614
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1615
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1616
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1617
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1618
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1619
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1620
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1621
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1622
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1623
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1624
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1625
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1626
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1627
-
1628
- if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1629
- src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1630
- src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1631
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1632
- }
1633
- else if (src0t == WSP_GGML_TYPE_Q4_K) {
1634
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1635
- }
1636
- else if (src0t == WSP_GGML_TYPE_Q3_K) {
1637
- #ifdef WSP_GGML_QKK_64
1638
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1639
- #else
1640
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1641
- #endif
1642
- }
1643
- else if (src0t == WSP_GGML_TYPE_Q5_K) {
1644
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1645
- }
1646
- else if (src0t == WSP_GGML_TYPE_Q6_K) {
1647
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1648
- } else {
1649
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1650
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1651
- }
1422
+ if (wsp_ggml_is_quantized(src0t)) {
1423
+ WSP_GGML_ASSERT(ne00 >= nth0*nth1);
1652
1424
  }
1653
- } break;
1654
- case WSP_GGML_OP_MUL_MAT_ID:
1655
- {
1656
- //WSP_GGML_ASSERT(ne00 == ne10);
1657
- //WSP_GGML_ASSERT(ne03 == ne13);
1658
-
1659
- WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
1660
-
1661
- const int n_as = ((int32_t *) dst->op_params)[1];
1662
-
1663
- // TODO: make this more general
1664
- WSP_GGML_ASSERT(n_as <= 8);
1665
-
1666
- struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
1667
-
1668
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1669
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1670
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1671
- const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
1672
-
1673
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
1674
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1675
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1676
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
1677
-
1678
- const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
1679
-
1680
- WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
1681
- WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
1682
-
1683
- WSP_GGML_ASSERT(ne20 % 32 == 0);
1684
- // !!!!!!!!! TODO: this assert is probably required but not sure!
1685
- //WSP_GGML_ASSERT(ne20 >= 64);
1686
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1687
-
1688
- const uint r2 = ne12/ne22;
1689
- const uint r3 = ne13/ne23;
1690
-
1691
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1692
- // to the matrix-vector kernel
1693
- int ne11_mm_min = 1;
1694
-
1695
- const int idx = ((int32_t *) dst->op_params)[0];
1696
-
1697
- // batch size
1698
- WSP_GGML_ASSERT(ne01 == ne11);
1699
-
1700
- const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1701
-
1702
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1703
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1704
- // !!!
1705
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1706
- // indirect matrix multiplication
1707
- // !!!
1708
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1709
- switch (src2->type) {
1710
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1711
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1712
- case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1713
- case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1714
- case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1715
- case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1716
- case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1717
- case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1718
- case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1719
- case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1720
- case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1721
- case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1722
- default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1723
- }
1724
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1725
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1726
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1727
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1728
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1729
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1730
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1731
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1732
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1733
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1734
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1735
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1736
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1737
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1738
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1739
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1740
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1741
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1742
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1743
- // TODO: how to make this an array? read Metal docs
1744
- for (int j = 0; j < n_as; ++j) {
1745
- struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1746
-
1747
- size_t offs_src_cur = 0;
1748
- id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1749
-
1750
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1751
- }
1752
-
1753
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1754
-
1755
- // TODO: processing one row at a time (ne11 -> 1) is not efficient
1756
- [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1757
- } else {
1758
- int nth0 = 32;
1759
- int nth1 = 1;
1760
- int nrows = 1;
1761
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1762
-
1763
- // use custom matrix x vector kernel
1764
- switch (src2t) {
1765
- case WSP_GGML_TYPE_F32:
1766
- {
1767
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1768
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1769
- } break;
1770
- case WSP_GGML_TYPE_F16:
1771
- {
1772
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1773
- nth0 = 32;
1774
- nth1 = 1;
1775
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1776
- } break;
1777
- case WSP_GGML_TYPE_Q4_0:
1778
- {
1779
- nth0 = 8;
1780
- nth1 = 8;
1781
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1782
- } break;
1783
- case WSP_GGML_TYPE_Q4_1:
1784
- {
1785
- nth0 = 8;
1786
- nth1 = 8;
1787
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1788
- } break;
1789
- case WSP_GGML_TYPE_Q5_0:
1790
- {
1791
- nth0 = 8;
1792
- nth1 = 8;
1793
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1794
- } break;
1795
- case WSP_GGML_TYPE_Q5_1:
1796
- {
1797
- nth0 = 8;
1798
- nth1 = 8;
1799
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1800
- } break;
1801
- case WSP_GGML_TYPE_Q8_0:
1802
- {
1803
- nth0 = 8;
1804
- nth1 = 8;
1805
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1806
- } break;
1807
- case WSP_GGML_TYPE_Q2_K:
1808
- {
1809
- nth0 = 2;
1810
- nth1 = 32;
1811
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1812
- } break;
1813
- case WSP_GGML_TYPE_Q3_K:
1814
- {
1815
- nth0 = 2;
1816
- nth1 = 32;
1817
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1818
- } break;
1819
- case WSP_GGML_TYPE_Q4_K:
1820
- {
1821
- nth0 = 4; //1;
1822
- nth1 = 8; //32;
1823
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1824
- } break;
1825
- case WSP_GGML_TYPE_Q5_K:
1826
- {
1827
- nth0 = 2;
1828
- nth1 = 32;
1829
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1830
- } break;
1831
- case WSP_GGML_TYPE_Q6_K:
1832
- {
1833
- nth0 = 2;
1834
- nth1 = 32;
1835
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1836
- } break;
1837
- default:
1838
- {
1839
- WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1840
- WSP_GGML_ASSERT(false && "not implemented");
1841
- }
1842
- };
1843
1425
 
1844
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1845
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1846
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1847
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1848
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1849
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1850
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1851
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1852
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1853
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1854
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1855
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1856
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1857
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1858
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1859
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1860
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1861
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1862
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1863
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1864
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1865
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1866
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1867
- // TODO: how to make this an array? read Metal docs
1868
- for (int j = 0; j < n_as; ++j) {
1869
- struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1870
-
1871
- size_t offs_src_cur = 0;
1872
- id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1873
-
1874
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1875
- }
1876
-
1877
- if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
1878
- src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
1879
- src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
1880
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1881
- }
1882
- else if (src2t == WSP_GGML_TYPE_Q4_K) {
1883
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1884
- }
1885
- else if (src2t == WSP_GGML_TYPE_Q3_K) {
1426
+ [encoder setComputePipelineState:pipeline];
1427
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1428
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1429
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1430
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1431
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1432
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1433
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1434
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1435
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1436
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1437
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1438
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1439
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1440
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1441
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1442
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1443
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1444
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1445
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1446
+
1447
+ if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1448
+ src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1449
+ src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1450
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1451
+ }
1452
+ else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) {
1453
+ const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1454
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1455
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1456
+ }
1457
+ else if (src0t == WSP_GGML_TYPE_Q4_K) {
1458
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1459
+ }
1460
+ else if (src0t == WSP_GGML_TYPE_Q3_K) {
1886
1461
  #ifdef WSP_GGML_QKK_64
1887
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1462
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1888
1463
  #else
1889
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1464
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1890
1465
  #endif
1891
- }
1892
- else if (src2t == WSP_GGML_TYPE_Q5_K) {
1893
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1894
- }
1895
- else if (src2t == WSP_GGML_TYPE_Q6_K) {
1896
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1897
- } else {
1898
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
1899
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
- }
1901
1466
  }
1902
- } break;
1903
- case WSP_GGML_OP_GET_ROWS:
1904
- {
1905
- switch (src0->type) {
1906
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1907
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1908
- case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1909
- case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1910
- case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1911
- case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1912
- case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1913
- case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1914
- case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
1915
- case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
1916
- case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
1917
- case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
1918
- default: WSP_GGML_ASSERT(false && "not implemented");
1467
+ else if (src0t == WSP_GGML_TYPE_Q5_K) {
1468
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1919
1469
  }
1920
-
1921
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1922
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1923
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1924
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1925
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1926
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1927
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1928
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1929
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1930
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1931
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1932
-
1933
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1934
- } break;
1935
- case WSP_GGML_OP_RMS_NORM:
1936
- {
1937
- WSP_GGML_ASSERT(ne00 % 4 == 0);
1938
-
1939
- float eps;
1940
- memcpy(&eps, dst->op_params, sizeof(float));
1941
-
1942
- int nth = 32; // SIMD width
1943
-
1944
- while (nth < ne00/4 && nth < 1024) {
1945
- nth *= 2;
1470
+ else if (src0t == WSP_GGML_TYPE_Q6_K) {
1471
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1472
+ } else {
1473
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1474
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1946
1475
  }
1476
+ }
1477
+ } break;
1478
+ case WSP_GGML_OP_MUL_MAT_ID:
1479
+ {
1480
+ //WSP_GGML_ASSERT(ne00 == ne10);
1481
+ //WSP_GGML_ASSERT(ne03 == ne13);
1947
1482
 
1948
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1949
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1950
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1951
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1952
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1953
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1954
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1955
-
1956
- const int64_t nrows = wsp_ggml_nrows(src0);
1957
-
1958
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1959
- } break;
1960
- case WSP_GGML_OP_GROUP_NORM:
1961
- {
1962
- WSP_GGML_ASSERT(ne00 % 4 == 0);
1963
-
1964
- //float eps;
1965
- //memcpy(&eps, dst->op_params, sizeof(float));
1966
-
1967
- const float eps = 1e-6f; // TODO: temporarily hardcoded
1968
-
1969
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1970
-
1971
- int nth = 32; // SIMD width
1972
-
1973
- //while (nth < ne00/4 && nth < 1024) {
1974
- // nth *= 2;
1975
- //}
1976
-
1977
- [encoder setComputePipelineState:ctx->pipeline_group_norm];
1978
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1979
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1980
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1981
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1982
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1983
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1984
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1985
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1986
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1987
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1988
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1989
-
1990
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1991
- } break;
1992
- case WSP_GGML_OP_NORM:
1993
- {
1994
- float eps;
1995
- memcpy(&eps, dst->op_params, sizeof(float));
1996
-
1997
- const int nth = MIN(256, ne00);
1998
-
1999
- [encoder setComputePipelineState:ctx->pipeline_norm];
2000
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2001
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2002
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2003
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2004
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2005
- [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2006
-
2007
- const int64_t nrows = wsp_ggml_nrows(src0);
2008
-
2009
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2010
- } break;
2011
- case WSP_GGML_OP_ALIBI:
2012
- {
2013
- WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
2014
-
2015
- const int nth = MIN(1024, ne00);
2016
-
2017
- //const int n_past = ((int32_t *) dst->op_params)[0];
2018
- const int n_head = ((int32_t *) dst->op_params)[1];
2019
- float max_bias;
2020
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2021
-
2022
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2023
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2024
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1483
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
2025
1484
 
2026
- [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
2027
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2028
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2029
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2030
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2031
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2032
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2033
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2034
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2035
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2036
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2037
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2038
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2039
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2040
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2041
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2042
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2043
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2044
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2045
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2046
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2047
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1485
+ const int n_as = ((int32_t *) dst->op_params)[1];
2048
1486
 
2049
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2050
- } break;
2051
- case WSP_GGML_OP_ROPE:
2052
- {
2053
- WSP_GGML_ASSERT(ne10 == ne02);
2054
-
2055
- const int nth = MIN(1024, ne00);
2056
-
2057
- const int n_past = ((int32_t *) dst->op_params)[0];
2058
- const int n_dims = ((int32_t *) dst->op_params)[1];
2059
- const int mode = ((int32_t *) dst->op_params)[2];
2060
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2061
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2062
-
2063
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2064
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2065
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2066
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2067
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2068
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2069
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1487
+ // TODO: make this more general
1488
+ WSP_GGML_ASSERT(n_as <= 8);
2070
1489
 
2071
- switch (src0->type) {
2072
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
2073
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
2074
- default: WSP_GGML_ASSERT(false);
2075
- };
1490
+ // max size of the src1ids array in the kernel stack
1491
+ WSP_GGML_ASSERT(ne11 <= 512);
2076
1492
 
2077
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2078
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2079
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2080
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2081
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2082
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2083
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2084
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2085
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2086
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2087
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2088
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2089
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2090
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2091
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2092
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2093
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2094
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2095
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2096
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2097
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2098
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2099
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2100
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2101
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2102
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2103
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2104
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2105
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1493
+ struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
2106
1494
 
2107
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2108
- } break;
2109
- case WSP_GGML_OP_IM2COL:
2110
- {
2111
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
2112
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
2113
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
1495
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1496
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1497
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1498
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
2114
1499
 
2115
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2116
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2117
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2118
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2119
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2120
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2121
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1500
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
1501
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1502
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1503
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
2122
1504
 
2123
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2124
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2125
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2126
- const int32_t IW = src1->ne[0];
1505
+ const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
2127
1506
 
2128
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2129
- const int32_t KW = src0->ne[0];
1507
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
1508
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
2130
1509
 
2131
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2132
- const int32_t OW = dst->ne[1];
1510
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
2133
1511
 
2134
- const int32_t CHW = IC * KH * KW;
1512
+ const uint r2 = ne12/ne22;
1513
+ const uint r3 = ne13/ne23;
2135
1514
 
2136
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2137
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1515
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1516
+ // to the matrix-vector kernel
1517
+ int ne11_mm_min = n_as;
2138
1518
 
2139
- switch (src0->type) {
2140
- case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
2141
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
2142
- default: WSP_GGML_ASSERT(false);
2143
- };
1519
+ const int idx = ((int32_t *) dst->op_params)[0];
2144
1520
 
2145
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2146
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2147
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2148
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2149
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2150
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2151
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2152
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2153
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2154
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2155
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2156
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2157
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2158
-
2159
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2160
- } break;
2161
- case WSP_GGML_OP_UPSCALE:
2162
- {
2163
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2164
-
2165
- const int sf = dst->op_params[0];
2166
-
2167
- [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2168
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2170
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2171
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2172
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2173
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2174
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2175
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2176
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2177
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2178
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2179
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2180
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2181
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2182
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2183
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2184
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2185
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2186
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2187
-
2188
- const int nth = MIN(1024, ne0);
2189
-
2190
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2191
- } break;
2192
- case WSP_GGML_OP_PAD:
2193
- {
2194
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2195
-
2196
- [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2197
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2198
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2199
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2200
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2201
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2202
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2203
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2204
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2205
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2206
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2207
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2208
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2209
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2210
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2211
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2212
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2213
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2214
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2215
-
2216
- const int nth = MIN(1024, ne0);
2217
-
2218
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2219
- } break;
2220
- case WSP_GGML_OP_ARGSORT:
2221
- {
2222
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2223
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
2224
-
2225
- const int nrows = wsp_ggml_nrows(src0);
2226
-
2227
- enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
2228
-
2229
- switch (order) {
2230
- case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2231
- case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2232
- default: WSP_GGML_ASSERT(false);
2233
- };
1521
+ // batch size
1522
+ WSP_GGML_ASSERT(ne01 == ne11);
2234
1523
 
2235
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2236
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2237
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1524
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1525
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1526
+ // !!!
1527
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1528
+ // indirect matrix multiplication
1529
+ // !!!
1530
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1531
+ ne20 % 32 == 0 && ne20 >= 64 &&
1532
+ ne11 > ne11_mm_min) {
2238
1533
 
2239
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2240
- } break;
2241
- case WSP_GGML_OP_LEAKY_RELU:
2242
- {
2243
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
1534
+ id<MTLComputePipelineState> pipeline = nil;
2244
1535
 
2245
- float slope;
2246
- memcpy(&slope, dst->op_params, sizeof(float));
1536
+ switch (src2->type) {
1537
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1538
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1539
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1540
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1541
+ case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1542
+ case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1543
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1544
+ case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1545
+ case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1546
+ case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1547
+ case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1548
+ case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1549
+ case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1550
+ case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1551
+ default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1552
+ }
2247
1553
 
2248
- [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2249
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2250
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2251
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
1554
+ [encoder setComputePipelineState:pipeline];
1555
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1556
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1557
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1558
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1559
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1560
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1561
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1562
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1563
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1564
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1565
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1566
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1567
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1568
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1569
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1570
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1571
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1572
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1573
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1574
+ // TODO: how to make this an array? read Metal docs
1575
+ for (int j = 0; j < 8; ++j) {
1576
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1577
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1578
+
1579
+ size_t offs_src_cur = 0;
1580
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1581
+
1582
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1583
+ }
2252
1584
 
2253
- const int64_t n = wsp_ggml_nelements(dst);
1585
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2254
1586
 
2255
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2256
- } break;
2257
- case WSP_GGML_OP_DUP:
2258
- case WSP_GGML_OP_CPY:
2259
- case WSP_GGML_OP_CONT:
2260
- {
2261
- WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
1587
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1588
+ } else {
1589
+ int nth0 = 32;
1590
+ int nth1 = 1;
1591
+ int nrows = 1;
1592
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2262
1593
 
2263
- int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
1594
+ id<MTLComputePipelineState> pipeline = nil;
2264
1595
 
2265
- switch (src0t) {
1596
+ // use custom matrix x vector kernel
1597
+ switch (src2t) {
2266
1598
  case WSP_GGML_TYPE_F32:
2267
1599
  {
2268
- WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
2269
-
2270
- switch (dstt) {
2271
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2272
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2273
- case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2274
- case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2275
- case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2276
- //case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2277
- //case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
2278
- default: WSP_GGML_ASSERT(false && "not implemented");
2279
- };
1600
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1601
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2280
1602
  } break;
2281
1603
  case WSP_GGML_TYPE_F16:
2282
1604
  {
2283
- switch (dstt) {
2284
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2285
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
2286
- default: WSP_GGML_ASSERT(false && "not implemented");
2287
- };
1605
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1606
+ nth0 = 32;
1607
+ nth1 = 1;
1608
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1609
+ } break;
1610
+ case WSP_GGML_TYPE_Q4_0:
1611
+ {
1612
+ nth0 = 8;
1613
+ nth1 = 8;
1614
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
1615
+ } break;
1616
+ case WSP_GGML_TYPE_Q4_1:
1617
+ {
1618
+ nth0 = 8;
1619
+ nth1 = 8;
1620
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1621
+ } break;
1622
+ case WSP_GGML_TYPE_Q5_0:
1623
+ {
1624
+ nth0 = 8;
1625
+ nth1 = 8;
1626
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1627
+ } break;
1628
+ case WSP_GGML_TYPE_Q5_1:
1629
+ {
1630
+ nth0 = 8;
1631
+ nth1 = 8;
1632
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1633
+ } break;
1634
+ case WSP_GGML_TYPE_Q8_0:
1635
+ {
1636
+ nth0 = 8;
1637
+ nth1 = 8;
1638
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1639
+ } break;
1640
+ case WSP_GGML_TYPE_Q2_K:
1641
+ {
1642
+ nth0 = 2;
1643
+ nth1 = 32;
1644
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
1645
+ } break;
1646
+ case WSP_GGML_TYPE_Q3_K:
1647
+ {
1648
+ nth0 = 2;
1649
+ nth1 = 32;
1650
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
1651
+ } break;
1652
+ case WSP_GGML_TYPE_Q4_K:
1653
+ {
1654
+ nth0 = 4; //1;
1655
+ nth1 = 8; //32;
1656
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
1657
+ } break;
1658
+ case WSP_GGML_TYPE_Q5_K:
1659
+ {
1660
+ nth0 = 2;
1661
+ nth1 = 32;
1662
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
1663
+ } break;
1664
+ case WSP_GGML_TYPE_Q6_K:
1665
+ {
1666
+ nth0 = 2;
1667
+ nth1 = 32;
1668
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
1669
+ } break;
1670
+ case WSP_GGML_TYPE_IQ2_XXS:
1671
+ {
1672
+ nth0 = 4;
1673
+ nth1 = 16;
1674
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2288
1675
  } break;
2289
- default: WSP_GGML_ASSERT(false && "not implemented");
1676
+ case WSP_GGML_TYPE_IQ2_XS:
1677
+ {
1678
+ nth0 = 4;
1679
+ nth1 = 16;
1680
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1681
+ } break;
1682
+ default:
1683
+ {
1684
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1685
+ WSP_GGML_ASSERT(false && "not implemented");
1686
+ }
1687
+ };
1688
+
1689
+ if (wsp_ggml_is_quantized(src2t)) {
1690
+ WSP_GGML_ASSERT(ne20 >= nth0*nth1);
2290
1691
  }
2291
1692
 
2292
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2293
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2294
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2295
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2296
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2297
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2298
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2299
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2300
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2301
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2302
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2303
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2304
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2305
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2306
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2307
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2308
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2309
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1693
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
2310
1694
 
2311
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2312
- } break;
2313
- default:
2314
- {
2315
- WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
2316
- WSP_GGML_ASSERT(false);
1695
+ [encoder setComputePipelineState:pipeline];
1696
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1697
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1698
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1699
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1700
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1701
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1702
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1703
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1704
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1705
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1706
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1707
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1708
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1709
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1710
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1711
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1712
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1713
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1714
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1715
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1716
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1717
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1718
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1719
+ // TODO: how to make this an array? read Metal docs
1720
+ for (int j = 0; j < 8; ++j) {
1721
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1722
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1723
+
1724
+ size_t offs_src_cur = 0;
1725
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1726
+
1727
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1728
+ }
1729
+
1730
+ if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
1731
+ src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
1732
+ src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
1733
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1734
+ }
1735
+ else if (src2t == WSP_GGML_TYPE_IQ2_XXS || src2t == WSP_GGML_TYPE_IQ2_XS) {
1736
+ const int mem_size = src2t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1737
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1738
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1739
+ }
1740
+ else if (src2t == WSP_GGML_TYPE_Q4_K) {
1741
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1742
+ }
1743
+ else if (src2t == WSP_GGML_TYPE_Q3_K) {
1744
+ #ifdef WSP_GGML_QKK_64
1745
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1746
+ #else
1747
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1748
+ #endif
1749
+ }
1750
+ else if (src2t == WSP_GGML_TYPE_Q5_K) {
1751
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1752
+ }
1753
+ else if (src2t == WSP_GGML_TYPE_Q6_K) {
1754
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1755
+ } else {
1756
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1757
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1758
+ }
1759
+ }
1760
+ } break;
1761
+ case WSP_GGML_OP_GET_ROWS:
1762
+ {
1763
+ id<MTLComputePipelineState> pipeline = nil;
1764
+
1765
+ switch (src0->type) {
1766
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
1767
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
1768
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
1769
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
1770
+ case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
1771
+ case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
1772
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
1773
+ case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
1774
+ case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
1775
+ case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
1776
+ case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
1777
+ case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1778
+ case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1779
+ case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1780
+ case WSP_GGML_TYPE_I32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1781
+ default: WSP_GGML_ASSERT(false && "not implemented");
1782
+ }
1783
+
1784
+ [encoder setComputePipelineState:pipeline];
1785
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1786
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1787
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1788
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1789
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1790
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1791
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1792
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1793
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1794
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1795
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1796
+
1797
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1798
+ } break;
1799
+ case WSP_GGML_OP_RMS_NORM:
1800
+ {
1801
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1802
+
1803
+ float eps;
1804
+ memcpy(&eps, dst->op_params, sizeof(float));
1805
+
1806
+ int nth = 32; // SIMD width
1807
+
1808
+ while (nth < ne00/4 && nth < 1024) {
1809
+ nth *= 2;
2317
1810
  }
2318
- }
2319
- }
2320
1811
 
2321
- if (encoder != nil) {
2322
- [encoder endEncoding];
2323
- encoder = nil;
1812
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
1813
+
1814
+ [encoder setComputePipelineState:pipeline];
1815
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1816
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1817
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1818
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1819
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1820
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1821
+
1822
+ const int64_t nrows = wsp_ggml_nrows(src0);
1823
+
1824
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1825
+ } break;
1826
+ case WSP_GGML_OP_GROUP_NORM:
1827
+ {
1828
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1829
+
1830
+ //float eps;
1831
+ //memcpy(&eps, dst->op_params, sizeof(float));
1832
+
1833
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1834
+
1835
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1836
+
1837
+ int nth = 32; // SIMD width
1838
+
1839
+ //while (nth < ne00/4 && nth < 1024) {
1840
+ // nth *= 2;
1841
+ //}
1842
+
1843
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
1844
+
1845
+ [encoder setComputePipelineState:pipeline];
1846
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1847
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1848
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1849
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1850
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1851
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1852
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1853
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1854
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1855
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1856
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1857
+
1858
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1859
+ } break;
1860
+ case WSP_GGML_OP_NORM:
1861
+ {
1862
+ float eps;
1863
+ memcpy(&eps, dst->op_params, sizeof(float));
1864
+
1865
+ const int nth = MIN(256, ne00);
1866
+
1867
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_NORM].pipeline;
1868
+
1869
+ [encoder setComputePipelineState:pipeline];
1870
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1871
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1872
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1873
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1874
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1875
+ [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1876
+
1877
+ const int64_t nrows = wsp_ggml_nrows(src0);
1878
+
1879
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1880
+ } break;
1881
+ case WSP_GGML_OP_ALIBI:
1882
+ {
1883
+ WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
1884
+
1885
+ const int nth = MIN(1024, ne00);
1886
+
1887
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1888
+ const int n_head = ((int32_t *) dst->op_params)[1];
1889
+ float max_bias;
1890
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1891
+
1892
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1893
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1894
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1895
+
1896
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
1897
+
1898
+ [encoder setComputePipelineState:pipeline];
1899
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1900
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1901
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1902
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1903
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1904
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1905
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1906
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1907
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1908
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1909
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1910
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1911
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1912
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1913
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1914
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1915
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1916
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1917
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1918
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1919
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1920
+
1921
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1922
+ } break;
1923
+ case WSP_GGML_OP_ROPE:
1924
+ {
1925
+ WSP_GGML_ASSERT(ne10 == ne02);
1926
+
1927
+ const int nth = MIN(1024, ne00);
1928
+
1929
+ const int n_past = ((int32_t *) dst->op_params)[0];
1930
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1931
+ const int mode = ((int32_t *) dst->op_params)[2];
1932
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1933
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1934
+
1935
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1936
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1937
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1938
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1939
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1940
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1941
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1942
+
1943
+ id<MTLComputePipelineState> pipeline = nil;
1944
+
1945
+ switch (src0->type) {
1946
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
1947
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
1948
+ default: WSP_GGML_ASSERT(false);
1949
+ };
1950
+
1951
+ [encoder setComputePipelineState:pipeline];
1952
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1953
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1954
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1955
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1956
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1957
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1958
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1959
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1960
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1961
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1962
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1963
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1964
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1965
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1966
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1967
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1968
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1969
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1970
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1971
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1972
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1973
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1974
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1975
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
1976
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
1977
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
1978
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
1979
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
1980
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1981
+
1982
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1983
+ } break;
1984
+ case WSP_GGML_OP_IM2COL:
1985
+ {
1986
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
1987
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1988
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
1989
+
1990
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1991
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1992
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1993
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1994
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1995
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1996
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1997
+
1998
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1999
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2000
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2001
+ const int32_t IW = src1->ne[0];
2002
+
2003
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2004
+ const int32_t KW = src0->ne[0];
2005
+
2006
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2007
+ const int32_t OW = dst->ne[1];
2008
+
2009
+ const int32_t CHW = IC * KH * KW;
2010
+
2011
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2012
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2013
+
2014
+ id<MTLComputePipelineState> pipeline = nil;
2015
+
2016
+ switch (src0->type) {
2017
+ case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
2018
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2019
+ default: WSP_GGML_ASSERT(false);
2020
+ };
2021
+
2022
+ [encoder setComputePipelineState:pipeline];
2023
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2024
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2025
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2026
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2027
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2028
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2029
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2030
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2031
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2032
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2033
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2034
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2035
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2036
+
2037
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2038
+ } break;
2039
+ case WSP_GGML_OP_UPSCALE:
2040
+ {
2041
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2042
+
2043
+ const int sf = dst->op_params[0];
2044
+
2045
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2046
+
2047
+ [encoder setComputePipelineState:pipeline];
2048
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2049
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2050
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2051
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2052
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2053
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2054
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2055
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2056
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2057
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2058
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2059
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2060
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2061
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2062
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2063
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2064
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2065
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2066
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2067
+
2068
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2069
+
2070
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2071
+ } break;
2072
+ case WSP_GGML_OP_PAD:
2073
+ {
2074
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2075
+
2076
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2077
+
2078
+ [encoder setComputePipelineState:pipeline];
2079
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2080
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2081
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2082
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2083
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2084
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2085
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2086
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2087
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2088
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2089
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2090
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2091
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2092
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2093
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2094
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2095
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2096
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2097
+
2098
+ const int nth = MIN(1024, ne0);
2099
+
2100
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2101
+ } break;
2102
+ case WSP_GGML_OP_ARGSORT:
2103
+ {
2104
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2105
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
2106
+
2107
+ const int nrows = wsp_ggml_nrows(src0);
2108
+
2109
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
2110
+
2111
+ id<MTLComputePipelineState> pipeline = nil;
2112
+
2113
+ switch (order) {
2114
+ case WSP_GGML_SORT_ASC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2115
+ case WSP_GGML_SORT_DESC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2116
+ default: WSP_GGML_ASSERT(false);
2117
+ };
2118
+
2119
+ [encoder setComputePipelineState:pipeline];
2120
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2121
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2122
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2123
+
2124
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2125
+ } break;
2126
+ case WSP_GGML_OP_LEAKY_RELU:
2127
+ {
2128
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2129
+
2130
+ float slope;
2131
+ memcpy(&slope, dst->op_params, sizeof(float));
2132
+
2133
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2134
+
2135
+ [encoder setComputePipelineState:pipeline];
2136
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2137
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2138
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2139
+
2140
+ const int64_t n = wsp_ggml_nelements(dst);
2141
+
2142
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2143
+ } break;
2144
+ case WSP_GGML_OP_DUP:
2145
+ case WSP_GGML_OP_CPY:
2146
+ case WSP_GGML_OP_CONT:
2147
+ {
2148
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
2149
+
2150
+ int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
2151
+
2152
+ id<MTLComputePipelineState> pipeline = nil;
2153
+
2154
+ switch (src0t) {
2155
+ case WSP_GGML_TYPE_F32:
2156
+ {
2157
+ WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
2158
+
2159
+ switch (dstt) {
2160
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2161
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2162
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2163
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2164
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2165
+ //case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2166
+ //case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2167
+ default: WSP_GGML_ASSERT(false && "not implemented");
2168
+ };
2169
+ } break;
2170
+ case WSP_GGML_TYPE_F16:
2171
+ {
2172
+ switch (dstt) {
2173
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2174
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2175
+ default: WSP_GGML_ASSERT(false && "not implemented");
2176
+ };
2177
+ } break;
2178
+ default: WSP_GGML_ASSERT(false && "not implemented");
2179
+ }
2180
+
2181
+ [encoder setComputePipelineState:pipeline];
2182
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2183
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2184
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2185
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2186
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2187
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2188
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2189
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2190
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2191
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2192
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2193
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2194
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2195
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2196
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2197
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2198
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2199
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2200
+
2201
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2202
+ } break;
2203
+ default:
2204
+ {
2205
+ WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
2206
+ WSP_GGML_ASSERT(false);
2207
+ }
2324
2208
  }
2325
2209
 
2326
- [command_buffer commit];
2327
- });
2328
- }
2210
+ #ifndef WSP_GGML_METAL_NDEBUG
2211
+ [encoder popDebugGroup];
2212
+ #endif
2213
+ }
2214
+
2215
+ [encoder endEncoding];
2329
2216
 
2330
- // wait for all threads to finish
2331
- dispatch_barrier_sync(ctx->d_queue, ^{});
2217
+ [command_buffer commit];
2218
+ });
2332
2219
 
2333
- // check status of command buffers
2220
+ // Wait for completion and check status of each command buffer
2334
2221
  // needed to detect if the device ran out-of-memory for example (#1881)
2335
- for (int i = 0; i < n_cb; i++) {
2336
- [ctx->command_buffers[i] waitUntilCompleted];
2337
2222
 
2338
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
2223
+ for (int i = 0; i < n_cb; ++i) {
2224
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
2225
+ [command_buffer waitUntilCompleted];
2226
+
2227
+ MTLCommandBufferStatus status = [command_buffer status];
2339
2228
  if (status != MTLCommandBufferStatusCompleted) {
2340
2229
  WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2341
- WSP_GGML_ASSERT(false);
2230
+ return false;
2342
2231
  }
2343
2232
  }
2344
2233
 
2345
- }
2234
+ return true;
2346
2235
  }
2347
2236
 
2348
2237
  ////////////////////////////////////////////////////////////////////////////////
2349
2238
 
2350
2239
  // backend interface
2351
2240
 
2241
+ // default buffer
2352
2242
  static id<MTLDevice> g_backend_device = nil;
2353
2243
  static int g_backend_device_ref_count = 0;
2354
2244
 
@@ -2372,64 +2262,98 @@ static void wsp_ggml_backend_metal_free_device(void) {
2372
2262
  }
2373
2263
  }
2374
2264
 
2375
- static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
2376
- struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2265
+ WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_get_name(wsp_ggml_backend_buffer_t buffer) {
2266
+ return "Metal";
2377
2267
 
2378
- return ctx->data;
2268
+ UNUSED(buffer);
2379
2269
  }
2380
2270
 
2381
- static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
2271
+ WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
2382
2272
  struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2383
2273
 
2384
2274
  wsp_ggml_backend_metal_free_device();
2385
2275
 
2386
- free(ctx->data);
2387
- free(ctx);
2276
+ if (ctx->owned) {
2277
+ free(ctx->all_data);
2278
+ }
2388
2279
 
2389
- UNUSED(buffer);
2280
+ free(ctx);
2390
2281
  }
2391
2282
 
2392
- static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2393
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
2394
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2283
+ WSP_GGML_CALL static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
2284
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2285
+
2286
+ return ctx->all_data;
2287
+ }
2395
2288
 
2289
+ WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2396
2290
  memcpy((char *)tensor->data + offset, data, size);
2397
2291
 
2398
2292
  UNUSED(buffer);
2399
2293
  }
2400
2294
 
2401
- static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2402
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
2403
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2404
-
2295
+ WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2405
2296
  memcpy(data, (const char *)tensor->data + offset, size);
2406
2297
 
2407
2298
  UNUSED(buffer);
2408
2299
  }
2409
2300
 
2410
- static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
2411
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
2301
+ WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
2302
+ if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
2303
+ memcpy(dst->data, src->data, wsp_ggml_nbytes(src));
2304
+ return true;
2305
+ }
2306
+ return false;
2412
2307
 
2413
2308
  UNUSED(buffer);
2414
2309
  }
2415
2310
 
2416
- static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
2417
- wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
2311
+ WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
2312
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2418
2313
 
2419
- UNUSED(buffer);
2314
+ memset(ctx->all_data, value, ctx->all_size);
2420
2315
  }
2421
2316
 
2422
- static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
2317
+ static struct wsp_ggml_backend_buffer_i wsp_ggml_backend_metal_buffer_i = {
2318
+ /* .get_name = */ wsp_ggml_backend_metal_buffer_get_name,
2423
2319
  /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
2424
2320
  /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
2425
2321
  /* .init_tensor = */ NULL,
2426
2322
  /* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
2427
2323
  /* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
2428
- /* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
2429
- /* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
2324
+ /* .cpy_tensor = */ wsp_ggml_backend_metal_buffer_cpy_tensor,
2325
+ /* .clear = */ wsp_ggml_backend_metal_buffer_clear,
2326
+ /* .reset = */ NULL,
2430
2327
  };
2431
2328
 
2432
- static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
2329
+ // default buffer type
2330
+
2331
+ WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
2332
+ return "Metal";
2333
+
2334
+ UNUSED(buft);
2335
+ }
2336
+
2337
+ static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
2338
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2339
+ if (@available(macOS 10.12, iOS 16.0, *)) {
2340
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2341
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2342
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2343
+
2344
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2345
+ WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2346
+ } else {
2347
+ WSP_GGML_METAL_LOG_INFO("\n");
2348
+ }
2349
+ } else {
2350
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2351
+ }
2352
+ #endif
2353
+ UNUSED(device);
2354
+ }
2355
+
2356
+ WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
2433
2357
  struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
2434
2358
 
2435
2359
  const size_t size_page = sysconf(_SC_PAGESIZE);
@@ -2439,33 +2363,59 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
2439
2363
  size_aligned += (size_page - (size_aligned % size_page));
2440
2364
  }
2441
2365
 
2442
- ctx->data = wsp_ggml_metal_host_malloc(size);
2443
- ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
2366
+ id<MTLDevice> device = wsp_ggml_backend_metal_get_device();
2367
+
2368
+ ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
2369
+ ctx->all_size = size_aligned;
2370
+ ctx->owned = true;
2371
+ ctx->n_buffers = 1;
2372
+
2373
+ ctx->buffers[0].data = ctx->all_data;
2374
+ ctx->buffers[0].size = size;
2375
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
2444
2376
  length:size_aligned
2445
2377
  options:MTLResourceStorageModeShared
2446
2378
  deallocator:nil];
2447
2379
 
2448
- return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
2380
+ if (ctx->buffers[0].metal == nil) {
2381
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2382
+ free(ctx);
2383
+ wsp_ggml_backend_metal_free_device();
2384
+ return NULL;
2385
+ }
2386
+
2387
+ WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2388
+ wsp_ggml_backend_metal_log_allocated_size(device);
2389
+
2390
+ return wsp_ggml_backend_buffer_init(buft, wsp_ggml_backend_metal_buffer_i, ctx, size);
2449
2391
  }
2450
2392
 
2451
- static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
2393
+ WSP_GGML_CALL static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
2452
2394
  return 32;
2453
2395
  UNUSED(buft);
2454
2396
  }
2455
2397
 
2456
- static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
2398
+ WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
2457
2399
  return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
2458
2400
 
2459
- WSP_GGML_UNUSED(buft);
2401
+ UNUSED(buft);
2402
+ }
2403
+
2404
+ WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
2405
+ return true;
2406
+
2407
+ UNUSED(buft);
2460
2408
  }
2461
2409
 
2462
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
2410
+ WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
2463
2411
  static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
2464
2412
  /* .iface = */ {
2413
+ /* .get_name = */ wsp_ggml_backend_metal_buffer_type_get_name,
2465
2414
  /* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
2466
2415
  /* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
2467
2416
  /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
2468
2417
  /* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
2418
+ /* .is_host = */ wsp_ggml_backend_metal_buffer_type_is_host,
2469
2419
  },
2470
2420
  /* .context = */ NULL,
2471
2421
  };
@@ -2473,67 +2423,134 @@ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
2473
2423
  return &wsp_ggml_backend_buffer_type_metal;
2474
2424
  }
2475
2425
 
2476
- static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
2426
+ // buffer from ptr
2427
+
2428
+ WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2429
+ struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
2430
+
2431
+ ctx->all_data = data;
2432
+ ctx->all_size = size;
2433
+ ctx->owned = false;
2434
+ ctx->n_buffers = 0;
2435
+
2436
+ const size_t size_page = sysconf(_SC_PAGESIZE);
2437
+
2438
+ // page-align the data ptr
2439
+ {
2440
+ const uintptr_t offs = (uintptr_t) data % size_page;
2441
+ data = (void *) ((char *) data - offs);
2442
+ size += offs;
2443
+ }
2444
+
2445
+ size_t size_aligned = size;
2446
+ if ((size_aligned % size_page) != 0) {
2447
+ size_aligned += (size_page - (size_aligned % size_page));
2448
+ }
2449
+
2450
+ id<MTLDevice> device = wsp_ggml_backend_metal_get_device();
2451
+
2452
+ // the buffer fits into the max buffer size allowed by the device
2453
+ if (size_aligned <= device.maxBufferLength) {
2454
+ ctx->buffers[ctx->n_buffers].data = data;
2455
+ ctx->buffers[ctx->n_buffers].size = size;
2456
+
2457
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
2458
+
2459
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
2460
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2461
+ return false;
2462
+ }
2463
+
2464
+ WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2465
+
2466
+ ++ctx->n_buffers;
2467
+ } else {
2468
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
2469
+ // one of the views
2470
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
2471
+ const size_t size_step = device.maxBufferLength - size_ovlp;
2472
+ const size_t size_view = device.maxBufferLength;
2473
+
2474
+ for (size_t i = 0; i < size; i += size_step) {
2475
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
2476
+
2477
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
2478
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
2479
+
2480
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
2481
+
2482
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
2483
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
2484
+ return false;
2485
+ }
2486
+
2487
+ WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
2488
+ if (i + size_step < size) {
2489
+ WSP_GGML_METAL_LOG_INFO("\n");
2490
+ }
2491
+
2492
+ ++ctx->n_buffers;
2493
+ }
2494
+ }
2495
+
2496
+ wsp_ggml_backend_metal_log_allocated_size(device);
2497
+
2498
+ return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
2499
+ }
2500
+
2501
+ // backend
2502
+
2503
+ WSP_GGML_CALL static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
2477
2504
  return "Metal";
2478
2505
 
2479
2506
  UNUSED(backend);
2480
2507
  }
2481
2508
 
2482
- static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
2509
+ WSP_GGML_CALL static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
2483
2510
  struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2484
2511
  wsp_ggml_metal_free(ctx);
2485
2512
  free(backend);
2486
2513
  }
2487
2514
 
2488
- static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
2489
- UNUSED(backend);
2490
- }
2491
-
2492
- static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
2515
+ WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
2493
2516
  return wsp_ggml_backend_metal_buffer_type();
2494
2517
 
2495
2518
  UNUSED(backend);
2496
2519
  }
2497
2520
 
2498
- static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
2521
+ WSP_GGML_CALL static bool wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
2499
2522
  struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
2500
2523
 
2501
- wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
2524
+ return wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
2502
2525
  }
2503
2526
 
2504
- static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
2505
- return wsp_ggml_metal_supports_op(op);
2527
+ WSP_GGML_CALL static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
2528
+ struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
2506
2529
 
2507
- UNUSED(backend);
2530
+ return wsp_ggml_metal_supports_op(metal_ctx, op);
2508
2531
  }
2509
2532
 
2510
- static struct wsp_ggml_backend_i metal_backend_i = {
2533
+ static struct wsp_ggml_backend_i wsp_ggml_backend_metal_i = {
2511
2534
  /* .get_name = */ wsp_ggml_backend_metal_name,
2512
2535
  /* .free = */ wsp_ggml_backend_metal_free,
2513
2536
  /* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
2514
2537
  /* .set_tensor_async = */ NULL,
2515
2538
  /* .get_tensor_async = */ NULL,
2516
- /* .cpy_tensor_from_async = */ NULL,
2517
- /* .cpy_tensor_to_async = */ NULL,
2518
- /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
2519
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2539
+ /* .cpy_tensor_async = */ NULL,
2540
+ /* .synchronize = */ NULL,
2541
+ /* .graph_plan_create = */ NULL,
2520
2542
  /* .graph_plan_free = */ NULL,
2521
2543
  /* .graph_plan_compute = */ NULL,
2522
2544
  /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
2523
2545
  /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
2524
2546
  };
2525
2547
 
2526
- // TODO: make a common log callback for all backends in ggml-backend
2527
- static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
2528
- fprintf(stderr, "%s", msg);
2529
-
2530
- UNUSED(level);
2531
- UNUSED(user_data);
2548
+ void wsp_ggml_backend_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
2549
+ wsp_ggml_metal_log_callback = log_callback;
2550
+ wsp_ggml_metal_log_user_data = user_data;
2532
2551
  }
2533
2552
 
2534
2553
  wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
2535
- wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
2536
-
2537
2554
  struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
2538
2555
 
2539
2556
  if (ctx == NULL) {
@@ -2543,7 +2560,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
2543
2560
  wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
2544
2561
 
2545
2562
  *metal_backend = (struct wsp_ggml_backend) {
2546
- /* .interface = */ metal_backend_i,
2563
+ /* .interface = */ wsp_ggml_backend_metal_i,
2547
2564
  /* .context = */ ctx,
2548
2565
  };
2549
2566
 
@@ -2551,7 +2568,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
2551
2568
  }
2552
2569
 
2553
2570
  bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
2554
- return backend->iface.get_name == wsp_ggml_backend_metal_name;
2571
+ return backend && backend->iface.get_name == wsp_ggml_backend_metal_name;
2555
2572
  }
2556
2573
 
2557
2574
  void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
@@ -2559,7 +2576,7 @@ void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
2559
2576
 
2560
2577
  struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2561
2578
 
2562
- wsp_ggml_metal_set_n_cb(ctx, n_cb);
2579
+ ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
2563
2580
  }
2564
2581
 
2565
2582
  bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
@@ -2570,9 +2587,9 @@ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int fami
2570
2587
  return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2571
2588
  }
2572
2589
 
2573
- wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2590
+ WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2574
2591
 
2575
- wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
2592
+ WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
2576
2593
  return wsp_ggml_backend_metal_init();
2577
2594
 
2578
2595
  WSP_GGML_UNUSED(params);