@lgrammel/ds4-provider 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -0
- package/binding.gyp +75 -0
- package/dist/ds4-language-model.d.ts +71 -0
- package/dist/ds4-language-model.d.ts.map +1 -0
- package/dist/ds4-language-model.js +888 -0
- package/dist/ds4-language-model.js.map +1 -0
- package/dist/ds4-provider.d.ts +13 -0
- package/dist/ds4-provider.d.ts.map +1 -0
- package/dist/ds4-provider.js +20 -0
- package/dist/ds4-provider.js.map +1 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +4 -0
- package/dist/index.js.map +1 -0
- package/dist/native-binding.d.ts +42 -0
- package/dist/native-binding.d.ts.map +1 -0
- package/dist/native-binding.js +157 -0
- package/dist/native-binding.js.map +1 -0
- package/ds4/LICENSE +22 -0
- package/ds4/ds4.c +18268 -0
- package/ds4/ds4.h +196 -0
- package/ds4/ds4_gpu.h +804 -0
- package/ds4/ds4_metal.m +14657 -0
- package/ds4/metal/argsort.metal +266 -0
- package/ds4/metal/bin.metal +192 -0
- package/ds4/metal/concat.metal +62 -0
- package/ds4/metal/cpy.metal +57 -0
- package/ds4/metal/dense.metal +1121 -0
- package/ds4/metal/dsv4_hc.metal +861 -0
- package/ds4/metal/dsv4_kv.metal +227 -0
- package/ds4/metal/dsv4_misc.metal +1088 -0
- package/ds4/metal/dsv4_rope.metal +155 -0
- package/ds4/metal/flash_attn.metal +1426 -0
- package/ds4/metal/get_rows.metal +54 -0
- package/ds4/metal/glu.metal +36 -0
- package/ds4/metal/moe.metal +1737 -0
- package/ds4/metal/norm.metal +153 -0
- package/ds4/metal/repeat.metal +52 -0
- package/ds4/metal/set_rows.metal +55 -0
- package/ds4/metal/softmax.metal +241 -0
- package/ds4/metal/sum_rows.metal +102 -0
- package/ds4/metal/unary.metal +312 -0
- package/native/binding.cpp +621 -0
- package/package.json +66 -0
- package/scripts/postinstall.cjs +13 -0
- package/scripts/vendor-ds4.cjs +67 -0
|
@@ -0,0 +1,1737 @@
|
|
|
1
|
+
// DS4 Metal routed-MoE matvec kernels.
|
|
2
|
+
|
|
3
|
+
#define QK_K 256
|
|
4
|
+
#define N_R0_Q2_K 4
|
|
5
|
+
#define N_R0_Q4_K 2
|
|
6
|
+
#define N_R0_IQ2_XXS 4
|
|
7
|
+
|
|
8
|
+
static constant uchar ds4_metal_kmask_iq2xs[8] = {
|
|
9
|
+
1, 2, 4, 8, 16, 32, 64, 128
|
|
10
|
+
};
|
|
11
|
+
|
|
12
|
+
static constant uchar ds4_metal_ksigns_iq2xs[128] = {
|
|
13
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
14
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
15
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
|
16
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
|
17
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
|
18
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
|
19
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
|
20
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
static constant ulong ds4_metal_iq2xxs_grid[256] = {
|
|
24
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
25
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
|
26
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
27
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
|
28
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
|
29
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
|
30
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
|
31
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
|
32
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
|
33
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
|
34
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
|
35
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
|
36
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
|
37
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
|
38
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
|
39
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
|
40
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
|
41
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
|
42
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
|
43
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
|
44
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
|
45
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
|
46
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
|
47
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
|
48
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
|
49
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
|
50
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
|
51
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
|
52
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
|
53
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
|
54
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
|
55
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
|
56
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
|
57
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
|
58
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
|
59
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
|
60
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
|
61
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
|
62
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
|
63
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
|
64
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
|
65
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
|
66
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
|
67
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
|
68
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
|
69
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
|
70
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
|
71
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
|
72
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
|
73
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
|
74
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
|
75
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
|
76
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
|
77
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
|
78
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
|
79
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
|
80
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
|
81
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
|
82
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
|
83
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
|
84
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
|
85
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
|
86
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
|
87
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
#define kmask_iq2xs ds4_metal_kmask_iq2xs
|
|
91
|
+
#define ksigns_iq2xs ds4_metal_ksigns_iq2xs
|
|
92
|
+
#define iq2xxs_grid ds4_metal_iq2xxs_grid
|
|
93
|
+
|
|
94
|
+
struct block_q2_K {
|
|
95
|
+
uchar scales[QK_K/16];
|
|
96
|
+
uchar qs[QK_K/4];
|
|
97
|
+
half d;
|
|
98
|
+
half dmin;
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
struct block_q4_K {
|
|
102
|
+
half d;
|
|
103
|
+
half dmin;
|
|
104
|
+
uchar scales[12];
|
|
105
|
+
uchar qs[QK_K/2];
|
|
106
|
+
};
|
|
107
|
+
|
|
108
|
+
struct block_iq2_xxs {
|
|
109
|
+
half d;
|
|
110
|
+
ushort qs[QK_K/8];
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
struct ds4_metal_dsv4_moe_swiglu_weight_args {
|
|
114
|
+
uint32_t width;
|
|
115
|
+
uint32_t rows;
|
|
116
|
+
uint64_t gate_row_stride;
|
|
117
|
+
uint64_t up_row_stride;
|
|
118
|
+
uint64_t mid_row_stride;
|
|
119
|
+
uint64_t weight_stride;
|
|
120
|
+
uint32_t write_clamped;
|
|
121
|
+
float clamp_value;
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
// Routed-MoE activation for the selected experts:
|
|
125
|
+
// clamp(gate), clamp(up), silu(gate) * up * route_weight. Normal inference
|
|
126
|
+
// does not consume gate/up after this point, so the fast path avoids writing the
|
|
127
|
+
// clamped intermediates back. A diagnostic env switch can restore those writes
|
|
128
|
+
// when comparing the old multi-kernel intermediate tensors.
|
|
129
|
+
kernel void kernel_dsv4_moe_swiglu_weight(
|
|
130
|
+
constant ds4_metal_dsv4_moe_swiglu_weight_args &args,
|
|
131
|
+
device char *gate,
|
|
132
|
+
device char *up,
|
|
133
|
+
device char *mid,
|
|
134
|
+
device const char *weights,
|
|
135
|
+
uint row [[threadgroup_position_in_grid]],
|
|
136
|
+
uint tid [[thread_position_in_threadgroup]],
|
|
137
|
+
uint ntg [[threads_per_threadgroup]]) {
|
|
138
|
+
if (row >= args.rows) return;
|
|
139
|
+
|
|
140
|
+
device float *gate_row = (device float *)(gate + (uint64_t)row * args.gate_row_stride);
|
|
141
|
+
device float *up_row = (device float *)(up + (uint64_t)row * args.up_row_stride);
|
|
142
|
+
device float *mid_row = (device float *)(mid + (uint64_t)row * args.mid_row_stride);
|
|
143
|
+
device const float *w = (device const float *)(weights + (uint64_t)row * args.weight_stride);
|
|
144
|
+
const float route_weight = w[0];
|
|
145
|
+
const float c = args.clamp_value;
|
|
146
|
+
|
|
147
|
+
for (uint i = tid; i < args.width; i += ntg) {
|
|
148
|
+
float g = gate_row[i];
|
|
149
|
+
float u = up_row[i];
|
|
150
|
+
if (c > 1.0e-6f) {
|
|
151
|
+
g = min(g, c);
|
|
152
|
+
u = clamp(u, -c, c);
|
|
153
|
+
if (args.write_clamped != 0) {
|
|
154
|
+
gate_row[i] = g;
|
|
155
|
+
up_row[i] = u;
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
const float silu = g / (1.0f + exp(-g));
|
|
159
|
+
mid_row[i] = silu * u * route_weight;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// Same routed-MoE activation as above, but stores the down-projection input in
|
|
164
|
+
// half precision. The grouped matmul path converts F32 activations to half
|
|
165
|
+
// before MMA anyway, so this cuts the large mid write/read traffic without
|
|
166
|
+
// changing the effective matmul input precision.
|
|
167
|
+
kernel void kernel_dsv4_moe_swiglu_weight_f16(
|
|
168
|
+
constant ds4_metal_dsv4_moe_swiglu_weight_args &args,
|
|
169
|
+
device char *gate,
|
|
170
|
+
device char *up,
|
|
171
|
+
device char *mid,
|
|
172
|
+
device const char *weights,
|
|
173
|
+
uint row [[threadgroup_position_in_grid]],
|
|
174
|
+
uint tid [[thread_position_in_threadgroup]],
|
|
175
|
+
uint ntg [[threads_per_threadgroup]]) {
|
|
176
|
+
if (row >= args.rows) return;
|
|
177
|
+
|
|
178
|
+
device float *gate_row = (device float *)(gate + (uint64_t)row * args.gate_row_stride);
|
|
179
|
+
device float *up_row = (device float *)(up + (uint64_t)row * args.up_row_stride);
|
|
180
|
+
device half *mid_row = (device half *)(mid + (uint64_t)row * args.mid_row_stride);
|
|
181
|
+
device const float *w = (device const float *)(weights + (uint64_t)row * args.weight_stride);
|
|
182
|
+
const float route_weight = w[0];
|
|
183
|
+
const float c = args.clamp_value;
|
|
184
|
+
|
|
185
|
+
for (uint i = tid; i < args.width; i += ntg) {
|
|
186
|
+
float g = gate_row[i];
|
|
187
|
+
float u = up_row[i];
|
|
188
|
+
if (c > 1.0e-6f) {
|
|
189
|
+
g = min(g, c);
|
|
190
|
+
u = clamp(u, -c, c);
|
|
191
|
+
if (args.write_clamped != 0) {
|
|
192
|
+
gate_row[i] = g;
|
|
193
|
+
up_row[i] = u;
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
const float silu = g / (1.0f + exp(-g));
|
|
197
|
+
mid_row[i] = (half)(silu * u * route_weight);
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
template <typename type4x4>
|
|
202
|
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
203
|
+
const float d = xb->d;
|
|
204
|
+
const float min = xb->dmin;
|
|
205
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
206
|
+
float dl, ml;
|
|
207
|
+
uint8_t sc = xb->scales[il];
|
|
208
|
+
|
|
209
|
+
q = q + 32*(il/8) + 16*(il&1);
|
|
210
|
+
il = (il/2)%4;
|
|
211
|
+
|
|
212
|
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
213
|
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
214
|
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
|
215
|
+
for (int i = 0; i < 16; ++i) {
|
|
216
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
|
221
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
|
222
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)),
|
|
223
|
+
uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
template <typename type4x4>
|
|
227
|
+
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 ®) {
|
|
228
|
+
device const uchar *q = xb->qs;
|
|
229
|
+
|
|
230
|
+
short is = (il / 4) * 2;
|
|
231
|
+
q = q + (il / 4) * 32 + 16 * (il & 1);
|
|
232
|
+
il = il & 3;
|
|
233
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il / 2, xb->scales);
|
|
234
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
235
|
+
const float min = xb->dmin;
|
|
236
|
+
const float dl = d * sc[0];
|
|
237
|
+
const float ml = min * sc[1];
|
|
238
|
+
|
|
239
|
+
const ushort mask = il < 2 ? 0x0F : 0xF0;
|
|
240
|
+
for (int i = 0; i < 16; ++i) {
|
|
241
|
+
reg[i / 4][i % 4] = dl * (q[i] & mask) - ml;
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
template <typename type4x4>
|
|
246
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
247
|
+
const float d = xb->d;
|
|
248
|
+
const int ib32 = il/2;
|
|
249
|
+
il = il%2;
|
|
250
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
251
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
252
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
253
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
254
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
255
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
256
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
257
|
+
for (int i = 0; i < 8; ++i) {
|
|
258
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
259
|
+
}
|
|
260
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
261
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
262
|
+
for (int i = 0; i < 8; ++i) {
|
|
263
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
struct ds4_metal_args_mul_mv_id {
|
|
268
|
+
int32_t nei0;
|
|
269
|
+
int32_t nei1;
|
|
270
|
+
uint64_t nbi1;
|
|
271
|
+
int32_t ne00;
|
|
272
|
+
int32_t ne01;
|
|
273
|
+
int32_t ne02;
|
|
274
|
+
uint64_t nb00;
|
|
275
|
+
uint64_t nb01;
|
|
276
|
+
uint64_t nb02;
|
|
277
|
+
int32_t ne10;
|
|
278
|
+
int32_t ne11;
|
|
279
|
+
int32_t ne12;
|
|
280
|
+
int32_t ne13;
|
|
281
|
+
uint64_t nb10;
|
|
282
|
+
uint64_t nb11;
|
|
283
|
+
uint64_t nb12;
|
|
284
|
+
int32_t ne0;
|
|
285
|
+
int32_t ne1;
|
|
286
|
+
uint64_t nb1;
|
|
287
|
+
int32_t nr0;
|
|
288
|
+
};
|
|
289
|
+
|
|
290
|
+
struct ds4_metal_args_mul_mm_id_map0 {
|
|
291
|
+
int32_t ne02;
|
|
292
|
+
int32_t ne10;
|
|
293
|
+
int32_t ne11;
|
|
294
|
+
uint64_t nb11;
|
|
295
|
+
uint64_t nb12;
|
|
296
|
+
int32_t ne21;
|
|
297
|
+
int32_t ne20;
|
|
298
|
+
uint64_t nb21;
|
|
299
|
+
};
|
|
300
|
+
|
|
301
|
+
struct ds4_metal_args_mul_mm_id {
|
|
302
|
+
int32_t ne00;
|
|
303
|
+
int32_t ne02;
|
|
304
|
+
uint64_t nb01;
|
|
305
|
+
uint64_t nb02;
|
|
306
|
+
uint64_t nb03;
|
|
307
|
+
int32_t ne11;
|
|
308
|
+
uint64_t nb10;
|
|
309
|
+
uint64_t nb11;
|
|
310
|
+
uint64_t nb12;
|
|
311
|
+
uint64_t nb13;
|
|
312
|
+
int32_t ne20;
|
|
313
|
+
int32_t ne21;
|
|
314
|
+
int32_t ne0;
|
|
315
|
+
int32_t ne1;
|
|
316
|
+
int16_t r2;
|
|
317
|
+
int16_t r3;
|
|
318
|
+
};
|
|
319
|
+
|
|
320
|
+
template<int nr0, typename args_t>
|
|
321
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
|
322
|
+
args_t args,
|
|
323
|
+
device const char * src0,
|
|
324
|
+
device const char * src1,
|
|
325
|
+
device char * dst,
|
|
326
|
+
threadgroup char * shmem,
|
|
327
|
+
uint3 tgpig,
|
|
328
|
+
ushort tiisg,
|
|
329
|
+
ushort sgitg) {
|
|
330
|
+
const short NSG = FC_mul_mv_nsg;
|
|
331
|
+
|
|
332
|
+
const int nb = args.ne00/QK_K;
|
|
333
|
+
|
|
334
|
+
const int r0 = tgpig.x;
|
|
335
|
+
const int r1 = tgpig.y;
|
|
336
|
+
const int im = tgpig.z;
|
|
337
|
+
|
|
338
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
339
|
+
|
|
340
|
+
const uint i12 = im%args.ne12;
|
|
341
|
+
const uint i13 = im/args.ne12;
|
|
342
|
+
|
|
343
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
344
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
345
|
+
|
|
346
|
+
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
|
347
|
+
device const float * y = (device const float *) (src1 + offset1);
|
|
348
|
+
|
|
349
|
+
float yl[32];
|
|
350
|
+
float sumf[nr0]={0.f};
|
|
351
|
+
|
|
352
|
+
const short ix = tiisg/8; // 0...3
|
|
353
|
+
const short it = tiisg%8; // 0...7
|
|
354
|
+
const short iq = it/4; // 0 or 1
|
|
355
|
+
const short ir = it%4; // 0...3
|
|
356
|
+
const short is = (8*ir)/16;// 0 or 1
|
|
357
|
+
|
|
358
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
|
359
|
+
|
|
360
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
|
361
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
362
|
+
for (short i = 0; i < 8; ++i) {
|
|
363
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
|
364
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
|
365
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
|
366
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
|
370
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
371
|
+
device const half * dh = &x[ib].d;
|
|
372
|
+
|
|
373
|
+
for (short row = 0; row < nr0; row++) {
|
|
374
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
|
375
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
|
376
|
+
for (int i = 0; i < 8; i += 2) {
|
|
377
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
|
378
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
|
379
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
|
380
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
|
381
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
|
382
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
|
383
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
|
384
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
|
385
|
+
}
|
|
386
|
+
float dall = dh[0];
|
|
387
|
+
float dmin = dh[1] * 1.f/16.f;
|
|
388
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
|
389
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
|
390
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
|
391
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
|
392
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
|
393
|
+
|
|
394
|
+
qs += args.nb01/2;
|
|
395
|
+
sc += args.nb01;
|
|
396
|
+
dh += args.nb01/2;
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
y4 += 4 * QK_K;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
403
|
+
|
|
404
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
|
405
|
+
float sum_all = simd_sum(sumf[row]);
|
|
406
|
+
if (tiisg == 0) {
|
|
407
|
+
dst_f32[first_row + row] = sum_all;
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
template<int nr0, typename args_t>
|
|
413
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
|
414
|
+
args_t args,
|
|
415
|
+
device const char *src0,
|
|
416
|
+
device const char *src1,
|
|
417
|
+
device char *dst,
|
|
418
|
+
threadgroup char *shmem,
|
|
419
|
+
uint3 tgpig,
|
|
420
|
+
ushort tiisg,
|
|
421
|
+
ushort sgitg) {
|
|
422
|
+
const short NSG = FC_mul_mv_nsg;
|
|
423
|
+
|
|
424
|
+
constexpr uint16_t kmask1 = 0x3f3f;
|
|
425
|
+
constexpr uint16_t kmask2 = 0x0f0f;
|
|
426
|
+
constexpr uint16_t kmask3 = 0xc0c0;
|
|
427
|
+
|
|
428
|
+
const short ix = tiisg / 8;
|
|
429
|
+
const short it = tiisg % 8;
|
|
430
|
+
const short iq = it / 4;
|
|
431
|
+
const short ir = it % 4;
|
|
432
|
+
|
|
433
|
+
const int nb = args.ne00 / QK_K;
|
|
434
|
+
|
|
435
|
+
const int r0 = tgpig.x;
|
|
436
|
+
const int r1 = tgpig.y;
|
|
437
|
+
const int im = tgpig.z;
|
|
438
|
+
|
|
439
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
440
|
+
|
|
441
|
+
const uint i12 = im % args.ne12;
|
|
442
|
+
const uint i13 = im / args.ne12;
|
|
443
|
+
|
|
444
|
+
const uint64_t offset0 = first_row * args.nb01 + (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03;
|
|
445
|
+
const uint64_t offset1 = r1 * args.nb11 + i12 * args.nb12 + i13 * args.nb13;
|
|
446
|
+
|
|
447
|
+
device const block_q4_K *x = (device const block_q4_K *)(src0 + offset0);
|
|
448
|
+
device const float *y = (device const float *)(src1 + offset1);
|
|
449
|
+
|
|
450
|
+
float yl[16];
|
|
451
|
+
float yh[16];
|
|
452
|
+
float sumf[nr0] = {0.f};
|
|
453
|
+
|
|
454
|
+
device const float *y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
455
|
+
|
|
456
|
+
uint16_t sc16[4];
|
|
457
|
+
thread const uint8_t *sc8 = (thread const uint8_t *)sc16;
|
|
458
|
+
|
|
459
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
|
460
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
461
|
+
|
|
462
|
+
for (short i = 0; i < 8; ++i) {
|
|
463
|
+
yl[i + 0] = y4[i + 0]; sumy[0] += yl[i + 0];
|
|
464
|
+
yl[i + 8] = y4[i + 32]; sumy[1] += yl[i + 8];
|
|
465
|
+
yh[i + 0] = y4[i + 128]; sumy[2] += yh[i + 0];
|
|
466
|
+
yh[i + 8] = y4[i + 160]; sumy[3] += yh[i + 8];
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
device const uint16_t *sc = (device const uint16_t *)x[ib].scales + iq;
|
|
470
|
+
device const uint16_t *q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
471
|
+
device const half *dh = &x[ib].d;
|
|
472
|
+
|
|
473
|
+
for (short row = 0; row < nr0; row++) {
|
|
474
|
+
sc16[0] = sc[0] & kmask1;
|
|
475
|
+
sc16[1] = sc[2] & kmask1;
|
|
476
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
477
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
|
478
|
+
|
|
479
|
+
device const uint16_t *q2 = q1 + 32;
|
|
480
|
+
|
|
481
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
|
482
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
|
483
|
+
|
|
484
|
+
FOR_UNROLL (short i = 0; i < 4; ++i) {
|
|
485
|
+
acc1[0] += yl[2 * i + 0] * (q1[i] & 0x000F);
|
|
486
|
+
acc1[1] += yl[2 * i + 1] * (q1[i] & 0x0F00);
|
|
487
|
+
acc1[2] += yl[2 * i + 8] * (q1[i] & 0x00F0);
|
|
488
|
+
acc1[3] += yl[2 * i + 9] * (q1[i] & 0xF000);
|
|
489
|
+
acc2[0] += yh[2 * i + 0] * (q2[i] & 0x000F);
|
|
490
|
+
acc2[1] += yh[2 * i + 1] * (q2[i] & 0x0F00);
|
|
491
|
+
acc2[2] += yh[2 * i + 8] * (q2[i] & 0x00F0);
|
|
492
|
+
acc2[3] += yh[2 * i + 9] * (q2[i] & 0xF000);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
sumf[row] += dh[0] * ((acc1[0] + 1.f / 256.f * acc1[1]) * sc8[0] +
|
|
496
|
+
(acc1[2] + 1.f / 256.f * acc1[3]) * sc8[1] * 1.f / 16.f +
|
|
497
|
+
(acc2[0] + 1.f / 256.f * acc2[1]) * sc8[4] +
|
|
498
|
+
(acc2[2] + 1.f / 256.f * acc2[3]) * sc8[5] * 1.f / 16.f) -
|
|
499
|
+
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
500
|
+
|
|
501
|
+
q1 += args.nb01 / 2;
|
|
502
|
+
sc += args.nb01 / 2;
|
|
503
|
+
dh += args.nb01 / 2;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
y4 += 4 * QK_K;
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
device float *dst_f32 = (device float *)dst + (uint64_t)im * args.ne0 * args.ne1 + (uint64_t)r1 * args.ne0;
|
|
510
|
+
|
|
511
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
|
512
|
+
float sum_all = simd_sum(sumf[row]);
|
|
513
|
+
if (tiisg == 0) {
|
|
514
|
+
dst_f32[first_row + row] = sum_all;
|
|
515
|
+
}
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
(void)shmem;
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
template<int nr0, typename args_t>
|
|
522
|
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
523
|
+
args_t args,
|
|
524
|
+
device const char * src0,
|
|
525
|
+
device const char * src1,
|
|
526
|
+
device char * dst,
|
|
527
|
+
threadgroup char * shmem,
|
|
528
|
+
uint3 tgpig,
|
|
529
|
+
ushort tiisg,
|
|
530
|
+
ushort sgitg) {
|
|
531
|
+
const short NSG = FC_mul_mv_nsg;
|
|
532
|
+
|
|
533
|
+
const int nb = args.ne00/QK_K;
|
|
534
|
+
|
|
535
|
+
const int r0 = tgpig.x;
|
|
536
|
+
const int r1 = tgpig.y;
|
|
537
|
+
const int im = tgpig.z;
|
|
538
|
+
|
|
539
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
540
|
+
|
|
541
|
+
const uint i12 = im%args.ne12;
|
|
542
|
+
const uint i13 = im/args.ne12;
|
|
543
|
+
|
|
544
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
545
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
546
|
+
|
|
547
|
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
|
548
|
+
device const float * y = (device const float *) (src1 + offset1);
|
|
549
|
+
|
|
550
|
+
float yl[32];
|
|
551
|
+
float sumf[nr0]={0.f};
|
|
552
|
+
|
|
553
|
+
const int nb32 = nb * (QK_K / 32);
|
|
554
|
+
|
|
555
|
+
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
|
|
556
|
+
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
|
|
557
|
+
{
|
|
558
|
+
int nval = 4;
|
|
559
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
560
|
+
for (int i = 0; i < nval; ++i) svalues[pos + i] = ds4_metal_iq2xxs_grid[pos + i];
|
|
561
|
+
nval = 2;
|
|
562
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
563
|
+
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ds4_metal_ksigns_iq2xs[pos+i];
|
|
564
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
const int ix = tiisg;
|
|
568
|
+
|
|
569
|
+
device const float * y4 = y + 32 * ix;
|
|
570
|
+
|
|
571
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
572
|
+
for (short i = 0; i < 32; ++i) {
|
|
573
|
+
yl[i] = y4[i];
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
577
|
+
const int ib = ib32 % (QK_K / 32);
|
|
578
|
+
|
|
579
|
+
device const block_iq2_xxs * xr = x + ibl;
|
|
580
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
581
|
+
device const half * dh = &xr->d;
|
|
582
|
+
|
|
583
|
+
for (short row = 0; row < nr0; row++) {
|
|
584
|
+
const float db = dh[0];
|
|
585
|
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
|
586
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
|
587
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
|
588
|
+
|
|
589
|
+
float sum = 0;
|
|
590
|
+
for (short l = 0; l < 4; ++l) {
|
|
591
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
|
592
|
+
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
|
593
|
+
for (short j = 0; j < 8; ++j) {
|
|
594
|
+
sum += yl[8*l + j] * grid[j] * (signs & ds4_metal_kmask_iq2xs[j] ? -1.f : 1.f);
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
sumf[row] += d * sum;
|
|
598
|
+
|
|
599
|
+
dh += args.nb01/2;
|
|
600
|
+
q2 += args.nb01/2;
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
y4 += 32 * 32;
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
607
|
+
|
|
608
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
|
609
|
+
float sum_all = simd_sum(sumf[row]);
|
|
610
|
+
if (tiisg == 0) {
|
|
611
|
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
|
612
|
+
}
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
template<int nr0>
|
|
617
|
+
void kernel_mul_mv_iq2_xxs_pair_f32_impl(
|
|
618
|
+
ds4_metal_args_mul_mv args,
|
|
619
|
+
device const char * src0_gate,
|
|
620
|
+
device const char * src0_up,
|
|
621
|
+
device const char * src1,
|
|
622
|
+
device char * dst_gate,
|
|
623
|
+
device char * dst_up,
|
|
624
|
+
threadgroup char * shmem,
|
|
625
|
+
uint3 tgpig,
|
|
626
|
+
ushort tiisg,
|
|
627
|
+
ushort sgitg) {
|
|
628
|
+
const short NSG = FC_mul_mv_nsg;
|
|
629
|
+
|
|
630
|
+
const int nb = args.ne00/QK_K;
|
|
631
|
+
|
|
632
|
+
const int r0 = tgpig.x;
|
|
633
|
+
const int r1 = tgpig.y;
|
|
634
|
+
const int im = tgpig.z;
|
|
635
|
+
|
|
636
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
637
|
+
|
|
638
|
+
const uint i12 = im%args.ne12;
|
|
639
|
+
const uint i13 = im/args.ne12;
|
|
640
|
+
|
|
641
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
642
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
643
|
+
|
|
644
|
+
device const block_iq2_xxs * xg = (device const block_iq2_xxs *) (src0_gate + offset0);
|
|
645
|
+
device const block_iq2_xxs * xu = (device const block_iq2_xxs *) (src0_up + offset0);
|
|
646
|
+
device const float * y = (device const float *) (src1 + offset1);
|
|
647
|
+
|
|
648
|
+
float yl[32];
|
|
649
|
+
float sumg[nr0]={0.f};
|
|
650
|
+
float sumu[nr0]={0.f};
|
|
651
|
+
|
|
652
|
+
const int nb32 = nb * (QK_K / 32);
|
|
653
|
+
|
|
654
|
+
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
|
|
655
|
+
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
|
|
656
|
+
{
|
|
657
|
+
int nval = 4;
|
|
658
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
659
|
+
for (int i = 0; i < nval; ++i) svalues[pos + i] = ds4_metal_iq2xxs_grid[pos + i];
|
|
660
|
+
nval = 2;
|
|
661
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
662
|
+
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ds4_metal_ksigns_iq2xs[pos+i];
|
|
663
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
const int ix = tiisg;
|
|
667
|
+
device const float * y4 = y + 32 * ix;
|
|
668
|
+
|
|
669
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
670
|
+
for (short i = 0; i < 32; ++i) {
|
|
671
|
+
yl[i] = y4[i];
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
675
|
+
const int ib = ib32 % (QK_K / 32);
|
|
676
|
+
|
|
677
|
+
device const block_iq2_xxs * xgr = xg + ibl;
|
|
678
|
+
device const block_iq2_xxs * xur = xu + ibl;
|
|
679
|
+
device const uint16_t * qg = xgr->qs + 4 * ib;
|
|
680
|
+
device const uint16_t * qu = xur->qs + 4 * ib;
|
|
681
|
+
device const half * dhg = &xgr->d;
|
|
682
|
+
device const half * dhu = &xur->d;
|
|
683
|
+
|
|
684
|
+
for (short row = 0; row < nr0; row++) {
|
|
685
|
+
device const uint8_t * aux8g = (device const uint8_t *)qg;
|
|
686
|
+
device const uint8_t * aux8u = (device const uint8_t *)qu;
|
|
687
|
+
const uint32_t aux32g = qg[2] | (qg[3] << 16);
|
|
688
|
+
const uint32_t aux32u = qu[2] | (qu[3] << 16);
|
|
689
|
+
const float dg = (float)dhg[0] * (0.5f + (aux32g >> 28));
|
|
690
|
+
const float du = (float)dhu[0] * (0.5f + (aux32u >> 28));
|
|
691
|
+
|
|
692
|
+
float sg = 0;
|
|
693
|
+
float su = 0;
|
|
694
|
+
for (short l = 0; l < 4; ++l) {
|
|
695
|
+
const threadgroup uint8_t * gridg = (const threadgroup uint8_t *)(svalues + aux8g[l]);
|
|
696
|
+
const threadgroup uint8_t * gridu = (const threadgroup uint8_t *)(svalues + aux8u[l]);
|
|
697
|
+
const uint8_t signg = ssigns[(aux32g >> 7*l) & 127];
|
|
698
|
+
const uint8_t signu = ssigns[(aux32u >> 7*l) & 127];
|
|
699
|
+
for (short j = 0; j < 8; ++j) {
|
|
700
|
+
const float v = yl[8*l + j];
|
|
701
|
+
sg += v * gridg[j] * (signg & ds4_metal_kmask_iq2xs[j] ? -1.f : 1.f);
|
|
702
|
+
su += v * gridu[j] * (signu & ds4_metal_kmask_iq2xs[j] ? -1.f : 1.f);
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
sumg[row] += dg * sg;
|
|
706
|
+
sumu[row] += du * su;
|
|
707
|
+
|
|
708
|
+
dhg += args.nb01/2;
|
|
709
|
+
dhu += args.nb01/2;
|
|
710
|
+
qg += args.nb01/2;
|
|
711
|
+
qu += args.nb01/2;
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
y4 += 32 * 32;
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
device float * dst_gate_f32 = (device float *) dst_gate + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
718
|
+
device float * dst_up_f32 = (device float *) dst_up + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
719
|
+
|
|
720
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
|
721
|
+
const float sum_gate = simd_sum(sumg[row]);
|
|
722
|
+
const float sum_up = simd_sum(sumu[row]);
|
|
723
|
+
if (tiisg == 0) {
|
|
724
|
+
dst_gate_f32[first_row + row] = sum_gate * 0.25f;
|
|
725
|
+
dst_up_f32[first_row + row] = sum_up * 0.25f;
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
typedef void (kernel_mul_mv2_disp_t)(
|
|
731
|
+
ds4_metal_args_mul_mv args,
|
|
732
|
+
device const char * src0,
|
|
733
|
+
device const char * src1,
|
|
734
|
+
device char * dst,
|
|
735
|
+
threadgroup char * shmem,
|
|
736
|
+
uint3 tgpig,
|
|
737
|
+
ushort tiisg,
|
|
738
|
+
ushort sgitg);
|
|
739
|
+
|
|
740
|
+
template<kernel_mul_mv2_disp_t disp_fn>
|
|
741
|
+
void mmv_fn(
|
|
742
|
+
ds4_metal_args_mul_mv args,
|
|
743
|
+
device const char * src0,
|
|
744
|
+
device const char * src1,
|
|
745
|
+
device char * dst,
|
|
746
|
+
threadgroup char * shmem,
|
|
747
|
+
uint3 tgpig,
|
|
748
|
+
ushort tiitg,
|
|
749
|
+
ushort tiisg,
|
|
750
|
+
ushort sgitg) {
|
|
751
|
+
disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
typedef decltype(mmv_fn<kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K>>) mul_mv_id_disp_fn_t;
|
|
755
|
+
|
|
756
|
+
// Decode-time expert matvec. The ids tensor selects the routed expert for each
|
|
757
|
+
// slot, then this wrapper invokes the quantized row kernel for Q8_0, Q2_K, or
|
|
758
|
+
// IQ2_XXS weights without materializing per-expert dispatches on the CPU.
|
|
759
|
+
template<mul_mv_id_disp_fn_t disp_fn>
|
|
760
|
+
kernel void kernel_mul_mv_id(
|
|
761
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
762
|
+
device const char * src0s,
|
|
763
|
+
device const char * src1,
|
|
764
|
+
device char * dst,
|
|
765
|
+
device const char * ids,
|
|
766
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
767
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
768
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
769
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
770
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
771
|
+
(void)tiitg;
|
|
772
|
+
|
|
773
|
+
const int iid1 = tgpig.z/args.nei0;
|
|
774
|
+
const int idx = tgpig.z%args.nei0;
|
|
775
|
+
|
|
776
|
+
tgpig.z = 0;
|
|
777
|
+
|
|
778
|
+
const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
|
|
779
|
+
|
|
780
|
+
const int64_t i11 = idx % args.ne11;
|
|
781
|
+
const int64_t i12 = iid1;
|
|
782
|
+
|
|
783
|
+
const int64_t i1 = idx;
|
|
784
|
+
const int64_t i2 = i12;
|
|
785
|
+
|
|
786
|
+
device const char * src0_cur = src0s + i02*args.nb02;
|
|
787
|
+
device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
|
|
788
|
+
|
|
789
|
+
device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
|
|
790
|
+
|
|
791
|
+
ds4_metal_args_mul_mv args0 = {
|
|
792
|
+
/*.ne00 =*/ args.ne00,
|
|
793
|
+
/*.ne01 =*/ args.ne01,
|
|
794
|
+
/*.ne02 =*/ 1,
|
|
795
|
+
/*.nb00 =*/ args.nb00,
|
|
796
|
+
/*.nb01 =*/ args.nb01,
|
|
797
|
+
/*.nb02 =*/ args.nb02,
|
|
798
|
+
/*.nb03 =*/ args.nb02,
|
|
799
|
+
/*.ne10 =*/ args.ne10,
|
|
800
|
+
/*.ne11 =*/ 1,
|
|
801
|
+
/*.ne12 =*/ 1,
|
|
802
|
+
/*.nb10 =*/ args.nb10,
|
|
803
|
+
/*.nb11 =*/ args.nb11,
|
|
804
|
+
/*.nb12 =*/ args.nb12,
|
|
805
|
+
/*.nb13 =*/ args.nb12,
|
|
806
|
+
/*.ne0 =*/ args.ne0,
|
|
807
|
+
/*.ne1 =*/ 1,
|
|
808
|
+
/*.nr0 =*/ args.nr0,
|
|
809
|
+
/*.r2 =*/ 1,
|
|
810
|
+
/*.r3 =*/ 1,
|
|
811
|
+
};
|
|
812
|
+
|
|
813
|
+
disp_fn(
|
|
814
|
+
args0,
|
|
815
|
+
/* src0 */ src0_cur,
|
|
816
|
+
/* src1 */ src1_cur,
|
|
817
|
+
/* dst */ dst_cur,
|
|
818
|
+
shmem,
|
|
819
|
+
tgpig,
|
|
820
|
+
tiitg,
|
|
821
|
+
tiisg,
|
|
822
|
+
sgitg);
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K>>>) kernel_mul_mv_id_q_t;
|
|
826
|
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>) kernel_mul_mv_id_q8_0_t;
|
|
827
|
+
|
|
828
|
+
// Host-visible decode MoE matvec variants for the DS4 quant formats.
|
|
829
|
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_q8_0_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
|
830
|
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_q_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K>>>;
|
|
831
|
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_q_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K>>>;
|
|
832
|
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_q_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
|
|
833
|
+
|
|
834
|
+
// DS4 attention output low projection, specialized for the fixed block
|
|
835
|
+
// diagonal mapping used by the model:
|
|
836
|
+
//
|
|
837
|
+
// low[token, group, rank] = heads[token, group, :] * Woa[group, rank, :]
|
|
838
|
+
//
|
|
839
|
+
// The generic GGML-style id matvec supports arbitrary routed expert ids. Here
|
|
840
|
+
// the id is always equal to the group number, so this wrapper keeps the exact
|
|
841
|
+
// Q8_0 dot kernel but removes the id-buffer load and the CPU-side id table.
|
|
842
|
+
kernel void kernel_dsv4_attn_out_low_q8_0_f32(
|
|
843
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
844
|
+
device const char * src0s,
|
|
845
|
+
device const char * src1,
|
|
846
|
+
device char * dst,
|
|
847
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
848
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
849
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
850
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
851
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
852
|
+
const int iid1 = tgpig.z/args.nei0;
|
|
853
|
+
const int idx = tgpig.z%args.nei0;
|
|
854
|
+
|
|
855
|
+
tgpig.z = 0;
|
|
856
|
+
|
|
857
|
+
const int64_t i11 = idx % args.ne11;
|
|
858
|
+
const int64_t i12 = iid1;
|
|
859
|
+
|
|
860
|
+
device const char * src0_cur = src0s + idx*args.nb02;
|
|
861
|
+
device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
|
|
862
|
+
device char * dst_cur = dst + (idx*args.ne0 + i12*args.ne1*args.ne0)*sizeof(float);
|
|
863
|
+
|
|
864
|
+
ds4_metal_args_mul_mv args0 = {
|
|
865
|
+
/*.ne00 =*/ args.ne00,
|
|
866
|
+
/*.ne01 =*/ args.ne01,
|
|
867
|
+
/*.ne02 =*/ 1,
|
|
868
|
+
/*.nb00 =*/ args.nb00,
|
|
869
|
+
/*.nb01 =*/ args.nb01,
|
|
870
|
+
/*.nb02 =*/ args.nb02,
|
|
871
|
+
/*.nb03 =*/ args.nb02,
|
|
872
|
+
/*.ne10 =*/ args.ne10,
|
|
873
|
+
/*.ne11 =*/ 1,
|
|
874
|
+
/*.ne12 =*/ 1,
|
|
875
|
+
/*.nb10 =*/ args.nb10,
|
|
876
|
+
/*.nb11 =*/ args.nb11,
|
|
877
|
+
/*.nb12 =*/ args.nb12,
|
|
878
|
+
/*.nb13 =*/ args.nb12,
|
|
879
|
+
/*.ne0 =*/ args.ne0,
|
|
880
|
+
/*.ne1 =*/ 1,
|
|
881
|
+
/*.nr0 =*/ args.nr0,
|
|
882
|
+
/*.r2 =*/ 1,
|
|
883
|
+
/*.r3 =*/ 1,
|
|
884
|
+
};
|
|
885
|
+
|
|
886
|
+
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, thread ds4_metal_args_mul_mv &>(
|
|
887
|
+
args0,
|
|
888
|
+
src0_cur,
|
|
889
|
+
src1_cur,
|
|
890
|
+
dst_cur,
|
|
891
|
+
shmem,
|
|
892
|
+
tgpig,
|
|
893
|
+
tiisg,
|
|
894
|
+
sgitg);
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
kernel void kernel_mul_mv_id_iq2_xxs_pair_f32(
|
|
898
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
899
|
+
device const char * src0_gate,
|
|
900
|
+
device const char * src0_up,
|
|
901
|
+
device const char * src1,
|
|
902
|
+
device char * dst_gate,
|
|
903
|
+
device char * dst_up,
|
|
904
|
+
device const char * ids,
|
|
905
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
906
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
907
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
908
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
909
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
910
|
+
const int iid1 = tgpig.z/args.nei0;
|
|
911
|
+
const int idx = tgpig.z%args.nei0;
|
|
912
|
+
|
|
913
|
+
tgpig.z = 0;
|
|
914
|
+
|
|
915
|
+
const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
|
|
916
|
+
|
|
917
|
+
const int64_t i11 = idx % args.ne11;
|
|
918
|
+
const int64_t i12 = iid1;
|
|
919
|
+
|
|
920
|
+
device const char * src0_gate_cur = src0_gate + i02*args.nb02;
|
|
921
|
+
device const char * src0_up_cur = src0_up + i02*args.nb02;
|
|
922
|
+
device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
|
|
923
|
+
|
|
924
|
+
device char * dst_gate_cur = dst_gate + (idx*args.ne0 + i12*args.ne1*args.ne0)*sizeof(float);
|
|
925
|
+
device char * dst_up_cur = dst_up + (idx*args.ne0 + i12*args.ne1*args.ne0)*sizeof(float);
|
|
926
|
+
|
|
927
|
+
ds4_metal_args_mul_mv args0 = {
|
|
928
|
+
args.ne00, args.ne01, 1,
|
|
929
|
+
args.nb00, args.nb01, args.nb02, args.nb02,
|
|
930
|
+
args.ne10, 1, 1,
|
|
931
|
+
args.nb10, args.nb11, args.nb12, args.nb12,
|
|
932
|
+
args.ne0, 1, args.nr0, 1, 1,
|
|
933
|
+
};
|
|
934
|
+
|
|
935
|
+
(void)tiitg;
|
|
936
|
+
kernel_mul_mv_iq2_xxs_pair_f32_impl<N_R0_IQ2_XXS>(
|
|
937
|
+
args0,
|
|
938
|
+
src0_gate_cur,
|
|
939
|
+
src0_up_cur,
|
|
940
|
+
src1_cur,
|
|
941
|
+
dst_gate_cur,
|
|
942
|
+
dst_up_cur,
|
|
943
|
+
shmem,
|
|
944
|
+
tgpig,
|
|
945
|
+
tiisg,
|
|
946
|
+
sgitg);
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
// Decode-only routed expert gate/up projection fused with the DS4 activation:
|
|
950
|
+
//
|
|
951
|
+
// mid = silu(clamp(gate)) * clamp(up) * route_weight
|
|
952
|
+
//
|
|
953
|
+
// The quantized dot products are intentionally the same IQ2_XXS paired path as
|
|
954
|
+
// `kernel_mul_mv_id_iq2_xxs_pair_f32`. The only extra work is done by lane 0
|
|
955
|
+
// after each exact reduced row has been produced. This removes the separate
|
|
956
|
+
// routed activation dispatch and avoids rereading the gate/up rows before the
|
|
957
|
+
// down projection. The host uses this only for the normal release path where
|
|
958
|
+
// diagnostics do not request clamped gate/up intermediates.
|
|
959
|
+
kernel void kernel_mul_mv_id_iq2_xxs_pair_swiglu_f32(
|
|
960
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
961
|
+
constant ds4_metal_dsv4_moe_swiglu_weight_args & act,
|
|
962
|
+
device const char * src0_gate,
|
|
963
|
+
device const char * src0_up,
|
|
964
|
+
device const char * src1,
|
|
965
|
+
device char * dst_gate,
|
|
966
|
+
device char * dst_up,
|
|
967
|
+
device char * dst_mid,
|
|
968
|
+
device const char * ids,
|
|
969
|
+
device const char * weights,
|
|
970
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
971
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
972
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
973
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
974
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
975
|
+
const short NSG = FC_mul_mv_nsg;
|
|
976
|
+
const int iid1 = tgpig.z / args.nei0;
|
|
977
|
+
const int idx = tgpig.z % args.nei0;
|
|
978
|
+
|
|
979
|
+
tgpig.z = 0;
|
|
980
|
+
|
|
981
|
+
const int32_t i02 = ((device const int32_t *) (ids + iid1 * args.nbi1))[idx];
|
|
982
|
+
const int64_t i11 = idx % args.ne11;
|
|
983
|
+
const int64_t i12 = iid1;
|
|
984
|
+
|
|
985
|
+
const int nb = args.ne00 / QK_K;
|
|
986
|
+
const int first_row = (tgpig.x * NSG + sgitg) * N_R0_IQ2_XXS;
|
|
987
|
+
const int nb32 = nb * (QK_K / 32);
|
|
988
|
+
|
|
989
|
+
device const block_iq2_xxs *xg =
|
|
990
|
+
(device const block_iq2_xxs *)(src0_gate + i02 * args.nb02 + (uint64_t)first_row * args.nb01);
|
|
991
|
+
device const block_iq2_xxs *xu =
|
|
992
|
+
(device const block_iq2_xxs *)(src0_up + i02 * args.nb02 + (uint64_t)first_row * args.nb01);
|
|
993
|
+
device const float *y =
|
|
994
|
+
(device const float *)(src1 + i11 * args.nb11 + i12 * args.nb12);
|
|
995
|
+
|
|
996
|
+
float yl[32];
|
|
997
|
+
float sumg[N_R0_IQ2_XXS] = {0.f};
|
|
998
|
+
float sumu[N_R0_IQ2_XXS] = {0.f};
|
|
999
|
+
|
|
1000
|
+
threadgroup uint64_t *svalues = (threadgroup uint64_t *)(shmem);
|
|
1001
|
+
threadgroup uint8_t *ssigns = (threadgroup uint8_t *)(svalues + 256);
|
|
1002
|
+
{
|
|
1003
|
+
int nval = 4;
|
|
1004
|
+
int pos = (32 * sgitg + tiisg) * nval;
|
|
1005
|
+
for (int i = 0; i < nval; ++i) svalues[pos + i] = ds4_metal_iq2xxs_grid[pos + i];
|
|
1006
|
+
nval = 2;
|
|
1007
|
+
pos = (32 * sgitg + tiisg) * nval;
|
|
1008
|
+
for (int i = 0; i < nval; ++i) ssigns[pos + i] = ds4_metal_ksigns_iq2xs[pos + i];
|
|
1009
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
const int ix = tiisg;
|
|
1013
|
+
device const float *y4 = y + 32 * ix;
|
|
1014
|
+
|
|
1015
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
1016
|
+
for (short i = 0; i < 32; ++i) {
|
|
1017
|
+
yl[i] = y4[i];
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
1021
|
+
const int ib = ib32 % (QK_K / 32);
|
|
1022
|
+
|
|
1023
|
+
device const block_iq2_xxs *xgr = xg + ibl;
|
|
1024
|
+
device const block_iq2_xxs *xur = xu + ibl;
|
|
1025
|
+
device const uint16_t *qg = xgr->qs + 4 * ib;
|
|
1026
|
+
device const uint16_t *qu = xur->qs + 4 * ib;
|
|
1027
|
+
device const half *dhg = &xgr->d;
|
|
1028
|
+
device const half *dhu = &xur->d;
|
|
1029
|
+
|
|
1030
|
+
for (short row = 0; row < N_R0_IQ2_XXS; row++) {
|
|
1031
|
+
device const uint8_t *aux8g = (device const uint8_t *)qg;
|
|
1032
|
+
device const uint8_t *aux8u = (device const uint8_t *)qu;
|
|
1033
|
+
const uint32_t aux32g = qg[2] | (qg[3] << 16);
|
|
1034
|
+
const uint32_t aux32u = qu[2] | (qu[3] << 16);
|
|
1035
|
+
const float dg = (float)dhg[0] * (0.5f + (aux32g >> 28));
|
|
1036
|
+
const float du = (float)dhu[0] * (0.5f + (aux32u >> 28));
|
|
1037
|
+
|
|
1038
|
+
float sg = 0;
|
|
1039
|
+
float su = 0;
|
|
1040
|
+
for (short l = 0; l < 4; ++l) {
|
|
1041
|
+
const threadgroup uint8_t *gridg = (const threadgroup uint8_t *)(svalues + aux8g[l]);
|
|
1042
|
+
const threadgroup uint8_t *gridu = (const threadgroup uint8_t *)(svalues + aux8u[l]);
|
|
1043
|
+
const uint8_t signg = ssigns[(aux32g >> 7 * l) & 127];
|
|
1044
|
+
const uint8_t signu = ssigns[(aux32u >> 7 * l) & 127];
|
|
1045
|
+
for (short j = 0; j < 8; ++j) {
|
|
1046
|
+
const float v = yl[8 * l + j];
|
|
1047
|
+
sg += v * gridg[j] * (signg & ds4_metal_kmask_iq2xs[j] ? -1.f : 1.f);
|
|
1048
|
+
su += v * gridu[j] * (signu & ds4_metal_kmask_iq2xs[j] ? -1.f : 1.f);
|
|
1049
|
+
}
|
|
1050
|
+
}
|
|
1051
|
+
sumg[row] += dg * sg;
|
|
1052
|
+
sumu[row] += du * su;
|
|
1053
|
+
|
|
1054
|
+
dhg += args.nb01 / 2;
|
|
1055
|
+
dhu += args.nb01 / 2;
|
|
1056
|
+
qg += args.nb01 / 2;
|
|
1057
|
+
qu += args.nb01 / 2;
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
y4 += 32 * 32;
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
device float *dst_gate_f32 =
|
|
1064
|
+
(device float *)dst_gate + (uint64_t)i12 * args.ne0 * args.ne1 + (uint64_t)i11 * args.ne0;
|
|
1065
|
+
device float *dst_up_f32 =
|
|
1066
|
+
(device float *)dst_up + (uint64_t)i12 * args.ne0 * args.ne1 + (uint64_t)i11 * args.ne0;
|
|
1067
|
+
device float *dst_mid_f32 =
|
|
1068
|
+
(device float *)(dst_mid + (uint64_t)idx * act.mid_row_stride);
|
|
1069
|
+
device const float *route_w =
|
|
1070
|
+
(device const float *)(weights + (uint64_t)idx * act.weight_stride);
|
|
1071
|
+
|
|
1072
|
+
const float c = act.clamp_value;
|
|
1073
|
+
const float route_weight = route_w[0];
|
|
1074
|
+
for (int row = 0; row < N_R0_IQ2_XXS && first_row + row < args.ne0; ++row) {
|
|
1075
|
+
const float sum_gate = simd_sum(sumg[row]);
|
|
1076
|
+
const float sum_up = simd_sum(sumu[row]);
|
|
1077
|
+
if (tiisg == 0) {
|
|
1078
|
+
const uint out_row = first_row + row;
|
|
1079
|
+
const float gate = sum_gate * 0.25f;
|
|
1080
|
+
const float up = sum_up * 0.25f;
|
|
1081
|
+
float g = gate;
|
|
1082
|
+
float u = up;
|
|
1083
|
+
if (c > 1.0e-6f) {
|
|
1084
|
+
g = min(g, c);
|
|
1085
|
+
u = clamp(u, -c, c);
|
|
1086
|
+
}
|
|
1087
|
+
dst_gate_f32[out_row] = gate;
|
|
1088
|
+
dst_up_f32[out_row] = up;
|
|
1089
|
+
const float silu = g / (1.0f + exp(-g));
|
|
1090
|
+
dst_mid_f32[out_row] = silu * u * route_weight;
|
|
1091
|
+
}
|
|
1092
|
+
}
|
|
1093
|
+
|
|
1094
|
+
(void)tiitg;
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
kernel void kernel_mul_mv_id_q4_K_pair_f32(
|
|
1098
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
1099
|
+
device const char * src0_gate,
|
|
1100
|
+
device const char * src0_up,
|
|
1101
|
+
device const char * src1,
|
|
1102
|
+
device char * dst_gate,
|
|
1103
|
+
device char * dst_up,
|
|
1104
|
+
device const char * ids,
|
|
1105
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1106
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1107
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
1108
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1109
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1110
|
+
const int iid1 = tgpig.z / args.nei0;
|
|
1111
|
+
const int idx = tgpig.z % args.nei0;
|
|
1112
|
+
|
|
1113
|
+
tgpig.z = 0;
|
|
1114
|
+
|
|
1115
|
+
const int32_t i02 = ((device const int32_t *)(ids + iid1 * args.nbi1))[idx];
|
|
1116
|
+
const int64_t i11 = idx % args.ne11;
|
|
1117
|
+
const int64_t i12 = iid1;
|
|
1118
|
+
|
|
1119
|
+
device const char *src0_gate_cur = src0_gate + i02 * args.nb02;
|
|
1120
|
+
device const char *src0_up_cur = src0_up + i02 * args.nb02;
|
|
1121
|
+
device const char *src1_cur = src1 + i11 * args.nb11 + i12 * args.nb12;
|
|
1122
|
+
|
|
1123
|
+
device char *dst_gate_cur = dst_gate + (idx * args.ne0 + i12 * args.ne1 * args.ne0) * sizeof(float);
|
|
1124
|
+
device char *dst_up_cur = dst_up + (idx * args.ne0 + i12 * args.ne1 * args.ne0) * sizeof(float);
|
|
1125
|
+
|
|
1126
|
+
ds4_metal_args_mul_mv args0 = {
|
|
1127
|
+
args.ne00, args.ne01, 1,
|
|
1128
|
+
args.nb00, args.nb01, args.nb02, args.nb02,
|
|
1129
|
+
args.ne10, 1, 1,
|
|
1130
|
+
args.nb10, args.nb11, args.nb12, args.nb12,
|
|
1131
|
+
args.ne0, 1, args.nr0, 1, 1,
|
|
1132
|
+
};
|
|
1133
|
+
|
|
1134
|
+
(void)tiitg;
|
|
1135
|
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K>(
|
|
1136
|
+
args0,
|
|
1137
|
+
src0_gate_cur,
|
|
1138
|
+
src1_cur,
|
|
1139
|
+
dst_gate_cur,
|
|
1140
|
+
shmem,
|
|
1141
|
+
tgpig,
|
|
1142
|
+
tiisg,
|
|
1143
|
+
sgitg);
|
|
1144
|
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K>(
|
|
1145
|
+
args0,
|
|
1146
|
+
src0_up_cur,
|
|
1147
|
+
src1_cur,
|
|
1148
|
+
dst_up_cur,
|
|
1149
|
+
shmem,
|
|
1150
|
+
tgpig,
|
|
1151
|
+
tiisg,
|
|
1152
|
+
sgitg);
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
// Same release-path fusion as the IQ2_XXS kernel above for the Q4_K expert
|
|
1156
|
+
// variant. The Q4 pair path reuses the existing exact matvec implementation
|
|
1157
|
+
// for gate and up, then the same lane that wrote each row derives the routed
|
|
1158
|
+
// SwiGLU input. This keeps Q4 behavior aligned with the Q2 optimization while
|
|
1159
|
+
// preserving the old pair projection arithmetic.
|
|
1160
|
+
kernel void kernel_mul_mv_id_q4_K_pair_swiglu_f32(
|
|
1161
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
1162
|
+
constant ds4_metal_dsv4_moe_swiglu_weight_args & act,
|
|
1163
|
+
device const char * src0_gate,
|
|
1164
|
+
device const char * src0_up,
|
|
1165
|
+
device const char * src1,
|
|
1166
|
+
device char * dst_gate,
|
|
1167
|
+
device char * dst_up,
|
|
1168
|
+
device char * dst_mid,
|
|
1169
|
+
device const char * ids,
|
|
1170
|
+
device const char * weights,
|
|
1171
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1172
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1173
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
1174
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1175
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1176
|
+
const int iid1 = tgpig.z / args.nei0;
|
|
1177
|
+
const int idx = tgpig.z % args.nei0;
|
|
1178
|
+
|
|
1179
|
+
tgpig.z = 0;
|
|
1180
|
+
|
|
1181
|
+
const int32_t i02 = ((device const int32_t *)(ids + iid1 * args.nbi1))[idx];
|
|
1182
|
+
const int64_t i11 = idx % args.ne11;
|
|
1183
|
+
const int64_t i12 = iid1;
|
|
1184
|
+
|
|
1185
|
+
device const char *src0_gate_cur = src0_gate + i02 * args.nb02;
|
|
1186
|
+
device const char *src0_up_cur = src0_up + i02 * args.nb02;
|
|
1187
|
+
device const char *src1_cur = src1 + i11 * args.nb11 + i12 * args.nb12;
|
|
1188
|
+
|
|
1189
|
+
device char *dst_gate_cur = dst_gate + (idx * args.ne0 + i12 * args.ne1 * args.ne0) * sizeof(float);
|
|
1190
|
+
device char *dst_up_cur = dst_up + (idx * args.ne0 + i12 * args.ne1 * args.ne0) * sizeof(float);
|
|
1191
|
+
|
|
1192
|
+
ds4_metal_args_mul_mv args0 = {
|
|
1193
|
+
args.ne00, args.ne01, 1,
|
|
1194
|
+
args.nb00, args.nb01, args.nb02, args.nb02,
|
|
1195
|
+
args.ne10, 1, 1,
|
|
1196
|
+
args.nb10, args.nb11, args.nb12, args.nb12,
|
|
1197
|
+
args.ne0, 1, args.nr0, 1, 1,
|
|
1198
|
+
};
|
|
1199
|
+
|
|
1200
|
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K>(
|
|
1201
|
+
args0,
|
|
1202
|
+
src0_gate_cur,
|
|
1203
|
+
src1_cur,
|
|
1204
|
+
dst_gate_cur,
|
|
1205
|
+
shmem,
|
|
1206
|
+
tgpig,
|
|
1207
|
+
tiisg,
|
|
1208
|
+
sgitg);
|
|
1209
|
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K>(
|
|
1210
|
+
args0,
|
|
1211
|
+
src0_up_cur,
|
|
1212
|
+
src1_cur,
|
|
1213
|
+
dst_up_cur,
|
|
1214
|
+
shmem,
|
|
1215
|
+
tgpig,
|
|
1216
|
+
tiisg,
|
|
1217
|
+
sgitg);
|
|
1218
|
+
|
|
1219
|
+
const short NSG = FC_mul_mv_nsg;
|
|
1220
|
+
const int first_row = (tgpig.x * NSG + sgitg) * N_R0_Q4_K;
|
|
1221
|
+
device float *gate_f32 = (device float *)dst_gate_cur;
|
|
1222
|
+
device float *up_f32 = (device float *)dst_up_cur;
|
|
1223
|
+
device float *mid_f32 = (device float *)(dst_mid + (uint64_t)idx * act.mid_row_stride);
|
|
1224
|
+
device const float *route_w = (device const float *)(weights + (uint64_t)idx * act.weight_stride);
|
|
1225
|
+
const float c = act.clamp_value;
|
|
1226
|
+
const float route_weight = route_w[0];
|
|
1227
|
+
|
|
1228
|
+
if (tiisg == 0) {
|
|
1229
|
+
for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) {
|
|
1230
|
+
const uint out_row = first_row + row;
|
|
1231
|
+
float g = gate_f32[out_row];
|
|
1232
|
+
float u = up_f32[out_row];
|
|
1233
|
+
if (c > 1.0e-6f) {
|
|
1234
|
+
g = min(g, c);
|
|
1235
|
+
u = clamp(u, -c, c);
|
|
1236
|
+
}
|
|
1237
|
+
const float silu = g / (1.0f + exp(-g));
|
|
1238
|
+
mid_f32[out_row] = silu * u * route_weight;
|
|
1239
|
+
}
|
|
1240
|
+
}
|
|
1241
|
+
|
|
1242
|
+
(void)tiitg;
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
kernel void kernel_mul_mv_id_q2_K_sum6_f32(
|
|
1246
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
1247
|
+
device const char * src0s,
|
|
1248
|
+
device const char * src1,
|
|
1249
|
+
device char * dst,
|
|
1250
|
+
device const char * ids,
|
|
1251
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1252
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1253
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
1254
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1255
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1256
|
+
const short NSG = FC_mul_mv_nsg;
|
|
1257
|
+
const short nr0 = N_R0_Q2_K;
|
|
1258
|
+
const int nb = args.ne00/QK_K;
|
|
1259
|
+
const int first_row = (tgpig.x * NSG + sgitg) * nr0;
|
|
1260
|
+
const uint token = tgpig.y;
|
|
1261
|
+
device const int32_t *token_ids = (device const int32_t *)(ids + (uint64_t)token * args.nbi1);
|
|
1262
|
+
device const char *token_src1 = src1 + (uint64_t)token * args.nb12;
|
|
1263
|
+
|
|
1264
|
+
float sumf[nr0] = {0.f};
|
|
1265
|
+
|
|
1266
|
+
const short ix = tiisg/8;
|
|
1267
|
+
const short it = tiisg%8;
|
|
1268
|
+
const short iq = it/4;
|
|
1269
|
+
const short ir = it%4;
|
|
1270
|
+
const short is = (8*ir)/16;
|
|
1271
|
+
|
|
1272
|
+
for (int expert_slot = 0; expert_slot < 6; expert_slot++) {
|
|
1273
|
+
const int32_t expert = token_ids[expert_slot];
|
|
1274
|
+
device const block_q2_K * x = (device const block_q2_K *)(src0s + expert*args.nb02 + first_row*args.nb01);
|
|
1275
|
+
device const float * y = (device const float *)(token_src1 + expert_slot*args.nb11);
|
|
1276
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
|
1277
|
+
|
|
1278
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
|
1279
|
+
float yl[32];
|
|
1280
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
1281
|
+
for (short i = 0; i < 8; ++i) {
|
|
1282
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
|
1283
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
|
1284
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
|
1285
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
|
1289
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1290
|
+
device const half * dh = &x[ib].d;
|
|
1291
|
+
|
|
1292
|
+
for (short row = 0; row < nr0; row++) {
|
|
1293
|
+
if (first_row + row < args.ne0) {
|
|
1294
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
|
1295
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
|
1296
|
+
for (int i = 0; i < 8; i += 2) {
|
|
1297
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
|
1298
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
|
1299
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
|
1300
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
|
1301
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
|
1302
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
|
1303
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
|
1304
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
|
1305
|
+
}
|
|
1306
|
+
float dall = dh[0];
|
|
1307
|
+
float dmin = dh[1] * 1.f/16.f;
|
|
1308
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
|
1309
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
|
1310
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
|
1311
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
|
1312
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) +
|
|
1313
|
+
sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
|
1314
|
+
}
|
|
1315
|
+
|
|
1316
|
+
qs += args.nb01/2;
|
|
1317
|
+
sc += args.nb01;
|
|
1318
|
+
dh += args.nb01/2;
|
|
1319
|
+
}
|
|
1320
|
+
|
|
1321
|
+
y4 += 4 * QK_K;
|
|
1322
|
+
}
|
|
1323
|
+
}
|
|
1324
|
+
|
|
1325
|
+
device float * dst_f32 = (device float *)(dst + (uint64_t)token * args.nb1);
|
|
1326
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; row++) {
|
|
1327
|
+
const float sum_all = simd_sum(sumf[row]);
|
|
1328
|
+
if (tiisg == 0) dst_f32[first_row + row] = sum_all;
|
|
1329
|
+
}
|
|
1330
|
+
|
|
1331
|
+
(void)shmem;
|
|
1332
|
+
(void)tiitg;
|
|
1333
|
+
(void)tgpig;
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1336
|
+
kernel void kernel_mul_mv_id_q4_K_sum6_f32(
|
|
1337
|
+
constant ds4_metal_args_mul_mv_id & args,
|
|
1338
|
+
device const char * src0s,
|
|
1339
|
+
device const char * src1,
|
|
1340
|
+
device char * dst,
|
|
1341
|
+
device const char * ids,
|
|
1342
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1343
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1344
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
1345
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1346
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1347
|
+
const short NSG = FC_mul_mv_nsg;
|
|
1348
|
+
const short nr0 = N_R0_Q4_K;
|
|
1349
|
+
const int nb = args.ne00 / QK_K;
|
|
1350
|
+
const int first_row = (tgpig.x * NSG + sgitg) * nr0;
|
|
1351
|
+
const uint token = tgpig.y;
|
|
1352
|
+
device const int32_t *token_ids = (device const int32_t *)(ids + (uint64_t)token * args.nbi1);
|
|
1353
|
+
device const char *token_src1 = src1 + (uint64_t)token * args.nb12;
|
|
1354
|
+
|
|
1355
|
+
constexpr uint16_t kmask1 = 0x3f3f;
|
|
1356
|
+
constexpr uint16_t kmask2 = 0x0f0f;
|
|
1357
|
+
constexpr uint16_t kmask3 = 0xc0c0;
|
|
1358
|
+
|
|
1359
|
+
const short ix = tiisg / 8;
|
|
1360
|
+
const short it = tiisg % 8;
|
|
1361
|
+
const short iq = it / 4;
|
|
1362
|
+
const short ir = it % 4;
|
|
1363
|
+
|
|
1364
|
+
float sumf[nr0] = {0.f};
|
|
1365
|
+
uint16_t sc16[4];
|
|
1366
|
+
thread const uint8_t *sc8 = (thread const uint8_t *)sc16;
|
|
1367
|
+
|
|
1368
|
+
for (int expert_slot = 0; expert_slot < 6; expert_slot++) {
|
|
1369
|
+
const int32_t expert = token_ids[expert_slot];
|
|
1370
|
+
device const block_q4_K *x =
|
|
1371
|
+
(device const block_q4_K *)(src0s + expert * args.nb02 + first_row * args.nb01);
|
|
1372
|
+
device const float *y = (device const float *)(token_src1 + expert_slot * args.nb11);
|
|
1373
|
+
device const float *y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
1374
|
+
|
|
1375
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
|
1376
|
+
float yl[16];
|
|
1377
|
+
float yh[16];
|
|
1378
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
1379
|
+
|
|
1380
|
+
for (short i = 0; i < 8; ++i) {
|
|
1381
|
+
yl[i + 0] = y4[i + 0]; sumy[0] += yl[i + 0];
|
|
1382
|
+
yl[i + 8] = y4[i + 32]; sumy[1] += yl[i + 8];
|
|
1383
|
+
yh[i + 0] = y4[i + 128]; sumy[2] += yh[i + 0];
|
|
1384
|
+
yh[i + 8] = y4[i + 160]; sumy[3] += yh[i + 8];
|
|
1385
|
+
}
|
|
1386
|
+
|
|
1387
|
+
device const uint16_t *sc = (device const uint16_t *)x[ib].scales + iq;
|
|
1388
|
+
device const uint16_t *q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1389
|
+
device const half *dh = &x[ib].d;
|
|
1390
|
+
|
|
1391
|
+
for (short row = 0; row < nr0; row++) {
|
|
1392
|
+
if (first_row + row < args.ne0) {
|
|
1393
|
+
sc16[0] = sc[0] & kmask1;
|
|
1394
|
+
sc16[1] = sc[2] & kmask1;
|
|
1395
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
1396
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
|
1397
|
+
|
|
1398
|
+
device const uint16_t *q2 = q1 + 32;
|
|
1399
|
+
|
|
1400
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
|
1401
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
|
1402
|
+
|
|
1403
|
+
FOR_UNROLL (short i = 0; i < 4; ++i) {
|
|
1404
|
+
acc1[0] += yl[2 * i + 0] * (q1[i] & 0x000F);
|
|
1405
|
+
acc1[1] += yl[2 * i + 1] * (q1[i] & 0x0F00);
|
|
1406
|
+
acc1[2] += yl[2 * i + 8] * (q1[i] & 0x00F0);
|
|
1407
|
+
acc1[3] += yl[2 * i + 9] * (q1[i] & 0xF000);
|
|
1408
|
+
acc2[0] += yh[2 * i + 0] * (q2[i] & 0x000F);
|
|
1409
|
+
acc2[1] += yh[2 * i + 1] * (q2[i] & 0x0F00);
|
|
1410
|
+
acc2[2] += yh[2 * i + 8] * (q2[i] & 0x00F0);
|
|
1411
|
+
acc2[3] += yh[2 * i + 9] * (q2[i] & 0xF000);
|
|
1412
|
+
}
|
|
1413
|
+
|
|
1414
|
+
sumf[row] += dh[0] * ((acc1[0] + 1.f / 256.f * acc1[1]) * sc8[0] +
|
|
1415
|
+
(acc1[2] + 1.f / 256.f * acc1[3]) * sc8[1] * 1.f / 16.f +
|
|
1416
|
+
(acc2[0] + 1.f / 256.f * acc2[1]) * sc8[4] +
|
|
1417
|
+
(acc2[2] + 1.f / 256.f * acc2[3]) * sc8[5] * 1.f / 16.f) -
|
|
1418
|
+
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] +
|
|
1419
|
+
sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
1420
|
+
}
|
|
1421
|
+
|
|
1422
|
+
q1 += args.nb01 / 2;
|
|
1423
|
+
sc += args.nb01 / 2;
|
|
1424
|
+
dh += args.nb01 / 2;
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
y4 += 4 * QK_K;
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
device float *dst_f32 = (device float *)(dst + (uint64_t)token * args.nb1);
|
|
1432
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; row++) {
|
|
1433
|
+
const float sum_all = simd_sum(sumf[row]);
|
|
1434
|
+
if (tiisg == 0) dst_f32[first_row + row] = sum_all;
|
|
1435
|
+
}
|
|
1436
|
+
|
|
1437
|
+
(void)shmem;
|
|
1438
|
+
(void)tiitg;
|
|
1439
|
+
(void)tgpig;
|
|
1440
|
+
}
|
|
1441
|
+
|
|
1442
|
+
#define QK_NL 16
|
|
1443
|
+
|
|
1444
|
+
// Builds the compact per-expert work map used by batched MoE matmul. DS4 routes
|
|
1445
|
+
// each token to a small fixed top-k list, so this turns token-major ids into
|
|
1446
|
+
// expert-major slices that the tiled matmul can consume.
|
|
1447
|
+
template<short ne20>
|
|
1448
|
+
kernel void kernel_mul_mm_id_map0(
|
|
1449
|
+
constant ds4_metal_args_mul_mm_id_map0 & args,
|
|
1450
|
+
device const char * src2,
|
|
1451
|
+
device char * htpe,
|
|
1452
|
+
device char * hids,
|
|
1453
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1454
|
+
ushort tpitg[[thread_position_in_threadgroup]],
|
|
1455
|
+
ushort ntg[[threads_per_threadgroup]]) {
|
|
1456
|
+
const short ide = tpitg;
|
|
1457
|
+
|
|
1458
|
+
uint32_t n_all = 0;
|
|
1459
|
+
|
|
1460
|
+
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
|
1461
|
+
|
|
1462
|
+
for (int i21 = 0; i21 < args.ne21; i21 += ntg) {
|
|
1463
|
+
if (i21 + tpitg < args.ne21) {
|
|
1464
|
+
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
|
1465
|
+
|
|
1466
|
+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
|
1467
|
+
|
|
1468
|
+
#pragma unroll(ne20)
|
|
1469
|
+
for (short i20 = 0; i20 < ne20; i20++) {
|
|
1470
|
+
sids[i20] = src2_i32[i20];
|
|
1471
|
+
}
|
|
1472
|
+
}
|
|
1473
|
+
|
|
1474
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1475
|
+
|
|
1476
|
+
for (short t = 0; t < ntg; t++) {
|
|
1477
|
+
if (i21 + t >= args.ne21) {
|
|
1478
|
+
break;
|
|
1479
|
+
}
|
|
1480
|
+
|
|
1481
|
+
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
|
1482
|
+
|
|
1483
|
+
short sel = 0;
|
|
1484
|
+
#pragma unroll(ne20)
|
|
1485
|
+
for (short i20 = 0; i20 < ne20; i20++) {
|
|
1486
|
+
sel += (sids[i20] == ide)*(i20 + 1);
|
|
1487
|
+
}
|
|
1488
|
+
|
|
1489
|
+
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
|
1490
|
+
|
|
1491
|
+
n_all += sel > 0;
|
|
1492
|
+
}
|
|
1493
|
+
|
|
1494
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
|
1498
|
+
tpe_u32[ide] = n_all;
|
|
1499
|
+
}
|
|
1500
|
+
|
|
1501
|
+
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
|
1502
|
+
|
|
1503
|
+
// Host-visible map builders for the routed-expert counts used by DS4 graph
|
|
1504
|
+
// shapes. Some arities are generic leftovers retained for nearby batch sizes.
|
|
1505
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
|
1506
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
|
1507
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
|
1508
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
|
|
1509
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
|
1510
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
1511
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
1512
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
|
1513
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
|
1514
|
+
|
|
1515
|
+
// Batched routed-expert matmul. It reads the expert-major map produced above,
|
|
1516
|
+
// loads selected expert weights, and writes results back to token-major slots
|
|
1517
|
+
// so the DS4 FFN can apply SwiGLU, weighting, and the down projection.
|
|
1518
|
+
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
1519
|
+
kernel void kernel_mul_mm_id(
|
|
1520
|
+
constant ds4_metal_args_mul_mm_id & args,
|
|
1521
|
+
device const char * src0,
|
|
1522
|
+
device const char * src1,
|
|
1523
|
+
device const char * htpe,
|
|
1524
|
+
device const char * hids,
|
|
1525
|
+
device char * dst,
|
|
1526
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1527
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1528
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
1529
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1530
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1531
|
+
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
1532
|
+
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
1533
|
+
|
|
1534
|
+
constexpr int NR0 = 64;
|
|
1535
|
+
constexpr int NR1 = 32;
|
|
1536
|
+
|
|
1537
|
+
constexpr int NK = 32;
|
|
1538
|
+
constexpr int NL0 = NK/16;
|
|
1539
|
+
constexpr int NL1 = NK/8;
|
|
1540
|
+
|
|
1541
|
+
const int im = tgpig.z;
|
|
1542
|
+
const int r0 = tgpig.y*NR0;
|
|
1543
|
+
const int r1 = tgpig.x*NR1;
|
|
1544
|
+
|
|
1545
|
+
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
|
1546
|
+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
|
1547
|
+
|
|
1548
|
+
const int32_t neh1 = tpe_u32[im];
|
|
1549
|
+
|
|
1550
|
+
if (r1 >= neh1) {
|
|
1551
|
+
return;
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
|
1555
|
+
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
|
|
1556
|
+
|
|
1557
|
+
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1;
|
|
1558
|
+
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1;
|
|
1559
|
+
|
|
1560
|
+
const short il0 = (tiitg % NL0);
|
|
1561
|
+
|
|
1562
|
+
short il = il0;
|
|
1563
|
+
|
|
1564
|
+
const int id = ids_i32[im*args.ne21 + r1 + lr1];
|
|
1565
|
+
|
|
1566
|
+
const short i11 = (id % args.ne20) % args.ne11;
|
|
1567
|
+
const short i12 = (id / args.ne20);
|
|
1568
|
+
const short i13 = 0;
|
|
1569
|
+
|
|
1570
|
+
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
|
1571
|
+
const short offset1 = il0/nl;
|
|
1572
|
+
|
|
1573
|
+
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
|
1574
|
+
|
|
1575
|
+
const short iy = 8*(tiitg % NL1);
|
|
1576
|
+
|
|
1577
|
+
device const T1 * y = (device const T1 *)(src1
|
|
1578
|
+
+ args.nb13*i13
|
|
1579
|
+
+ args.nb12*i12
|
|
1580
|
+
+ args.nb11*i11
|
|
1581
|
+
+ args.nb10*iy);
|
|
1582
|
+
|
|
1583
|
+
S0_8x8 ma[4];
|
|
1584
|
+
S1_8x8 mb[2];
|
|
1585
|
+
|
|
1586
|
+
simdgroup_float8x8 mc[8];
|
|
1587
|
+
|
|
1588
|
+
for (short i = 0; i < 8; i++){
|
|
1589
|
+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
|
1593
|
+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
1594
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1595
|
+
|
|
1596
|
+
for (short i = 0; i < 16; i++) {
|
|
1597
|
+
const short sx = 2*il0 + i/8;
|
|
1598
|
+
const short sy = (tiitg/NL0)/8;
|
|
1599
|
+
|
|
1600
|
+
const short lx = (tiitg/NL0)%8;
|
|
1601
|
+
const short ly = i%8;
|
|
1602
|
+
|
|
1603
|
+
const short ib = 8*sx + sy;
|
|
1604
|
+
|
|
1605
|
+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
1606
|
+
}
|
|
1607
|
+
} else {
|
|
1608
|
+
S0_4x4 temp_a;
|
|
1609
|
+
dequantize_func(x, il, temp_a);
|
|
1610
|
+
|
|
1611
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1612
|
+
|
|
1613
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
1614
|
+
const short sx = 2*il0 + i/8;
|
|
1615
|
+
const short sy = (tiitg/NL0)/8;
|
|
1616
|
+
|
|
1617
|
+
const short lx = (tiitg/NL0)%8;
|
|
1618
|
+
const short ly = i%8;
|
|
1619
|
+
|
|
1620
|
+
const short ib = 8*sx + sy;
|
|
1621
|
+
|
|
1622
|
+
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
|
1623
|
+
}
|
|
1624
|
+
}
|
|
1625
|
+
|
|
1626
|
+
if (FC_mul_mm_bc_inp) {
|
|
1627
|
+
for (short i = 0; i < 8; ++i) {
|
|
1628
|
+
const short sx = (tiitg%NL1);
|
|
1629
|
+
const short sy = (tiitg/NL1)/8;
|
|
1630
|
+
|
|
1631
|
+
const short lx = i;
|
|
1632
|
+
const short ly = (tiitg/NL1)%8;
|
|
1633
|
+
|
|
1634
|
+
const short ib = 4*sx + sy;
|
|
1635
|
+
|
|
1636
|
+
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
1637
|
+
}
|
|
1638
|
+
} else {
|
|
1639
|
+
const short sx = (tiitg%NL1);
|
|
1640
|
+
const short sy = (tiitg/NL1)/8;
|
|
1641
|
+
|
|
1642
|
+
const short ly = (tiitg/NL1)%8;
|
|
1643
|
+
|
|
1644
|
+
const short ib = 4*sx + sy;
|
|
1645
|
+
|
|
1646
|
+
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
1647
|
+
}
|
|
1648
|
+
|
|
1649
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
1650
|
+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
1651
|
+
|
|
1652
|
+
y += NK;
|
|
1653
|
+
|
|
1654
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1655
|
+
|
|
1656
|
+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
1657
|
+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
1658
|
+
|
|
1659
|
+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
|
1660
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
1661
|
+
|
|
1662
|
+
FOR_UNROLL (short i = 0; i < 4; i++) {
|
|
1663
|
+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
|
1664
|
+
}
|
|
1665
|
+
|
|
1666
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
1667
|
+
|
|
1668
|
+
FOR_UNROLL (short i = 0; i < 2; i++) {
|
|
1669
|
+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
|
1670
|
+
}
|
|
1671
|
+
|
|
1672
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
1673
|
+
|
|
1674
|
+
FOR_UNROLL (short i = 0; i < 8; i++){
|
|
1675
|
+
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
lsma += 8*64;
|
|
1679
|
+
lsmb += 4*64;
|
|
1680
|
+
}
|
|
1681
|
+
}
|
|
1682
|
+
|
|
1683
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1684
|
+
|
|
1685
|
+
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
|
1686
|
+
|
|
1687
|
+
for (short i = 0; i < 8; i++) {
|
|
1688
|
+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
|
1689
|
+
}
|
|
1690
|
+
|
|
1691
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1692
|
+
|
|
1693
|
+
for (short j = sgitg; j < nr1; j += 4) {
|
|
1694
|
+
const int idj = ids_i32[im*args.ne21 + r1 + j];
|
|
1695
|
+
|
|
1696
|
+
const short ide = idj % args.ne20;
|
|
1697
|
+
const short idt = idj / args.ne20;
|
|
1698
|
+
|
|
1699
|
+
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
|
|
1700
|
+
device float4 * D4 = (device float4 *) D;
|
|
1701
|
+
|
|
1702
|
+
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
|
|
1703
|
+
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
|
1704
|
+
|
|
1705
|
+
int i = tiisg;
|
|
1706
|
+
for (; i < nr0/4; i += 32) {
|
|
1707
|
+
*(D4 + i) = *(C4 + i);
|
|
1708
|
+
}
|
|
1709
|
+
|
|
1710
|
+
i = (4*(nr0/4)) + tiisg;
|
|
1711
|
+
for (; i < nr0; i += 32) {
|
|
1712
|
+
*(D + i) = *(C + i);
|
|
1713
|
+
}
|
|
1714
|
+
}
|
|
1715
|
+
}
|
|
1716
|
+
|
|
1717
|
+
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>) mul_mm_id;
|
|
1718
|
+
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>) mul_mm_id_f16_rhs;
|
|
1719
|
+
|
|
1720
|
+
// Host-visible batched MoE matmul variants for the DS4 quant formats.
|
|
1721
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
|
1722
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
|
1723
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
|
1724
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
|
1725
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>;
|
|
1726
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>;
|
|
1727
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>;
|
|
1728
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>;
|
|
1729
|
+
|
|
1730
|
+
#undef QK_NL
|
|
1731
|
+
#undef kmask_iq2xs
|
|
1732
|
+
#undef ksigns_iq2xs
|
|
1733
|
+
#undef iq2xxs_grid
|
|
1734
|
+
#undef QK_K
|
|
1735
|
+
#undef N_R0_Q2_K
|
|
1736
|
+
#undef N_R0_Q4_K
|
|
1737
|
+
#undef N_R0_IQ2_XXS
|