whisper.rn 0.4.0-rc.6 → 0.4.0-rc.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +5 -5
- package/cpp/coreml/whisper-encoder.mm +1 -1
- package/cpp/ggml-alloc.c +41 -11
- package/cpp/ggml-alloc.h +3 -1
- package/cpp/ggml-backend-impl.h +38 -34
- package/cpp/ggml-backend.c +630 -269
- package/cpp/ggml-backend.h +58 -30
- package/cpp/ggml-impl.h +3 -0
- package/cpp/ggml-metal-whisper.metal +1253 -341
- package/cpp/ggml-metal.h +6 -54
- package/cpp/ggml-metal.m +2004 -1987
- package/cpp/ggml-quants.c +2230 -421
- package/cpp/ggml-quants.h +39 -1
- package/cpp/ggml.c +735 -265
- package/cpp/ggml.h +94 -43
- package/cpp/rn-whisper.cpp +1 -0
- package/cpp/whisper.cpp +118 -86
- package/ios/RNWhisperContext.mm +4 -2
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +1 -1
- package/src/version.json +1 -1
package/cpp/ggml-metal.m
CHANGED
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
|
|
25
25
|
#define UNUSED(x) (void)(x)
|
|
26
26
|
|
|
27
|
-
#define
|
|
27
|
+
#define WSP_GGML_METAL_MAX_KERNELS 256
|
|
28
28
|
|
|
29
29
|
struct wsp_ggml_metal_buffer {
|
|
30
30
|
const char * name;
|
|
@@ -35,6 +35,134 @@ struct wsp_ggml_metal_buffer {
|
|
|
35
35
|
id<MTLBuffer> metal;
|
|
36
36
|
};
|
|
37
37
|
|
|
38
|
+
struct wsp_ggml_metal_kernel {
|
|
39
|
+
id<MTLFunction> function;
|
|
40
|
+
id<MTLComputePipelineState> pipeline;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
enum wsp_ggml_metal_kernel_type {
|
|
44
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD,
|
|
45
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
|
46
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL,
|
|
47
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
|
48
|
+
WSP_GGML_METAL_KERNEL_TYPE_DIV,
|
|
49
|
+
WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
|
50
|
+
WSP_GGML_METAL_KERNEL_TYPE_SCALE,
|
|
51
|
+
WSP_GGML_METAL_KERNEL_TYPE_SCALE_4,
|
|
52
|
+
WSP_GGML_METAL_KERNEL_TYPE_TANH,
|
|
53
|
+
WSP_GGML_METAL_KERNEL_TYPE_RELU,
|
|
54
|
+
WSP_GGML_METAL_KERNEL_TYPE_GELU,
|
|
55
|
+
WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
|
56
|
+
WSP_GGML_METAL_KERNEL_TYPE_SILU,
|
|
57
|
+
WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
|
58
|
+
WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
|
59
|
+
WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
|
60
|
+
WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
|
61
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
|
62
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
|
63
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
|
64
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
|
65
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
66
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
|
67
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
|
68
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
|
69
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
|
70
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
71
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
|
|
72
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
|
73
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
|
74
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
|
75
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
76
|
+
WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
77
|
+
WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
78
|
+
WSP_GGML_METAL_KERNEL_TYPE_NORM,
|
|
79
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
80
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
81
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
82
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
83
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
84
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
|
85
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
|
86
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
87
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
88
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
89
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
|
90
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
|
91
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
92
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
|
|
93
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
|
94
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
|
95
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
|
96
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
97
|
+
//WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
98
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
99
|
+
//WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
|
100
|
+
//WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
|
101
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
|
102
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
|
103
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
104
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
|
105
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
|
106
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
|
107
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
|
108
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
109
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
|
|
110
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
|
111
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
|
112
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
|
113
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
114
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
115
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
116
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
|
117
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
118
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
|
119
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
|
120
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
|
121
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
|
122
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
123
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
|
|
124
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
|
125
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
|
126
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
|
127
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
128
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
129
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
130
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
|
131
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
|
132
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
|
133
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
|
134
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
|
135
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
|
136
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
|
137
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
|
138
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
|
139
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
|
140
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
|
141
|
+
WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
142
|
+
WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
143
|
+
WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
144
|
+
WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
145
|
+
WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
146
|
+
WSP_GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
147
|
+
WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
148
|
+
WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
|
149
|
+
WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
|
150
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
151
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
152
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
153
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
|
154
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
|
155
|
+
//WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
|
156
|
+
//WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
|
157
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
158
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
|
159
|
+
WSP_GGML_METAL_KERNEL_TYPE_CONCAT,
|
|
160
|
+
WSP_GGML_METAL_KERNEL_TYPE_SQR,
|
|
161
|
+
WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
162
|
+
|
|
163
|
+
WSP_GGML_METAL_KERNEL_TYPE_COUNT
|
|
164
|
+
};
|
|
165
|
+
|
|
38
166
|
struct wsp_ggml_metal_context {
|
|
39
167
|
int n_cb;
|
|
40
168
|
|
|
@@ -42,131 +170,15 @@ struct wsp_ggml_metal_context {
|
|
|
42
170
|
id<MTLCommandQueue> queue;
|
|
43
171
|
id<MTLLibrary> library;
|
|
44
172
|
|
|
45
|
-
id<MTLCommandBuffer> command_buffers [WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
46
|
-
id<MTLComputeCommandEncoder> command_encoders[WSP_GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
47
|
-
|
|
48
173
|
dispatch_queue_t d_queue;
|
|
49
174
|
|
|
50
175
|
int n_buffers;
|
|
51
176
|
struct wsp_ggml_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
|
|
52
177
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
#define WSP_GGML_METAL_DECL_KERNEL(name) \
|
|
58
|
-
id<MTLFunction> function_##name; \
|
|
59
|
-
id<MTLComputePipelineState> pipeline_##name
|
|
60
|
-
|
|
61
|
-
WSP_GGML_METAL_DECL_KERNEL(add);
|
|
62
|
-
WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
63
|
-
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
64
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
65
|
-
WSP_GGML_METAL_DECL_KERNEL(div);
|
|
66
|
-
WSP_GGML_METAL_DECL_KERNEL(div_row);
|
|
67
|
-
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
68
|
-
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
69
|
-
WSP_GGML_METAL_DECL_KERNEL(tanh);
|
|
70
|
-
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
71
|
-
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
72
|
-
WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
|
|
73
|
-
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
74
|
-
WSP_GGML_METAL_DECL_KERNEL(soft_max);
|
|
75
|
-
WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
|
|
76
|
-
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
77
|
-
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
|
78
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_f32);
|
|
79
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
80
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
81
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
82
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
|
83
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
|
84
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
|
85
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
|
86
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
|
87
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
|
88
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
|
89
|
-
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
90
|
-
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
91
|
-
WSP_GGML_METAL_DECL_KERNEL(group_norm);
|
|
92
|
-
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
93
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
94
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
95
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
96
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
97
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
98
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
|
99
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
|
100
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
|
101
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
|
102
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
|
103
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
|
104
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
|
105
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
106
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
107
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
108
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
|
109
|
-
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
|
110
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
111
|
-
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
112
|
-
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
113
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
|
114
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
|
115
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
|
116
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
|
117
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
|
118
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
|
119
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
|
120
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
|
121
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
|
122
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
|
123
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
124
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
125
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
126
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
|
127
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
|
128
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
|
129
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
|
130
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
|
131
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
|
132
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
133
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
134
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
135
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
|
136
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
|
137
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
138
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
|
139
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
|
140
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
|
141
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
|
142
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
|
143
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
|
144
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
|
145
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
|
146
|
-
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
147
|
-
WSP_GGML_METAL_DECL_KERNEL(rope_f32);
|
|
148
|
-
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
149
|
-
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
150
|
-
WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
151
|
-
WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
|
|
152
|
-
WSP_GGML_METAL_DECL_KERNEL(pad_f32);
|
|
153
|
-
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
154
|
-
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
155
|
-
WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
|
156
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
157
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
158
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
159
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
|
160
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
|
161
|
-
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
162
|
-
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
163
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
164
|
-
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
|
165
|
-
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
166
|
-
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
167
|
-
WSP_GGML_METAL_DECL_KERNEL(sum_rows);
|
|
168
|
-
|
|
169
|
-
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
178
|
+
struct wsp_ggml_metal_kernel kernels[WSP_GGML_METAL_MAX_KERNELS];
|
|
179
|
+
|
|
180
|
+
bool support_simdgroup_reduction;
|
|
181
|
+
bool support_simdgroup_mm;
|
|
170
182
|
};
|
|
171
183
|
|
|
172
184
|
// MSL code
|
|
@@ -180,14 +192,16 @@ struct wsp_ggml_metal_context {
|
|
|
180
192
|
@implementation WSPGGMLMetalClass
|
|
181
193
|
@end
|
|
182
194
|
|
|
183
|
-
|
|
184
|
-
|
|
195
|
+
static void wsp_ggml_metal_default_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
|
|
196
|
+
fprintf(stderr, "%s", msg);
|
|
185
197
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
wsp_ggml_metal_log_user_data = user_data;
|
|
198
|
+
UNUSED(level);
|
|
199
|
+
UNUSED(user_data);
|
|
189
200
|
}
|
|
190
201
|
|
|
202
|
+
wsp_ggml_log_callback wsp_ggml_metal_log_callback = wsp_ggml_metal_default_log_callback;
|
|
203
|
+
void * wsp_ggml_metal_log_user_data = NULL;
|
|
204
|
+
|
|
191
205
|
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
192
206
|
static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
|
|
193
207
|
if (wsp_ggml_metal_log_callback != NULL) {
|
|
@@ -210,24 +224,33 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma
|
|
|
210
224
|
}
|
|
211
225
|
}
|
|
212
226
|
|
|
213
|
-
|
|
214
|
-
|
|
227
|
+
static void * wsp_ggml_metal_host_malloc(size_t n) {
|
|
228
|
+
void * data = NULL;
|
|
229
|
+
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
230
|
+
if (result != 0) {
|
|
231
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
232
|
+
return NULL;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
return data;
|
|
236
|
+
}
|
|
215
237
|
|
|
216
|
-
|
|
217
|
-
|
|
238
|
+
static struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
239
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
218
240
|
|
|
219
|
-
#if TARGET_OS_OSX
|
|
241
|
+
#if TARGET_OS_OSX && !WSP_GGML_METAL_NDEBUG
|
|
220
242
|
// Show all the Metal device instances in the system
|
|
221
243
|
NSArray * devices = MTLCopyAllDevices();
|
|
222
|
-
for (device in devices) {
|
|
223
|
-
s = [device name];
|
|
244
|
+
for (id<MTLDevice> device in devices) {
|
|
245
|
+
NSString * s = [device name];
|
|
224
246
|
WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
|
225
247
|
}
|
|
248
|
+
[devices release]; // since it was created by a *Copy* C method
|
|
226
249
|
#endif
|
|
227
250
|
|
|
228
251
|
// Pick and show default Metal device
|
|
229
|
-
device = MTLCreateSystemDefaultDevice();
|
|
230
|
-
s = [device name];
|
|
252
|
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
253
|
+
NSString * s = [device name];
|
|
231
254
|
WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
232
255
|
|
|
233
256
|
// Configure context
|
|
@@ -236,7 +259,6 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
236
259
|
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
237
260
|
ctx->queue = [ctx->device newCommandQueue];
|
|
238
261
|
ctx->n_buffers = 0;
|
|
239
|
-
ctx->concur_list_len = 0;
|
|
240
262
|
|
|
241
263
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
242
264
|
|
|
@@ -251,6 +273,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
251
273
|
NSError * error = nil;
|
|
252
274
|
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
253
275
|
if (libPath != nil) {
|
|
276
|
+
// pre-compiled library found
|
|
254
277
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
255
278
|
WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
|
256
279
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
@@ -278,12 +301,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
278
301
|
return NULL;
|
|
279
302
|
}
|
|
280
303
|
|
|
281
|
-
|
|
304
|
+
@autoreleasepool {
|
|
305
|
+
// dictionary of preprocessor macros
|
|
306
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
|
307
|
+
|
|
282
308
|
#ifdef WSP_GGML_QKK_64
|
|
283
|
-
|
|
284
|
-
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
|
309
|
+
prep[@"QK_K"] = @(64);
|
|
285
310
|
#endif
|
|
286
|
-
|
|
311
|
+
|
|
312
|
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
|
313
|
+
options.preprocessorMacros = prep;
|
|
314
|
+
|
|
315
|
+
//[options setFastMathEnabled:false];
|
|
316
|
+
|
|
317
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
318
|
+
}
|
|
287
319
|
}
|
|
288
320
|
|
|
289
321
|
if (error) {
|
|
@@ -292,22 +324,51 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
292
324
|
}
|
|
293
325
|
}
|
|
294
326
|
|
|
295
|
-
#if TARGET_OS_OSX
|
|
296
327
|
// print MTL GPU family:
|
|
297
328
|
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
298
329
|
|
|
330
|
+
const NSInteger MTLGPUFamilyMetal3 = 5001;
|
|
331
|
+
|
|
299
332
|
// determine max supported GPU family
|
|
300
333
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
301
334
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
335
|
+
{
|
|
336
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
337
|
+
if ([ctx->device supportsFamily:i]) {
|
|
338
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
339
|
+
break;
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
|
344
|
+
if ([ctx->device supportsFamily:i]) {
|
|
345
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
|
346
|
+
break;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
|
|
351
|
+
if ([ctx->device supportsFamily:i]) {
|
|
352
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
|
|
353
|
+
break;
|
|
354
|
+
}
|
|
306
355
|
}
|
|
307
356
|
}
|
|
308
357
|
|
|
358
|
+
ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
359
|
+
ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
|
|
360
|
+
|
|
361
|
+
ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
362
|
+
|
|
363
|
+
WSP_GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
|
|
364
|
+
WSP_GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
|
309
365
|
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
310
|
-
|
|
366
|
+
|
|
367
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
368
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
369
|
+
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
370
|
+
}
|
|
371
|
+
#elif TARGET_OS_OSX
|
|
311
372
|
if (ctx->device.maxTransferRate != 0) {
|
|
312
373
|
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
313
374
|
} else {
|
|
@@ -319,286 +380,177 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
319
380
|
{
|
|
320
381
|
NSError * error = nil;
|
|
321
382
|
|
|
383
|
+
for (int i = 0; i < WSP_GGML_METAL_MAX_KERNELS; ++i) {
|
|
384
|
+
ctx->kernels[i].function = nil;
|
|
385
|
+
ctx->kernels[i].pipeline = nil;
|
|
386
|
+
}
|
|
387
|
+
|
|
322
388
|
/*
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
389
|
+
WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
|
390
|
+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
|
391
|
+
(int) kernel->pipeline.threadExecutionWidth); \
|
|
326
392
|
*/
|
|
327
|
-
#define WSP_GGML_METAL_ADD_KERNEL(name) \
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
393
|
+
#define WSP_GGML_METAL_ADD_KERNEL(e, name, supported) \
|
|
394
|
+
if (supported) { \
|
|
395
|
+
struct wsp_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
|
396
|
+
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
|
397
|
+
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
|
|
398
|
+
if (error) { \
|
|
399
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
400
|
+
return NULL; \
|
|
401
|
+
} \
|
|
402
|
+
} else { \
|
|
403
|
+
WSP_GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
|
|
333
404
|
}
|
|
334
405
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
338
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
339
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
340
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
341
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
342
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
343
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
344
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
345
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
346
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
347
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
348
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
349
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
350
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
351
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
352
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
353
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
354
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
355
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
356
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
357
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
358
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
359
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
360
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
361
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
362
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
363
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
364
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
365
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
366
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
367
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
368
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
369
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
370
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
371
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
372
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
373
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
374
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
375
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
376
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
377
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
378
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
379
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
380
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
381
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
382
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
383
|
-
|
|
384
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
388
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
389
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
390
|
-
|
|
391
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
395
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
396
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
424
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
425
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
426
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
427
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
428
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
429
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
430
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
431
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
432
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
433
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
434
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
435
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
436
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
440
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
441
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
442
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
443
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
444
|
-
|
|
445
|
-
|
|
406
|
+
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
407
|
+
|
|
408
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
409
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
|
410
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
411
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
|
412
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
413
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
|
414
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
|
415
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
|
416
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
|
417
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
|
418
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
419
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
420
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
421
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
|
422
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
|
423
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
424
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
|
425
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
426
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
|
427
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
|
428
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
|
429
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
430
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
|
431
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
|
432
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
|
433
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
|
434
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
435
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
|
436
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
|
437
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
|
438
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
|
439
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
440
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
441
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
442
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
443
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
|
444
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
|
445
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
|
446
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
447
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
448
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
449
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
450
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
451
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
452
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
453
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
454
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
455
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
456
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
457
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
458
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
459
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
460
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
461
|
+
//WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
462
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
463
|
+
//WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
464
|
+
//WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
465
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
466
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
467
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
468
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
469
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
470
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
471
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
472
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
473
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
474
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
475
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
476
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
477
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
478
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
479
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
480
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
|
481
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
|
482
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
|
483
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
|
484
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
|
485
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
|
486
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
|
487
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
|
488
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
|
489
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
490
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
491
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
492
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
493
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
494
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
|
495
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
|
496
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
|
497
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
|
498
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
|
499
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
|
500
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
|
501
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
|
502
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
|
503
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
504
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
505
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
506
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
507
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
508
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
509
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
510
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
511
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
512
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
|
513
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
|
514
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
515
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
516
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
517
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
|
518
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
519
|
+
//WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
|
520
|
+
//WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
|
521
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
522
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
|
523
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
|
524
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
|
525
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
446
526
|
}
|
|
447
527
|
|
|
448
528
|
return ctx;
|
|
449
529
|
}
|
|
450
530
|
|
|
451
|
-
void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
531
|
+
static void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
452
532
|
WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
|
453
|
-
#define WSP_GGML_METAL_DEL_KERNEL(name) \
|
|
454
|
-
|
|
455
|
-
WSP_GGML_METAL_DEL_KERNEL(add);
|
|
456
|
-
WSP_GGML_METAL_DEL_KERNEL(add_row);
|
|
457
|
-
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
458
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
459
|
-
WSP_GGML_METAL_DEL_KERNEL(div);
|
|
460
|
-
WSP_GGML_METAL_DEL_KERNEL(div_row);
|
|
461
|
-
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
462
|
-
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
463
|
-
WSP_GGML_METAL_DEL_KERNEL(tanh);
|
|
464
|
-
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
465
|
-
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
466
|
-
WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
|
|
467
|
-
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
468
|
-
WSP_GGML_METAL_DEL_KERNEL(soft_max);
|
|
469
|
-
WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
|
|
470
|
-
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
471
|
-
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
|
472
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_f32);
|
|
473
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
474
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
475
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
476
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
|
477
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
|
478
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
479
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
480
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
481
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
|
482
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
483
|
-
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
484
|
-
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
485
|
-
WSP_GGML_METAL_DEL_KERNEL(group_norm);
|
|
486
|
-
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
487
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
488
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
489
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
490
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
491
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
492
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
|
493
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
|
494
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
|
495
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
|
496
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
|
497
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
|
498
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
|
499
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
500
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
501
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
502
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
|
503
|
-
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
|
504
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
505
|
-
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
506
|
-
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
507
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
|
508
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
|
509
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
|
510
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
|
511
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
|
512
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
|
513
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
|
514
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
|
515
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
|
516
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
|
517
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
518
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
519
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
520
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
521
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
522
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
|
523
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
|
524
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
525
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
526
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
527
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
528
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
529
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
530
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
|
531
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
|
532
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
533
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
|
534
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
|
535
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
|
536
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
|
537
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
|
538
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
|
539
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
|
540
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
|
541
|
-
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
542
|
-
}
|
|
543
|
-
WSP_GGML_METAL_DEL_KERNEL(rope_f32);
|
|
544
|
-
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
545
|
-
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
546
|
-
WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
547
|
-
WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
|
|
548
|
-
WSP_GGML_METAL_DEL_KERNEL(pad_f32);
|
|
549
|
-
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
550
|
-
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
551
|
-
WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
|
552
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
553
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
554
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
555
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
|
556
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
|
557
|
-
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
558
|
-
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
559
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
560
|
-
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
|
561
|
-
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
562
|
-
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
563
|
-
WSP_GGML_METAL_DEL_KERNEL(sum_rows);
|
|
564
|
-
|
|
565
|
-
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
566
533
|
|
|
567
534
|
free(ctx);
|
|
568
535
|
}
|
|
569
536
|
|
|
570
|
-
|
|
571
|
-
void * data = NULL;
|
|
572
|
-
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
573
|
-
if (result != 0) {
|
|
574
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
575
|
-
return NULL;
|
|
576
|
-
}
|
|
577
|
-
|
|
578
|
-
return data;
|
|
579
|
-
}
|
|
580
|
-
|
|
581
|
-
void wsp_ggml_metal_host_free(void * data) {
|
|
582
|
-
free(data);
|
|
583
|
-
}
|
|
584
|
-
|
|
585
|
-
void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb) {
|
|
586
|
-
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
587
|
-
}
|
|
537
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
588
538
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
539
|
+
struct wsp_ggml_backend_metal_buffer {
|
|
540
|
+
void * data;
|
|
541
|
+
size_t size;
|
|
592
542
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
}
|
|
543
|
+
id<MTLBuffer> metal;
|
|
544
|
+
};
|
|
596
545
|
|
|
597
|
-
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
598
546
|
struct wsp_ggml_backend_metal_buffer_context {
|
|
599
|
-
void *
|
|
547
|
+
void * all_data;
|
|
548
|
+
size_t all_size;
|
|
549
|
+
bool owned;
|
|
600
550
|
|
|
601
|
-
|
|
551
|
+
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
|
552
|
+
int n_buffers;
|
|
553
|
+
struct wsp_ggml_backend_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
|
|
602
554
|
};
|
|
603
555
|
|
|
604
556
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
@@ -610,17 +562,29 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
|
|
|
610
562
|
|
|
611
563
|
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
612
564
|
|
|
565
|
+
wsp_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
|
566
|
+
|
|
613
567
|
// compatibility with ggml-backend
|
|
614
|
-
if (
|
|
615
|
-
struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *)
|
|
568
|
+
if (buffer && buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
|
|
569
|
+
struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) buffer->context;
|
|
616
570
|
|
|
617
|
-
|
|
571
|
+
// find the view that contains the tensor fully
|
|
572
|
+
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
|
573
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
|
618
574
|
|
|
619
|
-
|
|
575
|
+
//WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
|
576
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
|
577
|
+
*offs = (size_t) ioffs;
|
|
620
578
|
|
|
621
|
-
|
|
579
|
+
//WSP_GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
|
622
580
|
|
|
623
|
-
|
|
581
|
+
return buf_ctx->buffers[i].metal;
|
|
582
|
+
}
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
|
586
|
+
|
|
587
|
+
return nil;
|
|
624
588
|
}
|
|
625
589
|
|
|
626
590
|
// find the view that contains the tensor fully
|
|
@@ -642,210 +606,7 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
|
|
|
642
606
|
return nil;
|
|
643
607
|
}
|
|
644
608
|
|
|
645
|
-
bool
|
|
646
|
-
struct wsp_ggml_metal_context * ctx,
|
|
647
|
-
const char * name,
|
|
648
|
-
void * data,
|
|
649
|
-
size_t size,
|
|
650
|
-
size_t max_size) {
|
|
651
|
-
if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
|
|
652
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
|
653
|
-
return false;
|
|
654
|
-
}
|
|
655
|
-
|
|
656
|
-
if (data) {
|
|
657
|
-
// verify that the buffer does not overlap with any of the existing buffers
|
|
658
|
-
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
659
|
-
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
|
660
|
-
|
|
661
|
-
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
|
662
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
|
663
|
-
return false;
|
|
664
|
-
}
|
|
665
|
-
}
|
|
666
|
-
|
|
667
|
-
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
668
|
-
|
|
669
|
-
size_t size_aligned = size;
|
|
670
|
-
if ((size_aligned % size_page) != 0) {
|
|
671
|
-
size_aligned += (size_page - (size_aligned % size_page));
|
|
672
|
-
}
|
|
673
|
-
|
|
674
|
-
// the buffer fits into the max buffer size allowed by the device
|
|
675
|
-
if (size_aligned <= ctx->device.maxBufferLength) {
|
|
676
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
|
677
|
-
ctx->buffers[ctx->n_buffers].data = data;
|
|
678
|
-
ctx->buffers[ctx->n_buffers].size = size;
|
|
679
|
-
|
|
680
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
681
|
-
|
|
682
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
683
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
684
|
-
return false;
|
|
685
|
-
}
|
|
686
|
-
|
|
687
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
688
|
-
|
|
689
|
-
++ctx->n_buffers;
|
|
690
|
-
} else {
|
|
691
|
-
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
692
|
-
// one of the views
|
|
693
|
-
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
694
|
-
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
|
695
|
-
const size_t size_view = ctx->device.maxBufferLength;
|
|
696
|
-
|
|
697
|
-
for (size_t i = 0; i < size; i += size_step) {
|
|
698
|
-
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
699
|
-
|
|
700
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
|
701
|
-
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
|
702
|
-
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
703
|
-
|
|
704
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
705
|
-
|
|
706
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
707
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
708
|
-
return false;
|
|
709
|
-
}
|
|
710
|
-
|
|
711
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
712
|
-
if (i + size_step < size) {
|
|
713
|
-
WSP_GGML_METAL_LOG_INFO("\n");
|
|
714
|
-
}
|
|
715
|
-
|
|
716
|
-
++ctx->n_buffers;
|
|
717
|
-
}
|
|
718
|
-
}
|
|
719
|
-
|
|
720
|
-
#if TARGET_OS_OSX
|
|
721
|
-
WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
722
|
-
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
723
|
-
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
724
|
-
|
|
725
|
-
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
726
|
-
WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
727
|
-
} else {
|
|
728
|
-
WSP_GGML_METAL_LOG_INFO("\n");
|
|
729
|
-
}
|
|
730
|
-
#else
|
|
731
|
-
WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
732
|
-
#endif
|
|
733
|
-
}
|
|
734
|
-
|
|
735
|
-
return true;
|
|
736
|
-
}
|
|
737
|
-
|
|
738
|
-
void wsp_ggml_metal_set_tensor(
|
|
739
|
-
struct wsp_ggml_metal_context * ctx,
|
|
740
|
-
struct wsp_ggml_tensor * t) {
|
|
741
|
-
size_t offs;
|
|
742
|
-
id<MTLBuffer> id_dst = wsp_ggml_metal_get_buffer(ctx, t, &offs);
|
|
743
|
-
|
|
744
|
-
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, wsp_ggml_nbytes(t));
|
|
745
|
-
}
|
|
746
|
-
|
|
747
|
-
void wsp_ggml_metal_get_tensor(
|
|
748
|
-
struct wsp_ggml_metal_context * ctx,
|
|
749
|
-
struct wsp_ggml_tensor * t) {
|
|
750
|
-
size_t offs;
|
|
751
|
-
id<MTLBuffer> id_src = wsp_ggml_metal_get_buffer(ctx, t, &offs);
|
|
752
|
-
|
|
753
|
-
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), wsp_ggml_nbytes(t));
|
|
754
|
-
}
|
|
755
|
-
|
|
756
|
-
void wsp_ggml_metal_graph_find_concurrency(
|
|
757
|
-
struct wsp_ggml_metal_context * ctx,
|
|
758
|
-
struct wsp_ggml_cgraph * gf, bool check_mem) {
|
|
759
|
-
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
|
760
|
-
int nodes_unused[WSP_GGML_MAX_CONCUR];
|
|
761
|
-
|
|
762
|
-
for (int i = 0; i < WSP_GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
|
763
|
-
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
|
764
|
-
ctx->concur_list_len = 0;
|
|
765
|
-
|
|
766
|
-
int n_left = gf->n_nodes;
|
|
767
|
-
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
|
768
|
-
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
|
769
|
-
|
|
770
|
-
while (n_left > 0) {
|
|
771
|
-
// number of nodes at a layer (that can be issued concurrently)
|
|
772
|
-
int concurrency = 0;
|
|
773
|
-
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
|
774
|
-
if (nodes_unused[i]) {
|
|
775
|
-
// if the requirements for gf->nodes[i] are satisfied
|
|
776
|
-
int exe_flag = 1;
|
|
777
|
-
|
|
778
|
-
// scan all srcs
|
|
779
|
-
for (int src_ind = 0; src_ind < WSP_GGML_MAX_SRC; src_ind++) {
|
|
780
|
-
struct wsp_ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
|
781
|
-
if (src_cur) {
|
|
782
|
-
// if is leaf nodes it's satisfied.
|
|
783
|
-
// TODO: wsp_ggml_is_leaf()
|
|
784
|
-
if (src_cur->op == WSP_GGML_OP_NONE && src_cur->grad == NULL) {
|
|
785
|
-
continue;
|
|
786
|
-
}
|
|
787
|
-
|
|
788
|
-
// otherwise this src should be the output from previous nodes.
|
|
789
|
-
int is_found = 0;
|
|
790
|
-
|
|
791
|
-
// scan 2*search_depth back because we inserted barrier.
|
|
792
|
-
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
|
793
|
-
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
|
794
|
-
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
|
795
|
-
is_found = 1;
|
|
796
|
-
break;
|
|
797
|
-
}
|
|
798
|
-
}
|
|
799
|
-
if (is_found == 0) {
|
|
800
|
-
exe_flag = 0;
|
|
801
|
-
break;
|
|
802
|
-
}
|
|
803
|
-
}
|
|
804
|
-
}
|
|
805
|
-
if (exe_flag && check_mem) {
|
|
806
|
-
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
|
807
|
-
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
|
808
|
-
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
|
809
|
-
int64_t length = (int64_t) wsp_ggml_nbytes(gf->nodes[i]);
|
|
810
|
-
for (int j = n_start; j < i; j++) {
|
|
811
|
-
if (nodes_unused[j] && gf->nodes[j]->op != WSP_GGML_OP_RESHAPE \
|
|
812
|
-
&& gf->nodes[j]->op != WSP_GGML_OP_VIEW \
|
|
813
|
-
&& gf->nodes[j]->op != WSP_GGML_OP_TRANSPOSE \
|
|
814
|
-
&& gf->nodes[j]->op != WSP_GGML_OP_PERMUTE) {
|
|
815
|
-
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
|
816
|
-
((int64_t)gf->nodes[j]->data) + (int64_t) wsp_ggml_nbytes(gf->nodes[j]) <= data_start) {
|
|
817
|
-
continue;
|
|
818
|
-
}
|
|
819
|
-
|
|
820
|
-
exe_flag = 0;
|
|
821
|
-
}
|
|
822
|
-
}
|
|
823
|
-
}
|
|
824
|
-
if (exe_flag) {
|
|
825
|
-
ctx->concur_list[level_pos + concurrency] = i;
|
|
826
|
-
nodes_unused[i] = 0;
|
|
827
|
-
concurrency++;
|
|
828
|
-
ctx->concur_list_len++;
|
|
829
|
-
}
|
|
830
|
-
}
|
|
831
|
-
}
|
|
832
|
-
n_left -= concurrency;
|
|
833
|
-
// adding a barrier different layer
|
|
834
|
-
ctx->concur_list[level_pos + concurrency] = -1;
|
|
835
|
-
ctx->concur_list_len++;
|
|
836
|
-
// jump all sorted nodes at nodes_bak
|
|
837
|
-
while (!nodes_unused[n_start]) {
|
|
838
|
-
n_start++;
|
|
839
|
-
}
|
|
840
|
-
level_pos += concurrency + 1;
|
|
841
|
-
}
|
|
842
|
-
|
|
843
|
-
if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
|
|
844
|
-
WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
|
845
|
-
}
|
|
846
|
-
}
|
|
847
|
-
|
|
848
|
-
static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
609
|
+
static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_metal_context * ctx, const struct wsp_ggml_tensor * op) {
|
|
849
610
|
switch (op->op) {
|
|
850
611
|
case WSP_GGML_OP_UNARY:
|
|
851
612
|
switch (wsp_ggml_get_unary_op(op)) {
|
|
@@ -871,9 +632,11 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
871
632
|
case WSP_GGML_OP_SCALE:
|
|
872
633
|
case WSP_GGML_OP_SQR:
|
|
873
634
|
case WSP_GGML_OP_SUM_ROWS:
|
|
635
|
+
return true;
|
|
874
636
|
case WSP_GGML_OP_SOFT_MAX:
|
|
875
637
|
case WSP_GGML_OP_RMS_NORM:
|
|
876
638
|
case WSP_GGML_OP_GROUP_NORM:
|
|
639
|
+
return ctx->support_simdgroup_reduction;
|
|
877
640
|
case WSP_GGML_OP_NORM:
|
|
878
641
|
case WSP_GGML_OP_ALIBI:
|
|
879
642
|
case WSP_GGML_OP_ROPE:
|
|
@@ -882,9 +645,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
882
645
|
case WSP_GGML_OP_PAD:
|
|
883
646
|
case WSP_GGML_OP_ARGSORT:
|
|
884
647
|
case WSP_GGML_OP_LEAKY_RELU:
|
|
648
|
+
return true;
|
|
885
649
|
case WSP_GGML_OP_MUL_MAT:
|
|
886
650
|
case WSP_GGML_OP_MUL_MAT_ID:
|
|
887
|
-
return
|
|
651
|
+
return ctx->support_simdgroup_reduction;
|
|
888
652
|
case WSP_GGML_OP_CPY:
|
|
889
653
|
case WSP_GGML_OP_DUP:
|
|
890
654
|
case WSP_GGML_OP_CONT:
|
|
@@ -922,1433 +686,1559 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
|
922
686
|
return false;
|
|
923
687
|
}
|
|
924
688
|
}
|
|
925
|
-
|
|
689
|
+
|
|
690
|
+
static bool wsp_ggml_metal_graph_compute(
|
|
926
691
|
struct wsp_ggml_metal_context * ctx,
|
|
927
692
|
struct wsp_ggml_cgraph * gf) {
|
|
928
|
-
@autoreleasepool {
|
|
929
693
|
|
|
930
|
-
// if there is ctx->concur_list, dispatch concurrently
|
|
931
|
-
// else fallback to serial dispatch
|
|
932
694
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
|
933
|
-
|
|
934
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= WSP_GGML_MAX_CONCUR;
|
|
935
|
-
|
|
936
|
-
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
|
937
|
-
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
|
695
|
+
edesc.dispatchType = MTLDispatchTypeSerial;
|
|
938
696
|
|
|
939
697
|
// create multiple command buffers and enqueue them
|
|
940
698
|
// then, we encode the graph into the command buffers in parallel
|
|
941
699
|
|
|
700
|
+
const int n_nodes = gf->n_nodes;
|
|
942
701
|
const int n_cb = ctx->n_cb;
|
|
702
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
|
943
703
|
|
|
944
|
-
|
|
945
|
-
|
|
704
|
+
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
|
705
|
+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
706
|
+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
707
|
+
command_buffer_builder[cb_idx] = command_buffer;
|
|
946
708
|
|
|
947
709
|
// enqueue the command buffers in order to specify their execution order
|
|
948
|
-
[
|
|
949
|
-
|
|
950
|
-
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
|
710
|
+
[command_buffer enqueue];
|
|
951
711
|
}
|
|
712
|
+
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
|
952
713
|
|
|
953
|
-
|
|
954
|
-
const int
|
|
955
|
-
|
|
956
|
-
dispatch_async(ctx->d_queue, ^{
|
|
957
|
-
size_t offs_src0 = 0;
|
|
958
|
-
size_t offs_src1 = 0;
|
|
959
|
-
size_t offs_dst = 0;
|
|
960
|
-
|
|
961
|
-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
962
|
-
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
|
963
|
-
|
|
964
|
-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
965
|
-
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
966
|
-
|
|
967
|
-
for (int ind = node_start; ind < node_end; ++ind) {
|
|
968
|
-
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
|
969
|
-
|
|
970
|
-
if (i == -1) {
|
|
971
|
-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
972
|
-
continue;
|
|
973
|
-
}
|
|
974
|
-
|
|
975
|
-
//WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
|
|
976
|
-
|
|
977
|
-
struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
978
|
-
struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
979
|
-
struct wsp_ggml_tensor * dst = gf->nodes[i];
|
|
980
|
-
|
|
981
|
-
switch (dst->op) {
|
|
982
|
-
case WSP_GGML_OP_NONE:
|
|
983
|
-
case WSP_GGML_OP_RESHAPE:
|
|
984
|
-
case WSP_GGML_OP_VIEW:
|
|
985
|
-
case WSP_GGML_OP_TRANSPOSE:
|
|
986
|
-
case WSP_GGML_OP_PERMUTE:
|
|
987
|
-
{
|
|
988
|
-
// noop -> next node
|
|
989
|
-
} continue;
|
|
990
|
-
default:
|
|
991
|
-
{
|
|
992
|
-
} break;
|
|
993
|
-
}
|
|
994
|
-
|
|
995
|
-
if (!wsp_ggml_metal_supports_op(dst)) {
|
|
996
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
|
|
997
|
-
WSP_GGML_ASSERT(!"unsupported op");
|
|
998
|
-
}
|
|
999
|
-
|
|
1000
|
-
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
1001
|
-
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
1002
|
-
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
1003
|
-
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
1004
|
-
|
|
1005
|
-
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
1006
|
-
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
1007
|
-
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
1008
|
-
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
1009
|
-
|
|
1010
|
-
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
1011
|
-
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
1012
|
-
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
1013
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
|
1014
|
-
|
|
1015
|
-
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
1016
|
-
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
1017
|
-
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
1018
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
|
1019
|
-
|
|
1020
|
-
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
1021
|
-
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
1022
|
-
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
1023
|
-
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
1024
|
-
|
|
1025
|
-
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
1026
|
-
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
1027
|
-
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
1028
|
-
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
1029
|
-
|
|
1030
|
-
const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
|
|
1031
|
-
const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
|
|
1032
|
-
const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
|
|
1033
|
-
|
|
1034
|
-
id<MTLBuffer> id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
|
1035
|
-
id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
1036
|
-
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
1037
|
-
|
|
1038
|
-
//WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
1039
|
-
//if (src0) {
|
|
1040
|
-
// WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
|
|
1041
|
-
// wsp_ggml_is_contiguous(src0), src0->name);
|
|
1042
|
-
//}
|
|
1043
|
-
//if (src1) {
|
|
1044
|
-
// WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
|
|
1045
|
-
// wsp_ggml_is_contiguous(src1), src1->name);
|
|
1046
|
-
//}
|
|
1047
|
-
//if (dst) {
|
|
1048
|
-
// WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
|
|
1049
|
-
// dst->name);
|
|
1050
|
-
//}
|
|
1051
|
-
|
|
1052
|
-
switch (dst->op) {
|
|
1053
|
-
case WSP_GGML_OP_CONCAT:
|
|
1054
|
-
{
|
|
1055
|
-
const int64_t nb = ne00;
|
|
1056
|
-
|
|
1057
|
-
[encoder setComputePipelineState:ctx->pipeline_concat];
|
|
1058
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1059
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1060
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1061
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1062
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1063
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1064
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1065
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1066
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
1067
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
1068
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
1069
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1070
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1071
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1072
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1073
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1074
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1075
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1076
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1077
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1078
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1079
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1080
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1081
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1082
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1083
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1084
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1085
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
1086
|
-
|
|
1087
|
-
const int nth = MIN(1024, ne0);
|
|
1088
|
-
|
|
1089
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1090
|
-
} break;
|
|
1091
|
-
case WSP_GGML_OP_ADD:
|
|
1092
|
-
case WSP_GGML_OP_MUL:
|
|
1093
|
-
case WSP_GGML_OP_DIV:
|
|
1094
|
-
{
|
|
1095
|
-
const size_t offs = 0;
|
|
1096
|
-
|
|
1097
|
-
bool bcast_row = false;
|
|
1098
|
-
|
|
1099
|
-
int64_t nb = ne00;
|
|
1100
|
-
|
|
1101
|
-
id<MTLComputePipelineState> pipeline = nil;
|
|
1102
|
-
|
|
1103
|
-
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1104
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
714
|
+
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
|
715
|
+
const int cb_idx = iter;
|
|
1105
716
|
|
|
1106
|
-
|
|
1107
|
-
|
|
717
|
+
size_t offs_src0 = 0;
|
|
718
|
+
size_t offs_src1 = 0;
|
|
719
|
+
size_t offs_dst = 0;
|
|
1108
720
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
|
1112
|
-
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
|
1113
|
-
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
|
1114
|
-
default: WSP_GGML_ASSERT(false);
|
|
1115
|
-
}
|
|
721
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
722
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
1116
723
|
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
switch (dst->op) {
|
|
1120
|
-
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
|
1121
|
-
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
|
1122
|
-
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
|
1123
|
-
default: WSP_GGML_ASSERT(false);
|
|
1124
|
-
}
|
|
1125
|
-
}
|
|
724
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
725
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
1126
726
|
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1133
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1134
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1135
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1136
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
1137
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
1138
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
1139
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1140
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1141
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1142
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1143
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1144
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1145
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1146
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1147
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1148
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1149
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1150
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1151
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1152
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1153
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1154
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1155
|
-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1156
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
1157
|
-
|
|
1158
|
-
if (bcast_row) {
|
|
1159
|
-
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
727
|
+
for (int i = node_start; i < node_end; ++i) {
|
|
728
|
+
if (i == -1) {
|
|
729
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
730
|
+
continue;
|
|
731
|
+
}
|
|
1160
732
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
733
|
+
//WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
|
|
734
|
+
|
|
735
|
+
struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
736
|
+
struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
737
|
+
struct wsp_ggml_tensor * dst = gf->nodes[i];
|
|
738
|
+
|
|
739
|
+
switch (dst->op) {
|
|
740
|
+
case WSP_GGML_OP_NONE:
|
|
741
|
+
case WSP_GGML_OP_RESHAPE:
|
|
742
|
+
case WSP_GGML_OP_VIEW:
|
|
743
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
744
|
+
case WSP_GGML_OP_PERMUTE:
|
|
745
|
+
{
|
|
746
|
+
// noop -> next node
|
|
747
|
+
} continue;
|
|
748
|
+
default:
|
|
749
|
+
{
|
|
750
|
+
} break;
|
|
751
|
+
}
|
|
1164
752
|
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
{
|
|
1170
|
-
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
1171
|
-
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1172
|
-
WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
|
|
753
|
+
if (!wsp_ggml_metal_supports_op(ctx, dst)) {
|
|
754
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
|
|
755
|
+
WSP_GGML_ASSERT(!"unsupported op");
|
|
756
|
+
}
|
|
1173
757
|
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
1178
|
-
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
1179
|
-
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
1180
|
-
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1181
|
-
|
|
1182
|
-
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1183
|
-
|
|
1184
|
-
if (!inplace) {
|
|
1185
|
-
// run a separete kernel to cpy src->dst
|
|
1186
|
-
// not sure how to avoid this
|
|
1187
|
-
// TODO: make a simpler cpy_bytes kernel
|
|
1188
|
-
|
|
1189
|
-
const int nth = MIN(1024, ne00);
|
|
1190
|
-
|
|
1191
|
-
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
|
1192
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1193
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1194
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1195
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1196
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1197
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1198
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1199
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1200
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1201
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1202
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1203
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1204
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1205
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1206
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1207
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1208
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1209
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1210
|
-
|
|
1211
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1212
|
-
}
|
|
758
|
+
#ifndef WSP_GGML_METAL_NDEBUG
|
|
759
|
+
[encoder pushDebugGroup:[NSString stringWithCString:wsp_ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
|
760
|
+
#endif
|
|
1213
761
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
762
|
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
763
|
+
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
764
|
+
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
765
|
+
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
766
|
+
|
|
767
|
+
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
768
|
+
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
769
|
+
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
770
|
+
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
771
|
+
|
|
772
|
+
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
773
|
+
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
774
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
775
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
|
776
|
+
|
|
777
|
+
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
778
|
+
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
779
|
+
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
780
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
|
781
|
+
|
|
782
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
783
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
784
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
785
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
786
|
+
|
|
787
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
788
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
789
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
790
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
791
|
+
|
|
792
|
+
const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
|
|
793
|
+
const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
|
|
794
|
+
const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
|
|
795
|
+
|
|
796
|
+
id<MTLBuffer> id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
|
797
|
+
id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
798
|
+
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
799
|
+
|
|
800
|
+
//WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
801
|
+
//if (src0) {
|
|
802
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
|
|
803
|
+
// wsp_ggml_is_contiguous(src0), src0->name);
|
|
804
|
+
//}
|
|
805
|
+
//if (src1) {
|
|
806
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
|
|
807
|
+
// wsp_ggml_is_contiguous(src1), src1->name);
|
|
808
|
+
//}
|
|
809
|
+
//if (dst) {
|
|
810
|
+
// WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
|
|
811
|
+
// dst->name);
|
|
812
|
+
//}
|
|
813
|
+
|
|
814
|
+
switch (dst->op) {
|
|
815
|
+
case WSP_GGML_OP_CONCAT:
|
|
816
|
+
{
|
|
817
|
+
const int64_t nb = ne00;
|
|
818
|
+
|
|
819
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
|
820
|
+
|
|
821
|
+
[encoder setComputePipelineState:pipeline];
|
|
822
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
823
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
824
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
825
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
826
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
827
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
828
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
829
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
830
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
831
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
832
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
833
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
834
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
835
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
836
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
837
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
838
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
839
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
840
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
841
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
842
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
843
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
844
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
845
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
846
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
847
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
848
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
849
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
850
|
+
|
|
851
|
+
const int nth = MIN(1024, ne0);
|
|
852
|
+
|
|
853
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
854
|
+
} break;
|
|
855
|
+
case WSP_GGML_OP_ADD:
|
|
856
|
+
case WSP_GGML_OP_MUL:
|
|
857
|
+
case WSP_GGML_OP_DIV:
|
|
858
|
+
{
|
|
859
|
+
const size_t offs = 0;
|
|
860
|
+
|
|
861
|
+
bool bcast_row = false;
|
|
862
|
+
|
|
863
|
+
int64_t nb = ne00;
|
|
864
|
+
|
|
865
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
866
|
+
|
|
867
|
+
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1250
868
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1251
869
|
|
|
1252
|
-
|
|
870
|
+
// src1 is a row
|
|
871
|
+
WSP_GGML_ASSERT(ne11 == 1);
|
|
1253
872
|
|
|
1254
|
-
|
|
873
|
+
nb = ne00 / 4;
|
|
874
|
+
switch (dst->op) {
|
|
875
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
876
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
877
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
878
|
+
default: WSP_GGML_ASSERT(false);
|
|
879
|
+
}
|
|
1255
880
|
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
881
|
+
bcast_row = true;
|
|
882
|
+
} else {
|
|
883
|
+
switch (dst->op) {
|
|
884
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
885
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
886
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
887
|
+
default: WSP_GGML_ASSERT(false);
|
|
1261
888
|
}
|
|
889
|
+
}
|
|
1262
890
|
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
891
|
+
[encoder setComputePipelineState:pipeline];
|
|
892
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
893
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
894
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
895
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
896
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
897
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
898
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
899
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
900
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
901
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
902
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
903
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
904
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
905
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
906
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
907
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
908
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
909
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
910
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
911
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
912
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
913
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
914
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
915
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
916
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
917
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
918
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
919
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
920
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
921
|
+
|
|
922
|
+
if (bcast_row) {
|
|
923
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
1266
924
|
|
|
1267
925
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1268
|
-
}
|
|
1269
|
-
|
|
1270
|
-
switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
|
|
1271
|
-
case WSP_GGML_UNARY_OP_TANH:
|
|
1272
|
-
{
|
|
1273
|
-
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
|
1274
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1275
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
926
|
+
} else {
|
|
927
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
1276
928
|
|
|
1277
|
-
|
|
929
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
930
|
+
}
|
|
931
|
+
} break;
|
|
932
|
+
case WSP_GGML_OP_ACC:
|
|
933
|
+
{
|
|
934
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
935
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
936
|
+
WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
|
|
1278
937
|
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
case WSP_GGML_UNARY_OP_RELU:
|
|
1282
|
-
{
|
|
1283
|
-
[encoder setComputePipelineState:ctx->pipeline_relu];
|
|
1284
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1285
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
938
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
939
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
1286
940
|
|
|
1287
|
-
|
|
941
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
942
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
943
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
944
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1288
945
|
|
|
1289
|
-
|
|
1290
|
-
} break;
|
|
1291
|
-
case WSP_GGML_UNARY_OP_GELU:
|
|
1292
|
-
{
|
|
1293
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
|
1294
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1295
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
946
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1296
947
|
|
|
1297
|
-
|
|
1298
|
-
|
|
948
|
+
if (!inplace) {
|
|
949
|
+
// run a separete kernel to cpy src->dst
|
|
950
|
+
// not sure how to avoid this
|
|
951
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1299
952
|
|
|
1300
|
-
|
|
1301
|
-
} break;
|
|
1302
|
-
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
1303
|
-
{
|
|
1304
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
|
1305
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1306
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
953
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
|
1307
954
|
|
|
1308
|
-
|
|
1309
|
-
|
|
955
|
+
[encoder setComputePipelineState:pipeline];
|
|
956
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
957
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
958
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
959
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
960
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
961
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
962
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
963
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
964
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
965
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
966
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
967
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
968
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
969
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
970
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
971
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
972
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
973
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1310
974
|
|
|
1311
|
-
|
|
1312
|
-
} break;
|
|
1313
|
-
case WSP_GGML_UNARY_OP_SILU:
|
|
1314
|
-
{
|
|
1315
|
-
[encoder setComputePipelineState:ctx->pipeline_silu];
|
|
1316
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1317
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
975
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
1318
976
|
|
|
1319
|
-
|
|
1320
|
-
|
|
977
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
978
|
+
}
|
|
1321
979
|
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
980
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
981
|
+
|
|
982
|
+
[encoder setComputePipelineState:pipeline];
|
|
983
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
984
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
985
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
986
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
987
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
988
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
989
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
990
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
991
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
|
992
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
|
993
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
|
994
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
995
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
996
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
997
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
998
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
999
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1000
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1001
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1002
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1003
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1004
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1005
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1006
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1007
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
|
1008
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
|
1009
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
|
1010
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1011
|
+
|
|
1012
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
1013
|
+
|
|
1014
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1015
|
+
} break;
|
|
1016
|
+
case WSP_GGML_OP_SCALE:
|
|
1017
|
+
{
|
|
1018
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1019
|
+
|
|
1020
|
+
const float scale = *(const float *) dst->op_params;
|
|
1021
|
+
|
|
1022
|
+
int64_t n = wsp_ggml_nelements(dst);
|
|
1023
|
+
|
|
1024
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1025
|
+
|
|
1026
|
+
if (n % 4 == 0) {
|
|
1027
|
+
n /= 4;
|
|
1028
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
|
|
1029
|
+
} else {
|
|
1030
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
|
|
1031
|
+
}
|
|
1333
1032
|
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1033
|
+
[encoder setComputePipelineState:pipeline];
|
|
1034
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1035
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1036
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
1337
1037
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1038
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1039
|
+
} break;
|
|
1040
|
+
case WSP_GGML_OP_UNARY:
|
|
1041
|
+
switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
|
|
1042
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
1043
|
+
{
|
|
1044
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
|
1344
1045
|
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1349
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1350
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1351
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1352
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1353
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1354
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1355
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1356
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1357
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1358
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1359
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1360
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1361
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1362
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1363
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1364
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1365
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1366
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1367
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1368
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1369
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1370
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1371
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1372
|
-
|
|
1373
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1374
|
-
} break;
|
|
1375
|
-
case WSP_GGML_OP_SOFT_MAX:
|
|
1376
|
-
{
|
|
1377
|
-
int nth = 32; // SIMD width
|
|
1378
|
-
|
|
1379
|
-
if (ne00%4 == 0) {
|
|
1380
|
-
while (nth < ne00/4 && nth < 256) {
|
|
1381
|
-
nth *= 2;
|
|
1382
|
-
}
|
|
1383
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
1384
|
-
} else {
|
|
1385
|
-
while (nth < ne00 && nth < 1024) {
|
|
1386
|
-
nth *= 2;
|
|
1387
|
-
}
|
|
1388
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
1389
|
-
}
|
|
1046
|
+
[encoder setComputePipelineState:pipeline];
|
|
1047
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1048
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1390
1049
|
|
|
1391
|
-
|
|
1050
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1392
1051
|
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
}
|
|
1399
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1400
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1401
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1402
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1403
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1404
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1405
|
-
|
|
1406
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1407
|
-
} break;
|
|
1408
|
-
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
1409
|
-
{
|
|
1410
|
-
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
1411
|
-
|
|
1412
|
-
if (ne00%8 == 0) {
|
|
1413
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
|
1414
|
-
} else {
|
|
1415
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
|
1416
|
-
}
|
|
1417
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1418
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1419
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1420
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1421
|
-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
1052
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1053
|
+
} break;
|
|
1054
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
1055
|
+
{
|
|
1056
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RELU].pipeline;
|
|
1422
1057
|
|
|
1423
|
-
|
|
1424
|
-
[encoder
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1058
|
+
[encoder setComputePipelineState:pipeline];
|
|
1059
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1060
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1061
|
+
|
|
1062
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1063
|
+
|
|
1064
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1065
|
+
} break;
|
|
1066
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
1067
|
+
{
|
|
1068
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
|
1069
|
+
|
|
1070
|
+
[encoder setComputePipelineState:pipeline];
|
|
1071
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1072
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1073
|
+
|
|
1074
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1075
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1076
|
+
|
|
1077
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1078
|
+
} break;
|
|
1079
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
1080
|
+
{
|
|
1081
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
|
1082
|
+
|
|
1083
|
+
[encoder setComputePipelineState:pipeline];
|
|
1084
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1085
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1086
|
+
|
|
1087
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1088
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1089
|
+
|
|
1090
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1091
|
+
} break;
|
|
1092
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
1093
|
+
{
|
|
1094
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
|
1095
|
+
|
|
1096
|
+
[encoder setComputePipelineState:pipeline];
|
|
1097
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1098
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1099
|
+
|
|
1100
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1101
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1102
|
+
|
|
1103
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1104
|
+
} break;
|
|
1105
|
+
default:
|
|
1106
|
+
{
|
|
1107
|
+
WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
1108
|
+
WSP_GGML_ASSERT(false);
|
|
1428
1109
|
}
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1110
|
+
} break;
|
|
1111
|
+
case WSP_GGML_OP_SQR:
|
|
1112
|
+
{
|
|
1113
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1114
|
+
|
|
1115
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SQR].pipeline;
|
|
1116
|
+
|
|
1117
|
+
[encoder setComputePipelineState:pipeline];
|
|
1118
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1119
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1120
|
+
|
|
1121
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1122
|
+
|
|
1123
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1124
|
+
} break;
|
|
1125
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
1126
|
+
{
|
|
1127
|
+
WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
|
|
1128
|
+
|
|
1129
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
1130
|
+
|
|
1131
|
+
[encoder setComputePipelineState:pipeline];
|
|
1132
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1133
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1134
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1135
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1136
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1137
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1138
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1139
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1140
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1141
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1142
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1143
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1144
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1145
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1146
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1147
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1148
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1149
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1150
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1151
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1152
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1153
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1154
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1155
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1156
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1157
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1158
|
+
|
|
1159
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1160
|
+
} break;
|
|
1161
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
1162
|
+
{
|
|
1163
|
+
int nth = 32; // SIMD width
|
|
1164
|
+
|
|
1165
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1166
|
+
|
|
1167
|
+
if (ne00%4 == 0) {
|
|
1168
|
+
while (nth < ne00/4 && nth < 256) {
|
|
1169
|
+
nth *= 2;
|
|
1170
|
+
}
|
|
1171
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
|
1172
|
+
} else {
|
|
1173
|
+
while (nth < ne00 && nth < 1024) {
|
|
1174
|
+
nth *= 2;
|
|
1175
|
+
}
|
|
1176
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
|
1177
|
+
}
|
|
1433
1178
|
|
|
1434
|
-
|
|
1435
|
-
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1436
|
-
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1179
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
1437
1180
|
|
|
1438
|
-
|
|
1439
|
-
|
|
1181
|
+
[encoder setComputePipelineState:pipeline];
|
|
1182
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1183
|
+
if (id_src1) {
|
|
1184
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1185
|
+
} else {
|
|
1186
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
1187
|
+
}
|
|
1188
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1189
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1190
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1191
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1192
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1193
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1194
|
+
|
|
1195
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1196
|
+
} break;
|
|
1197
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
1198
|
+
{
|
|
1199
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
1200
|
+
|
|
1201
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1202
|
+
|
|
1203
|
+
if (ne00%8 == 0) {
|
|
1204
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
|
|
1205
|
+
} else {
|
|
1206
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1209
|
+
[encoder setComputePipelineState:pipeline];
|
|
1210
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1211
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1212
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1213
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1214
|
+
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
1215
|
+
|
|
1216
|
+
if (ne00%8 == 0) {
|
|
1217
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1218
|
+
}
|
|
1219
|
+
else {
|
|
1220
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1221
|
+
}
|
|
1222
|
+
} break;
|
|
1223
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
1224
|
+
{
|
|
1225
|
+
WSP_GGML_ASSERT(ne00 == ne10);
|
|
1440
1226
|
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1227
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
|
1228
|
+
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1229
|
+
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1230
|
+
|
|
1231
|
+
const uint r2 = ne12/ne02;
|
|
1232
|
+
const uint r3 = ne13/ne03;
|
|
1233
|
+
|
|
1234
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1235
|
+
// to the matrix-vector kernel
|
|
1236
|
+
int ne11_mm_min = 1;
|
|
1444
1237
|
|
|
1445
1238
|
#if 0
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
}
|
|
1239
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
1240
|
+
// these numbers do not translate to other devices or model sizes
|
|
1241
|
+
// TODO: need to find a better approach
|
|
1242
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1243
|
+
switch (src0t) {
|
|
1244
|
+
case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1245
|
+
case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1246
|
+
case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1247
|
+
case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1248
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1249
|
+
case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1250
|
+
case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1251
|
+
case WSP_GGML_TYPE_Q5_0: // not tested yet
|
|
1252
|
+
case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1253
|
+
case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1254
|
+
case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1255
|
+
default: ne11_mm_min = 1; break;
|
|
1464
1256
|
}
|
|
1257
|
+
}
|
|
1465
1258
|
#endif
|
|
1466
1259
|
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1260
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1261
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1262
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1263
|
+
!wsp_ggml_is_transposed(src0) &&
|
|
1264
|
+
!wsp_ggml_is_transposed(src1) &&
|
|
1265
|
+
src1t == WSP_GGML_TYPE_F32 &&
|
|
1266
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1267
|
+
(ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1268
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1269
|
+
|
|
1270
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1271
|
+
|
|
1272
|
+
switch (src0->type) {
|
|
1273
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
|
1274
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
|
1275
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
|
1276
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
|
1277
|
+
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
1278
|
+
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
1279
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
1280
|
+
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
1281
|
+
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
1282
|
+
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
1283
|
+
case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
|
1284
|
+
case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
|
1285
|
+
case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
|
1286
|
+
case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
|
1287
|
+
default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
[encoder setComputePipelineState:pipeline];
|
|
1291
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1292
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1293
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1294
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1295
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1296
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
|
1297
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
|
1298
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
1299
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
1300
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
1301
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1302
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1303
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1304
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1305
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1306
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1307
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1308
|
+
} else {
|
|
1309
|
+
int nth0 = 32;
|
|
1310
|
+
int nth1 = 1;
|
|
1311
|
+
int nrows = 1;
|
|
1312
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1313
|
+
|
|
1314
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1315
|
+
|
|
1316
|
+
// use custom matrix x vector kernel
|
|
1317
|
+
switch (src0t) {
|
|
1318
|
+
case WSP_GGML_TYPE_F32:
|
|
1319
|
+
{
|
|
1320
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1321
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
1322
|
+
nrows = 4;
|
|
1323
|
+
} break;
|
|
1324
|
+
case WSP_GGML_TYPE_F16:
|
|
1325
|
+
{
|
|
1326
|
+
nth0 = 32;
|
|
1327
|
+
nth1 = 1;
|
|
1328
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
1329
|
+
if (ne11 * ne12 < 4) {
|
|
1330
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
1331
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1332
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
1333
|
+
nrows = ne11;
|
|
1536
1334
|
} else {
|
|
1537
|
-
|
|
1335
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
|
1538
1336
|
nrows = 4;
|
|
1539
1337
|
}
|
|
1540
|
-
}
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
nth0 = 8;
|
|
1544
|
-
nth1 = 8;
|
|
1545
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
1546
|
-
} break;
|
|
1547
|
-
case WSP_GGML_TYPE_Q4_1:
|
|
1548
|
-
{
|
|
1549
|
-
nth0 = 8;
|
|
1550
|
-
nth1 = 8;
|
|
1551
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1552
|
-
} break;
|
|
1553
|
-
case WSP_GGML_TYPE_Q5_0:
|
|
1554
|
-
{
|
|
1555
|
-
nth0 = 8;
|
|
1556
|
-
nth1 = 8;
|
|
1557
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1558
|
-
} break;
|
|
1559
|
-
case WSP_GGML_TYPE_Q5_1:
|
|
1560
|
-
{
|
|
1561
|
-
nth0 = 8;
|
|
1562
|
-
nth1 = 8;
|
|
1563
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
1564
|
-
} break;
|
|
1565
|
-
case WSP_GGML_TYPE_Q8_0:
|
|
1566
|
-
{
|
|
1567
|
-
nth0 = 8;
|
|
1568
|
-
nth1 = 8;
|
|
1569
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
1570
|
-
} break;
|
|
1571
|
-
case WSP_GGML_TYPE_Q2_K:
|
|
1572
|
-
{
|
|
1573
|
-
nth0 = 2;
|
|
1574
|
-
nth1 = 32;
|
|
1575
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
1576
|
-
} break;
|
|
1577
|
-
case WSP_GGML_TYPE_Q3_K:
|
|
1578
|
-
{
|
|
1579
|
-
nth0 = 2;
|
|
1580
|
-
nth1 = 32;
|
|
1581
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
1582
|
-
} break;
|
|
1583
|
-
case WSP_GGML_TYPE_Q4_K:
|
|
1584
|
-
{
|
|
1585
|
-
nth0 = 4; //1;
|
|
1586
|
-
nth1 = 8; //32;
|
|
1587
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
1588
|
-
} break;
|
|
1589
|
-
case WSP_GGML_TYPE_Q5_K:
|
|
1590
|
-
{
|
|
1591
|
-
nth0 = 2;
|
|
1592
|
-
nth1 = 32;
|
|
1593
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
1594
|
-
} break;
|
|
1595
|
-
case WSP_GGML_TYPE_Q6_K:
|
|
1596
|
-
{
|
|
1597
|
-
nth0 = 2;
|
|
1598
|
-
nth1 = 32;
|
|
1599
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
1600
|
-
} break;
|
|
1601
|
-
default:
|
|
1602
|
-
{
|
|
1603
|
-
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1604
|
-
WSP_GGML_ASSERT(false && "not implemented");
|
|
1338
|
+
} else {
|
|
1339
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
|
1340
|
+
nrows = 4;
|
|
1605
1341
|
}
|
|
1606
|
-
|
|
1342
|
+
} break;
|
|
1343
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1344
|
+
{
|
|
1345
|
+
nth0 = 8;
|
|
1346
|
+
nth1 = 8;
|
|
1347
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
|
1348
|
+
} break;
|
|
1349
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1350
|
+
{
|
|
1351
|
+
nth0 = 8;
|
|
1352
|
+
nth1 = 8;
|
|
1353
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
|
1354
|
+
} break;
|
|
1355
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1356
|
+
{
|
|
1357
|
+
nth0 = 8;
|
|
1358
|
+
nth1 = 8;
|
|
1359
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
|
1360
|
+
} break;
|
|
1361
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1362
|
+
{
|
|
1363
|
+
nth0 = 8;
|
|
1364
|
+
nth1 = 8;
|
|
1365
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
|
1366
|
+
} break;
|
|
1367
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1368
|
+
{
|
|
1369
|
+
nth0 = 8;
|
|
1370
|
+
nth1 = 8;
|
|
1371
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
1372
|
+
} break;
|
|
1373
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
1374
|
+
{
|
|
1375
|
+
nth0 = 2;
|
|
1376
|
+
nth1 = 32;
|
|
1377
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
|
1378
|
+
} break;
|
|
1379
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
1380
|
+
{
|
|
1381
|
+
nth0 = 2;
|
|
1382
|
+
nth1 = 32;
|
|
1383
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
|
1384
|
+
} break;
|
|
1385
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
1386
|
+
{
|
|
1387
|
+
nth0 = 4; //1;
|
|
1388
|
+
nth1 = 8; //32;
|
|
1389
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
|
1390
|
+
} break;
|
|
1391
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
1392
|
+
{
|
|
1393
|
+
nth0 = 2;
|
|
1394
|
+
nth1 = 32;
|
|
1395
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
|
1396
|
+
} break;
|
|
1397
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
1398
|
+
{
|
|
1399
|
+
nth0 = 2;
|
|
1400
|
+
nth1 = 32;
|
|
1401
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
|
1402
|
+
} break;
|
|
1403
|
+
case WSP_GGML_TYPE_IQ2_XXS:
|
|
1404
|
+
{
|
|
1405
|
+
nth0 = 4;
|
|
1406
|
+
nth1 = 16;
|
|
1407
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
|
1408
|
+
} break;
|
|
1409
|
+
case WSP_GGML_TYPE_IQ2_XS:
|
|
1410
|
+
{
|
|
1411
|
+
nth0 = 4;
|
|
1412
|
+
nth1 = 16;
|
|
1413
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
|
1414
|
+
} break;
|
|
1415
|
+
default:
|
|
1416
|
+
{
|
|
1417
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1418
|
+
WSP_GGML_ASSERT(false && "not implemented");
|
|
1419
|
+
}
|
|
1420
|
+
};
|
|
1607
1421
|
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1611
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1612
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1613
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1614
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1615
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1616
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1617
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
1618
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
1619
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
1620
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
1621
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
1622
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1623
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1624
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1625
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1626
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1627
|
-
|
|
1628
|
-
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1629
|
-
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1630
|
-
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1631
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1632
|
-
}
|
|
1633
|
-
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1634
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1635
|
-
}
|
|
1636
|
-
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1637
|
-
#ifdef WSP_GGML_QKK_64
|
|
1638
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1639
|
-
#else
|
|
1640
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1641
|
-
#endif
|
|
1642
|
-
}
|
|
1643
|
-
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1644
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1645
|
-
}
|
|
1646
|
-
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1647
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1648
|
-
} else {
|
|
1649
|
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1650
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1651
|
-
}
|
|
1422
|
+
if (wsp_ggml_is_quantized(src0t)) {
|
|
1423
|
+
WSP_GGML_ASSERT(ne00 >= nth0*nth1);
|
|
1652
1424
|
}
|
|
1653
|
-
} break;
|
|
1654
|
-
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1655
|
-
{
|
|
1656
|
-
//WSP_GGML_ASSERT(ne00 == ne10);
|
|
1657
|
-
//WSP_GGML_ASSERT(ne03 == ne13);
|
|
1658
|
-
|
|
1659
|
-
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
1660
|
-
|
|
1661
|
-
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
1662
|
-
|
|
1663
|
-
// TODO: make this more general
|
|
1664
|
-
WSP_GGML_ASSERT(n_as <= 8);
|
|
1665
|
-
|
|
1666
|
-
struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1667
|
-
|
|
1668
|
-
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1669
|
-
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1670
|
-
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1671
|
-
const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
|
|
1672
|
-
|
|
1673
|
-
const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
|
|
1674
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1675
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1676
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
|
|
1677
|
-
|
|
1678
|
-
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
1679
|
-
|
|
1680
|
-
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
|
|
1681
|
-
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
|
|
1682
|
-
|
|
1683
|
-
WSP_GGML_ASSERT(ne20 % 32 == 0);
|
|
1684
|
-
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
|
1685
|
-
//WSP_GGML_ASSERT(ne20 >= 64);
|
|
1686
|
-
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1687
|
-
|
|
1688
|
-
const uint r2 = ne12/ne22;
|
|
1689
|
-
const uint r3 = ne13/ne23;
|
|
1690
|
-
|
|
1691
|
-
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1692
|
-
// to the matrix-vector kernel
|
|
1693
|
-
int ne11_mm_min = 1;
|
|
1694
|
-
|
|
1695
|
-
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1696
|
-
|
|
1697
|
-
// batch size
|
|
1698
|
-
WSP_GGML_ASSERT(ne01 == ne11);
|
|
1699
|
-
|
|
1700
|
-
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
|
1701
|
-
|
|
1702
|
-
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1703
|
-
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1704
|
-
// !!!
|
|
1705
|
-
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1706
|
-
// indirect matrix multiplication
|
|
1707
|
-
// !!!
|
|
1708
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
|
1709
|
-
switch (src2->type) {
|
|
1710
|
-
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1711
|
-
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
1712
|
-
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
|
1713
|
-
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
|
1714
|
-
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
|
1715
|
-
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
|
1716
|
-
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
|
1717
|
-
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
|
1718
|
-
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
|
1719
|
-
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
|
1720
|
-
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
|
1721
|
-
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
1722
|
-
default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1723
|
-
}
|
|
1724
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1725
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1726
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1727
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1728
|
-
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1729
|
-
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1730
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1731
|
-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1732
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1733
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1734
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1735
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1736
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1737
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1738
|
-
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
|
1739
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1740
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1741
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1742
|
-
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1743
|
-
// TODO: how to make this an array? read Metal docs
|
|
1744
|
-
for (int j = 0; j < n_as; ++j) {
|
|
1745
|
-
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1746
|
-
|
|
1747
|
-
size_t offs_src_cur = 0;
|
|
1748
|
-
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1749
|
-
|
|
1750
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1751
|
-
}
|
|
1752
|
-
|
|
1753
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1754
|
-
|
|
1755
|
-
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
|
1756
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1757
|
-
} else {
|
|
1758
|
-
int nth0 = 32;
|
|
1759
|
-
int nth1 = 1;
|
|
1760
|
-
int nrows = 1;
|
|
1761
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1762
|
-
|
|
1763
|
-
// use custom matrix x vector kernel
|
|
1764
|
-
switch (src2t) {
|
|
1765
|
-
case WSP_GGML_TYPE_F32:
|
|
1766
|
-
{
|
|
1767
|
-
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1768
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
|
1769
|
-
} break;
|
|
1770
|
-
case WSP_GGML_TYPE_F16:
|
|
1771
|
-
{
|
|
1772
|
-
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1773
|
-
nth0 = 32;
|
|
1774
|
-
nth1 = 1;
|
|
1775
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
|
1776
|
-
} break;
|
|
1777
|
-
case WSP_GGML_TYPE_Q4_0:
|
|
1778
|
-
{
|
|
1779
|
-
nth0 = 8;
|
|
1780
|
-
nth1 = 8;
|
|
1781
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
|
1782
|
-
} break;
|
|
1783
|
-
case WSP_GGML_TYPE_Q4_1:
|
|
1784
|
-
{
|
|
1785
|
-
nth0 = 8;
|
|
1786
|
-
nth1 = 8;
|
|
1787
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
|
1788
|
-
} break;
|
|
1789
|
-
case WSP_GGML_TYPE_Q5_0:
|
|
1790
|
-
{
|
|
1791
|
-
nth0 = 8;
|
|
1792
|
-
nth1 = 8;
|
|
1793
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
|
1794
|
-
} break;
|
|
1795
|
-
case WSP_GGML_TYPE_Q5_1:
|
|
1796
|
-
{
|
|
1797
|
-
nth0 = 8;
|
|
1798
|
-
nth1 = 8;
|
|
1799
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
|
1800
|
-
} break;
|
|
1801
|
-
case WSP_GGML_TYPE_Q8_0:
|
|
1802
|
-
{
|
|
1803
|
-
nth0 = 8;
|
|
1804
|
-
nth1 = 8;
|
|
1805
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
|
1806
|
-
} break;
|
|
1807
|
-
case WSP_GGML_TYPE_Q2_K:
|
|
1808
|
-
{
|
|
1809
|
-
nth0 = 2;
|
|
1810
|
-
nth1 = 32;
|
|
1811
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
|
1812
|
-
} break;
|
|
1813
|
-
case WSP_GGML_TYPE_Q3_K:
|
|
1814
|
-
{
|
|
1815
|
-
nth0 = 2;
|
|
1816
|
-
nth1 = 32;
|
|
1817
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
|
1818
|
-
} break;
|
|
1819
|
-
case WSP_GGML_TYPE_Q4_K:
|
|
1820
|
-
{
|
|
1821
|
-
nth0 = 4; //1;
|
|
1822
|
-
nth1 = 8; //32;
|
|
1823
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
|
1824
|
-
} break;
|
|
1825
|
-
case WSP_GGML_TYPE_Q5_K:
|
|
1826
|
-
{
|
|
1827
|
-
nth0 = 2;
|
|
1828
|
-
nth1 = 32;
|
|
1829
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
|
1830
|
-
} break;
|
|
1831
|
-
case WSP_GGML_TYPE_Q6_K:
|
|
1832
|
-
{
|
|
1833
|
-
nth0 = 2;
|
|
1834
|
-
nth1 = 32;
|
|
1835
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
|
1836
|
-
} break;
|
|
1837
|
-
default:
|
|
1838
|
-
{
|
|
1839
|
-
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1840
|
-
WSP_GGML_ASSERT(false && "not implemented");
|
|
1841
|
-
}
|
|
1842
|
-
};
|
|
1843
1425
|
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1880
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1881
|
-
}
|
|
1882
|
-
else if (src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1883
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1884
|
-
}
|
|
1885
|
-
else if (src2t == WSP_GGML_TYPE_Q3_K) {
|
|
1426
|
+
[encoder setComputePipelineState:pipeline];
|
|
1427
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1428
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1429
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1430
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1431
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1432
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1433
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1434
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1435
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1436
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
1437
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
1438
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
1439
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
1440
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
1441
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1442
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1443
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1444
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1445
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1446
|
+
|
|
1447
|
+
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1448
|
+
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1449
|
+
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1450
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1451
|
+
}
|
|
1452
|
+
else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) {
|
|
1453
|
+
const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
1454
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1455
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1456
|
+
}
|
|
1457
|
+
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1458
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1459
|
+
}
|
|
1460
|
+
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1886
1461
|
#ifdef WSP_GGML_QKK_64
|
|
1887
|
-
|
|
1462
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1888
1463
|
#else
|
|
1889
|
-
|
|
1464
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1890
1465
|
#endif
|
|
1891
|
-
}
|
|
1892
|
-
else if (src2t == WSP_GGML_TYPE_Q5_K) {
|
|
1893
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1894
|
-
}
|
|
1895
|
-
else if (src2t == WSP_GGML_TYPE_Q6_K) {
|
|
1896
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1897
|
-
} else {
|
|
1898
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1899
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1900
|
-
}
|
|
1901
1466
|
}
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
{
|
|
1905
|
-
switch (src0->type) {
|
|
1906
|
-
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
|
1907
|
-
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
|
1908
|
-
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
|
1909
|
-
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
1910
|
-
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
|
1911
|
-
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
|
1912
|
-
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
|
1913
|
-
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
|
1914
|
-
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
|
1915
|
-
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
|
1916
|
-
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
|
1917
|
-
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
|
1918
|
-
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1467
|
+
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1468
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1919
1469
|
}
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1926
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1927
|
-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1928
|
-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1929
|
-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1930
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1931
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1932
|
-
|
|
1933
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1934
|
-
} break;
|
|
1935
|
-
case WSP_GGML_OP_RMS_NORM:
|
|
1936
|
-
{
|
|
1937
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1938
|
-
|
|
1939
|
-
float eps;
|
|
1940
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1941
|
-
|
|
1942
|
-
int nth = 32; // SIMD width
|
|
1943
|
-
|
|
1944
|
-
while (nth < ne00/4 && nth < 1024) {
|
|
1945
|
-
nth *= 2;
|
|
1470
|
+
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1471
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1472
|
+
} else {
|
|
1473
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1474
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1946
1475
|
}
|
|
1476
|
+
}
|
|
1477
|
+
} break;
|
|
1478
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1479
|
+
{
|
|
1480
|
+
//WSP_GGML_ASSERT(ne00 == ne10);
|
|
1481
|
+
//WSP_GGML_ASSERT(ne03 == ne13);
|
|
1947
1482
|
|
|
1948
|
-
|
|
1949
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1950
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1951
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1952
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1953
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1954
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1955
|
-
|
|
1956
|
-
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1957
|
-
|
|
1958
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1959
|
-
} break;
|
|
1960
|
-
case WSP_GGML_OP_GROUP_NORM:
|
|
1961
|
-
{
|
|
1962
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1963
|
-
|
|
1964
|
-
//float eps;
|
|
1965
|
-
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
1966
|
-
|
|
1967
|
-
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
1968
|
-
|
|
1969
|
-
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
1970
|
-
|
|
1971
|
-
int nth = 32; // SIMD width
|
|
1972
|
-
|
|
1973
|
-
//while (nth < ne00/4 && nth < 1024) {
|
|
1974
|
-
// nth *= 2;
|
|
1975
|
-
//}
|
|
1976
|
-
|
|
1977
|
-
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
|
1978
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1979
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1980
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1981
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1982
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1983
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
1984
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
1985
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
1986
|
-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
1987
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
1988
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1989
|
-
|
|
1990
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1991
|
-
} break;
|
|
1992
|
-
case WSP_GGML_OP_NORM:
|
|
1993
|
-
{
|
|
1994
|
-
float eps;
|
|
1995
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1996
|
-
|
|
1997
|
-
const int nth = MIN(256, ne00);
|
|
1998
|
-
|
|
1999
|
-
[encoder setComputePipelineState:ctx->pipeline_norm];
|
|
2000
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2001
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2002
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2003
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
2004
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
2005
|
-
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
2006
|
-
|
|
2007
|
-
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
2008
|
-
|
|
2009
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2010
|
-
} break;
|
|
2011
|
-
case WSP_GGML_OP_ALIBI:
|
|
2012
|
-
{
|
|
2013
|
-
WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
|
|
2014
|
-
|
|
2015
|
-
const int nth = MIN(1024, ne00);
|
|
2016
|
-
|
|
2017
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
2018
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
2019
|
-
float max_bias;
|
|
2020
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
2021
|
-
|
|
2022
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
2023
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
2024
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1483
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
2025
1484
|
|
|
2026
|
-
|
|
2027
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2028
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2029
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2030
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2031
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2032
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2033
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2034
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2035
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2036
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2037
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2038
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2039
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2040
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2041
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2042
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2043
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2044
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
2045
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
2046
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
2047
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1485
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
2048
1486
|
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
case WSP_GGML_OP_ROPE:
|
|
2052
|
-
{
|
|
2053
|
-
WSP_GGML_ASSERT(ne10 == ne02);
|
|
2054
|
-
|
|
2055
|
-
const int nth = MIN(1024, ne00);
|
|
2056
|
-
|
|
2057
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
2058
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
2059
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
2060
|
-
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
2061
|
-
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
2062
|
-
|
|
2063
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
2064
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
2065
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
2066
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
2067
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
2068
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
2069
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1487
|
+
// TODO: make this more general
|
|
1488
|
+
WSP_GGML_ASSERT(n_as <= 8);
|
|
2070
1489
|
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
|
2074
|
-
default: WSP_GGML_ASSERT(false);
|
|
2075
|
-
};
|
|
1490
|
+
// max size of the src1ids array in the kernel stack
|
|
1491
|
+
WSP_GGML_ASSERT(ne11 <= 512);
|
|
2076
1492
|
|
|
2077
|
-
|
|
2078
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
2079
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2080
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
2081
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
2082
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
2083
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
2084
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
2085
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
2086
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
2087
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
2088
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
2089
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
2090
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
2091
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
2092
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
2093
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
2094
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
2095
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
2096
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
2097
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
2098
|
-
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
2099
|
-
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
2100
|
-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
2101
|
-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
2102
|
-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
2103
|
-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
2104
|
-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
2105
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
1493
|
+
struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
2106
1494
|
|
|
2107
|
-
|
|
2108
|
-
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
2112
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
2113
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
1495
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1496
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1497
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1498
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
|
|
2114
1499
|
|
|
2115
|
-
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
2120
|
-
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
2121
|
-
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
1500
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
|
|
1501
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1502
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1503
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
|
|
2122
1504
|
|
|
2123
|
-
|
|
2124
|
-
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
2125
|
-
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
2126
|
-
const int32_t IW = src1->ne[0];
|
|
1505
|
+
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
2127
1506
|
|
|
2128
|
-
|
|
2129
|
-
|
|
1507
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
|
|
1508
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
|
|
2130
1509
|
|
|
2131
|
-
|
|
2132
|
-
const int32_t OW = dst->ne[1];
|
|
1510
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
2133
1511
|
|
|
2134
|
-
|
|
1512
|
+
const uint r2 = ne12/ne22;
|
|
1513
|
+
const uint r3 = ne13/ne23;
|
|
2135
1514
|
|
|
2136
|
-
|
|
2137
|
-
|
|
1515
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1516
|
+
// to the matrix-vector kernel
|
|
1517
|
+
int ne11_mm_min = n_as;
|
|
2138
1518
|
|
|
2139
|
-
|
|
2140
|
-
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
|
|
2141
|
-
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
|
2142
|
-
default: WSP_GGML_ASSERT(false);
|
|
2143
|
-
};
|
|
1519
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
|
2144
1520
|
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
2148
|
-
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
2149
|
-
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
2150
|
-
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
2151
|
-
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
2152
|
-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
2153
|
-
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
2154
|
-
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
2155
|
-
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
2156
|
-
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
2157
|
-
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
2158
|
-
|
|
2159
|
-
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2160
|
-
} break;
|
|
2161
|
-
case WSP_GGML_OP_UPSCALE:
|
|
2162
|
-
{
|
|
2163
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2164
|
-
|
|
2165
|
-
const int sf = dst->op_params[0];
|
|
2166
|
-
|
|
2167
|
-
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
|
2168
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2169
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2170
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2171
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2172
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2173
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2174
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2175
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2176
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2177
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2178
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2179
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2180
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2181
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2182
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2183
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2184
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2185
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2186
|
-
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2187
|
-
|
|
2188
|
-
const int nth = MIN(1024, ne0);
|
|
2189
|
-
|
|
2190
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2191
|
-
} break;
|
|
2192
|
-
case WSP_GGML_OP_PAD:
|
|
2193
|
-
{
|
|
2194
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2195
|
-
|
|
2196
|
-
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
|
2197
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2198
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2199
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2200
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2201
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2202
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2203
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2204
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2205
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2206
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2207
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2208
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2209
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2210
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2211
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2212
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2213
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2214
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2215
|
-
|
|
2216
|
-
const int nth = MIN(1024, ne0);
|
|
2217
|
-
|
|
2218
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2219
|
-
} break;
|
|
2220
|
-
case WSP_GGML_OP_ARGSORT:
|
|
2221
|
-
{
|
|
2222
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2223
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
|
|
2224
|
-
|
|
2225
|
-
const int nrows = wsp_ggml_nrows(src0);
|
|
2226
|
-
|
|
2227
|
-
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
|
|
2228
|
-
|
|
2229
|
-
switch (order) {
|
|
2230
|
-
case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
|
2231
|
-
case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
|
2232
|
-
default: WSP_GGML_ASSERT(false);
|
|
2233
|
-
};
|
|
1521
|
+
// batch size
|
|
1522
|
+
WSP_GGML_ASSERT(ne01 == ne11);
|
|
2234
1523
|
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
1524
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1525
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1526
|
+
// !!!
|
|
1527
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1528
|
+
// indirect matrix multiplication
|
|
1529
|
+
// !!!
|
|
1530
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1531
|
+
ne20 % 32 == 0 && ne20 >= 64 &&
|
|
1532
|
+
ne11 > ne11_mm_min) {
|
|
2238
1533
|
|
|
2239
|
-
|
|
2240
|
-
} break;
|
|
2241
|
-
case WSP_GGML_OP_LEAKY_RELU:
|
|
2242
|
-
{
|
|
2243
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
1534
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2244
1535
|
|
|
2245
|
-
|
|
2246
|
-
|
|
1536
|
+
switch (src2->type) {
|
|
1537
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
|
1538
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
|
1539
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
|
1540
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
|
1541
|
+
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
|
1542
|
+
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
|
1543
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
|
1544
|
+
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
|
1545
|
+
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
|
1546
|
+
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
|
1547
|
+
case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
|
1548
|
+
case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
|
1549
|
+
case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
|
1550
|
+
case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
|
1551
|
+
default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1552
|
+
}
|
|
2247
1553
|
|
|
2248
|
-
[encoder setComputePipelineState:
|
|
2249
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
|
2250
|
-
[encoder setBuffer:
|
|
2251
|
-
[encoder
|
|
1554
|
+
[encoder setComputePipelineState:pipeline];
|
|
1555
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1556
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1557
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1558
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1559
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1560
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1561
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1562
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1563
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1564
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1565
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1566
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1567
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1568
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1569
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
1570
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1571
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1572
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1573
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1574
|
+
// TODO: how to make this an array? read Metal docs
|
|
1575
|
+
for (int j = 0; j < 8; ++j) {
|
|
1576
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1577
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1578
|
+
|
|
1579
|
+
size_t offs_src_cur = 0;
|
|
1580
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1581
|
+
|
|
1582
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1583
|
+
}
|
|
2252
1584
|
|
|
2253
|
-
|
|
1585
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
2254
1586
|
|
|
2255
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
|
2256
|
-
}
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
1587
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1588
|
+
} else {
|
|
1589
|
+
int nth0 = 32;
|
|
1590
|
+
int nth1 = 1;
|
|
1591
|
+
int nrows = 1;
|
|
1592
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
2262
1593
|
|
|
2263
|
-
|
|
1594
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2264
1595
|
|
|
2265
|
-
|
|
1596
|
+
// use custom matrix x vector kernel
|
|
1597
|
+
switch (src2t) {
|
|
2266
1598
|
case WSP_GGML_TYPE_F32:
|
|
2267
1599
|
{
|
|
2268
|
-
WSP_GGML_ASSERT(
|
|
2269
|
-
|
|
2270
|
-
switch (dstt) {
|
|
2271
|
-
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
2272
|
-
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
2273
|
-
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
|
2274
|
-
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
|
2275
|
-
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
|
2276
|
-
//case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
|
2277
|
-
//case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
|
2278
|
-
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
2279
|
-
};
|
|
1600
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1601
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
|
2280
1602
|
} break;
|
|
2281
1603
|
case WSP_GGML_TYPE_F16:
|
|
2282
1604
|
{
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
1605
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1606
|
+
nth0 = 32;
|
|
1607
|
+
nth1 = 1;
|
|
1608
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
|
1609
|
+
} break;
|
|
1610
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1611
|
+
{
|
|
1612
|
+
nth0 = 8;
|
|
1613
|
+
nth1 = 8;
|
|
1614
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
|
1615
|
+
} break;
|
|
1616
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1617
|
+
{
|
|
1618
|
+
nth0 = 8;
|
|
1619
|
+
nth1 = 8;
|
|
1620
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
|
1621
|
+
} break;
|
|
1622
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1623
|
+
{
|
|
1624
|
+
nth0 = 8;
|
|
1625
|
+
nth1 = 8;
|
|
1626
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
|
1627
|
+
} break;
|
|
1628
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1629
|
+
{
|
|
1630
|
+
nth0 = 8;
|
|
1631
|
+
nth1 = 8;
|
|
1632
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
|
1633
|
+
} break;
|
|
1634
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1635
|
+
{
|
|
1636
|
+
nth0 = 8;
|
|
1637
|
+
nth1 = 8;
|
|
1638
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
|
1639
|
+
} break;
|
|
1640
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
1641
|
+
{
|
|
1642
|
+
nth0 = 2;
|
|
1643
|
+
nth1 = 32;
|
|
1644
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
|
1645
|
+
} break;
|
|
1646
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
1647
|
+
{
|
|
1648
|
+
nth0 = 2;
|
|
1649
|
+
nth1 = 32;
|
|
1650
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
|
1651
|
+
} break;
|
|
1652
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
1653
|
+
{
|
|
1654
|
+
nth0 = 4; //1;
|
|
1655
|
+
nth1 = 8; //32;
|
|
1656
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
|
1657
|
+
} break;
|
|
1658
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
1659
|
+
{
|
|
1660
|
+
nth0 = 2;
|
|
1661
|
+
nth1 = 32;
|
|
1662
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
|
1663
|
+
} break;
|
|
1664
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
1665
|
+
{
|
|
1666
|
+
nth0 = 2;
|
|
1667
|
+
nth1 = 32;
|
|
1668
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
|
1669
|
+
} break;
|
|
1670
|
+
case WSP_GGML_TYPE_IQ2_XXS:
|
|
1671
|
+
{
|
|
1672
|
+
nth0 = 4;
|
|
1673
|
+
nth1 = 16;
|
|
1674
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
|
2288
1675
|
} break;
|
|
2289
|
-
|
|
1676
|
+
case WSP_GGML_TYPE_IQ2_XS:
|
|
1677
|
+
{
|
|
1678
|
+
nth0 = 4;
|
|
1679
|
+
nth1 = 16;
|
|
1680
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
|
1681
|
+
} break;
|
|
1682
|
+
default:
|
|
1683
|
+
{
|
|
1684
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
1685
|
+
WSP_GGML_ASSERT(false && "not implemented");
|
|
1686
|
+
}
|
|
1687
|
+
};
|
|
1688
|
+
|
|
1689
|
+
if (wsp_ggml_is_quantized(src2t)) {
|
|
1690
|
+
WSP_GGML_ASSERT(ne20 >= nth0*nth1);
|
|
2290
1691
|
}
|
|
2291
1692
|
|
|
2292
|
-
|
|
2293
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2294
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2295
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2296
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2297
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2298
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2299
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2300
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2301
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2302
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2303
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2304
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2305
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2306
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2307
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2308
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2309
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1693
|
+
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
|
2310
1694
|
|
|
2311
|
-
[encoder
|
|
2312
|
-
|
|
2313
|
-
|
|
2314
|
-
|
|
2315
|
-
|
|
2316
|
-
|
|
1695
|
+
[encoder setComputePipelineState:pipeline];
|
|
1696
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1697
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1698
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1699
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1700
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1701
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
1702
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
|
1703
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
|
1704
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
|
1705
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
|
1706
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1707
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
|
1708
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1709
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1710
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1711
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1712
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1713
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
1714
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
|
1715
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
1716
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
|
1717
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
|
1718
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
|
1719
|
+
// TODO: how to make this an array? read Metal docs
|
|
1720
|
+
for (int j = 0; j < 8; ++j) {
|
|
1721
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1722
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1723
|
+
|
|
1724
|
+
size_t offs_src_cur = 0;
|
|
1725
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1726
|
+
|
|
1727
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1728
|
+
}
|
|
1729
|
+
|
|
1730
|
+
if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
|
|
1731
|
+
src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
|
|
1732
|
+
src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1733
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1734
|
+
}
|
|
1735
|
+
else if (src2t == WSP_GGML_TYPE_IQ2_XXS || src2t == WSP_GGML_TYPE_IQ2_XS) {
|
|
1736
|
+
const int mem_size = src2t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
1737
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1738
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1739
|
+
}
|
|
1740
|
+
else if (src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1741
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1742
|
+
}
|
|
1743
|
+
else if (src2t == WSP_GGML_TYPE_Q3_K) {
|
|
1744
|
+
#ifdef WSP_GGML_QKK_64
|
|
1745
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1746
|
+
#else
|
|
1747
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1748
|
+
#endif
|
|
1749
|
+
}
|
|
1750
|
+
else if (src2t == WSP_GGML_TYPE_Q5_K) {
|
|
1751
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1752
|
+
}
|
|
1753
|
+
else if (src2t == WSP_GGML_TYPE_Q6_K) {
|
|
1754
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1755
|
+
} else {
|
|
1756
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1757
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1758
|
+
}
|
|
1759
|
+
}
|
|
1760
|
+
} break;
|
|
1761
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
1762
|
+
{
|
|
1763
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1764
|
+
|
|
1765
|
+
switch (src0->type) {
|
|
1766
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
|
1767
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
|
1768
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
|
1769
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
|
1770
|
+
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
1771
|
+
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
|
1772
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
|
1773
|
+
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
|
1774
|
+
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
|
1775
|
+
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
1776
|
+
case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
|
|
1777
|
+
case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
|
1778
|
+
case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
|
1779
|
+
case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
|
1780
|
+
case WSP_GGML_TYPE_I32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
1781
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1782
|
+
}
|
|
1783
|
+
|
|
1784
|
+
[encoder setComputePipelineState:pipeline];
|
|
1785
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1786
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1787
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1788
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1789
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1790
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1791
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1792
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1793
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1794
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1795
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1796
|
+
|
|
1797
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1798
|
+
} break;
|
|
1799
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
1800
|
+
{
|
|
1801
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1802
|
+
|
|
1803
|
+
float eps;
|
|
1804
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1805
|
+
|
|
1806
|
+
int nth = 32; // SIMD width
|
|
1807
|
+
|
|
1808
|
+
while (nth < ne00/4 && nth < 1024) {
|
|
1809
|
+
nth *= 2;
|
|
2317
1810
|
}
|
|
2318
|
-
}
|
|
2319
|
-
}
|
|
2320
1811
|
|
|
2321
|
-
|
|
2322
|
-
|
|
2323
|
-
|
|
1812
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
|
1813
|
+
|
|
1814
|
+
[encoder setComputePipelineState:pipeline];
|
|
1815
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1816
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1817
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1818
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1819
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1820
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1821
|
+
|
|
1822
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1823
|
+
|
|
1824
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1825
|
+
} break;
|
|
1826
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
1827
|
+
{
|
|
1828
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1829
|
+
|
|
1830
|
+
//float eps;
|
|
1831
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
1832
|
+
|
|
1833
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
1834
|
+
|
|
1835
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
1836
|
+
|
|
1837
|
+
int nth = 32; // SIMD width
|
|
1838
|
+
|
|
1839
|
+
//while (nth < ne00/4 && nth < 1024) {
|
|
1840
|
+
// nth *= 2;
|
|
1841
|
+
//}
|
|
1842
|
+
|
|
1843
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
|
1844
|
+
|
|
1845
|
+
[encoder setComputePipelineState:pipeline];
|
|
1846
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1847
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1848
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1849
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1850
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1851
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
1852
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
1853
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
1854
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
1855
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
1856
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1857
|
+
|
|
1858
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1859
|
+
} break;
|
|
1860
|
+
case WSP_GGML_OP_NORM:
|
|
1861
|
+
{
|
|
1862
|
+
float eps;
|
|
1863
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1864
|
+
|
|
1865
|
+
const int nth = MIN(256, ne00);
|
|
1866
|
+
|
|
1867
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
|
1868
|
+
|
|
1869
|
+
[encoder setComputePipelineState:pipeline];
|
|
1870
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1871
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1872
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1873
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1874
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1875
|
+
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
1876
|
+
|
|
1877
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1878
|
+
|
|
1879
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1880
|
+
} break;
|
|
1881
|
+
case WSP_GGML_OP_ALIBI:
|
|
1882
|
+
{
|
|
1883
|
+
WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
|
|
1884
|
+
|
|
1885
|
+
const int nth = MIN(1024, ne00);
|
|
1886
|
+
|
|
1887
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1888
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
1889
|
+
float max_bias;
|
|
1890
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1891
|
+
|
|
1892
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
1893
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
1894
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1895
|
+
|
|
1896
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
|
1897
|
+
|
|
1898
|
+
[encoder setComputePipelineState:pipeline];
|
|
1899
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1900
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1901
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1902
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1903
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1904
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1905
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1906
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1907
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1908
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1909
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1910
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1911
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1912
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1913
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1914
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1915
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1916
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1917
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
1918
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
1919
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1920
|
+
|
|
1921
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1922
|
+
} break;
|
|
1923
|
+
case WSP_GGML_OP_ROPE:
|
|
1924
|
+
{
|
|
1925
|
+
WSP_GGML_ASSERT(ne10 == ne02);
|
|
1926
|
+
|
|
1927
|
+
const int nth = MIN(1024, ne00);
|
|
1928
|
+
|
|
1929
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1930
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1931
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1932
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
1933
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
1934
|
+
|
|
1935
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1936
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1937
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1938
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1939
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1940
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1941
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1942
|
+
|
|
1943
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1944
|
+
|
|
1945
|
+
switch (src0->type) {
|
|
1946
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
|
1947
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
|
1948
|
+
default: WSP_GGML_ASSERT(false);
|
|
1949
|
+
};
|
|
1950
|
+
|
|
1951
|
+
[encoder setComputePipelineState:pipeline];
|
|
1952
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1953
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1954
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1955
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1956
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
1957
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
1958
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
1959
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
1960
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
1961
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
1962
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
1963
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
1964
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
1965
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
1966
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
1967
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
1968
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
1969
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
1970
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
1971
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
1972
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
1973
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
1974
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
1975
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
1976
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
1977
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
1978
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
1979
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
1980
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
1981
|
+
|
|
1982
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1983
|
+
} break;
|
|
1984
|
+
case WSP_GGML_OP_IM2COL:
|
|
1985
|
+
{
|
|
1986
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
1987
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1988
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
1989
|
+
|
|
1990
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
1991
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
1992
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
1993
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
1994
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
1995
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
1996
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
1997
|
+
|
|
1998
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
1999
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
2000
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
2001
|
+
const int32_t IW = src1->ne[0];
|
|
2002
|
+
|
|
2003
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
2004
|
+
const int32_t KW = src0->ne[0];
|
|
2005
|
+
|
|
2006
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
2007
|
+
const int32_t OW = dst->ne[1];
|
|
2008
|
+
|
|
2009
|
+
const int32_t CHW = IC * KH * KW;
|
|
2010
|
+
|
|
2011
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
2012
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
2013
|
+
|
|
2014
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2015
|
+
|
|
2016
|
+
switch (src0->type) {
|
|
2017
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
|
|
2018
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
|
2019
|
+
default: WSP_GGML_ASSERT(false);
|
|
2020
|
+
};
|
|
2021
|
+
|
|
2022
|
+
[encoder setComputePipelineState:pipeline];
|
|
2023
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
2024
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2025
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
2026
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
2027
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
2028
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
2029
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
2030
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
2031
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
2032
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
2033
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
2034
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
2035
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
2036
|
+
|
|
2037
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2038
|
+
} break;
|
|
2039
|
+
case WSP_GGML_OP_UPSCALE:
|
|
2040
|
+
{
|
|
2041
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2042
|
+
|
|
2043
|
+
const int sf = dst->op_params[0];
|
|
2044
|
+
|
|
2045
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
|
2046
|
+
|
|
2047
|
+
[encoder setComputePipelineState:pipeline];
|
|
2048
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2049
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2050
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2051
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2052
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2053
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2054
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2055
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2056
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2057
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2058
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2059
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2060
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2061
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2062
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2063
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2064
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2065
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2066
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2067
|
+
|
|
2068
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
2069
|
+
|
|
2070
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2071
|
+
} break;
|
|
2072
|
+
case WSP_GGML_OP_PAD:
|
|
2073
|
+
{
|
|
2074
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2075
|
+
|
|
2076
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
|
2077
|
+
|
|
2078
|
+
[encoder setComputePipelineState:pipeline];
|
|
2079
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2080
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2081
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2082
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2083
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2084
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2085
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2086
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2087
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2088
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2089
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2090
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2091
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2092
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2093
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2094
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2095
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2096
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2097
|
+
|
|
2098
|
+
const int nth = MIN(1024, ne0);
|
|
2099
|
+
|
|
2100
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2101
|
+
} break;
|
|
2102
|
+
case WSP_GGML_OP_ARGSORT:
|
|
2103
|
+
{
|
|
2104
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2105
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
|
|
2106
|
+
|
|
2107
|
+
const int nrows = wsp_ggml_nrows(src0);
|
|
2108
|
+
|
|
2109
|
+
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
|
|
2110
|
+
|
|
2111
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2112
|
+
|
|
2113
|
+
switch (order) {
|
|
2114
|
+
case WSP_GGML_SORT_ASC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
|
2115
|
+
case WSP_GGML_SORT_DESC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
|
2116
|
+
default: WSP_GGML_ASSERT(false);
|
|
2117
|
+
};
|
|
2118
|
+
|
|
2119
|
+
[encoder setComputePipelineState:pipeline];
|
|
2120
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2121
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2122
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2123
|
+
|
|
2124
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
2125
|
+
} break;
|
|
2126
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
2127
|
+
{
|
|
2128
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2129
|
+
|
|
2130
|
+
float slope;
|
|
2131
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
|
2132
|
+
|
|
2133
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
|
2134
|
+
|
|
2135
|
+
[encoder setComputePipelineState:pipeline];
|
|
2136
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2137
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2138
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
|
2139
|
+
|
|
2140
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2141
|
+
|
|
2142
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2143
|
+
} break;
|
|
2144
|
+
case WSP_GGML_OP_DUP:
|
|
2145
|
+
case WSP_GGML_OP_CPY:
|
|
2146
|
+
case WSP_GGML_OP_CONT:
|
|
2147
|
+
{
|
|
2148
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
2149
|
+
|
|
2150
|
+
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
2151
|
+
|
|
2152
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2153
|
+
|
|
2154
|
+
switch (src0t) {
|
|
2155
|
+
case WSP_GGML_TYPE_F32:
|
|
2156
|
+
{
|
|
2157
|
+
WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
|
|
2158
|
+
|
|
2159
|
+
switch (dstt) {
|
|
2160
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
|
2161
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
|
2162
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
|
2163
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
|
2164
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
2165
|
+
//case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
|
2166
|
+
//case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
|
2167
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
2168
|
+
};
|
|
2169
|
+
} break;
|
|
2170
|
+
case WSP_GGML_TYPE_F16:
|
|
2171
|
+
{
|
|
2172
|
+
switch (dstt) {
|
|
2173
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
|
2174
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
|
2175
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
2176
|
+
};
|
|
2177
|
+
} break;
|
|
2178
|
+
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
[encoder setComputePipelineState:pipeline];
|
|
2182
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2183
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2184
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2185
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2186
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2187
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2188
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2189
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2190
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2191
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2192
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2193
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2194
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2195
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2196
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2197
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2198
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2199
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
2200
|
+
|
|
2201
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2202
|
+
} break;
|
|
2203
|
+
default:
|
|
2204
|
+
{
|
|
2205
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
2206
|
+
WSP_GGML_ASSERT(false);
|
|
2207
|
+
}
|
|
2324
2208
|
}
|
|
2325
2209
|
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2210
|
+
#ifndef WSP_GGML_METAL_NDEBUG
|
|
2211
|
+
[encoder popDebugGroup];
|
|
2212
|
+
#endif
|
|
2213
|
+
}
|
|
2214
|
+
|
|
2215
|
+
[encoder endEncoding];
|
|
2329
2216
|
|
|
2330
|
-
|
|
2331
|
-
|
|
2217
|
+
[command_buffer commit];
|
|
2218
|
+
});
|
|
2332
2219
|
|
|
2333
|
-
// check status of command
|
|
2220
|
+
// Wait for completion and check status of each command buffer
|
|
2334
2221
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
2335
|
-
for (int i = 0; i < n_cb; i++) {
|
|
2336
|
-
[ctx->command_buffers[i] waitUntilCompleted];
|
|
2337
2222
|
|
|
2338
|
-
|
|
2223
|
+
for (int i = 0; i < n_cb; ++i) {
|
|
2224
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[i];
|
|
2225
|
+
[command_buffer waitUntilCompleted];
|
|
2226
|
+
|
|
2227
|
+
MTLCommandBufferStatus status = [command_buffer status];
|
|
2339
2228
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
2340
2229
|
WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
2341
|
-
|
|
2230
|
+
return false;
|
|
2342
2231
|
}
|
|
2343
2232
|
}
|
|
2344
2233
|
|
|
2345
|
-
|
|
2234
|
+
return true;
|
|
2346
2235
|
}
|
|
2347
2236
|
|
|
2348
2237
|
////////////////////////////////////////////////////////////////////////////////
|
|
2349
2238
|
|
|
2350
2239
|
// backend interface
|
|
2351
2240
|
|
|
2241
|
+
// default buffer
|
|
2352
2242
|
static id<MTLDevice> g_backend_device = nil;
|
|
2353
2243
|
static int g_backend_device_ref_count = 0;
|
|
2354
2244
|
|
|
@@ -2372,64 +2262,98 @@ static void wsp_ggml_backend_metal_free_device(void) {
|
|
|
2372
2262
|
}
|
|
2373
2263
|
}
|
|
2374
2264
|
|
|
2375
|
-
static
|
|
2376
|
-
|
|
2265
|
+
WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_get_name(wsp_ggml_backend_buffer_t buffer) {
|
|
2266
|
+
return "Metal";
|
|
2377
2267
|
|
|
2378
|
-
|
|
2268
|
+
UNUSED(buffer);
|
|
2379
2269
|
}
|
|
2380
2270
|
|
|
2381
|
-
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
2271
|
+
WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
2382
2272
|
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
2383
2273
|
|
|
2384
2274
|
wsp_ggml_backend_metal_free_device();
|
|
2385
2275
|
|
|
2386
|
-
|
|
2387
|
-
|
|
2276
|
+
if (ctx->owned) {
|
|
2277
|
+
free(ctx->all_data);
|
|
2278
|
+
}
|
|
2388
2279
|
|
|
2389
|
-
|
|
2280
|
+
free(ctx);
|
|
2390
2281
|
}
|
|
2391
2282
|
|
|
2392
|
-
static void
|
|
2393
|
-
|
|
2394
|
-
|
|
2283
|
+
WSP_GGML_CALL static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
|
|
2284
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
2285
|
+
|
|
2286
|
+
return ctx->all_data;
|
|
2287
|
+
}
|
|
2395
2288
|
|
|
2289
|
+
WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
2396
2290
|
memcpy((char *)tensor->data + offset, data, size);
|
|
2397
2291
|
|
|
2398
2292
|
UNUSED(buffer);
|
|
2399
2293
|
}
|
|
2400
2294
|
|
|
2401
|
-
static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
2402
|
-
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
2403
|
-
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
2404
|
-
|
|
2295
|
+
WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
2405
2296
|
memcpy(data, (const char *)tensor->data + offset, size);
|
|
2406
2297
|
|
|
2407
2298
|
UNUSED(buffer);
|
|
2408
2299
|
}
|
|
2409
2300
|
|
|
2410
|
-
static
|
|
2411
|
-
|
|
2301
|
+
WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
2302
|
+
if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
|
|
2303
|
+
memcpy(dst->data, src->data, wsp_ggml_nbytes(src));
|
|
2304
|
+
return true;
|
|
2305
|
+
}
|
|
2306
|
+
return false;
|
|
2412
2307
|
|
|
2413
2308
|
UNUSED(buffer);
|
|
2414
2309
|
}
|
|
2415
2310
|
|
|
2416
|
-
static void
|
|
2417
|
-
|
|
2311
|
+
WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
|
|
2312
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
2418
2313
|
|
|
2419
|
-
|
|
2314
|
+
memset(ctx->all_data, value, ctx->all_size);
|
|
2420
2315
|
}
|
|
2421
2316
|
|
|
2422
|
-
static struct wsp_ggml_backend_buffer_i
|
|
2317
|
+
static struct wsp_ggml_backend_buffer_i wsp_ggml_backend_metal_buffer_i = {
|
|
2318
|
+
/* .get_name = */ wsp_ggml_backend_metal_buffer_get_name,
|
|
2423
2319
|
/* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
|
|
2424
2320
|
/* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
|
|
2425
2321
|
/* .init_tensor = */ NULL,
|
|
2426
2322
|
/* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
|
|
2427
2323
|
/* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
|
|
2428
|
-
/* .
|
|
2429
|
-
/* .
|
|
2324
|
+
/* .cpy_tensor = */ wsp_ggml_backend_metal_buffer_cpy_tensor,
|
|
2325
|
+
/* .clear = */ wsp_ggml_backend_metal_buffer_clear,
|
|
2326
|
+
/* .reset = */ NULL,
|
|
2430
2327
|
};
|
|
2431
2328
|
|
|
2432
|
-
|
|
2329
|
+
// default buffer type
|
|
2330
|
+
|
|
2331
|
+
WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
|
|
2332
|
+
return "Metal";
|
|
2333
|
+
|
|
2334
|
+
UNUSED(buft);
|
|
2335
|
+
}
|
|
2336
|
+
|
|
2337
|
+
static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
2338
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
2339
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
2340
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
2341
|
+
device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
2342
|
+
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
2343
|
+
|
|
2344
|
+
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
|
2345
|
+
WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
2346
|
+
} else {
|
|
2347
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
2348
|
+
}
|
|
2349
|
+
} else {
|
|
2350
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
2351
|
+
}
|
|
2352
|
+
#endif
|
|
2353
|
+
UNUSED(device);
|
|
2354
|
+
}
|
|
2355
|
+
|
|
2356
|
+
WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
2433
2357
|
struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
|
|
2434
2358
|
|
|
2435
2359
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
@@ -2439,33 +2363,59 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
2439
2363
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
2440
2364
|
}
|
|
2441
2365
|
|
|
2442
|
-
|
|
2443
|
-
|
|
2366
|
+
id<MTLDevice> device = wsp_ggml_backend_metal_get_device();
|
|
2367
|
+
|
|
2368
|
+
ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
|
|
2369
|
+
ctx->all_size = size_aligned;
|
|
2370
|
+
ctx->owned = true;
|
|
2371
|
+
ctx->n_buffers = 1;
|
|
2372
|
+
|
|
2373
|
+
ctx->buffers[0].data = ctx->all_data;
|
|
2374
|
+
ctx->buffers[0].size = size;
|
|
2375
|
+
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
|
2444
2376
|
length:size_aligned
|
|
2445
2377
|
options:MTLResourceStorageModeShared
|
|
2446
2378
|
deallocator:nil];
|
|
2447
2379
|
|
|
2448
|
-
|
|
2380
|
+
if (ctx->buffers[0].metal == nil) {
|
|
2381
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
2382
|
+
free(ctx);
|
|
2383
|
+
wsp_ggml_backend_metal_free_device();
|
|
2384
|
+
return NULL;
|
|
2385
|
+
}
|
|
2386
|
+
|
|
2387
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
|
2388
|
+
wsp_ggml_backend_metal_log_allocated_size(device);
|
|
2389
|
+
|
|
2390
|
+
return wsp_ggml_backend_buffer_init(buft, wsp_ggml_backend_metal_buffer_i, ctx, size);
|
|
2449
2391
|
}
|
|
2450
2392
|
|
|
2451
|
-
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
2393
|
+
WSP_GGML_CALL static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
2452
2394
|
return 32;
|
|
2453
2395
|
UNUSED(buft);
|
|
2454
2396
|
}
|
|
2455
2397
|
|
|
2456
|
-
static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
|
|
2398
|
+
WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
|
|
2457
2399
|
return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
|
|
2458
2400
|
|
|
2459
|
-
|
|
2401
|
+
UNUSED(buft);
|
|
2402
|
+
}
|
|
2403
|
+
|
|
2404
|
+
WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
|
|
2405
|
+
return true;
|
|
2406
|
+
|
|
2407
|
+
UNUSED(buft);
|
|
2460
2408
|
}
|
|
2461
2409
|
|
|
2462
|
-
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
2410
|
+
WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
2463
2411
|
static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
|
|
2464
2412
|
/* .iface = */ {
|
|
2413
|
+
/* .get_name = */ wsp_ggml_backend_metal_buffer_type_get_name,
|
|
2465
2414
|
/* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
|
|
2466
2415
|
/* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
|
|
2467
2416
|
/* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
|
|
2468
2417
|
/* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
|
|
2418
|
+
/* .is_host = */ wsp_ggml_backend_metal_buffer_type_is_host,
|
|
2469
2419
|
},
|
|
2470
2420
|
/* .context = */ NULL,
|
|
2471
2421
|
};
|
|
@@ -2473,67 +2423,134 @@ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
|
2473
2423
|
return &wsp_ggml_backend_buffer_type_metal;
|
|
2474
2424
|
}
|
|
2475
2425
|
|
|
2476
|
-
|
|
2426
|
+
// buffer from ptr
|
|
2427
|
+
|
|
2428
|
+
WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
|
2429
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
|
|
2430
|
+
|
|
2431
|
+
ctx->all_data = data;
|
|
2432
|
+
ctx->all_size = size;
|
|
2433
|
+
ctx->owned = false;
|
|
2434
|
+
ctx->n_buffers = 0;
|
|
2435
|
+
|
|
2436
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
2437
|
+
|
|
2438
|
+
// page-align the data ptr
|
|
2439
|
+
{
|
|
2440
|
+
const uintptr_t offs = (uintptr_t) data % size_page;
|
|
2441
|
+
data = (void *) ((char *) data - offs);
|
|
2442
|
+
size += offs;
|
|
2443
|
+
}
|
|
2444
|
+
|
|
2445
|
+
size_t size_aligned = size;
|
|
2446
|
+
if ((size_aligned % size_page) != 0) {
|
|
2447
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
|
2448
|
+
}
|
|
2449
|
+
|
|
2450
|
+
id<MTLDevice> device = wsp_ggml_backend_metal_get_device();
|
|
2451
|
+
|
|
2452
|
+
// the buffer fits into the max buffer size allowed by the device
|
|
2453
|
+
if (size_aligned <= device.maxBufferLength) {
|
|
2454
|
+
ctx->buffers[ctx->n_buffers].data = data;
|
|
2455
|
+
ctx->buffers[ctx->n_buffers].size = size;
|
|
2456
|
+
|
|
2457
|
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
2458
|
+
|
|
2459
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
2460
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
2461
|
+
return false;
|
|
2462
|
+
}
|
|
2463
|
+
|
|
2464
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
|
2465
|
+
|
|
2466
|
+
++ctx->n_buffers;
|
|
2467
|
+
} else {
|
|
2468
|
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
2469
|
+
// one of the views
|
|
2470
|
+
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
2471
|
+
const size_t size_step = device.maxBufferLength - size_ovlp;
|
|
2472
|
+
const size_t size_view = device.maxBufferLength;
|
|
2473
|
+
|
|
2474
|
+
for (size_t i = 0; i < size; i += size_step) {
|
|
2475
|
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
2476
|
+
|
|
2477
|
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
|
2478
|
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
2479
|
+
|
|
2480
|
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
2481
|
+
|
|
2482
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
2483
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
|
2484
|
+
return false;
|
|
2485
|
+
}
|
|
2486
|
+
|
|
2487
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
|
|
2488
|
+
if (i + size_step < size) {
|
|
2489
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
++ctx->n_buffers;
|
|
2493
|
+
}
|
|
2494
|
+
}
|
|
2495
|
+
|
|
2496
|
+
wsp_ggml_backend_metal_log_allocated_size(device);
|
|
2497
|
+
|
|
2498
|
+
return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
|
|
2499
|
+
}
|
|
2500
|
+
|
|
2501
|
+
// backend
|
|
2502
|
+
|
|
2503
|
+
WSP_GGML_CALL static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
2477
2504
|
return "Metal";
|
|
2478
2505
|
|
|
2479
2506
|
UNUSED(backend);
|
|
2480
2507
|
}
|
|
2481
2508
|
|
|
2482
|
-
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
2509
|
+
WSP_GGML_CALL static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
2483
2510
|
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2484
2511
|
wsp_ggml_metal_free(ctx);
|
|
2485
2512
|
free(backend);
|
|
2486
2513
|
}
|
|
2487
2514
|
|
|
2488
|
-
static
|
|
2489
|
-
UNUSED(backend);
|
|
2490
|
-
}
|
|
2491
|
-
|
|
2492
|
-
static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
|
|
2515
|
+
WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
|
|
2493
2516
|
return wsp_ggml_backend_metal_buffer_type();
|
|
2494
2517
|
|
|
2495
2518
|
UNUSED(backend);
|
|
2496
2519
|
}
|
|
2497
2520
|
|
|
2498
|
-
static
|
|
2521
|
+
WSP_GGML_CALL static bool wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
|
|
2499
2522
|
struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2500
2523
|
|
|
2501
|
-
wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
2524
|
+
return wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
2502
2525
|
}
|
|
2503
2526
|
|
|
2504
|
-
static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
2505
|
-
|
|
2527
|
+
WSP_GGML_CALL static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
2528
|
+
struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2506
2529
|
|
|
2507
|
-
|
|
2530
|
+
return wsp_ggml_metal_supports_op(metal_ctx, op);
|
|
2508
2531
|
}
|
|
2509
2532
|
|
|
2510
|
-
static struct wsp_ggml_backend_i
|
|
2533
|
+
static struct wsp_ggml_backend_i wsp_ggml_backend_metal_i = {
|
|
2511
2534
|
/* .get_name = */ wsp_ggml_backend_metal_name,
|
|
2512
2535
|
/* .free = */ wsp_ggml_backend_metal_free,
|
|
2513
2536
|
/* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
|
|
2514
2537
|
/* .set_tensor_async = */ NULL,
|
|
2515
2538
|
/* .get_tensor_async = */ NULL,
|
|
2516
|
-
/* .
|
|
2517
|
-
/* .
|
|
2518
|
-
/* .
|
|
2519
|
-
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
|
2539
|
+
/* .cpy_tensor_async = */ NULL,
|
|
2540
|
+
/* .synchronize = */ NULL,
|
|
2541
|
+
/* .graph_plan_create = */ NULL,
|
|
2520
2542
|
/* .graph_plan_free = */ NULL,
|
|
2521
2543
|
/* .graph_plan_compute = */ NULL,
|
|
2522
2544
|
/* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
|
|
2523
2545
|
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
2524
2546
|
};
|
|
2525
2547
|
|
|
2526
|
-
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
UNUSED(level);
|
|
2531
|
-
UNUSED(user_data);
|
|
2548
|
+
void wsp_ggml_backend_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
2549
|
+
wsp_ggml_metal_log_callback = log_callback;
|
|
2550
|
+
wsp_ggml_metal_log_user_data = user_data;
|
|
2532
2551
|
}
|
|
2533
2552
|
|
|
2534
2553
|
wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
2535
|
-
wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
|
|
2536
|
-
|
|
2537
2554
|
struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
2538
2555
|
|
|
2539
2556
|
if (ctx == NULL) {
|
|
@@ -2543,7 +2560,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
|
2543
2560
|
wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
|
|
2544
2561
|
|
|
2545
2562
|
*metal_backend = (struct wsp_ggml_backend) {
|
|
2546
|
-
/* .interface = */
|
|
2563
|
+
/* .interface = */ wsp_ggml_backend_metal_i,
|
|
2547
2564
|
/* .context = */ ctx,
|
|
2548
2565
|
};
|
|
2549
2566
|
|
|
@@ -2551,7 +2568,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
|
2551
2568
|
}
|
|
2552
2569
|
|
|
2553
2570
|
bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
|
|
2554
|
-
return backend->iface.get_name == wsp_ggml_backend_metal_name;
|
|
2571
|
+
return backend && backend->iface.get_name == wsp_ggml_backend_metal_name;
|
|
2555
2572
|
}
|
|
2556
2573
|
|
|
2557
2574
|
void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
@@ -2559,7 +2576,7 @@ void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
|
2559
2576
|
|
|
2560
2577
|
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2561
2578
|
|
|
2562
|
-
|
|
2579
|
+
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
2563
2580
|
}
|
|
2564
2581
|
|
|
2565
2582
|
bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
|
|
@@ -2570,9 +2587,9 @@ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int fami
|
|
|
2570
2587
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
2571
2588
|
}
|
|
2572
2589
|
|
|
2573
|
-
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2590
|
+
WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2574
2591
|
|
|
2575
|
-
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2592
|
+
WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2576
2593
|
return wsp_ggml_backend_metal_init();
|
|
2577
2594
|
|
|
2578
2595
|
WSP_GGML_UNUSED(params);
|