whisper.rn 0.4.0-rc.2 → 0.4.0-rc.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -0
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +29 -15
- package/android/src/main/jni.cpp +6 -2
- package/cpp/ggml-alloc.c +413 -280
- package/cpp/ggml-alloc.h +67 -8
- package/cpp/ggml-backend-impl.h +87 -0
- package/cpp/ggml-backend.c +950 -0
- package/cpp/ggml-backend.h +136 -0
- package/cpp/ggml-impl.h +243 -0
- package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
- package/cpp/ggml-metal.h +21 -0
- package/cpp/ggml-metal.m +623 -234
- package/cpp/ggml-quants.c +7377 -0
- package/cpp/ggml-quants.h +224 -0
- package/cpp/ggml.c +3773 -4455
- package/cpp/ggml.h +279 -146
- package/cpp/whisper.cpp +182 -103
- package/cpp/whisper.h +48 -11
- package/ios/RNWhisper.mm +8 -2
- package/ios/RNWhisperContext.h +6 -2
- package/ios/RNWhisperContext.mm +97 -26
- package/jest/mock.js +1 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +28 -9
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +28 -9
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +7 -1
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +8 -3
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +8 -1
- package/src/index.ts +30 -18
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -2
|
@@ -13,23 +13,85 @@ typedef struct {
|
|
|
13
13
|
|
|
14
14
|
#define QK4_1 32
|
|
15
15
|
typedef struct {
|
|
16
|
-
half d;
|
|
17
|
-
half m;
|
|
16
|
+
half d; // delta
|
|
17
|
+
half m; // min
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
|
19
19
|
} block_q4_1;
|
|
20
20
|
|
|
21
|
+
#define QK5_0 32
|
|
22
|
+
typedef struct {
|
|
23
|
+
half d; // delta
|
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
|
26
|
+
} block_q5_0;
|
|
27
|
+
|
|
28
|
+
#define QK5_1 32
|
|
29
|
+
typedef struct {
|
|
30
|
+
half d; // delta
|
|
31
|
+
half m; // min
|
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
|
34
|
+
} block_q5_1;
|
|
35
|
+
|
|
21
36
|
#define QK8_0 32
|
|
22
37
|
typedef struct {
|
|
23
38
|
half d; // delta
|
|
24
39
|
int8_t qs[QK8_0]; // quants
|
|
25
40
|
} block_q8_0;
|
|
26
41
|
|
|
42
|
+
// general-purpose kernel for addition of two tensors
|
|
43
|
+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
|
44
|
+
// cons: not very efficient
|
|
27
45
|
kernel void kernel_add(
|
|
28
|
-
device const
|
|
29
|
-
device const
|
|
30
|
-
device
|
|
31
|
-
|
|
32
|
-
|
|
46
|
+
device const char * src0,
|
|
47
|
+
device const char * src1,
|
|
48
|
+
device char * dst,
|
|
49
|
+
constant int64_t & ne00,
|
|
50
|
+
constant int64_t & ne01,
|
|
51
|
+
constant int64_t & ne02,
|
|
52
|
+
constant int64_t & ne03,
|
|
53
|
+
constant int64_t & nb00,
|
|
54
|
+
constant int64_t & nb01,
|
|
55
|
+
constant int64_t & nb02,
|
|
56
|
+
constant int64_t & nb03,
|
|
57
|
+
constant int64_t & ne10,
|
|
58
|
+
constant int64_t & ne11,
|
|
59
|
+
constant int64_t & ne12,
|
|
60
|
+
constant int64_t & ne13,
|
|
61
|
+
constant int64_t & nb10,
|
|
62
|
+
constant int64_t & nb11,
|
|
63
|
+
constant int64_t & nb12,
|
|
64
|
+
constant int64_t & nb13,
|
|
65
|
+
constant int64_t & ne0,
|
|
66
|
+
constant int64_t & ne1,
|
|
67
|
+
constant int64_t & ne2,
|
|
68
|
+
constant int64_t & ne3,
|
|
69
|
+
constant int64_t & nb0,
|
|
70
|
+
constant int64_t & nb1,
|
|
71
|
+
constant int64_t & nb2,
|
|
72
|
+
constant int64_t & nb3,
|
|
73
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
74
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
75
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
76
|
+
const int64_t i03 = tgpig.z;
|
|
77
|
+
const int64_t i02 = tgpig.y;
|
|
78
|
+
const int64_t i01 = tgpig.x;
|
|
79
|
+
|
|
80
|
+
const int64_t i13 = i03 % ne13;
|
|
81
|
+
const int64_t i12 = i02 % ne12;
|
|
82
|
+
const int64_t i11 = i01 % ne11;
|
|
83
|
+
|
|
84
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
|
85
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
86
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
87
|
+
|
|
88
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
89
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
|
90
|
+
|
|
91
|
+
src0_ptr += ntg.x*nb00;
|
|
92
|
+
src1_ptr += ntg.x*nb10;
|
|
93
|
+
dst_ptr += ntg.x*nb0;
|
|
94
|
+
}
|
|
33
95
|
}
|
|
34
96
|
|
|
35
97
|
// assumption: src1 is a row
|
|
@@ -38,7 +100,7 @@ kernel void kernel_add_row(
|
|
|
38
100
|
device const float4 * src0,
|
|
39
101
|
device const float4 * src1,
|
|
40
102
|
device float4 * dst,
|
|
41
|
-
constant int64_t & nb,
|
|
103
|
+
constant int64_t & nb [[buffer(27)]],
|
|
42
104
|
uint tpig[[thread_position_in_grid]]) {
|
|
43
105
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
44
106
|
}
|
|
@@ -63,9 +125,17 @@ kernel void kernel_mul_row(
|
|
|
63
125
|
}
|
|
64
126
|
|
|
65
127
|
kernel void kernel_scale(
|
|
128
|
+
device const float * src0,
|
|
129
|
+
device float * dst,
|
|
130
|
+
constant float & scale,
|
|
131
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
132
|
+
dst[tpig] = src0[tpig] * scale;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
kernel void kernel_scale_4(
|
|
66
136
|
device const float4 * src0,
|
|
67
137
|
device float4 * dst,
|
|
68
|
-
constant float
|
|
138
|
+
constant float & scale,
|
|
69
139
|
uint tpig[[thread_position_in_grid]]) {
|
|
70
140
|
dst[tpig] = src0[tpig] * scale;
|
|
71
141
|
}
|
|
@@ -85,6 +155,13 @@ kernel void kernel_relu(
|
|
|
85
155
|
dst[tpig] = max(0.0f, src0[tpig]);
|
|
86
156
|
}
|
|
87
157
|
|
|
158
|
+
kernel void kernel_sqr(
|
|
159
|
+
device const float * src0,
|
|
160
|
+
device float * dst,
|
|
161
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
162
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
|
163
|
+
}
|
|
164
|
+
|
|
88
165
|
constant float GELU_COEF_A = 0.044715f;
|
|
89
166
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
90
167
|
|
|
@@ -107,36 +184,73 @@ kernel void kernel_soft_max(
|
|
|
107
184
|
constant int64_t & ne00,
|
|
108
185
|
constant int64_t & ne01,
|
|
109
186
|
constant int64_t & ne02,
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
187
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
188
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
189
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
190
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
191
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
192
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
193
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
194
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
195
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
116
196
|
|
|
117
197
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
118
198
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
119
199
|
|
|
120
200
|
// parallel max
|
|
121
|
-
float lmax = psrc0[tpitg
|
|
122
|
-
|
|
201
|
+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
|
202
|
+
|
|
203
|
+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
|
123
204
|
lmax = MAX(lmax, psrc0[i00]);
|
|
124
205
|
}
|
|
125
|
-
|
|
206
|
+
|
|
207
|
+
float max = simd_max(lmax);
|
|
208
|
+
if (tiisg == 0) {
|
|
209
|
+
buf[sgitg] = max;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
213
|
+
|
|
214
|
+
// broadcast, simd group number is ntg / 32
|
|
215
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
216
|
+
if (tpitg < i) {
|
|
217
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
222
|
+
|
|
223
|
+
max = buf[0];
|
|
126
224
|
|
|
127
225
|
// parallel sum
|
|
128
226
|
float lsum = 0.0f;
|
|
129
|
-
for (int i00 = tpitg
|
|
227
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
130
228
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
|
131
229
|
lsum += exp_psrc0;
|
|
132
230
|
// Remember the result of exp here. exp is expensive, so we really do not
|
|
133
|
-
//
|
|
231
|
+
// wish to compute it twice.
|
|
134
232
|
pdst[i00] = exp_psrc0;
|
|
135
233
|
}
|
|
136
234
|
|
|
137
|
-
|
|
235
|
+
float sum = simd_sum(lsum);
|
|
236
|
+
if (tiisg == 0) {
|
|
237
|
+
buf[sgitg] = sum;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
241
|
+
|
|
242
|
+
// broadcast, simd group number is ntg / 32
|
|
243
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
244
|
+
if (tpitg < i) {
|
|
245
|
+
buf[tpitg] += buf[tpitg + i];
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
250
|
+
|
|
251
|
+
sum = buf[0];
|
|
138
252
|
|
|
139
|
-
for (int i00 = tpitg
|
|
253
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
140
254
|
pdst[i00] /= sum;
|
|
141
255
|
}
|
|
142
256
|
}
|
|
@@ -147,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|
|
147
261
|
constant int64_t & ne00,
|
|
148
262
|
constant int64_t & ne01,
|
|
149
263
|
constant int64_t & ne02,
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
264
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
265
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
266
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
267
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
268
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
269
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
270
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
271
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
272
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
156
273
|
|
|
157
274
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
158
275
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
159
276
|
|
|
160
277
|
// parallel max
|
|
161
|
-
float4 lmax4 = psrc4[tpitg
|
|
162
|
-
|
|
278
|
+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
|
279
|
+
|
|
280
|
+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
|
163
281
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
|
164
282
|
}
|
|
165
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
166
283
|
|
|
167
|
-
const float
|
|
284
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
285
|
+
float max = simd_max(lmax);
|
|
286
|
+
if (tiisg == 0) {
|
|
287
|
+
buf[sgitg] = max;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
291
|
+
|
|
292
|
+
// broadcast, simd group number is ntg / 32
|
|
293
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
294
|
+
if (tpitg < i) {
|
|
295
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
300
|
+
|
|
301
|
+
max = buf[0];
|
|
168
302
|
|
|
169
303
|
// parallel sum
|
|
170
304
|
float4 lsum4 = 0.0f;
|
|
171
|
-
for (int i00 = tpitg
|
|
305
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
172
306
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
|
173
307
|
lsum4 += exp_psrc4;
|
|
174
308
|
pdst4[i00] = exp_psrc4;
|
|
175
309
|
}
|
|
176
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
177
310
|
|
|
178
|
-
const float
|
|
311
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
312
|
+
float sum = simd_sum(lsum);
|
|
313
|
+
if (tiisg == 0) {
|
|
314
|
+
buf[sgitg] = sum;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
318
|
+
|
|
319
|
+
// broadcast, simd group number is ntg / 32
|
|
320
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
321
|
+
if (tpitg < i) {
|
|
322
|
+
buf[tpitg] += buf[tpitg + i];
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
327
|
+
|
|
328
|
+
sum = buf[0];
|
|
179
329
|
|
|
180
|
-
for (int i00 = tpitg
|
|
330
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
181
331
|
pdst4[i00] /= sum;
|
|
182
332
|
}
|
|
183
333
|
}
|
|
@@ -197,7 +347,7 @@ kernel void kernel_diag_mask_inf(
|
|
|
197
347
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
|
198
348
|
} else {
|
|
199
349
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
|
200
|
-
|
|
350
|
+
}
|
|
201
351
|
}
|
|
202
352
|
|
|
203
353
|
kernel void kernel_diag_mask_inf_8(
|
|
@@ -291,10 +441,11 @@ kernel void kernel_rms_norm(
|
|
|
291
441
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
292
442
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
293
443
|
uint ntg[[threads_per_threadgroup]]) {
|
|
294
|
-
device const float4 * x
|
|
295
|
-
device const float
|
|
296
|
-
|
|
297
|
-
|
|
444
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
445
|
+
device const float * x_scalar = (device const float *) x;
|
|
446
|
+
|
|
447
|
+
float4 sumf = 0;
|
|
448
|
+
float all_sum = 0;
|
|
298
449
|
|
|
299
450
|
// parallel sum
|
|
300
451
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
@@ -307,6 +458,7 @@ kernel void kernel_rms_norm(
|
|
|
307
458
|
}
|
|
308
459
|
|
|
309
460
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
461
|
+
|
|
310
462
|
// broadcast, simd group number is ntg / 32
|
|
311
463
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
312
464
|
if (tpitg < i) {
|
|
@@ -314,7 +466,9 @@ kernel void kernel_rms_norm(
|
|
|
314
466
|
}
|
|
315
467
|
}
|
|
316
468
|
if (tpitg == 0) {
|
|
317
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
469
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
470
|
+
sum[0] += x_scalar[i];
|
|
471
|
+
}
|
|
318
472
|
sum[0] /= ne00;
|
|
319
473
|
}
|
|
320
474
|
|
|
@@ -329,7 +483,9 @@ kernel void kernel_rms_norm(
|
|
|
329
483
|
y[i00] = x[i00] * scale;
|
|
330
484
|
}
|
|
331
485
|
if (tpitg == 0) {
|
|
332
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
486
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
487
|
+
y_scalar[i00] = x_scalar[i00] * scale;
|
|
488
|
+
}
|
|
333
489
|
}
|
|
334
490
|
}
|
|
335
491
|
|
|
@@ -339,8 +495,11 @@ kernel void kernel_rms_norm(
|
|
|
339
495
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
340
496
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
341
497
|
float d = qb_curr->d;
|
|
498
|
+
|
|
342
499
|
float2 acc = 0.f;
|
|
500
|
+
|
|
343
501
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
|
502
|
+
|
|
344
503
|
for (int i = 0; i < 8; i+=2) {
|
|
345
504
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
346
505
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -357,8 +516,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
|
357
516
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
358
517
|
float d = qb_curr->d;
|
|
359
518
|
float m = qb_curr->m;
|
|
360
|
-
|
|
519
|
+
|
|
361
520
|
float2 acc = 0.f;
|
|
521
|
+
|
|
522
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
|
523
|
+
|
|
362
524
|
for (int i = 0; i < 8; i+=2) {
|
|
363
525
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
364
526
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -368,9 +530,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
|
368
530
|
return d * (acc[0] + acc[1]) + sumy * m;
|
|
369
531
|
}
|
|
370
532
|
|
|
533
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
534
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
|
535
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
536
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
537
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
538
|
+
float d = qb_curr->d;
|
|
539
|
+
|
|
540
|
+
float2 acc = 0.f;
|
|
541
|
+
|
|
542
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
|
543
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
544
|
+
|
|
545
|
+
for (int i = 0; i < 8; i+=2) {
|
|
546
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
547
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
548
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
549
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
550
|
+
}
|
|
551
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
555
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
|
556
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
557
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
558
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
559
|
+
float d = qb_curr->d;
|
|
560
|
+
float m = qb_curr->m;
|
|
561
|
+
|
|
562
|
+
float2 acc = 0.f;
|
|
563
|
+
|
|
564
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
|
565
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
566
|
+
|
|
567
|
+
for (int i = 0; i < 8; i+=2) {
|
|
568
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
569
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
570
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
571
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
572
|
+
}
|
|
573
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
|
574
|
+
}
|
|
575
|
+
|
|
371
576
|
// putting them in the kernel cause a significant performance penalty
|
|
372
|
-
#define N_DST 4
|
|
373
|
-
#define N_SIMDGROUP 2
|
|
577
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
|
578
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
374
579
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
375
580
|
//Note: This is a template, but strictly speaking it only applies to
|
|
376
581
|
// quantizations where the block size is 32. It also does not
|
|
@@ -381,18 +586,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
381
586
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
|
382
587
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
|
383
588
|
const int nb = ne00/QK4_0;
|
|
589
|
+
|
|
384
590
|
const int r0 = tgpig.x;
|
|
385
591
|
const int r1 = tgpig.y;
|
|
386
592
|
const int im = tgpig.z;
|
|
593
|
+
|
|
387
594
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
595
|
+
|
|
388
596
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
|
597
|
+
|
|
389
598
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
|
390
599
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
391
|
-
float yl[16]; // src1 vector cache
|
|
392
|
-
float sumf[nr]={0.f};
|
|
393
600
|
|
|
394
|
-
|
|
395
|
-
|
|
601
|
+
float yl[16]; // src1 vector cache
|
|
602
|
+
float sumf[nr] = {0.f};
|
|
603
|
+
|
|
604
|
+
const int ix = (tiisg/2);
|
|
605
|
+
const int il = (tiisg%2)*8;
|
|
396
606
|
|
|
397
607
|
device const float * yb = y + ix * QK4_0 + il;
|
|
398
608
|
|
|
@@ -403,6 +613,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
403
613
|
sumy += yb[i] + yb[i+1];
|
|
404
614
|
yl[i+0] = yb[i+ 0];
|
|
405
615
|
yl[i+1] = yb[i+ 1]/256.f;
|
|
616
|
+
|
|
406
617
|
sumy += yb[i+16] + yb[i+17];
|
|
407
618
|
yl[i+8] = yb[i+16]/16.f;
|
|
408
619
|
yl[i+9] = yb[i+17]/4096.f;
|
|
@@ -418,12 +629,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
418
629
|
for (int row = 0; row < nr; ++row) {
|
|
419
630
|
const float tot = simd_sum(sumf[row]);
|
|
420
631
|
if (tiisg == 0 && first_row + row < ne01) {
|
|
421
|
-
dst[
|
|
632
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
|
422
633
|
}
|
|
423
634
|
}
|
|
424
635
|
}
|
|
425
636
|
|
|
426
|
-
kernel void
|
|
637
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
|
427
638
|
device const void * src0,
|
|
428
639
|
device const float * src1,
|
|
429
640
|
device float * dst,
|
|
@@ -436,12 +647,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
|
436
647
|
constant int64_t & ne1[[buffer(16)]],
|
|
437
648
|
constant uint & gqa[[buffer(17)]],
|
|
438
649
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
439
|
-
uint
|
|
440
|
-
uint
|
|
650
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
651
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
441
652
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
442
653
|
}
|
|
443
654
|
|
|
444
|
-
kernel void
|
|
655
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
|
445
656
|
device const void * src0,
|
|
446
657
|
device const float * src1,
|
|
447
658
|
device float * dst,
|
|
@@ -459,9 +670,46 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
|
459
670
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
460
671
|
}
|
|
461
672
|
|
|
673
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
|
674
|
+
device const void * src0,
|
|
675
|
+
device const float * src1,
|
|
676
|
+
device float * dst,
|
|
677
|
+
constant int64_t & ne00,
|
|
678
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
679
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
680
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
681
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
682
|
+
constant int64_t & ne0[[buffer(15)]],
|
|
683
|
+
constant int64_t & ne1[[buffer(16)]],
|
|
684
|
+
constant uint & gqa[[buffer(17)]],
|
|
685
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
686
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
687
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
688
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
|
692
|
+
device const void * src0,
|
|
693
|
+
device const float * src1,
|
|
694
|
+
device float * dst,
|
|
695
|
+
constant int64_t & ne00,
|
|
696
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
697
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
698
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
699
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
700
|
+
constant int64_t & ne0[[buffer(15)]],
|
|
701
|
+
constant int64_t & ne1[[buffer(16)]],
|
|
702
|
+
constant uint & gqa[[buffer(17)]],
|
|
703
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
704
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
705
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
706
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
|
|
462
710
|
#define NB_Q8_0 8
|
|
463
711
|
|
|
464
|
-
kernel void
|
|
712
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
|
465
713
|
device const void * src0,
|
|
466
714
|
device const float * src1,
|
|
467
715
|
device float * dst,
|
|
@@ -525,7 +773,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
|
525
773
|
|
|
526
774
|
#define N_F32_F32 4
|
|
527
775
|
|
|
528
|
-
kernel void
|
|
776
|
+
kernel void kernel_mul_mv_f32_f32(
|
|
529
777
|
device const char * src0,
|
|
530
778
|
device const char * src1,
|
|
531
779
|
device float * dst,
|
|
@@ -596,7 +844,7 @@ kernel void kernel_mul_mat_f32_f32(
|
|
|
596
844
|
}
|
|
597
845
|
}
|
|
598
846
|
|
|
599
|
-
kernel void
|
|
847
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
|
600
848
|
device const char * src0,
|
|
601
849
|
device const char * src1,
|
|
602
850
|
device float * dst,
|
|
@@ -615,7 +863,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
|
615
863
|
constant int64_t & ne0,
|
|
616
864
|
constant int64_t & ne1,
|
|
617
865
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
618
|
-
uint
|
|
866
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
619
867
|
|
|
620
868
|
const int64_t r0 = tgpig.x;
|
|
621
869
|
const int64_t r1 = tgpig.y;
|
|
@@ -650,7 +898,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
|
650
898
|
|
|
651
899
|
#define N_F16_F32 4
|
|
652
900
|
|
|
653
|
-
kernel void
|
|
901
|
+
kernel void kernel_mul_mv_f16_f32(
|
|
654
902
|
device const char * src0,
|
|
655
903
|
device const char * src1,
|
|
656
904
|
device float * dst,
|
|
@@ -722,7 +970,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
|
722
970
|
}
|
|
723
971
|
|
|
724
972
|
// Assumes row size (ne00) is a multiple of 4
|
|
725
|
-
kernel void
|
|
973
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
|
726
974
|
device const char * src0,
|
|
727
975
|
device const char * src1,
|
|
728
976
|
device float * dst,
|
|
@@ -783,7 +1031,9 @@ kernel void kernel_alibi_f32(
|
|
|
783
1031
|
constant uint64_t & nb1,
|
|
784
1032
|
constant uint64_t & nb2,
|
|
785
1033
|
constant uint64_t & nb3,
|
|
786
|
-
constant
|
|
1034
|
+
constant float & m0,
|
|
1035
|
+
constant float & m1,
|
|
1036
|
+
constant int & n_heads_log2_floor,
|
|
787
1037
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
788
1038
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
789
1039
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -799,37 +1049,122 @@ kernel void kernel_alibi_f32(
|
|
|
799
1049
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
800
1050
|
|
|
801
1051
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
802
|
-
float m_k
|
|
1052
|
+
float m_k;
|
|
1053
|
+
if (i2 < n_heads_log2_floor) {
|
|
1054
|
+
m_k = pow(m0, i2 + 1);
|
|
1055
|
+
} else {
|
|
1056
|
+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
|
1057
|
+
}
|
|
803
1058
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
804
1059
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
805
1060
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
|
806
1061
|
}
|
|
807
1062
|
}
|
|
808
1063
|
|
|
1064
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1065
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
1066
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
1070
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1071
|
+
static void rope_yarn(
|
|
1072
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1073
|
+
thread float * cos_theta, thread float * sin_theta
|
|
1074
|
+
) {
|
|
1075
|
+
// Get n-d rotational scaling corrected for extrapolation
|
|
1076
|
+
float theta_interp = freq_scale * theta_extrap;
|
|
1077
|
+
float theta = theta_interp;
|
|
1078
|
+
if (ext_factor != 0.0f) {
|
|
1079
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
1080
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
1081
|
+
|
|
1082
|
+
// Get n-d magnitude scaling corrected for interpolation
|
|
1083
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
|
1084
|
+
}
|
|
1085
|
+
*cos_theta = cos(theta) * mscale;
|
|
1086
|
+
*sin_theta = sin(theta) * mscale;
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1090
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1091
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
|
1092
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
static void rope_yarn_corr_dims(
|
|
1096
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1097
|
+
) {
|
|
1098
|
+
// start and end correction dims
|
|
1099
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
|
1100
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
typedef void (rope_t)(
|
|
1104
|
+
device const void * src0,
|
|
1105
|
+
device const int32_t * src1,
|
|
1106
|
+
device float * dst,
|
|
1107
|
+
constant int64_t & ne00,
|
|
1108
|
+
constant int64_t & ne01,
|
|
1109
|
+
constant int64_t & ne02,
|
|
1110
|
+
constant int64_t & ne03,
|
|
1111
|
+
constant uint64_t & nb00,
|
|
1112
|
+
constant uint64_t & nb01,
|
|
1113
|
+
constant uint64_t & nb02,
|
|
1114
|
+
constant uint64_t & nb03,
|
|
1115
|
+
constant int64_t & ne0,
|
|
1116
|
+
constant int64_t & ne1,
|
|
1117
|
+
constant int64_t & ne2,
|
|
1118
|
+
constant int64_t & ne3,
|
|
1119
|
+
constant uint64_t & nb0,
|
|
1120
|
+
constant uint64_t & nb1,
|
|
1121
|
+
constant uint64_t & nb2,
|
|
1122
|
+
constant uint64_t & nb3,
|
|
1123
|
+
constant int & n_past,
|
|
1124
|
+
constant int & n_dims,
|
|
1125
|
+
constant int & mode,
|
|
1126
|
+
constant int & n_orig_ctx,
|
|
1127
|
+
constant float & freq_base,
|
|
1128
|
+
constant float & freq_scale,
|
|
1129
|
+
constant float & ext_factor,
|
|
1130
|
+
constant float & attn_factor,
|
|
1131
|
+
constant float & beta_fast,
|
|
1132
|
+
constant float & beta_slow,
|
|
1133
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
1134
|
+
uint3 tptg[[threads_per_threadgroup]],
|
|
1135
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
|
1136
|
+
|
|
1137
|
+
template<typename T>
|
|
809
1138
|
kernel void kernel_rope(
|
|
810
|
-
device const
|
|
811
|
-
device
|
|
812
|
-
|
|
813
|
-
constant
|
|
814
|
-
constant
|
|
815
|
-
constant
|
|
816
|
-
constant
|
|
817
|
-
constant
|
|
818
|
-
constant
|
|
819
|
-
constant
|
|
820
|
-
constant
|
|
821
|
-
constant
|
|
822
|
-
constant
|
|
823
|
-
constant
|
|
824
|
-
constant
|
|
825
|
-
constant
|
|
826
|
-
constant
|
|
827
|
-
constant
|
|
828
|
-
constant
|
|
829
|
-
constant
|
|
830
|
-
constant
|
|
831
|
-
constant
|
|
832
|
-
constant
|
|
1139
|
+
device const void * src0,
|
|
1140
|
+
device const int32_t * src1,
|
|
1141
|
+
device float * dst,
|
|
1142
|
+
constant int64_t & ne00,
|
|
1143
|
+
constant int64_t & ne01,
|
|
1144
|
+
constant int64_t & ne02,
|
|
1145
|
+
constant int64_t & ne03,
|
|
1146
|
+
constant uint64_t & nb00,
|
|
1147
|
+
constant uint64_t & nb01,
|
|
1148
|
+
constant uint64_t & nb02,
|
|
1149
|
+
constant uint64_t & nb03,
|
|
1150
|
+
constant int64_t & ne0,
|
|
1151
|
+
constant int64_t & ne1,
|
|
1152
|
+
constant int64_t & ne2,
|
|
1153
|
+
constant int64_t & ne3,
|
|
1154
|
+
constant uint64_t & nb0,
|
|
1155
|
+
constant uint64_t & nb1,
|
|
1156
|
+
constant uint64_t & nb2,
|
|
1157
|
+
constant uint64_t & nb3,
|
|
1158
|
+
constant int & n_past,
|
|
1159
|
+
constant int & n_dims,
|
|
1160
|
+
constant int & mode,
|
|
1161
|
+
constant int & n_orig_ctx,
|
|
1162
|
+
constant float & freq_base,
|
|
1163
|
+
constant float & freq_scale,
|
|
1164
|
+
constant float & ext_factor,
|
|
1165
|
+
constant float & attn_factor,
|
|
1166
|
+
constant float & beta_fast,
|
|
1167
|
+
constant float & beta_slow,
|
|
833
1168
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
834
1169
|
uint3 tptg[[threads_per_threadgroup]],
|
|
835
1170
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
@@ -839,23 +1174,28 @@ kernel void kernel_rope(
|
|
|
839
1174
|
|
|
840
1175
|
const bool is_neox = mode & 2;
|
|
841
1176
|
|
|
842
|
-
|
|
1177
|
+
float corr_dims[2];
|
|
1178
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
|
843
1179
|
|
|
844
|
-
const
|
|
1180
|
+
device const int32_t * pos = src1;
|
|
1181
|
+
|
|
1182
|
+
const int64_t p = pos[i2];
|
|
1183
|
+
|
|
1184
|
+
const float theta_0 = (float)p;
|
|
845
1185
|
const float inv_ndims = -1.f/n_dims;
|
|
846
1186
|
|
|
847
1187
|
if (!is_neox) {
|
|
848
1188
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
849
1189
|
|
|
850
1190
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
851
|
-
|
|
852
|
-
|
|
1191
|
+
float cos_theta, sin_theta;
|
|
1192
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
853
1193
|
|
|
854
|
-
device const
|
|
855
|
-
device
|
|
1194
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1195
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
856
1196
|
|
|
857
|
-
const
|
|
858
|
-
const
|
|
1197
|
+
const T x0 = src[0];
|
|
1198
|
+
const T x1 = src[1];
|
|
859
1199
|
|
|
860
1200
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
861
1201
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
@@ -864,14 +1204,17 @@ kernel void kernel_rope(
|
|
|
864
1204
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
865
1205
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
866
1206
|
|
|
867
|
-
|
|
868
|
-
const float
|
|
869
|
-
|
|
1207
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1208
|
+
const float cur_rot = inv_ndims*ic - ib;
|
|
1209
|
+
|
|
1210
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
|
1211
|
+
float cos_theta, sin_theta;
|
|
1212
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
870
1213
|
|
|
871
1214
|
const int64_t i0 = ib*n_dims + ic/2;
|
|
872
1215
|
|
|
873
|
-
device const
|
|
874
|
-
device
|
|
1216
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1217
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
875
1218
|
|
|
876
1219
|
const float x0 = src[0];
|
|
877
1220
|
const float x1 = src[n_dims/2];
|
|
@@ -883,6 +1226,9 @@ kernel void kernel_rope(
|
|
|
883
1226
|
}
|
|
884
1227
|
}
|
|
885
1228
|
|
|
1229
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
|
1230
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
|
1231
|
+
|
|
886
1232
|
kernel void kernel_cpy_f16_f16(
|
|
887
1233
|
device const half * src0,
|
|
888
1234
|
device half * dst,
|
|
@@ -1008,6 +1354,62 @@ kernel void kernel_cpy_f32_f32(
|
|
|
1008
1354
|
}
|
|
1009
1355
|
}
|
|
1010
1356
|
|
|
1357
|
+
kernel void kernel_concat(
|
|
1358
|
+
device const char * src0,
|
|
1359
|
+
device const char * src1,
|
|
1360
|
+
device char * dst,
|
|
1361
|
+
constant int64_t & ne00,
|
|
1362
|
+
constant int64_t & ne01,
|
|
1363
|
+
constant int64_t & ne02,
|
|
1364
|
+
constant int64_t & ne03,
|
|
1365
|
+
constant uint64_t & nb00,
|
|
1366
|
+
constant uint64_t & nb01,
|
|
1367
|
+
constant uint64_t & nb02,
|
|
1368
|
+
constant uint64_t & nb03,
|
|
1369
|
+
constant int64_t & ne10,
|
|
1370
|
+
constant int64_t & ne11,
|
|
1371
|
+
constant int64_t & ne12,
|
|
1372
|
+
constant int64_t & ne13,
|
|
1373
|
+
constant uint64_t & nb10,
|
|
1374
|
+
constant uint64_t & nb11,
|
|
1375
|
+
constant uint64_t & nb12,
|
|
1376
|
+
constant uint64_t & nb13,
|
|
1377
|
+
constant int64_t & ne0,
|
|
1378
|
+
constant int64_t & ne1,
|
|
1379
|
+
constant int64_t & ne2,
|
|
1380
|
+
constant int64_t & ne3,
|
|
1381
|
+
constant uint64_t & nb0,
|
|
1382
|
+
constant uint64_t & nb1,
|
|
1383
|
+
constant uint64_t & nb2,
|
|
1384
|
+
constant uint64_t & nb3,
|
|
1385
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1386
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1387
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1388
|
+
|
|
1389
|
+
const int64_t i03 = tgpig.z;
|
|
1390
|
+
const int64_t i02 = tgpig.y;
|
|
1391
|
+
const int64_t i01 = tgpig.x;
|
|
1392
|
+
|
|
1393
|
+
const int64_t i13 = i03 % ne13;
|
|
1394
|
+
const int64_t i12 = i02 % ne12;
|
|
1395
|
+
const int64_t i11 = i01 % ne11;
|
|
1396
|
+
|
|
1397
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
|
1398
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
1399
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
1400
|
+
|
|
1401
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1402
|
+
if (i02 < ne02) {
|
|
1403
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
|
1404
|
+
src0_ptr += ntg.x*nb00;
|
|
1405
|
+
} else {
|
|
1406
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
|
1407
|
+
src1_ptr += ntg.x*nb10;
|
|
1408
|
+
}
|
|
1409
|
+
dst_ptr += ntg.x*nb0;
|
|
1410
|
+
}
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1011
1413
|
//============================================ k-quants ======================================================
|
|
1012
1414
|
|
|
1013
1415
|
#ifndef QK_K
|
|
@@ -1100,7 +1502,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
|
1100
1502
|
|
|
1101
1503
|
//====================================== dot products =========================
|
|
1102
1504
|
|
|
1103
|
-
kernel void
|
|
1505
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
|
1104
1506
|
device const void * src0,
|
|
1105
1507
|
device const float * src1,
|
|
1106
1508
|
device float * dst,
|
|
@@ -1244,7 +1646,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
|
1244
1646
|
}
|
|
1245
1647
|
|
|
1246
1648
|
#if QK_K == 256
|
|
1247
|
-
kernel void
|
|
1649
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
1248
1650
|
device const void * src0,
|
|
1249
1651
|
device const float * src1,
|
|
1250
1652
|
device float * dst,
|
|
@@ -1273,8 +1675,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1273
1675
|
|
|
1274
1676
|
float yl[32];
|
|
1275
1677
|
|
|
1276
|
-
const uint16_t kmask1 = 0x3030;
|
|
1277
|
-
const uint16_t kmask2 = 0x0f0f;
|
|
1678
|
+
//const uint16_t kmask1 = 0x3030;
|
|
1679
|
+
//const uint16_t kmask2 = 0x0f0f;
|
|
1278
1680
|
|
|
1279
1681
|
const int tid = tiisg/4;
|
|
1280
1682
|
const int ix = tiisg%4;
|
|
@@ -1396,7 +1798,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1396
1798
|
}
|
|
1397
1799
|
}
|
|
1398
1800
|
#else
|
|
1399
|
-
kernel void
|
|
1801
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
1400
1802
|
device const void * src0,
|
|
1401
1803
|
device const float * src1,
|
|
1402
1804
|
device float * dst,
|
|
@@ -1467,7 +1869,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1467
1869
|
#endif
|
|
1468
1870
|
|
|
1469
1871
|
#if QK_K == 256
|
|
1470
|
-
kernel void
|
|
1872
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
1471
1873
|
device const void * src0,
|
|
1472
1874
|
device const float * src1,
|
|
1473
1875
|
device float * dst,
|
|
@@ -1573,7 +1975,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1573
1975
|
}
|
|
1574
1976
|
}
|
|
1575
1977
|
#else
|
|
1576
|
-
kernel void
|
|
1978
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
1577
1979
|
device const void * src0,
|
|
1578
1980
|
device const float * src1,
|
|
1579
1981
|
device float * dst,
|
|
@@ -1662,7 +2064,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1662
2064
|
}
|
|
1663
2065
|
#endif
|
|
1664
2066
|
|
|
1665
|
-
kernel void
|
|
2067
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
|
1666
2068
|
device const void * src0,
|
|
1667
2069
|
device const float * src1,
|
|
1668
2070
|
device float * dst,
|
|
@@ -1835,7 +2237,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1835
2237
|
|
|
1836
2238
|
}
|
|
1837
2239
|
|
|
1838
|
-
kernel void
|
|
2240
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
|
1839
2241
|
device const void * src0,
|
|
1840
2242
|
device const float * src1,
|
|
1841
2243
|
device float * dst,
|
|
@@ -1984,6 +2386,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
|
1984
2386
|
}
|
|
1985
2387
|
}
|
|
1986
2388
|
|
|
2389
|
+
template <typename type4x4>
|
|
2390
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
|
2391
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
2392
|
+
const float d = xb->d;
|
|
2393
|
+
const float md = -16.h * xb->d;
|
|
2394
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
2395
|
+
|
|
2396
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
2397
|
+
|
|
2398
|
+
const int x_mv = il ? 4 : 0;
|
|
2399
|
+
|
|
2400
|
+
const int gh_mv = il ? 12 : 0;
|
|
2401
|
+
const int gh_bk = il ? 0 : 4;
|
|
2402
|
+
|
|
2403
|
+
for (int i = 0; i < 8; i++) {
|
|
2404
|
+
// extract the 5-th bits for x0 and x1
|
|
2405
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
2406
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
2407
|
+
|
|
2408
|
+
// combine the 4-bits from qs with the 5th bit
|
|
2409
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
2410
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
2411
|
+
|
|
2412
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
|
2413
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
|
2414
|
+
}
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
template <typename type4x4>
|
|
2418
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
|
2419
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
2420
|
+
const float d = xb->d;
|
|
2421
|
+
const float m = xb->m;
|
|
2422
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
2423
|
+
|
|
2424
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
2425
|
+
|
|
2426
|
+
const int x_mv = il ? 4 : 0;
|
|
2427
|
+
|
|
2428
|
+
const int gh_mv = il ? 12 : 0;
|
|
2429
|
+
const int gh_bk = il ? 0 : 4;
|
|
2430
|
+
|
|
2431
|
+
for (int i = 0; i < 8; i++) {
|
|
2432
|
+
// extract the 5-th bits for x0 and x1
|
|
2433
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
2434
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
2435
|
+
|
|
2436
|
+
// combine the 4-bits from qs with the 5th bit
|
|
2437
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
2438
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
2439
|
+
|
|
2440
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
|
2441
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
2442
|
+
}
|
|
2443
|
+
}
|
|
2444
|
+
|
|
1987
2445
|
template <typename type4x4>
|
|
1988
2446
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
1989
2447
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
@@ -2173,7 +2631,7 @@ kernel void kernel_get_rows(
|
|
|
2173
2631
|
}
|
|
2174
2632
|
|
|
2175
2633
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
2176
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
|
2634
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
2177
2635
|
#define BLOCK_SIZE_K 32
|
|
2178
2636
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
2179
2637
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
@@ -2210,9 +2668,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2210
2668
|
const uint r0 = tgpig.y;
|
|
2211
2669
|
const uint r1 = tgpig.x;
|
|
2212
2670
|
const uint im = tgpig.z;
|
|
2671
|
+
|
|
2213
2672
|
// if this block is of 64x32 shape or smaller
|
|
2214
2673
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
2215
2674
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
2675
|
+
|
|
2216
2676
|
// a thread shouldn't load data outside of the matrix
|
|
2217
2677
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
2218
2678
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
@@ -2236,26 +2696,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2236
2696
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
2237
2697
|
|
|
2238
2698
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
2239
|
-
//load data and store to threadgroup memory
|
|
2699
|
+
// load data and store to threadgroup memory
|
|
2240
2700
|
half4x4 temp_a;
|
|
2241
2701
|
dequantize_func(x, il, temp_a);
|
|
2242
2702
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2703
|
+
|
|
2243
2704
|
#pragma unroll(16)
|
|
2244
2705
|
for (int i = 0; i < 16; i++) {
|
|
2245
2706
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
2246
|
-
+
|
|
2247
|
-
+
|
|
2707
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
2708
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
2248
2709
|
}
|
|
2249
|
-
|
|
2250
|
-
|
|
2710
|
+
|
|
2711
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
2712
|
+
|
|
2251
2713
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
2252
2714
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
2253
2715
|
y += BLOCK_SIZE_K;
|
|
2254
2716
|
|
|
2255
2717
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2256
|
-
|
|
2718
|
+
|
|
2719
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
2257
2720
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
2258
2721
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
2722
|
+
|
|
2259
2723
|
#pragma unroll(4)
|
|
2260
2724
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
2261
2725
|
#pragma unroll(4)
|
|
@@ -2270,6 +2734,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2270
2734
|
|
|
2271
2735
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
2272
2736
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
2737
|
+
|
|
2273
2738
|
#pragma unroll(8)
|
|
2274
2739
|
for (int i = 0; i < 8; i++){
|
|
2275
2740
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
@@ -2278,25 +2743,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2278
2743
|
}
|
|
2279
2744
|
|
|
2280
2745
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
|
2281
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
|
2282
|
-
|
|
2746
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
2747
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
2283
2748
|
for (int i = 0; i < 8; i++) {
|
|
2284
2749
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
2285
2750
|
}
|
|
2286
2751
|
} else {
|
|
2287
2752
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
2288
2753
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2289
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
|
2754
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
2290
2755
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
2291
2756
|
for (int i = 0; i < 8; i++) {
|
|
2292
2757
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
2293
2758
|
}
|
|
2294
2759
|
|
|
2295
2760
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2296
|
-
|
|
2297
|
-
|
|
2761
|
+
|
|
2762
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
2763
|
+
if (sgitg == 0) {
|
|
2298
2764
|
for (int i = 0; i < n_rows; i++) {
|
|
2299
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
|
2765
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
2300
2766
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
2301
2767
|
}
|
|
2302
2768
|
}
|
|
@@ -2317,6 +2783,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
|
2317
2783
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
2318
2784
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
2319
2785
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
2786
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
2787
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
|
2320
2788
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
2321
2789
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2322
2790
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
@@ -2345,6 +2813,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
|
2345
2813
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
2346
2814
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
2347
2815
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
2816
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
2817
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
2348
2818
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
2349
2819
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2350
2820
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|