llama_cpp 0.14.6 → 0.14.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +37 -2
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +11 -6
- data/vendor/tmp/llama.cpp/ggml-alloc.c +8 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +14 -10
- data/vendor/tmp/llama.cpp/ggml-impl.h +262 -4
- data/vendor/tmp/llama.cpp/ggml-quants.c +0 -293
- data/vendor/tmp/llama.cpp/ggml.c +3 -17
- data/vendor/tmp/llama.cpp/llama.cpp +379 -66
- data/vendor/tmp/llama.cpp/llama.h +19 -6
- data/vendor/tmp/llama.cpp/sgemm.cpp +404 -553
- metadata +2 -2
@@ -50,6 +50,7 @@
|
|
50
50
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
51
51
|
|
52
52
|
#include "sgemm.h"
|
53
|
+
#include <algorithm>
|
53
54
|
#include "ggml-impl.h"
|
54
55
|
#include "ggml-quants.h"
|
55
56
|
|
@@ -65,22 +66,6 @@
|
|
65
66
|
#define VECTOR_REGISTERS 16
|
66
67
|
#endif
|
67
68
|
|
68
|
-
// there will be blocks
|
69
|
-
#define BEGIN_KERNEL(RM, RN) \
|
70
|
-
int ytiles = (m - m0) / RM; \
|
71
|
-
int xtiles = (n - n0) / RN; \
|
72
|
-
int tiles = ytiles * xtiles; \
|
73
|
-
int duty = (tiles + nth - 1) / nth; \
|
74
|
-
int start = duty * ith; \
|
75
|
-
int end = start + duty; \
|
76
|
-
if (end > tiles) \
|
77
|
-
end = tiles; \
|
78
|
-
for (int job = start; job < end; ++job) { \
|
79
|
-
int i = m0 + job / xtiles * RM; \
|
80
|
-
int j = n0 + job % xtiles * RN;
|
81
|
-
|
82
|
-
#define END_KERNEL() }
|
83
|
-
|
84
69
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
85
70
|
|
86
71
|
namespace {
|
@@ -122,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|
122
107
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
123
108
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
124
109
|
|
110
|
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
111
|
+
// VECTORIZED FUSED MULTIPLY ADD
|
112
|
+
|
113
|
+
/**
|
114
|
+
* Computes a * b + c.
|
115
|
+
*/
|
116
|
+
template <typename T, typename U>
|
117
|
+
inline U madd(T a, T b, U c) {
|
118
|
+
return add(mul(a, b), c);
|
119
|
+
}
|
120
|
+
|
121
|
+
#if defined(__FMA__)
|
122
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
123
|
+
template <>
|
124
|
+
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
125
|
+
return _mm256_fmadd_ps(a, b, c);
|
126
|
+
}
|
127
|
+
#endif
|
128
|
+
#if defined(__AVX512F__)
|
129
|
+
template <>
|
130
|
+
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
131
|
+
return _mm512_fmadd_ps(a, b, c);
|
132
|
+
}
|
133
|
+
#endif
|
134
|
+
#endif
|
135
|
+
|
136
|
+
#if defined(__ARM_FEATURE_FMA)
|
137
|
+
template <>
|
138
|
+
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
139
|
+
return vfmaq_f32(c, b, a);
|
140
|
+
}
|
141
|
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
142
|
+
template <>
|
143
|
+
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
144
|
+
return vfmaq_f16(c, b, a);
|
145
|
+
}
|
146
|
+
#endif
|
147
|
+
#endif
|
148
|
+
|
125
149
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
126
150
|
// VECTORIZED HORIZONTAL SUM
|
127
151
|
|
@@ -213,36 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
213
237
|
}
|
214
238
|
#endif // __AVX512F__
|
215
239
|
|
216
|
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
217
|
-
// ABSTRACTIONS
|
218
|
-
|
219
|
-
/**
|
220
|
-
* Computes a * b + c.
|
221
|
-
*
|
222
|
-
* This operation will become fused into a single arithmetic instruction
|
223
|
-
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
|
224
|
-
* 2013), AMD Bulldozer+ (c. 2011), etc.
|
225
|
-
*/
|
226
|
-
template <typename T, typename U>
|
227
|
-
inline U madd(T a, T b, U c) {
|
228
|
-
return add(mul(a, b), c);
|
229
|
-
}
|
230
|
-
|
231
|
-
/**
|
232
|
-
* Computes a * b + c with error correction.
|
233
|
-
*
|
234
|
-
* @see W. Kahan, "Further remarks on reducing truncation errors,"
|
235
|
-
* Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
|
236
|
-
* doi: 10.1145/363707.363723.
|
237
|
-
*/
|
238
|
-
template <typename T, typename U>
|
239
|
-
inline U madder(T a, T b, U c, U *e) {
|
240
|
-
U y = sub(mul(a, b), *e);
|
241
|
-
U t = add(c, y);
|
242
|
-
*e = sub(sub(t, c), y);
|
243
|
-
return t;
|
244
|
-
}
|
245
|
-
|
246
240
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
247
241
|
// FLOATING POINT MATRIX MULTIPLICATION
|
248
242
|
|
@@ -265,226 +259,179 @@ class tinyBLAS {
|
|
265
259
|
private:
|
266
260
|
NOINLINE void mnpack(int m0, int m, int n0, int n) {
|
267
261
|
int mc, nc, mp, np;
|
268
|
-
|
269
|
-
|
270
|
-
|
262
|
+
switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
|
263
|
+
#if VECTOR_REGISTERS == 32
|
264
|
+
case 0x55:
|
265
|
+
mc = 5;
|
266
|
+
nc = 5;
|
267
|
+
gemm<5, 5>(m0, m, n0, n);
|
268
|
+
break;
|
269
|
+
case 0x45:
|
270
|
+
mc = 4;
|
271
|
+
nc = 5;
|
272
|
+
gemm<4, 5>(m0, m, n0, n);
|
273
|
+
break;
|
274
|
+
case 0x54:
|
271
275
|
mc = 5;
|
276
|
+
nc = 4;
|
277
|
+
gemm<5, 4>(m0, m, n0, n);
|
278
|
+
break;
|
279
|
+
case 0x44:
|
280
|
+
mc = 4;
|
281
|
+
nc = 4;
|
282
|
+
gemm<4, 4>(m0, m, n0, n);
|
283
|
+
break;
|
284
|
+
case 0x53:
|
285
|
+
mc = 5;
|
286
|
+
nc = 3;
|
287
|
+
gemm<5, 3>(m0, m, n0, n);
|
288
|
+
break;
|
289
|
+
case 0x35:
|
290
|
+
mc = 3;
|
272
291
|
nc = 5;
|
273
|
-
|
274
|
-
|
292
|
+
gemm<3, 5>(m0, m, n0, n);
|
293
|
+
break;
|
294
|
+
case 0x43:
|
295
|
+
mc = 4;
|
296
|
+
nc = 3;
|
297
|
+
gemm<4, 3>(m0, m, n0, n);
|
298
|
+
break;
|
299
|
+
#else
|
300
|
+
case 0x55:
|
301
|
+
case 0x54:
|
302
|
+
case 0x53:
|
303
|
+
case 0x45:
|
304
|
+
case 0x44:
|
305
|
+
case 0x43:
|
306
|
+
mc = 4;
|
307
|
+
nc = 3;
|
308
|
+
gemm<4, 3>(m0, m, n0, n);
|
309
|
+
break;
|
310
|
+
case 0x35:
|
311
|
+
#endif
|
312
|
+
case 0x34:
|
275
313
|
mc = 3;
|
276
314
|
nc = 4;
|
277
|
-
|
278
|
-
|
279
|
-
|
315
|
+
gemm<3, 4>(m0, m, n0, n);
|
316
|
+
break;
|
317
|
+
case 0x52:
|
318
|
+
mc = 5;
|
319
|
+
nc = 2;
|
320
|
+
gemm<5, 2>(m0, m, n0, n);
|
321
|
+
break;
|
322
|
+
case 0x33:
|
323
|
+
mc = 3;
|
324
|
+
nc = 3;
|
325
|
+
gemm<3, 3>(m0, m, n0, n);
|
326
|
+
break;
|
327
|
+
case 0x25:
|
328
|
+
mc = 2;
|
329
|
+
nc = 5;
|
330
|
+
gemm<2, 5>(m0, m, n0, n);
|
331
|
+
break;
|
332
|
+
case 0x42:
|
333
|
+
mc = 4;
|
334
|
+
nc = 2;
|
335
|
+
gemm<4, 2>(m0, m, n0, n);
|
336
|
+
break;
|
337
|
+
case 0x24:
|
338
|
+
mc = 2;
|
280
339
|
nc = 4;
|
281
|
-
|
282
|
-
|
340
|
+
gemm<2, 4>(m0, m, n0, n);
|
341
|
+
break;
|
342
|
+
case 0x32:
|
343
|
+
mc = 3;
|
344
|
+
nc = 2;
|
345
|
+
gemm<3, 2>(m0, m, n0, n);
|
346
|
+
break;
|
347
|
+
case 0x23:
|
348
|
+
mc = 2;
|
349
|
+
nc = 3;
|
350
|
+
gemm<2, 3>(m0, m, n0, n);
|
351
|
+
break;
|
352
|
+
case 0x51:
|
353
|
+
mc = 5;
|
354
|
+
nc = 1;
|
355
|
+
gemm<5, 1>(m0, m, n0, n);
|
356
|
+
break;
|
357
|
+
case 0x41:
|
283
358
|
mc = 4;
|
284
359
|
nc = 1;
|
285
|
-
|
286
|
-
|
360
|
+
gemm<4, 1>(m0, m, n0, n);
|
361
|
+
break;
|
362
|
+
case 0x22:
|
363
|
+
mc = 2;
|
364
|
+
nc = 2;
|
365
|
+
gemm<2, 2>(m0, m, n0, n);
|
366
|
+
break;
|
367
|
+
case 0x15:
|
368
|
+
mc = 1;
|
369
|
+
nc = 5;
|
370
|
+
gemm<1, 5>(m0, m, n0, n);
|
371
|
+
break;
|
372
|
+
case 0x14:
|
373
|
+
mc = 1;
|
374
|
+
nc = 4;
|
375
|
+
gemm<1, 4>(m0, m, n0, n);
|
376
|
+
break;
|
377
|
+
case 0x31:
|
378
|
+
mc = 3;
|
379
|
+
nc = 1;
|
380
|
+
gemm<3, 1>(m0, m, n0, n);
|
381
|
+
break;
|
382
|
+
case 0x13:
|
287
383
|
mc = 1;
|
384
|
+
nc = 3;
|
385
|
+
gemm<1, 3>(m0, m, n0, n);
|
386
|
+
break;
|
387
|
+
case 0x21:
|
388
|
+
mc = 2;
|
288
389
|
nc = 1;
|
289
|
-
|
390
|
+
gemm<2, 1>(m0, m, n0, n);
|
391
|
+
break;
|
392
|
+
case 0x12:
|
393
|
+
mc = 1;
|
394
|
+
nc = 2;
|
395
|
+
gemm<1, 2>(m0, m, n0, n);
|
396
|
+
break;
|
397
|
+
case 0x11:
|
398
|
+
mc = 1;
|
399
|
+
nc = 1;
|
400
|
+
gemm<1, 1>(m0, m, n0, n);
|
401
|
+
break;
|
402
|
+
default:
|
403
|
+
return;
|
290
404
|
}
|
291
405
|
mp = m0 + (m - m0) / mc * mc;
|
292
406
|
np = n0 + (n - n0) / nc * nc;
|
293
407
|
mnpack(mp, m, n0, np);
|
294
|
-
mnpack(m0,
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
NOINLINE void
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
D c40 = {0};
|
321
|
-
D c41 = {0};
|
322
|
-
D c42 = {0};
|
323
|
-
D c43 = {0};
|
324
|
-
D c44 = {0};
|
325
|
-
for (int l = 0; l < k; l += KN) {
|
326
|
-
V k0 = load<V>(B + ldb * (j + 0) + l);
|
327
|
-
V k1 = load<V>(B + ldb * (j + 1) + l);
|
328
|
-
V k2 = load<V>(B + ldb * (j + 2) + l);
|
329
|
-
V k3 = load<V>(B + ldb * (j + 3) + l);
|
330
|
-
V k4 = load<V>(B + ldb * (j + 4) + l);
|
331
|
-
V a0 = load<V>(A + lda * (i + 0) + l);
|
332
|
-
c00 = madd(a0, k0, c00);
|
333
|
-
c01 = madd(a0, k1, c01);
|
334
|
-
c02 = madd(a0, k2, c02);
|
335
|
-
c03 = madd(a0, k3, c03);
|
336
|
-
c04 = madd(a0, k4, c04);
|
337
|
-
V a1 = load<V>(A + lda * (i + 1) + l);
|
338
|
-
c10 = madd(a1, k0, c10);
|
339
|
-
c11 = madd(a1, k1, c11);
|
340
|
-
c12 = madd(a1, k2, c12);
|
341
|
-
c13 = madd(a1, k3, c13);
|
342
|
-
c14 = madd(a1, k4, c14);
|
343
|
-
V a2 = load<V>(A + lda * (i + 2) + l);
|
344
|
-
c20 = madd(a2, k0, c20);
|
345
|
-
c21 = madd(a2, k1, c21);
|
346
|
-
c22 = madd(a2, k2, c22);
|
347
|
-
c23 = madd(a2, k3, c23);
|
348
|
-
c24 = madd(a2, k4, c24);
|
349
|
-
V a3 = load<V>(A + lda * (i + 3) + l);
|
350
|
-
c30 = madd(a3, k0, c30);
|
351
|
-
c31 = madd(a3, k1, c31);
|
352
|
-
c32 = madd(a3, k2, c32);
|
353
|
-
c33 = madd(a3, k3, c33);
|
354
|
-
c34 = madd(a3, k4, c34);
|
355
|
-
V a4 = load<V>(A + lda * (i + 4) + l);
|
356
|
-
c40 = madd(a4, k0, c40);
|
357
|
-
c41 = madd(a4, k1, c41);
|
358
|
-
c42 = madd(a4, k2, c42);
|
359
|
-
c43 = madd(a4, k3, c43);
|
360
|
-
c44 = madd(a4, k4, c44);
|
361
|
-
}
|
362
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
363
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
364
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
365
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
366
|
-
C[ldc * (j + 0) + (i + 4)] = hsum(c40);
|
367
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
368
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
369
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
370
|
-
C[ldc * (j + 1) + (i + 3)] = hsum(c31);
|
371
|
-
C[ldc * (j + 1) + (i + 4)] = hsum(c41);
|
372
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
373
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
374
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
375
|
-
C[ldc * (j + 2) + (i + 3)] = hsum(c32);
|
376
|
-
C[ldc * (j + 2) + (i + 4)] = hsum(c42);
|
377
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
378
|
-
C[ldc * (j + 3) + (i + 1)] = hsum(c13);
|
379
|
-
C[ldc * (j + 3) + (i + 2)] = hsum(c23);
|
380
|
-
C[ldc * (j + 3) + (i + 3)] = hsum(c33);
|
381
|
-
C[ldc * (j + 3) + (i + 4)] = hsum(c43);
|
382
|
-
C[ldc * (j + 4) + (i + 0)] = hsum(c04);
|
383
|
-
C[ldc * (j + 4) + (i + 1)] = hsum(c14);
|
384
|
-
C[ldc * (j + 4) + (i + 2)] = hsum(c24);
|
385
|
-
C[ldc * (j + 4) + (i + 3)] = hsum(c34);
|
386
|
-
C[ldc * (j + 4) + (i + 4)] = hsum(c44);
|
387
|
-
END_KERNEL()
|
388
|
-
}
|
389
|
-
|
390
|
-
NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
|
391
|
-
BEGIN_KERNEL(3, 4)
|
392
|
-
D c00 = {0};
|
393
|
-
D c01 = {0};
|
394
|
-
D c02 = {0};
|
395
|
-
D c03 = {0};
|
396
|
-
D c10 = {0};
|
397
|
-
D c11 = {0};
|
398
|
-
D c12 = {0};
|
399
|
-
D c13 = {0};
|
400
|
-
D c20 = {0};
|
401
|
-
D c21 = {0};
|
402
|
-
D c22 = {0};
|
403
|
-
D c23 = {0};
|
404
|
-
for (int l = 0; l < k; l += KN) {
|
405
|
-
V k0 = load<V>(B + ldb * (j + 0) + l);
|
406
|
-
V k1 = load<V>(B + ldb * (j + 1) + l);
|
407
|
-
V k2 = load<V>(B + ldb * (j + 2) + l);
|
408
|
-
V k3 = load<V>(B + ldb * (j + 3) + l);
|
409
|
-
V a0 = load<V>(A + lda * (i + 0) + l);
|
410
|
-
c00 = madd(a0, k0, c00);
|
411
|
-
c01 = madd(a0, k1, c01);
|
412
|
-
c02 = madd(a0, k2, c02);
|
413
|
-
c03 = madd(a0, k3, c03);
|
414
|
-
V a1 = load<V>(A + lda * (i + 1) + l);
|
415
|
-
c10 = madd(a1, k0, c10);
|
416
|
-
c11 = madd(a1, k1, c11);
|
417
|
-
c12 = madd(a1, k2, c12);
|
418
|
-
c13 = madd(a1, k3, c13);
|
419
|
-
V a2 = load<V>(A + lda * (i + 2) + l);
|
420
|
-
c20 = madd(a2, k0, c20);
|
421
|
-
c21 = madd(a2, k1, c21);
|
422
|
-
c22 = madd(a2, k2, c22);
|
423
|
-
c23 = madd(a2, k3, c23);
|
424
|
-
}
|
425
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
426
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
427
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
428
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
429
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
430
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
431
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
432
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
433
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
434
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
435
|
-
C[ldc * (j + 3) + (i + 1)] = hsum(c13);
|
436
|
-
C[ldc * (j + 3) + (i + 2)] = hsum(c23);
|
437
|
-
END_KERNEL()
|
438
|
-
}
|
439
|
-
|
440
|
-
NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
|
441
|
-
BEGIN_KERNEL(1, 4)
|
442
|
-
D c00 = {0}, e00 = {0};
|
443
|
-
D c01 = {0}, e01 = {0};
|
444
|
-
D c02 = {0}, e02 = {0};
|
445
|
-
D c03 = {0}, e03 = {0};
|
446
|
-
for (int l = 0; l < k; l += KN) {
|
447
|
-
V a = load<V>(A + lda * (i + 0) + l);
|
448
|
-
c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
|
449
|
-
c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
|
450
|
-
c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
|
451
|
-
c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
|
452
|
-
}
|
453
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
454
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
455
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
456
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
457
|
-
END_KERNEL()
|
458
|
-
}
|
459
|
-
|
460
|
-
NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
|
461
|
-
BEGIN_KERNEL(4, 1)
|
462
|
-
D c00 = {0}, e00 = {0};
|
463
|
-
D c10 = {0}, e10 = {0};
|
464
|
-
D c20 = {0}, e20 = {0};
|
465
|
-
D c30 = {0}, e30 = {0};
|
466
|
-
for (int l = 0; l < k; l += KN) {
|
467
|
-
V b = load<V>(B + ldb * (j + 0) + l);
|
468
|
-
c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
|
469
|
-
c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
|
470
|
-
c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
|
471
|
-
c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
|
408
|
+
mnpack(m0, m, np, n);
|
409
|
+
}
|
410
|
+
|
411
|
+
template <int RM, int RN>
|
412
|
+
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
413
|
+
int ytiles = (m - m0) / RM;
|
414
|
+
int xtiles = (n - n0) / RN;
|
415
|
+
int tiles = xtiles * ytiles;
|
416
|
+
int duty = (tiles + nth - 1) / nth;
|
417
|
+
int start = duty * ith;
|
418
|
+
int end = start + duty;
|
419
|
+
if (end > tiles)
|
420
|
+
end = tiles;
|
421
|
+
for (int job = start; job < end; ++job) {
|
422
|
+
int ii = m0 + job / xtiles * RM;
|
423
|
+
int jj = n0 + job % xtiles * RN;
|
424
|
+
D Cv[RN][RM] = {};
|
425
|
+
for (int l = 0; l < k; l += KN)
|
426
|
+
for (int j = 0; j < RN; ++j)
|
427
|
+
for (int i = 0; i < RM; ++i)
|
428
|
+
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
429
|
+
load<V>(B + ldb * (jj + j) + l),
|
430
|
+
Cv[j][i]);
|
431
|
+
for (int j = 0; j < RN; ++j)
|
432
|
+
for (int i = 0; i < RM; ++i)
|
433
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
472
434
|
}
|
473
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
474
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
475
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
476
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
477
|
-
END_KERNEL()
|
478
|
-
}
|
479
|
-
|
480
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
481
|
-
BEGIN_KERNEL(1, 1)
|
482
|
-
D c = {0}, e = {0};
|
483
|
-
for (int l = 0; l < k; l += KN)
|
484
|
-
c = madder(load<V>(A + lda * i + l),
|
485
|
-
load<V>(B + ldb * j + l), c, &e);
|
486
|
-
C[ldc * j + i] = hsum(c);
|
487
|
-
END_KERNEL()
|
488
435
|
}
|
489
436
|
|
490
437
|
const TA *const A;
|
@@ -521,120 +468,97 @@ class tinyBLAS_Q0_ARM {
|
|
521
468
|
private:
|
522
469
|
NOINLINE void mnpack(int m0, int m, int n0, int n) {
|
523
470
|
int mc, nc, mp, np;
|
524
|
-
|
525
|
-
|
526
|
-
if (m - m0 >= 3 && n - n0 >= 3) {
|
471
|
+
switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
|
472
|
+
case 0x33:
|
527
473
|
mc = 3;
|
528
474
|
nc = 3;
|
529
|
-
|
530
|
-
|
475
|
+
gemm<3, 3>(m0, m, n0, n);
|
476
|
+
break;
|
477
|
+
case 0x32:
|
478
|
+
mc = 3;
|
479
|
+
nc = 2;
|
480
|
+
gemm<3, 2>(m0, m, n0, n);
|
481
|
+
break;
|
482
|
+
case 0x23:
|
483
|
+
mc = 2;
|
484
|
+
nc = 3;
|
485
|
+
gemm<2, 3>(m0, m, n0, n);
|
486
|
+
break;
|
487
|
+
case 0x22:
|
488
|
+
mc = 2;
|
489
|
+
nc = 2;
|
490
|
+
gemm<2, 2>(m0, m, n0, n);
|
491
|
+
break;
|
492
|
+
case 0x31:
|
493
|
+
mc = 3;
|
494
|
+
nc = 1;
|
495
|
+
gemm<3, 1>(m0, m, n0, n);
|
496
|
+
break;
|
497
|
+
case 0x13:
|
498
|
+
mc = 1;
|
499
|
+
nc = 3;
|
500
|
+
gemm<1, 3>(m0, m, n0, n);
|
501
|
+
break;
|
502
|
+
case 0x21:
|
503
|
+
mc = 2;
|
504
|
+
nc = 1;
|
505
|
+
gemm<2, 1>(m0, m, n0, n);
|
506
|
+
break;
|
507
|
+
case 0x12:
|
508
|
+
mc = 1;
|
509
|
+
nc = 2;
|
510
|
+
gemm<1, 2>(m0, m, n0, n);
|
511
|
+
break;
|
512
|
+
case 0x11:
|
531
513
|
mc = 1;
|
532
514
|
nc = 1;
|
533
|
-
|
515
|
+
gemm<1, 1>(m0, m, n0, n);
|
516
|
+
break;
|
517
|
+
default:
|
518
|
+
return;
|
534
519
|
}
|
535
520
|
mp = m0 + (m - m0) / mc * mc;
|
536
521
|
np = n0 + (n - n0) / nc * nc;
|
537
522
|
mnpack(mp, m, n0, np);
|
538
|
-
mnpack(m0,
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
NOINLINE void
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
|
571
|
-
c02 = vmlaq_n_f32(
|
572
|
-
c02,
|
573
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
|
574
|
-
load_hi(Ap0 + l), load_hi(Bp2 + l))),
|
575
|
-
unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
|
576
|
-
c10 = vmlaq_n_f32(
|
577
|
-
c10,
|
578
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
|
579
|
-
load_hi(Ap1 + l), load_hi(Bp0 + l))),
|
580
|
-
unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
|
581
|
-
c11 = vmlaq_n_f32(
|
582
|
-
c11,
|
583
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
|
584
|
-
load_hi(Ap1 + l), load_hi(Bp1 + l))),
|
585
|
-
unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
|
586
|
-
c12 = vmlaq_n_f32(
|
587
|
-
c12,
|
588
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
|
589
|
-
load_hi(Ap1 + l), load_hi(Bp2 + l))),
|
590
|
-
unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
|
591
|
-
c20 = vmlaq_n_f32(
|
592
|
-
c20,
|
593
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
|
594
|
-
load_hi(Ap2 + l), load_hi(Bp0 + l))),
|
595
|
-
unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
|
596
|
-
c21 = vmlaq_n_f32(
|
597
|
-
c21,
|
598
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
|
599
|
-
load_hi(Ap2 + l), load_hi(Bp1 + l))),
|
600
|
-
unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
|
601
|
-
c22 = vmlaq_n_f32(
|
602
|
-
c22,
|
603
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
|
604
|
-
load_hi(Ap2 + l), load_hi(Bp2 + l))),
|
605
|
-
unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
|
606
|
-
}
|
607
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
608
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
609
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
610
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
611
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
612
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
613
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
614
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
615
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
616
|
-
END_KERNEL()
|
617
|
-
}
|
618
|
-
|
619
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
620
|
-
BEGIN_KERNEL(1, 1)
|
621
|
-
float32x4_t acc = vdupq_n_f32(0.f);
|
622
|
-
const TA *Ap = A + lda * i;
|
623
|
-
const block_q8_0 *Bp = B + ldb * j;
|
624
|
-
for (int l = 0; l < k; ++l) {
|
625
|
-
acc = vmlaq_n_f32(acc,
|
626
|
-
vcvtq_f32_s32(vdotq_s32(
|
627
|
-
vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
|
628
|
-
load_hi(Ap + l), load_hi(Bp + l))),
|
629
|
-
unhalf(Ap[l].d) * unhalf(Bp[l].d));
|
523
|
+
mnpack(m0, m, np, n);
|
524
|
+
}
|
525
|
+
|
526
|
+
template <int RM, int RN>
|
527
|
+
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
528
|
+
int ytiles = (m - m0) / RM;
|
529
|
+
int xtiles = (n - n0) / RN;
|
530
|
+
int tiles = xtiles * ytiles;
|
531
|
+
int duty = (tiles + nth - 1) / nth;
|
532
|
+
int start = duty * ith;
|
533
|
+
int end = start + duty;
|
534
|
+
if (end > tiles)
|
535
|
+
end = tiles;
|
536
|
+
for (int job = start; job < end; ++job) {
|
537
|
+
int ii = m0 + job / xtiles * RM;
|
538
|
+
int jj = n0 + job % xtiles * RN;
|
539
|
+
float32x4_t Cv[RN][RM] = {};
|
540
|
+
for (int l = 0; l < k; ++l)
|
541
|
+
for (int j = 0; j < RN; ++j)
|
542
|
+
for (int i = 0; i < RM; ++i)
|
543
|
+
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
544
|
+
vcvtq_f32_s32(vdotq_s32(
|
545
|
+
vdotq_s32(vdupq_n_s32(0),
|
546
|
+
load_lo(A + lda * (ii + i) + l),
|
547
|
+
load_lo(B + ldb * (jj + j) + l)),
|
548
|
+
load_hi(A + lda * (ii + i) + l),
|
549
|
+
load_hi(B + ldb * (jj + j) + l))),
|
550
|
+
unhalf(A[lda * (ii + i) + l].d) *
|
551
|
+
unhalf(B[ldb * (jj + j) + l].d));
|
552
|
+
for (int j = 0; j < RN; ++j)
|
553
|
+
for (int i = 0; i < RM; ++i)
|
554
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
630
555
|
}
|
631
|
-
C[ldc * j + i] = hsum(acc);
|
632
|
-
END_KERNEL()
|
633
556
|
}
|
634
557
|
|
635
558
|
inline int8x16_t load_lo(const block_q8_0 *b) {
|
636
559
|
return vld1q_s8(b->qs);
|
637
560
|
}
|
561
|
+
|
638
562
|
inline int8x16_t load_hi(const block_q8_0 *b) {
|
639
563
|
return vld1q_s8(b->qs + 16);
|
640
564
|
}
|
@@ -644,6 +568,7 @@ class tinyBLAS_Q0_ARM {
|
|
644
568
|
vdupq_n_u8(0x0f))),
|
645
569
|
vdupq_n_s8(0x8));
|
646
570
|
}
|
571
|
+
|
647
572
|
inline int8x16_t load_hi(const block_q4_0 *b) {
|
648
573
|
return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
|
649
574
|
vdupq_n_s8(0x8));
|
@@ -679,217 +604,143 @@ class tinyBLAS_Q0_AVX2 {
|
|
679
604
|
}
|
680
605
|
|
681
606
|
private:
|
682
|
-
|
607
|
+
void mnpack(int m0, int m, int n0, int n) {
|
683
608
|
int mc, nc, mp, np;
|
684
|
-
|
685
|
-
|
686
|
-
|
609
|
+
switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
|
610
|
+
#if VECTOR_REGISTERS == 32
|
611
|
+
case 0x44:
|
612
|
+
mc = 4;
|
613
|
+
nc = 4;
|
614
|
+
gemm<4, 4>(m0, m, n0, n);
|
615
|
+
break;
|
616
|
+
case 0x43:
|
617
|
+
mc = 4;
|
618
|
+
nc = 3;
|
619
|
+
gemm<4, 3>(m0, m, n0, n);
|
620
|
+
break;
|
621
|
+
case 0x34:
|
622
|
+
mc = 3;
|
623
|
+
nc = 4;
|
624
|
+
gemm<3, 4>(m0, m, n0, n);
|
625
|
+
break;
|
626
|
+
case 0x33:
|
627
|
+
mc = 3;
|
628
|
+
nc = 3;
|
629
|
+
gemm<3, 3>(m0, m, n0, n);
|
630
|
+
break;
|
631
|
+
case 0x42:
|
687
632
|
mc = 4;
|
633
|
+
nc = 2;
|
634
|
+
gemm<4, 2>(m0, m, n0, n);
|
635
|
+
break;
|
636
|
+
case 0x24:
|
637
|
+
mc = 2;
|
638
|
+
nc = 4;
|
639
|
+
gemm<2, 4>(m0, m, n0, n);
|
640
|
+
break;
|
641
|
+
#else
|
642
|
+
case 0x44:
|
643
|
+
case 0x43:
|
644
|
+
case 0x42:
|
645
|
+
mc = 4;
|
646
|
+
nc = 2;
|
647
|
+
gemm<4, 2>(m0, m, n0, n);
|
648
|
+
break;
|
649
|
+
case 0x34:
|
650
|
+
case 0x24:
|
651
|
+
mc = 2;
|
652
|
+
nc = 4;
|
653
|
+
gemm<2, 4>(m0, m, n0, n);
|
654
|
+
break;
|
655
|
+
case 0x33:
|
656
|
+
#endif
|
657
|
+
case 0x32:
|
658
|
+
mc = 3;
|
659
|
+
nc = 2;
|
660
|
+
gemm<3, 2>(m0, m, n0, n);
|
661
|
+
break;
|
662
|
+
case 0x23:
|
663
|
+
mc = 2;
|
688
664
|
nc = 3;
|
689
|
-
|
690
|
-
|
665
|
+
gemm<2, 3>(m0, m, n0, n);
|
666
|
+
break;
|
667
|
+
case 0x41:
|
691
668
|
mc = 4;
|
692
669
|
nc = 1;
|
693
|
-
|
694
|
-
|
670
|
+
gemm<4, 1>(m0, m, n0, n);
|
671
|
+
break;
|
672
|
+
case 0x22:
|
673
|
+
mc = 2;
|
674
|
+
nc = 2;
|
675
|
+
gemm<2, 2>(m0, m, n0, n);
|
676
|
+
break;
|
677
|
+
case 0x14:
|
695
678
|
mc = 1;
|
696
679
|
nc = 4;
|
697
|
-
|
698
|
-
|
680
|
+
gemm<1, 4>(m0, m, n0, n);
|
681
|
+
break;
|
682
|
+
case 0x31:
|
683
|
+
mc = 3;
|
684
|
+
nc = 1;
|
685
|
+
gemm<3, 1>(m0, m, n0, n);
|
686
|
+
break;
|
687
|
+
case 0x13:
|
699
688
|
mc = 1;
|
689
|
+
nc = 3;
|
690
|
+
gemm<1, 3>(m0, m, n0, n);
|
691
|
+
break;
|
692
|
+
case 0x21:
|
693
|
+
mc = 2;
|
700
694
|
nc = 1;
|
701
|
-
|
695
|
+
gemm<2, 1>(m0, m, n0, n);
|
696
|
+
break;
|
697
|
+
case 0x12:
|
698
|
+
mc = 1;
|
699
|
+
nc = 2;
|
700
|
+
gemm<1, 2>(m0, m, n0, n);
|
701
|
+
break;
|
702
|
+
case 0x11:
|
703
|
+
mc = 1;
|
704
|
+
nc = 1;
|
705
|
+
gemm<1, 1>(m0, m, n0, n);
|
706
|
+
break;
|
707
|
+
default:
|
708
|
+
return;
|
702
709
|
}
|
703
710
|
mp = m0 + (m - m0) / mc * mc;
|
704
711
|
np = n0 + (n - n0) / nc * nc;
|
705
712
|
mnpack(mp, m, n0, np);
|
706
|
-
mnpack(m0,
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
NOINLINE void
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
__m256i e0 = load(Ap0 + l);
|
737
|
-
__m256i e1 = load(Ap1 + l);
|
738
|
-
__m256i e2 = load(Ap2 + l);
|
739
|
-
__m256i e3 = load(Ap3 + l);
|
740
|
-
float db0 = unhalf(Bp0[l].d);
|
741
|
-
__m256 d00 = _mm256_set1_ps(da0 * db0);
|
742
|
-
__m256 d10 = _mm256_set1_ps(da1 * db0);
|
743
|
-
__m256 d20 = _mm256_set1_ps(da2 * db0);
|
744
|
-
__m256 d30 = _mm256_set1_ps(da3 * db0);
|
745
|
-
__m256i f0 = load(Bp0 + l);
|
746
|
-
__m256i u0 = _mm256_sign_epi8(f0, f0);
|
747
|
-
__m256i s00 = _mm256_sign_epi8(e0, f0);
|
748
|
-
__m256i s10 = _mm256_sign_epi8(e1, f0);
|
749
|
-
__m256i s20 = _mm256_sign_epi8(e2, f0);
|
750
|
-
__m256i s30 = _mm256_sign_epi8(e3, f0);
|
751
|
-
c00 = madd(d00, updot(u0, s00), c00);
|
752
|
-
c10 = madd(d10, updot(u0, s10), c10);
|
753
|
-
c20 = madd(d20, updot(u0, s20), c20);
|
754
|
-
c30 = madd(d30, updot(u0, s30), c30);
|
755
|
-
float db1 = unhalf(Bp1[l].d);
|
756
|
-
__m256 d01 = _mm256_set1_ps(da0 * db1);
|
757
|
-
__m256 d11 = _mm256_set1_ps(da1 * db1);
|
758
|
-
__m256 d21 = _mm256_set1_ps(da2 * db1);
|
759
|
-
__m256 d31 = _mm256_set1_ps(da3 * db1);
|
760
|
-
__m256i f1 = load(Bp1 + l);
|
761
|
-
__m256i u1 = _mm256_sign_epi8(f1, f1);
|
762
|
-
__m256i s01 = _mm256_sign_epi8(e0, f1);
|
763
|
-
__m256i s11 = _mm256_sign_epi8(e1, f1);
|
764
|
-
__m256i s21 = _mm256_sign_epi8(e2, f1);
|
765
|
-
__m256i s31 = _mm256_sign_epi8(e3, f1);
|
766
|
-
c01 = madd(d01, updot(u1, s01), c01);
|
767
|
-
c11 = madd(d11, updot(u1, s11), c11);
|
768
|
-
c21 = madd(d21, updot(u1, s21), c21);
|
769
|
-
c31 = madd(d31, updot(u1, s31), c31);
|
770
|
-
float db2 = unhalf(Bp2[l].d);
|
771
|
-
__m256 d02 = _mm256_set1_ps(da0 * db2);
|
772
|
-
__m256 d12 = _mm256_set1_ps(da1 * db2);
|
773
|
-
__m256 d22 = _mm256_set1_ps(da2 * db2);
|
774
|
-
__m256 d32 = _mm256_set1_ps(da3 * db2);
|
775
|
-
__m256i f2 = load(Bp2 + l);
|
776
|
-
__m256i u2 = _mm256_sign_epi8(f2, f2);
|
777
|
-
__m256i s02 = _mm256_sign_epi8(e0, f2);
|
778
|
-
__m256i s12 = _mm256_sign_epi8(e1, f2);
|
779
|
-
__m256i s22 = _mm256_sign_epi8(e2, f2);
|
780
|
-
__m256i s32 = _mm256_sign_epi8(e3, f2);
|
781
|
-
c02 = madd(d02, updot(u2, s02), c02);
|
782
|
-
c12 = madd(d12, updot(u2, s12), c12);
|
783
|
-
c22 = madd(d22, updot(u2, s22), c22);
|
784
|
-
c32 = madd(d32, updot(u2, s32), c32);
|
785
|
-
}
|
786
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
787
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
788
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
789
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
790
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
791
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
792
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
793
|
-
C[ldc * (j + 1) + (i + 3)] = hsum(c31);
|
794
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
795
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
796
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
797
|
-
C[ldc * (j + 2) + (i + 3)] = hsum(c32);
|
798
|
-
END_KERNEL()
|
799
|
-
}
|
800
|
-
|
801
|
-
NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
|
802
|
-
BEGIN_KERNEL(4, 1)
|
803
|
-
__m256 c0 = _mm256_setzero_ps();
|
804
|
-
__m256 c1 = _mm256_setzero_ps();
|
805
|
-
__m256 c2 = _mm256_setzero_ps();
|
806
|
-
__m256 c3 = _mm256_setzero_ps();
|
807
|
-
const TA *Ap0 = A + lda * (i + 0);
|
808
|
-
const TA *Ap1 = A + lda * (i + 1);
|
809
|
-
const TA *Ap2 = A + lda * (i + 2);
|
810
|
-
const TA *Ap3 = A + lda * (i + 3);
|
811
|
-
const TB *Bp = B + ldb * j;
|
812
|
-
for (int l = 0; l < k; ++l) {
|
813
|
-
float db0 = unhalf(Bp[l].d);
|
814
|
-
__m256i f = load(Bp + l);
|
815
|
-
__m256i u = _mm256_sign_epi8(f, f);
|
816
|
-
__m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
|
817
|
-
__m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
|
818
|
-
__m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
|
819
|
-
__m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
|
820
|
-
__m256i e0 = load(Ap0 + l);
|
821
|
-
__m256i e1 = load(Ap1 + l);
|
822
|
-
__m256i e2 = load(Ap2 + l);
|
823
|
-
__m256i e3 = load(Ap3 + l);
|
824
|
-
__m256i s0 = _mm256_sign_epi8(e0, f);
|
825
|
-
__m256i s1 = _mm256_sign_epi8(e1, f);
|
826
|
-
__m256i s2 = _mm256_sign_epi8(e2, f);
|
827
|
-
__m256i s3 = _mm256_sign_epi8(e3, f);
|
828
|
-
__m256 g0 = updot(u, s0);
|
829
|
-
__m256 g1 = updot(u, s1);
|
830
|
-
__m256 g2 = updot(u, s2);
|
831
|
-
__m256 g3 = updot(u, s3);
|
832
|
-
c0 = madd(d0, g0, c0);
|
833
|
-
c1 = madd(d1, g1, c1);
|
834
|
-
c2 = madd(d2, g2, c2);
|
835
|
-
c3 = madd(d3, g3, c3);
|
836
|
-
}
|
837
|
-
C[ldc * j + (i + 0)] = hsum(c0);
|
838
|
-
C[ldc * j + (i + 1)] = hsum(c1);
|
839
|
-
C[ldc * j + (i + 2)] = hsum(c2);
|
840
|
-
C[ldc * j + (i + 3)] = hsum(c3);
|
841
|
-
END_KERNEL()
|
842
|
-
}
|
843
|
-
|
844
|
-
NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
|
845
|
-
BEGIN_KERNEL(1, 4)
|
846
|
-
__m256 c0 = _mm256_setzero_ps();
|
847
|
-
__m256 c1 = _mm256_setzero_ps();
|
848
|
-
__m256 c2 = _mm256_setzero_ps();
|
849
|
-
__m256 c3 = _mm256_setzero_ps();
|
850
|
-
const TB *Bp0 = B + ldb * (j + 0);
|
851
|
-
const TB *Bp1 = B + ldb * (j + 1);
|
852
|
-
const TB *Bp2 = B + ldb * (j + 2);
|
853
|
-
const TB *Bp3 = B + ldb * (j + 3);
|
854
|
-
const TA *Ap = A + lda * i;
|
855
|
-
for (int l = 0; l < k; ++l) {
|
856
|
-
float da0 = unhalf(Ap[l].d);
|
857
|
-
__m256i f = load(Ap + l);
|
858
|
-
__m256i u = _mm256_sign_epi8(f, f);
|
859
|
-
__m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
|
860
|
-
__m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
|
861
|
-
__m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
|
862
|
-
__m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
|
863
|
-
__m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
|
864
|
-
__m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
|
865
|
-
__m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
|
866
|
-
__m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
|
867
|
-
c0 = madd(d0, g0, c0);
|
868
|
-
c1 = madd(d1, g1, c1);
|
869
|
-
c2 = madd(d2, g2, c2);
|
870
|
-
c3 = madd(d3, g3, c3);
|
871
|
-
}
|
872
|
-
C[ldc * (j + 0) + i] = hsum(c0);
|
873
|
-
C[ldc * (j + 1) + i] = hsum(c1);
|
874
|
-
C[ldc * (j + 2) + i] = hsum(c2);
|
875
|
-
C[ldc * (j + 3) + i] = hsum(c3);
|
876
|
-
END_KERNEL()
|
877
|
-
}
|
878
|
-
|
879
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
880
|
-
BEGIN_KERNEL(1, 1)
|
881
|
-
__m256 c = _mm256_setzero_ps();
|
882
|
-
const TA *Ap = A + lda * i;
|
883
|
-
const TB *Bp = B + ldb * j;
|
884
|
-
for (int l = 0; l < k; ++l) {
|
885
|
-
__m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
|
886
|
-
__m256i e = load(Ap + l);
|
887
|
-
__m256i f = load(Bp + l);
|
888
|
-
__m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
|
889
|
-
c = madd(d, g, c);
|
713
|
+
mnpack(m0, m, np, n);
|
714
|
+
}
|
715
|
+
|
716
|
+
template <int RM, int RN>
|
717
|
+
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
718
|
+
int ytiles = (m - m0) / RM;
|
719
|
+
int xtiles = (n - n0) / RN;
|
720
|
+
int tiles = xtiles * ytiles;
|
721
|
+
int duty = (tiles + nth - 1) / nth;
|
722
|
+
int start = duty * ith;
|
723
|
+
int end = start + duty;
|
724
|
+
if (end > tiles)
|
725
|
+
end = tiles;
|
726
|
+
for (int job = start; job < end; ++job) {
|
727
|
+
int ii = m0 + job / xtiles * RM;
|
728
|
+
int jj = n0 + job % xtiles * RN;
|
729
|
+
__m256 Cv[RN][RM] = {};
|
730
|
+
for (int l = 0; l < k; ++l)
|
731
|
+
for (int j = 0; j < RN; ++j)
|
732
|
+
for (int i = 0; i < RM; ++i)
|
733
|
+
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
734
|
+
unhalf(B[ldb * (jj + j) + l].d)),
|
735
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
736
|
+
load(A + lda * (ii + i) + l)),
|
737
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
738
|
+
load(A + lda * (ii + i) + l))),
|
739
|
+
Cv[j][i]);
|
740
|
+
for (int j = 0; j < RN; ++j)
|
741
|
+
for (int i = 0; i < RM; ++i)
|
742
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
890
743
|
}
|
891
|
-
C[ldc * j + i] = hsum(c);
|
892
|
-
END_KERNEL()
|
893
744
|
}
|
894
745
|
|
895
746
|
inline __m256i load(const block_q8_0 *b) {
|
@@ -911,10 +762,10 @@ class tinyBLAS_Q0_AVX2 {
|
|
911
762
|
}
|
912
763
|
|
913
764
|
static inline __m256i denibble(const uint8_t *p) {
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
765
|
+
__m128i x = _mm_loadu_si128((const __m128i *)p);
|
766
|
+
return _mm256_and_si256(_mm256_set1_epi8(15),
|
767
|
+
_mm256_insertf128_si256(_mm256_castsi128_si256(x),
|
768
|
+
_mm_srli_epi16(x, 4), 1));
|
918
769
|
}
|
919
770
|
|
920
771
|
const TA *const A;
|