llama_cpp 0.12.1 → 0.12.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +64 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +7 -0
- data/vendor/tmp/llama.cpp/Makefile +0 -9
- data/vendor/tmp/llama.cpp/ggml-alloc.c +28 -6
- data/vendor/tmp/llama.cpp/ggml-alloc.h +3 -1
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +36 -36
- data/vendor/tmp/llama.cpp/ggml-backend.c +510 -263
- data/vendor/tmp/llama.cpp/ggml-backend.h +42 -32
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +692 -476
- data/vendor/tmp/llama.cpp/ggml-cuda.h +18 -30
- data/vendor/tmp/llama.cpp/ggml-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +4 -56
- data/vendor/tmp/llama.cpp/ggml-metal.m +1860 -2073
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +321 -14
- data/vendor/tmp/llama.cpp/ggml-opencl.h +13 -3
- data/vendor/tmp/llama.cpp/ggml-quants.c +1638 -134
- data/vendor/tmp/llama.cpp/ggml-quants.h +15 -4
- data/vendor/tmp/llama.cpp/ggml.c +142 -64
- data/vendor/tmp/llama.cpp/ggml.h +47 -29
- data/vendor/tmp/llama.cpp/llama.cpp +1219 -1615
- data/vendor/tmp/llama.cpp/llama.h +30 -8
- metadata +2 -2
@@ -24,7 +24,7 @@
|
|
24
24
|
|
25
25
|
#define UNUSED(x) (void)(x)
|
26
26
|
|
27
|
-
#define
|
27
|
+
#define GGML_METAL_MAX_KERNELS 256
|
28
28
|
|
29
29
|
struct ggml_metal_buffer {
|
30
30
|
const char * name;
|
@@ -35,6 +35,134 @@ struct ggml_metal_buffer {
|
|
35
35
|
id<MTLBuffer> metal;
|
36
36
|
};
|
37
37
|
|
38
|
+
struct ggml_metal_kernel {
|
39
|
+
id<MTLFunction> function;
|
40
|
+
id<MTLComputePipelineState> pipeline;
|
41
|
+
};
|
42
|
+
|
43
|
+
enum ggml_metal_kernel_type {
|
44
|
+
GGML_METAL_KERNEL_TYPE_ADD,
|
45
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
46
|
+
GGML_METAL_KERNEL_TYPE_MUL,
|
47
|
+
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
48
|
+
GGML_METAL_KERNEL_TYPE_DIV,
|
49
|
+
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
50
|
+
GGML_METAL_KERNEL_TYPE_SCALE,
|
51
|
+
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
52
|
+
GGML_METAL_KERNEL_TYPE_TANH,
|
53
|
+
GGML_METAL_KERNEL_TYPE_RELU,
|
54
|
+
GGML_METAL_KERNEL_TYPE_GELU,
|
55
|
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
56
|
+
GGML_METAL_KERNEL_TYPE_SILU,
|
57
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
58
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
59
|
+
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
60
|
+
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
61
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
62
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
63
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
64
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
65
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
66
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
67
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
68
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
69
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
70
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
71
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
|
72
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
73
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
74
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
75
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
76
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
77
|
+
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
78
|
+
GGML_METAL_KERNEL_TYPE_NORM,
|
79
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
80
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
81
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
82
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
83
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
84
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
85
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
86
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
87
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
88
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
89
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
90
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
91
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
92
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
|
93
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
94
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
96
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
97
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
98
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
99
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
100
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
101
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
102
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
103
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
104
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
105
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
106
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
107
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
108
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
109
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
|
110
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
111
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
112
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
113
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
114
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
115
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
116
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
117
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
119
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
120
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
121
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
122
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
123
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
|
124
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
125
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
126
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
127
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
128
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
129
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
130
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
131
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
132
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
133
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
134
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
135
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
136
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
137
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
138
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
139
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
140
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
141
|
+
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
142
|
+
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
143
|
+
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
144
|
+
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
145
|
+
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
146
|
+
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
147
|
+
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
148
|
+
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
149
|
+
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
150
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
151
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
152
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
153
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
154
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
155
|
+
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
156
|
+
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
157
|
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
158
|
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
159
|
+
GGML_METAL_KERNEL_TYPE_CONCAT,
|
160
|
+
GGML_METAL_KERNEL_TYPE_SQR,
|
161
|
+
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
162
|
+
|
163
|
+
GGML_METAL_KERNEL_TYPE_COUNT
|
164
|
+
};
|
165
|
+
|
38
166
|
struct ggml_metal_context {
|
39
167
|
int n_cb;
|
40
168
|
|
@@ -42,142 +170,15 @@ struct ggml_metal_context {
|
|
42
170
|
id<MTLCommandQueue> queue;
|
43
171
|
id<MTLLibrary> library;
|
44
172
|
|
45
|
-
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
46
|
-
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
47
|
-
|
48
173
|
dispatch_queue_t d_queue;
|
49
174
|
|
50
175
|
int n_buffers;
|
51
176
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
52
177
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
#define GGML_METAL_DECL_KERNEL(name) \
|
58
|
-
id<MTLFunction> function_##name; \
|
59
|
-
id<MTLComputePipelineState> pipeline_##name
|
60
|
-
|
61
|
-
GGML_METAL_DECL_KERNEL(add);
|
62
|
-
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
63
|
-
GGML_METAL_DECL_KERNEL(mul);
|
64
|
-
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
65
|
-
GGML_METAL_DECL_KERNEL(div);
|
66
|
-
GGML_METAL_DECL_KERNEL(div_row);
|
67
|
-
GGML_METAL_DECL_KERNEL(scale);
|
68
|
-
GGML_METAL_DECL_KERNEL(scale_4);
|
69
|
-
GGML_METAL_DECL_KERNEL(tanh);
|
70
|
-
GGML_METAL_DECL_KERNEL(relu);
|
71
|
-
GGML_METAL_DECL_KERNEL(gelu);
|
72
|
-
GGML_METAL_DECL_KERNEL(gelu_quick);
|
73
|
-
GGML_METAL_DECL_KERNEL(silu);
|
74
|
-
GGML_METAL_DECL_KERNEL(soft_max);
|
75
|
-
GGML_METAL_DECL_KERNEL(soft_max_4);
|
76
|
-
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
77
|
-
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
78
|
-
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
79
|
-
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
80
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
81
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
82
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
83
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
84
|
-
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
85
|
-
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
86
|
-
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
87
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
88
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
89
|
-
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
90
|
-
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
91
|
-
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
92
|
-
GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
|
93
|
-
GGML_METAL_DECL_KERNEL(rms_norm);
|
94
|
-
GGML_METAL_DECL_KERNEL(group_norm);
|
95
|
-
GGML_METAL_DECL_KERNEL(norm);
|
96
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
97
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
98
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
99
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
100
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
101
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
102
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
103
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
104
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
105
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
106
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
107
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
108
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
109
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
110
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
111
|
-
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
112
|
-
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
|
113
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
114
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
115
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
116
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
117
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
118
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
119
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
120
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
121
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
122
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
123
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
124
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
125
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
126
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
127
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
128
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
129
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
|
130
|
-
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
131
|
-
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
132
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
133
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
134
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
135
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
136
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
137
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
138
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
139
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
140
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
141
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
142
|
-
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
143
|
-
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
|
144
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
145
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
146
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
147
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
148
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
149
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
150
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
151
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
152
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
153
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
154
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
155
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
156
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
157
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
|
158
|
-
GGML_METAL_DECL_KERNEL(rope_f32);
|
159
|
-
GGML_METAL_DECL_KERNEL(rope_f16);
|
160
|
-
GGML_METAL_DECL_KERNEL(alibi_f32);
|
161
|
-
GGML_METAL_DECL_KERNEL(im2col_f16);
|
162
|
-
GGML_METAL_DECL_KERNEL(upscale_f32);
|
163
|
-
GGML_METAL_DECL_KERNEL(pad_f32);
|
164
|
-
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
165
|
-
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
166
|
-
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
167
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
168
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
169
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
170
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
171
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
172
|
-
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
173
|
-
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
174
|
-
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
175
|
-
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
176
|
-
GGML_METAL_DECL_KERNEL(concat);
|
177
|
-
GGML_METAL_DECL_KERNEL(sqr);
|
178
|
-
GGML_METAL_DECL_KERNEL(sum_rows);
|
179
|
-
|
180
|
-
#undef GGML_METAL_DECL_KERNEL
|
178
|
+
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
|
179
|
+
|
180
|
+
bool support_simdgroup_reduction;
|
181
|
+
bool support_simdgroup_mm;
|
181
182
|
};
|
182
183
|
|
183
184
|
// MSL code
|
@@ -191,7 +192,6 @@ struct ggml_metal_context {
|
|
191
192
|
@implementation GGMLMetalClass
|
192
193
|
@end
|
193
194
|
|
194
|
-
|
195
195
|
static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
196
196
|
fprintf(stderr, "%s", msg);
|
197
197
|
|
@@ -202,11 +202,6 @@ static void ggml_metal_default_log_callback(enum ggml_log_level level, const cha
|
|
202
202
|
ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
|
203
203
|
void * ggml_metal_log_user_data = NULL;
|
204
204
|
|
205
|
-
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
206
|
-
ggml_metal_log_callback = log_callback;
|
207
|
-
ggml_metal_log_user_data = user_data;
|
208
|
-
}
|
209
|
-
|
210
205
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
211
206
|
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
212
207
|
if (ggml_metal_log_callback != NULL) {
|
@@ -229,7 +224,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
229
224
|
}
|
230
225
|
}
|
231
226
|
|
232
|
-
|
227
|
+
static void * ggml_metal_host_malloc(size_t n) {
|
228
|
+
void * data = NULL;
|
229
|
+
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
230
|
+
if (result != 0) {
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
232
|
+
return NULL;
|
233
|
+
}
|
234
|
+
|
235
|
+
return data;
|
236
|
+
}
|
237
|
+
|
238
|
+
static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
233
239
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
234
240
|
|
235
241
|
id<MTLDevice> device;
|
@@ -255,7 +261,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
255
261
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
256
262
|
ctx->queue = [ctx->device newCommandQueue];
|
257
263
|
ctx->n_buffers = 0;
|
258
|
-
ctx->concur_list_len = 0;
|
259
264
|
|
260
265
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
261
266
|
|
@@ -298,19 +303,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
298
303
|
return NULL;
|
299
304
|
}
|
300
305
|
|
301
|
-
|
306
|
+
// dictionary of preprocessor macros
|
307
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
308
|
+
|
302
309
|
#ifdef GGML_QKK_64
|
303
|
-
|
304
|
-
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
310
|
+
prep[@"QK_K"] = @(64);
|
305
311
|
#endif
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
// and go through the "pre-compiled library found" path above
|
312
|
+
|
313
|
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
314
|
+
options.preprocessorMacros = prep;
|
315
|
+
|
311
316
|
//[options setFastMathEnabled:false];
|
312
317
|
|
313
318
|
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
319
|
+
|
320
|
+
[options release];
|
321
|
+
[prep release];
|
314
322
|
}
|
315
323
|
|
316
324
|
if (error) {
|
@@ -319,22 +327,51 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
319
327
|
}
|
320
328
|
}
|
321
329
|
|
322
|
-
#if TARGET_OS_OSX
|
323
330
|
// print MTL GPU family:
|
324
331
|
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
325
332
|
|
333
|
+
const NSInteger MTLGPUFamilyMetal3 = 5001;
|
334
|
+
|
326
335
|
// determine max supported GPU family
|
327
336
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
328
337
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
338
|
+
{
|
339
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
340
|
+
if ([ctx->device supportsFamily:i]) {
|
341
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
342
|
+
break;
|
343
|
+
}
|
344
|
+
}
|
345
|
+
|
346
|
+
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
347
|
+
if ([ctx->device supportsFamily:i]) {
|
348
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
349
|
+
break;
|
350
|
+
}
|
351
|
+
}
|
352
|
+
|
353
|
+
for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
|
354
|
+
if ([ctx->device supportsFamily:i]) {
|
355
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
|
356
|
+
break;
|
357
|
+
}
|
333
358
|
}
|
334
359
|
}
|
335
360
|
|
361
|
+
ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
362
|
+
ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
|
363
|
+
|
364
|
+
ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
365
|
+
|
366
|
+
GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
|
367
|
+
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
336
368
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
337
|
-
|
369
|
+
|
370
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
371
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
372
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
373
|
+
}
|
374
|
+
#elif TARGET_OS_OSX
|
338
375
|
if (ctx->device.maxTransferRate != 0) {
|
339
376
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
340
377
|
} else {
|
@@ -346,279 +383,171 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
346
383
|
{
|
347
384
|
NSError * error = nil;
|
348
385
|
|
386
|
+
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
387
|
+
ctx->kernels[i].function = nil;
|
388
|
+
ctx->kernels[i].pipeline = nil;
|
389
|
+
}
|
390
|
+
|
349
391
|
/*
|
350
|
-
|
351
|
-
|
352
|
-
|
392
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
393
|
+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
394
|
+
(int) kernel->pipeline.threadExecutionWidth); \
|
353
395
|
*/
|
354
|
-
#define GGML_METAL_ADD_KERNEL(name) \
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
396
|
+
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
|
397
|
+
if (supported) { \
|
398
|
+
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
399
|
+
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
400
|
+
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
|
401
|
+
if (error) { \
|
402
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
403
|
+
return NULL; \
|
404
|
+
} \
|
405
|
+
} else { \
|
406
|
+
GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
|
360
407
|
}
|
361
408
|
|
362
|
-
|
363
|
-
|
364
|
-
GGML_METAL_ADD_KERNEL(
|
365
|
-
GGML_METAL_ADD_KERNEL(
|
366
|
-
GGML_METAL_ADD_KERNEL(
|
367
|
-
GGML_METAL_ADD_KERNEL(
|
368
|
-
GGML_METAL_ADD_KERNEL(
|
369
|
-
GGML_METAL_ADD_KERNEL(
|
370
|
-
GGML_METAL_ADD_KERNEL(
|
371
|
-
GGML_METAL_ADD_KERNEL(
|
372
|
-
GGML_METAL_ADD_KERNEL(
|
373
|
-
GGML_METAL_ADD_KERNEL(
|
374
|
-
GGML_METAL_ADD_KERNEL(
|
375
|
-
GGML_METAL_ADD_KERNEL(
|
376
|
-
GGML_METAL_ADD_KERNEL(
|
377
|
-
GGML_METAL_ADD_KERNEL(
|
378
|
-
GGML_METAL_ADD_KERNEL(
|
379
|
-
GGML_METAL_ADD_KERNEL(
|
380
|
-
GGML_METAL_ADD_KERNEL(
|
381
|
-
GGML_METAL_ADD_KERNEL(
|
382
|
-
GGML_METAL_ADD_KERNEL(
|
383
|
-
GGML_METAL_ADD_KERNEL(
|
384
|
-
GGML_METAL_ADD_KERNEL(
|
385
|
-
GGML_METAL_ADD_KERNEL(
|
386
|
-
GGML_METAL_ADD_KERNEL(
|
387
|
-
GGML_METAL_ADD_KERNEL(
|
388
|
-
GGML_METAL_ADD_KERNEL(
|
389
|
-
GGML_METAL_ADD_KERNEL(
|
390
|
-
GGML_METAL_ADD_KERNEL(
|
391
|
-
GGML_METAL_ADD_KERNEL(
|
392
|
-
GGML_METAL_ADD_KERNEL(
|
393
|
-
GGML_METAL_ADD_KERNEL(
|
394
|
-
GGML_METAL_ADD_KERNEL(
|
395
|
-
GGML_METAL_ADD_KERNEL(
|
396
|
-
GGML_METAL_ADD_KERNEL(
|
397
|
-
GGML_METAL_ADD_KERNEL(
|
398
|
-
GGML_METAL_ADD_KERNEL(
|
399
|
-
GGML_METAL_ADD_KERNEL(
|
400
|
-
GGML_METAL_ADD_KERNEL(
|
401
|
-
GGML_METAL_ADD_KERNEL(
|
402
|
-
GGML_METAL_ADD_KERNEL(
|
403
|
-
GGML_METAL_ADD_KERNEL(
|
404
|
-
GGML_METAL_ADD_KERNEL(
|
405
|
-
GGML_METAL_ADD_KERNEL(
|
406
|
-
GGML_METAL_ADD_KERNEL(
|
407
|
-
GGML_METAL_ADD_KERNEL(
|
408
|
-
GGML_METAL_ADD_KERNEL(
|
409
|
-
GGML_METAL_ADD_KERNEL(
|
410
|
-
GGML_METAL_ADD_KERNEL(
|
411
|
-
GGML_METAL_ADD_KERNEL(
|
412
|
-
GGML_METAL_ADD_KERNEL(
|
413
|
-
GGML_METAL_ADD_KERNEL(
|
414
|
-
GGML_METAL_ADD_KERNEL(
|
415
|
-
|
416
|
-
GGML_METAL_ADD_KERNEL(
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
GGML_METAL_ADD_KERNEL(
|
422
|
-
GGML_METAL_ADD_KERNEL(
|
423
|
-
GGML_METAL_ADD_KERNEL(
|
424
|
-
GGML_METAL_ADD_KERNEL(
|
425
|
-
GGML_METAL_ADD_KERNEL(
|
426
|
-
GGML_METAL_ADD_KERNEL(
|
427
|
-
GGML_METAL_ADD_KERNEL(
|
428
|
-
GGML_METAL_ADD_KERNEL(
|
429
|
-
GGML_METAL_ADD_KERNEL(
|
430
|
-
GGML_METAL_ADD_KERNEL(
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
GGML_METAL_ADD_KERNEL(rope_f32);
|
462
|
-
GGML_METAL_ADD_KERNEL(rope_f16);
|
463
|
-
GGML_METAL_ADD_KERNEL(alibi_f32);
|
464
|
-
GGML_METAL_ADD_KERNEL(im2col_f16);
|
465
|
-
GGML_METAL_ADD_KERNEL(upscale_f32);
|
466
|
-
GGML_METAL_ADD_KERNEL(pad_f32);
|
467
|
-
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
468
|
-
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
469
|
-
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
470
|
-
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
471
|
-
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
472
|
-
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
473
|
-
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
474
|
-
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
475
|
-
|
476
|
-
|
477
|
-
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
478
|
-
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
479
|
-
GGML_METAL_ADD_KERNEL(concat);
|
480
|
-
GGML_METAL_ADD_KERNEL(sqr);
|
481
|
-
GGML_METAL_ADD_KERNEL(sum_rows);
|
482
|
-
|
483
|
-
#undef GGML_METAL_ADD_KERNEL
|
409
|
+
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
410
|
+
|
411
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
412
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
413
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
414
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
415
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
416
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
417
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
418
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
419
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
420
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
421
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
422
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
423
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
424
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
425
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
426
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
427
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
428
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
429
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
430
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
431
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
432
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
433
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
434
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
435
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
436
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
437
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
438
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
439
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
440
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
441
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
442
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
443
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
444
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
445
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
446
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
447
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
448
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
449
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
450
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
451
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
452
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
453
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
454
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
455
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
456
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
457
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
458
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
459
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
460
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
461
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
462
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
463
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
464
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
465
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
466
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
467
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
468
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
469
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
470
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
471
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
472
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
473
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
474
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
476
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
477
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
478
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
480
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
482
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
484
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
485
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
486
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
487
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
488
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
489
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
491
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
501
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
502
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
504
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
506
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
507
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
508
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
509
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
511
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
517
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
518
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
519
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
520
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
521
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
522
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
523
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
524
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
525
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
527
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
528
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
484
529
|
}
|
485
530
|
|
486
531
|
return ctx;
|
487
532
|
}
|
488
533
|
|
489
|
-
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
534
|
+
static void ggml_metal_free(struct ggml_metal_context * ctx) {
|
490
535
|
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
491
|
-
#define GGML_METAL_DEL_KERNEL(name) \
|
492
|
-
[ctx->function_##name release]; \
|
493
|
-
[ctx->pipeline_##name release];
|
494
|
-
|
495
|
-
GGML_METAL_DEL_KERNEL(add);
|
496
|
-
GGML_METAL_DEL_KERNEL(add_row);
|
497
|
-
GGML_METAL_DEL_KERNEL(mul);
|
498
|
-
GGML_METAL_DEL_KERNEL(mul_row);
|
499
|
-
GGML_METAL_DEL_KERNEL(div);
|
500
|
-
GGML_METAL_DEL_KERNEL(div_row);
|
501
|
-
GGML_METAL_DEL_KERNEL(scale);
|
502
|
-
GGML_METAL_DEL_KERNEL(scale_4);
|
503
|
-
GGML_METAL_DEL_KERNEL(tanh);
|
504
|
-
GGML_METAL_DEL_KERNEL(relu);
|
505
|
-
GGML_METAL_DEL_KERNEL(gelu);
|
506
|
-
GGML_METAL_DEL_KERNEL(gelu_quick);
|
507
|
-
GGML_METAL_DEL_KERNEL(silu);
|
508
|
-
GGML_METAL_DEL_KERNEL(soft_max);
|
509
|
-
GGML_METAL_DEL_KERNEL(soft_max_4);
|
510
|
-
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
511
|
-
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
512
|
-
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
513
|
-
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
514
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
515
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
516
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
517
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
518
|
-
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
519
|
-
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
520
|
-
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
521
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
522
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
523
|
-
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
524
|
-
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
525
|
-
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
526
|
-
GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
|
527
|
-
GGML_METAL_DEL_KERNEL(rms_norm);
|
528
|
-
GGML_METAL_DEL_KERNEL(group_norm);
|
529
|
-
GGML_METAL_DEL_KERNEL(norm);
|
530
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
531
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
532
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
533
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
534
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
535
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
536
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
537
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
538
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
539
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
540
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
541
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
542
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
543
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
544
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
545
|
-
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
546
|
-
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
|
547
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
548
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
549
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
550
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
551
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
552
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
553
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
554
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
555
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
556
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
557
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
558
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
559
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
560
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
561
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
562
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
563
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
|
564
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
565
|
-
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
566
|
-
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
567
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
568
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
569
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
570
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
571
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
572
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
573
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
574
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
575
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
576
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
577
|
-
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
578
|
-
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
|
579
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
580
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
581
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
582
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
583
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
584
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
585
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
586
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
587
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
588
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
589
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
590
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
591
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
592
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
|
593
|
-
}
|
594
|
-
GGML_METAL_DEL_KERNEL(rope_f32);
|
595
|
-
GGML_METAL_DEL_KERNEL(rope_f16);
|
596
|
-
GGML_METAL_DEL_KERNEL(alibi_f32);
|
597
|
-
GGML_METAL_DEL_KERNEL(im2col_f16);
|
598
|
-
GGML_METAL_DEL_KERNEL(upscale_f32);
|
599
|
-
GGML_METAL_DEL_KERNEL(pad_f32);
|
600
|
-
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
601
|
-
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
602
|
-
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
603
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
604
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
605
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
606
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
607
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
608
|
-
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
609
|
-
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
610
|
-
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
611
|
-
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
612
|
-
GGML_METAL_DEL_KERNEL(concat);
|
613
|
-
GGML_METAL_DEL_KERNEL(sqr);
|
614
|
-
GGML_METAL_DEL_KERNEL(sum_rows);
|
615
|
-
|
616
|
-
#undef GGML_METAL_DEL_KERNEL
|
617
536
|
|
618
537
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
619
538
|
[ctx->buffers[i].metal release];
|
620
539
|
}
|
621
540
|
|
541
|
+
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
542
|
+
if (ctx->kernels[i].pipeline) {
|
543
|
+
[ctx->kernels[i].pipeline release];
|
544
|
+
}
|
545
|
+
|
546
|
+
if (ctx->kernels[i].function) {
|
547
|
+
[ctx->kernels[i].function release];
|
548
|
+
}
|
549
|
+
}
|
550
|
+
|
622
551
|
[ctx->library release];
|
623
552
|
[ctx->queue release];
|
624
553
|
[ctx->device release];
|
@@ -628,33 +557,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
628
557
|
free(ctx);
|
629
558
|
}
|
630
559
|
|
631
|
-
void * ggml_metal_host_malloc(size_t n) {
|
632
|
-
void * data = NULL;
|
633
|
-
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
634
|
-
if (result != 0) {
|
635
|
-
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
636
|
-
return NULL;
|
637
|
-
}
|
638
|
-
|
639
|
-
return data;
|
640
|
-
}
|
641
|
-
|
642
|
-
void ggml_metal_host_free(void * data) {
|
643
|
-
free(data);
|
644
|
-
}
|
645
|
-
|
646
|
-
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
647
|
-
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
648
|
-
}
|
649
|
-
|
650
|
-
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
651
|
-
return ctx->concur_list_len;
|
652
|
-
}
|
653
|
-
|
654
|
-
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
655
|
-
return ctx->concur_list;
|
656
|
-
}
|
657
|
-
|
658
560
|
// temporarily defined here for compatibility between ggml-backend and the old API
|
659
561
|
|
660
562
|
struct ggml_backend_metal_buffer {
|
@@ -727,210 +629,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
727
629
|
return nil;
|
728
630
|
}
|
729
631
|
|
730
|
-
bool
|
731
|
-
struct ggml_metal_context * ctx,
|
732
|
-
const char * name,
|
733
|
-
void * data,
|
734
|
-
size_t size,
|
735
|
-
size_t max_size) {
|
736
|
-
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
737
|
-
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
738
|
-
return false;
|
739
|
-
}
|
740
|
-
|
741
|
-
if (data) {
|
742
|
-
// verify that the buffer does not overlap with any of the existing buffers
|
743
|
-
for (int i = 0; i < ctx->n_buffers; ++i) {
|
744
|
-
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
745
|
-
|
746
|
-
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
747
|
-
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
748
|
-
return false;
|
749
|
-
}
|
750
|
-
}
|
751
|
-
|
752
|
-
const size_t size_page = sysconf(_SC_PAGESIZE);
|
753
|
-
|
754
|
-
size_t size_aligned = size;
|
755
|
-
if ((size_aligned % size_page) != 0) {
|
756
|
-
size_aligned += (size_page - (size_aligned % size_page));
|
757
|
-
}
|
758
|
-
|
759
|
-
// the buffer fits into the max buffer size allowed by the device
|
760
|
-
if (size_aligned <= ctx->device.maxBufferLength) {
|
761
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
762
|
-
ctx->buffers[ctx->n_buffers].data = data;
|
763
|
-
ctx->buffers[ctx->n_buffers].size = size;
|
764
|
-
|
765
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
766
|
-
|
767
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
768
|
-
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
769
|
-
return false;
|
770
|
-
}
|
771
|
-
|
772
|
-
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
773
|
-
|
774
|
-
++ctx->n_buffers;
|
775
|
-
} else {
|
776
|
-
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
777
|
-
// one of the views
|
778
|
-
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
779
|
-
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
780
|
-
const size_t size_view = ctx->device.maxBufferLength;
|
781
|
-
|
782
|
-
for (size_t i = 0; i < size; i += size_step) {
|
783
|
-
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
784
|
-
|
785
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
786
|
-
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
787
|
-
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
788
|
-
|
789
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
790
|
-
|
791
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
792
|
-
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
793
|
-
return false;
|
794
|
-
}
|
795
|
-
|
796
|
-
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
797
|
-
if (i + size_step < size) {
|
798
|
-
GGML_METAL_LOG_INFO("\n");
|
799
|
-
}
|
800
|
-
|
801
|
-
++ctx->n_buffers;
|
802
|
-
}
|
803
|
-
}
|
804
|
-
|
805
|
-
#if TARGET_OS_OSX
|
806
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
807
|
-
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
808
|
-
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
809
|
-
|
810
|
-
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
811
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
812
|
-
} else {
|
813
|
-
GGML_METAL_LOG_INFO("\n");
|
814
|
-
}
|
815
|
-
#else
|
816
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
817
|
-
#endif
|
818
|
-
}
|
819
|
-
|
820
|
-
return true;
|
821
|
-
}
|
822
|
-
|
823
|
-
void ggml_metal_set_tensor(
|
824
|
-
struct ggml_metal_context * ctx,
|
825
|
-
struct ggml_tensor * t) {
|
826
|
-
size_t offs;
|
827
|
-
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
828
|
-
|
829
|
-
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
|
830
|
-
}
|
831
|
-
|
832
|
-
void ggml_metal_get_tensor(
|
833
|
-
struct ggml_metal_context * ctx,
|
834
|
-
struct ggml_tensor * t) {
|
835
|
-
size_t offs;
|
836
|
-
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
837
|
-
|
838
|
-
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
839
|
-
}
|
840
|
-
|
841
|
-
void ggml_metal_graph_find_concurrency(
|
842
|
-
struct ggml_metal_context * ctx,
|
843
|
-
struct ggml_cgraph * gf, bool check_mem) {
|
844
|
-
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
845
|
-
int nodes_unused[GGML_MAX_CONCUR];
|
846
|
-
|
847
|
-
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
848
|
-
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
849
|
-
ctx->concur_list_len = 0;
|
850
|
-
|
851
|
-
int n_left = gf->n_nodes;
|
852
|
-
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
853
|
-
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
854
|
-
|
855
|
-
while (n_left > 0) {
|
856
|
-
// number of nodes at a layer (that can be issued concurrently)
|
857
|
-
int concurrency = 0;
|
858
|
-
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
859
|
-
if (nodes_unused[i]) {
|
860
|
-
// if the requirements for gf->nodes[i] are satisfied
|
861
|
-
int exe_flag = 1;
|
862
|
-
|
863
|
-
// scan all srcs
|
864
|
-
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
865
|
-
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
866
|
-
if (src_cur) {
|
867
|
-
// if is leaf nodes it's satisfied.
|
868
|
-
// TODO: ggml_is_leaf()
|
869
|
-
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
870
|
-
continue;
|
871
|
-
}
|
872
|
-
|
873
|
-
// otherwise this src should be the output from previous nodes.
|
874
|
-
int is_found = 0;
|
875
|
-
|
876
|
-
// scan 2*search_depth back because we inserted barrier.
|
877
|
-
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
878
|
-
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
879
|
-
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
880
|
-
is_found = 1;
|
881
|
-
break;
|
882
|
-
}
|
883
|
-
}
|
884
|
-
if (is_found == 0) {
|
885
|
-
exe_flag = 0;
|
886
|
-
break;
|
887
|
-
}
|
888
|
-
}
|
889
|
-
}
|
890
|
-
if (exe_flag && check_mem) {
|
891
|
-
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
892
|
-
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
893
|
-
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
894
|
-
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
895
|
-
for (int j = n_start; j < i; j++) {
|
896
|
-
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
897
|
-
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
898
|
-
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
899
|
-
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
900
|
-
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
901
|
-
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
902
|
-
continue;
|
903
|
-
}
|
904
|
-
|
905
|
-
exe_flag = 0;
|
906
|
-
}
|
907
|
-
}
|
908
|
-
}
|
909
|
-
if (exe_flag) {
|
910
|
-
ctx->concur_list[level_pos + concurrency] = i;
|
911
|
-
nodes_unused[i] = 0;
|
912
|
-
concurrency++;
|
913
|
-
ctx->concur_list_len++;
|
914
|
-
}
|
915
|
-
}
|
916
|
-
}
|
917
|
-
n_left -= concurrency;
|
918
|
-
// adding a barrier different layer
|
919
|
-
ctx->concur_list[level_pos + concurrency] = -1;
|
920
|
-
ctx->concur_list_len++;
|
921
|
-
// jump all sorted nodes at nodes_bak
|
922
|
-
while (!nodes_unused[n_start]) {
|
923
|
-
n_start++;
|
924
|
-
}
|
925
|
-
level_pos += concurrency + 1;
|
926
|
-
}
|
927
|
-
|
928
|
-
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
929
|
-
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
930
|
-
}
|
931
|
-
}
|
932
|
-
|
933
|
-
static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
632
|
+
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
|
934
633
|
switch (op->op) {
|
935
634
|
case GGML_OP_UNARY:
|
936
635
|
switch (ggml_get_unary_op(op)) {
|
@@ -956,9 +655,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
956
655
|
case GGML_OP_SCALE:
|
957
656
|
case GGML_OP_SQR:
|
958
657
|
case GGML_OP_SUM_ROWS:
|
658
|
+
return true;
|
959
659
|
case GGML_OP_SOFT_MAX:
|
960
660
|
case GGML_OP_RMS_NORM:
|
961
661
|
case GGML_OP_GROUP_NORM:
|
662
|
+
return ctx->support_simdgroup_reduction;
|
962
663
|
case GGML_OP_NORM:
|
963
664
|
case GGML_OP_ALIBI:
|
964
665
|
case GGML_OP_ROPE:
|
@@ -967,9 +668,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
967
668
|
case GGML_OP_PAD:
|
968
669
|
case GGML_OP_ARGSORT:
|
969
670
|
case GGML_OP_LEAKY_RELU:
|
671
|
+
return true;
|
970
672
|
case GGML_OP_MUL_MAT:
|
971
673
|
case GGML_OP_MUL_MAT_ID:
|
972
|
-
return
|
674
|
+
return ctx->support_simdgroup_reduction;
|
973
675
|
case GGML_OP_CPY:
|
974
676
|
case GGML_OP_DUP:
|
975
677
|
case GGML_OP_CONT:
|
@@ -1007,1480 +709,1549 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
1007
709
|
return false;
|
1008
710
|
}
|
1009
711
|
}
|
1010
|
-
|
712
|
+
|
713
|
+
static bool ggml_metal_graph_compute(
|
1011
714
|
struct ggml_metal_context * ctx,
|
1012
715
|
struct ggml_cgraph * gf) {
|
1013
716
|
@autoreleasepool {
|
1014
717
|
|
1015
|
-
// if there is ctx->concur_list, dispatch concurrently
|
1016
|
-
// else fallback to serial dispatch
|
1017
718
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
1018
|
-
|
1019
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
1020
|
-
|
1021
|
-
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
1022
|
-
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
719
|
+
edesc.dispatchType = MTLDispatchTypeSerial;
|
1023
720
|
|
1024
721
|
// create multiple command buffers and enqueue them
|
1025
722
|
// then, we encode the graph into the command buffers in parallel
|
1026
723
|
|
724
|
+
const int n_nodes = gf->n_nodes;
|
1027
725
|
const int n_cb = ctx->n_cb;
|
726
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
1028
727
|
|
1029
|
-
|
1030
|
-
|
728
|
+
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
729
|
+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
730
|
+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
731
|
+
command_buffer_builder[cb_idx] = command_buffer;
|
1031
732
|
|
1032
733
|
// enqueue the command buffers in order to specify their execution order
|
1033
|
-
[
|
1034
|
-
|
1035
|
-
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
734
|
+
[command_buffer enqueue];
|
1036
735
|
}
|
736
|
+
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
1037
737
|
|
1038
|
-
|
1039
|
-
const int
|
1040
|
-
|
1041
|
-
dispatch_async(ctx->d_queue, ^{
|
1042
|
-
size_t offs_src0 = 0;
|
1043
|
-
size_t offs_src1 = 0;
|
1044
|
-
size_t offs_dst = 0;
|
1045
|
-
|
1046
|
-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
1047
|
-
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
1048
|
-
|
1049
|
-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
1050
|
-
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
1051
|
-
|
1052
|
-
for (int ind = node_start; ind < node_end; ++ind) {
|
1053
|
-
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
1054
|
-
|
1055
|
-
if (i == -1) {
|
1056
|
-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
1057
|
-
continue;
|
1058
|
-
}
|
1059
|
-
|
1060
|
-
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
1061
|
-
|
1062
|
-
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
1063
|
-
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
1064
|
-
struct ggml_tensor * dst = gf->nodes[i];
|
1065
|
-
|
1066
|
-
switch (dst->op) {
|
1067
|
-
case GGML_OP_NONE:
|
1068
|
-
case GGML_OP_RESHAPE:
|
1069
|
-
case GGML_OP_VIEW:
|
1070
|
-
case GGML_OP_TRANSPOSE:
|
1071
|
-
case GGML_OP_PERMUTE:
|
1072
|
-
{
|
1073
|
-
// noop -> next node
|
1074
|
-
} continue;
|
1075
|
-
default:
|
1076
|
-
{
|
1077
|
-
} break;
|
1078
|
-
}
|
738
|
+
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
739
|
+
const int cb_idx = iter;
|
1079
740
|
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
}
|
1084
|
-
|
1085
|
-
#ifndef GGML_METAL_NDEBUG
|
1086
|
-
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
1087
|
-
#endif
|
741
|
+
size_t offs_src0 = 0;
|
742
|
+
size_t offs_src1 = 0;
|
743
|
+
size_t offs_dst = 0;
|
1088
744
|
|
1089
|
-
|
1090
|
-
|
1091
|
-
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
1092
|
-
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
1093
|
-
|
1094
|
-
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
1095
|
-
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
1096
|
-
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
1097
|
-
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
1098
|
-
|
1099
|
-
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
1100
|
-
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
1101
|
-
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
1102
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
1103
|
-
|
1104
|
-
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
1105
|
-
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
1106
|
-
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
1107
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
1108
|
-
|
1109
|
-
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
1110
|
-
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
1111
|
-
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
1112
|
-
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
1113
|
-
|
1114
|
-
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
1115
|
-
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
1116
|
-
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
1117
|
-
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
1118
|
-
|
1119
|
-
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
1120
|
-
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
1121
|
-
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
1122
|
-
|
1123
|
-
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
1124
|
-
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
1125
|
-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
1126
|
-
|
1127
|
-
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
1128
|
-
//if (src0) {
|
1129
|
-
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
1130
|
-
// ggml_is_contiguous(src0), src0->name);
|
1131
|
-
//}
|
1132
|
-
//if (src1) {
|
1133
|
-
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
1134
|
-
// ggml_is_contiguous(src1), src1->name);
|
1135
|
-
//}
|
1136
|
-
//if (dst) {
|
1137
|
-
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
1138
|
-
// dst->name);
|
1139
|
-
//}
|
1140
|
-
|
1141
|
-
switch (dst->op) {
|
1142
|
-
case GGML_OP_CONCAT:
|
1143
|
-
{
|
1144
|
-
const int64_t nb = ne00;
|
1145
|
-
|
1146
|
-
[encoder setComputePipelineState:ctx->pipeline_concat];
|
1147
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1148
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1149
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1150
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1151
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1152
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1153
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1154
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1155
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
1156
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
1157
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
1158
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1159
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1160
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1161
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1162
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1163
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1164
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1165
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1166
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1167
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1168
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1169
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1170
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1171
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1172
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1173
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1174
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
1175
|
-
|
1176
|
-
const int nth = MIN(1024, ne0);
|
1177
|
-
|
1178
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1179
|
-
} break;
|
1180
|
-
case GGML_OP_ADD:
|
1181
|
-
case GGML_OP_MUL:
|
1182
|
-
case GGML_OP_DIV:
|
1183
|
-
{
|
1184
|
-
const size_t offs = 0;
|
1185
|
-
|
1186
|
-
bool bcast_row = false;
|
1187
|
-
|
1188
|
-
int64_t nb = ne00;
|
745
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
746
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1189
747
|
|
1190
|
-
|
748
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
749
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
1191
750
|
|
1192
|
-
|
1193
|
-
|
751
|
+
for (int i = node_start; i < node_end; ++i) {
|
752
|
+
if (i == -1) {
|
753
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
754
|
+
continue;
|
755
|
+
}
|
1194
756
|
|
1195
|
-
|
1196
|
-
|
757
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
758
|
+
|
759
|
+
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
760
|
+
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
761
|
+
struct ggml_tensor * dst = gf->nodes[i];
|
762
|
+
|
763
|
+
switch (dst->op) {
|
764
|
+
case GGML_OP_NONE:
|
765
|
+
case GGML_OP_RESHAPE:
|
766
|
+
case GGML_OP_VIEW:
|
767
|
+
case GGML_OP_TRANSPOSE:
|
768
|
+
case GGML_OP_PERMUTE:
|
769
|
+
{
|
770
|
+
// noop -> next node
|
771
|
+
} continue;
|
772
|
+
default:
|
773
|
+
{
|
774
|
+
} break;
|
775
|
+
}
|
1197
776
|
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
1203
|
-
default: GGML_ASSERT(false);
|
1204
|
-
}
|
777
|
+
if (!ggml_metal_supports_op(ctx, dst)) {
|
778
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
779
|
+
GGML_ASSERT(!"unsupported op");
|
780
|
+
}
|
1205
781
|
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
1210
|
-
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
1211
|
-
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
1212
|
-
default: GGML_ASSERT(false);
|
1213
|
-
}
|
1214
|
-
}
|
782
|
+
#ifndef GGML_METAL_NDEBUG
|
783
|
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
784
|
+
#endif
|
1215
785
|
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
786
|
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
787
|
+
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
788
|
+
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
789
|
+
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
790
|
+
|
791
|
+
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
792
|
+
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
793
|
+
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
794
|
+
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
795
|
+
|
796
|
+
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
797
|
+
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
798
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
799
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
800
|
+
|
801
|
+
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
802
|
+
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
803
|
+
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
804
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
805
|
+
|
806
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
807
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
808
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
809
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
810
|
+
|
811
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
812
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
813
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
814
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
815
|
+
|
816
|
+
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
817
|
+
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
818
|
+
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
819
|
+
|
820
|
+
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
821
|
+
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
822
|
+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
823
|
+
|
824
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
825
|
+
//if (src0) {
|
826
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
827
|
+
// ggml_is_contiguous(src0), src0->name);
|
828
|
+
//}
|
829
|
+
//if (src1) {
|
830
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
831
|
+
// ggml_is_contiguous(src1), src1->name);
|
832
|
+
//}
|
833
|
+
//if (dst) {
|
834
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
835
|
+
// dst->name);
|
836
|
+
//}
|
837
|
+
|
838
|
+
switch (dst->op) {
|
839
|
+
case GGML_OP_CONCAT:
|
840
|
+
{
|
841
|
+
const int64_t nb = ne00;
|
842
|
+
|
843
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
844
|
+
|
845
|
+
[encoder setComputePipelineState:pipeline];
|
846
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
847
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
848
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
849
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
850
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
851
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
852
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
853
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
854
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
855
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
856
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
857
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
858
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
859
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
860
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
861
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
862
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
863
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
864
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
865
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
866
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
867
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
868
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
869
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
870
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
871
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
872
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
873
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
874
|
+
|
875
|
+
const int nth = MIN(1024, ne0);
|
876
|
+
|
877
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
878
|
+
} break;
|
879
|
+
case GGML_OP_ADD:
|
880
|
+
case GGML_OP_MUL:
|
881
|
+
case GGML_OP_DIV:
|
882
|
+
{
|
883
|
+
const size_t offs = 0;
|
884
|
+
|
885
|
+
bool bcast_row = false;
|
886
|
+
|
887
|
+
int64_t nb = ne00;
|
888
|
+
|
889
|
+
id<MTLComputePipelineState> pipeline = nil;
|
890
|
+
|
891
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
892
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1249
893
|
|
1250
|
-
|
1251
|
-
|
1252
|
-
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
894
|
+
// src1 is a row
|
895
|
+
GGML_ASSERT(ne11 == 1);
|
1253
896
|
|
1254
|
-
|
897
|
+
nb = ne00 / 4;
|
898
|
+
switch (dst->op) {
|
899
|
+
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
900
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
901
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
902
|
+
default: GGML_ASSERT(false);
|
1255
903
|
}
|
1256
|
-
} break;
|
1257
|
-
case GGML_OP_ACC:
|
1258
|
-
{
|
1259
|
-
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1260
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1261
|
-
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1262
904
|
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1272
|
-
|
1273
|
-
if (!inplace) {
|
1274
|
-
// run a separete kernel to cpy src->dst
|
1275
|
-
// not sure how to avoid this
|
1276
|
-
// TODO: make a simpler cpy_bytes kernel
|
1277
|
-
|
1278
|
-
const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
|
1279
|
-
|
1280
|
-
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
1281
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1282
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1283
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1284
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1285
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1286
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1287
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1288
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1289
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1290
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1291
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1292
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1293
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1294
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1295
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1296
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1297
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1298
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1299
|
-
|
1300
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
905
|
+
bcast_row = true;
|
906
|
+
} else {
|
907
|
+
switch (dst->op) {
|
908
|
+
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
909
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
910
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
911
|
+
default: GGML_ASSERT(false);
|
1301
912
|
}
|
913
|
+
}
|
1302
914
|
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
} break;
|
1337
|
-
case GGML_OP_SCALE:
|
1338
|
-
{
|
1339
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
915
|
+
[encoder setComputePipelineState:pipeline];
|
916
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
917
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
918
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
919
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
920
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
921
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
922
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
923
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
924
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
925
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
926
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
927
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
928
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
929
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
930
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
931
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
932
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
933
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
934
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
935
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
936
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
937
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
938
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
939
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
940
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
941
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
942
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
943
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
944
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
945
|
+
|
946
|
+
if (bcast_row) {
|
947
|
+
const int64_t n = ggml_nelements(dst)/4;
|
1340
948
|
|
1341
|
-
|
949
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
950
|
+
} else {
|
951
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1342
952
|
|
1343
|
-
|
953
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
954
|
+
}
|
955
|
+
} break;
|
956
|
+
case GGML_OP_ACC:
|
957
|
+
{
|
958
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
959
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
960
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1344
961
|
|
1345
|
-
|
1346
|
-
|
1347
|
-
[encoder setComputePipelineState:ctx->pipeline_scale_4];
|
1348
|
-
} else {
|
1349
|
-
[encoder setComputePipelineState:ctx->pipeline_scale];
|
1350
|
-
}
|
962
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
963
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
1351
964
|
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
965
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
966
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
967
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
968
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
1355
969
|
|
1356
|
-
|
1357
|
-
} break;
|
1358
|
-
case GGML_OP_UNARY:
|
1359
|
-
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1360
|
-
case GGML_UNARY_OP_TANH:
|
1361
|
-
{
|
1362
|
-
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
1363
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1364
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
970
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1365
971
|
|
1366
|
-
|
972
|
+
if (!inplace) {
|
973
|
+
// run a separete kernel to cpy src->dst
|
974
|
+
// not sure how to avoid this
|
975
|
+
// TODO: make a simpler cpy_bytes kernel
|
1367
976
|
|
1368
|
-
|
1369
|
-
} break;
|
1370
|
-
case GGML_UNARY_OP_RELU:
|
1371
|
-
{
|
1372
|
-
[encoder setComputePipelineState:ctx->pipeline_relu];
|
1373
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1374
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
977
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
1375
978
|
|
1376
|
-
|
979
|
+
[encoder setComputePipelineState:pipeline];
|
980
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
981
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
982
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
983
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
984
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
985
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
986
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
987
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
988
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
989
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
990
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
991
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
992
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
993
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
994
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
995
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
996
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
997
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1377
998
|
|
1378
|
-
|
1379
|
-
} break;
|
1380
|
-
case GGML_UNARY_OP_GELU:
|
1381
|
-
{
|
1382
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
1383
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1384
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
999
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
1385
1000
|
|
1386
|
-
|
1387
|
-
|
1001
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1002
|
+
}
|
1388
1003
|
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1004
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
1005
|
+
|
1006
|
+
[encoder setComputePipelineState:pipeline];
|
1007
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1008
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1009
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1010
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1011
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1012
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1013
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1014
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1015
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1016
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1017
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1018
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1019
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1020
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1021
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1022
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1023
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1024
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1025
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1026
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1027
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1028
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1029
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1030
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1031
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1032
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1033
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1034
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1035
|
+
|
1036
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
1037
|
+
|
1038
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1039
|
+
} break;
|
1040
|
+
case GGML_OP_SCALE:
|
1041
|
+
{
|
1042
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1043
|
+
|
1044
|
+
const float scale = *(const float *) dst->op_params;
|
1045
|
+
|
1046
|
+
int64_t n = ggml_nelements(dst);
|
1047
|
+
|
1048
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1049
|
+
|
1050
|
+
if (n % 4 == 0) {
|
1051
|
+
n /= 4;
|
1052
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
|
1053
|
+
} else {
|
1054
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
|
1055
|
+
}
|
1396
1056
|
|
1397
|
-
|
1398
|
-
|
1057
|
+
[encoder setComputePipelineState:pipeline];
|
1058
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1059
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1060
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
1399
1061
|
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1062
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1063
|
+
} break;
|
1064
|
+
case GGML_OP_UNARY:
|
1065
|
+
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1066
|
+
case GGML_UNARY_OP_TANH:
|
1067
|
+
{
|
1068
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
1407
1069
|
|
1408
|
-
|
1409
|
-
|
1070
|
+
[encoder setComputePipelineState:pipeline];
|
1071
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1072
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1410
1073
|
|
1411
|
-
|
1412
|
-
} break;
|
1413
|
-
default:
|
1414
|
-
{
|
1415
|
-
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1416
|
-
GGML_ASSERT(false);
|
1417
|
-
}
|
1418
|
-
} break;
|
1419
|
-
case GGML_OP_SQR:
|
1420
|
-
{
|
1421
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
1074
|
+
const int64_t n = ggml_nelements(dst);
|
1422
1075
|
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1076
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1077
|
+
} break;
|
1078
|
+
case GGML_UNARY_OP_RELU:
|
1079
|
+
{
|
1080
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
|
1426
1081
|
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
case GGML_OP_SUM_ROWS:
|
1431
|
-
{
|
1432
|
-
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
1082
|
+
[encoder setComputePipelineState:pipeline];
|
1083
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1084
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1433
1085
|
|
1434
|
-
|
1435
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1436
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1437
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1438
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1439
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1440
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1441
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1442
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1443
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1444
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1445
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1446
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1447
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1448
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1449
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1450
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1451
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1452
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1453
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1454
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
1455
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
1456
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
1457
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
1458
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
1459
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
1460
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
1461
|
-
|
1462
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1463
|
-
} break;
|
1464
|
-
case GGML_OP_SOFT_MAX:
|
1465
|
-
{
|
1466
|
-
int nth = 32; // SIMD width
|
1467
|
-
|
1468
|
-
if (ne00%4 == 0) {
|
1469
|
-
while (nth < ne00/4 && nth < 256) {
|
1470
|
-
nth *= 2;
|
1471
|
-
}
|
1472
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
1473
|
-
} else {
|
1474
|
-
while (nth < ne00 && nth < 1024) {
|
1475
|
-
nth *= 2;
|
1476
|
-
}
|
1477
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
1478
|
-
}
|
1086
|
+
const int64_t n = ggml_nelements(dst);
|
1479
1087
|
|
1480
|
-
|
1088
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1089
|
+
} break;
|
1090
|
+
case GGML_UNARY_OP_GELU:
|
1091
|
+
{
|
1092
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
1481
1093
|
|
1482
|
-
|
1483
|
-
|
1484
|
-
[encoder setBuffer:
|
1485
|
-
} else {
|
1486
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1487
|
-
}
|
1488
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1489
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1490
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1491
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1492
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1493
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1494
|
-
|
1495
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1496
|
-
} break;
|
1497
|
-
case GGML_OP_DIAG_MASK_INF:
|
1498
|
-
{
|
1499
|
-
const int n_past = ((int32_t *)(dst->op_params))[0];
|
1500
|
-
|
1501
|
-
if (ne00%8 == 0) {
|
1502
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
1503
|
-
} else {
|
1504
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
1505
|
-
}
|
1506
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1507
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1508
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1509
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1510
|
-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
1094
|
+
[encoder setComputePipelineState:pipeline];
|
1095
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1096
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1511
1097
|
|
1512
|
-
|
1513
|
-
|
1514
|
-
}
|
1515
|
-
else {
|
1516
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1517
|
-
}
|
1518
|
-
} break;
|
1519
|
-
case GGML_OP_MUL_MAT:
|
1520
|
-
{
|
1521
|
-
GGML_ASSERT(ne00 == ne10);
|
1098
|
+
const int64_t n = ggml_nelements(dst);
|
1099
|
+
GGML_ASSERT(n % 4 == 0);
|
1522
1100
|
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1101
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1102
|
+
} break;
|
1103
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1104
|
+
{
|
1105
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
1526
1106
|
|
1527
|
-
|
1528
|
-
|
1107
|
+
[encoder setComputePipelineState:pipeline];
|
1108
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1109
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1529
1110
|
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1111
|
+
const int64_t n = ggml_nelements(dst);
|
1112
|
+
GGML_ASSERT(n % 4 == 0);
|
1113
|
+
|
1114
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1115
|
+
} break;
|
1116
|
+
case GGML_UNARY_OP_SILU:
|
1117
|
+
{
|
1118
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
1119
|
+
|
1120
|
+
[encoder setComputePipelineState:pipeline];
|
1121
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1122
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1123
|
+
|
1124
|
+
const int64_t n = ggml_nelements(dst);
|
1125
|
+
GGML_ASSERT(n % 4 == 0);
|
1126
|
+
|
1127
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1128
|
+
} break;
|
1129
|
+
default:
|
1130
|
+
{
|
1131
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1132
|
+
GGML_ASSERT(false);
|
1133
|
+
}
|
1134
|
+
} break;
|
1135
|
+
case GGML_OP_SQR:
|
1136
|
+
{
|
1137
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1138
|
+
|
1139
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
|
1140
|
+
|
1141
|
+
[encoder setComputePipelineState:pipeline];
|
1142
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1143
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1144
|
+
|
1145
|
+
const int64_t n = ggml_nelements(dst);
|
1146
|
+
|
1147
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1148
|
+
} break;
|
1149
|
+
case GGML_OP_SUM_ROWS:
|
1150
|
+
{
|
1151
|
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
1152
|
+
|
1153
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
1154
|
+
|
1155
|
+
[encoder setComputePipelineState:pipeline];
|
1156
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1157
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1158
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1159
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1160
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1161
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1162
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1163
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1164
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1165
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1166
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1167
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1168
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1169
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1170
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1171
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1172
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1173
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1174
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1175
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
1176
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
1177
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
1178
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
1179
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
1180
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
1181
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
1182
|
+
|
1183
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1184
|
+
} break;
|
1185
|
+
case GGML_OP_SOFT_MAX:
|
1186
|
+
{
|
1187
|
+
int nth = 32; // SIMD width
|
1188
|
+
|
1189
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1190
|
+
|
1191
|
+
if (ne00%4 == 0) {
|
1192
|
+
while (nth < ne00/4 && nth < 256) {
|
1193
|
+
nth *= 2;
|
1194
|
+
}
|
1195
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
1196
|
+
} else {
|
1197
|
+
while (nth < ne00 && nth < 1024) {
|
1198
|
+
nth *= 2;
|
1199
|
+
}
|
1200
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
1201
|
+
}
|
1202
|
+
|
1203
|
+
const float scale = ((float *) dst->op_params)[0];
|
1204
|
+
|
1205
|
+
[encoder setComputePipelineState:pipeline];
|
1206
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1207
|
+
if (id_src1) {
|
1208
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1209
|
+
} else {
|
1210
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1211
|
+
}
|
1212
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1213
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1214
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1215
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1216
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1217
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1218
|
+
|
1219
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1220
|
+
} break;
|
1221
|
+
case GGML_OP_DIAG_MASK_INF:
|
1222
|
+
{
|
1223
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
1224
|
+
|
1225
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1226
|
+
|
1227
|
+
if (ne00%8 == 0) {
|
1228
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
|
1229
|
+
} else {
|
1230
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
1231
|
+
}
|
1232
|
+
|
1233
|
+
[encoder setComputePipelineState:pipeline];
|
1234
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1235
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1236
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1237
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1238
|
+
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
1239
|
+
|
1240
|
+
if (ne00%8 == 0) {
|
1241
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1242
|
+
}
|
1243
|
+
else {
|
1244
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1245
|
+
}
|
1246
|
+
} break;
|
1247
|
+
case GGML_OP_MUL_MAT:
|
1248
|
+
{
|
1249
|
+
GGML_ASSERT(ne00 == ne10);
|
1250
|
+
|
1251
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
1252
|
+
GGML_ASSERT(ne12 % ne02 == 0);
|
1253
|
+
GGML_ASSERT(ne13 % ne03 == 0);
|
1254
|
+
|
1255
|
+
const uint r2 = ne12/ne02;
|
1256
|
+
const uint r3 = ne13/ne03;
|
1257
|
+
|
1258
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1259
|
+
// to the matrix-vector kernel
|
1260
|
+
int ne11_mm_min = 1;
|
1533
1261
|
|
1534
1262
|
#if 0
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
}
|
1263
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
1264
|
+
// these numbers do not translate to other devices or model sizes
|
1265
|
+
// TODO: need to find a better approach
|
1266
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
1267
|
+
switch (src0t) {
|
1268
|
+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
1269
|
+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
1270
|
+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
1271
|
+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
1272
|
+
case GGML_TYPE_Q4_0:
|
1273
|
+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
1274
|
+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
1275
|
+
case GGML_TYPE_Q5_0: // not tested yet
|
1276
|
+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
1277
|
+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
1278
|
+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
1279
|
+
default: ne11_mm_min = 1; break;
|
1553
1280
|
}
|
1281
|
+
}
|
1554
1282
|
#endif
|
1555
1283
|
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1284
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1285
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1286
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1287
|
+
!ggml_is_transposed(src0) &&
|
1288
|
+
!ggml_is_transposed(src1) &&
|
1289
|
+
src1t == GGML_TYPE_F32 &&
|
1290
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
1291
|
+
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
1292
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1293
|
+
|
1294
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1295
|
+
|
1296
|
+
switch (src0->type) {
|
1297
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
1298
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
1299
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
1300
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
1301
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
1302
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
1303
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
1304
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
1305
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
1306
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
1307
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
1308
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
1309
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1310
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1311
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
1312
|
+
}
|
1313
|
+
|
1314
|
+
[encoder setComputePipelineState:pipeline];
|
1315
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1316
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1317
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1318
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1319
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1320
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
1321
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
1322
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
1323
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
1324
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
1325
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
1326
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
1327
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
1328
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
1329
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
1330
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1331
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1332
|
+
} else {
|
1333
|
+
int nth0 = 32;
|
1334
|
+
int nth1 = 1;
|
1335
|
+
int nrows = 1;
|
1336
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1337
|
+
|
1338
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1339
|
+
|
1340
|
+
// use custom matrix x vector kernel
|
1341
|
+
switch (src0t) {
|
1342
|
+
case GGML_TYPE_F32:
|
1343
|
+
{
|
1344
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1345
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
1346
|
+
nrows = 4;
|
1347
|
+
} break;
|
1348
|
+
case GGML_TYPE_F16:
|
1349
|
+
{
|
1350
|
+
nth0 = 32;
|
1351
|
+
nth1 = 1;
|
1352
|
+
if (src1t == GGML_TYPE_F32) {
|
1353
|
+
if (ne11 * ne12 < 4) {
|
1354
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
1355
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
1356
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
1357
|
+
nrows = ne11;
|
1627
1358
|
} else {
|
1628
|
-
|
1359
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
1629
1360
|
nrows = 4;
|
1630
1361
|
}
|
1631
|
-
}
|
1632
|
-
|
1633
|
-
|
1634
|
-
nth0 = 8;
|
1635
|
-
nth1 = 8;
|
1636
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
1637
|
-
} break;
|
1638
|
-
case GGML_TYPE_Q4_1:
|
1639
|
-
{
|
1640
|
-
nth0 = 8;
|
1641
|
-
nth1 = 8;
|
1642
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1643
|
-
} break;
|
1644
|
-
case GGML_TYPE_Q5_0:
|
1645
|
-
{
|
1646
|
-
nth0 = 8;
|
1647
|
-
nth1 = 8;
|
1648
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1649
|
-
} break;
|
1650
|
-
case GGML_TYPE_Q5_1:
|
1651
|
-
{
|
1652
|
-
nth0 = 8;
|
1653
|
-
nth1 = 8;
|
1654
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1655
|
-
} break;
|
1656
|
-
case GGML_TYPE_Q8_0:
|
1657
|
-
{
|
1658
|
-
nth0 = 8;
|
1659
|
-
nth1 = 8;
|
1660
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
1661
|
-
} break;
|
1662
|
-
case GGML_TYPE_Q2_K:
|
1663
|
-
{
|
1664
|
-
nth0 = 2;
|
1665
|
-
nth1 = 32;
|
1666
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
1667
|
-
} break;
|
1668
|
-
case GGML_TYPE_Q3_K:
|
1669
|
-
{
|
1670
|
-
nth0 = 2;
|
1671
|
-
nth1 = 32;
|
1672
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
1673
|
-
} break;
|
1674
|
-
case GGML_TYPE_Q4_K:
|
1675
|
-
{
|
1676
|
-
nth0 = 4; //1;
|
1677
|
-
nth1 = 8; //32;
|
1678
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
1679
|
-
} break;
|
1680
|
-
case GGML_TYPE_Q5_K:
|
1681
|
-
{
|
1682
|
-
nth0 = 2;
|
1683
|
-
nth1 = 32;
|
1684
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
1685
|
-
} break;
|
1686
|
-
case GGML_TYPE_Q6_K:
|
1687
|
-
{
|
1688
|
-
nth0 = 2;
|
1689
|
-
nth1 = 32;
|
1690
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
1691
|
-
} break;
|
1692
|
-
case GGML_TYPE_IQ2_XXS:
|
1693
|
-
{
|
1694
|
-
nth0 = 4;
|
1695
|
-
nth1 = 16;
|
1696
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
1697
|
-
} break;
|
1698
|
-
case GGML_TYPE_IQ2_XS:
|
1699
|
-
{
|
1700
|
-
nth0 = 4;
|
1701
|
-
nth1 = 16;
|
1702
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
|
1703
|
-
} break;
|
1704
|
-
default:
|
1705
|
-
{
|
1706
|
-
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1707
|
-
GGML_ASSERT(false && "not implemented");
|
1362
|
+
} else {
|
1363
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
1364
|
+
nrows = 4;
|
1708
1365
|
}
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1366
|
+
} break;
|
1367
|
+
case GGML_TYPE_Q4_0:
|
1368
|
+
{
|
1369
|
+
nth0 = 8;
|
1370
|
+
nth1 = 8;
|
1371
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
1372
|
+
} break;
|
1373
|
+
case GGML_TYPE_Q4_1:
|
1374
|
+
{
|
1375
|
+
nth0 = 8;
|
1376
|
+
nth1 = 8;
|
1377
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
1378
|
+
} break;
|
1379
|
+
case GGML_TYPE_Q5_0:
|
1380
|
+
{
|
1381
|
+
nth0 = 8;
|
1382
|
+
nth1 = 8;
|
1383
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
1384
|
+
} break;
|
1385
|
+
case GGML_TYPE_Q5_1:
|
1386
|
+
{
|
1387
|
+
nth0 = 8;
|
1388
|
+
nth1 = 8;
|
1389
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
1390
|
+
} break;
|
1391
|
+
case GGML_TYPE_Q8_0:
|
1392
|
+
{
|
1393
|
+
nth0 = 8;
|
1394
|
+
nth1 = 8;
|
1395
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
1396
|
+
} break;
|
1397
|
+
case GGML_TYPE_Q2_K:
|
1398
|
+
{
|
1399
|
+
nth0 = 2;
|
1400
|
+
nth1 = 32;
|
1401
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
1402
|
+
} break;
|
1403
|
+
case GGML_TYPE_Q3_K:
|
1404
|
+
{
|
1405
|
+
nth0 = 2;
|
1406
|
+
nth1 = 32;
|
1407
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
1408
|
+
} break;
|
1409
|
+
case GGML_TYPE_Q4_K:
|
1410
|
+
{
|
1411
|
+
nth0 = 4; //1;
|
1412
|
+
nth1 = 8; //32;
|
1413
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
1414
|
+
} break;
|
1415
|
+
case GGML_TYPE_Q5_K:
|
1416
|
+
{
|
1417
|
+
nth0 = 2;
|
1418
|
+
nth1 = 32;
|
1419
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
1420
|
+
} break;
|
1421
|
+
case GGML_TYPE_Q6_K:
|
1422
|
+
{
|
1423
|
+
nth0 = 2;
|
1424
|
+
nth1 = 32;
|
1425
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
1426
|
+
} break;
|
1427
|
+
case GGML_TYPE_IQ2_XXS:
|
1428
|
+
{
|
1429
|
+
nth0 = 4;
|
1430
|
+
nth1 = 16;
|
1431
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
1432
|
+
} break;
|
1433
|
+
case GGML_TYPE_IQ2_XS:
|
1434
|
+
{
|
1435
|
+
nth0 = 4;
|
1436
|
+
nth1 = 16;
|
1437
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
1438
|
+
} break;
|
1439
|
+
default:
|
1440
|
+
{
|
1441
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1442
|
+
GGML_ASSERT(false && "not implemented");
|
1443
|
+
}
|
1444
|
+
};
|
1714
1445
|
|
1715
|
-
|
1716
|
-
|
1717
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1718
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1719
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1720
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1721
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1722
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1723
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1724
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
1725
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
1726
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
1727
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
1728
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
1729
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
1730
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
1731
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1732
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1733
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1734
|
-
|
1735
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1736
|
-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1737
|
-
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1738
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1739
|
-
}
|
1740
|
-
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1741
|
-
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1742
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1743
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1744
|
-
}
|
1745
|
-
else if (src0t == GGML_TYPE_Q4_K) {
|
1746
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1747
|
-
}
|
1748
|
-
else if (src0t == GGML_TYPE_Q3_K) {
|
1749
|
-
#ifdef GGML_QKK_64
|
1750
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1751
|
-
#else
|
1752
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1753
|
-
#endif
|
1754
|
-
}
|
1755
|
-
else if (src0t == GGML_TYPE_Q5_K) {
|
1756
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1757
|
-
}
|
1758
|
-
else if (src0t == GGML_TYPE_Q6_K) {
|
1759
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1760
|
-
} else {
|
1761
|
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1762
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
|
-
}
|
1446
|
+
if (ggml_is_quantized(src0t)) {
|
1447
|
+
GGML_ASSERT(ne00 >= nth0*nth1);
|
1764
1448
|
}
|
1765
|
-
} break;
|
1766
|
-
case GGML_OP_MUL_MAT_ID:
|
1767
|
-
{
|
1768
|
-
//GGML_ASSERT(ne00 == ne10);
|
1769
|
-
//GGML_ASSERT(ne03 == ne13);
|
1770
|
-
|
1771
|
-
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1772
|
-
|
1773
|
-
const int n_as = ((int32_t *) dst->op_params)[1];
|
1774
|
-
|
1775
|
-
// TODO: make this more general
|
1776
|
-
GGML_ASSERT(n_as <= 8);
|
1777
|
-
|
1778
|
-
// max size of the src1ids array in the kernel stack
|
1779
|
-
GGML_ASSERT(ne11 <= 512);
|
1780
|
-
|
1781
|
-
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
1782
|
-
|
1783
|
-
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
1784
|
-
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
1785
|
-
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
1786
|
-
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
1787
|
-
|
1788
|
-
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
1789
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
1790
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
1791
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
1792
|
-
|
1793
|
-
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
1794
|
-
|
1795
|
-
GGML_ASSERT(!ggml_is_transposed(src2));
|
1796
|
-
GGML_ASSERT(!ggml_is_transposed(src1));
|
1797
|
-
|
1798
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1799
|
-
|
1800
|
-
const uint r2 = ne12/ne22;
|
1801
|
-
const uint r3 = ne13/ne23;
|
1802
|
-
|
1803
|
-
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1804
|
-
// to the matrix-vector kernel
|
1805
|
-
int ne11_mm_min = n_as;
|
1806
|
-
|
1807
|
-
const int idx = ((int32_t *) dst->op_params)[0];
|
1808
|
-
|
1809
|
-
// batch size
|
1810
|
-
GGML_ASSERT(ne01 == ne11);
|
1811
|
-
|
1812
|
-
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1813
|
-
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1814
|
-
// !!!
|
1815
|
-
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1816
|
-
// indirect matrix multiplication
|
1817
|
-
// !!!
|
1818
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1819
|
-
ne20 % 32 == 0 && ne20 >= 64 &&
|
1820
|
-
ne11 > ne11_mm_min) {
|
1821
|
-
switch (src2->type) {
|
1822
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1823
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
1824
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
1825
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
1826
|
-
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
1827
|
-
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
1828
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
1829
|
-
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
1830
|
-
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
1831
|
-
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
1832
|
-
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
1833
|
-
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
1834
|
-
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
1835
|
-
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
|
1836
|
-
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1837
|
-
}
|
1838
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1839
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1840
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1841
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1842
|
-
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1843
|
-
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1844
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1845
|
-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1846
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1847
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1848
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1849
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1850
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1851
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1852
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
1853
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1854
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1855
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1856
|
-
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1857
|
-
// TODO: how to make this an array? read Metal docs
|
1858
|
-
for (int j = 0; j < 8; ++j) {
|
1859
|
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
1860
|
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1861
|
-
|
1862
|
-
size_t offs_src_cur = 0;
|
1863
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1864
|
-
|
1865
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1866
|
-
}
|
1867
|
-
|
1868
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1869
|
-
|
1870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1871
|
-
} else {
|
1872
|
-
int nth0 = 32;
|
1873
|
-
int nth1 = 1;
|
1874
|
-
int nrows = 1;
|
1875
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1876
|
-
|
1877
|
-
// use custom matrix x vector kernel
|
1878
|
-
switch (src2t) {
|
1879
|
-
case GGML_TYPE_F32:
|
1880
|
-
{
|
1881
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1882
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
1883
|
-
} break;
|
1884
|
-
case GGML_TYPE_F16:
|
1885
|
-
{
|
1886
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1887
|
-
nth0 = 32;
|
1888
|
-
nth1 = 1;
|
1889
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
1890
|
-
} break;
|
1891
|
-
case GGML_TYPE_Q4_0:
|
1892
|
-
{
|
1893
|
-
nth0 = 8;
|
1894
|
-
nth1 = 8;
|
1895
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
1896
|
-
} break;
|
1897
|
-
case GGML_TYPE_Q4_1:
|
1898
|
-
{
|
1899
|
-
nth0 = 8;
|
1900
|
-
nth1 = 8;
|
1901
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
1902
|
-
} break;
|
1903
|
-
case GGML_TYPE_Q5_0:
|
1904
|
-
{
|
1905
|
-
nth0 = 8;
|
1906
|
-
nth1 = 8;
|
1907
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
1908
|
-
} break;
|
1909
|
-
case GGML_TYPE_Q5_1:
|
1910
|
-
{
|
1911
|
-
nth0 = 8;
|
1912
|
-
nth1 = 8;
|
1913
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
1914
|
-
} break;
|
1915
|
-
case GGML_TYPE_Q8_0:
|
1916
|
-
{
|
1917
|
-
nth0 = 8;
|
1918
|
-
nth1 = 8;
|
1919
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
1920
|
-
} break;
|
1921
|
-
case GGML_TYPE_Q2_K:
|
1922
|
-
{
|
1923
|
-
nth0 = 2;
|
1924
|
-
nth1 = 32;
|
1925
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
1926
|
-
} break;
|
1927
|
-
case GGML_TYPE_Q3_K:
|
1928
|
-
{
|
1929
|
-
nth0 = 2;
|
1930
|
-
nth1 = 32;
|
1931
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
1932
|
-
} break;
|
1933
|
-
case GGML_TYPE_Q4_K:
|
1934
|
-
{
|
1935
|
-
nth0 = 4; //1;
|
1936
|
-
nth1 = 8; //32;
|
1937
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
1938
|
-
} break;
|
1939
|
-
case GGML_TYPE_Q5_K:
|
1940
|
-
{
|
1941
|
-
nth0 = 2;
|
1942
|
-
nth1 = 32;
|
1943
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
1944
|
-
} break;
|
1945
|
-
case GGML_TYPE_Q6_K:
|
1946
|
-
{
|
1947
|
-
nth0 = 2;
|
1948
|
-
nth1 = 32;
|
1949
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
1950
|
-
} break;
|
1951
|
-
case GGML_TYPE_IQ2_XXS:
|
1952
|
-
{
|
1953
|
-
nth0 = 4;
|
1954
|
-
nth1 = 16;
|
1955
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
1956
|
-
} break;
|
1957
|
-
case GGML_TYPE_IQ2_XS:
|
1958
|
-
{
|
1959
|
-
nth0 = 4;
|
1960
|
-
nth1 = 16;
|
1961
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
|
1962
|
-
} break;
|
1963
|
-
default:
|
1964
|
-
{
|
1965
|
-
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
1966
|
-
GGML_ASSERT(false && "not implemented");
|
1967
|
-
}
|
1968
|
-
};
|
1969
|
-
|
1970
|
-
if (ggml_is_quantized(src2t)) {
|
1971
|
-
GGML_ASSERT(ne20 >= nth0*nth1);
|
1972
|
-
}
|
1973
1449
|
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
1995
|
-
|
1996
|
-
|
1997
|
-
|
1998
|
-
[encoder
|
1999
|
-
|
2000
|
-
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2009
|
-
|
2010
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
2011
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
2012
|
-
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
2013
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2014
|
-
}
|
2015
|
-
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
2016
|
-
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
2017
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2018
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2019
|
-
}
|
2020
|
-
else if (src2t == GGML_TYPE_Q4_K) {
|
2021
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2022
|
-
}
|
2023
|
-
else if (src2t == GGML_TYPE_Q3_K) {
|
1450
|
+
[encoder setComputePipelineState:pipeline];
|
1451
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1452
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1453
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1454
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1455
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1456
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1457
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1458
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1459
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1460
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
1461
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
1462
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
1463
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
1464
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
1465
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
1466
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
1467
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1468
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1469
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1470
|
+
|
1471
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1472
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1473
|
+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1474
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1475
|
+
}
|
1476
|
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1477
|
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1478
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1479
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1480
|
+
}
|
1481
|
+
else if (src0t == GGML_TYPE_Q4_K) {
|
1482
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1483
|
+
}
|
1484
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
2024
1485
|
#ifdef GGML_QKK_64
|
2025
|
-
|
1486
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2026
1487
|
#else
|
2027
|
-
|
1488
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2028
1489
|
#endif
|
2029
|
-
}
|
2030
|
-
else if (src2t == GGML_TYPE_Q5_K) {
|
2031
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2032
|
-
}
|
2033
|
-
else if (src2t == GGML_TYPE_Q6_K) {
|
2034
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2035
|
-
} else {
|
2036
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
2037
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2038
|
-
}
|
2039
1490
|
}
|
2040
|
-
|
2041
|
-
|
2042
|
-
{
|
2043
|
-
switch (src0->type) {
|
2044
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
2045
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
2046
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
2047
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
2048
|
-
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
2049
|
-
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
2050
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
2051
|
-
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
2052
|
-
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
2053
|
-
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
2054
|
-
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
2055
|
-
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
2056
|
-
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
2057
|
-
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
2058
|
-
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
|
2059
|
-
default: GGML_ASSERT(false && "not implemented");
|
1491
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
1492
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2060
1493
|
}
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
2067
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
2068
|
-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
2069
|
-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
2070
|
-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
2071
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
2072
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
2073
|
-
|
2074
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
2075
|
-
} break;
|
2076
|
-
case GGML_OP_RMS_NORM:
|
2077
|
-
{
|
2078
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
2079
|
-
|
2080
|
-
float eps;
|
2081
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
2082
|
-
|
2083
|
-
int nth = 32; // SIMD width
|
2084
|
-
|
2085
|
-
while (nth < ne00/4 && nth < 1024) {
|
2086
|
-
nth *= 2;
|
1494
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
1495
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1496
|
+
} else {
|
1497
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1498
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2087
1499
|
}
|
1500
|
+
}
|
1501
|
+
} break;
|
1502
|
+
case GGML_OP_MUL_MAT_ID:
|
1503
|
+
{
|
1504
|
+
//GGML_ASSERT(ne00 == ne10);
|
1505
|
+
//GGML_ASSERT(ne03 == ne13);
|
2088
1506
|
|
2089
|
-
|
2090
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2091
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2092
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2093
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
2094
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
2095
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2096
|
-
|
2097
|
-
const int64_t nrows = ggml_nrows(src0);
|
2098
|
-
|
2099
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2100
|
-
} break;
|
2101
|
-
case GGML_OP_GROUP_NORM:
|
2102
|
-
{
|
2103
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
2104
|
-
|
2105
|
-
//float eps;
|
2106
|
-
//memcpy(&eps, dst->op_params, sizeof(float));
|
2107
|
-
|
2108
|
-
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
2109
|
-
|
2110
|
-
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
2111
|
-
|
2112
|
-
int nth = 32; // SIMD width
|
2113
|
-
|
2114
|
-
//while (nth < ne00/4 && nth < 1024) {
|
2115
|
-
// nth *= 2;
|
2116
|
-
//}
|
2117
|
-
|
2118
|
-
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
2119
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2120
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2121
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2122
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2123
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2124
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
2125
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
2126
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
2127
|
-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
2128
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
2129
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2130
|
-
|
2131
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2132
|
-
} break;
|
2133
|
-
case GGML_OP_NORM:
|
2134
|
-
{
|
2135
|
-
float eps;
|
2136
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
2137
|
-
|
2138
|
-
const int nth = MIN(256, ne00);
|
2139
|
-
|
2140
|
-
[encoder setComputePipelineState:ctx->pipeline_norm];
|
2141
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2142
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2143
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2144
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
2145
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
2146
|
-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
2147
|
-
|
2148
|
-
const int64_t nrows = ggml_nrows(src0);
|
2149
|
-
|
2150
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2151
|
-
} break;
|
2152
|
-
case GGML_OP_ALIBI:
|
2153
|
-
{
|
2154
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2155
|
-
|
2156
|
-
const int nth = MIN(1024, ne00);
|
2157
|
-
|
2158
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2159
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2160
|
-
float max_bias;
|
2161
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
1507
|
+
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
2162
1508
|
|
2163
|
-
|
2164
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2165
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
1509
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
2166
1510
|
|
2167
|
-
|
2168
|
-
|
2169
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2170
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2171
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2172
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2173
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2174
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2175
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2176
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2177
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2178
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2179
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2180
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2181
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2182
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2183
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2184
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2185
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2186
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2187
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2188
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
1511
|
+
// TODO: make this more general
|
1512
|
+
GGML_ASSERT(n_as <= 8);
|
2189
1513
|
|
2190
|
-
|
2191
|
-
|
2192
|
-
case GGML_OP_ROPE:
|
2193
|
-
{
|
2194
|
-
GGML_ASSERT(ne10 == ne02);
|
2195
|
-
|
2196
|
-
const int nth = MIN(1024, ne00);
|
2197
|
-
|
2198
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
2199
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
2200
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
2201
|
-
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2202
|
-
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
2203
|
-
|
2204
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
2205
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
2206
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
2207
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
2208
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
2209
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
2210
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
1514
|
+
// max size of the src1ids array in the kernel stack
|
1515
|
+
GGML_ASSERT(ne11 <= 512);
|
2211
1516
|
|
2212
|
-
|
2213
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
2214
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
2215
|
-
default: GGML_ASSERT(false);
|
2216
|
-
};
|
1517
|
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
2217
1518
|
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
2223
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
2224
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
2225
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
2226
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2227
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2228
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2229
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
2230
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
2231
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
2232
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
2233
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
2234
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
2235
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
2236
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
2237
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
2238
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
2239
|
-
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
2240
|
-
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
2241
|
-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
2242
|
-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
2243
|
-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
2244
|
-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
2245
|
-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
2246
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
1519
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
1520
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
1521
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
1522
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
2247
1523
|
|
2248
|
-
|
2249
|
-
|
2250
|
-
|
2251
|
-
|
2252
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2253
|
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
2254
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
1524
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
1525
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
1526
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
1527
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
2255
1528
|
|
2256
|
-
|
2257
|
-
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
2258
|
-
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
2259
|
-
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
2260
|
-
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
2261
|
-
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
2262
|
-
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
1529
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2263
1530
|
|
2264
|
-
|
2265
|
-
|
2266
|
-
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
2267
|
-
const int32_t IW = src1->ne[0];
|
1531
|
+
GGML_ASSERT(!ggml_is_transposed(src2));
|
1532
|
+
GGML_ASSERT(!ggml_is_transposed(src1));
|
2268
1533
|
|
2269
|
-
|
2270
|
-
const int32_t KW = src0->ne[0];
|
1534
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
2271
1535
|
|
2272
|
-
|
2273
|
-
|
1536
|
+
const uint r2 = ne12/ne22;
|
1537
|
+
const uint r3 = ne13/ne23;
|
2274
1538
|
|
2275
|
-
|
1539
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1540
|
+
// to the matrix-vector kernel
|
1541
|
+
int ne11_mm_min = n_as;
|
2276
1542
|
|
2277
|
-
|
2278
|
-
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
1543
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
2279
1544
|
|
2280
|
-
|
2281
|
-
|
2282
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
2283
|
-
default: GGML_ASSERT(false);
|
2284
|
-
};
|
1545
|
+
// batch size
|
1546
|
+
GGML_ASSERT(ne01 == ne11);
|
2285
1547
|
|
2286
|
-
|
2287
|
-
|
2288
|
-
|
2289
|
-
|
2290
|
-
|
2291
|
-
|
2292
|
-
|
2293
|
-
|
2294
|
-
|
2295
|
-
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
2296
|
-
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
2297
|
-
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
2298
|
-
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
2299
|
-
|
2300
|
-
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
2301
|
-
} break;
|
2302
|
-
case GGML_OP_UPSCALE:
|
2303
|
-
{
|
2304
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2305
|
-
|
2306
|
-
const int sf = dst->op_params[0];
|
2307
|
-
|
2308
|
-
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
2309
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2310
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2311
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2312
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2313
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2314
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2315
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2316
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2317
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2318
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2319
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2320
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2321
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2322
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2323
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2324
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2325
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2326
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2327
|
-
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2328
|
-
|
2329
|
-
const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
|
2330
|
-
|
2331
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2332
|
-
} break;
|
2333
|
-
case GGML_OP_PAD:
|
2334
|
-
{
|
2335
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2336
|
-
|
2337
|
-
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
2338
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2339
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2340
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2341
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2342
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2343
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2344
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2345
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2346
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2347
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2348
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2349
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2350
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2351
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2352
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2353
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2354
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2355
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2356
|
-
|
2357
|
-
const int nth = MIN(1024, ne0);
|
2358
|
-
|
2359
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2360
|
-
} break;
|
2361
|
-
case GGML_OP_ARGSORT:
|
2362
|
-
{
|
2363
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2364
|
-
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
2365
|
-
|
2366
|
-
const int nrows = ggml_nrows(src0);
|
2367
|
-
|
2368
|
-
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
2369
|
-
|
2370
|
-
switch (order) {
|
2371
|
-
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
2372
|
-
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
2373
|
-
default: GGML_ASSERT(false);
|
2374
|
-
};
|
1548
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1549
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1550
|
+
// !!!
|
1551
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1552
|
+
// indirect matrix multiplication
|
1553
|
+
// !!!
|
1554
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1555
|
+
ne20 % 32 == 0 && ne20 >= 64 &&
|
1556
|
+
ne11 > ne11_mm_min) {
|
2375
1557
|
|
2376
|
-
|
2377
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2378
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2379
|
-
|
2380
|
-
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
2381
|
-
} break;
|
2382
|
-
case GGML_OP_LEAKY_RELU:
|
2383
|
-
{
|
2384
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
1558
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2385
1559
|
|
2386
|
-
|
2387
|
-
|
1560
|
+
switch (src2->type) {
|
1561
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
1562
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
1563
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
1564
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
1565
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
1566
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
1567
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
1568
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
1569
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
1570
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
1571
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
1572
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
1573
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
1574
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
1575
|
+
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1576
|
+
}
|
2388
1577
|
|
2389
|
-
[encoder setComputePipelineState:
|
2390
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
2391
|
-
[encoder setBuffer:
|
2392
|
-
[encoder
|
1578
|
+
[encoder setComputePipelineState:pipeline];
|
1579
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1580
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1581
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1582
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1583
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1584
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1585
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1586
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1587
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1588
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1589
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1590
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1591
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1592
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1593
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
1594
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1595
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1596
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1597
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1598
|
+
// TODO: how to make this an array? read Metal docs
|
1599
|
+
for (int j = 0; j < 8; ++j) {
|
1600
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
1601
|
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1602
|
+
|
1603
|
+
size_t offs_src_cur = 0;
|
1604
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1605
|
+
|
1606
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1607
|
+
}
|
2393
1608
|
|
2394
|
-
|
1609
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
2395
1610
|
|
2396
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
2397
|
-
}
|
2398
|
-
|
2399
|
-
|
2400
|
-
|
2401
|
-
|
2402
|
-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
1611
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1612
|
+
} else {
|
1613
|
+
int nth0 = 32;
|
1614
|
+
int nth1 = 1;
|
1615
|
+
int nrows = 1;
|
1616
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2403
1617
|
|
2404
|
-
|
1618
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2405
1619
|
|
2406
|
-
|
1620
|
+
// use custom matrix x vector kernel
|
1621
|
+
switch (src2t) {
|
2407
1622
|
case GGML_TYPE_F32:
|
2408
1623
|
{
|
2409
|
-
GGML_ASSERT(
|
2410
|
-
|
2411
|
-
switch (dstt) {
|
2412
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
2413
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
2414
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
2415
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
2416
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
2417
|
-
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
2418
|
-
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
2419
|
-
default: GGML_ASSERT(false && "not implemented");
|
2420
|
-
};
|
1624
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1625
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
2421
1626
|
} break;
|
2422
1627
|
case GGML_TYPE_F16:
|
2423
1628
|
{
|
2424
|
-
|
2425
|
-
|
2426
|
-
|
2427
|
-
|
2428
|
-
|
1629
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1630
|
+
nth0 = 32;
|
1631
|
+
nth1 = 1;
|
1632
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
1633
|
+
} break;
|
1634
|
+
case GGML_TYPE_Q4_0:
|
1635
|
+
{
|
1636
|
+
nth0 = 8;
|
1637
|
+
nth1 = 8;
|
1638
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
2429
1639
|
} break;
|
2430
|
-
|
1640
|
+
case GGML_TYPE_Q4_1:
|
1641
|
+
{
|
1642
|
+
nth0 = 8;
|
1643
|
+
nth1 = 8;
|
1644
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
1645
|
+
} break;
|
1646
|
+
case GGML_TYPE_Q5_0:
|
1647
|
+
{
|
1648
|
+
nth0 = 8;
|
1649
|
+
nth1 = 8;
|
1650
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
1651
|
+
} break;
|
1652
|
+
case GGML_TYPE_Q5_1:
|
1653
|
+
{
|
1654
|
+
nth0 = 8;
|
1655
|
+
nth1 = 8;
|
1656
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
1657
|
+
} break;
|
1658
|
+
case GGML_TYPE_Q8_0:
|
1659
|
+
{
|
1660
|
+
nth0 = 8;
|
1661
|
+
nth1 = 8;
|
1662
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
1663
|
+
} break;
|
1664
|
+
case GGML_TYPE_Q2_K:
|
1665
|
+
{
|
1666
|
+
nth0 = 2;
|
1667
|
+
nth1 = 32;
|
1668
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
1669
|
+
} break;
|
1670
|
+
case GGML_TYPE_Q3_K:
|
1671
|
+
{
|
1672
|
+
nth0 = 2;
|
1673
|
+
nth1 = 32;
|
1674
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
1675
|
+
} break;
|
1676
|
+
case GGML_TYPE_Q4_K:
|
1677
|
+
{
|
1678
|
+
nth0 = 4; //1;
|
1679
|
+
nth1 = 8; //32;
|
1680
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
1681
|
+
} break;
|
1682
|
+
case GGML_TYPE_Q5_K:
|
1683
|
+
{
|
1684
|
+
nth0 = 2;
|
1685
|
+
nth1 = 32;
|
1686
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
1687
|
+
} break;
|
1688
|
+
case GGML_TYPE_Q6_K:
|
1689
|
+
{
|
1690
|
+
nth0 = 2;
|
1691
|
+
nth1 = 32;
|
1692
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
1693
|
+
} break;
|
1694
|
+
case GGML_TYPE_IQ2_XXS:
|
1695
|
+
{
|
1696
|
+
nth0 = 4;
|
1697
|
+
nth1 = 16;
|
1698
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
1699
|
+
} break;
|
1700
|
+
case GGML_TYPE_IQ2_XS:
|
1701
|
+
{
|
1702
|
+
nth0 = 4;
|
1703
|
+
nth1 = 16;
|
1704
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
1705
|
+
} break;
|
1706
|
+
default:
|
1707
|
+
{
|
1708
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
1709
|
+
GGML_ASSERT(false && "not implemented");
|
1710
|
+
}
|
1711
|
+
};
|
1712
|
+
|
1713
|
+
if (ggml_is_quantized(src2t)) {
|
1714
|
+
GGML_ASSERT(ne20 >= nth0*nth1);
|
2431
1715
|
}
|
2432
1716
|
|
2433
|
-
|
2434
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2435
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2436
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2437
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2438
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2439
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2440
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2441
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2442
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2443
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2444
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2445
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2446
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2447
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2448
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2449
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2450
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1717
|
+
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
2451
1718
|
|
2452
|
-
[encoder
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
2458
|
-
|
2459
|
-
|
1719
|
+
[encoder setComputePipelineState:pipeline];
|
1720
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1721
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1722
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1723
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1724
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1725
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1726
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
1727
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
1728
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
1729
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
1730
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1731
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
1732
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1733
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1734
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1735
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1736
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1737
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1738
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
1739
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1740
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
1741
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
1742
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
1743
|
+
// TODO: how to make this an array? read Metal docs
|
1744
|
+
for (int j = 0; j < 8; ++j) {
|
1745
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
1746
|
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1747
|
+
|
1748
|
+
size_t offs_src_cur = 0;
|
1749
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1750
|
+
|
1751
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1752
|
+
}
|
2460
1753
|
|
2461
|
-
|
2462
|
-
|
1754
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1755
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1756
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
1757
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1758
|
+
}
|
1759
|
+
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
1760
|
+
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1761
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1762
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
|
+
}
|
1764
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
1765
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1766
|
+
}
|
1767
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
1768
|
+
#ifdef GGML_QKK_64
|
1769
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1770
|
+
#else
|
1771
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2463
1772
|
#endif
|
2464
|
-
|
1773
|
+
}
|
1774
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
1775
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1776
|
+
}
|
1777
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
1778
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1779
|
+
} else {
|
1780
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1781
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
} break;
|
1785
|
+
case GGML_OP_GET_ROWS:
|
1786
|
+
{
|
1787
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1788
|
+
|
1789
|
+
switch (src0->type) {
|
1790
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
1791
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
1792
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
1793
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
1794
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
1795
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
1796
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
1797
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
1798
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
1799
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
1800
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
|
1801
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
1802
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
1803
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
1804
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
1805
|
+
default: GGML_ASSERT(false && "not implemented");
|
1806
|
+
}
|
1807
|
+
|
1808
|
+
[encoder setComputePipelineState:pipeline];
|
1809
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1810
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1811
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1812
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1813
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1814
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
1815
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
1816
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
1817
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
1818
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
1819
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
1820
|
+
|
1821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1822
|
+
} break;
|
1823
|
+
case GGML_OP_RMS_NORM:
|
1824
|
+
{
|
1825
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1826
|
+
|
1827
|
+
float eps;
|
1828
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
1829
|
+
|
1830
|
+
int nth = 32; // SIMD width
|
1831
|
+
|
1832
|
+
while (nth < ne00/4 && nth < 1024) {
|
1833
|
+
nth *= 2;
|
1834
|
+
}
|
2465
1835
|
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
1836
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
1837
|
+
|
1838
|
+
[encoder setComputePipelineState:pipeline];
|
1839
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1840
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1841
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1842
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
1843
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
1844
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1845
|
+
|
1846
|
+
const int64_t nrows = ggml_nrows(src0);
|
1847
|
+
|
1848
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1849
|
+
} break;
|
1850
|
+
case GGML_OP_GROUP_NORM:
|
1851
|
+
{
|
1852
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1853
|
+
|
1854
|
+
//float eps;
|
1855
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
1856
|
+
|
1857
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
1858
|
+
|
1859
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
1860
|
+
|
1861
|
+
int nth = 32; // SIMD width
|
1862
|
+
|
1863
|
+
//while (nth < ne00/4 && nth < 1024) {
|
1864
|
+
// nth *= 2;
|
1865
|
+
//}
|
1866
|
+
|
1867
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
1868
|
+
|
1869
|
+
[encoder setComputePipelineState:pipeline];
|
1870
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1871
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1872
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1873
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1874
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1875
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
1876
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
1877
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
1878
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
1879
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
1880
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1881
|
+
|
1882
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1883
|
+
} break;
|
1884
|
+
case GGML_OP_NORM:
|
1885
|
+
{
|
1886
|
+
float eps;
|
1887
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
1888
|
+
|
1889
|
+
const int nth = MIN(256, ne00);
|
1890
|
+
|
1891
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
1892
|
+
|
1893
|
+
[encoder setComputePipelineState:pipeline];
|
1894
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1895
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1896
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1897
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
1898
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
1899
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
1900
|
+
|
1901
|
+
const int64_t nrows = ggml_nrows(src0);
|
1902
|
+
|
1903
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1904
|
+
} break;
|
1905
|
+
case GGML_OP_ALIBI:
|
1906
|
+
{
|
1907
|
+
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
1908
|
+
|
1909
|
+
const int nth = MIN(1024, ne00);
|
1910
|
+
|
1911
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
1912
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
1913
|
+
float max_bias;
|
1914
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
1915
|
+
|
1916
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
1917
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
1918
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
1919
|
+
|
1920
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
1921
|
+
|
1922
|
+
[encoder setComputePipelineState:pipeline];
|
1923
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1924
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1925
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1926
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1927
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1928
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1929
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1930
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1931
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1932
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1933
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1934
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1935
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1936
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1937
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1938
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1939
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1940
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1941
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1942
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
1943
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
1944
|
+
|
1945
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1946
|
+
} break;
|
1947
|
+
case GGML_OP_ROPE:
|
1948
|
+
{
|
1949
|
+
GGML_ASSERT(ne10 == ne02);
|
1950
|
+
|
1951
|
+
const int nth = MIN(1024, ne00);
|
1952
|
+
|
1953
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
1954
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1955
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
1956
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
1957
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
1958
|
+
|
1959
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
1960
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
1961
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
1962
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
1963
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
1964
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
1965
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
1966
|
+
|
1967
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1968
|
+
|
1969
|
+
switch (src0->type) {
|
1970
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
1971
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
1972
|
+
default: GGML_ASSERT(false);
|
1973
|
+
};
|
1974
|
+
|
1975
|
+
[encoder setComputePipelineState:pipeline];
|
1976
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1977
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1978
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1979
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1980
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
1981
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
1982
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
1983
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
1984
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
1985
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
1986
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
1987
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
1988
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
1989
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
1990
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
1991
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
1992
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
1993
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
1994
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
1995
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
1996
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
1997
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
1998
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
1999
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
2000
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
2001
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
2002
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
2003
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
2004
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
2005
|
+
|
2006
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2007
|
+
} break;
|
2008
|
+
case GGML_OP_IM2COL:
|
2009
|
+
{
|
2010
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2011
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
2012
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
2013
|
+
|
2014
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
2015
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
2016
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
2017
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
2018
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
2019
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
2020
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
2021
|
+
|
2022
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
2023
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
2024
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
2025
|
+
const int32_t IW = src1->ne[0];
|
2026
|
+
|
2027
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
2028
|
+
const int32_t KW = src0->ne[0];
|
2029
|
+
|
2030
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
2031
|
+
const int32_t OW = dst->ne[1];
|
2032
|
+
|
2033
|
+
const int32_t CHW = IC * KH * KW;
|
2034
|
+
|
2035
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
2036
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
2037
|
+
|
2038
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2039
|
+
|
2040
|
+
switch (src0->type) {
|
2041
|
+
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
2042
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
2043
|
+
default: GGML_ASSERT(false);
|
2044
|
+
};
|
2045
|
+
|
2046
|
+
[encoder setComputePipelineState:pipeline];
|
2047
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
2048
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2049
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
2050
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
2051
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
2052
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
2053
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
2054
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
2055
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
2056
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
2057
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
2058
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
2059
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
2060
|
+
|
2061
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
2062
|
+
} break;
|
2063
|
+
case GGML_OP_UPSCALE:
|
2064
|
+
{
|
2065
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2066
|
+
|
2067
|
+
const int sf = dst->op_params[0];
|
2068
|
+
|
2069
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2070
|
+
|
2071
|
+
[encoder setComputePipelineState:pipeline];
|
2072
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2073
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2074
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2075
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2076
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2077
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2078
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2079
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2080
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2081
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2082
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2083
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2084
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2085
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2086
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2087
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2088
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2089
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2090
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2091
|
+
|
2092
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2093
|
+
|
2094
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2095
|
+
} break;
|
2096
|
+
case GGML_OP_PAD:
|
2097
|
+
{
|
2098
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2099
|
+
|
2100
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
2101
|
+
|
2102
|
+
[encoder setComputePipelineState:pipeline];
|
2103
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2104
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2105
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2106
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2107
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2108
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2109
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2110
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2111
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2112
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2113
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2114
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2115
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2116
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2117
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2118
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2119
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2120
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2121
|
+
|
2122
|
+
const int nth = MIN(1024, ne0);
|
2123
|
+
|
2124
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2125
|
+
} break;
|
2126
|
+
case GGML_OP_ARGSORT:
|
2127
|
+
{
|
2128
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2129
|
+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
2130
|
+
|
2131
|
+
const int nrows = ggml_nrows(src0);
|
2132
|
+
|
2133
|
+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
2134
|
+
|
2135
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2136
|
+
|
2137
|
+
switch (order) {
|
2138
|
+
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
2139
|
+
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
2140
|
+
default: GGML_ASSERT(false);
|
2141
|
+
};
|
2142
|
+
|
2143
|
+
[encoder setComputePipelineState:pipeline];
|
2144
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2145
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2146
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2147
|
+
|
2148
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
2149
|
+
} break;
|
2150
|
+
case GGML_OP_LEAKY_RELU:
|
2151
|
+
{
|
2152
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2153
|
+
|
2154
|
+
float slope;
|
2155
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
2156
|
+
|
2157
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
2158
|
+
|
2159
|
+
[encoder setComputePipelineState:pipeline];
|
2160
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2161
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2162
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
2163
|
+
|
2164
|
+
const int64_t n = ggml_nelements(dst);
|
2165
|
+
|
2166
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2167
|
+
} break;
|
2168
|
+
case GGML_OP_DUP:
|
2169
|
+
case GGML_OP_CPY:
|
2170
|
+
case GGML_OP_CONT:
|
2171
|
+
{
|
2172
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
2173
|
+
|
2174
|
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
2175
|
+
|
2176
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2177
|
+
|
2178
|
+
switch (src0t) {
|
2179
|
+
case GGML_TYPE_F32:
|
2180
|
+
{
|
2181
|
+
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
2182
|
+
|
2183
|
+
switch (dstt) {
|
2184
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
2185
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
2186
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
2187
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
2188
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
2189
|
+
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
2190
|
+
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
2191
|
+
default: GGML_ASSERT(false && "not implemented");
|
2192
|
+
};
|
2193
|
+
} break;
|
2194
|
+
case GGML_TYPE_F16:
|
2195
|
+
{
|
2196
|
+
switch (dstt) {
|
2197
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
2198
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
2199
|
+
default: GGML_ASSERT(false && "not implemented");
|
2200
|
+
};
|
2201
|
+
} break;
|
2202
|
+
default: GGML_ASSERT(false && "not implemented");
|
2203
|
+
}
|
2204
|
+
|
2205
|
+
[encoder setComputePipelineState:pipeline];
|
2206
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2207
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2208
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2209
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2210
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2211
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2212
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2213
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2214
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2215
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2216
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2217
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2218
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2219
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2220
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2221
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2222
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2223
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2224
|
+
|
2225
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2226
|
+
} break;
|
2227
|
+
default:
|
2228
|
+
{
|
2229
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
2230
|
+
GGML_ASSERT(false);
|
2231
|
+
}
|
2469
2232
|
}
|
2470
2233
|
|
2471
|
-
|
2472
|
-
|
2473
|
-
|
2234
|
+
#ifndef GGML_METAL_NDEBUG
|
2235
|
+
[encoder popDebugGroup];
|
2236
|
+
#endif
|
2237
|
+
}
|
2238
|
+
|
2239
|
+
if (encoder != nil) {
|
2240
|
+
[encoder endEncoding];
|
2241
|
+
encoder = nil;
|
2242
|
+
}
|
2474
2243
|
|
2475
|
-
|
2476
|
-
|
2244
|
+
[command_buffer commit];
|
2245
|
+
});
|
2477
2246
|
|
2478
|
-
// check status of command
|
2247
|
+
// Wait for completion and check status of each command buffer
|
2479
2248
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
2480
|
-
for (int i = 0; i < n_cb; i++) {
|
2481
|
-
[ctx->command_buffers[i] waitUntilCompleted];
|
2482
2249
|
|
2483
|
-
|
2250
|
+
for (int i = 0; i < n_cb; ++i) {
|
2251
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[i];
|
2252
|
+
[command_buffer waitUntilCompleted];
|
2253
|
+
|
2254
|
+
MTLCommandBufferStatus status = [command_buffer status];
|
2484
2255
|
if (status != MTLCommandBufferStatusCompleted) {
|
2485
2256
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2486
2257
|
return false;
|
@@ -2520,13 +2291,13 @@ static void ggml_backend_metal_free_device(void) {
|
|
2520
2291
|
}
|
2521
2292
|
}
|
2522
2293
|
|
2523
|
-
static
|
2524
|
-
|
2294
|
+
GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
2295
|
+
return "Metal";
|
2525
2296
|
|
2526
|
-
|
2297
|
+
UNUSED(buffer);
|
2527
2298
|
}
|
2528
2299
|
|
2529
|
-
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
2300
|
+
GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
2530
2301
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2531
2302
|
|
2532
2303
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
@@ -2541,50 +2312,80 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
2541
2312
|
free(ctx);
|
2542
2313
|
}
|
2543
2314
|
|
2544
|
-
static void
|
2545
|
-
|
2315
|
+
GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
2316
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2546
2317
|
|
2547
|
-
|
2318
|
+
return ctx->all_data;
|
2548
2319
|
}
|
2549
2320
|
|
2550
|
-
static void
|
2551
|
-
memcpy(
|
2321
|
+
GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
2322
|
+
memcpy((char *)tensor->data + offset, data, size);
|
2552
2323
|
|
2553
2324
|
UNUSED(buffer);
|
2554
2325
|
}
|
2555
2326
|
|
2556
|
-
static void
|
2557
|
-
|
2327
|
+
GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
2328
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
2558
2329
|
|
2559
2330
|
UNUSED(buffer);
|
2560
2331
|
}
|
2561
2332
|
|
2562
|
-
static
|
2563
|
-
|
2333
|
+
GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
2334
|
+
if (ggml_backend_buffer_is_host(src->buffer)) {
|
2335
|
+
memcpy(dst->data, src->data, ggml_nbytes(src));
|
2336
|
+
return true;
|
2337
|
+
}
|
2338
|
+
return false;
|
2564
2339
|
|
2565
2340
|
UNUSED(buffer);
|
2566
2341
|
}
|
2567
2342
|
|
2568
|
-
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
2343
|
+
GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
2569
2344
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2570
2345
|
|
2571
2346
|
memset(ctx->all_data, value, ctx->all_size);
|
2572
2347
|
}
|
2573
2348
|
|
2574
2349
|
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
2350
|
+
/* .get_name = */ ggml_backend_metal_buffer_get_name,
|
2575
2351
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
2576
2352
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
2577
2353
|
/* .init_tensor = */ NULL,
|
2578
2354
|
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
2579
2355
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
2580
|
-
/* .
|
2581
|
-
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
2356
|
+
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
2582
2357
|
/* .clear = */ ggml_backend_metal_buffer_clear,
|
2358
|
+
/* .reset = */ NULL,
|
2583
2359
|
};
|
2584
2360
|
|
2585
2361
|
// default buffer type
|
2586
2362
|
|
2587
|
-
static
|
2363
|
+
GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
2364
|
+
return "Metal";
|
2365
|
+
|
2366
|
+
UNUSED(buft);
|
2367
|
+
}
|
2368
|
+
|
2369
|
+
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
2370
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
2371
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
2372
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2373
|
+
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2374
|
+
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2375
|
+
|
2376
|
+
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2377
|
+
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2378
|
+
} else {
|
2379
|
+
GGML_METAL_LOG_INFO("\n");
|
2380
|
+
}
|
2381
|
+
} else {
|
2382
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
2383
|
+
}
|
2384
|
+
#endif
|
2385
|
+
UNUSED(device);
|
2386
|
+
}
|
2387
|
+
|
2388
|
+
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
2588
2389
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
2589
2390
|
|
2590
2391
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
@@ -2616,46 +2417,32 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
2616
2417
|
}
|
2617
2418
|
|
2618
2419
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
2619
|
-
|
2620
|
-
|
2621
|
-
#if TARGET_OS_OSX
|
2622
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2623
|
-
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2624
|
-
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2625
|
-
|
2626
|
-
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2627
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2628
|
-
} else {
|
2629
|
-
GGML_METAL_LOG_INFO("\n");
|
2630
|
-
}
|
2631
|
-
#else
|
2632
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
2633
|
-
#endif
|
2634
|
-
|
2420
|
+
ggml_backend_metal_log_allocated_size(device);
|
2635
2421
|
|
2636
2422
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
2637
2423
|
}
|
2638
2424
|
|
2639
|
-
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
2425
|
+
GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
2640
2426
|
return 32;
|
2641
2427
|
UNUSED(buft);
|
2642
2428
|
}
|
2643
2429
|
|
2644
|
-
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
2430
|
+
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
2645
2431
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
2646
2432
|
|
2647
2433
|
UNUSED(buft);
|
2648
2434
|
}
|
2649
2435
|
|
2650
|
-
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
2436
|
+
GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
2651
2437
|
return true;
|
2652
2438
|
|
2653
2439
|
UNUSED(buft);
|
2654
2440
|
}
|
2655
2441
|
|
2656
|
-
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
2442
|
+
GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
2657
2443
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
2658
2444
|
/* .iface = */ {
|
2445
|
+
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
2659
2446
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
2660
2447
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2661
2448
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
@@ -2670,7 +2457,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2670
2457
|
|
2671
2458
|
// buffer from ptr
|
2672
2459
|
|
2673
|
-
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
2460
|
+
GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
2674
2461
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
2675
2462
|
|
2676
2463
|
ctx->all_data = data;
|
@@ -2679,6 +2466,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
2679
2466
|
ctx->n_buffers = 0;
|
2680
2467
|
|
2681
2468
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
2469
|
+
|
2470
|
+
// page-align the data ptr
|
2471
|
+
{
|
2472
|
+
const uintptr_t offs = (uintptr_t) data % size_page;
|
2473
|
+
data = (void *) ((char *) data - offs);
|
2474
|
+
size += offs;
|
2475
|
+
}
|
2476
|
+
|
2682
2477
|
size_t size_aligned = size;
|
2683
2478
|
if ((size_aligned % size_page) != 0) {
|
2684
2479
|
size_aligned += (size_page - (size_aligned % size_page));
|
@@ -2730,63 +2525,50 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
2730
2525
|
}
|
2731
2526
|
}
|
2732
2527
|
|
2733
|
-
|
2734
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2735
|
-
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2736
|
-
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2737
|
-
|
2738
|
-
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2739
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2740
|
-
} else {
|
2741
|
-
GGML_METAL_LOG_INFO("\n");
|
2742
|
-
}
|
2743
|
-
#else
|
2744
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
2745
|
-
#endif
|
2528
|
+
ggml_backend_metal_log_allocated_size(device);
|
2746
2529
|
|
2747
2530
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2748
2531
|
}
|
2749
2532
|
|
2750
2533
|
// backend
|
2751
2534
|
|
2752
|
-
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2535
|
+
GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2753
2536
|
return "Metal";
|
2754
2537
|
|
2755
2538
|
UNUSED(backend);
|
2756
2539
|
}
|
2757
2540
|
|
2758
|
-
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
2541
|
+
GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
|
2759
2542
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2760
2543
|
ggml_metal_free(ctx);
|
2761
2544
|
free(backend);
|
2762
2545
|
}
|
2763
2546
|
|
2764
|
-
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
2547
|
+
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
2765
2548
|
return ggml_backend_metal_buffer_type();
|
2766
2549
|
|
2767
2550
|
UNUSED(backend);
|
2768
2551
|
}
|
2769
2552
|
|
2770
|
-
static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
2553
|
+
GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
2771
2554
|
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
2772
2555
|
|
2773
2556
|
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
2774
2557
|
}
|
2775
2558
|
|
2776
|
-
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
2777
|
-
|
2559
|
+
GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
2560
|
+
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
2778
2561
|
|
2779
|
-
|
2562
|
+
return ggml_metal_supports_op(metal_ctx, op);
|
2780
2563
|
}
|
2781
2564
|
|
2782
|
-
static struct ggml_backend_i
|
2565
|
+
static struct ggml_backend_i ggml_backend_metal_i = {
|
2783
2566
|
/* .get_name = */ ggml_backend_metal_name,
|
2784
2567
|
/* .free = */ ggml_backend_metal_free,
|
2785
2568
|
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
2786
2569
|
/* .set_tensor_async = */ NULL,
|
2787
2570
|
/* .get_tensor_async = */ NULL,
|
2788
|
-
/* .
|
2789
|
-
/* .cpy_tensor_to_async = */ NULL,
|
2571
|
+
/* .cpy_tensor_async = */ NULL,
|
2790
2572
|
/* .synchronize = */ NULL,
|
2791
2573
|
/* .graph_plan_create = */ NULL,
|
2792
2574
|
/* .graph_plan_free = */ NULL,
|
@@ -2795,6 +2577,11 @@ static struct ggml_backend_i metal_backend_i = {
|
|
2795
2577
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
2796
2578
|
};
|
2797
2579
|
|
2580
|
+
void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
2581
|
+
ggml_metal_log_callback = log_callback;
|
2582
|
+
ggml_metal_log_user_data = user_data;
|
2583
|
+
}
|
2584
|
+
|
2798
2585
|
ggml_backend_t ggml_backend_metal_init(void) {
|
2799
2586
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2800
2587
|
|
@@ -2805,7 +2592,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2805
2592
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
2806
2593
|
|
2807
2594
|
*metal_backend = (struct ggml_backend) {
|
2808
|
-
/* .interface = */
|
2595
|
+
/* .interface = */ ggml_backend_metal_i,
|
2809
2596
|
/* .context = */ ctx,
|
2810
2597
|
};
|
2811
2598
|
|
@@ -2813,7 +2600,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2813
2600
|
}
|
2814
2601
|
|
2815
2602
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
2816
|
-
return backend->iface.get_name == ggml_backend_metal_name;
|
2603
|
+
return backend && backend->iface.get_name == ggml_backend_metal_name;
|
2817
2604
|
}
|
2818
2605
|
|
2819
2606
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
@@ -2821,7 +2608,7 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
2821
2608
|
|
2822
2609
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2823
2610
|
|
2824
|
-
|
2611
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
2825
2612
|
}
|
2826
2613
|
|
2827
2614
|
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
@@ -2832,9 +2619,9 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
2832
2619
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2833
2620
|
}
|
2834
2621
|
|
2835
|
-
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2622
|
+
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2836
2623
|
|
2837
|
-
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
2624
|
+
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
2838
2625
|
return ggml_backend_metal_init();
|
2839
2626
|
|
2840
2627
|
GGML_UNUSED(params);
|