llama_cpp 0.12.0 → 0.12.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +78 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +11 -0
- data/vendor/tmp/llama.cpp/Makefile +7 -10
- data/vendor/tmp/llama.cpp/ggml-alloc.c +28 -6
- data/vendor/tmp/llama.cpp/ggml-alloc.h +3 -1
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +36 -36
- data/vendor/tmp/llama.cpp/ggml-backend.c +512 -261
- data/vendor/tmp/llama.cpp/ggml-backend.h +43 -33
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +1494 -559
- data/vendor/tmp/llama.cpp/ggml-cuda.h +18 -30
- data/vendor/tmp/llama.cpp/ggml-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +4 -56
- data/vendor/tmp/llama.cpp/ggml-metal.m +1868 -2002
- data/vendor/tmp/llama.cpp/ggml-metal.metal +692 -8
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +321 -14
- data/vendor/tmp/llama.cpp/ggml-opencl.h +13 -3
- data/vendor/tmp/llama.cpp/ggml-quants.c +2182 -44
- data/vendor/tmp/llama.cpp/ggml-quants.h +36 -1
- data/vendor/tmp/llama.cpp/ggml.c +222 -105
- data/vendor/tmp/llama.cpp/ggml.h +56 -35
- data/vendor/tmp/llama.cpp/llama.cpp +1271 -1618
- data/vendor/tmp/llama.cpp/llama.h +44 -8
- metadata +2 -2
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
|
|
25
25
|
#define UNUSED(x) (void)(x)
|
|
26
26
|
|
|
27
|
-
#define
|
|
27
|
+
#define GGML_METAL_MAX_KERNELS 256
|
|
28
28
|
|
|
29
29
|
struct ggml_metal_buffer {
|
|
30
30
|
const char * name;
|
|
@@ -35,6 +35,134 @@ struct ggml_metal_buffer {
|
|
|
35
35
|
id<MTLBuffer> metal;
|
|
36
36
|
};
|
|
37
37
|
|
|
38
|
+
struct ggml_metal_kernel {
|
|
39
|
+
id<MTLFunction> function;
|
|
40
|
+
id<MTLComputePipelineState> pipeline;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
enum ggml_metal_kernel_type {
|
|
44
|
+
GGML_METAL_KERNEL_TYPE_ADD,
|
|
45
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
|
46
|
+
GGML_METAL_KERNEL_TYPE_MUL,
|
|
47
|
+
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
|
48
|
+
GGML_METAL_KERNEL_TYPE_DIV,
|
|
49
|
+
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
|
50
|
+
GGML_METAL_KERNEL_TYPE_SCALE,
|
|
51
|
+
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
|
52
|
+
GGML_METAL_KERNEL_TYPE_TANH,
|
|
53
|
+
GGML_METAL_KERNEL_TYPE_RELU,
|
|
54
|
+
GGML_METAL_KERNEL_TYPE_GELU,
|
|
55
|
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
|
56
|
+
GGML_METAL_KERNEL_TYPE_SILU,
|
|
57
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
|
58
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
|
59
|
+
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
|
60
|
+
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
|
61
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
|
62
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
|
63
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
|
64
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
|
65
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
66
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
|
67
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
|
68
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
|
69
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
|
70
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
71
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
|
|
72
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
|
73
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
|
74
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
|
75
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
76
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
77
|
+
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
78
|
+
GGML_METAL_KERNEL_TYPE_NORM,
|
|
79
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
80
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
81
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
82
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
83
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
84
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
|
85
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
|
86
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
87
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
88
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
89
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
|
90
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
|
91
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
92
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
|
|
93
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
|
94
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
|
96
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
97
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
98
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
99
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
|
100
|
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
|
101
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
|
102
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
|
103
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
104
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
|
105
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
|
106
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
|
107
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
|
108
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
109
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
|
|
110
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
|
111
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
|
112
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
|
113
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
114
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
115
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
116
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
|
117
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
|
119
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
|
120
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
|
121
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
|
122
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
123
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
|
|
124
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
|
125
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
|
126
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
|
127
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
128
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
129
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
130
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
|
131
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
|
132
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
|
133
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
|
134
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
|
135
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
|
136
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
|
137
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
|
138
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
|
139
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
|
140
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
|
141
|
+
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
142
|
+
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
143
|
+
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
144
|
+
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
145
|
+
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
146
|
+
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
147
|
+
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
148
|
+
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
|
149
|
+
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
|
150
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
151
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
152
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
153
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
|
154
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
|
155
|
+
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
|
156
|
+
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
|
157
|
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
158
|
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
|
159
|
+
GGML_METAL_KERNEL_TYPE_CONCAT,
|
|
160
|
+
GGML_METAL_KERNEL_TYPE_SQR,
|
|
161
|
+
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
162
|
+
|
|
163
|
+
GGML_METAL_KERNEL_TYPE_COUNT
|
|
164
|
+
};
|
|
165
|
+
|
|
38
166
|
struct ggml_metal_context {
|
|
39
167
|
int n_cb;
|
|
40
168
|
|
|
@@ -42,132 +170,15 @@ struct ggml_metal_context {
|
|
|
42
170
|
id<MTLCommandQueue> queue;
|
|
43
171
|
id<MTLLibrary> library;
|
|
44
172
|
|
|
45
|
-
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
46
|
-
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
47
|
-
|
|
48
173
|
dispatch_queue_t d_queue;
|
|
49
174
|
|
|
50
175
|
int n_buffers;
|
|
51
176
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
|
52
177
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
#define GGML_METAL_DECL_KERNEL(name) \
|
|
58
|
-
id<MTLFunction> function_##name; \
|
|
59
|
-
id<MTLComputePipelineState> pipeline_##name
|
|
60
|
-
|
|
61
|
-
GGML_METAL_DECL_KERNEL(add);
|
|
62
|
-
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
63
|
-
GGML_METAL_DECL_KERNEL(mul);
|
|
64
|
-
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
65
|
-
GGML_METAL_DECL_KERNEL(div);
|
|
66
|
-
GGML_METAL_DECL_KERNEL(div_row);
|
|
67
|
-
GGML_METAL_DECL_KERNEL(scale);
|
|
68
|
-
GGML_METAL_DECL_KERNEL(scale_4);
|
|
69
|
-
GGML_METAL_DECL_KERNEL(tanh);
|
|
70
|
-
GGML_METAL_DECL_KERNEL(relu);
|
|
71
|
-
GGML_METAL_DECL_KERNEL(gelu);
|
|
72
|
-
GGML_METAL_DECL_KERNEL(gelu_quick);
|
|
73
|
-
GGML_METAL_DECL_KERNEL(silu);
|
|
74
|
-
GGML_METAL_DECL_KERNEL(soft_max);
|
|
75
|
-
GGML_METAL_DECL_KERNEL(soft_max_4);
|
|
76
|
-
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
77
|
-
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
|
78
|
-
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
|
79
|
-
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
80
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
81
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
82
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
|
83
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
|
84
|
-
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
|
85
|
-
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
|
86
|
-
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
|
87
|
-
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
|
88
|
-
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
|
89
|
-
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
90
|
-
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
|
91
|
-
GGML_METAL_DECL_KERNEL(rms_norm);
|
|
92
|
-
GGML_METAL_DECL_KERNEL(group_norm);
|
|
93
|
-
GGML_METAL_DECL_KERNEL(norm);
|
|
94
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
95
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
96
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
97
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
98
|
-
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
99
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
|
100
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
|
101
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
|
102
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
|
103
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
|
104
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
|
105
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
|
106
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
107
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
108
|
-
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
109
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
|
110
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
|
111
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
112
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
113
|
-
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
114
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
|
115
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
|
116
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
|
117
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
|
118
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
|
119
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
|
120
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
|
121
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
|
122
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
|
123
|
-
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
|
124
|
-
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
125
|
-
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
126
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
127
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
|
128
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
|
129
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
|
130
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
|
131
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
|
132
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
|
133
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
134
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
135
|
-
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
136
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
|
137
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
|
138
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
139
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
|
140
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
|
141
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
|
142
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
|
143
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
|
144
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
|
145
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
|
146
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
|
147
|
-
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
148
|
-
GGML_METAL_DECL_KERNEL(rope_f32);
|
|
149
|
-
GGML_METAL_DECL_KERNEL(rope_f16);
|
|
150
|
-
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
151
|
-
GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
152
|
-
GGML_METAL_DECL_KERNEL(upscale_f32);
|
|
153
|
-
GGML_METAL_DECL_KERNEL(pad_f32);
|
|
154
|
-
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
155
|
-
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
156
|
-
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
|
157
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
158
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
159
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
160
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
|
161
|
-
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
|
162
|
-
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
163
|
-
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
164
|
-
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
165
|
-
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
|
166
|
-
GGML_METAL_DECL_KERNEL(concat);
|
|
167
|
-
GGML_METAL_DECL_KERNEL(sqr);
|
|
168
|
-
GGML_METAL_DECL_KERNEL(sum_rows);
|
|
169
|
-
|
|
170
|
-
#undef GGML_METAL_DECL_KERNEL
|
|
178
|
+
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
|
|
179
|
+
|
|
180
|
+
bool support_simdgroup_reduction;
|
|
181
|
+
bool support_simdgroup_mm;
|
|
171
182
|
};
|
|
172
183
|
|
|
173
184
|
// MSL code
|
|
@@ -181,7 +192,6 @@ struct ggml_metal_context {
|
|
|
181
192
|
@implementation GGMLMetalClass
|
|
182
193
|
@end
|
|
183
194
|
|
|
184
|
-
|
|
185
195
|
static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
|
186
196
|
fprintf(stderr, "%s", msg);
|
|
187
197
|
|
|
@@ -192,11 +202,6 @@ static void ggml_metal_default_log_callback(enum ggml_log_level level, const cha
|
|
|
192
202
|
ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
|
|
193
203
|
void * ggml_metal_log_user_data = NULL;
|
|
194
204
|
|
|
195
|
-
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
|
196
|
-
ggml_metal_log_callback = log_callback;
|
|
197
|
-
ggml_metal_log_user_data = user_data;
|
|
198
|
-
}
|
|
199
|
-
|
|
200
205
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
201
206
|
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
202
207
|
if (ggml_metal_log_callback != NULL) {
|
|
@@ -219,7 +224,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
|
219
224
|
}
|
|
220
225
|
}
|
|
221
226
|
|
|
222
|
-
|
|
227
|
+
static void * ggml_metal_host_malloc(size_t n) {
|
|
228
|
+
void * data = NULL;
|
|
229
|
+
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
230
|
+
if (result != 0) {
|
|
231
|
+
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
232
|
+
return NULL;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
return data;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
223
239
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
224
240
|
|
|
225
241
|
id<MTLDevice> device;
|
|
@@ -245,7 +261,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
245
261
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
246
262
|
ctx->queue = [ctx->device newCommandQueue];
|
|
247
263
|
ctx->n_buffers = 0;
|
|
248
|
-
ctx->concur_list_len = 0;
|
|
249
264
|
|
|
250
265
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
251
266
|
|
|
@@ -258,14 +273,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
258
273
|
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
|
259
274
|
#endif
|
|
260
275
|
NSError * error = nil;
|
|
261
|
-
NSString * libPath = [bundle pathForResource:@"
|
|
276
|
+
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
262
277
|
if (libPath != nil) {
|
|
263
278
|
// pre-compiled library found
|
|
264
279
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
265
280
|
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
|
266
281
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
267
282
|
} else {
|
|
268
|
-
GGML_METAL_LOG_INFO("%s:
|
|
283
|
+
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
|
269
284
|
|
|
270
285
|
NSString * sourcePath;
|
|
271
286
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
|
@@ -288,19 +303,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
288
303
|
return NULL;
|
|
289
304
|
}
|
|
290
305
|
|
|
291
|
-
|
|
306
|
+
// dictionary of preprocessor macros
|
|
307
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
|
308
|
+
|
|
292
309
|
#ifdef GGML_QKK_64
|
|
293
|
-
|
|
294
|
-
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
|
310
|
+
prep[@"QK_K"] = @(64);
|
|
295
311
|
#endif
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
// and go through the "pre-compiled library found" path above
|
|
312
|
+
|
|
313
|
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
|
314
|
+
options.preprocessorMacros = prep;
|
|
315
|
+
|
|
301
316
|
//[options setFastMathEnabled:false];
|
|
302
317
|
|
|
303
318
|
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
319
|
+
|
|
320
|
+
[options release];
|
|
321
|
+
[prep release];
|
|
304
322
|
}
|
|
305
323
|
|
|
306
324
|
if (error) {
|
|
@@ -309,22 +327,51 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
309
327
|
}
|
|
310
328
|
}
|
|
311
329
|
|
|
312
|
-
#if TARGET_OS_OSX
|
|
313
330
|
// print MTL GPU family:
|
|
314
331
|
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
315
332
|
|
|
333
|
+
const NSInteger MTLGPUFamilyMetal3 = 5001;
|
|
334
|
+
|
|
316
335
|
// determine max supported GPU family
|
|
317
336
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
318
337
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
338
|
+
{
|
|
339
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
340
|
+
if ([ctx->device supportsFamily:i]) {
|
|
341
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
342
|
+
break;
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
|
347
|
+
if ([ctx->device supportsFamily:i]) {
|
|
348
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
|
349
|
+
break;
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
|
|
354
|
+
if ([ctx->device supportsFamily:i]) {
|
|
355
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
|
|
356
|
+
break;
|
|
357
|
+
}
|
|
323
358
|
}
|
|
324
359
|
}
|
|
325
360
|
|
|
361
|
+
ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
362
|
+
ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
|
|
363
|
+
|
|
364
|
+
ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
365
|
+
|
|
366
|
+
GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
|
|
367
|
+
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
|
326
368
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
327
|
-
|
|
369
|
+
|
|
370
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
371
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
372
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
373
|
+
}
|
|
374
|
+
#elif TARGET_OS_OSX
|
|
328
375
|
if (ctx->device.maxTransferRate != 0) {
|
|
329
376
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
330
377
|
} else {
|
|
@@ -336,259 +383,171 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
336
383
|
{
|
|
337
384
|
NSError * error = nil;
|
|
338
385
|
|
|
386
|
+
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
|
387
|
+
ctx->kernels[i].function = nil;
|
|
388
|
+
ctx->kernels[i].pipeline = nil;
|
|
389
|
+
}
|
|
390
|
+
|
|
339
391
|
/*
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
392
|
+
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
|
393
|
+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
|
394
|
+
(int) kernel->pipeline.threadExecutionWidth); \
|
|
343
395
|
*/
|
|
344
|
-
#define GGML_METAL_ADD_KERNEL(name) \
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
396
|
+
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
|
|
397
|
+
if (supported) { \
|
|
398
|
+
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
|
399
|
+
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
|
400
|
+
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
|
|
401
|
+
if (error) { \
|
|
402
|
+
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
403
|
+
return NULL; \
|
|
404
|
+
} \
|
|
405
|
+
} else { \
|
|
406
|
+
GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
|
|
350
407
|
}
|
|
351
408
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
GGML_METAL_ADD_KERNEL(
|
|
355
|
-
GGML_METAL_ADD_KERNEL(
|
|
356
|
-
GGML_METAL_ADD_KERNEL(
|
|
357
|
-
GGML_METAL_ADD_KERNEL(
|
|
358
|
-
GGML_METAL_ADD_KERNEL(
|
|
359
|
-
GGML_METAL_ADD_KERNEL(
|
|
360
|
-
GGML_METAL_ADD_KERNEL(
|
|
361
|
-
GGML_METAL_ADD_KERNEL(
|
|
362
|
-
GGML_METAL_ADD_KERNEL(
|
|
363
|
-
GGML_METAL_ADD_KERNEL(
|
|
364
|
-
GGML_METAL_ADD_KERNEL(
|
|
365
|
-
GGML_METAL_ADD_KERNEL(
|
|
366
|
-
GGML_METAL_ADD_KERNEL(
|
|
367
|
-
GGML_METAL_ADD_KERNEL(
|
|
368
|
-
GGML_METAL_ADD_KERNEL(
|
|
369
|
-
GGML_METAL_ADD_KERNEL(
|
|
370
|
-
GGML_METAL_ADD_KERNEL(
|
|
371
|
-
GGML_METAL_ADD_KERNEL(
|
|
372
|
-
GGML_METAL_ADD_KERNEL(
|
|
373
|
-
GGML_METAL_ADD_KERNEL(
|
|
374
|
-
GGML_METAL_ADD_KERNEL(
|
|
375
|
-
GGML_METAL_ADD_KERNEL(
|
|
376
|
-
GGML_METAL_ADD_KERNEL(
|
|
377
|
-
GGML_METAL_ADD_KERNEL(
|
|
378
|
-
GGML_METAL_ADD_KERNEL(
|
|
379
|
-
GGML_METAL_ADD_KERNEL(
|
|
380
|
-
GGML_METAL_ADD_KERNEL(
|
|
381
|
-
GGML_METAL_ADD_KERNEL(
|
|
382
|
-
GGML_METAL_ADD_KERNEL(
|
|
383
|
-
GGML_METAL_ADD_KERNEL(
|
|
384
|
-
GGML_METAL_ADD_KERNEL(
|
|
385
|
-
GGML_METAL_ADD_KERNEL(
|
|
386
|
-
GGML_METAL_ADD_KERNEL(
|
|
387
|
-
GGML_METAL_ADD_KERNEL(
|
|
388
|
-
GGML_METAL_ADD_KERNEL(
|
|
389
|
-
GGML_METAL_ADD_KERNEL(
|
|
390
|
-
GGML_METAL_ADD_KERNEL(
|
|
391
|
-
GGML_METAL_ADD_KERNEL(
|
|
392
|
-
GGML_METAL_ADD_KERNEL(
|
|
393
|
-
GGML_METAL_ADD_KERNEL(
|
|
394
|
-
GGML_METAL_ADD_KERNEL(
|
|
395
|
-
GGML_METAL_ADD_KERNEL(
|
|
396
|
-
GGML_METAL_ADD_KERNEL(
|
|
397
|
-
GGML_METAL_ADD_KERNEL(
|
|
398
|
-
GGML_METAL_ADD_KERNEL(
|
|
399
|
-
GGML_METAL_ADD_KERNEL(
|
|
400
|
-
GGML_METAL_ADD_KERNEL(
|
|
401
|
-
|
|
402
|
-
GGML_METAL_ADD_KERNEL(
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
GGML_METAL_ADD_KERNEL(
|
|
406
|
-
GGML_METAL_ADD_KERNEL(
|
|
407
|
-
|
|
408
|
-
GGML_METAL_ADD_KERNEL(
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
GGML_METAL_ADD_KERNEL(
|
|
412
|
-
GGML_METAL_ADD_KERNEL(
|
|
413
|
-
GGML_METAL_ADD_KERNEL(
|
|
414
|
-
GGML_METAL_ADD_KERNEL(
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
GGML_METAL_ADD_KERNEL(
|
|
442
|
-
GGML_METAL_ADD_KERNEL(
|
|
443
|
-
GGML_METAL_ADD_KERNEL(
|
|
444
|
-
GGML_METAL_ADD_KERNEL(
|
|
445
|
-
GGML_METAL_ADD_KERNEL(
|
|
446
|
-
GGML_METAL_ADD_KERNEL(
|
|
447
|
-
GGML_METAL_ADD_KERNEL(
|
|
448
|
-
GGML_METAL_ADD_KERNEL(
|
|
449
|
-
GGML_METAL_ADD_KERNEL(
|
|
450
|
-
GGML_METAL_ADD_KERNEL(
|
|
451
|
-
GGML_METAL_ADD_KERNEL(
|
|
452
|
-
GGML_METAL_ADD_KERNEL(
|
|
453
|
-
GGML_METAL_ADD_KERNEL(
|
|
454
|
-
GGML_METAL_ADD_KERNEL(
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
GGML_METAL_ADD_KERNEL(
|
|
458
|
-
GGML_METAL_ADD_KERNEL(
|
|
459
|
-
GGML_METAL_ADD_KERNEL(
|
|
460
|
-
GGML_METAL_ADD_KERNEL(
|
|
461
|
-
GGML_METAL_ADD_KERNEL(
|
|
462
|
-
|
|
463
|
-
|
|
409
|
+
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
410
|
+
|
|
411
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
412
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
|
413
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
414
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
|
415
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
416
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
|
417
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
|
418
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
|
419
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
|
420
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
|
421
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
422
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
423
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
424
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
|
425
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
|
426
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
427
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
|
428
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
429
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
|
430
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
|
431
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
|
432
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
433
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
|
434
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
|
435
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
|
436
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
|
437
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
438
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
|
439
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
|
440
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
|
441
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
|
442
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
443
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
444
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
445
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
446
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
|
447
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
|
448
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
|
449
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
450
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
451
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
452
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
453
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
454
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
455
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
456
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
457
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
458
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
459
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
460
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
461
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
462
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
463
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
464
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
465
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
466
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
467
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
468
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
469
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
470
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
471
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
472
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
473
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
474
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
476
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
477
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
478
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
480
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
482
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
|
484
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
|
485
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
|
486
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
|
487
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
|
488
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
|
489
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
|
491
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
|
501
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
|
502
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
|
504
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
|
506
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
507
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
508
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
509
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
511
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
|
516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
|
517
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
518
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
519
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
520
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
|
521
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
522
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
|
523
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
|
524
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
525
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
|
527
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
|
528
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
464
529
|
}
|
|
465
530
|
|
|
466
531
|
return ctx;
|
|
467
532
|
}
|
|
468
533
|
|
|
469
|
-
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
534
|
+
static void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
470
535
|
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
|
471
|
-
#define GGML_METAL_DEL_KERNEL(name) \
|
|
472
|
-
[ctx->function_##name release]; \
|
|
473
|
-
[ctx->pipeline_##name release];
|
|
474
|
-
|
|
475
|
-
GGML_METAL_DEL_KERNEL(add);
|
|
476
|
-
GGML_METAL_DEL_KERNEL(add_row);
|
|
477
|
-
GGML_METAL_DEL_KERNEL(mul);
|
|
478
|
-
GGML_METAL_DEL_KERNEL(mul_row);
|
|
479
|
-
GGML_METAL_DEL_KERNEL(div);
|
|
480
|
-
GGML_METAL_DEL_KERNEL(div_row);
|
|
481
|
-
GGML_METAL_DEL_KERNEL(scale);
|
|
482
|
-
GGML_METAL_DEL_KERNEL(scale_4);
|
|
483
|
-
GGML_METAL_DEL_KERNEL(tanh);
|
|
484
|
-
GGML_METAL_DEL_KERNEL(relu);
|
|
485
|
-
GGML_METAL_DEL_KERNEL(gelu);
|
|
486
|
-
GGML_METAL_DEL_KERNEL(gelu_quick);
|
|
487
|
-
GGML_METAL_DEL_KERNEL(silu);
|
|
488
|
-
GGML_METAL_DEL_KERNEL(soft_max);
|
|
489
|
-
GGML_METAL_DEL_KERNEL(soft_max_4);
|
|
490
|
-
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
491
|
-
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
|
492
|
-
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
|
493
|
-
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
494
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
495
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
496
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
|
497
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
|
498
|
-
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
499
|
-
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
500
|
-
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
501
|
-
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
|
502
|
-
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
503
|
-
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
504
|
-
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
|
505
|
-
GGML_METAL_DEL_KERNEL(rms_norm);
|
|
506
|
-
GGML_METAL_DEL_KERNEL(group_norm);
|
|
507
|
-
GGML_METAL_DEL_KERNEL(norm);
|
|
508
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
509
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
510
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
511
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
512
|
-
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
513
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
|
514
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
|
515
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
|
516
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
|
517
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
|
518
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
|
519
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
|
520
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
521
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
522
|
-
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
523
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
|
524
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
|
525
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
526
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
527
|
-
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
528
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
|
529
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
|
530
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
|
531
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
|
532
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
|
533
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
|
534
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
|
535
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
|
536
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
|
537
|
-
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
|
538
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
539
|
-
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
540
|
-
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
541
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
542
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
543
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
|
544
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
|
545
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
546
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
547
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
548
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
549
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
550
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
551
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
|
552
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
|
553
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
554
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
|
555
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
|
556
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
|
557
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
|
558
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
|
559
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
|
560
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
|
561
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
|
562
|
-
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
563
|
-
}
|
|
564
|
-
GGML_METAL_DEL_KERNEL(rope_f32);
|
|
565
|
-
GGML_METAL_DEL_KERNEL(rope_f16);
|
|
566
|
-
GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
567
|
-
GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
568
|
-
GGML_METAL_DEL_KERNEL(upscale_f32);
|
|
569
|
-
GGML_METAL_DEL_KERNEL(pad_f32);
|
|
570
|
-
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
571
|
-
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
572
|
-
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
|
573
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
574
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
575
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
576
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
|
577
|
-
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
|
578
|
-
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
579
|
-
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
580
|
-
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
581
|
-
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
|
582
|
-
GGML_METAL_DEL_KERNEL(concat);
|
|
583
|
-
GGML_METAL_DEL_KERNEL(sqr);
|
|
584
|
-
GGML_METAL_DEL_KERNEL(sum_rows);
|
|
585
|
-
|
|
586
|
-
#undef GGML_METAL_DEL_KERNEL
|
|
587
536
|
|
|
588
537
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
589
538
|
[ctx->buffers[i].metal release];
|
|
590
539
|
}
|
|
591
540
|
|
|
541
|
+
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
|
542
|
+
if (ctx->kernels[i].pipeline) {
|
|
543
|
+
[ctx->kernels[i].pipeline release];
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
if (ctx->kernels[i].function) {
|
|
547
|
+
[ctx->kernels[i].function release];
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
|
|
592
551
|
[ctx->library release];
|
|
593
552
|
[ctx->queue release];
|
|
594
553
|
[ctx->device release];
|
|
@@ -598,33 +557,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
|
598
557
|
free(ctx);
|
|
599
558
|
}
|
|
600
559
|
|
|
601
|
-
void * ggml_metal_host_malloc(size_t n) {
|
|
602
|
-
void * data = NULL;
|
|
603
|
-
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
604
|
-
if (result != 0) {
|
|
605
|
-
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
606
|
-
return NULL;
|
|
607
|
-
}
|
|
608
|
-
|
|
609
|
-
return data;
|
|
610
|
-
}
|
|
611
|
-
|
|
612
|
-
void ggml_metal_host_free(void * data) {
|
|
613
|
-
free(data);
|
|
614
|
-
}
|
|
615
|
-
|
|
616
|
-
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
617
|
-
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
618
|
-
}
|
|
619
|
-
|
|
620
|
-
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
|
621
|
-
return ctx->concur_list_len;
|
|
622
|
-
}
|
|
623
|
-
|
|
624
|
-
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
625
|
-
return ctx->concur_list;
|
|
626
|
-
}
|
|
627
|
-
|
|
628
560
|
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
629
561
|
|
|
630
562
|
struct ggml_backend_metal_buffer {
|
|
@@ -697,210 +629,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
|
697
629
|
return nil;
|
|
698
630
|
}
|
|
699
631
|
|
|
700
|
-
bool
|
|
701
|
-
struct ggml_metal_context * ctx,
|
|
702
|
-
const char * name,
|
|
703
|
-
void * data,
|
|
704
|
-
size_t size,
|
|
705
|
-
size_t max_size) {
|
|
706
|
-
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
|
707
|
-
GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
|
708
|
-
return false;
|
|
709
|
-
}
|
|
710
|
-
|
|
711
|
-
if (data) {
|
|
712
|
-
// verify that the buffer does not overlap with any of the existing buffers
|
|
713
|
-
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
714
|
-
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
|
715
|
-
|
|
716
|
-
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
|
717
|
-
GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
|
718
|
-
return false;
|
|
719
|
-
}
|
|
720
|
-
}
|
|
721
|
-
|
|
722
|
-
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
723
|
-
|
|
724
|
-
size_t size_aligned = size;
|
|
725
|
-
if ((size_aligned % size_page) != 0) {
|
|
726
|
-
size_aligned += (size_page - (size_aligned % size_page));
|
|
727
|
-
}
|
|
728
|
-
|
|
729
|
-
// the buffer fits into the max buffer size allowed by the device
|
|
730
|
-
if (size_aligned <= ctx->device.maxBufferLength) {
|
|
731
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
|
732
|
-
ctx->buffers[ctx->n_buffers].data = data;
|
|
733
|
-
ctx->buffers[ctx->n_buffers].size = size;
|
|
734
|
-
|
|
735
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
736
|
-
|
|
737
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
738
|
-
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
739
|
-
return false;
|
|
740
|
-
}
|
|
741
|
-
|
|
742
|
-
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
743
|
-
|
|
744
|
-
++ctx->n_buffers;
|
|
745
|
-
} else {
|
|
746
|
-
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
747
|
-
// one of the views
|
|
748
|
-
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
749
|
-
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
|
750
|
-
const size_t size_view = ctx->device.maxBufferLength;
|
|
751
|
-
|
|
752
|
-
for (size_t i = 0; i < size; i += size_step) {
|
|
753
|
-
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
754
|
-
|
|
755
|
-
ctx->buffers[ctx->n_buffers].name = name;
|
|
756
|
-
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
|
757
|
-
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
758
|
-
|
|
759
|
-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
760
|
-
|
|
761
|
-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
762
|
-
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
763
|
-
return false;
|
|
764
|
-
}
|
|
765
|
-
|
|
766
|
-
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
767
|
-
if (i + size_step < size) {
|
|
768
|
-
GGML_METAL_LOG_INFO("\n");
|
|
769
|
-
}
|
|
770
|
-
|
|
771
|
-
++ctx->n_buffers;
|
|
772
|
-
}
|
|
773
|
-
}
|
|
774
|
-
|
|
775
|
-
#if TARGET_OS_OSX
|
|
776
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
777
|
-
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
778
|
-
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
779
|
-
|
|
780
|
-
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
781
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
782
|
-
} else {
|
|
783
|
-
GGML_METAL_LOG_INFO("\n");
|
|
784
|
-
}
|
|
785
|
-
#else
|
|
786
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
787
|
-
#endif
|
|
788
|
-
}
|
|
789
|
-
|
|
790
|
-
return true;
|
|
791
|
-
}
|
|
792
|
-
|
|
793
|
-
void ggml_metal_set_tensor(
|
|
794
|
-
struct ggml_metal_context * ctx,
|
|
795
|
-
struct ggml_tensor * t) {
|
|
796
|
-
size_t offs;
|
|
797
|
-
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
|
798
|
-
|
|
799
|
-
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
|
|
800
|
-
}
|
|
801
|
-
|
|
802
|
-
void ggml_metal_get_tensor(
|
|
803
|
-
struct ggml_metal_context * ctx,
|
|
804
|
-
struct ggml_tensor * t) {
|
|
805
|
-
size_t offs;
|
|
806
|
-
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
|
807
|
-
|
|
808
|
-
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
|
809
|
-
}
|
|
810
|
-
|
|
811
|
-
void ggml_metal_graph_find_concurrency(
|
|
812
|
-
struct ggml_metal_context * ctx,
|
|
813
|
-
struct ggml_cgraph * gf, bool check_mem) {
|
|
814
|
-
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
|
815
|
-
int nodes_unused[GGML_MAX_CONCUR];
|
|
816
|
-
|
|
817
|
-
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
|
818
|
-
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
|
819
|
-
ctx->concur_list_len = 0;
|
|
820
|
-
|
|
821
|
-
int n_left = gf->n_nodes;
|
|
822
|
-
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
|
823
|
-
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
|
824
|
-
|
|
825
|
-
while (n_left > 0) {
|
|
826
|
-
// number of nodes at a layer (that can be issued concurrently)
|
|
827
|
-
int concurrency = 0;
|
|
828
|
-
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
|
829
|
-
if (nodes_unused[i]) {
|
|
830
|
-
// if the requirements for gf->nodes[i] are satisfied
|
|
831
|
-
int exe_flag = 1;
|
|
832
|
-
|
|
833
|
-
// scan all srcs
|
|
834
|
-
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
|
835
|
-
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
|
836
|
-
if (src_cur) {
|
|
837
|
-
// if is leaf nodes it's satisfied.
|
|
838
|
-
// TODO: ggml_is_leaf()
|
|
839
|
-
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
|
840
|
-
continue;
|
|
841
|
-
}
|
|
842
|
-
|
|
843
|
-
// otherwise this src should be the output from previous nodes.
|
|
844
|
-
int is_found = 0;
|
|
845
|
-
|
|
846
|
-
// scan 2*search_depth back because we inserted barrier.
|
|
847
|
-
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
|
848
|
-
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
|
849
|
-
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
|
850
|
-
is_found = 1;
|
|
851
|
-
break;
|
|
852
|
-
}
|
|
853
|
-
}
|
|
854
|
-
if (is_found == 0) {
|
|
855
|
-
exe_flag = 0;
|
|
856
|
-
break;
|
|
857
|
-
}
|
|
858
|
-
}
|
|
859
|
-
}
|
|
860
|
-
if (exe_flag && check_mem) {
|
|
861
|
-
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
|
862
|
-
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
|
863
|
-
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
|
864
|
-
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
|
865
|
-
for (int j = n_start; j < i; j++) {
|
|
866
|
-
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
|
867
|
-
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
|
868
|
-
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
|
869
|
-
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
|
870
|
-
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
|
871
|
-
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
|
872
|
-
continue;
|
|
873
|
-
}
|
|
874
|
-
|
|
875
|
-
exe_flag = 0;
|
|
876
|
-
}
|
|
877
|
-
}
|
|
878
|
-
}
|
|
879
|
-
if (exe_flag) {
|
|
880
|
-
ctx->concur_list[level_pos + concurrency] = i;
|
|
881
|
-
nodes_unused[i] = 0;
|
|
882
|
-
concurrency++;
|
|
883
|
-
ctx->concur_list_len++;
|
|
884
|
-
}
|
|
885
|
-
}
|
|
886
|
-
}
|
|
887
|
-
n_left -= concurrency;
|
|
888
|
-
// adding a barrier different layer
|
|
889
|
-
ctx->concur_list[level_pos + concurrency] = -1;
|
|
890
|
-
ctx->concur_list_len++;
|
|
891
|
-
// jump all sorted nodes at nodes_bak
|
|
892
|
-
while (!nodes_unused[n_start]) {
|
|
893
|
-
n_start++;
|
|
894
|
-
}
|
|
895
|
-
level_pos += concurrency + 1;
|
|
896
|
-
}
|
|
897
|
-
|
|
898
|
-
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
|
899
|
-
GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
|
900
|
-
}
|
|
901
|
-
}
|
|
902
|
-
|
|
903
|
-
static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
632
|
+
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
|
|
904
633
|
switch (op->op) {
|
|
905
634
|
case GGML_OP_UNARY:
|
|
906
635
|
switch (ggml_get_unary_op(op)) {
|
|
@@ -926,9 +655,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
|
926
655
|
case GGML_OP_SCALE:
|
|
927
656
|
case GGML_OP_SQR:
|
|
928
657
|
case GGML_OP_SUM_ROWS:
|
|
658
|
+
return true;
|
|
929
659
|
case GGML_OP_SOFT_MAX:
|
|
930
660
|
case GGML_OP_RMS_NORM:
|
|
931
661
|
case GGML_OP_GROUP_NORM:
|
|
662
|
+
return ctx->support_simdgroup_reduction;
|
|
932
663
|
case GGML_OP_NORM:
|
|
933
664
|
case GGML_OP_ALIBI:
|
|
934
665
|
case GGML_OP_ROPE:
|
|
@@ -937,9 +668,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
|
937
668
|
case GGML_OP_PAD:
|
|
938
669
|
case GGML_OP_ARGSORT:
|
|
939
670
|
case GGML_OP_LEAKY_RELU:
|
|
671
|
+
return true;
|
|
940
672
|
case GGML_OP_MUL_MAT:
|
|
941
673
|
case GGML_OP_MUL_MAT_ID:
|
|
942
|
-
return
|
|
674
|
+
return ctx->support_simdgroup_reduction;
|
|
943
675
|
case GGML_OP_CPY:
|
|
944
676
|
case GGML_OP_DUP:
|
|
945
677
|
case GGML_OP_CONT:
|
|
@@ -977,1438 +709,1556 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
|
977
709
|
return false;
|
|
978
710
|
}
|
|
979
711
|
}
|
|
980
|
-
|
|
712
|
+
|
|
713
|
+
static bool ggml_metal_graph_compute(
|
|
981
714
|
struct ggml_metal_context * ctx,
|
|
982
715
|
struct ggml_cgraph * gf) {
|
|
983
716
|
@autoreleasepool {
|
|
984
717
|
|
|
985
|
-
// if there is ctx->concur_list, dispatch concurrently
|
|
986
|
-
// else fallback to serial dispatch
|
|
987
718
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
|
988
|
-
|
|
989
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
|
990
|
-
|
|
991
|
-
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
|
992
|
-
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
|
719
|
+
edesc.dispatchType = MTLDispatchTypeSerial;
|
|
993
720
|
|
|
994
721
|
// create multiple command buffers and enqueue them
|
|
995
722
|
// then, we encode the graph into the command buffers in parallel
|
|
996
723
|
|
|
724
|
+
const int n_nodes = gf->n_nodes;
|
|
997
725
|
const int n_cb = ctx->n_cb;
|
|
726
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
|
998
727
|
|
|
999
|
-
|
|
1000
|
-
|
|
728
|
+
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
|
729
|
+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
730
|
+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
731
|
+
command_buffer_builder[cb_idx] = command_buffer;
|
|
1001
732
|
|
|
1002
733
|
// enqueue the command buffers in order to specify their execution order
|
|
1003
|
-
[
|
|
1004
|
-
|
|
1005
|
-
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
|
734
|
+
[command_buffer enqueue];
|
|
1006
735
|
}
|
|
736
|
+
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
|
1007
737
|
|
|
1008
|
-
|
|
1009
|
-
const int
|
|
1010
|
-
|
|
1011
|
-
dispatch_async(ctx->d_queue, ^{
|
|
1012
|
-
size_t offs_src0 = 0;
|
|
1013
|
-
size_t offs_src1 = 0;
|
|
1014
|
-
size_t offs_dst = 0;
|
|
1015
|
-
|
|
1016
|
-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
1017
|
-
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
|
1018
|
-
|
|
1019
|
-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
1020
|
-
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
1021
|
-
|
|
1022
|
-
for (int ind = node_start; ind < node_end; ++ind) {
|
|
1023
|
-
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
|
1024
|
-
|
|
1025
|
-
if (i == -1) {
|
|
1026
|
-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
1027
|
-
continue;
|
|
1028
|
-
}
|
|
1029
|
-
|
|
1030
|
-
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
|
1031
|
-
|
|
1032
|
-
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
1033
|
-
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
1034
|
-
struct ggml_tensor * dst = gf->nodes[i];
|
|
1035
|
-
|
|
1036
|
-
switch (dst->op) {
|
|
1037
|
-
case GGML_OP_NONE:
|
|
1038
|
-
case GGML_OP_RESHAPE:
|
|
1039
|
-
case GGML_OP_VIEW:
|
|
1040
|
-
case GGML_OP_TRANSPOSE:
|
|
1041
|
-
case GGML_OP_PERMUTE:
|
|
1042
|
-
{
|
|
1043
|
-
// noop -> next node
|
|
1044
|
-
} continue;
|
|
1045
|
-
default:
|
|
1046
|
-
{
|
|
1047
|
-
} break;
|
|
1048
|
-
}
|
|
1049
|
-
|
|
1050
|
-
if (!ggml_metal_supports_op(dst)) {
|
|
1051
|
-
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
|
1052
|
-
GGML_ASSERT(!"unsupported op");
|
|
1053
|
-
}
|
|
1054
|
-
|
|
1055
|
-
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
1056
|
-
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
1057
|
-
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
1058
|
-
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
1059
|
-
|
|
1060
|
-
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
1061
|
-
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
1062
|
-
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
1063
|
-
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
1064
|
-
|
|
1065
|
-
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
1066
|
-
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
1067
|
-
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
1068
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
|
1069
|
-
|
|
1070
|
-
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
1071
|
-
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
1072
|
-
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
1073
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
|
1074
|
-
|
|
1075
|
-
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
1076
|
-
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
1077
|
-
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
1078
|
-
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
1079
|
-
|
|
1080
|
-
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
1081
|
-
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
1082
|
-
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
1083
|
-
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
1084
|
-
|
|
1085
|
-
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
|
1086
|
-
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
|
1087
|
-
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
|
1088
|
-
|
|
1089
|
-
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
|
1090
|
-
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
1091
|
-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
1092
|
-
|
|
1093
|
-
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
1094
|
-
//if (src0) {
|
|
1095
|
-
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
|
1096
|
-
// ggml_is_contiguous(src0), src0->name);
|
|
1097
|
-
//}
|
|
1098
|
-
//if (src1) {
|
|
1099
|
-
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
|
1100
|
-
// ggml_is_contiguous(src1), src1->name);
|
|
1101
|
-
//}
|
|
1102
|
-
//if (dst) {
|
|
1103
|
-
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
|
1104
|
-
// dst->name);
|
|
1105
|
-
//}
|
|
1106
|
-
|
|
1107
|
-
switch (dst->op) {
|
|
1108
|
-
case GGML_OP_CONCAT:
|
|
1109
|
-
{
|
|
1110
|
-
const int64_t nb = ne00;
|
|
1111
|
-
|
|
1112
|
-
[encoder setComputePipelineState:ctx->pipeline_concat];
|
|
1113
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1114
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1115
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1116
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1117
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1118
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1119
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1120
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1121
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
1122
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
1123
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
1124
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1125
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1126
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1127
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1128
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1129
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1130
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1131
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1132
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1133
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1134
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1135
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1136
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1137
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1138
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1139
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1140
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
1141
|
-
|
|
1142
|
-
const int nth = MIN(1024, ne0);
|
|
1143
|
-
|
|
1144
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1145
|
-
} break;
|
|
1146
|
-
case GGML_OP_ADD:
|
|
1147
|
-
case GGML_OP_MUL:
|
|
1148
|
-
case GGML_OP_DIV:
|
|
1149
|
-
{
|
|
1150
|
-
const size_t offs = 0;
|
|
1151
|
-
|
|
1152
|
-
bool bcast_row = false;
|
|
1153
|
-
|
|
1154
|
-
int64_t nb = ne00;
|
|
1155
|
-
|
|
1156
|
-
id<MTLComputePipelineState> pipeline = nil;
|
|
1157
|
-
|
|
1158
|
-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1159
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
738
|
+
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
|
739
|
+
const int cb_idx = iter;
|
|
1160
740
|
|
|
1161
|
-
|
|
1162
|
-
|
|
741
|
+
size_t offs_src0 = 0;
|
|
742
|
+
size_t offs_src1 = 0;
|
|
743
|
+
size_t offs_dst = 0;
|
|
1163
744
|
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
|
1167
|
-
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
|
1168
|
-
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
|
1169
|
-
default: GGML_ASSERT(false);
|
|
1170
|
-
}
|
|
745
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
746
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
1171
747
|
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
switch (dst->op) {
|
|
1175
|
-
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
|
1176
|
-
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
|
1177
|
-
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
|
1178
|
-
default: GGML_ASSERT(false);
|
|
1179
|
-
}
|
|
1180
|
-
}
|
|
748
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
749
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
1181
750
|
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1188
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1189
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1190
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1191
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
1192
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
1193
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
1194
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1195
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1196
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1197
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1198
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1199
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1200
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1201
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1202
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1203
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1204
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1205
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1206
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1207
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1208
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1209
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1210
|
-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1211
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
1212
|
-
|
|
1213
|
-
if (bcast_row) {
|
|
1214
|
-
const int64_t n = ggml_nelements(dst)/4;
|
|
751
|
+
for (int i = node_start; i < node_end; ++i) {
|
|
752
|
+
if (i == -1) {
|
|
753
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
754
|
+
continue;
|
|
755
|
+
}
|
|
1215
756
|
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
757
|
+
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
|
758
|
+
|
|
759
|
+
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
760
|
+
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
761
|
+
struct ggml_tensor * dst = gf->nodes[i];
|
|
762
|
+
|
|
763
|
+
switch (dst->op) {
|
|
764
|
+
case GGML_OP_NONE:
|
|
765
|
+
case GGML_OP_RESHAPE:
|
|
766
|
+
case GGML_OP_VIEW:
|
|
767
|
+
case GGML_OP_TRANSPOSE:
|
|
768
|
+
case GGML_OP_PERMUTE:
|
|
769
|
+
{
|
|
770
|
+
// noop -> next node
|
|
771
|
+
} continue;
|
|
772
|
+
default:
|
|
773
|
+
{
|
|
774
|
+
} break;
|
|
775
|
+
}
|
|
1219
776
|
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
{
|
|
1225
|
-
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
1226
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1227
|
-
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
|
777
|
+
if (!ggml_metal_supports_op(ctx, dst)) {
|
|
778
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
|
779
|
+
GGML_ASSERT(!"unsupported op");
|
|
780
|
+
}
|
|
1228
781
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
1233
|
-
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
1234
|
-
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
1235
|
-
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1236
|
-
|
|
1237
|
-
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1238
|
-
|
|
1239
|
-
if (!inplace) {
|
|
1240
|
-
// run a separete kernel to cpy src->dst
|
|
1241
|
-
// not sure how to avoid this
|
|
1242
|
-
// TODO: make a simpler cpy_bytes kernel
|
|
1243
|
-
|
|
1244
|
-
const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
|
|
1245
|
-
|
|
1246
|
-
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
|
1247
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1248
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1249
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1250
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1251
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1252
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1253
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1254
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1255
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1256
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1257
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1258
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1259
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1260
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1261
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1262
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1263
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1264
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1265
|
-
|
|
1266
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1267
|
-
}
|
|
782
|
+
#ifndef GGML_METAL_NDEBUG
|
|
783
|
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
|
784
|
+
#endif
|
|
1268
785
|
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
786
|
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
787
|
+
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
788
|
+
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
789
|
+
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
|
790
|
+
|
|
791
|
+
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
|
792
|
+
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
|
793
|
+
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
|
794
|
+
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
|
795
|
+
|
|
796
|
+
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
|
797
|
+
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
|
798
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
|
799
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
|
800
|
+
|
|
801
|
+
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
|
802
|
+
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
|
803
|
+
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
|
804
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
|
805
|
+
|
|
806
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
807
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
808
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
|
809
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
|
810
|
+
|
|
811
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
|
812
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
|
813
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
814
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
815
|
+
|
|
816
|
+
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
|
817
|
+
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
|
818
|
+
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
|
819
|
+
|
|
820
|
+
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
|
|
821
|
+
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
822
|
+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
823
|
+
|
|
824
|
+
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
825
|
+
//if (src0) {
|
|
826
|
+
// GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
|
827
|
+
// ggml_is_contiguous(src0), src0->name);
|
|
828
|
+
//}
|
|
829
|
+
//if (src1) {
|
|
830
|
+
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
|
831
|
+
// ggml_is_contiguous(src1), src1->name);
|
|
832
|
+
//}
|
|
833
|
+
//if (dst) {
|
|
834
|
+
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
|
835
|
+
// dst->name);
|
|
836
|
+
//}
|
|
837
|
+
|
|
838
|
+
switch (dst->op) {
|
|
839
|
+
case GGML_OP_CONCAT:
|
|
840
|
+
{
|
|
841
|
+
const int64_t nb = ne00;
|
|
842
|
+
|
|
843
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
|
844
|
+
|
|
845
|
+
[encoder setComputePipelineState:pipeline];
|
|
846
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
847
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
848
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
849
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
850
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
851
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
852
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
853
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
854
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
855
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
856
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
857
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
858
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
859
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
860
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
861
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
862
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
863
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
864
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
865
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
866
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
867
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
868
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
869
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
870
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
871
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
872
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
873
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
874
|
+
|
|
875
|
+
const int nth = MIN(1024, ne0);
|
|
876
|
+
|
|
877
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
878
|
+
} break;
|
|
879
|
+
case GGML_OP_ADD:
|
|
880
|
+
case GGML_OP_MUL:
|
|
881
|
+
case GGML_OP_DIV:
|
|
882
|
+
{
|
|
883
|
+
const size_t offs = 0;
|
|
884
|
+
|
|
885
|
+
bool bcast_row = false;
|
|
886
|
+
|
|
887
|
+
int64_t nb = ne00;
|
|
888
|
+
|
|
889
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
890
|
+
|
|
891
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1305
892
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
1306
893
|
|
|
1307
|
-
|
|
894
|
+
// src1 is a row
|
|
895
|
+
GGML_ASSERT(ne11 == 1);
|
|
1308
896
|
|
|
1309
|
-
|
|
897
|
+
nb = ne00 / 4;
|
|
898
|
+
switch (dst->op) {
|
|
899
|
+
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
900
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
901
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
902
|
+
default: GGML_ASSERT(false);
|
|
903
|
+
}
|
|
1310
904
|
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
905
|
+
bcast_row = true;
|
|
906
|
+
} else {
|
|
907
|
+
switch (dst->op) {
|
|
908
|
+
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
909
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
910
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
911
|
+
default: GGML_ASSERT(false);
|
|
1316
912
|
}
|
|
913
|
+
}
|
|
1317
914
|
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
915
|
+
[encoder setComputePipelineState:pipeline];
|
|
916
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
917
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
918
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
919
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
920
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
921
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
922
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
923
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
924
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
925
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
926
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
927
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
928
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
929
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
930
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
931
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
932
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
933
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
934
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
935
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
936
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
937
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
938
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
939
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
940
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
941
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
942
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
943
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
944
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
945
|
+
|
|
946
|
+
if (bcast_row) {
|
|
947
|
+
const int64_t n = ggml_nelements(dst)/4;
|
|
1321
948
|
|
|
1322
949
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1323
|
-
}
|
|
1324
|
-
|
|
1325
|
-
switch (ggml_get_unary_op(gf->nodes[i])) {
|
|
1326
|
-
case GGML_UNARY_OP_TANH:
|
|
1327
|
-
{
|
|
1328
|
-
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
|
1329
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1330
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
950
|
+
} else {
|
|
951
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
1331
952
|
|
|
1332
|
-
|
|
953
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
954
|
+
}
|
|
955
|
+
} break;
|
|
956
|
+
case GGML_OP_ACC:
|
|
957
|
+
{
|
|
958
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
959
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
960
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
|
1333
961
|
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
case GGML_UNARY_OP_RELU:
|
|
1337
|
-
{
|
|
1338
|
-
[encoder setComputePipelineState:ctx->pipeline_relu];
|
|
1339
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1340
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
962
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
963
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
1341
964
|
|
|
1342
|
-
|
|
965
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
966
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
967
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
968
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1343
969
|
|
|
1344
|
-
|
|
1345
|
-
} break;
|
|
1346
|
-
case GGML_UNARY_OP_GELU:
|
|
1347
|
-
{
|
|
1348
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
|
1349
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1350
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
970
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1351
971
|
|
|
1352
|
-
|
|
1353
|
-
|
|
972
|
+
if (!inplace) {
|
|
973
|
+
// run a separete kernel to cpy src->dst
|
|
974
|
+
// not sure how to avoid this
|
|
975
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1354
976
|
|
|
1355
|
-
|
|
1356
|
-
} break;
|
|
1357
|
-
case GGML_UNARY_OP_GELU_QUICK:
|
|
1358
|
-
{
|
|
1359
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
|
1360
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1361
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
977
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
|
1362
978
|
|
|
1363
|
-
|
|
1364
|
-
|
|
979
|
+
[encoder setComputePipelineState:pipeline];
|
|
980
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
981
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
982
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
983
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
984
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
985
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
986
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
987
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
988
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
989
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
990
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
991
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
992
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
993
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
994
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
995
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
996
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
997
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1365
998
|
|
|
1366
|
-
|
|
1367
|
-
} break;
|
|
1368
|
-
case GGML_UNARY_OP_SILU:
|
|
1369
|
-
{
|
|
1370
|
-
[encoder setComputePipelineState:ctx->pipeline_silu];
|
|
1371
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1372
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
999
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
1373
1000
|
|
|
1374
|
-
|
|
1375
|
-
|
|
1001
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1002
|
+
}
|
|
1376
1003
|
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1004
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
1005
|
+
|
|
1006
|
+
[encoder setComputePipelineState:pipeline];
|
|
1007
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1008
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1009
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1010
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1011
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1012
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1013
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1014
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1015
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
|
1016
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
|
1017
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
|
1018
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1019
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1020
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1021
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1022
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1023
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1024
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1025
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1026
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1027
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1028
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1029
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1030
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1031
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
|
1032
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
|
1033
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
|
1034
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1035
|
+
|
|
1036
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
1037
|
+
|
|
1038
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1039
|
+
} break;
|
|
1040
|
+
case GGML_OP_SCALE:
|
|
1041
|
+
{
|
|
1042
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
1043
|
+
|
|
1044
|
+
const float scale = *(const float *) dst->op_params;
|
|
1045
|
+
|
|
1046
|
+
int64_t n = ggml_nelements(dst);
|
|
1047
|
+
|
|
1048
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1049
|
+
|
|
1050
|
+
if (n % 4 == 0) {
|
|
1051
|
+
n /= 4;
|
|
1052
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
|
|
1053
|
+
} else {
|
|
1054
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
|
|
1055
|
+
}
|
|
1388
1056
|
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1057
|
+
[encoder setComputePipelineState:pipeline];
|
|
1058
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1059
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1060
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
1392
1061
|
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1062
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1063
|
+
} break;
|
|
1064
|
+
case GGML_OP_UNARY:
|
|
1065
|
+
switch (ggml_get_unary_op(gf->nodes[i])) {
|
|
1066
|
+
case GGML_UNARY_OP_TANH:
|
|
1067
|
+
{
|
|
1068
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
|
1399
1069
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1404
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1405
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1406
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1407
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1408
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1409
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1410
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1411
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1412
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1413
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1414
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1415
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1416
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1417
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1418
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1419
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1420
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1421
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1422
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1423
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1424
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1425
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1426
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1427
|
-
|
|
1428
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1429
|
-
} break;
|
|
1430
|
-
case GGML_OP_SOFT_MAX:
|
|
1431
|
-
{
|
|
1432
|
-
int nth = 32; // SIMD width
|
|
1433
|
-
|
|
1434
|
-
if (ne00%4 == 0) {
|
|
1435
|
-
while (nth < ne00/4 && nth < 256) {
|
|
1436
|
-
nth *= 2;
|
|
1437
|
-
}
|
|
1438
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
1439
|
-
} else {
|
|
1440
|
-
while (nth < ne00 && nth < 1024) {
|
|
1441
|
-
nth *= 2;
|
|
1442
|
-
}
|
|
1443
|
-
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
1444
|
-
}
|
|
1070
|
+
[encoder setComputePipelineState:pipeline];
|
|
1071
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1072
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1445
1073
|
|
|
1446
|
-
|
|
1074
|
+
const int64_t n = ggml_nelements(dst);
|
|
1447
1075
|
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
}
|
|
1454
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1455
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1456
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1457
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1458
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1459
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1460
|
-
|
|
1461
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1462
|
-
} break;
|
|
1463
|
-
case GGML_OP_DIAG_MASK_INF:
|
|
1464
|
-
{
|
|
1465
|
-
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
1466
|
-
|
|
1467
|
-
if (ne00%8 == 0) {
|
|
1468
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
|
1469
|
-
} else {
|
|
1470
|
-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
|
1471
|
-
}
|
|
1472
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1473
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1474
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1475
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1476
|
-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
1076
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1077
|
+
} break;
|
|
1078
|
+
case GGML_UNARY_OP_RELU:
|
|
1079
|
+
{
|
|
1080
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
|
|
1477
1081
|
|
|
1478
|
-
|
|
1479
|
-
[encoder
|
|
1480
|
-
|
|
1481
|
-
else {
|
|
1482
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1483
|
-
}
|
|
1484
|
-
} break;
|
|
1485
|
-
case GGML_OP_MUL_MAT:
|
|
1486
|
-
{
|
|
1487
|
-
GGML_ASSERT(ne00 == ne10);
|
|
1082
|
+
[encoder setComputePipelineState:pipeline];
|
|
1083
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1084
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1488
1085
|
|
|
1489
|
-
|
|
1490
|
-
GGML_ASSERT(ne12 % ne02 == 0);
|
|
1491
|
-
GGML_ASSERT(ne13 % ne03 == 0);
|
|
1086
|
+
const int64_t n = ggml_nelements(dst);
|
|
1492
1087
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1088
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1089
|
+
} break;
|
|
1090
|
+
case GGML_UNARY_OP_GELU:
|
|
1091
|
+
{
|
|
1092
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
|
1495
1093
|
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1094
|
+
[encoder setComputePipelineState:pipeline];
|
|
1095
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1096
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1499
1097
|
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
// these numbers do not translate to other devices or model sizes
|
|
1503
|
-
// TODO: need to find a better approach
|
|
1504
|
-
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1505
|
-
switch (src0t) {
|
|
1506
|
-
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1507
|
-
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1508
|
-
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1509
|
-
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1510
|
-
case GGML_TYPE_Q4_0:
|
|
1511
|
-
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1512
|
-
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1513
|
-
case GGML_TYPE_Q5_0: // not tested yet
|
|
1514
|
-
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1515
|
-
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1516
|
-
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1517
|
-
default: ne11_mm_min = 1; break;
|
|
1518
|
-
}
|
|
1519
|
-
}
|
|
1520
|
-
#endif
|
|
1098
|
+
const int64_t n = ggml_nelements(dst);
|
|
1099
|
+
GGML_ASSERT(n % 4 == 0);
|
|
1521
1100
|
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
src1t == GGML_TYPE_F32 &&
|
|
1528
|
-
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1529
|
-
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1530
|
-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1531
|
-
switch (src0->type) {
|
|
1532
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
1533
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
|
1534
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
|
1535
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
|
1536
|
-
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
|
1537
|
-
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
|
1538
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
|
1539
|
-
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
|
1540
|
-
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
|
1541
|
-
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
|
1542
|
-
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
|
1543
|
-
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
|
1544
|
-
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
1545
|
-
}
|
|
1546
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1547
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1548
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1549
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1550
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1551
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
|
1552
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
|
1553
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
1554
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
1555
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
1556
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1557
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1558
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1559
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1560
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1561
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1562
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1563
|
-
} else {
|
|
1564
|
-
int nth0 = 32;
|
|
1565
|
-
int nth1 = 1;
|
|
1566
|
-
int nrows = 1;
|
|
1567
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1568
|
-
|
|
1569
|
-
// use custom matrix x vector kernel
|
|
1570
|
-
switch (src0t) {
|
|
1571
|
-
case GGML_TYPE_F32:
|
|
1572
|
-
{
|
|
1573
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1574
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
|
1575
|
-
nrows = 4;
|
|
1576
|
-
} break;
|
|
1577
|
-
case GGML_TYPE_F16:
|
|
1578
|
-
{
|
|
1579
|
-
nth0 = 32;
|
|
1580
|
-
nth1 = 1;
|
|
1581
|
-
if (src1t == GGML_TYPE_F32) {
|
|
1582
|
-
if (ne11 * ne12 < 4) {
|
|
1583
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
|
1584
|
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1585
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
|
1586
|
-
nrows = ne11;
|
|
1587
|
-
} else {
|
|
1588
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
|
1589
|
-
nrows = 4;
|
|
1590
|
-
}
|
|
1591
|
-
} else {
|
|
1592
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
|
1593
|
-
nrows = 4;
|
|
1594
|
-
}
|
|
1595
|
-
} break;
|
|
1596
|
-
case GGML_TYPE_Q4_0:
|
|
1597
|
-
{
|
|
1598
|
-
nth0 = 8;
|
|
1599
|
-
nth1 = 8;
|
|
1600
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
1601
|
-
} break;
|
|
1602
|
-
case GGML_TYPE_Q4_1:
|
|
1603
|
-
{
|
|
1604
|
-
nth0 = 8;
|
|
1605
|
-
nth1 = 8;
|
|
1606
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1607
|
-
} break;
|
|
1608
|
-
case GGML_TYPE_Q5_0:
|
|
1609
|
-
{
|
|
1610
|
-
nth0 = 8;
|
|
1611
|
-
nth1 = 8;
|
|
1612
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1613
|
-
} break;
|
|
1614
|
-
case GGML_TYPE_Q5_1:
|
|
1615
|
-
{
|
|
1616
|
-
nth0 = 8;
|
|
1617
|
-
nth1 = 8;
|
|
1618
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
1619
|
-
} break;
|
|
1620
|
-
case GGML_TYPE_Q8_0:
|
|
1621
|
-
{
|
|
1622
|
-
nth0 = 8;
|
|
1623
|
-
nth1 = 8;
|
|
1624
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
1625
|
-
} break;
|
|
1626
|
-
case GGML_TYPE_Q2_K:
|
|
1627
|
-
{
|
|
1628
|
-
nth0 = 2;
|
|
1629
|
-
nth1 = 32;
|
|
1630
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
1631
|
-
} break;
|
|
1632
|
-
case GGML_TYPE_Q3_K:
|
|
1633
|
-
{
|
|
1634
|
-
nth0 = 2;
|
|
1635
|
-
nth1 = 32;
|
|
1636
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
1637
|
-
} break;
|
|
1638
|
-
case GGML_TYPE_Q4_K:
|
|
1639
|
-
{
|
|
1640
|
-
nth0 = 4; //1;
|
|
1641
|
-
nth1 = 8; //32;
|
|
1642
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
1643
|
-
} break;
|
|
1644
|
-
case GGML_TYPE_Q5_K:
|
|
1645
|
-
{
|
|
1646
|
-
nth0 = 2;
|
|
1647
|
-
nth1 = 32;
|
|
1648
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
1649
|
-
} break;
|
|
1650
|
-
case GGML_TYPE_Q6_K:
|
|
1651
|
-
{
|
|
1652
|
-
nth0 = 2;
|
|
1653
|
-
nth1 = 32;
|
|
1654
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
1655
|
-
} break;
|
|
1656
|
-
default:
|
|
1657
|
-
{
|
|
1658
|
-
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1659
|
-
GGML_ASSERT(false && "not implemented");
|
|
1660
|
-
}
|
|
1661
|
-
};
|
|
1662
|
-
|
|
1663
|
-
if (ggml_is_quantized(src0t)) {
|
|
1664
|
-
GGML_ASSERT(ne00 >= nth0*nth1);
|
|
1665
|
-
}
|
|
1101
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1102
|
+
} break;
|
|
1103
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
|
1104
|
+
{
|
|
1105
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
|
1666
1106
|
|
|
1107
|
+
[encoder setComputePipelineState:pipeline];
|
|
1667
1108
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1668
|
-
[encoder setBuffer:
|
|
1669
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1670
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1671
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1672
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1673
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1674
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1675
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1676
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
1677
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
1678
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
1679
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
1680
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
1681
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1682
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1683
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1684
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1685
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1686
|
-
|
|
1687
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
|
1688
|
-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
|
1689
|
-
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
|
1690
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1691
|
-
}
|
|
1692
|
-
else if (src0t == GGML_TYPE_Q4_K) {
|
|
1693
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1694
|
-
}
|
|
1695
|
-
else if (src0t == GGML_TYPE_Q3_K) {
|
|
1696
|
-
#ifdef GGML_QKK_64
|
|
1697
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1698
|
-
#else
|
|
1699
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1700
|
-
#endif
|
|
1701
|
-
}
|
|
1702
|
-
else if (src0t == GGML_TYPE_Q5_K) {
|
|
1703
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1704
|
-
}
|
|
1705
|
-
else if (src0t == GGML_TYPE_Q6_K) {
|
|
1706
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1707
|
-
} else {
|
|
1708
|
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1709
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1710
|
-
}
|
|
1711
|
-
}
|
|
1712
|
-
} break;
|
|
1713
|
-
case GGML_OP_MUL_MAT_ID:
|
|
1714
|
-
{
|
|
1715
|
-
//GGML_ASSERT(ne00 == ne10);
|
|
1716
|
-
//GGML_ASSERT(ne03 == ne13);
|
|
1717
|
-
|
|
1718
|
-
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
|
1719
|
-
|
|
1720
|
-
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
1721
|
-
|
|
1722
|
-
// TODO: make this more general
|
|
1723
|
-
GGML_ASSERT(n_as <= 8);
|
|
1724
|
-
|
|
1725
|
-
// max size of the src1ids array in the kernel stack
|
|
1726
|
-
GGML_ASSERT(ne11 <= 512);
|
|
1727
|
-
|
|
1728
|
-
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1729
|
-
|
|
1730
|
-
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1731
|
-
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1732
|
-
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1733
|
-
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
|
1734
|
-
|
|
1735
|
-
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
|
1736
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1737
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1738
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
|
1739
|
-
|
|
1740
|
-
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
|
1741
|
-
|
|
1742
|
-
GGML_ASSERT(!ggml_is_transposed(src2));
|
|
1743
|
-
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
1744
|
-
|
|
1745
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1746
|
-
|
|
1747
|
-
const uint r2 = ne12/ne22;
|
|
1748
|
-
const uint r3 = ne13/ne23;
|
|
1749
|
-
|
|
1750
|
-
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1751
|
-
// to the matrix-vector kernel
|
|
1752
|
-
int ne11_mm_min = n_as;
|
|
1753
|
-
|
|
1754
|
-
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1755
|
-
|
|
1756
|
-
// batch size
|
|
1757
|
-
GGML_ASSERT(ne01 == ne11);
|
|
1758
|
-
|
|
1759
|
-
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1760
|
-
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1761
|
-
// !!!
|
|
1762
|
-
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1763
|
-
// indirect matrix multiplication
|
|
1764
|
-
// !!!
|
|
1765
|
-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1766
|
-
ne20 % 32 == 0 && ne20 >= 64 &&
|
|
1767
|
-
ne11 > ne11_mm_min) {
|
|
1768
|
-
switch (src2->type) {
|
|
1769
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1770
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
1771
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
|
1772
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
|
1773
|
-
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
|
1774
|
-
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
|
1775
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
|
1776
|
-
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
|
1777
|
-
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
|
1778
|
-
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
|
1779
|
-
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
|
1780
|
-
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
1781
|
-
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1782
|
-
}
|
|
1783
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1784
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1785
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1786
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1787
|
-
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1788
|
-
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1789
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1790
|
-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1791
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1792
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1793
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1794
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1795
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1796
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1797
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
1798
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1799
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1800
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1801
|
-
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1802
|
-
// TODO: how to make this an array? read Metal docs
|
|
1803
|
-
for (int j = 0; j < 8; ++j) {
|
|
1804
|
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1805
|
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1806
|
-
|
|
1807
|
-
size_t offs_src_cur = 0;
|
|
1808
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1809
|
-
|
|
1810
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1811
|
-
}
|
|
1812
|
-
|
|
1813
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1814
|
-
|
|
1815
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1816
|
-
} else {
|
|
1817
|
-
int nth0 = 32;
|
|
1818
|
-
int nth1 = 1;
|
|
1819
|
-
int nrows = 1;
|
|
1820
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1821
|
-
|
|
1822
|
-
// use custom matrix x vector kernel
|
|
1823
|
-
switch (src2t) {
|
|
1824
|
-
case GGML_TYPE_F32:
|
|
1825
|
-
{
|
|
1826
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1827
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
|
1828
|
-
} break;
|
|
1829
|
-
case GGML_TYPE_F16:
|
|
1830
|
-
{
|
|
1831
|
-
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1832
|
-
nth0 = 32;
|
|
1833
|
-
nth1 = 1;
|
|
1834
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
|
1835
|
-
} break;
|
|
1836
|
-
case GGML_TYPE_Q4_0:
|
|
1837
|
-
{
|
|
1838
|
-
nth0 = 8;
|
|
1839
|
-
nth1 = 8;
|
|
1840
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
|
1841
|
-
} break;
|
|
1842
|
-
case GGML_TYPE_Q4_1:
|
|
1843
|
-
{
|
|
1844
|
-
nth0 = 8;
|
|
1845
|
-
nth1 = 8;
|
|
1846
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
|
1847
|
-
} break;
|
|
1848
|
-
case GGML_TYPE_Q5_0:
|
|
1849
|
-
{
|
|
1850
|
-
nth0 = 8;
|
|
1851
|
-
nth1 = 8;
|
|
1852
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
|
1853
|
-
} break;
|
|
1854
|
-
case GGML_TYPE_Q5_1:
|
|
1855
|
-
{
|
|
1856
|
-
nth0 = 8;
|
|
1857
|
-
nth1 = 8;
|
|
1858
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
|
1859
|
-
} break;
|
|
1860
|
-
case GGML_TYPE_Q8_0:
|
|
1861
|
-
{
|
|
1862
|
-
nth0 = 8;
|
|
1863
|
-
nth1 = 8;
|
|
1864
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
|
1865
|
-
} break;
|
|
1866
|
-
case GGML_TYPE_Q2_K:
|
|
1867
|
-
{
|
|
1868
|
-
nth0 = 2;
|
|
1869
|
-
nth1 = 32;
|
|
1870
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
|
1871
|
-
} break;
|
|
1872
|
-
case GGML_TYPE_Q3_K:
|
|
1873
|
-
{
|
|
1874
|
-
nth0 = 2;
|
|
1875
|
-
nth1 = 32;
|
|
1876
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
|
1877
|
-
} break;
|
|
1878
|
-
case GGML_TYPE_Q4_K:
|
|
1879
|
-
{
|
|
1880
|
-
nth0 = 4; //1;
|
|
1881
|
-
nth1 = 8; //32;
|
|
1882
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
|
1883
|
-
} break;
|
|
1884
|
-
case GGML_TYPE_Q5_K:
|
|
1885
|
-
{
|
|
1886
|
-
nth0 = 2;
|
|
1887
|
-
nth1 = 32;
|
|
1888
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
|
1889
|
-
} break;
|
|
1890
|
-
case GGML_TYPE_Q6_K:
|
|
1891
|
-
{
|
|
1892
|
-
nth0 = 2;
|
|
1893
|
-
nth1 = 32;
|
|
1894
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
|
1895
|
-
} break;
|
|
1896
|
-
default:
|
|
1897
|
-
{
|
|
1898
|
-
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
1899
|
-
GGML_ASSERT(false && "not implemented");
|
|
1900
|
-
}
|
|
1901
|
-
};
|
|
1109
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1902
1110
|
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
}
|
|
1111
|
+
const int64_t n = ggml_nelements(dst);
|
|
1112
|
+
GGML_ASSERT(n % 4 == 0);
|
|
1906
1113
|
|
|
1907
|
-
|
|
1114
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1115
|
+
} break;
|
|
1116
|
+
case GGML_UNARY_OP_SILU:
|
|
1117
|
+
{
|
|
1118
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
|
1908
1119
|
|
|
1120
|
+
[encoder setComputePipelineState:pipeline];
|
|
1909
1121
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1910
|
-
[encoder setBuffer:
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
[encoder
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1922
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1923
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1924
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1925
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1926
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
1927
|
-
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
|
1928
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
1929
|
-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
|
1930
|
-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
|
1931
|
-
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
|
1932
|
-
// TODO: how to make this an array? read Metal docs
|
|
1933
|
-
for (int j = 0; j < 8; ++j) {
|
|
1934
|
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1935
|
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1936
|
-
|
|
1937
|
-
size_t offs_src_cur = 0;
|
|
1938
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1939
|
-
|
|
1940
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1941
|
-
}
|
|
1942
|
-
|
|
1943
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
|
1944
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
|
1945
|
-
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
|
1946
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1947
|
-
}
|
|
1948
|
-
else if (src2t == GGML_TYPE_Q4_K) {
|
|
1949
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1950
|
-
}
|
|
1951
|
-
else if (src2t == GGML_TYPE_Q3_K) {
|
|
1952
|
-
#ifdef GGML_QKK_64
|
|
1953
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1954
|
-
#else
|
|
1955
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1956
|
-
#endif
|
|
1957
|
-
}
|
|
1958
|
-
else if (src2t == GGML_TYPE_Q5_K) {
|
|
1959
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1960
|
-
}
|
|
1961
|
-
else if (src2t == GGML_TYPE_Q6_K) {
|
|
1962
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1963
|
-
} else {
|
|
1964
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1965
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1966
|
-
}
|
|
1122
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1123
|
+
|
|
1124
|
+
const int64_t n = ggml_nelements(dst);
|
|
1125
|
+
GGML_ASSERT(n % 4 == 0);
|
|
1126
|
+
|
|
1127
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1128
|
+
} break;
|
|
1129
|
+
default:
|
|
1130
|
+
{
|
|
1131
|
+
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
|
1132
|
+
GGML_ASSERT(false);
|
|
1967
1133
|
}
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1134
|
+
} break;
|
|
1135
|
+
case GGML_OP_SQR:
|
|
1136
|
+
{
|
|
1137
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
1138
|
+
|
|
1139
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
|
|
1140
|
+
|
|
1141
|
+
[encoder setComputePipelineState:pipeline];
|
|
1142
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1143
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1144
|
+
|
|
1145
|
+
const int64_t n = ggml_nelements(dst);
|
|
1146
|
+
|
|
1147
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1148
|
+
} break;
|
|
1149
|
+
case GGML_OP_SUM_ROWS:
|
|
1150
|
+
{
|
|
1151
|
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
1152
|
+
|
|
1153
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
1154
|
+
|
|
1155
|
+
[encoder setComputePipelineState:pipeline];
|
|
1156
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1157
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1158
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1159
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1160
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1161
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1162
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1163
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1164
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1165
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1166
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1167
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1168
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1169
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1170
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1171
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1172
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1173
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1174
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1175
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1176
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1177
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1178
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1179
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1180
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1181
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1182
|
+
|
|
1183
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1184
|
+
} break;
|
|
1185
|
+
case GGML_OP_SOFT_MAX:
|
|
1186
|
+
{
|
|
1187
|
+
int nth = 32; // SIMD width
|
|
1188
|
+
|
|
1189
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1190
|
+
|
|
1191
|
+
if (ne00%4 == 0) {
|
|
1192
|
+
while (nth < ne00/4 && nth < 256) {
|
|
1193
|
+
nth *= 2;
|
|
1986
1194
|
}
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1991
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1992
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1993
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1994
|
-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1995
|
-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1996
|
-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1997
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1998
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1999
|
-
|
|
2000
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
2001
|
-
} break;
|
|
2002
|
-
case GGML_OP_RMS_NORM:
|
|
2003
|
-
{
|
|
2004
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
|
2005
|
-
|
|
2006
|
-
float eps;
|
|
2007
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
|
2008
|
-
|
|
2009
|
-
int nth = 32; // SIMD width
|
|
2010
|
-
|
|
2011
|
-
while (nth < ne00/4 && nth < 1024) {
|
|
1195
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
|
1196
|
+
} else {
|
|
1197
|
+
while (nth < ne00 && nth < 1024) {
|
|
2012
1198
|
nth *= 2;
|
|
2013
1199
|
}
|
|
1200
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
|
1201
|
+
}
|
|
2014
1202
|
|
|
2015
|
-
|
|
2016
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2017
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2018
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2019
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
2020
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
2021
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2022
|
-
|
|
2023
|
-
const int64_t nrows = ggml_nrows(src0);
|
|
2024
|
-
|
|
2025
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2026
|
-
} break;
|
|
2027
|
-
case GGML_OP_GROUP_NORM:
|
|
2028
|
-
{
|
|
2029
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
|
2030
|
-
|
|
2031
|
-
//float eps;
|
|
2032
|
-
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
2033
|
-
|
|
2034
|
-
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
2035
|
-
|
|
2036
|
-
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
2037
|
-
|
|
2038
|
-
int nth = 32; // SIMD width
|
|
2039
|
-
|
|
2040
|
-
//while (nth < ne00/4 && nth < 1024) {
|
|
2041
|
-
// nth *= 2;
|
|
2042
|
-
//}
|
|
2043
|
-
|
|
2044
|
-
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
|
2045
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2046
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2047
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2048
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2049
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2050
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
2051
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
2052
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
2053
|
-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
2054
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
2055
|
-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2056
|
-
|
|
2057
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2058
|
-
} break;
|
|
2059
|
-
case GGML_OP_NORM:
|
|
2060
|
-
{
|
|
2061
|
-
float eps;
|
|
2062
|
-
memcpy(&eps, dst->op_params, sizeof(float));
|
|
2063
|
-
|
|
2064
|
-
const int nth = MIN(256, ne00);
|
|
2065
|
-
|
|
2066
|
-
[encoder setComputePipelineState:ctx->pipeline_norm];
|
|
2067
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2068
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2069
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2070
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
2071
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
2072
|
-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
1203
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
2073
1204
|
|
|
2074
|
-
|
|
1205
|
+
[encoder setComputePipelineState:pipeline];
|
|
1206
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1207
|
+
if (id_src1) {
|
|
1208
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1209
|
+
} else {
|
|
1210
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
1211
|
+
}
|
|
1212
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1213
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1214
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1215
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1216
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1217
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1218
|
+
|
|
1219
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1220
|
+
} break;
|
|
1221
|
+
case GGML_OP_DIAG_MASK_INF:
|
|
1222
|
+
{
|
|
1223
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
1224
|
+
|
|
1225
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1226
|
+
|
|
1227
|
+
if (ne00%8 == 0) {
|
|
1228
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
|
|
1229
|
+
} else {
|
|
1230
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
|
1231
|
+
}
|
|
2075
1232
|
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
1233
|
+
[encoder setComputePipelineState:pipeline];
|
|
1234
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1235
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1236
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1237
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1238
|
+
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
|
2081
1239
|
|
|
2082
|
-
|
|
1240
|
+
if (ne00%8 == 0) {
|
|
1241
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1242
|
+
}
|
|
1243
|
+
else {
|
|
1244
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1245
|
+
}
|
|
1246
|
+
} break;
|
|
1247
|
+
case GGML_OP_MUL_MAT:
|
|
1248
|
+
{
|
|
1249
|
+
GGML_ASSERT(ne00 == ne10);
|
|
2083
1250
|
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1251
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
|
1252
|
+
GGML_ASSERT(ne12 % ne02 == 0);
|
|
1253
|
+
GGML_ASSERT(ne13 % ne03 == 0);
|
|
2088
1254
|
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1255
|
+
const uint r2 = ne12/ne02;
|
|
1256
|
+
const uint r3 = ne13/ne03;
|
|
2092
1257
|
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2097
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2098
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2099
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2100
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2101
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2102
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2103
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2104
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2105
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2106
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2107
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2108
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2109
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2110
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2111
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
2112
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
2113
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
2114
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1258
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1259
|
+
// to the matrix-vector kernel
|
|
1260
|
+
int ne11_mm_min = 1;
|
|
2115
1261
|
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
|
|
2125
|
-
|
|
2126
|
-
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
1262
|
+
#if 0
|
|
1263
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
1264
|
+
// these numbers do not translate to other devices or model sizes
|
|
1265
|
+
// TODO: need to find a better approach
|
|
1266
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1267
|
+
switch (src0t) {
|
|
1268
|
+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1269
|
+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1270
|
+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1271
|
+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1272
|
+
case GGML_TYPE_Q4_0:
|
|
1273
|
+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1274
|
+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1275
|
+
case GGML_TYPE_Q5_0: // not tested yet
|
|
1276
|
+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1277
|
+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1278
|
+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1279
|
+
default: ne11_mm_min = 1; break;
|
|
1280
|
+
}
|
|
1281
|
+
}
|
|
1282
|
+
#endif
|
|
1283
|
+
|
|
1284
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1285
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1286
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1287
|
+
!ggml_is_transposed(src0) &&
|
|
1288
|
+
!ggml_is_transposed(src1) &&
|
|
1289
|
+
src1t == GGML_TYPE_F32 &&
|
|
1290
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1291
|
+
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1292
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1293
|
+
|
|
1294
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2137
1295
|
|
|
2138
1296
|
switch (src0->type) {
|
|
2139
|
-
case GGML_TYPE_F32:
|
|
2140
|
-
case GGML_TYPE_F16:
|
|
2141
|
-
|
|
1297
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
|
1298
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
|
1299
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
|
1300
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
|
1301
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
1302
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
1303
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
1304
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
1305
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
1306
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
1307
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
|
1308
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
|
1309
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
|
1310
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
|
1311
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
1312
|
+
}
|
|
1313
|
+
|
|
1314
|
+
[encoder setComputePipelineState:pipeline];
|
|
1315
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1316
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1317
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1318
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1319
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1320
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
|
1321
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
|
1322
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
1323
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
1324
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
1325
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1326
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1327
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1328
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1329
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1330
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1331
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1332
|
+
} else {
|
|
1333
|
+
int nth0 = 32;
|
|
1334
|
+
int nth1 = 1;
|
|
1335
|
+
int nrows = 1;
|
|
1336
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1337
|
+
|
|
1338
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1339
|
+
|
|
1340
|
+
// use custom matrix x vector kernel
|
|
1341
|
+
switch (src0t) {
|
|
1342
|
+
case GGML_TYPE_F32:
|
|
1343
|
+
{
|
|
1344
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1345
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
1346
|
+
nrows = 4;
|
|
1347
|
+
} break;
|
|
1348
|
+
case GGML_TYPE_F16:
|
|
1349
|
+
{
|
|
1350
|
+
nth0 = 32;
|
|
1351
|
+
nth1 = 1;
|
|
1352
|
+
if (src1t == GGML_TYPE_F32) {
|
|
1353
|
+
if (ne11 * ne12 < 4) {
|
|
1354
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
1355
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1356
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
1357
|
+
nrows = ne11;
|
|
1358
|
+
} else {
|
|
1359
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
|
1360
|
+
nrows = 4;
|
|
1361
|
+
}
|
|
1362
|
+
} else {
|
|
1363
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
|
1364
|
+
nrows = 4;
|
|
1365
|
+
}
|
|
1366
|
+
} break;
|
|
1367
|
+
case GGML_TYPE_Q4_0:
|
|
1368
|
+
{
|
|
1369
|
+
nth0 = 8;
|
|
1370
|
+
nth1 = 8;
|
|
1371
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
|
1372
|
+
} break;
|
|
1373
|
+
case GGML_TYPE_Q4_1:
|
|
1374
|
+
{
|
|
1375
|
+
nth0 = 8;
|
|
1376
|
+
nth1 = 8;
|
|
1377
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
|
1378
|
+
} break;
|
|
1379
|
+
case GGML_TYPE_Q5_0:
|
|
1380
|
+
{
|
|
1381
|
+
nth0 = 8;
|
|
1382
|
+
nth1 = 8;
|
|
1383
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
|
1384
|
+
} break;
|
|
1385
|
+
case GGML_TYPE_Q5_1:
|
|
1386
|
+
{
|
|
1387
|
+
nth0 = 8;
|
|
1388
|
+
nth1 = 8;
|
|
1389
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
|
1390
|
+
} break;
|
|
1391
|
+
case GGML_TYPE_Q8_0:
|
|
1392
|
+
{
|
|
1393
|
+
nth0 = 8;
|
|
1394
|
+
nth1 = 8;
|
|
1395
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
1396
|
+
} break;
|
|
1397
|
+
case GGML_TYPE_Q2_K:
|
|
1398
|
+
{
|
|
1399
|
+
nth0 = 2;
|
|
1400
|
+
nth1 = 32;
|
|
1401
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
|
1402
|
+
} break;
|
|
1403
|
+
case GGML_TYPE_Q3_K:
|
|
1404
|
+
{
|
|
1405
|
+
nth0 = 2;
|
|
1406
|
+
nth1 = 32;
|
|
1407
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
|
1408
|
+
} break;
|
|
1409
|
+
case GGML_TYPE_Q4_K:
|
|
1410
|
+
{
|
|
1411
|
+
nth0 = 4; //1;
|
|
1412
|
+
nth1 = 8; //32;
|
|
1413
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
|
1414
|
+
} break;
|
|
1415
|
+
case GGML_TYPE_Q5_K:
|
|
1416
|
+
{
|
|
1417
|
+
nth0 = 2;
|
|
1418
|
+
nth1 = 32;
|
|
1419
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
|
1420
|
+
} break;
|
|
1421
|
+
case GGML_TYPE_Q6_K:
|
|
1422
|
+
{
|
|
1423
|
+
nth0 = 2;
|
|
1424
|
+
nth1 = 32;
|
|
1425
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
|
1426
|
+
} break;
|
|
1427
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1428
|
+
{
|
|
1429
|
+
nth0 = 4;
|
|
1430
|
+
nth1 = 16;
|
|
1431
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
|
1432
|
+
} break;
|
|
1433
|
+
case GGML_TYPE_IQ2_XS:
|
|
1434
|
+
{
|
|
1435
|
+
nth0 = 4;
|
|
1436
|
+
nth1 = 16;
|
|
1437
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
|
1438
|
+
} break;
|
|
1439
|
+
default:
|
|
1440
|
+
{
|
|
1441
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1442
|
+
GGML_ASSERT(false && "not implemented");
|
|
1443
|
+
}
|
|
2142
1444
|
};
|
|
2143
1445
|
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
2148
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
2149
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
2150
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
2151
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
2152
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
2153
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
2154
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
2155
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
2156
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
2157
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
2158
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
2159
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
2160
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
2161
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
2162
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
2163
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
2164
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
2165
|
-
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
2166
|
-
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
2167
|
-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
2168
|
-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
2169
|
-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
2170
|
-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
2171
|
-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
2172
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
1446
|
+
if (ggml_is_quantized(src0t)) {
|
|
1447
|
+
GGML_ASSERT(ne00 >= nth0*nth1);
|
|
1448
|
+
}
|
|
2173
1449
|
|
|
2174
|
-
[encoder
|
|
2175
|
-
|
|
2176
|
-
|
|
2177
|
-
|
|
2178
|
-
|
|
2179
|
-
|
|
2180
|
-
|
|
1450
|
+
[encoder setComputePipelineState:pipeline];
|
|
1451
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1452
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1453
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1454
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1455
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1456
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1457
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1458
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1459
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1460
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
1461
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
1462
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
1463
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
1464
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
1465
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1466
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1467
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1468
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1469
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1470
|
+
|
|
1471
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
|
1472
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
|
1473
|
+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
|
1474
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1475
|
+
}
|
|
1476
|
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
1477
|
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
1478
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1479
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1480
|
+
}
|
|
1481
|
+
else if (src0t == GGML_TYPE_Q4_K) {
|
|
1482
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1483
|
+
}
|
|
1484
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
|
1485
|
+
#ifdef GGML_QKK_64
|
|
1486
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1487
|
+
#else
|
|
1488
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1489
|
+
#endif
|
|
1490
|
+
}
|
|
1491
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
|
1492
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1493
|
+
}
|
|
1494
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
|
1495
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1496
|
+
} else {
|
|
1497
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1498
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1499
|
+
}
|
|
1500
|
+
}
|
|
1501
|
+
} break;
|
|
1502
|
+
case GGML_OP_MUL_MAT_ID:
|
|
1503
|
+
{
|
|
1504
|
+
//GGML_ASSERT(ne00 == ne10);
|
|
1505
|
+
//GGML_ASSERT(ne03 == ne13);
|
|
2181
1506
|
|
|
2182
|
-
|
|
2183
|
-
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
2184
|
-
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
2185
|
-
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
2186
|
-
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
2187
|
-
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
2188
|
-
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
1507
|
+
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
|
2189
1508
|
|
|
2190
|
-
|
|
2191
|
-
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
2192
|
-
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
2193
|
-
const int32_t IW = src1->ne[0];
|
|
1509
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
2194
1510
|
|
|
2195
|
-
|
|
2196
|
-
|
|
1511
|
+
// TODO: make this more general
|
|
1512
|
+
GGML_ASSERT(n_as <= 8);
|
|
2197
1513
|
|
|
2198
|
-
|
|
2199
|
-
|
|
1514
|
+
// max size of the src1ids array in the kernel stack
|
|
1515
|
+
GGML_ASSERT(ne11 <= 512);
|
|
2200
1516
|
|
|
2201
|
-
|
|
1517
|
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
2202
1518
|
|
|
2203
|
-
|
|
2204
|
-
|
|
1519
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1520
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1521
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1522
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
|
2205
1523
|
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
};
|
|
1524
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
|
1525
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1526
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1527
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
|
2211
1528
|
|
|
2212
|
-
|
|
2213
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2214
|
-
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
2215
|
-
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
2216
|
-
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
2217
|
-
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
2218
|
-
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
2219
|
-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
2220
|
-
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
2221
|
-
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
2222
|
-
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
2223
|
-
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
2224
|
-
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
2225
|
-
|
|
2226
|
-
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2227
|
-
} break;
|
|
2228
|
-
case GGML_OP_UPSCALE:
|
|
2229
|
-
{
|
|
2230
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2231
|
-
|
|
2232
|
-
const int sf = dst->op_params[0];
|
|
2233
|
-
|
|
2234
|
-
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
|
2235
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2236
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2237
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2238
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2239
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2240
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2241
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2242
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2243
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2244
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2245
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2246
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2247
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2248
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2249
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2250
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2251
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2252
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2253
|
-
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2254
|
-
|
|
2255
|
-
const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
|
|
2256
|
-
|
|
2257
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2258
|
-
} break;
|
|
2259
|
-
case GGML_OP_PAD:
|
|
2260
|
-
{
|
|
2261
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2262
|
-
|
|
2263
|
-
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
|
2264
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2265
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2266
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2267
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2268
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2269
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2270
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2271
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2272
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2273
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2274
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2275
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2276
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2277
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2278
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2279
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2280
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2281
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2282
|
-
|
|
2283
|
-
const int nth = MIN(1024, ne0);
|
|
2284
|
-
|
|
2285
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2286
|
-
} break;
|
|
2287
|
-
case GGML_OP_ARGSORT:
|
|
2288
|
-
{
|
|
2289
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2290
|
-
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
|
2291
|
-
|
|
2292
|
-
const int nrows = ggml_nrows(src0);
|
|
2293
|
-
|
|
2294
|
-
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
2295
|
-
|
|
2296
|
-
switch (order) {
|
|
2297
|
-
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
|
2298
|
-
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
|
2299
|
-
default: GGML_ASSERT(false);
|
|
2300
|
-
};
|
|
1529
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
|
2301
1530
|
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1531
|
+
GGML_ASSERT(!ggml_is_transposed(src2));
|
|
1532
|
+
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
2305
1533
|
|
|
2306
|
-
|
|
2307
|
-
} break;
|
|
2308
|
-
case GGML_OP_LEAKY_RELU:
|
|
2309
|
-
{
|
|
2310
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
1534
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
2311
1535
|
|
|
2312
|
-
|
|
2313
|
-
|
|
1536
|
+
const uint r2 = ne12/ne22;
|
|
1537
|
+
const uint r3 = ne13/ne23;
|
|
2314
1538
|
|
|
2315
|
-
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
|
1539
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1540
|
+
// to the matrix-vector kernel
|
|
1541
|
+
int ne11_mm_min = n_as;
|
|
2319
1542
|
|
|
2320
|
-
|
|
1543
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
|
2321
1544
|
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
case GGML_OP_DUP:
|
|
2325
|
-
case GGML_OP_CPY:
|
|
2326
|
-
case GGML_OP_CONT:
|
|
2327
|
-
{
|
|
2328
|
-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
1545
|
+
// batch size
|
|
1546
|
+
GGML_ASSERT(ne01 == ne11);
|
|
2329
1547
|
|
|
2330
|
-
|
|
1548
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1549
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1550
|
+
// !!!
|
|
1551
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1552
|
+
// indirect matrix multiplication
|
|
1553
|
+
// !!!
|
|
1554
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1555
|
+
ne20 % 32 == 0 && ne20 >= 64 &&
|
|
1556
|
+
ne11 > ne11_mm_min) {
|
|
2331
1557
|
|
|
2332
|
-
|
|
1558
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1559
|
+
|
|
1560
|
+
switch (src2->type) {
|
|
1561
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
|
1562
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
|
1563
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
|
1564
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
|
1565
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
|
1566
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
|
1567
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
|
1568
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
|
1569
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
|
1570
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
|
1571
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
|
1572
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
|
1573
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
|
1574
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
|
1575
|
+
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1576
|
+
}
|
|
1577
|
+
|
|
1578
|
+
[encoder setComputePipelineState:pipeline];
|
|
1579
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1580
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1581
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1582
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1583
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1584
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1585
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1586
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1587
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1588
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1589
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1590
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1591
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1592
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1593
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
1594
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1595
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1596
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1597
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1598
|
+
// TODO: how to make this an array? read Metal docs
|
|
1599
|
+
for (int j = 0; j < 8; ++j) {
|
|
1600
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1601
|
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1602
|
+
|
|
1603
|
+
size_t offs_src_cur = 0;
|
|
1604
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1605
|
+
|
|
1606
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1607
|
+
}
|
|
1608
|
+
|
|
1609
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1610
|
+
|
|
1611
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1612
|
+
} else {
|
|
1613
|
+
int nth0 = 32;
|
|
1614
|
+
int nth1 = 1;
|
|
1615
|
+
int nrows = 1;
|
|
1616
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1617
|
+
|
|
1618
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1619
|
+
|
|
1620
|
+
// use custom matrix x vector kernel
|
|
1621
|
+
switch (src2t) {
|
|
2333
1622
|
case GGML_TYPE_F32:
|
|
2334
1623
|
{
|
|
2335
|
-
GGML_ASSERT(
|
|
2336
|
-
|
|
2337
|
-
switch (dstt) {
|
|
2338
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
2339
|
-
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
2340
|
-
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
|
2341
|
-
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
|
2342
|
-
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
|
2343
|
-
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
|
2344
|
-
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
|
2345
|
-
default: GGML_ASSERT(false && "not implemented");
|
|
2346
|
-
};
|
|
1624
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1625
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
|
2347
1626
|
} break;
|
|
2348
1627
|
case GGML_TYPE_F16:
|
|
2349
1628
|
{
|
|
2350
|
-
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
|
|
2354
|
-
|
|
1629
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1630
|
+
nth0 = 32;
|
|
1631
|
+
nth1 = 1;
|
|
1632
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
|
1633
|
+
} break;
|
|
1634
|
+
case GGML_TYPE_Q4_0:
|
|
1635
|
+
{
|
|
1636
|
+
nth0 = 8;
|
|
1637
|
+
nth1 = 8;
|
|
1638
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
|
1639
|
+
} break;
|
|
1640
|
+
case GGML_TYPE_Q4_1:
|
|
1641
|
+
{
|
|
1642
|
+
nth0 = 8;
|
|
1643
|
+
nth1 = 8;
|
|
1644
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
|
2355
1645
|
} break;
|
|
2356
|
-
|
|
1646
|
+
case GGML_TYPE_Q5_0:
|
|
1647
|
+
{
|
|
1648
|
+
nth0 = 8;
|
|
1649
|
+
nth1 = 8;
|
|
1650
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
|
1651
|
+
} break;
|
|
1652
|
+
case GGML_TYPE_Q5_1:
|
|
1653
|
+
{
|
|
1654
|
+
nth0 = 8;
|
|
1655
|
+
nth1 = 8;
|
|
1656
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
|
1657
|
+
} break;
|
|
1658
|
+
case GGML_TYPE_Q8_0:
|
|
1659
|
+
{
|
|
1660
|
+
nth0 = 8;
|
|
1661
|
+
nth1 = 8;
|
|
1662
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
|
1663
|
+
} break;
|
|
1664
|
+
case GGML_TYPE_Q2_K:
|
|
1665
|
+
{
|
|
1666
|
+
nth0 = 2;
|
|
1667
|
+
nth1 = 32;
|
|
1668
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
|
1669
|
+
} break;
|
|
1670
|
+
case GGML_TYPE_Q3_K:
|
|
1671
|
+
{
|
|
1672
|
+
nth0 = 2;
|
|
1673
|
+
nth1 = 32;
|
|
1674
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
|
1675
|
+
} break;
|
|
1676
|
+
case GGML_TYPE_Q4_K:
|
|
1677
|
+
{
|
|
1678
|
+
nth0 = 4; //1;
|
|
1679
|
+
nth1 = 8; //32;
|
|
1680
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
|
1681
|
+
} break;
|
|
1682
|
+
case GGML_TYPE_Q5_K:
|
|
1683
|
+
{
|
|
1684
|
+
nth0 = 2;
|
|
1685
|
+
nth1 = 32;
|
|
1686
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
|
1687
|
+
} break;
|
|
1688
|
+
case GGML_TYPE_Q6_K:
|
|
1689
|
+
{
|
|
1690
|
+
nth0 = 2;
|
|
1691
|
+
nth1 = 32;
|
|
1692
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
|
1693
|
+
} break;
|
|
1694
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1695
|
+
{
|
|
1696
|
+
nth0 = 4;
|
|
1697
|
+
nth1 = 16;
|
|
1698
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
|
1699
|
+
} break;
|
|
1700
|
+
case GGML_TYPE_IQ2_XS:
|
|
1701
|
+
{
|
|
1702
|
+
nth0 = 4;
|
|
1703
|
+
nth1 = 16;
|
|
1704
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
|
1705
|
+
} break;
|
|
1706
|
+
default:
|
|
1707
|
+
{
|
|
1708
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
1709
|
+
GGML_ASSERT(false && "not implemented");
|
|
1710
|
+
}
|
|
1711
|
+
};
|
|
1712
|
+
|
|
1713
|
+
if (ggml_is_quantized(src2t)) {
|
|
1714
|
+
GGML_ASSERT(ne20 >= nth0*nth1);
|
|
2357
1715
|
}
|
|
2358
1716
|
|
|
2359
|
-
|
|
2360
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2361
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2362
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2363
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2364
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2365
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2366
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2367
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2368
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2369
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2370
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2371
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2372
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2373
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2374
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2375
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2376
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1717
|
+
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
|
2377
1718
|
|
|
2378
|
-
[encoder
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
1719
|
+
[encoder setComputePipelineState:pipeline];
|
|
1720
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1721
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1722
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1723
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1724
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1725
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
1726
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
|
1727
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
|
1728
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
|
1729
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
|
1730
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1731
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
|
1732
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1733
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1734
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1735
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1736
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1737
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
1738
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
|
1739
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
1740
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
|
1741
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
|
1742
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
|
1743
|
+
// TODO: how to make this an array? read Metal docs
|
|
1744
|
+
for (int j = 0; j < 8; ++j) {
|
|
1745
|
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
|
1746
|
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
|
1747
|
+
|
|
1748
|
+
size_t offs_src_cur = 0;
|
|
1749
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1750
|
+
|
|
1751
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1752
|
+
}
|
|
1753
|
+
|
|
1754
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
|
1755
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
|
1756
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
|
1757
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1758
|
+
}
|
|
1759
|
+
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
|
1760
|
+
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
1761
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1762
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1763
|
+
}
|
|
1764
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
|
1765
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1766
|
+
}
|
|
1767
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
|
1768
|
+
#ifdef GGML_QKK_64
|
|
1769
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1770
|
+
#else
|
|
1771
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1772
|
+
#endif
|
|
1773
|
+
}
|
|
1774
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
|
1775
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1776
|
+
}
|
|
1777
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
|
1778
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1779
|
+
} else {
|
|
1780
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1781
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1782
|
+
}
|
|
1783
|
+
}
|
|
1784
|
+
} break;
|
|
1785
|
+
case GGML_OP_GET_ROWS:
|
|
1786
|
+
{
|
|
1787
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1788
|
+
|
|
1789
|
+
switch (src0->type) {
|
|
1790
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
|
1791
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
|
1792
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
|
1793
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
|
1794
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
1795
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
|
1796
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
|
1797
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
|
1798
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
|
1799
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
1800
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
|
|
1801
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
|
1802
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
|
1803
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
|
1804
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
1805
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
1806
|
+
}
|
|
1807
|
+
|
|
1808
|
+
[encoder setComputePipelineState:pipeline];
|
|
1809
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1810
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1811
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1812
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1813
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1814
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1815
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1816
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1817
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1818
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1819
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1820
|
+
|
|
1821
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1822
|
+
} break;
|
|
1823
|
+
case GGML_OP_RMS_NORM:
|
|
1824
|
+
{
|
|
1825
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
|
1826
|
+
|
|
1827
|
+
float eps;
|
|
1828
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1829
|
+
|
|
1830
|
+
int nth = 32; // SIMD width
|
|
1831
|
+
|
|
1832
|
+
while (nth < ne00/4 && nth < 1024) {
|
|
1833
|
+
nth *= 2;
|
|
1834
|
+
}
|
|
1835
|
+
|
|
1836
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
|
1837
|
+
|
|
1838
|
+
[encoder setComputePipelineState:pipeline];
|
|
1839
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1840
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1841
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1842
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1843
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1844
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1845
|
+
|
|
1846
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
1847
|
+
|
|
1848
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1849
|
+
} break;
|
|
1850
|
+
case GGML_OP_GROUP_NORM:
|
|
1851
|
+
{
|
|
1852
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
|
1853
|
+
|
|
1854
|
+
//float eps;
|
|
1855
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
1856
|
+
|
|
1857
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
1858
|
+
|
|
1859
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
1860
|
+
|
|
1861
|
+
int nth = 32; // SIMD width
|
|
1862
|
+
|
|
1863
|
+
//while (nth < ne00/4 && nth < 1024) {
|
|
1864
|
+
// nth *= 2;
|
|
1865
|
+
//}
|
|
1866
|
+
|
|
1867
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
|
1868
|
+
|
|
1869
|
+
[encoder setComputePipelineState:pipeline];
|
|
1870
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1871
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1872
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1873
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1874
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1875
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
1876
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
1877
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
1878
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
1879
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
1880
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1881
|
+
|
|
1882
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1883
|
+
} break;
|
|
1884
|
+
case GGML_OP_NORM:
|
|
1885
|
+
{
|
|
1886
|
+
float eps;
|
|
1887
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1888
|
+
|
|
1889
|
+
const int nth = MIN(256, ne00);
|
|
1890
|
+
|
|
1891
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
|
1892
|
+
|
|
1893
|
+
[encoder setComputePipelineState:pipeline];
|
|
1894
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1895
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1896
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1897
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1898
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1899
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
1900
|
+
|
|
1901
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
1902
|
+
|
|
1903
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1904
|
+
} break;
|
|
1905
|
+
case GGML_OP_ALIBI:
|
|
1906
|
+
{
|
|
1907
|
+
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
|
1908
|
+
|
|
1909
|
+
const int nth = MIN(1024, ne00);
|
|
1910
|
+
|
|
1911
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1912
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
1913
|
+
float max_bias;
|
|
1914
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1915
|
+
|
|
1916
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
1917
|
+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
1918
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1919
|
+
|
|
1920
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
|
1921
|
+
|
|
1922
|
+
[encoder setComputePipelineState:pipeline];
|
|
1923
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1924
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1925
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1926
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1927
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1928
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1929
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1930
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1931
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1932
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1933
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1934
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1935
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1936
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1937
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1938
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1939
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1940
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1941
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
1942
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
1943
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1944
|
+
|
|
1945
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1946
|
+
} break;
|
|
1947
|
+
case GGML_OP_ROPE:
|
|
1948
|
+
{
|
|
1949
|
+
GGML_ASSERT(ne10 == ne02);
|
|
1950
|
+
|
|
1951
|
+
const int nth = MIN(1024, ne00);
|
|
1952
|
+
|
|
1953
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1954
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1955
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1956
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
1957
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
1958
|
+
|
|
1959
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1960
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1961
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1962
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1963
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1964
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1965
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1966
|
+
|
|
1967
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1968
|
+
|
|
1969
|
+
switch (src0->type) {
|
|
1970
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
|
1971
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
|
1972
|
+
default: GGML_ASSERT(false);
|
|
1973
|
+
};
|
|
1974
|
+
|
|
1975
|
+
[encoder setComputePipelineState:pipeline];
|
|
1976
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1977
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1978
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1979
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1980
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
1981
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
1982
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
1983
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
1984
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
1985
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
1986
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
1987
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
1988
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
1989
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
1990
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
1991
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
1992
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
1993
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
1994
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
1995
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
1996
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
1997
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
1998
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
1999
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
2000
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
2001
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
2002
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
2003
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
2004
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
2005
|
+
|
|
2006
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2007
|
+
} break;
|
|
2008
|
+
case GGML_OP_IM2COL:
|
|
2009
|
+
{
|
|
2010
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
2011
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
2012
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
2013
|
+
|
|
2014
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
2015
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
2016
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
2017
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
2018
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
2019
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
2020
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
2021
|
+
|
|
2022
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
2023
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
2024
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
2025
|
+
const int32_t IW = src1->ne[0];
|
|
2026
|
+
|
|
2027
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
2028
|
+
const int32_t KW = src0->ne[0];
|
|
2029
|
+
|
|
2030
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
2031
|
+
const int32_t OW = dst->ne[1];
|
|
2032
|
+
|
|
2033
|
+
const int32_t CHW = IC * KH * KW;
|
|
2034
|
+
|
|
2035
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
2036
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
2037
|
+
|
|
2038
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2039
|
+
|
|
2040
|
+
switch (src0->type) {
|
|
2041
|
+
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
|
2042
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
|
2043
|
+
default: GGML_ASSERT(false);
|
|
2044
|
+
};
|
|
2045
|
+
|
|
2046
|
+
[encoder setComputePipelineState:pipeline];
|
|
2047
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
2048
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2049
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
2050
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
2051
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
2052
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
2053
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
2054
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
2055
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
2056
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
2057
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
2058
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
2059
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
2060
|
+
|
|
2061
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2062
|
+
} break;
|
|
2063
|
+
case GGML_OP_UPSCALE:
|
|
2064
|
+
{
|
|
2065
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2066
|
+
|
|
2067
|
+
const int sf = dst->op_params[0];
|
|
2068
|
+
|
|
2069
|
+
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
|
2070
|
+
|
|
2071
|
+
[encoder setComputePipelineState:pipeline];
|
|
2072
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2073
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2074
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2075
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2076
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2077
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2078
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2079
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2080
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2081
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2082
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2083
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2084
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2085
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2086
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2087
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2088
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2089
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2090
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2091
|
+
|
|
2092
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
2093
|
+
|
|
2094
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2095
|
+
} break;
|
|
2096
|
+
case GGML_OP_PAD:
|
|
2097
|
+
{
|
|
2098
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2099
|
+
|
|
2100
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
|
2101
|
+
|
|
2102
|
+
[encoder setComputePipelineState:pipeline];
|
|
2103
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2104
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2105
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2106
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2107
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2108
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2109
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2110
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2111
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2112
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2113
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2114
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2115
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2116
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2117
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2118
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2119
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2120
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2121
|
+
|
|
2122
|
+
const int nth = MIN(1024, ne0);
|
|
2123
|
+
|
|
2124
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2125
|
+
} break;
|
|
2126
|
+
case GGML_OP_ARGSORT:
|
|
2127
|
+
{
|
|
2128
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2129
|
+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
|
2130
|
+
|
|
2131
|
+
const int nrows = ggml_nrows(src0);
|
|
2132
|
+
|
|
2133
|
+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
2134
|
+
|
|
2135
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2136
|
+
|
|
2137
|
+
switch (order) {
|
|
2138
|
+
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
|
2139
|
+
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
|
2140
|
+
default: GGML_ASSERT(false);
|
|
2141
|
+
};
|
|
2142
|
+
|
|
2143
|
+
[encoder setComputePipelineState:pipeline];
|
|
2144
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2145
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2146
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2147
|
+
|
|
2148
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
2149
|
+
} break;
|
|
2150
|
+
case GGML_OP_LEAKY_RELU:
|
|
2151
|
+
{
|
|
2152
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2153
|
+
|
|
2154
|
+
float slope;
|
|
2155
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
|
2156
|
+
|
|
2157
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
|
2158
|
+
|
|
2159
|
+
[encoder setComputePipelineState:pipeline];
|
|
2160
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2161
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2162
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
|
2163
|
+
|
|
2164
|
+
const int64_t n = ggml_nelements(dst);
|
|
2165
|
+
|
|
2166
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2167
|
+
} break;
|
|
2168
|
+
case GGML_OP_DUP:
|
|
2169
|
+
case GGML_OP_CPY:
|
|
2170
|
+
case GGML_OP_CONT:
|
|
2171
|
+
{
|
|
2172
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
2173
|
+
|
|
2174
|
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
2175
|
+
|
|
2176
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2177
|
+
|
|
2178
|
+
switch (src0t) {
|
|
2179
|
+
case GGML_TYPE_F32:
|
|
2180
|
+
{
|
|
2181
|
+
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
|
2182
|
+
|
|
2183
|
+
switch (dstt) {
|
|
2184
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
|
2185
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
|
2186
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
|
2187
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
|
2188
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
2189
|
+
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
|
2190
|
+
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
|
2191
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
2192
|
+
};
|
|
2193
|
+
} break;
|
|
2194
|
+
case GGML_TYPE_F16:
|
|
2195
|
+
{
|
|
2196
|
+
switch (dstt) {
|
|
2197
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
|
2198
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
|
2199
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
2200
|
+
};
|
|
2201
|
+
} break;
|
|
2202
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
2384
2203
|
}
|
|
2385
|
-
}
|
|
2386
|
-
}
|
|
2387
2204
|
|
|
2388
|
-
|
|
2389
|
-
|
|
2390
|
-
|
|
2205
|
+
[encoder setComputePipelineState:pipeline];
|
|
2206
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2207
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2208
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2209
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
2210
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
2211
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
2212
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
2213
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
2214
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
2215
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
2216
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
2217
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
2218
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
2219
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
2220
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
2221
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
2222
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
2223
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
2224
|
+
|
|
2225
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2226
|
+
} break;
|
|
2227
|
+
default:
|
|
2228
|
+
{
|
|
2229
|
+
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
|
2230
|
+
GGML_ASSERT(false);
|
|
2231
|
+
}
|
|
2391
2232
|
}
|
|
2392
2233
|
|
|
2393
|
-
|
|
2394
|
-
|
|
2395
|
-
|
|
2234
|
+
#ifndef GGML_METAL_NDEBUG
|
|
2235
|
+
[encoder popDebugGroup];
|
|
2236
|
+
#endif
|
|
2237
|
+
}
|
|
2238
|
+
|
|
2239
|
+
if (encoder != nil) {
|
|
2240
|
+
[encoder endEncoding];
|
|
2241
|
+
encoder = nil;
|
|
2242
|
+
}
|
|
2396
2243
|
|
|
2397
|
-
|
|
2398
|
-
|
|
2244
|
+
[command_buffer commit];
|
|
2245
|
+
});
|
|
2399
2246
|
|
|
2400
|
-
// check status of command
|
|
2247
|
+
// Wait for completion and check status of each command buffer
|
|
2401
2248
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
2402
|
-
for (int i = 0; i < n_cb; i++) {
|
|
2403
|
-
[ctx->command_buffers[i] waitUntilCompleted];
|
|
2404
2249
|
|
|
2405
|
-
|
|
2250
|
+
for (int i = 0; i < n_cb; ++i) {
|
|
2251
|
+
id<MTLCommandBuffer> command_buffer = command_buffers[i];
|
|
2252
|
+
[command_buffer waitUntilCompleted];
|
|
2253
|
+
|
|
2254
|
+
MTLCommandBufferStatus status = [command_buffer status];
|
|
2406
2255
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
2407
2256
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
2408
|
-
|
|
2257
|
+
return false;
|
|
2409
2258
|
}
|
|
2410
2259
|
}
|
|
2411
2260
|
|
|
2261
|
+
return true;
|
|
2412
2262
|
}
|
|
2413
2263
|
}
|
|
2414
2264
|
|
|
@@ -2441,13 +2291,13 @@ static void ggml_backend_metal_free_device(void) {
|
|
|
2441
2291
|
}
|
|
2442
2292
|
}
|
|
2443
2293
|
|
|
2444
|
-
static
|
|
2445
|
-
|
|
2294
|
+
GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
2295
|
+
return "Metal";
|
|
2446
2296
|
|
|
2447
|
-
|
|
2297
|
+
UNUSED(buffer);
|
|
2448
2298
|
}
|
|
2449
2299
|
|
|
2450
|
-
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
2300
|
+
GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
2451
2301
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
2452
2302
|
|
|
2453
2303
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
|
@@ -2462,50 +2312,80 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
|
2462
2312
|
free(ctx);
|
|
2463
2313
|
}
|
|
2464
2314
|
|
|
2465
|
-
static void
|
|
2466
|
-
|
|
2315
|
+
GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
2316
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
2467
2317
|
|
|
2468
|
-
|
|
2318
|
+
return ctx->all_data;
|
|
2469
2319
|
}
|
|
2470
2320
|
|
|
2471
|
-
static void
|
|
2472
|
-
memcpy(
|
|
2321
|
+
GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
2322
|
+
memcpy((char *)tensor->data + offset, data, size);
|
|
2473
2323
|
|
|
2474
2324
|
UNUSED(buffer);
|
|
2475
2325
|
}
|
|
2476
2326
|
|
|
2477
|
-
static void
|
|
2478
|
-
|
|
2327
|
+
GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
2328
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
|
2479
2329
|
|
|
2480
2330
|
UNUSED(buffer);
|
|
2481
2331
|
}
|
|
2482
2332
|
|
|
2483
|
-
static
|
|
2484
|
-
|
|
2333
|
+
GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
|
2334
|
+
if (ggml_backend_buffer_is_host(src->buffer)) {
|
|
2335
|
+
memcpy(dst->data, src->data, ggml_nbytes(src));
|
|
2336
|
+
return true;
|
|
2337
|
+
}
|
|
2338
|
+
return false;
|
|
2485
2339
|
|
|
2486
2340
|
UNUSED(buffer);
|
|
2487
2341
|
}
|
|
2488
2342
|
|
|
2489
|
-
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
2343
|
+
GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
2490
2344
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
|
2491
2345
|
|
|
2492
2346
|
memset(ctx->all_data, value, ctx->all_size);
|
|
2493
2347
|
}
|
|
2494
2348
|
|
|
2495
2349
|
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
|
2350
|
+
/* .get_name = */ ggml_backend_metal_buffer_get_name,
|
|
2496
2351
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
|
2497
2352
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
|
2498
2353
|
/* .init_tensor = */ NULL,
|
|
2499
2354
|
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
|
2500
2355
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
|
2501
|
-
/* .
|
|
2502
|
-
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
|
2356
|
+
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
|
2503
2357
|
/* .clear = */ ggml_backend_metal_buffer_clear,
|
|
2358
|
+
/* .reset = */ NULL,
|
|
2504
2359
|
};
|
|
2505
2360
|
|
|
2506
2361
|
// default buffer type
|
|
2507
2362
|
|
|
2508
|
-
static
|
|
2363
|
+
GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
2364
|
+
return "Metal";
|
|
2365
|
+
|
|
2366
|
+
UNUSED(buft);
|
|
2367
|
+
}
|
|
2368
|
+
|
|
2369
|
+
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
2370
|
+
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
2371
|
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
2372
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
2373
|
+
device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
2374
|
+
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
2375
|
+
|
|
2376
|
+
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
|
2377
|
+
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
2378
|
+
} else {
|
|
2379
|
+
GGML_METAL_LOG_INFO("\n");
|
|
2380
|
+
}
|
|
2381
|
+
} else {
|
|
2382
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
2383
|
+
}
|
|
2384
|
+
#endif
|
|
2385
|
+
UNUSED(device);
|
|
2386
|
+
}
|
|
2387
|
+
|
|
2388
|
+
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
2509
2389
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
|
2510
2390
|
|
|
2511
2391
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
@@ -2537,46 +2417,32 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
2537
2417
|
}
|
|
2538
2418
|
|
|
2539
2419
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
#if TARGET_OS_OSX
|
|
2543
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
2544
|
-
device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
2545
|
-
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
2546
|
-
|
|
2547
|
-
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
|
2548
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
2549
|
-
} else {
|
|
2550
|
-
GGML_METAL_LOG_INFO("\n");
|
|
2551
|
-
}
|
|
2552
|
-
#else
|
|
2553
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
2554
|
-
#endif
|
|
2555
|
-
|
|
2420
|
+
ggml_backend_metal_log_allocated_size(device);
|
|
2556
2421
|
|
|
2557
2422
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
|
2558
2423
|
}
|
|
2559
2424
|
|
|
2560
|
-
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
2425
|
+
GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
2561
2426
|
return 32;
|
|
2562
2427
|
UNUSED(buft);
|
|
2563
2428
|
}
|
|
2564
2429
|
|
|
2565
|
-
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
|
2430
|
+
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
|
2566
2431
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
|
2567
2432
|
|
|
2568
2433
|
UNUSED(buft);
|
|
2569
2434
|
}
|
|
2570
2435
|
|
|
2571
|
-
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
2436
|
+
GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
2572
2437
|
return true;
|
|
2573
2438
|
|
|
2574
2439
|
UNUSED(buft);
|
|
2575
2440
|
}
|
|
2576
2441
|
|
|
2577
|
-
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2442
|
+
GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2578
2443
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
|
2579
2444
|
/* .iface = */ {
|
|
2445
|
+
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
|
2580
2446
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
|
2581
2447
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
|
2582
2448
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
@@ -2591,7 +2457,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
|
2591
2457
|
|
|
2592
2458
|
// buffer from ptr
|
|
2593
2459
|
|
|
2594
|
-
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
|
2460
|
+
GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
|
2595
2461
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
|
2596
2462
|
|
|
2597
2463
|
ctx->all_data = data;
|
|
@@ -2600,6 +2466,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
2600
2466
|
ctx->n_buffers = 0;
|
|
2601
2467
|
|
|
2602
2468
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
2469
|
+
|
|
2470
|
+
// page-align the data ptr
|
|
2471
|
+
{
|
|
2472
|
+
const uintptr_t offs = (uintptr_t) data % size_page;
|
|
2473
|
+
data = (void *) ((char *) data - offs);
|
|
2474
|
+
size += offs;
|
|
2475
|
+
}
|
|
2476
|
+
|
|
2603
2477
|
size_t size_aligned = size;
|
|
2604
2478
|
if ((size_aligned % size_page) != 0) {
|
|
2605
2479
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
@@ -2651,63 +2525,50 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
2651
2525
|
}
|
|
2652
2526
|
}
|
|
2653
2527
|
|
|
2654
|
-
|
|
2655
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
2656
|
-
device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
2657
|
-
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
2658
|
-
|
|
2659
|
-
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
|
2660
|
-
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
2661
|
-
} else {
|
|
2662
|
-
GGML_METAL_LOG_INFO("\n");
|
|
2663
|
-
}
|
|
2664
|
-
#else
|
|
2665
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
2666
|
-
#endif
|
|
2528
|
+
ggml_backend_metal_log_allocated_size(device);
|
|
2667
2529
|
|
|
2668
2530
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
|
2669
2531
|
}
|
|
2670
2532
|
|
|
2671
2533
|
// backend
|
|
2672
2534
|
|
|
2673
|
-
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
2535
|
+
GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
2674
2536
|
return "Metal";
|
|
2675
2537
|
|
|
2676
2538
|
UNUSED(backend);
|
|
2677
2539
|
}
|
|
2678
2540
|
|
|
2679
|
-
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
2541
|
+
GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
2680
2542
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
|
2681
2543
|
ggml_metal_free(ctx);
|
|
2682
2544
|
free(backend);
|
|
2683
2545
|
}
|
|
2684
2546
|
|
|
2685
|
-
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
|
2547
|
+
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
|
2686
2548
|
return ggml_backend_metal_buffer_type();
|
|
2687
2549
|
|
|
2688
2550
|
UNUSED(backend);
|
|
2689
2551
|
}
|
|
2690
2552
|
|
|
2691
|
-
static
|
|
2553
|
+
GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
2692
2554
|
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
|
2693
2555
|
|
|
2694
|
-
ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
2556
|
+
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
2695
2557
|
}
|
|
2696
2558
|
|
|
2697
|
-
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
2698
|
-
|
|
2559
|
+
GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
2560
|
+
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
|
2699
2561
|
|
|
2700
|
-
|
|
2562
|
+
return ggml_metal_supports_op(metal_ctx, op);
|
|
2701
2563
|
}
|
|
2702
2564
|
|
|
2703
|
-
static struct ggml_backend_i
|
|
2565
|
+
static struct ggml_backend_i ggml_backend_metal_i = {
|
|
2704
2566
|
/* .get_name = */ ggml_backend_metal_name,
|
|
2705
2567
|
/* .free = */ ggml_backend_metal_free,
|
|
2706
2568
|
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
|
2707
2569
|
/* .set_tensor_async = */ NULL,
|
|
2708
2570
|
/* .get_tensor_async = */ NULL,
|
|
2709
|
-
/* .
|
|
2710
|
-
/* .cpy_tensor_to_async = */ NULL,
|
|
2571
|
+
/* .cpy_tensor_async = */ NULL,
|
|
2711
2572
|
/* .synchronize = */ NULL,
|
|
2712
2573
|
/* .graph_plan_create = */ NULL,
|
|
2713
2574
|
/* .graph_plan_free = */ NULL,
|
|
@@ -2716,6 +2577,11 @@ static struct ggml_backend_i metal_backend_i = {
|
|
|
2716
2577
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
|
2717
2578
|
};
|
|
2718
2579
|
|
|
2580
|
+
void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
|
2581
|
+
ggml_metal_log_callback = log_callback;
|
|
2582
|
+
ggml_metal_log_user_data = user_data;
|
|
2583
|
+
}
|
|
2584
|
+
|
|
2719
2585
|
ggml_backend_t ggml_backend_metal_init(void) {
|
|
2720
2586
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
|
2721
2587
|
|
|
@@ -2726,7 +2592,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
2726
2592
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
|
2727
2593
|
|
|
2728
2594
|
*metal_backend = (struct ggml_backend) {
|
|
2729
|
-
/* .interface = */
|
|
2595
|
+
/* .interface = */ ggml_backend_metal_i,
|
|
2730
2596
|
/* .context = */ ctx,
|
|
2731
2597
|
};
|
|
2732
2598
|
|
|
@@ -2734,7 +2600,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
2734
2600
|
}
|
|
2735
2601
|
|
|
2736
2602
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
|
2737
|
-
return backend->iface.get_name == ggml_backend_metal_name;
|
|
2603
|
+
return backend && backend->iface.get_name == ggml_backend_metal_name;
|
|
2738
2604
|
}
|
|
2739
2605
|
|
|
2740
2606
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
@@ -2742,7 +2608,7 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
2742
2608
|
|
|
2743
2609
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
|
2744
2610
|
|
|
2745
|
-
|
|
2611
|
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
2746
2612
|
}
|
|
2747
2613
|
|
|
2748
2614
|
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
@@ -2753,9 +2619,9 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
2753
2619
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
2754
2620
|
}
|
|
2755
2621
|
|
|
2756
|
-
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2622
|
+
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2757
2623
|
|
|
2758
|
-
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2624
|
+
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2759
2625
|
return ggml_backend_metal_init();
|
|
2760
2626
|
|
|
2761
2627
|
GGML_UNUSED(params);
|