node-llama-cpp 2.8.2 → 2.8.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-metal.metal +2674 -434
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-metal.metal +2674 -434
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/package.json +9 -1
|
@@ -59,26 +59,27 @@ kernel void kernel_add(
|
|
|
59
59
|
constant int64_t & ne01,
|
|
60
60
|
constant int64_t & ne02,
|
|
61
61
|
constant int64_t & ne03,
|
|
62
|
-
constant
|
|
63
|
-
constant
|
|
64
|
-
constant
|
|
65
|
-
constant
|
|
62
|
+
constant uint64_t & nb00,
|
|
63
|
+
constant uint64_t & nb01,
|
|
64
|
+
constant uint64_t & nb02,
|
|
65
|
+
constant uint64_t & nb03,
|
|
66
66
|
constant int64_t & ne10,
|
|
67
67
|
constant int64_t & ne11,
|
|
68
68
|
constant int64_t & ne12,
|
|
69
69
|
constant int64_t & ne13,
|
|
70
|
-
constant
|
|
71
|
-
constant
|
|
72
|
-
constant
|
|
73
|
-
constant
|
|
70
|
+
constant uint64_t & nb10,
|
|
71
|
+
constant uint64_t & nb11,
|
|
72
|
+
constant uint64_t & nb12,
|
|
73
|
+
constant uint64_t & nb13,
|
|
74
74
|
constant int64_t & ne0,
|
|
75
75
|
constant int64_t & ne1,
|
|
76
76
|
constant int64_t & ne2,
|
|
77
77
|
constant int64_t & ne3,
|
|
78
|
-
constant
|
|
79
|
-
constant
|
|
80
|
-
constant
|
|
81
|
-
constant
|
|
78
|
+
constant uint64_t & nb0,
|
|
79
|
+
constant uint64_t & nb1,
|
|
80
|
+
constant uint64_t & nb2,
|
|
81
|
+
constant uint64_t & nb3,
|
|
82
|
+
constant int64_t & offs,
|
|
82
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
83
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
84
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -90,9 +91,9 @@ kernel void kernel_add(
|
|
|
90
91
|
const int64_t i12 = i02 % ne12;
|
|
91
92
|
const int64_t i11 = i01 % ne11;
|
|
92
93
|
|
|
93
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
|
94
95
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
95
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
|
96
97
|
|
|
97
98
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
98
99
|
const int i10 = i0 % ne10;
|
|
@@ -108,26 +109,26 @@ kernel void kernel_mul(
|
|
|
108
109
|
constant int64_t & ne01,
|
|
109
110
|
constant int64_t & ne02,
|
|
110
111
|
constant int64_t & ne03,
|
|
111
|
-
constant
|
|
112
|
-
constant
|
|
113
|
-
constant
|
|
114
|
-
constant
|
|
112
|
+
constant uint64_t & nb00,
|
|
113
|
+
constant uint64_t & nb01,
|
|
114
|
+
constant uint64_t & nb02,
|
|
115
|
+
constant uint64_t & nb03,
|
|
115
116
|
constant int64_t & ne10,
|
|
116
117
|
constant int64_t & ne11,
|
|
117
118
|
constant int64_t & ne12,
|
|
118
119
|
constant int64_t & ne13,
|
|
119
|
-
constant
|
|
120
|
-
constant
|
|
121
|
-
constant
|
|
122
|
-
constant
|
|
120
|
+
constant uint64_t & nb10,
|
|
121
|
+
constant uint64_t & nb11,
|
|
122
|
+
constant uint64_t & nb12,
|
|
123
|
+
constant uint64_t & nb13,
|
|
123
124
|
constant int64_t & ne0,
|
|
124
125
|
constant int64_t & ne1,
|
|
125
126
|
constant int64_t & ne2,
|
|
126
127
|
constant int64_t & ne3,
|
|
127
|
-
constant
|
|
128
|
-
constant
|
|
129
|
-
constant
|
|
130
|
-
constant
|
|
128
|
+
constant uint64_t & nb0,
|
|
129
|
+
constant uint64_t & nb1,
|
|
130
|
+
constant uint64_t & nb2,
|
|
131
|
+
constant uint64_t & nb3,
|
|
131
132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
132
133
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
133
134
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -157,26 +158,26 @@ kernel void kernel_div(
|
|
|
157
158
|
constant int64_t & ne01,
|
|
158
159
|
constant int64_t & ne02,
|
|
159
160
|
constant int64_t & ne03,
|
|
160
|
-
constant
|
|
161
|
-
constant
|
|
162
|
-
constant
|
|
163
|
-
constant
|
|
161
|
+
constant uint64_t & nb00,
|
|
162
|
+
constant uint64_t & nb01,
|
|
163
|
+
constant uint64_t & nb02,
|
|
164
|
+
constant uint64_t & nb03,
|
|
164
165
|
constant int64_t & ne10,
|
|
165
166
|
constant int64_t & ne11,
|
|
166
167
|
constant int64_t & ne12,
|
|
167
168
|
constant int64_t & ne13,
|
|
168
|
-
constant
|
|
169
|
-
constant
|
|
170
|
-
constant
|
|
171
|
-
constant
|
|
169
|
+
constant uint64_t & nb10,
|
|
170
|
+
constant uint64_t & nb11,
|
|
171
|
+
constant uint64_t & nb12,
|
|
172
|
+
constant uint64_t & nb13,
|
|
172
173
|
constant int64_t & ne0,
|
|
173
174
|
constant int64_t & ne1,
|
|
174
175
|
constant int64_t & ne2,
|
|
175
176
|
constant int64_t & ne3,
|
|
176
|
-
constant
|
|
177
|
-
constant
|
|
178
|
-
constant
|
|
179
|
-
constant
|
|
177
|
+
constant uint64_t & nb0,
|
|
178
|
+
constant uint64_t & nb1,
|
|
179
|
+
constant uint64_t & nb2,
|
|
180
|
+
constant uint64_t & nb3,
|
|
180
181
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
181
182
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
182
183
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
|
|
204
205
|
device const float4 * src0,
|
|
205
206
|
device const float4 * src1,
|
|
206
207
|
device float4 * dst,
|
|
207
|
-
constant
|
|
208
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
208
209
|
uint tpig[[thread_position_in_grid]]) {
|
|
209
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
210
211
|
}
|
|
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
|
|
213
214
|
device const float4 * src0,
|
|
214
215
|
device const float4 * src1,
|
|
215
216
|
device float4 * dst,
|
|
216
|
-
constant
|
|
217
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
217
218
|
uint tpig[[thread_position_in_grid]]) {
|
|
218
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
219
220
|
}
|
|
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
|
|
222
223
|
device const float4 * src0,
|
|
223
224
|
device const float4 * src1,
|
|
224
225
|
device float4 * dst,
|
|
225
|
-
constant
|
|
226
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
226
227
|
uint tpig[[thread_position_in_grid]]) {
|
|
227
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
228
229
|
}
|
|
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
|
|
|
243
244
|
dst[tpig] = src0[tpig] * scale;
|
|
244
245
|
}
|
|
245
246
|
|
|
246
|
-
kernel void
|
|
247
|
-
device const
|
|
248
|
-
device
|
|
247
|
+
kernel void kernel_relu(
|
|
248
|
+
device const float * src0,
|
|
249
|
+
device float * dst,
|
|
249
250
|
uint tpig[[thread_position_in_grid]]) {
|
|
250
|
-
|
|
251
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
|
251
|
+
dst[tpig] = max(0.0f, src0[tpig]);
|
|
252
252
|
}
|
|
253
253
|
|
|
254
|
-
kernel void
|
|
254
|
+
kernel void kernel_tanh(
|
|
255
255
|
device const float * src0,
|
|
256
256
|
device float * dst,
|
|
257
257
|
uint tpig[[thread_position_in_grid]]) {
|
|
258
|
-
|
|
258
|
+
device const float & x = src0[tpig];
|
|
259
|
+
dst[tpig] = precise::tanh(x);
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
265
|
+
|
|
266
|
+
kernel void kernel_gelu(
|
|
267
|
+
device const float4 * src0,
|
|
268
|
+
device float4 * dst,
|
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
270
|
+
device const float4 & x = src0[tpig];
|
|
271
|
+
|
|
272
|
+
// BEWARE !!!
|
|
273
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
274
|
+
// This was observed with Falcon 7B and 40B models
|
|
275
|
+
//
|
|
276
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
kernel void kernel_gelu_quick(
|
|
280
|
+
device const float4 * src0,
|
|
281
|
+
device float4 * dst,
|
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
283
|
+
device const float4 & x = src0[tpig];
|
|
284
|
+
|
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
kernel void kernel_silu(
|
|
289
|
+
device const float4 * src0,
|
|
290
|
+
device float4 * dst,
|
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
292
|
+
device const float4 & x = src0[tpig];
|
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
|
259
294
|
}
|
|
260
295
|
|
|
261
296
|
kernel void kernel_sqr(
|
|
@@ -272,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
|
272
307
|
constant int64_t & ne01,
|
|
273
308
|
constant int64_t & ne02,
|
|
274
309
|
constant int64_t & ne03,
|
|
275
|
-
constant
|
|
276
|
-
constant
|
|
277
|
-
constant
|
|
278
|
-
constant
|
|
310
|
+
constant uint64_t & nb00,
|
|
311
|
+
constant uint64_t & nb01,
|
|
312
|
+
constant uint64_t & nb02,
|
|
313
|
+
constant uint64_t & nb03,
|
|
279
314
|
constant int64_t & ne10,
|
|
280
315
|
constant int64_t & ne11,
|
|
281
316
|
constant int64_t & ne12,
|
|
282
317
|
constant int64_t & ne13,
|
|
283
|
-
constant
|
|
284
|
-
constant
|
|
285
|
-
constant
|
|
286
|
-
constant
|
|
318
|
+
constant uint64_t & nb10,
|
|
319
|
+
constant uint64_t & nb11,
|
|
320
|
+
constant uint64_t & nb12,
|
|
321
|
+
constant uint64_t & nb13,
|
|
287
322
|
constant int64_t & ne0,
|
|
288
323
|
constant int64_t & ne1,
|
|
289
324
|
constant int64_t & ne2,
|
|
290
325
|
constant int64_t & ne3,
|
|
291
|
-
constant
|
|
292
|
-
constant
|
|
293
|
-
constant
|
|
294
|
-
constant
|
|
326
|
+
constant uint64_t & nb0,
|
|
327
|
+
constant uint64_t & nb1,
|
|
328
|
+
constant uint64_t & nb2,
|
|
329
|
+
constant uint64_t & nb3,
|
|
295
330
|
uint3 tpig[[thread_position_in_grid]]) {
|
|
296
331
|
int64_t i3 = tpig.z;
|
|
297
332
|
int64_t i2 = tpig.y;
|
|
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
|
|
313
348
|
dst_row[0] = row_sum;
|
|
314
349
|
}
|
|
315
350
|
|
|
316
|
-
constant float GELU_COEF_A = 0.044715f;
|
|
317
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
318
|
-
|
|
319
|
-
kernel void kernel_gelu(
|
|
320
|
-
device const float4 * src0,
|
|
321
|
-
device float4 * dst,
|
|
322
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
323
|
-
device const float4 & x = src0[tpig];
|
|
324
|
-
|
|
325
|
-
// BEWARE !!!
|
|
326
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
327
|
-
// This was observed with Falcon 7B and 40B models
|
|
328
|
-
//
|
|
329
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
330
|
-
}
|
|
331
|
-
|
|
332
351
|
kernel void kernel_soft_max(
|
|
333
352
|
device const float * src0,
|
|
334
353
|
device const float * src1,
|
|
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
|
|
|
347
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
348
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
349
368
|
|
|
350
|
-
device const float * psrc0 =
|
|
351
|
-
device const float * pmask = src1 ? src1
|
|
352
|
-
device float * pdst =
|
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
353
372
|
|
|
354
373
|
// parallel max
|
|
355
374
|
float lmax = -INFINITY;
|
|
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
|
|
|
385
404
|
pdst[i00] = exp_psrc0;
|
|
386
405
|
}
|
|
387
406
|
|
|
407
|
+
// This barrier fixes a failing test
|
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
410
|
+
|
|
388
411
|
float sum = simd_sum(lsum);
|
|
412
|
+
|
|
389
413
|
if (ntg > N_SIMDWIDTH) {
|
|
390
414
|
if (sgitg == 0) {
|
|
391
415
|
buf[tiisg] = 0.0f;
|
|
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
|
428
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
429
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
430
454
|
|
|
431
|
-
device const float4 * psrc4 =
|
|
432
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
433
|
-
device float4 * pdst4 =
|
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
434
458
|
|
|
435
459
|
// parallel max
|
|
436
460
|
float4 lmax4 = -INFINITY;
|
|
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
|
468
492
|
}
|
|
469
493
|
|
|
470
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
495
|
+
|
|
496
|
+
// This barrier fixes a failing test
|
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
499
|
+
|
|
471
500
|
float sum = simd_sum(lsum);
|
|
501
|
+
|
|
472
502
|
if (ntg > N_SIMDWIDTH) {
|
|
473
503
|
if (sgitg == 0) {
|
|
474
504
|
buf[tiisg] = 0.0f;
|
|
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
|
639
669
|
}
|
|
640
670
|
}
|
|
641
671
|
|
|
672
|
+
kernel void kernel_group_norm(
|
|
673
|
+
device const float * src0,
|
|
674
|
+
device float * dst,
|
|
675
|
+
constant int64_t & ne00,
|
|
676
|
+
constant int64_t & ne01,
|
|
677
|
+
constant int64_t & ne02,
|
|
678
|
+
constant uint64_t & nb00,
|
|
679
|
+
constant uint64_t & nb01,
|
|
680
|
+
constant uint64_t & nb02,
|
|
681
|
+
constant int32_t & n_groups,
|
|
682
|
+
constant float & eps,
|
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
|
691
|
+
|
|
692
|
+
int start = tgpig * gs;
|
|
693
|
+
int end = start + gs;
|
|
694
|
+
|
|
695
|
+
start += tpitg;
|
|
696
|
+
|
|
697
|
+
if (end >= ne) {
|
|
698
|
+
end = ne;
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
|
702
|
+
|
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
|
704
|
+
tmp += src0[j];
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
708
|
+
tmp = simd_sum(tmp);
|
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
|
710
|
+
if (sgitg == 0) {
|
|
711
|
+
buf[tiisg] = 0.0f;
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
715
|
+
|
|
716
|
+
if (tiisg == 0) {
|
|
717
|
+
buf[sgitg] = tmp;
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
721
|
+
|
|
722
|
+
tmp = buf[tiisg];
|
|
723
|
+
tmp = simd_sum(tmp);
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
const float mean = tmp / gs;
|
|
727
|
+
tmp = 0.0f;
|
|
728
|
+
|
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
|
730
|
+
float xi = src0[j] - mean;
|
|
731
|
+
dst[j] = xi;
|
|
732
|
+
tmp += xi * xi;
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
tmp = simd_sum(tmp);
|
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
|
737
|
+
if (sgitg == 0) {
|
|
738
|
+
buf[tiisg] = 0.0f;
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
742
|
+
|
|
743
|
+
if (tiisg == 0) {
|
|
744
|
+
buf[sgitg] = tmp;
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
748
|
+
|
|
749
|
+
tmp = buf[tiisg];
|
|
750
|
+
tmp = simd_sum(tmp);
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
const float variance = tmp / gs;
|
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
|
756
|
+
dst[j] *= scale;
|
|
757
|
+
}
|
|
758
|
+
}
|
|
759
|
+
|
|
642
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
643
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
|
644
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
@@ -728,10 +846,10 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
728
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
729
847
|
//Note: This is a template, but strictly speaking it only applies to
|
|
730
848
|
// quantizations where the block size is 32. It also does not
|
|
731
|
-
//
|
|
849
|
+
// guard against the number of rows not being divisible by
|
|
732
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
733
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
734
|
-
void
|
|
852
|
+
void mul_vec_q_n_f32_impl(
|
|
735
853
|
device const void * src0,
|
|
736
854
|
device const float * src1,
|
|
737
855
|
device float * dst,
|
|
@@ -802,18 +920,25 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
802
920
|
device const float * src1,
|
|
803
921
|
device float * dst,
|
|
804
922
|
constant int64_t & ne00,
|
|
805
|
-
constant int64_t & ne01
|
|
806
|
-
constant int64_t & ne02
|
|
807
|
-
constant
|
|
808
|
-
constant
|
|
809
|
-
constant
|
|
810
|
-
constant int64_t &
|
|
811
|
-
constant
|
|
812
|
-
constant
|
|
923
|
+
constant int64_t & ne01,
|
|
924
|
+
constant int64_t & ne02,
|
|
925
|
+
constant uint64_t & nb00,
|
|
926
|
+
constant uint64_t & nb01,
|
|
927
|
+
constant uint64_t & nb02,
|
|
928
|
+
constant int64_t & ne10,
|
|
929
|
+
constant int64_t & ne11,
|
|
930
|
+
constant int64_t & ne12,
|
|
931
|
+
constant uint64_t & nb10,
|
|
932
|
+
constant uint64_t & nb11,
|
|
933
|
+
constant uint64_t & nb12,
|
|
934
|
+
constant int64_t & ne0,
|
|
935
|
+
constant int64_t & ne1,
|
|
936
|
+
constant uint & r2,
|
|
937
|
+
constant uint & r3,
|
|
813
938
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
814
939
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
815
940
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
816
|
-
|
|
941
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
817
942
|
}
|
|
818
943
|
|
|
819
944
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -821,18 +946,25 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
821
946
|
device const float * src1,
|
|
822
947
|
device float * dst,
|
|
823
948
|
constant int64_t & ne00,
|
|
824
|
-
constant int64_t & ne01
|
|
825
|
-
constant int64_t & ne02
|
|
826
|
-
constant
|
|
827
|
-
constant
|
|
828
|
-
constant
|
|
829
|
-
constant int64_t &
|
|
830
|
-
constant
|
|
831
|
-
constant
|
|
949
|
+
constant int64_t & ne01,
|
|
950
|
+
constant int64_t & ne02,
|
|
951
|
+
constant uint64_t & nb00,
|
|
952
|
+
constant uint64_t & nb01,
|
|
953
|
+
constant uint64_t & nb02,
|
|
954
|
+
constant int64_t & ne10,
|
|
955
|
+
constant int64_t & ne11,
|
|
956
|
+
constant int64_t & ne12,
|
|
957
|
+
constant uint64_t & nb10,
|
|
958
|
+
constant uint64_t & nb11,
|
|
959
|
+
constant uint64_t & nb12,
|
|
960
|
+
constant int64_t & ne0,
|
|
961
|
+
constant int64_t & ne1,
|
|
962
|
+
constant uint & r2,
|
|
963
|
+
constant uint & r3,
|
|
832
964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
833
965
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
834
966
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
835
|
-
|
|
967
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
836
968
|
}
|
|
837
969
|
|
|
838
970
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -840,18 +972,25 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
840
972
|
device const float * src1,
|
|
841
973
|
device float * dst,
|
|
842
974
|
constant int64_t & ne00,
|
|
843
|
-
constant int64_t & ne01
|
|
844
|
-
constant int64_t & ne02
|
|
845
|
-
constant
|
|
846
|
-
constant
|
|
847
|
-
constant
|
|
848
|
-
constant int64_t &
|
|
849
|
-
constant
|
|
850
|
-
constant
|
|
975
|
+
constant int64_t & ne01,
|
|
976
|
+
constant int64_t & ne02,
|
|
977
|
+
constant uint64_t & nb00,
|
|
978
|
+
constant uint64_t & nb01,
|
|
979
|
+
constant uint64_t & nb02,
|
|
980
|
+
constant int64_t & ne10,
|
|
981
|
+
constant int64_t & ne11,
|
|
982
|
+
constant int64_t & ne12,
|
|
983
|
+
constant uint64_t & nb10,
|
|
984
|
+
constant uint64_t & nb11,
|
|
985
|
+
constant uint64_t & nb12,
|
|
986
|
+
constant int64_t & ne0,
|
|
987
|
+
constant int64_t & ne1,
|
|
988
|
+
constant uint & r2,
|
|
989
|
+
constant uint & r3,
|
|
851
990
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
852
991
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
853
992
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
854
|
-
|
|
993
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
855
994
|
}
|
|
856
995
|
|
|
857
996
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -859,39 +998,46 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
859
998
|
device const float * src1,
|
|
860
999
|
device float * dst,
|
|
861
1000
|
constant int64_t & ne00,
|
|
862
|
-
constant int64_t & ne01
|
|
863
|
-
constant int64_t & ne02
|
|
864
|
-
constant
|
|
865
|
-
constant
|
|
866
|
-
constant
|
|
867
|
-
constant int64_t &
|
|
868
|
-
constant
|
|
869
|
-
constant
|
|
1001
|
+
constant int64_t & ne01,
|
|
1002
|
+
constant int64_t & ne02,
|
|
1003
|
+
constant uint64_t & nb00,
|
|
1004
|
+
constant uint64_t & nb01,
|
|
1005
|
+
constant uint64_t & nb02,
|
|
1006
|
+
constant int64_t & ne10,
|
|
1007
|
+
constant int64_t & ne11,
|
|
1008
|
+
constant int64_t & ne12,
|
|
1009
|
+
constant uint64_t & nb10,
|
|
1010
|
+
constant uint64_t & nb11,
|
|
1011
|
+
constant uint64_t & nb12,
|
|
1012
|
+
constant int64_t & ne0,
|
|
1013
|
+
constant int64_t & ne1,
|
|
1014
|
+
constant uint & r2,
|
|
1015
|
+
constant uint & r3,
|
|
870
1016
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
871
1017
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
872
1018
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
873
|
-
|
|
1019
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
874
1020
|
}
|
|
875
1021
|
|
|
876
1022
|
|
|
877
1023
|
#define NB_Q8_0 8
|
|
878
1024
|
|
|
879
|
-
|
|
1025
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
|
880
1026
|
device const void * src0,
|
|
881
1027
|
device const float * src1,
|
|
882
1028
|
device float * dst,
|
|
883
1029
|
constant int64_t & ne00,
|
|
884
|
-
constant int64_t & ne01
|
|
885
|
-
constant int64_t & ne02
|
|
886
|
-
constant int64_t & ne10
|
|
887
|
-
constant int64_t & ne12
|
|
888
|
-
constant int64_t & ne0
|
|
889
|
-
constant int64_t & ne1
|
|
890
|
-
constant uint & r2
|
|
891
|
-
constant uint & r3
|
|
1030
|
+
constant int64_t & ne01,
|
|
1031
|
+
constant int64_t & ne02,
|
|
1032
|
+
constant int64_t & ne10,
|
|
1033
|
+
constant int64_t & ne12,
|
|
1034
|
+
constant int64_t & ne0,
|
|
1035
|
+
constant int64_t & ne1,
|
|
1036
|
+
constant uint & r2,
|
|
1037
|
+
constant uint & r3,
|
|
892
1038
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
893
|
-
uint
|
|
894
|
-
uint
|
|
1039
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
1040
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
895
1041
|
const int nr = N_DST;
|
|
896
1042
|
const int nsg = N_SIMDGROUP;
|
|
897
1043
|
const int nw = N_SIMDWIDTH;
|
|
@@ -945,9 +1091,36 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
945
1091
|
}
|
|
946
1092
|
}
|
|
947
1093
|
|
|
1094
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
|
1095
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
|
1096
|
+
device const void * src0,
|
|
1097
|
+
device const float * src1,
|
|
1098
|
+
device float * dst,
|
|
1099
|
+
constant int64_t & ne00,
|
|
1100
|
+
constant int64_t & ne01,
|
|
1101
|
+
constant int64_t & ne02,
|
|
1102
|
+
constant uint64_t & nb00,
|
|
1103
|
+
constant uint64_t & nb01,
|
|
1104
|
+
constant uint64_t & nb02,
|
|
1105
|
+
constant int64_t & ne10,
|
|
1106
|
+
constant int64_t & ne11,
|
|
1107
|
+
constant int64_t & ne12,
|
|
1108
|
+
constant uint64_t & nb10,
|
|
1109
|
+
constant uint64_t & nb11,
|
|
1110
|
+
constant uint64_t & nb12,
|
|
1111
|
+
constant int64_t & ne0,
|
|
1112
|
+
constant int64_t & ne1,
|
|
1113
|
+
constant uint & r2,
|
|
1114
|
+
constant uint & r3,
|
|
1115
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1116
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
1117
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1118
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
1119
|
+
}
|
|
1120
|
+
|
|
948
1121
|
#define N_F32_F32 4
|
|
949
1122
|
|
|
950
|
-
|
|
1123
|
+
void kernel_mul_mv_f32_f32_impl(
|
|
951
1124
|
device const char * src0,
|
|
952
1125
|
device const char * src1,
|
|
953
1126
|
device float * dst,
|
|
@@ -965,8 +1138,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
965
1138
|
constant uint64_t & nb12,
|
|
966
1139
|
constant int64_t & ne0,
|
|
967
1140
|
constant int64_t & ne1,
|
|
968
|
-
constant uint & r2
|
|
969
|
-
constant uint & r3
|
|
1141
|
+
constant uint & r2,
|
|
1142
|
+
constant uint & r3,
|
|
970
1143
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
971
1144
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
972
1145
|
|
|
@@ -1025,6 +1198,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
1025
1198
|
}
|
|
1026
1199
|
}
|
|
1027
1200
|
|
|
1201
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
|
1202
|
+
kernel void kernel_mul_mv_f32_f32(
|
|
1203
|
+
device const char * src0,
|
|
1204
|
+
device const char * src1,
|
|
1205
|
+
device float * dst,
|
|
1206
|
+
constant int64_t & ne00,
|
|
1207
|
+
constant int64_t & ne01,
|
|
1208
|
+
constant int64_t & ne02,
|
|
1209
|
+
constant uint64_t & nb00,
|
|
1210
|
+
constant uint64_t & nb01,
|
|
1211
|
+
constant uint64_t & nb02,
|
|
1212
|
+
constant int64_t & ne10,
|
|
1213
|
+
constant int64_t & ne11,
|
|
1214
|
+
constant int64_t & ne12,
|
|
1215
|
+
constant uint64_t & nb10,
|
|
1216
|
+
constant uint64_t & nb11,
|
|
1217
|
+
constant uint64_t & nb12,
|
|
1218
|
+
constant int64_t & ne0,
|
|
1219
|
+
constant int64_t & ne1,
|
|
1220
|
+
constant uint & r2,
|
|
1221
|
+
constant uint & r3,
|
|
1222
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1223
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1224
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1028
1227
|
#define N_F16_F16 4
|
|
1029
1228
|
|
|
1030
1229
|
kernel void kernel_mul_mv_f16_f16(
|
|
@@ -1045,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
|
1045
1244
|
constant uint64_t & nb12,
|
|
1046
1245
|
constant int64_t & ne0,
|
|
1047
1246
|
constant int64_t & ne1,
|
|
1048
|
-
constant uint & r2
|
|
1049
|
-
constant uint & r3
|
|
1247
|
+
constant uint & r2,
|
|
1248
|
+
constant uint & r3,
|
|
1050
1249
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1051
1250
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1052
1251
|
|
|
@@ -1105,7 +1304,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
|
1105
1304
|
}
|
|
1106
1305
|
}
|
|
1107
1306
|
|
|
1108
|
-
|
|
1307
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
|
1109
1308
|
device const char * src0,
|
|
1110
1309
|
device const char * src1,
|
|
1111
1310
|
device float * dst,
|
|
@@ -1123,8 +1322,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
1123
1322
|
constant uint64_t & nb12,
|
|
1124
1323
|
constant int64_t & ne0,
|
|
1125
1324
|
constant int64_t & ne1,
|
|
1126
|
-
constant uint & r2
|
|
1127
|
-
constant uint & r3
|
|
1325
|
+
constant uint & r2,
|
|
1326
|
+
constant uint & r3,
|
|
1128
1327
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1129
1328
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1130
1329
|
|
|
@@ -1161,12 +1360,10 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
1161
1360
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1162
1361
|
}
|
|
1163
1362
|
}
|
|
1164
|
-
|
|
1165
1363
|
}
|
|
1166
1364
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
kernel void kernel_mul_mv_f16_f32(
|
|
1365
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
|
1366
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
|
1170
1367
|
device const char * src0,
|
|
1171
1368
|
device const char * src1,
|
|
1172
1369
|
device float * dst,
|
|
@@ -1184,21 +1381,48 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
1184
1381
|
constant uint64_t & nb12,
|
|
1185
1382
|
constant int64_t & ne0,
|
|
1186
1383
|
constant int64_t & ne1,
|
|
1187
|
-
constant uint & r2
|
|
1188
|
-
constant uint & r3
|
|
1384
|
+
constant uint & r2,
|
|
1385
|
+
constant uint & r3,
|
|
1189
1386
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1190
|
-
uint
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
const int64_t rb = tgpig.y*N_F16_F32;
|
|
1194
|
-
const int64_t im = tgpig.z;
|
|
1195
|
-
|
|
1196
|
-
const uint i12 = im%ne12;
|
|
1197
|
-
const uint i13 = im/ne12;
|
|
1387
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1388
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1389
|
+
}
|
|
1198
1390
|
|
|
1199
|
-
|
|
1391
|
+
#define N_F16_F32 4
|
|
1200
1392
|
|
|
1201
|
-
|
|
1393
|
+
void kernel_mul_mv_f16_f32_impl(
|
|
1394
|
+
device const char * src0,
|
|
1395
|
+
device const char * src1,
|
|
1396
|
+
device float * dst,
|
|
1397
|
+
constant int64_t & ne00,
|
|
1398
|
+
constant int64_t & ne01,
|
|
1399
|
+
constant int64_t & ne02,
|
|
1400
|
+
constant uint64_t & nb00,
|
|
1401
|
+
constant uint64_t & nb01,
|
|
1402
|
+
constant uint64_t & nb02,
|
|
1403
|
+
constant int64_t & ne10,
|
|
1404
|
+
constant int64_t & ne11,
|
|
1405
|
+
constant int64_t & ne12,
|
|
1406
|
+
constant uint64_t & nb10,
|
|
1407
|
+
constant uint64_t & nb11,
|
|
1408
|
+
constant uint64_t & nb12,
|
|
1409
|
+
constant int64_t & ne0,
|
|
1410
|
+
constant int64_t & ne1,
|
|
1411
|
+
constant uint & r2,
|
|
1412
|
+
constant uint & r3,
|
|
1413
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1414
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1415
|
+
|
|
1416
|
+
const int64_t r0 = tgpig.x;
|
|
1417
|
+
const int64_t rb = tgpig.y*N_F16_F32;
|
|
1418
|
+
const int64_t im = tgpig.z;
|
|
1419
|
+
|
|
1420
|
+
const uint i12 = im%ne12;
|
|
1421
|
+
const uint i13 = im/ne12;
|
|
1422
|
+
|
|
1423
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1424
|
+
|
|
1425
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
1202
1426
|
|
|
1203
1427
|
if (ne00 < 128) {
|
|
1204
1428
|
for (int row = 0; row < N_F16_F32; ++row) {
|
|
@@ -1244,6 +1468,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
1244
1468
|
}
|
|
1245
1469
|
}
|
|
1246
1470
|
|
|
1471
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
|
1472
|
+
kernel void kernel_mul_mv_f16_f32(
|
|
1473
|
+
device const char * src0,
|
|
1474
|
+
device const char * src1,
|
|
1475
|
+
device float * dst,
|
|
1476
|
+
constant int64_t & ne00,
|
|
1477
|
+
constant int64_t & ne01,
|
|
1478
|
+
constant int64_t & ne02,
|
|
1479
|
+
constant uint64_t & nb00,
|
|
1480
|
+
constant uint64_t & nb01,
|
|
1481
|
+
constant uint64_t & nb02,
|
|
1482
|
+
constant int64_t & ne10,
|
|
1483
|
+
constant int64_t & ne11,
|
|
1484
|
+
constant int64_t & ne12,
|
|
1485
|
+
constant uint64_t & nb10,
|
|
1486
|
+
constant uint64_t & nb11,
|
|
1487
|
+
constant uint64_t & nb12,
|
|
1488
|
+
constant int64_t & ne0,
|
|
1489
|
+
constant int64_t & ne1,
|
|
1490
|
+
constant uint & r2,
|
|
1491
|
+
constant uint & r3,
|
|
1492
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1493
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1494
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1247
1497
|
// Assumes row size (ne00) is a multiple of 4
|
|
1248
1498
|
kernel void kernel_mul_mv_f16_f32_l4(
|
|
1249
1499
|
device const char * src0,
|
|
@@ -1263,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
1263
1513
|
constant uint64_t & nb12,
|
|
1264
1514
|
constant int64_t & ne0,
|
|
1265
1515
|
constant int64_t & ne1,
|
|
1266
|
-
constant uint & r2
|
|
1267
|
-
constant uint & r3
|
|
1516
|
+
constant uint & r2,
|
|
1517
|
+
constant uint & r3,
|
|
1268
1518
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1269
1519
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1270
1520
|
|
|
@@ -1328,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
|
1328
1578
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1329
1579
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1330
1580
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1331
|
-
|
|
1581
|
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1582
|
+
|
|
1332
1583
|
const int64_t k = i3*ne3 + i2;
|
|
1333
1584
|
|
|
1334
1585
|
float m_k;
|
|
@@ -1487,8 +1738,9 @@ kernel void kernel_rope(
|
|
|
1487
1738
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1488
1739
|
}
|
|
1489
1740
|
} else {
|
|
1490
|
-
for (int64_t
|
|
1491
|
-
|
|
1741
|
+
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
|
1742
|
+
if (ic < n_dims) {
|
|
1743
|
+
const int64_t ib = 0;
|
|
1492
1744
|
|
|
1493
1745
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1494
1746
|
const float cur_rot = inv_ndims*ic - ib;
|
|
@@ -1507,6 +1759,14 @@ kernel void kernel_rope(
|
|
|
1507
1759
|
|
|
1508
1760
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1509
1761
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1762
|
+
} else {
|
|
1763
|
+
const int64_t i0 = ic;
|
|
1764
|
+
|
|
1765
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1766
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1767
|
+
|
|
1768
|
+
dst_data[0] = src[0];
|
|
1769
|
+
dst_data[1] = src[1];
|
|
1510
1770
|
}
|
|
1511
1771
|
}
|
|
1512
1772
|
}
|
|
@@ -1548,6 +1808,97 @@ kernel void kernel_im2col_f16(
|
|
|
1548
1808
|
}
|
|
1549
1809
|
}
|
|
1550
1810
|
|
|
1811
|
+
kernel void kernel_upscale_f32(
|
|
1812
|
+
device const char * src0,
|
|
1813
|
+
device char * dst,
|
|
1814
|
+
constant int64_t & ne00,
|
|
1815
|
+
constant int64_t & ne01,
|
|
1816
|
+
constant int64_t & ne02,
|
|
1817
|
+
constant int64_t & ne03,
|
|
1818
|
+
constant uint64_t & nb00,
|
|
1819
|
+
constant uint64_t & nb01,
|
|
1820
|
+
constant uint64_t & nb02,
|
|
1821
|
+
constant uint64_t & nb03,
|
|
1822
|
+
constant int64_t & ne0,
|
|
1823
|
+
constant int64_t & ne1,
|
|
1824
|
+
constant int64_t & ne2,
|
|
1825
|
+
constant int64_t & ne3,
|
|
1826
|
+
constant uint64_t & nb0,
|
|
1827
|
+
constant uint64_t & nb1,
|
|
1828
|
+
constant uint64_t & nb2,
|
|
1829
|
+
constant uint64_t & nb3,
|
|
1830
|
+
constant int32_t & sf,
|
|
1831
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1832
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1833
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1834
|
+
|
|
1835
|
+
const int64_t i3 = tgpig.z;
|
|
1836
|
+
const int64_t i2 = tgpig.y;
|
|
1837
|
+
const int64_t i1 = tgpig.x;
|
|
1838
|
+
|
|
1839
|
+
const int64_t i03 = i3;
|
|
1840
|
+
const int64_t i02 = i2;
|
|
1841
|
+
const int64_t i01 = i1/sf;
|
|
1842
|
+
|
|
1843
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
|
1844
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
|
1845
|
+
|
|
1846
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1847
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
|
1848
|
+
}
|
|
1849
|
+
}
|
|
1850
|
+
|
|
1851
|
+
kernel void kernel_pad_f32(
|
|
1852
|
+
device const char * src0,
|
|
1853
|
+
device char * dst,
|
|
1854
|
+
constant int64_t & ne00,
|
|
1855
|
+
constant int64_t & ne01,
|
|
1856
|
+
constant int64_t & ne02,
|
|
1857
|
+
constant int64_t & ne03,
|
|
1858
|
+
constant uint64_t & nb00,
|
|
1859
|
+
constant uint64_t & nb01,
|
|
1860
|
+
constant uint64_t & nb02,
|
|
1861
|
+
constant uint64_t & nb03,
|
|
1862
|
+
constant int64_t & ne0,
|
|
1863
|
+
constant int64_t & ne1,
|
|
1864
|
+
constant int64_t & ne2,
|
|
1865
|
+
constant int64_t & ne3,
|
|
1866
|
+
constant uint64_t & nb0,
|
|
1867
|
+
constant uint64_t & nb1,
|
|
1868
|
+
constant uint64_t & nb2,
|
|
1869
|
+
constant uint64_t & nb3,
|
|
1870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1871
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1872
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1873
|
+
|
|
1874
|
+
const int64_t i3 = tgpig.z;
|
|
1875
|
+
const int64_t i2 = tgpig.y;
|
|
1876
|
+
const int64_t i1 = tgpig.x;
|
|
1877
|
+
|
|
1878
|
+
const int64_t i03 = i3;
|
|
1879
|
+
const int64_t i02 = i2;
|
|
1880
|
+
const int64_t i01 = i1;
|
|
1881
|
+
|
|
1882
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
|
1883
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
|
1884
|
+
|
|
1885
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
|
1886
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1887
|
+
if (i0 < ne00) {
|
|
1888
|
+
dst_ptr[i0] = src0_ptr[i0];
|
|
1889
|
+
} else {
|
|
1890
|
+
dst_ptr[i0] = 0.0f;
|
|
1891
|
+
}
|
|
1892
|
+
}
|
|
1893
|
+
|
|
1894
|
+
return;
|
|
1895
|
+
}
|
|
1896
|
+
|
|
1897
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1898
|
+
dst_ptr[i0] = 0.0f;
|
|
1899
|
+
}
|
|
1900
|
+
}
|
|
1901
|
+
|
|
1551
1902
|
// bitonic sort implementation following the CUDA kernels as reference
|
|
1552
1903
|
typedef void (argsort_t)(
|
|
1553
1904
|
device const float * x,
|
|
@@ -1600,9 +1951,17 @@ kernel void kernel_argsort_f32_i32(
|
|
|
1600
1951
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
|
1601
1952
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
1602
1953
|
|
|
1954
|
+
kernel void kernel_leaky_relu_f32(
|
|
1955
|
+
device const float * src0,
|
|
1956
|
+
device float * dst,
|
|
1957
|
+
constant float & slope,
|
|
1958
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1959
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
|
1960
|
+
}
|
|
1961
|
+
|
|
1603
1962
|
kernel void kernel_cpy_f16_f16(
|
|
1604
|
-
device
|
|
1605
|
-
device
|
|
1963
|
+
device const half * src0,
|
|
1964
|
+
device half * dst,
|
|
1606
1965
|
constant int64_t & ne00,
|
|
1607
1966
|
constant int64_t & ne01,
|
|
1608
1967
|
constant int64_t & ne02,
|
|
@@ -1641,6 +2000,47 @@ kernel void kernel_cpy_f16_f16(
|
|
|
1641
2000
|
}
|
|
1642
2001
|
}
|
|
1643
2002
|
|
|
2003
|
+
kernel void kernel_cpy_f16_f32(
|
|
2004
|
+
device const half * src0,
|
|
2005
|
+
device float * dst,
|
|
2006
|
+
constant int64_t & ne00,
|
|
2007
|
+
constant int64_t & ne01,
|
|
2008
|
+
constant int64_t & ne02,
|
|
2009
|
+
constant int64_t & ne03,
|
|
2010
|
+
constant uint64_t & nb00,
|
|
2011
|
+
constant uint64_t & nb01,
|
|
2012
|
+
constant uint64_t & nb02,
|
|
2013
|
+
constant uint64_t & nb03,
|
|
2014
|
+
constant int64_t & ne0,
|
|
2015
|
+
constant int64_t & ne1,
|
|
2016
|
+
constant int64_t & ne2,
|
|
2017
|
+
constant int64_t & ne3,
|
|
2018
|
+
constant uint64_t & nb0,
|
|
2019
|
+
constant uint64_t & nb1,
|
|
2020
|
+
constant uint64_t & nb2,
|
|
2021
|
+
constant uint64_t & nb3,
|
|
2022
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2023
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2024
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2025
|
+
const int64_t i03 = tgpig[2];
|
|
2026
|
+
const int64_t i02 = tgpig[1];
|
|
2027
|
+
const int64_t i01 = tgpig[0];
|
|
2028
|
+
|
|
2029
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2030
|
+
|
|
2031
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2032
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2033
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2034
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2035
|
+
|
|
2036
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
2037
|
+
|
|
2038
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2039
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2040
|
+
dst_data[i00] = src[0];
|
|
2041
|
+
}
|
|
2042
|
+
}
|
|
2043
|
+
|
|
1644
2044
|
kernel void kernel_cpy_f32_f16(
|
|
1645
2045
|
device const float * src0,
|
|
1646
2046
|
device half * dst,
|
|
@@ -1917,9 +2317,9 @@ kernel void kernel_cpy_f32_q4_1(
|
|
|
1917
2317
|
}
|
|
1918
2318
|
|
|
1919
2319
|
kernel void kernel_concat(
|
|
1920
|
-
device
|
|
1921
|
-
device
|
|
1922
|
-
device
|
|
2320
|
+
device const char * src0,
|
|
2321
|
+
device const char * src1,
|
|
2322
|
+
device char * dst,
|
|
1923
2323
|
constant int64_t & ne00,
|
|
1924
2324
|
constant int64_t & ne01,
|
|
1925
2325
|
constant int64_t & ne02,
|
|
@@ -1956,7 +2356,7 @@ kernel void kernel_concat(
|
|
|
1956
2356
|
const int64_t i12 = i02 % ne12;
|
|
1957
2357
|
const int64_t i11 = i01 % ne11;
|
|
1958
2358
|
|
|
1959
|
-
device const char * src0_ptr = src0 + i03
|
|
2359
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
|
1960
2360
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
1961
2361
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
1962
2362
|
|
|
@@ -2046,37 +2446,34 @@ typedef struct {
|
|
|
2046
2446
|
} block_q6_K;
|
|
2047
2447
|
// 210 bytes / block
|
|
2048
2448
|
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
}
|
|
2062
|
-
return r;
|
|
2063
|
-
}
|
|
2449
|
+
typedef struct {
|
|
2450
|
+
half d;
|
|
2451
|
+
uint16_t qs[QK_K/8];
|
|
2452
|
+
} block_iq2_xxs;
|
|
2453
|
+
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
|
2454
|
+
|
|
2455
|
+
typedef struct {
|
|
2456
|
+
half d;
|
|
2457
|
+
uint16_t qs[QK_K/8];
|
|
2458
|
+
uint8_t scales[QK_K/32];
|
|
2459
|
+
} block_iq2_xs;
|
|
2460
|
+
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
|
2064
2461
|
|
|
2065
2462
|
//====================================== dot products =========================
|
|
2066
2463
|
|
|
2067
|
-
|
|
2464
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
|
2068
2465
|
device const void * src0,
|
|
2069
2466
|
device const float * src1,
|
|
2070
2467
|
device float * dst,
|
|
2071
2468
|
constant int64_t & ne00,
|
|
2072
|
-
constant int64_t & ne01
|
|
2073
|
-
constant int64_t & ne02
|
|
2074
|
-
constant int64_t & ne10
|
|
2075
|
-
constant int64_t & ne12
|
|
2076
|
-
constant int64_t & ne0
|
|
2077
|
-
constant int64_t & ne1
|
|
2078
|
-
constant uint & r2
|
|
2079
|
-
constant uint & r3
|
|
2469
|
+
constant int64_t & ne01,
|
|
2470
|
+
constant int64_t & ne02,
|
|
2471
|
+
constant int64_t & ne10,
|
|
2472
|
+
constant int64_t & ne12,
|
|
2473
|
+
constant int64_t & ne0,
|
|
2474
|
+
constant int64_t & ne1,
|
|
2475
|
+
constant uint & r2,
|
|
2476
|
+
constant uint & r3,
|
|
2080
2477
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2081
2478
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2082
2479
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2214,23 +2611,51 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
2214
2611
|
}
|
|
2215
2612
|
}
|
|
2216
2613
|
|
|
2614
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
|
2615
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
|
2616
|
+
device const void * src0,
|
|
2617
|
+
device const float * src1,
|
|
2618
|
+
device float * dst,
|
|
2619
|
+
constant int64_t & ne00,
|
|
2620
|
+
constant int64_t & ne01,
|
|
2621
|
+
constant int64_t & ne02,
|
|
2622
|
+
constant uint64_t & nb00,
|
|
2623
|
+
constant uint64_t & nb01,
|
|
2624
|
+
constant uint64_t & nb02,
|
|
2625
|
+
constant int64_t & ne10,
|
|
2626
|
+
constant int64_t & ne11,
|
|
2627
|
+
constant int64_t & ne12,
|
|
2628
|
+
constant uint64_t & nb10,
|
|
2629
|
+
constant uint64_t & nb11,
|
|
2630
|
+
constant uint64_t & nb12,
|
|
2631
|
+
constant int64_t & ne0,
|
|
2632
|
+
constant int64_t & ne1,
|
|
2633
|
+
constant uint & r2,
|
|
2634
|
+
constant uint & r3,
|
|
2635
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2636
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2637
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2638
|
+
|
|
2639
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2640
|
+
}
|
|
2641
|
+
|
|
2217
2642
|
#if QK_K == 256
|
|
2218
|
-
|
|
2643
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
|
2219
2644
|
device const void * src0,
|
|
2220
2645
|
device const float * src1,
|
|
2221
2646
|
device float * dst,
|
|
2222
2647
|
constant int64_t & ne00,
|
|
2223
|
-
constant int64_t & ne01
|
|
2224
|
-
constant int64_t & ne02
|
|
2225
|
-
constant int64_t & ne10
|
|
2226
|
-
constant int64_t & ne12
|
|
2227
|
-
constant int64_t & ne0
|
|
2228
|
-
constant int64_t & ne1
|
|
2229
|
-
constant uint & r2
|
|
2230
|
-
constant uint & r3
|
|
2648
|
+
constant int64_t & ne01,
|
|
2649
|
+
constant int64_t & ne02,
|
|
2650
|
+
constant int64_t & ne10,
|
|
2651
|
+
constant int64_t & ne12,
|
|
2652
|
+
constant int64_t & ne0,
|
|
2653
|
+
constant int64_t & ne1,
|
|
2654
|
+
constant uint & r2,
|
|
2655
|
+
constant uint & r3,
|
|
2231
2656
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2232
|
-
uint
|
|
2233
|
-
uint
|
|
2657
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2658
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2234
2659
|
|
|
2235
2660
|
const int nb = ne00/QK_K;
|
|
2236
2661
|
|
|
@@ -2373,19 +2798,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
2373
2798
|
}
|
|
2374
2799
|
}
|
|
2375
2800
|
#else
|
|
2376
|
-
|
|
2801
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
|
2377
2802
|
device const void * src0,
|
|
2378
2803
|
device const float * src1,
|
|
2379
2804
|
device float * dst,
|
|
2380
2805
|
constant int64_t & ne00,
|
|
2381
|
-
constant int64_t & ne01
|
|
2382
|
-
constant int64_t & ne02
|
|
2383
|
-
constant int64_t & ne10
|
|
2384
|
-
constant int64_t & ne12
|
|
2385
|
-
constant int64_t & ne0
|
|
2386
|
-
constant int64_t & ne1
|
|
2387
|
-
constant uint & r2
|
|
2388
|
-
constant uint & r3
|
|
2806
|
+
constant int64_t & ne01,
|
|
2807
|
+
constant int64_t & ne02,
|
|
2808
|
+
constant int64_t & ne10,
|
|
2809
|
+
constant int64_t & ne12,
|
|
2810
|
+
constant int64_t & ne0,
|
|
2811
|
+
constant int64_t & ne1,
|
|
2812
|
+
constant uint & r2,
|
|
2813
|
+
constant uint & r3,
|
|
2389
2814
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2390
2815
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2391
2816
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2450,20 +2875,48 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
2450
2875
|
}
|
|
2451
2876
|
#endif
|
|
2452
2877
|
|
|
2878
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
|
2879
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
2880
|
+
device const void * src0,
|
|
2881
|
+
device const float * src1,
|
|
2882
|
+
device float * dst,
|
|
2883
|
+
constant int64_t & ne00,
|
|
2884
|
+
constant int64_t & ne01,
|
|
2885
|
+
constant int64_t & ne02,
|
|
2886
|
+
constant uint64_t & nb00,
|
|
2887
|
+
constant uint64_t & nb01,
|
|
2888
|
+
constant uint64_t & nb02,
|
|
2889
|
+
constant int64_t & ne10,
|
|
2890
|
+
constant int64_t & ne11,
|
|
2891
|
+
constant int64_t & ne12,
|
|
2892
|
+
constant uint64_t & nb10,
|
|
2893
|
+
constant uint64_t & nb11,
|
|
2894
|
+
constant uint64_t & nb12,
|
|
2895
|
+
constant int64_t & ne0,
|
|
2896
|
+
constant int64_t & ne1,
|
|
2897
|
+
constant uint & r2,
|
|
2898
|
+
constant uint & r3,
|
|
2899
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2900
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2901
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2902
|
+
|
|
2903
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2904
|
+
}
|
|
2905
|
+
|
|
2453
2906
|
#if QK_K == 256
|
|
2454
|
-
|
|
2907
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
|
2455
2908
|
device const void * src0,
|
|
2456
2909
|
device const float * src1,
|
|
2457
2910
|
device float * dst,
|
|
2458
2911
|
constant int64_t & ne00,
|
|
2459
|
-
constant int64_t & ne01
|
|
2460
|
-
constant int64_t & ne02
|
|
2461
|
-
constant int64_t & ne10
|
|
2462
|
-
constant int64_t & ne12
|
|
2463
|
-
constant int64_t & ne0
|
|
2464
|
-
constant int64_t & ne1
|
|
2465
|
-
constant uint & r2
|
|
2466
|
-
constant uint & r3
|
|
2912
|
+
constant int64_t & ne01,
|
|
2913
|
+
constant int64_t & ne02,
|
|
2914
|
+
constant int64_t & ne10,
|
|
2915
|
+
constant int64_t & ne12,
|
|
2916
|
+
constant int64_t & ne0,
|
|
2917
|
+
constant int64_t & ne1,
|
|
2918
|
+
constant uint & r2,
|
|
2919
|
+
constant uint & r3,
|
|
2467
2920
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2468
2921
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2469
2922
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2564,31 +3017,31 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
2564
3017
|
}
|
|
2565
3018
|
}
|
|
2566
3019
|
#else
|
|
2567
|
-
|
|
3020
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
|
2568
3021
|
device const void * src0,
|
|
2569
3022
|
device const float * src1,
|
|
2570
3023
|
device float * dst,
|
|
2571
3024
|
constant int64_t & ne00,
|
|
2572
|
-
constant int64_t & ne01
|
|
2573
|
-
constant int64_t & ne02
|
|
2574
|
-
constant int64_t & ne10
|
|
2575
|
-
constant int64_t & ne12
|
|
2576
|
-
constant int64_t & ne0
|
|
2577
|
-
constant int64_t & ne1
|
|
2578
|
-
constant uint & r2
|
|
2579
|
-
constant uint & r3
|
|
2580
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2581
|
-
uint
|
|
2582
|
-
uint
|
|
2583
|
-
|
|
2584
|
-
const int ix = tiisg/4; // 0...7
|
|
3025
|
+
constant int64_t & ne01,
|
|
3026
|
+
constant int64_t & ne02,
|
|
3027
|
+
constant int64_t & ne10,
|
|
3028
|
+
constant int64_t & ne12,
|
|
3029
|
+
constant int64_t & ne0,
|
|
3030
|
+
constant int64_t & ne1,
|
|
3031
|
+
constant uint & r2,
|
|
3032
|
+
constant uint & r3,
|
|
3033
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3034
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3035
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3036
|
+
|
|
3037
|
+
const int ix = tiisg/4; // 0...7
|
|
2585
3038
|
const int it = tiisg%4; // 0...3
|
|
2586
3039
|
|
|
2587
3040
|
const int nb = ne00/QK_K;
|
|
2588
3041
|
const int r0 = tgpig.x;
|
|
2589
3042
|
const int r1 = tgpig.y;
|
|
2590
3043
|
const int im = tgpig.z;
|
|
2591
|
-
const int first_row =
|
|
3044
|
+
const int first_row = r0 * N_DST;
|
|
2592
3045
|
const int ib_row = first_row * nb;
|
|
2593
3046
|
|
|
2594
3047
|
const uint i12 = im%ne12;
|
|
@@ -2654,29 +3107,57 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
2654
3107
|
for (int row = 0; row < N_DST; ++row) {
|
|
2655
3108
|
all_sum = simd_sum(sumf[row]);
|
|
2656
3109
|
if (tiisg == 0) {
|
|
2657
|
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
3110
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
2658
3111
|
}
|
|
2659
3112
|
}
|
|
2660
3113
|
}
|
|
2661
3114
|
#endif
|
|
2662
3115
|
|
|
2663
|
-
|
|
3116
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
|
3117
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
2664
3118
|
device const void * src0,
|
|
2665
3119
|
device const float * src1,
|
|
2666
3120
|
device float * dst,
|
|
2667
3121
|
constant int64_t & ne00,
|
|
2668
|
-
constant int64_t & ne01
|
|
2669
|
-
constant int64_t & ne02
|
|
2670
|
-
constant
|
|
2671
|
-
constant
|
|
2672
|
-
constant
|
|
2673
|
-
constant int64_t &
|
|
2674
|
-
constant
|
|
2675
|
-
constant
|
|
3122
|
+
constant int64_t & ne01,
|
|
3123
|
+
constant int64_t & ne02,
|
|
3124
|
+
constant uint64_t & nb00,
|
|
3125
|
+
constant uint64_t & nb01,
|
|
3126
|
+
constant uint64_t & nb02,
|
|
3127
|
+
constant int64_t & ne10,
|
|
3128
|
+
constant int64_t & ne11,
|
|
3129
|
+
constant int64_t & ne12,
|
|
3130
|
+
constant uint64_t & nb10,
|
|
3131
|
+
constant uint64_t & nb11,
|
|
3132
|
+
constant uint64_t & nb12,
|
|
3133
|
+
constant int64_t & ne0,
|
|
3134
|
+
constant int64_t & ne1,
|
|
3135
|
+
constant uint & r2,
|
|
3136
|
+
constant uint & r3,
|
|
2676
3137
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2677
3138
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2678
3139
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2679
3140
|
|
|
3141
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
3142
|
+
}
|
|
3143
|
+
|
|
3144
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
|
3145
|
+
device const void * src0,
|
|
3146
|
+
device const float * src1,
|
|
3147
|
+
device float * dst,
|
|
3148
|
+
constant int64_t & ne00,
|
|
3149
|
+
constant int64_t & ne01,
|
|
3150
|
+
constant int64_t & ne02,
|
|
3151
|
+
constant int64_t & ne10,
|
|
3152
|
+
constant int64_t & ne12,
|
|
3153
|
+
constant int64_t & ne0,
|
|
3154
|
+
constant int64_t & ne1,
|
|
3155
|
+
constant uint & r2,
|
|
3156
|
+
constant uint & r3,
|
|
3157
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3158
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3159
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3160
|
+
|
|
2680
3161
|
const int nb = ne00/QK_K;
|
|
2681
3162
|
|
|
2682
3163
|
const int64_t r0 = tgpig.x;
|
|
@@ -2836,25 +3317,52 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2836
3317
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
2837
3318
|
}
|
|
2838
3319
|
}
|
|
3320
|
+
}
|
|
3321
|
+
|
|
3322
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
|
3323
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
|
3324
|
+
device const void * src0,
|
|
3325
|
+
device const float * src1,
|
|
3326
|
+
device float * dst,
|
|
3327
|
+
constant int64_t & ne00,
|
|
3328
|
+
constant int64_t & ne01,
|
|
3329
|
+
constant int64_t & ne02,
|
|
3330
|
+
constant uint64_t & nb00,
|
|
3331
|
+
constant uint64_t & nb01,
|
|
3332
|
+
constant uint64_t & nb02,
|
|
3333
|
+
constant int64_t & ne10,
|
|
3334
|
+
constant int64_t & ne11,
|
|
3335
|
+
constant int64_t & ne12,
|
|
3336
|
+
constant uint64_t & nb10,
|
|
3337
|
+
constant uint64_t & nb11,
|
|
3338
|
+
constant uint64_t & nb12,
|
|
3339
|
+
constant int64_t & ne0,
|
|
3340
|
+
constant int64_t & ne1,
|
|
3341
|
+
constant uint & r2,
|
|
3342
|
+
constant uint & r3,
|
|
3343
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3344
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3345
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2839
3346
|
|
|
3347
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2840
3348
|
}
|
|
2841
3349
|
|
|
2842
|
-
|
|
3350
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
|
2843
3351
|
device const void * src0,
|
|
2844
3352
|
device const float * src1,
|
|
2845
3353
|
device float * dst,
|
|
2846
3354
|
constant int64_t & ne00,
|
|
2847
|
-
constant int64_t & ne01
|
|
2848
|
-
constant int64_t & ne02
|
|
2849
|
-
constant int64_t & ne10
|
|
2850
|
-
constant int64_t & ne12
|
|
2851
|
-
constant int64_t & ne0
|
|
2852
|
-
constant int64_t & ne1
|
|
2853
|
-
constant uint & r2
|
|
2854
|
-
constant uint & r3
|
|
3355
|
+
constant int64_t & ne01,
|
|
3356
|
+
constant int64_t & ne02,
|
|
3357
|
+
constant int64_t & ne10,
|
|
3358
|
+
constant int64_t & ne12,
|
|
3359
|
+
constant int64_t & ne0,
|
|
3360
|
+
constant int64_t & ne1,
|
|
3361
|
+
constant uint & r2,
|
|
3362
|
+
constant uint & r3,
|
|
2855
3363
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2856
|
-
uint
|
|
2857
|
-
uint
|
|
3364
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3365
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2858
3366
|
|
|
2859
3367
|
const uint8_t kmask1 = 0x03;
|
|
2860
3368
|
const uint8_t kmask2 = 0x0C;
|
|
@@ -2945,160 +3453,677 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2945
3453
|
}
|
|
2946
3454
|
}
|
|
2947
3455
|
|
|
2948
|
-
|
|
3456
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
|
3457
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
|
3458
|
+
device const void * src0,
|
|
3459
|
+
device const float * src1,
|
|
3460
|
+
device float * dst,
|
|
3461
|
+
constant int64_t & ne00,
|
|
3462
|
+
constant int64_t & ne01,
|
|
3463
|
+
constant int64_t & ne02,
|
|
3464
|
+
constant uint64_t & nb00,
|
|
3465
|
+
constant uint64_t & nb01,
|
|
3466
|
+
constant uint64_t & nb02,
|
|
3467
|
+
constant int64_t & ne10,
|
|
3468
|
+
constant int64_t & ne11,
|
|
3469
|
+
constant int64_t & ne12,
|
|
3470
|
+
constant uint64_t & nb10,
|
|
3471
|
+
constant uint64_t & nb11,
|
|
3472
|
+
constant uint64_t & nb12,
|
|
3473
|
+
constant int64_t & ne0,
|
|
3474
|
+
constant int64_t & ne1,
|
|
3475
|
+
constant uint & r2,
|
|
3476
|
+
constant uint & r3,
|
|
3477
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3478
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3479
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2949
3480
|
|
|
2950
|
-
|
|
2951
|
-
template <typename type4x4>
|
|
2952
|
-
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
2953
|
-
float4x4 temp = *(((device float4x4 *)src));
|
|
2954
|
-
for (int i = 0; i < 16; i++){
|
|
2955
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
|
2956
|
-
}
|
|
3481
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2957
3482
|
}
|
|
2958
3483
|
|
|
2959
|
-
|
|
2960
|
-
|
|
2961
|
-
|
|
2962
|
-
|
|
2963
|
-
|
|
2964
|
-
|
|
2965
|
-
|
|
3484
|
+
// ======================= "True" 2-bit
|
|
3485
|
+
|
|
3486
|
+
constexpr constant static uint64_t iq2xxs_grid[256] = {
|
|
3487
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3488
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
|
3489
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
3490
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
|
3491
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
|
3492
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
|
3493
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
|
3494
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
|
3495
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
|
3496
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
|
3497
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
|
3498
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
|
3499
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
|
3500
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
|
3501
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
|
3502
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
|
3503
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
|
3504
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
|
3505
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
|
3506
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
|
3507
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
|
3508
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
|
3509
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
|
3510
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
|
3511
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
|
3512
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
|
3513
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
|
3514
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
|
3515
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
|
3516
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
|
3517
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
|
3518
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
|
3519
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
|
3520
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
|
3521
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
|
3522
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
|
3523
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
|
3524
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
|
3525
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
|
3526
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
|
3527
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
|
3528
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
|
3529
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
|
3530
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
|
3531
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
|
3532
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
|
3533
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
|
3534
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
|
3535
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
|
3536
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
|
3537
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
|
3538
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
|
3539
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
|
3540
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
|
3541
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
|
3542
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
|
3543
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
|
3544
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
|
3545
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
|
3546
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
|
3547
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
|
3548
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
|
3549
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
|
3550
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
|
3551
|
+
};
|
|
2966
3552
|
|
|
2967
|
-
|
|
2968
|
-
|
|
2969
|
-
|
|
2970
|
-
|
|
2971
|
-
|
|
2972
|
-
|
|
2973
|
-
|
|
2974
|
-
|
|
3553
|
+
constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
3554
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3555
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
|
3556
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
|
3557
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
|
3558
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
|
3559
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
|
3560
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
|
3561
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
|
3562
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
|
3563
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
|
3564
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
|
3565
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
|
3566
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
|
3567
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
|
3568
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
|
3569
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
|
3570
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
|
3571
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
|
3572
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
|
3573
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
|
3574
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
|
3575
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
|
3576
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
|
3577
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
|
3578
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
|
3579
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
|
3580
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
|
3581
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
|
3582
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
|
3583
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
|
3584
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
|
3585
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
|
3586
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
|
3587
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
|
3588
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
|
3589
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
|
3590
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
|
3591
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
|
3592
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
|
3593
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
|
3594
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
|
3595
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
|
3596
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
|
3597
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
|
3598
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
|
3599
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
|
3600
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
|
3601
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
|
3602
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
|
3603
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
|
3604
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
|
3605
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
|
3606
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
|
3607
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
|
3608
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
|
3609
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
|
3610
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
|
3611
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
|
3612
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
|
3613
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
|
3614
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
|
3615
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
|
3616
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
|
3617
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
|
3618
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
|
3619
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
|
3620
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
|
3621
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
|
3622
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
|
3623
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
|
3624
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
|
3625
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
|
3626
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
|
3627
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
|
3628
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
|
3629
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
|
3630
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
|
3631
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
|
3632
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
|
3633
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
|
3634
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
|
3635
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
|
3636
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
|
3637
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
|
3638
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
|
3639
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
|
3640
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
|
3641
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
|
3642
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
|
3643
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
|
3644
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
|
3645
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
|
3646
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
|
3647
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
|
3648
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
|
3649
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
|
3650
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
|
3651
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
|
3652
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
|
3653
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
|
3654
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
|
3655
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
|
3656
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
|
3657
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
|
3658
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
|
3659
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
|
3660
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
|
3661
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
|
3662
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
|
3663
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
|
3664
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
|
3665
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
|
3666
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
|
3667
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
|
3668
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
|
3669
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
|
3670
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
|
3671
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
|
3672
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
|
3673
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
|
3674
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
|
3675
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
|
3676
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
|
3677
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
|
3678
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
|
3679
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
|
3680
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
|
3681
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3682
|
+
};
|
|
2975
3683
|
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
|
|
2980
|
-
|
|
3684
|
+
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3685
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
3686
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
3687
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
|
3688
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
|
3689
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
|
3690
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
|
3691
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
|
3692
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
|
3693
|
+
};
|
|
2981
3694
|
|
|
2982
|
-
|
|
2983
|
-
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
2984
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
2985
|
-
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
2986
|
-
const float d2 = d1 / 256.f;
|
|
2987
|
-
const float m = xb->m;
|
|
2988
|
-
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
2989
|
-
const ushort mask1 = mask0 << 8;
|
|
3695
|
+
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
|
2990
3696
|
|
|
2991
|
-
|
|
2992
|
-
|
|
2993
|
-
|
|
2994
|
-
|
|
2995
|
-
|
|
3697
|
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3698
|
+
device const void * src0,
|
|
3699
|
+
device const float * src1,
|
|
3700
|
+
device float * dst,
|
|
3701
|
+
constant int64_t & ne00,
|
|
3702
|
+
constant int64_t & ne01,
|
|
3703
|
+
constant int64_t & ne02,
|
|
3704
|
+
constant int64_t & ne10,
|
|
3705
|
+
constant int64_t & ne12,
|
|
3706
|
+
constant int64_t & ne0,
|
|
3707
|
+
constant int64_t & ne1,
|
|
3708
|
+
constant uint & r2,
|
|
3709
|
+
constant uint & r3,
|
|
3710
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3711
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2996
3714
|
|
|
2997
|
-
|
|
2998
|
-
|
|
2999
|
-
|
|
3000
|
-
const
|
|
3001
|
-
const float md = -16.h * xb->d;
|
|
3002
|
-
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
3715
|
+
const int nb = ne00/QK_K;
|
|
3716
|
+
const int r0 = tgpig.x;
|
|
3717
|
+
const int r1 = tgpig.y;
|
|
3718
|
+
const int im = tgpig.z;
|
|
3003
3719
|
|
|
3004
|
-
const
|
|
3720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3721
|
+
const int ib_row = first_row * nb;
|
|
3005
3722
|
|
|
3006
|
-
const
|
|
3723
|
+
const uint i12 = im%ne12;
|
|
3724
|
+
const uint i13 = im/ne12;
|
|
3007
3725
|
|
|
3008
|
-
const
|
|
3009
|
-
const int gh_bk = il ? 0 : 4;
|
|
3726
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3010
3727
|
|
|
3011
|
-
|
|
3012
|
-
|
|
3013
|
-
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
3014
|
-
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
3728
|
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
|
3729
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3015
3730
|
|
|
3016
|
-
|
|
3017
|
-
|
|
3018
|
-
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
3731
|
+
float yl[32];
|
|
3732
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3019
3733
|
|
|
3020
|
-
|
|
3021
|
-
|
|
3734
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3735
|
+
|
|
3736
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3737
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
|
3738
|
+
{
|
|
3739
|
+
int nval = 4;
|
|
3740
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3741
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
|
3742
|
+
nval = 2;
|
|
3743
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3744
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3745
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3022
3746
|
}
|
|
3023
|
-
}
|
|
3024
3747
|
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
3028
|
-
const float d = xb->d;
|
|
3029
|
-
const float m = xb->m;
|
|
3030
|
-
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
3748
|
+
#if QK_K == 256
|
|
3749
|
+
const int ix = tiisg;
|
|
3031
3750
|
|
|
3032
|
-
const
|
|
3751
|
+
device const float * y4 = y + 32 * ix;
|
|
3033
3752
|
|
|
3034
|
-
|
|
3753
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3035
3754
|
|
|
3036
|
-
|
|
3037
|
-
|
|
3755
|
+
for (int i = 0; i < 32; ++i) {
|
|
3756
|
+
yl[i] = y4[i];
|
|
3757
|
+
}
|
|
3038
3758
|
|
|
3039
|
-
|
|
3040
|
-
|
|
3041
|
-
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
3042
|
-
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
3759
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3760
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3043
3761
|
|
|
3044
|
-
|
|
3045
|
-
const
|
|
3046
|
-
const
|
|
3762
|
+
device const block_iq2_xxs * xr = x + ibl;
|
|
3763
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3764
|
+
device const half * dh = &xr->d;
|
|
3047
3765
|
|
|
3048
|
-
|
|
3049
|
-
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
3050
|
-
}
|
|
3051
|
-
}
|
|
3766
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3052
3767
|
|
|
3053
|
-
|
|
3054
|
-
|
|
3055
|
-
|
|
3056
|
-
|
|
3768
|
+
const float db = dh[0];
|
|
3769
|
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
|
3770
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
|
3771
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
|
3057
3772
|
|
|
3058
|
-
|
|
3059
|
-
|
|
3060
|
-
|
|
3061
|
-
|
|
3773
|
+
float sum = 0;
|
|
3774
|
+
for (int l = 0; l < 4; ++l) {
|
|
3775
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
|
3776
|
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
|
3777
|
+
for (int j = 0; j < 8; ++j) {
|
|
3778
|
+
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3779
|
+
}
|
|
3780
|
+
}
|
|
3781
|
+
sumf[row] += d * sum;
|
|
3062
3782
|
|
|
3063
|
-
|
|
3064
|
-
|
|
3065
|
-
|
|
3066
|
-
const half min = xb->dmin;
|
|
3067
|
-
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
3068
|
-
half dl, ml;
|
|
3069
|
-
uint8_t sc = xb->scales[il];
|
|
3783
|
+
dh += nb*sizeof(block_iq2_xxs)/2;
|
|
3784
|
+
q2 += nb*sizeof(block_iq2_xxs)/2;
|
|
3785
|
+
}
|
|
3070
3786
|
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3787
|
+
y4 += 32 * 32;
|
|
3788
|
+
}
|
|
3789
|
+
#else
|
|
3790
|
+
// TODO
|
|
3074
3791
|
#endif
|
|
3075
|
-
|
|
3076
|
-
|
|
3077
|
-
|
|
3078
|
-
|
|
3079
|
-
|
|
3792
|
+
|
|
3793
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3794
|
+
all_sum = simd_sum(sumf[row]);
|
|
3795
|
+
if (tiisg == 0) {
|
|
3796
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3797
|
+
}
|
|
3080
3798
|
}
|
|
3081
3799
|
}
|
|
3082
3800
|
|
|
3083
|
-
|
|
3084
|
-
void
|
|
3085
|
-
|
|
3086
|
-
|
|
3087
|
-
|
|
3088
|
-
|
|
3089
|
-
|
|
3090
|
-
|
|
3091
|
-
|
|
3092
|
-
|
|
3093
|
-
|
|
3094
|
-
|
|
3801
|
+
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
|
3802
|
+
kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
3803
|
+
device const void * src0,
|
|
3804
|
+
device const float * src1,
|
|
3805
|
+
device float * dst,
|
|
3806
|
+
constant int64_t & ne00,
|
|
3807
|
+
constant int64_t & ne01,
|
|
3808
|
+
constant int64_t & ne02,
|
|
3809
|
+
constant uint64_t & nb00,
|
|
3810
|
+
constant uint64_t & nb01,
|
|
3811
|
+
constant uint64_t & nb02,
|
|
3812
|
+
constant int64_t & ne10,
|
|
3813
|
+
constant int64_t & ne11,
|
|
3814
|
+
constant int64_t & ne12,
|
|
3815
|
+
constant uint64_t & nb10,
|
|
3816
|
+
constant uint64_t & nb11,
|
|
3817
|
+
constant uint64_t & nb12,
|
|
3818
|
+
constant int64_t & ne0,
|
|
3819
|
+
constant int64_t & ne1,
|
|
3820
|
+
constant uint & r2,
|
|
3821
|
+
constant uint & r3,
|
|
3822
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3823
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3824
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3825
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3826
|
+
|
|
3827
|
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3828
|
+
}
|
|
3829
|
+
|
|
3830
|
+
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3831
|
+
device const void * src0,
|
|
3832
|
+
device const float * src1,
|
|
3833
|
+
device float * dst,
|
|
3834
|
+
constant int64_t & ne00,
|
|
3835
|
+
constant int64_t & ne01,
|
|
3836
|
+
constant int64_t & ne02,
|
|
3837
|
+
constant int64_t & ne10,
|
|
3838
|
+
constant int64_t & ne12,
|
|
3839
|
+
constant int64_t & ne0,
|
|
3840
|
+
constant int64_t & ne1,
|
|
3841
|
+
constant uint & r2,
|
|
3842
|
+
constant uint & r3,
|
|
3843
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3845
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3846
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3847
|
+
|
|
3848
|
+
const int nb = ne00/QK_K;
|
|
3849
|
+
const int r0 = tgpig.x;
|
|
3850
|
+
const int r1 = tgpig.y;
|
|
3851
|
+
const int im = tgpig.z;
|
|
3852
|
+
|
|
3853
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3854
|
+
const int ib_row = first_row * nb;
|
|
3855
|
+
|
|
3856
|
+
const uint i12 = im%ne12;
|
|
3857
|
+
const uint i13 = im/ne12;
|
|
3858
|
+
|
|
3859
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3860
|
+
|
|
3861
|
+
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
|
|
3862
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3863
|
+
|
|
3864
|
+
float yl[32];
|
|
3865
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3866
|
+
|
|
3867
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3868
|
+
|
|
3869
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3870
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
|
3871
|
+
{
|
|
3872
|
+
int nval = 8;
|
|
3873
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3874
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
|
3875
|
+
nval = 2;
|
|
3876
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3877
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3878
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
#if QK_K == 256
|
|
3882
|
+
const int ix = tiisg;
|
|
3883
|
+
|
|
3884
|
+
device const float * y4 = y + 32 * ix;
|
|
3885
|
+
|
|
3886
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3887
|
+
|
|
3888
|
+
for (int i = 0; i < 32; ++i) {
|
|
3889
|
+
yl[i] = y4[i];
|
|
3890
|
+
}
|
|
3891
|
+
|
|
3892
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3893
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3894
|
+
|
|
3895
|
+
device const block_iq2_xs * xr = x + ibl;
|
|
3896
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3897
|
+
device const uint8_t * sc = xr->scales + ib;
|
|
3898
|
+
device const half * dh = &xr->d;
|
|
3899
|
+
|
|
3900
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3901
|
+
|
|
3902
|
+
const float db = dh[0];
|
|
3903
|
+
const uint8_t ls1 = sc[0] & 0xf;
|
|
3904
|
+
const uint8_t ls2 = sc[0] >> 4;
|
|
3905
|
+
const float d1 = db * (0.5f + ls1);
|
|
3906
|
+
const float d2 = db * (0.5f + ls2);
|
|
3907
|
+
|
|
3908
|
+
float sum1 = 0, sum2 = 0;
|
|
3909
|
+
for (int l = 0; l < 2; ++l) {
|
|
3910
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3911
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3912
|
+
for (int j = 0; j < 8; ++j) {
|
|
3913
|
+
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3914
|
+
}
|
|
3915
|
+
}
|
|
3916
|
+
for (int l = 2; l < 4; ++l) {
|
|
3917
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3918
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3919
|
+
for (int j = 0; j < 8; ++j) {
|
|
3920
|
+
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3921
|
+
}
|
|
3922
|
+
}
|
|
3923
|
+
sumf[row] += d1 * sum1 + d2 * sum2;
|
|
3924
|
+
|
|
3925
|
+
dh += nb*sizeof(block_iq2_xs)/2;
|
|
3926
|
+
q2 += nb*sizeof(block_iq2_xs)/2;
|
|
3927
|
+
sc += nb*sizeof(block_iq2_xs);
|
|
3928
|
+
}
|
|
3929
|
+
|
|
3930
|
+
y4 += 32 * 32;
|
|
3931
|
+
}
|
|
3932
|
+
#else
|
|
3933
|
+
// TODO
|
|
3934
|
+
#endif
|
|
3935
|
+
|
|
3936
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3937
|
+
all_sum = simd_sum(sumf[row]);
|
|
3938
|
+
if (tiisg == 0) {
|
|
3939
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3940
|
+
}
|
|
3941
|
+
}
|
|
3942
|
+
}
|
|
3943
|
+
|
|
3944
|
+
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
|
3945
|
+
kernel void kernel_mul_mv_iq2_xs_f32(
|
|
3946
|
+
device const void * src0,
|
|
3947
|
+
device const float * src1,
|
|
3948
|
+
device float * dst,
|
|
3949
|
+
constant int64_t & ne00,
|
|
3950
|
+
constant int64_t & ne01,
|
|
3951
|
+
constant int64_t & ne02,
|
|
3952
|
+
constant uint64_t & nb00,
|
|
3953
|
+
constant uint64_t & nb01,
|
|
3954
|
+
constant uint64_t & nb02,
|
|
3955
|
+
constant int64_t & ne10,
|
|
3956
|
+
constant int64_t & ne11,
|
|
3957
|
+
constant int64_t & ne12,
|
|
3958
|
+
constant uint64_t & nb10,
|
|
3959
|
+
constant uint64_t & nb11,
|
|
3960
|
+
constant uint64_t & nb12,
|
|
3961
|
+
constant int64_t & ne0,
|
|
3962
|
+
constant int64_t & ne1,
|
|
3963
|
+
constant uint & r2,
|
|
3964
|
+
constant uint & r3,
|
|
3965
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3966
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3967
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3968
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3969
|
+
|
|
3970
|
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3971
|
+
}
|
|
3972
|
+
|
|
3973
|
+
//============================= templates and their specializations =============================
|
|
3974
|
+
|
|
3975
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
3976
|
+
template <typename type4x4>
|
|
3977
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
3978
|
+
float4x4 temp = *(((device float4x4 *)src));
|
|
3979
|
+
for (int i = 0; i < 16; i++){
|
|
3980
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3981
|
+
}
|
|
3982
|
+
}
|
|
3983
|
+
|
|
3984
|
+
template <typename type4x4>
|
|
3985
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
|
3986
|
+
half4x4 temp = *(((device half4x4 *)src));
|
|
3987
|
+
for (int i = 0; i < 16; i++){
|
|
3988
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3989
|
+
}
|
|
3990
|
+
}
|
|
3991
|
+
|
|
3992
|
+
template <typename type4x4>
|
|
3993
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
3994
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
3995
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
3996
|
+
const float d2 = d1 / 256.f;
|
|
3997
|
+
const float md = -8.h * xb->d;
|
|
3998
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
3999
|
+
const ushort mask1 = mask0 << 8;
|
|
4000
|
+
|
|
4001
|
+
for (int i=0;i<8;i++) {
|
|
4002
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
|
4003
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
4004
|
+
}
|
|
4005
|
+
}
|
|
4006
|
+
|
|
4007
|
+
template <typename type4x4>
|
|
4008
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
4009
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
4010
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
4011
|
+
const float d2 = d1 / 256.f;
|
|
4012
|
+
const float m = xb->m;
|
|
4013
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
4014
|
+
const ushort mask1 = mask0 << 8;
|
|
4015
|
+
|
|
4016
|
+
for (int i=0;i<8;i++) {
|
|
4017
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
|
4018
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
|
4019
|
+
}
|
|
4020
|
+
}
|
|
4021
|
+
|
|
4022
|
+
template <typename type4x4>
|
|
4023
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
|
4024
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
4025
|
+
const float d = xb->d;
|
|
4026
|
+
const float md = -16.h * xb->d;
|
|
4027
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
4028
|
+
|
|
4029
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
4030
|
+
|
|
4031
|
+
const int x_mv = il ? 4 : 0;
|
|
4032
|
+
|
|
4033
|
+
const int gh_mv = il ? 12 : 0;
|
|
4034
|
+
const int gh_bk = il ? 0 : 4;
|
|
4035
|
+
|
|
4036
|
+
for (int i = 0; i < 8; i++) {
|
|
4037
|
+
// extract the 5-th bits for x0 and x1
|
|
4038
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
4039
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
4040
|
+
|
|
4041
|
+
// combine the 4-bits from qs with the 5th bit
|
|
4042
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
4043
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
4044
|
+
|
|
4045
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
|
4046
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
|
4047
|
+
}
|
|
4048
|
+
}
|
|
4049
|
+
|
|
4050
|
+
template <typename type4x4>
|
|
4051
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
|
4052
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
4053
|
+
const float d = xb->d;
|
|
4054
|
+
const float m = xb->m;
|
|
4055
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
4056
|
+
|
|
4057
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
4058
|
+
|
|
4059
|
+
const int x_mv = il ? 4 : 0;
|
|
4060
|
+
|
|
4061
|
+
const int gh_mv = il ? 12 : 0;
|
|
4062
|
+
const int gh_bk = il ? 0 : 4;
|
|
4063
|
+
|
|
4064
|
+
for (int i = 0; i < 8; i++) {
|
|
4065
|
+
// extract the 5-th bits for x0 and x1
|
|
4066
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
4067
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
4068
|
+
|
|
4069
|
+
// combine the 4-bits from qs with the 5th bit
|
|
4070
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
4071
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
4072
|
+
|
|
4073
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
|
4074
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
4075
|
+
}
|
|
4076
|
+
}
|
|
4077
|
+
|
|
4078
|
+
template <typename type4x4>
|
|
4079
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
4080
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
4081
|
+
const half d = xb->d;
|
|
4082
|
+
|
|
4083
|
+
for (int i = 0; i < 16; i++) {
|
|
4084
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
|
4085
|
+
}
|
|
4086
|
+
}
|
|
4087
|
+
|
|
4088
|
+
template <typename type4x4>
|
|
4089
|
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
4090
|
+
const float d = xb->d;
|
|
4091
|
+
const float min = xb->dmin;
|
|
4092
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
4093
|
+
float dl, ml;
|
|
4094
|
+
uint8_t sc = xb->scales[il];
|
|
4095
|
+
|
|
4096
|
+
#if QK_K == 256
|
|
4097
|
+
q = q + 32*(il/8) + 16*(il&1);
|
|
4098
|
+
il = (il/2)%4;
|
|
4099
|
+
#endif
|
|
4100
|
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
4101
|
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
4102
|
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
|
4103
|
+
for (int i = 0; i < 16; ++i) {
|
|
4104
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
4105
|
+
}
|
|
4106
|
+
}
|
|
4107
|
+
|
|
4108
|
+
template <typename type4x4>
|
|
4109
|
+
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
|
4110
|
+
const half d_all = xb->d;
|
|
4111
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
4112
|
+
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
|
4113
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
4114
|
+
|
|
4115
|
+
#if QK_K == 256
|
|
4116
|
+
q = q + 32 * (il/8) + 16 * (il&1);
|
|
4117
|
+
h = h + 16 * (il&1);
|
|
4118
|
+
uint8_t m = 1 << (il/2);
|
|
4119
|
+
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
|
3095
4120
|
((il/4)>0 ? 12 : 3);
|
|
3096
4121
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
|
3097
4122
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
|
3098
4123
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
|
3099
4124
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
|
3100
|
-
|
|
3101
|
-
const
|
|
4125
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
|
4126
|
+
const float ml = 4.f * dl;
|
|
3102
4127
|
|
|
3103
4128
|
il = (il/2) & 3;
|
|
3104
4129
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
@@ -3135,10 +4160,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
|
3135
4160
|
q = q + (il/4) * 32 + 16 * (il&1);
|
|
3136
4161
|
il = il & 3;
|
|
3137
4162
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
3138
|
-
const
|
|
3139
|
-
const
|
|
3140
|
-
const
|
|
3141
|
-
const
|
|
4163
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
4164
|
+
const float min = xb->dmin;
|
|
4165
|
+
const float dl = d * sc[0];
|
|
4166
|
+
const float ml = min * sc[1];
|
|
3142
4167
|
#else
|
|
3143
4168
|
q = q + 16 * (il&1);
|
|
3144
4169
|
device const uint8_t * s = xb->scales;
|
|
@@ -3165,13 +4190,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
|
3165
4190
|
uint8_t ul = 1 << (il/2);
|
|
3166
4191
|
il = il & 3;
|
|
3167
4192
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
3168
|
-
const
|
|
3169
|
-
const
|
|
3170
|
-
const
|
|
3171
|
-
const
|
|
4193
|
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
|
4194
|
+
const float min = xb->dmin;
|
|
4195
|
+
const float dl = d * sc[0];
|
|
4196
|
+
const float ml = min * sc[1];
|
|
3172
4197
|
|
|
3173
|
-
const ushort mask
|
|
3174
|
-
const
|
|
4198
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
4199
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
|
3175
4200
|
for (int i = 0; i < 16; ++i) {
|
|
3176
4201
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
|
3177
4202
|
}
|
|
@@ -3198,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3198
4223
|
#if QK_K == 256
|
|
3199
4224
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
3200
4225
|
qh = qh + 32*(il/8) + 16*(il&1);
|
|
3201
|
-
|
|
4226
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
|
3202
4227
|
il = (il/2) & 3;
|
|
3203
4228
|
#else
|
|
3204
4229
|
ql = ql + 16 * (il&1);
|
|
3205
|
-
|
|
4230
|
+
float sc = scales[il];
|
|
3206
4231
|
#endif
|
|
3207
4232
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
3208
4233
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
|
3209
|
-
const
|
|
3210
|
-
const
|
|
3211
|
-
const
|
|
4234
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
|
4235
|
+
const float ml = d_all * sc * 32.f;
|
|
4236
|
+
const float dl = d_all * sc * coef;
|
|
3212
4237
|
for (int i = 0; i < 16; ++i) {
|
|
3213
4238
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
|
3214
4239
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
@@ -3216,28 +4241,171 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3216
4241
|
}
|
|
3217
4242
|
}
|
|
3218
4243
|
|
|
3219
|
-
template<typename
|
|
3220
|
-
|
|
4244
|
+
template <typename type4x4>
|
|
4245
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
4246
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4247
|
+
const float d = xb->d;
|
|
4248
|
+
const int ib32 = il/2;
|
|
4249
|
+
il = il%2;
|
|
4250
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4251
|
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
|
4252
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4253
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
4254
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
4255
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
4256
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
4257
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
4258
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
4259
|
+
for (int i = 0; i < 8; ++i) {
|
|
4260
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4261
|
+
}
|
|
4262
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
4263
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
4264
|
+
for (int i = 0; i < 8; ++i) {
|
|
4265
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4266
|
+
}
|
|
4267
|
+
}
|
|
4268
|
+
|
|
4269
|
+
template <typename type4x4>
|
|
4270
|
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
|
4271
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4272
|
+
const float d = xb->d;
|
|
4273
|
+
const int ib32 = il/2;
|
|
4274
|
+
il = il%2;
|
|
4275
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4276
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4277
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
4278
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
|
4279
|
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
|
4280
|
+
for (int i = 0; i < 8; ++i) {
|
|
4281
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4282
|
+
}
|
|
4283
|
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
|
4284
|
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
|
4285
|
+
for (int i = 0; i < 8; ++i) {
|
|
4286
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4287
|
+
}
|
|
4288
|
+
}
|
|
4289
|
+
|
|
4290
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
4291
|
+
kernel void kernel_get_rows(
|
|
3221
4292
|
device const void * src0,
|
|
3222
|
-
device const
|
|
4293
|
+
device const char * src1,
|
|
3223
4294
|
device float * dst,
|
|
3224
4295
|
constant int64_t & ne00,
|
|
3225
4296
|
constant uint64_t & nb01,
|
|
4297
|
+
constant uint64_t & nb02,
|
|
4298
|
+
constant int64_t & ne10,
|
|
4299
|
+
constant uint64_t & nb10,
|
|
4300
|
+
constant uint64_t & nb11,
|
|
3226
4301
|
constant uint64_t & nb1,
|
|
3227
|
-
|
|
4302
|
+
constant uint64_t & nb2,
|
|
4303
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3228
4304
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
3229
|
-
|
|
3230
|
-
const
|
|
3231
|
-
const
|
|
4305
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4306
|
+
//const int64_t i = tgpig;
|
|
4307
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
|
3232
4308
|
|
|
3233
|
-
|
|
4309
|
+
const int64_t i10 = tgpig.x;
|
|
4310
|
+
const int64_t i11 = tgpig.y;
|
|
4311
|
+
|
|
4312
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4313
|
+
|
|
4314
|
+
const int64_t i02 = i11;
|
|
4315
|
+
|
|
4316
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
3234
4317
|
float4x4 temp;
|
|
3235
4318
|
dequantize_func(
|
|
3236
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
|
3237
|
-
*(((device float4x4 *) ((device char *) dst +
|
|
4319
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
4320
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
4321
|
+
}
|
|
4322
|
+
}
|
|
4323
|
+
|
|
4324
|
+
kernel void kernel_get_rows_f32(
|
|
4325
|
+
device const void * src0,
|
|
4326
|
+
device const char * src1,
|
|
4327
|
+
device float * dst,
|
|
4328
|
+
constant int64_t & ne00,
|
|
4329
|
+
constant uint64_t & nb01,
|
|
4330
|
+
constant uint64_t & nb02,
|
|
4331
|
+
constant int64_t & ne10,
|
|
4332
|
+
constant uint64_t & nb10,
|
|
4333
|
+
constant uint64_t & nb11,
|
|
4334
|
+
constant uint64_t & nb1,
|
|
4335
|
+
constant uint64_t & nb2,
|
|
4336
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4337
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4338
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4339
|
+
const int64_t i10 = tgpig.x;
|
|
4340
|
+
const int64_t i11 = tgpig.y;
|
|
4341
|
+
|
|
4342
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4343
|
+
|
|
4344
|
+
const int64_t i02 = i11;
|
|
4345
|
+
|
|
4346
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4347
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4348
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4349
|
+
}
|
|
4350
|
+
}
|
|
4351
|
+
|
|
4352
|
+
kernel void kernel_get_rows_f16(
|
|
4353
|
+
device const void * src0,
|
|
4354
|
+
device const char * src1,
|
|
4355
|
+
device float * dst,
|
|
4356
|
+
constant int64_t & ne00,
|
|
4357
|
+
constant uint64_t & nb01,
|
|
4358
|
+
constant uint64_t & nb02,
|
|
4359
|
+
constant int64_t & ne10,
|
|
4360
|
+
constant uint64_t & nb10,
|
|
4361
|
+
constant uint64_t & nb11,
|
|
4362
|
+
constant uint64_t & nb1,
|
|
4363
|
+
constant uint64_t & nb2,
|
|
4364
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4365
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4366
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4367
|
+
const int64_t i10 = tgpig.x;
|
|
4368
|
+
const int64_t i11 = tgpig.y;
|
|
4369
|
+
|
|
4370
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4371
|
+
|
|
4372
|
+
const int64_t i02 = i11;
|
|
4373
|
+
|
|
4374
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4375
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4376
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4377
|
+
}
|
|
4378
|
+
}
|
|
4379
|
+
|
|
4380
|
+
kernel void kernel_get_rows_i32(
|
|
4381
|
+
device const void * src0,
|
|
4382
|
+
device const char * src1,
|
|
4383
|
+
device int32_t * dst,
|
|
4384
|
+
constant int64_t & ne00,
|
|
4385
|
+
constant uint64_t & nb01,
|
|
4386
|
+
constant uint64_t & nb02,
|
|
4387
|
+
constant int64_t & ne10,
|
|
4388
|
+
constant uint64_t & nb10,
|
|
4389
|
+
constant uint64_t & nb11,
|
|
4390
|
+
constant uint64_t & nb1,
|
|
4391
|
+
constant uint64_t & nb2,
|
|
4392
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4393
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4394
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4395
|
+
const int64_t i10 = tgpig.x;
|
|
4396
|
+
const int64_t i11 = tgpig.y;
|
|
4397
|
+
|
|
4398
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4399
|
+
|
|
4400
|
+
const int64_t i02 = i11;
|
|
4401
|
+
|
|
4402
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4403
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4404
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
3238
4405
|
}
|
|
3239
4406
|
}
|
|
3240
4407
|
|
|
4408
|
+
|
|
3241
4409
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
3242
4410
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
3243
4411
|
#define BLOCK_SIZE_K 32
|
|
@@ -3256,12 +4424,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3256
4424
|
device float * dst,
|
|
3257
4425
|
constant int64_t & ne00,
|
|
3258
4426
|
constant int64_t & ne02,
|
|
3259
|
-
constant
|
|
3260
|
-
constant
|
|
4427
|
+
constant uint64_t & nb01,
|
|
4428
|
+
constant uint64_t & nb02,
|
|
3261
4429
|
constant int64_t & ne12,
|
|
3262
|
-
constant
|
|
3263
|
-
constant
|
|
3264
|
-
constant
|
|
4430
|
+
constant uint64_t & nb10,
|
|
4431
|
+
constant uint64_t & nb11,
|
|
4432
|
+
constant uint64_t & nb12,
|
|
3265
4433
|
constant int64_t & ne0,
|
|
3266
4434
|
constant int64_t & ne1,
|
|
3267
4435
|
constant uint & r2,
|
|
@@ -3382,18 +4550,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3382
4550
|
}
|
|
3383
4551
|
}
|
|
3384
4552
|
|
|
4553
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
|
4554
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
4555
|
+
void kernel_mul_mm_id_impl(
|
|
4556
|
+
device const uchar * src0,
|
|
4557
|
+
device const uchar * src1,
|
|
4558
|
+
thread short * src1ids,
|
|
4559
|
+
device float * dst,
|
|
4560
|
+
constant int64_t & ne00,
|
|
4561
|
+
constant int64_t & ne02,
|
|
4562
|
+
constant uint64_t & nb01,
|
|
4563
|
+
constant uint64_t & nb02,
|
|
4564
|
+
constant int64_t & ne12,
|
|
4565
|
+
constant uint64_t & nb10,
|
|
4566
|
+
constant uint64_t & nb11,
|
|
4567
|
+
constant uint64_t & nb12,
|
|
4568
|
+
constant int64_t & ne0,
|
|
4569
|
+
int64_t ne1,
|
|
4570
|
+
constant uint & r2,
|
|
4571
|
+
constant uint & r3,
|
|
4572
|
+
threadgroup uchar * shared_memory,
|
|
4573
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4574
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4575
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4576
|
+
|
|
4577
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
4578
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
4579
|
+
|
|
4580
|
+
const uint r0 = tgpig.y;
|
|
4581
|
+
const uint r1 = tgpig.x;
|
|
4582
|
+
const uint im = tgpig.z;
|
|
4583
|
+
|
|
4584
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
|
4585
|
+
|
|
4586
|
+
// if this block is of 64x32 shape or smaller
|
|
4587
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
4588
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
4589
|
+
|
|
4590
|
+
// a thread shouldn't load data outside of the matrix
|
|
4591
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
4592
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
4593
|
+
|
|
4594
|
+
simdgroup_half8x8 ma[4];
|
|
4595
|
+
simdgroup_float8x8 mb[2];
|
|
4596
|
+
simdgroup_float8x8 c_res[8];
|
|
4597
|
+
for (int i = 0; i < 8; i++){
|
|
4598
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
4599
|
+
}
|
|
4600
|
+
|
|
4601
|
+
short il = (tiitg % THREAD_PER_ROW);
|
|
4602
|
+
|
|
4603
|
+
const uint i12 = im%ne12;
|
|
4604
|
+
const uint i13 = im/ne12;
|
|
4605
|
+
|
|
4606
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
4607
|
+
ushort offset1 = il/nl;
|
|
4608
|
+
|
|
4609
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
4610
|
+
device const float * y = (device const float *)(src1
|
|
4611
|
+
+ nb12 * im
|
|
4612
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
4613
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
4614
|
+
|
|
4615
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
4616
|
+
// load data and store to threadgroup memory
|
|
4617
|
+
half4x4 temp_a;
|
|
4618
|
+
dequantize_func(x, il, temp_a);
|
|
4619
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4620
|
+
|
|
4621
|
+
for (int i = 0; i < 16; i++) {
|
|
4622
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
4623
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
4624
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
4625
|
+
}
|
|
4626
|
+
|
|
4627
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
4628
|
+
|
|
4629
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
4630
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
4631
|
+
y += BLOCK_SIZE_K;
|
|
4632
|
+
|
|
4633
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4634
|
+
|
|
4635
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
4636
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
4637
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
4638
|
+
|
|
4639
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
4640
|
+
for (int i = 0; i < 4; i++) {
|
|
4641
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
4642
|
+
}
|
|
4643
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
4644
|
+
for (int i = 0; i < 2; i++) {
|
|
4645
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
4646
|
+
}
|
|
4647
|
+
|
|
4648
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
4649
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
4650
|
+
|
|
4651
|
+
for (int i = 0; i < 8; i++){
|
|
4652
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
4653
|
+
}
|
|
4654
|
+
}
|
|
4655
|
+
}
|
|
4656
|
+
|
|
4657
|
+
{
|
|
4658
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4659
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
4660
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
4661
|
+
for (int i = 0; i < 8; i++) {
|
|
4662
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
4663
|
+
}
|
|
4664
|
+
|
|
4665
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4666
|
+
|
|
4667
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
|
4668
|
+
if (sgitg == 0) {
|
|
4669
|
+
for (int i = 0; i < n_rows; i++) {
|
|
4670
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
4671
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4672
|
+
}
|
|
4673
|
+
}
|
|
4674
|
+
}
|
|
4675
|
+
}
|
|
4676
|
+
}
|
|
4677
|
+
|
|
3385
4678
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3386
4679
|
kernel void kernel_mul_mm(device const uchar * src0,
|
|
3387
4680
|
device const uchar * src1,
|
|
3388
4681
|
device float * dst,
|
|
3389
4682
|
constant int64_t & ne00,
|
|
3390
4683
|
constant int64_t & ne02,
|
|
3391
|
-
constant
|
|
3392
|
-
constant
|
|
4684
|
+
constant uint64_t & nb01,
|
|
4685
|
+
constant uint64_t & nb02,
|
|
3393
4686
|
constant int64_t & ne12,
|
|
3394
|
-
constant
|
|
3395
|
-
constant
|
|
3396
|
-
constant
|
|
4687
|
+
constant uint64_t & nb10,
|
|
4688
|
+
constant uint64_t & nb11,
|
|
4689
|
+
constant uint64_t & nb12,
|
|
3397
4690
|
constant int64_t & ne0,
|
|
3398
4691
|
constant int64_t & ne1,
|
|
3399
4692
|
constant uint & r2,
|
|
@@ -3426,19 +4719,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
3426
4719
|
|
|
3427
4720
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3428
4721
|
kernel void kernel_mul_mm_id(
|
|
3429
|
-
device const
|
|
4722
|
+
device const uchar * ids,
|
|
3430
4723
|
device const uchar * src1,
|
|
3431
4724
|
device float * dst,
|
|
4725
|
+
constant uint64_t & nbi1,
|
|
3432
4726
|
constant int64_t & ne00,
|
|
3433
4727
|
constant int64_t & ne02,
|
|
3434
|
-
constant
|
|
3435
|
-
constant
|
|
4728
|
+
constant uint64_t & nb01,
|
|
4729
|
+
constant uint64_t & nb02,
|
|
3436
4730
|
constant int64_t & ne12,
|
|
3437
|
-
constant int64_t &
|
|
3438
|
-
constant
|
|
3439
|
-
constant
|
|
4731
|
+
constant int64_t & ne13,
|
|
4732
|
+
constant uint64_t & nb10,
|
|
4733
|
+
constant uint64_t & nb11,
|
|
4734
|
+
constant uint64_t & nb12,
|
|
3440
4735
|
constant int64_t & ne0,
|
|
3441
4736
|
constant int64_t & ne1,
|
|
4737
|
+
constant uint64_t & nb1,
|
|
3442
4738
|
constant uint & r2,
|
|
3443
4739
|
constant uint & r3,
|
|
3444
4740
|
constant int & idx,
|
|
@@ -3454,11 +4750,27 @@ kernel void kernel_mul_mm_id(
|
|
|
3454
4750
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3455
4751
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
3456
4752
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3457
|
-
device const uchar *
|
|
4753
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3458
4754
|
|
|
3459
|
-
|
|
3460
|
-
|
|
4755
|
+
// expert id
|
|
4756
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
|
4757
|
+
|
|
4758
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4759
|
+
|
|
4760
|
+
// row indices of src1 for expert id
|
|
4761
|
+
int64_t _ne1 = 0;
|
|
4762
|
+
short src1ids[512];
|
|
4763
|
+
|
|
4764
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
4765
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
4766
|
+
src1ids[_ne1++] = i1;
|
|
4767
|
+
}
|
|
4768
|
+
}
|
|
4769
|
+
|
|
4770
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
4771
|
+
src0s[id],
|
|
3461
4772
|
src1,
|
|
4773
|
+
src1ids,
|
|
3462
4774
|
dst,
|
|
3463
4775
|
ne00,
|
|
3464
4776
|
ne02,
|
|
@@ -3469,7 +4781,7 @@ kernel void kernel_mul_mm_id(
|
|
|
3469
4781
|
nb11,
|
|
3470
4782
|
nb12,
|
|
3471
4783
|
ne0,
|
|
3472
|
-
|
|
4784
|
+
_ne1,
|
|
3473
4785
|
r2,
|
|
3474
4786
|
r3,
|
|
3475
4787
|
shared_memory,
|
|
@@ -3484,17 +4796,26 @@ kernel void kernel_mul_mm_id(
|
|
|
3484
4796
|
#define QK_NL 4
|
|
3485
4797
|
#endif
|
|
3486
4798
|
|
|
4799
|
+
//
|
|
4800
|
+
// get rows
|
|
4801
|
+
//
|
|
4802
|
+
|
|
3487
4803
|
typedef void (get_rows_t)(
|
|
3488
4804
|
device const void * src0,
|
|
3489
|
-
device const
|
|
4805
|
+
device const char * src1,
|
|
3490
4806
|
device float * dst,
|
|
3491
4807
|
constant int64_t & ne00,
|
|
3492
4808
|
constant uint64_t & nb01,
|
|
4809
|
+
constant uint64_t & nb02,
|
|
4810
|
+
constant int64_t & ne10,
|
|
4811
|
+
constant uint64_t & nb10,
|
|
4812
|
+
constant uint64_t & nb11,
|
|
3493
4813
|
constant uint64_t & nb1,
|
|
3494
|
-
|
|
4814
|
+
constant uint64_t & nb2,
|
|
4815
|
+
uint3, uint, uint3);
|
|
3495
4816
|
|
|
3496
|
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
3497
|
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
4817
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
4818
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
3498
4819
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
3499
4820
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
3500
4821
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -3505,6 +4826,12 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|
|
3505
4826
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
3506
4827
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
3507
4828
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4829
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4830
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4831
|
+
|
|
4832
|
+
//
|
|
4833
|
+
// matrix-matrix multiplication
|
|
4834
|
+
//
|
|
3508
4835
|
|
|
3509
4836
|
typedef void (mat_mm_t)(
|
|
3510
4837
|
device const uchar * src0,
|
|
@@ -3512,12 +4839,12 @@ typedef void (mat_mm_t)(
|
|
|
3512
4839
|
device float * dst,
|
|
3513
4840
|
constant int64_t & ne00,
|
|
3514
4841
|
constant int64_t & ne02,
|
|
3515
|
-
constant
|
|
3516
|
-
constant
|
|
4842
|
+
constant uint64_t & nb01,
|
|
4843
|
+
constant uint64_t & nb02,
|
|
3517
4844
|
constant int64_t & ne12,
|
|
3518
|
-
constant
|
|
3519
|
-
constant
|
|
3520
|
-
constant
|
|
4845
|
+
constant uint64_t & nb10,
|
|
4846
|
+
constant uint64_t & nb11,
|
|
4847
|
+
constant uint64_t & nb12,
|
|
3521
4848
|
constant int64_t & ne0,
|
|
3522
4849
|
constant int64_t & ne1,
|
|
3523
4850
|
constant uint & r2,
|
|
@@ -3537,21 +4864,30 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
3537
4864
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
3538
4865
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
3539
4866
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4867
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4868
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4869
|
+
|
|
4870
|
+
//
|
|
4871
|
+
// indirect matrix-matrix multiplication
|
|
4872
|
+
//
|
|
3540
4873
|
|
|
3541
4874
|
typedef void (mat_mm_id_t)(
|
|
3542
|
-
device const
|
|
4875
|
+
device const uchar * ids,
|
|
3543
4876
|
device const uchar * src1,
|
|
3544
4877
|
device float * dst,
|
|
4878
|
+
constant uint64_t & nbi1,
|
|
3545
4879
|
constant int64_t & ne00,
|
|
3546
4880
|
constant int64_t & ne02,
|
|
3547
|
-
constant
|
|
3548
|
-
constant
|
|
4881
|
+
constant uint64_t & nb01,
|
|
4882
|
+
constant uint64_t & nb02,
|
|
3549
4883
|
constant int64_t & ne12,
|
|
3550
|
-
constant int64_t &
|
|
3551
|
-
constant
|
|
3552
|
-
constant
|
|
4884
|
+
constant int64_t & ne13,
|
|
4885
|
+
constant uint64_t & nb10,
|
|
4886
|
+
constant uint64_t & nb11,
|
|
4887
|
+
constant uint64_t & nb12,
|
|
3553
4888
|
constant int64_t & ne0,
|
|
3554
4889
|
constant int64_t & ne1,
|
|
4890
|
+
constant uint64_t & nb1,
|
|
3555
4891
|
constant uint & r2,
|
|
3556
4892
|
constant uint & r3,
|
|
3557
4893
|
constant int & idx,
|
|
@@ -3578,3 +4914,907 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
3578
4914
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
3579
4915
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
3580
4916
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4917
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4918
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4919
|
+
|
|
4920
|
+
//
|
|
4921
|
+
// matrix-vector multiplication
|
|
4922
|
+
//
|
|
4923
|
+
|
|
4924
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
4925
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
|
4926
|
+
device const char * ids,
|
|
4927
|
+
device const char * src1,
|
|
4928
|
+
device float * dst,
|
|
4929
|
+
constant uint64_t & nbi1,
|
|
4930
|
+
constant int64_t & ne00,
|
|
4931
|
+
constant int64_t & ne01,
|
|
4932
|
+
constant int64_t & ne02,
|
|
4933
|
+
constant uint64_t & nb00,
|
|
4934
|
+
constant uint64_t & nb01,
|
|
4935
|
+
constant uint64_t & nb02,
|
|
4936
|
+
constant int64_t & ne10,
|
|
4937
|
+
constant int64_t & ne11,
|
|
4938
|
+
constant int64_t & ne12,
|
|
4939
|
+
constant int64_t & ne13,
|
|
4940
|
+
constant uint64_t & nb10,
|
|
4941
|
+
constant uint64_t & nb11,
|
|
4942
|
+
constant uint64_t & nb12,
|
|
4943
|
+
constant int64_t & ne0,
|
|
4944
|
+
constant int64_t & ne1,
|
|
4945
|
+
constant uint64_t & nb1,
|
|
4946
|
+
constant uint & r2,
|
|
4947
|
+
constant uint & r3,
|
|
4948
|
+
constant int & idx,
|
|
4949
|
+
device const char * src00,
|
|
4950
|
+
device const char * src01,
|
|
4951
|
+
device const char * src02,
|
|
4952
|
+
device const char * src03,
|
|
4953
|
+
device const char * src04,
|
|
4954
|
+
device const char * src05,
|
|
4955
|
+
device const char * src06,
|
|
4956
|
+
device const char * src07,
|
|
4957
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4958
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4959
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4960
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4961
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4962
|
+
|
|
4963
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4964
|
+
|
|
4965
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4966
|
+
|
|
4967
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4968
|
+
|
|
4969
|
+
kernel_mul_mv_f32_f32_impl(
|
|
4970
|
+
src0[id],
|
|
4971
|
+
src1 + bid*nb11,
|
|
4972
|
+
dst + bid*ne0,
|
|
4973
|
+
ne00,
|
|
4974
|
+
ne01,
|
|
4975
|
+
ne02,
|
|
4976
|
+
nb00,
|
|
4977
|
+
nb01,
|
|
4978
|
+
nb02,
|
|
4979
|
+
ne10,
|
|
4980
|
+
ne11,
|
|
4981
|
+
ne12,
|
|
4982
|
+
nb10,
|
|
4983
|
+
nb11,
|
|
4984
|
+
nb12,
|
|
4985
|
+
ne0,
|
|
4986
|
+
ne1,
|
|
4987
|
+
r2,
|
|
4988
|
+
r3,
|
|
4989
|
+
tgpig,
|
|
4990
|
+
tiisg);
|
|
4991
|
+
}
|
|
4992
|
+
|
|
4993
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
4994
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
|
4995
|
+
device const char * ids,
|
|
4996
|
+
device const char * src1,
|
|
4997
|
+
device float * dst,
|
|
4998
|
+
constant uint64_t & nbi1,
|
|
4999
|
+
constant int64_t & ne00,
|
|
5000
|
+
constant int64_t & ne01,
|
|
5001
|
+
constant int64_t & ne02,
|
|
5002
|
+
constant uint64_t & nb00,
|
|
5003
|
+
constant uint64_t & nb01,
|
|
5004
|
+
constant uint64_t & nb02,
|
|
5005
|
+
constant int64_t & ne10,
|
|
5006
|
+
constant int64_t & ne11,
|
|
5007
|
+
constant int64_t & ne12,
|
|
5008
|
+
constant int64_t & ne13,
|
|
5009
|
+
constant uint64_t & nb10,
|
|
5010
|
+
constant uint64_t & nb11,
|
|
5011
|
+
constant uint64_t & nb12,
|
|
5012
|
+
constant int64_t & ne0,
|
|
5013
|
+
constant int64_t & ne1,
|
|
5014
|
+
constant uint64_t & nb1,
|
|
5015
|
+
constant uint & r2,
|
|
5016
|
+
constant uint & r3,
|
|
5017
|
+
constant int & idx,
|
|
5018
|
+
device const char * src00,
|
|
5019
|
+
device const char * src01,
|
|
5020
|
+
device const char * src02,
|
|
5021
|
+
device const char * src03,
|
|
5022
|
+
device const char * src04,
|
|
5023
|
+
device const char * src05,
|
|
5024
|
+
device const char * src06,
|
|
5025
|
+
device const char * src07,
|
|
5026
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5027
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5028
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5029
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5030
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5031
|
+
|
|
5032
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5033
|
+
|
|
5034
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5035
|
+
|
|
5036
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5037
|
+
|
|
5038
|
+
kernel_mul_mv_f16_f32_impl(
|
|
5039
|
+
src0[id],
|
|
5040
|
+
src1 + bid*nb11,
|
|
5041
|
+
dst + bid*ne0,
|
|
5042
|
+
ne00,
|
|
5043
|
+
ne01,
|
|
5044
|
+
ne02,
|
|
5045
|
+
nb00,
|
|
5046
|
+
nb01,
|
|
5047
|
+
nb02,
|
|
5048
|
+
ne10,
|
|
5049
|
+
ne11,
|
|
5050
|
+
ne12,
|
|
5051
|
+
nb10,
|
|
5052
|
+
nb11,
|
|
5053
|
+
nb12,
|
|
5054
|
+
ne0,
|
|
5055
|
+
ne1,
|
|
5056
|
+
r2,
|
|
5057
|
+
r3,
|
|
5058
|
+
tgpig,
|
|
5059
|
+
tiisg);
|
|
5060
|
+
}
|
|
5061
|
+
|
|
5062
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
5063
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
5064
|
+
device const char * ids,
|
|
5065
|
+
device const char * src1,
|
|
5066
|
+
device float * dst,
|
|
5067
|
+
constant uint64_t & nbi1,
|
|
5068
|
+
constant int64_t & ne00,
|
|
5069
|
+
constant int64_t & ne01,
|
|
5070
|
+
constant int64_t & ne02,
|
|
5071
|
+
constant uint64_t & nb00,
|
|
5072
|
+
constant uint64_t & nb01,
|
|
5073
|
+
constant uint64_t & nb02,
|
|
5074
|
+
constant int64_t & ne10,
|
|
5075
|
+
constant int64_t & ne11,
|
|
5076
|
+
constant int64_t & ne12,
|
|
5077
|
+
constant int64_t & ne13,
|
|
5078
|
+
constant uint64_t & nb10,
|
|
5079
|
+
constant uint64_t & nb11,
|
|
5080
|
+
constant uint64_t & nb12,
|
|
5081
|
+
constant int64_t & ne0,
|
|
5082
|
+
constant int64_t & ne1,
|
|
5083
|
+
constant uint64_t & nb1,
|
|
5084
|
+
constant uint & r2,
|
|
5085
|
+
constant uint & r3,
|
|
5086
|
+
constant int & idx,
|
|
5087
|
+
device const char * src00,
|
|
5088
|
+
device const char * src01,
|
|
5089
|
+
device const char * src02,
|
|
5090
|
+
device const char * src03,
|
|
5091
|
+
device const char * src04,
|
|
5092
|
+
device const char * src05,
|
|
5093
|
+
device const char * src06,
|
|
5094
|
+
device const char * src07,
|
|
5095
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5096
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5097
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5098
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5099
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5100
|
+
|
|
5101
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5102
|
+
|
|
5103
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5104
|
+
|
|
5105
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5106
|
+
|
|
5107
|
+
kernel_mul_mv_q8_0_f32_impl(
|
|
5108
|
+
src0[id],
|
|
5109
|
+
(device const float *) (src1 + bid*nb11),
|
|
5110
|
+
dst + bid*ne0,
|
|
5111
|
+
ne00,
|
|
5112
|
+
ne01,
|
|
5113
|
+
ne02,
|
|
5114
|
+
ne10,
|
|
5115
|
+
ne12,
|
|
5116
|
+
ne0,
|
|
5117
|
+
ne1,
|
|
5118
|
+
r2,
|
|
5119
|
+
r3,
|
|
5120
|
+
tgpig,
|
|
5121
|
+
tiisg,
|
|
5122
|
+
sgitg);
|
|
5123
|
+
}
|
|
5124
|
+
|
|
5125
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
5126
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
5127
|
+
device const char * ids,
|
|
5128
|
+
device const char * src1,
|
|
5129
|
+
device float * dst,
|
|
5130
|
+
constant uint64_t & nbi1,
|
|
5131
|
+
constant int64_t & ne00,
|
|
5132
|
+
constant int64_t & ne01,
|
|
5133
|
+
constant int64_t & ne02,
|
|
5134
|
+
constant uint64_t & nb00,
|
|
5135
|
+
constant uint64_t & nb01,
|
|
5136
|
+
constant uint64_t & nb02,
|
|
5137
|
+
constant int64_t & ne10,
|
|
5138
|
+
constant int64_t & ne11,
|
|
5139
|
+
constant int64_t & ne12,
|
|
5140
|
+
constant int64_t & ne13,
|
|
5141
|
+
constant uint64_t & nb10,
|
|
5142
|
+
constant uint64_t & nb11,
|
|
5143
|
+
constant uint64_t & nb12,
|
|
5144
|
+
constant int64_t & ne0,
|
|
5145
|
+
constant int64_t & ne1,
|
|
5146
|
+
constant uint64_t & nb1,
|
|
5147
|
+
constant uint & r2,
|
|
5148
|
+
constant uint & r3,
|
|
5149
|
+
constant int & idx,
|
|
5150
|
+
device const char * src00,
|
|
5151
|
+
device const char * src01,
|
|
5152
|
+
device const char * src02,
|
|
5153
|
+
device const char * src03,
|
|
5154
|
+
device const char * src04,
|
|
5155
|
+
device const char * src05,
|
|
5156
|
+
device const char * src06,
|
|
5157
|
+
device const char * src07,
|
|
5158
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5159
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5160
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5161
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5162
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5163
|
+
|
|
5164
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5165
|
+
|
|
5166
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5167
|
+
|
|
5168
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5169
|
+
|
|
5170
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5171
|
+
src0[id],
|
|
5172
|
+
(device const float *) (src1 + bid*nb11),
|
|
5173
|
+
dst + bid*ne0,
|
|
5174
|
+
ne00,
|
|
5175
|
+
ne01,
|
|
5176
|
+
ne02,
|
|
5177
|
+
ne10,
|
|
5178
|
+
ne12,
|
|
5179
|
+
ne0,
|
|
5180
|
+
ne1,
|
|
5181
|
+
r2,
|
|
5182
|
+
r3,
|
|
5183
|
+
tgpig,
|
|
5184
|
+
tiisg,
|
|
5185
|
+
sgitg);
|
|
5186
|
+
}
|
|
5187
|
+
|
|
5188
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
5189
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
5190
|
+
device const char * ids,
|
|
5191
|
+
device const char * src1,
|
|
5192
|
+
device float * dst,
|
|
5193
|
+
constant uint64_t & nbi1,
|
|
5194
|
+
constant int64_t & ne00,
|
|
5195
|
+
constant int64_t & ne01,
|
|
5196
|
+
constant int64_t & ne02,
|
|
5197
|
+
constant uint64_t & nb00,
|
|
5198
|
+
constant uint64_t & nb01,
|
|
5199
|
+
constant uint64_t & nb02,
|
|
5200
|
+
constant int64_t & ne10,
|
|
5201
|
+
constant int64_t & ne11,
|
|
5202
|
+
constant int64_t & ne12,
|
|
5203
|
+
constant int64_t & ne13,
|
|
5204
|
+
constant uint64_t & nb10,
|
|
5205
|
+
constant uint64_t & nb11,
|
|
5206
|
+
constant uint64_t & nb12,
|
|
5207
|
+
constant int64_t & ne0,
|
|
5208
|
+
constant int64_t & ne1,
|
|
5209
|
+
constant uint64_t & nb1,
|
|
5210
|
+
constant uint & r2,
|
|
5211
|
+
constant uint & r3,
|
|
5212
|
+
constant int & idx,
|
|
5213
|
+
device const char * src00,
|
|
5214
|
+
device const char * src01,
|
|
5215
|
+
device const char * src02,
|
|
5216
|
+
device const char * src03,
|
|
5217
|
+
device const char * src04,
|
|
5218
|
+
device const char * src05,
|
|
5219
|
+
device const char * src06,
|
|
5220
|
+
device const char * src07,
|
|
5221
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5222
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5223
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5224
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5225
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5226
|
+
|
|
5227
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5228
|
+
|
|
5229
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5230
|
+
|
|
5231
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5232
|
+
|
|
5233
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5234
|
+
src0[id],
|
|
5235
|
+
(device const float *) (src1 + bid*nb11),
|
|
5236
|
+
dst + bid*ne0,
|
|
5237
|
+
ne00,
|
|
5238
|
+
ne01,
|
|
5239
|
+
ne02,
|
|
5240
|
+
ne10,
|
|
5241
|
+
ne12,
|
|
5242
|
+
ne0,
|
|
5243
|
+
ne1,
|
|
5244
|
+
r2,
|
|
5245
|
+
r3,
|
|
5246
|
+
tgpig,
|
|
5247
|
+
tiisg,
|
|
5248
|
+
sgitg);
|
|
5249
|
+
}
|
|
5250
|
+
|
|
5251
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
5252
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
5253
|
+
device const char * ids,
|
|
5254
|
+
device const char * src1,
|
|
5255
|
+
device float * dst,
|
|
5256
|
+
constant uint64_t & nbi1,
|
|
5257
|
+
constant int64_t & ne00,
|
|
5258
|
+
constant int64_t & ne01,
|
|
5259
|
+
constant int64_t & ne02,
|
|
5260
|
+
constant uint64_t & nb00,
|
|
5261
|
+
constant uint64_t & nb01,
|
|
5262
|
+
constant uint64_t & nb02,
|
|
5263
|
+
constant int64_t & ne10,
|
|
5264
|
+
constant int64_t & ne11,
|
|
5265
|
+
constant int64_t & ne12,
|
|
5266
|
+
constant int64_t & ne13,
|
|
5267
|
+
constant uint64_t & nb10,
|
|
5268
|
+
constant uint64_t & nb11,
|
|
5269
|
+
constant uint64_t & nb12,
|
|
5270
|
+
constant int64_t & ne0,
|
|
5271
|
+
constant int64_t & ne1,
|
|
5272
|
+
constant uint64_t & nb1,
|
|
5273
|
+
constant uint & r2,
|
|
5274
|
+
constant uint & r3,
|
|
5275
|
+
constant int & idx,
|
|
5276
|
+
device const char * src00,
|
|
5277
|
+
device const char * src01,
|
|
5278
|
+
device const char * src02,
|
|
5279
|
+
device const char * src03,
|
|
5280
|
+
device const char * src04,
|
|
5281
|
+
device const char * src05,
|
|
5282
|
+
device const char * src06,
|
|
5283
|
+
device const char * src07,
|
|
5284
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5285
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5286
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5287
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5288
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5289
|
+
|
|
5290
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5291
|
+
|
|
5292
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5293
|
+
|
|
5294
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5295
|
+
|
|
5296
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5297
|
+
src0[id],
|
|
5298
|
+
(device const float *) (src1 + bid*nb11),
|
|
5299
|
+
dst + bid*ne0,
|
|
5300
|
+
ne00,
|
|
5301
|
+
ne01,
|
|
5302
|
+
ne02,
|
|
5303
|
+
ne10,
|
|
5304
|
+
ne12,
|
|
5305
|
+
ne0,
|
|
5306
|
+
ne1,
|
|
5307
|
+
r2,
|
|
5308
|
+
r3,
|
|
5309
|
+
tgpig,
|
|
5310
|
+
tiisg,
|
|
5311
|
+
sgitg);
|
|
5312
|
+
}
|
|
5313
|
+
|
|
5314
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
5315
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
5316
|
+
device const char * ids,
|
|
5317
|
+
device const char * src1,
|
|
5318
|
+
device float * dst,
|
|
5319
|
+
constant uint64_t & nbi1,
|
|
5320
|
+
constant int64_t & ne00,
|
|
5321
|
+
constant int64_t & ne01,
|
|
5322
|
+
constant int64_t & ne02,
|
|
5323
|
+
constant uint64_t & nb00,
|
|
5324
|
+
constant uint64_t & nb01,
|
|
5325
|
+
constant uint64_t & nb02,
|
|
5326
|
+
constant int64_t & ne10,
|
|
5327
|
+
constant int64_t & ne11,
|
|
5328
|
+
constant int64_t & ne12,
|
|
5329
|
+
constant int64_t & ne13,
|
|
5330
|
+
constant uint64_t & nb10,
|
|
5331
|
+
constant uint64_t & nb11,
|
|
5332
|
+
constant uint64_t & nb12,
|
|
5333
|
+
constant int64_t & ne0,
|
|
5334
|
+
constant int64_t & ne1,
|
|
5335
|
+
constant uint64_t & nb1,
|
|
5336
|
+
constant uint & r2,
|
|
5337
|
+
constant uint & r3,
|
|
5338
|
+
constant int & idx,
|
|
5339
|
+
device const char * src00,
|
|
5340
|
+
device const char * src01,
|
|
5341
|
+
device const char * src02,
|
|
5342
|
+
device const char * src03,
|
|
5343
|
+
device const char * src04,
|
|
5344
|
+
device const char * src05,
|
|
5345
|
+
device const char * src06,
|
|
5346
|
+
device const char * src07,
|
|
5347
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5348
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5349
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5350
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5351
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5352
|
+
|
|
5353
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5354
|
+
|
|
5355
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5356
|
+
|
|
5357
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5358
|
+
|
|
5359
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5360
|
+
src0[id],
|
|
5361
|
+
(device const float *) (src1 + bid*nb11),
|
|
5362
|
+
dst + bid*ne0,
|
|
5363
|
+
ne00,
|
|
5364
|
+
ne01,
|
|
5365
|
+
ne02,
|
|
5366
|
+
ne10,
|
|
5367
|
+
ne12,
|
|
5368
|
+
ne0,
|
|
5369
|
+
ne1,
|
|
5370
|
+
r2,
|
|
5371
|
+
r3,
|
|
5372
|
+
tgpig,
|
|
5373
|
+
tiisg,
|
|
5374
|
+
sgitg);
|
|
5375
|
+
}
|
|
5376
|
+
|
|
5377
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
5378
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
5379
|
+
device const char * ids,
|
|
5380
|
+
device const char * src1,
|
|
5381
|
+
device float * dst,
|
|
5382
|
+
constant uint64_t & nbi1,
|
|
5383
|
+
constant int64_t & ne00,
|
|
5384
|
+
constant int64_t & ne01,
|
|
5385
|
+
constant int64_t & ne02,
|
|
5386
|
+
constant uint64_t & nb00,
|
|
5387
|
+
constant uint64_t & nb01,
|
|
5388
|
+
constant uint64_t & nb02,
|
|
5389
|
+
constant int64_t & ne10,
|
|
5390
|
+
constant int64_t & ne11,
|
|
5391
|
+
constant int64_t & ne12,
|
|
5392
|
+
constant int64_t & ne13,
|
|
5393
|
+
constant uint64_t & nb10,
|
|
5394
|
+
constant uint64_t & nb11,
|
|
5395
|
+
constant uint64_t & nb12,
|
|
5396
|
+
constant int64_t & ne0,
|
|
5397
|
+
constant int64_t & ne1,
|
|
5398
|
+
constant uint64_t & nb1,
|
|
5399
|
+
constant uint & r2,
|
|
5400
|
+
constant uint & r3,
|
|
5401
|
+
constant int & idx,
|
|
5402
|
+
device const char * src00,
|
|
5403
|
+
device const char * src01,
|
|
5404
|
+
device const char * src02,
|
|
5405
|
+
device const char * src03,
|
|
5406
|
+
device const char * src04,
|
|
5407
|
+
device const char * src05,
|
|
5408
|
+
device const char * src06,
|
|
5409
|
+
device const char * src07,
|
|
5410
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5411
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5412
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5413
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5414
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5415
|
+
|
|
5416
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5417
|
+
|
|
5418
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5419
|
+
|
|
5420
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5421
|
+
|
|
5422
|
+
kernel_mul_mv_q2_K_f32_impl(
|
|
5423
|
+
src0[id],
|
|
5424
|
+
(device const float *) (src1 + bid*nb11),
|
|
5425
|
+
dst + bid*ne0,
|
|
5426
|
+
ne00,
|
|
5427
|
+
ne01,
|
|
5428
|
+
ne02,
|
|
5429
|
+
ne10,
|
|
5430
|
+
ne12,
|
|
5431
|
+
ne0,
|
|
5432
|
+
ne1,
|
|
5433
|
+
r2,
|
|
5434
|
+
r3,
|
|
5435
|
+
tgpig,
|
|
5436
|
+
tiisg,
|
|
5437
|
+
sgitg);
|
|
5438
|
+
}
|
|
5439
|
+
|
|
5440
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
5441
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
5442
|
+
device const char * ids,
|
|
5443
|
+
device const char * src1,
|
|
5444
|
+
device float * dst,
|
|
5445
|
+
constant uint64_t & nbi1,
|
|
5446
|
+
constant int64_t & ne00,
|
|
5447
|
+
constant int64_t & ne01,
|
|
5448
|
+
constant int64_t & ne02,
|
|
5449
|
+
constant uint64_t & nb00,
|
|
5450
|
+
constant uint64_t & nb01,
|
|
5451
|
+
constant uint64_t & nb02,
|
|
5452
|
+
constant int64_t & ne10,
|
|
5453
|
+
constant int64_t & ne11,
|
|
5454
|
+
constant int64_t & ne12,
|
|
5455
|
+
constant int64_t & ne13,
|
|
5456
|
+
constant uint64_t & nb10,
|
|
5457
|
+
constant uint64_t & nb11,
|
|
5458
|
+
constant uint64_t & nb12,
|
|
5459
|
+
constant int64_t & ne0,
|
|
5460
|
+
constant int64_t & ne1,
|
|
5461
|
+
constant uint64_t & nb1,
|
|
5462
|
+
constant uint & r2,
|
|
5463
|
+
constant uint & r3,
|
|
5464
|
+
constant int & idx,
|
|
5465
|
+
device const char * src00,
|
|
5466
|
+
device const char * src01,
|
|
5467
|
+
device const char * src02,
|
|
5468
|
+
device const char * src03,
|
|
5469
|
+
device const char * src04,
|
|
5470
|
+
device const char * src05,
|
|
5471
|
+
device const char * src06,
|
|
5472
|
+
device const char * src07,
|
|
5473
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5474
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5475
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5476
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5477
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5478
|
+
|
|
5479
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5480
|
+
|
|
5481
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5482
|
+
|
|
5483
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5484
|
+
|
|
5485
|
+
kernel_mul_mv_q3_K_f32_impl(
|
|
5486
|
+
src0[id],
|
|
5487
|
+
(device const float *) (src1 + bid*nb11),
|
|
5488
|
+
dst + bid*ne0,
|
|
5489
|
+
ne00,
|
|
5490
|
+
ne01,
|
|
5491
|
+
ne02,
|
|
5492
|
+
ne10,
|
|
5493
|
+
ne12,
|
|
5494
|
+
ne0,
|
|
5495
|
+
ne1,
|
|
5496
|
+
r2,
|
|
5497
|
+
r3,
|
|
5498
|
+
tgpig,
|
|
5499
|
+
tiisg,
|
|
5500
|
+
sgitg);
|
|
5501
|
+
}
|
|
5502
|
+
|
|
5503
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
5504
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
5505
|
+
device const char * ids,
|
|
5506
|
+
device const char * src1,
|
|
5507
|
+
device float * dst,
|
|
5508
|
+
constant uint64_t & nbi1,
|
|
5509
|
+
constant int64_t & ne00,
|
|
5510
|
+
constant int64_t & ne01,
|
|
5511
|
+
constant int64_t & ne02,
|
|
5512
|
+
constant uint64_t & nb00,
|
|
5513
|
+
constant uint64_t & nb01,
|
|
5514
|
+
constant uint64_t & nb02,
|
|
5515
|
+
constant int64_t & ne10,
|
|
5516
|
+
constant int64_t & ne11,
|
|
5517
|
+
constant int64_t & ne12,
|
|
5518
|
+
constant int64_t & ne13,
|
|
5519
|
+
constant uint64_t & nb10,
|
|
5520
|
+
constant uint64_t & nb11,
|
|
5521
|
+
constant uint64_t & nb12,
|
|
5522
|
+
constant int64_t & ne0,
|
|
5523
|
+
constant int64_t & ne1,
|
|
5524
|
+
constant uint64_t & nb1,
|
|
5525
|
+
constant uint & r2,
|
|
5526
|
+
constant uint & r3,
|
|
5527
|
+
constant int & idx,
|
|
5528
|
+
device const char * src00,
|
|
5529
|
+
device const char * src01,
|
|
5530
|
+
device const char * src02,
|
|
5531
|
+
device const char * src03,
|
|
5532
|
+
device const char * src04,
|
|
5533
|
+
device const char * src05,
|
|
5534
|
+
device const char * src06,
|
|
5535
|
+
device const char * src07,
|
|
5536
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5537
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5538
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5539
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5540
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5541
|
+
|
|
5542
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5543
|
+
|
|
5544
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5545
|
+
|
|
5546
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5547
|
+
|
|
5548
|
+
kernel_mul_mv_q4_K_f32_impl(
|
|
5549
|
+
src0[id],
|
|
5550
|
+
(device const float *) (src1 + bid*nb11),
|
|
5551
|
+
dst + bid*ne0,
|
|
5552
|
+
ne00,
|
|
5553
|
+
ne01,
|
|
5554
|
+
ne02,
|
|
5555
|
+
ne10,
|
|
5556
|
+
ne12,
|
|
5557
|
+
ne0,
|
|
5558
|
+
ne1,
|
|
5559
|
+
r2,
|
|
5560
|
+
r3,
|
|
5561
|
+
tgpig,
|
|
5562
|
+
tiisg,
|
|
5563
|
+
sgitg);
|
|
5564
|
+
}
|
|
5565
|
+
|
|
5566
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
5567
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
5568
|
+
device const char * ids,
|
|
5569
|
+
device const char * src1,
|
|
5570
|
+
device float * dst,
|
|
5571
|
+
constant uint64_t & nbi1,
|
|
5572
|
+
constant int64_t & ne00,
|
|
5573
|
+
constant int64_t & ne01,
|
|
5574
|
+
constant int64_t & ne02,
|
|
5575
|
+
constant uint64_t & nb00,
|
|
5576
|
+
constant uint64_t & nb01,
|
|
5577
|
+
constant uint64_t & nb02,
|
|
5578
|
+
constant int64_t & ne10,
|
|
5579
|
+
constant int64_t & ne11,
|
|
5580
|
+
constant int64_t & ne12,
|
|
5581
|
+
constant int64_t & ne13,
|
|
5582
|
+
constant uint64_t & nb10,
|
|
5583
|
+
constant uint64_t & nb11,
|
|
5584
|
+
constant uint64_t & nb12,
|
|
5585
|
+
constant int64_t & ne0,
|
|
5586
|
+
constant int64_t & ne1,
|
|
5587
|
+
constant uint64_t & nb1,
|
|
5588
|
+
constant uint & r2,
|
|
5589
|
+
constant uint & r3,
|
|
5590
|
+
constant int & idx,
|
|
5591
|
+
device const char * src00,
|
|
5592
|
+
device const char * src01,
|
|
5593
|
+
device const char * src02,
|
|
5594
|
+
device const char * src03,
|
|
5595
|
+
device const char * src04,
|
|
5596
|
+
device const char * src05,
|
|
5597
|
+
device const char * src06,
|
|
5598
|
+
device const char * src07,
|
|
5599
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5600
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5601
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5602
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5603
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5604
|
+
|
|
5605
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5606
|
+
|
|
5607
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5608
|
+
|
|
5609
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5610
|
+
|
|
5611
|
+
kernel_mul_mv_q5_K_f32_impl(
|
|
5612
|
+
src0[id],
|
|
5613
|
+
(device const float *) (src1 + bid*nb11),
|
|
5614
|
+
dst + bid*ne0,
|
|
5615
|
+
ne00,
|
|
5616
|
+
ne01,
|
|
5617
|
+
ne02,
|
|
5618
|
+
ne10,
|
|
5619
|
+
ne12,
|
|
5620
|
+
ne0,
|
|
5621
|
+
ne1,
|
|
5622
|
+
r2,
|
|
5623
|
+
r3,
|
|
5624
|
+
tgpig,
|
|
5625
|
+
tiisg,
|
|
5626
|
+
sgitg);
|
|
5627
|
+
}
|
|
5628
|
+
|
|
5629
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
5630
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
5631
|
+
device const char * ids,
|
|
5632
|
+
device const char * src1,
|
|
5633
|
+
device float * dst,
|
|
5634
|
+
constant uint64_t & nbi1,
|
|
5635
|
+
constant int64_t & ne00,
|
|
5636
|
+
constant int64_t & ne01,
|
|
5637
|
+
constant int64_t & ne02,
|
|
5638
|
+
constant uint64_t & nb00,
|
|
5639
|
+
constant uint64_t & nb01,
|
|
5640
|
+
constant uint64_t & nb02,
|
|
5641
|
+
constant int64_t & ne10,
|
|
5642
|
+
constant int64_t & ne11,
|
|
5643
|
+
constant int64_t & ne12,
|
|
5644
|
+
constant int64_t & ne13,
|
|
5645
|
+
constant uint64_t & nb10,
|
|
5646
|
+
constant uint64_t & nb11,
|
|
5647
|
+
constant uint64_t & nb12,
|
|
5648
|
+
constant int64_t & ne0,
|
|
5649
|
+
constant int64_t & ne1,
|
|
5650
|
+
constant uint64_t & nb1,
|
|
5651
|
+
constant uint & r2,
|
|
5652
|
+
constant uint & r3,
|
|
5653
|
+
constant int & idx,
|
|
5654
|
+
device const char * src00,
|
|
5655
|
+
device const char * src01,
|
|
5656
|
+
device const char * src02,
|
|
5657
|
+
device const char * src03,
|
|
5658
|
+
device const char * src04,
|
|
5659
|
+
device const char * src05,
|
|
5660
|
+
device const char * src06,
|
|
5661
|
+
device const char * src07,
|
|
5662
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5663
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5664
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5665
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5666
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5667
|
+
|
|
5668
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5669
|
+
|
|
5670
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5671
|
+
|
|
5672
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5673
|
+
|
|
5674
|
+
kernel_mul_mv_q6_K_f32_impl(
|
|
5675
|
+
src0[id],
|
|
5676
|
+
(device const float *) (src1 + bid*nb11),
|
|
5677
|
+
dst + bid*ne0,
|
|
5678
|
+
ne00,
|
|
5679
|
+
ne01,
|
|
5680
|
+
ne02,
|
|
5681
|
+
ne10,
|
|
5682
|
+
ne12,
|
|
5683
|
+
ne0,
|
|
5684
|
+
ne1,
|
|
5685
|
+
r2,
|
|
5686
|
+
r3,
|
|
5687
|
+
tgpig,
|
|
5688
|
+
tiisg,
|
|
5689
|
+
sgitg);
|
|
5690
|
+
}
|
|
5691
|
+
|
|
5692
|
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
|
5693
|
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
5694
|
+
device const char * ids,
|
|
5695
|
+
device const char * src1,
|
|
5696
|
+
device float * dst,
|
|
5697
|
+
constant uint64_t & nbi1,
|
|
5698
|
+
constant int64_t & ne00,
|
|
5699
|
+
constant int64_t & ne01,
|
|
5700
|
+
constant int64_t & ne02,
|
|
5701
|
+
constant uint64_t & nb00,
|
|
5702
|
+
constant uint64_t & nb01,
|
|
5703
|
+
constant uint64_t & nb02,
|
|
5704
|
+
constant int64_t & ne10,
|
|
5705
|
+
constant int64_t & ne11,
|
|
5706
|
+
constant int64_t & ne12,
|
|
5707
|
+
constant int64_t & ne13,
|
|
5708
|
+
constant uint64_t & nb10,
|
|
5709
|
+
constant uint64_t & nb11,
|
|
5710
|
+
constant uint64_t & nb12,
|
|
5711
|
+
constant int64_t & ne0,
|
|
5712
|
+
constant int64_t & ne1,
|
|
5713
|
+
constant uint64_t & nb1,
|
|
5714
|
+
constant uint & r2,
|
|
5715
|
+
constant uint & r3,
|
|
5716
|
+
constant int & idx,
|
|
5717
|
+
device const char * src00,
|
|
5718
|
+
device const char * src01,
|
|
5719
|
+
device const char * src02,
|
|
5720
|
+
device const char * src03,
|
|
5721
|
+
device const char * src04,
|
|
5722
|
+
device const char * src05,
|
|
5723
|
+
device const char * src06,
|
|
5724
|
+
device const char * src07,
|
|
5725
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5726
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5727
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5728
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5729
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5730
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5731
|
+
|
|
5732
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5733
|
+
|
|
5734
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5735
|
+
|
|
5736
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5737
|
+
|
|
5738
|
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5739
|
+
src0[id],
|
|
5740
|
+
(device const float *) (src1 + bid*nb11),
|
|
5741
|
+
dst + bid*ne0,
|
|
5742
|
+
ne00,
|
|
5743
|
+
ne01,
|
|
5744
|
+
ne02,
|
|
5745
|
+
ne10,
|
|
5746
|
+
ne12,
|
|
5747
|
+
ne0,
|
|
5748
|
+
ne1,
|
|
5749
|
+
r2,
|
|
5750
|
+
r3,
|
|
5751
|
+
shared_values,
|
|
5752
|
+
tgpig,
|
|
5753
|
+
tiisg,
|
|
5754
|
+
sgitg);
|
|
5755
|
+
}
|
|
5756
|
+
|
|
5757
|
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
|
5758
|
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
5759
|
+
device const char * ids,
|
|
5760
|
+
device const char * src1,
|
|
5761
|
+
device float * dst,
|
|
5762
|
+
constant uint64_t & nbi1,
|
|
5763
|
+
constant int64_t & ne00,
|
|
5764
|
+
constant int64_t & ne01,
|
|
5765
|
+
constant int64_t & ne02,
|
|
5766
|
+
constant uint64_t & nb00,
|
|
5767
|
+
constant uint64_t & nb01,
|
|
5768
|
+
constant uint64_t & nb02,
|
|
5769
|
+
constant int64_t & ne10,
|
|
5770
|
+
constant int64_t & ne11,
|
|
5771
|
+
constant int64_t & ne12,
|
|
5772
|
+
constant int64_t & ne13,
|
|
5773
|
+
constant uint64_t & nb10,
|
|
5774
|
+
constant uint64_t & nb11,
|
|
5775
|
+
constant uint64_t & nb12,
|
|
5776
|
+
constant int64_t & ne0,
|
|
5777
|
+
constant int64_t & ne1,
|
|
5778
|
+
constant uint64_t & nb1,
|
|
5779
|
+
constant uint & r2,
|
|
5780
|
+
constant uint & r3,
|
|
5781
|
+
constant int & idx,
|
|
5782
|
+
device const char * src00,
|
|
5783
|
+
device const char * src01,
|
|
5784
|
+
device const char * src02,
|
|
5785
|
+
device const char * src03,
|
|
5786
|
+
device const char * src04,
|
|
5787
|
+
device const char * src05,
|
|
5788
|
+
device const char * src06,
|
|
5789
|
+
device const char * src07,
|
|
5790
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5791
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5792
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5793
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5794
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5795
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5796
|
+
|
|
5797
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5798
|
+
|
|
5799
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5800
|
+
|
|
5801
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5802
|
+
|
|
5803
|
+
kernel_mul_mv_iq2_xs_f32_impl(
|
|
5804
|
+
src0[id],
|
|
5805
|
+
(device const float *) (src1 + bid*nb11),
|
|
5806
|
+
dst + bid*ne0,
|
|
5807
|
+
ne00,
|
|
5808
|
+
ne01,
|
|
5809
|
+
ne02,
|
|
5810
|
+
ne10,
|
|
5811
|
+
ne12,
|
|
5812
|
+
ne0,
|
|
5813
|
+
ne1,
|
|
5814
|
+
r2,
|
|
5815
|
+
r3,
|
|
5816
|
+
shared_values,
|
|
5817
|
+
tgpig,
|
|
5818
|
+
tiisg,
|
|
5819
|
+
sgitg);
|
|
5820
|
+
}
|