llama_cpp 0.14.6 → 0.14.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -50,6 +50,7 @@
50
50
  #pragma GCC diagnostic ignored "-Wignored-attributes"
51
51
 
52
52
  #include "sgemm.h"
53
+ #include <algorithm>
53
54
  #include "ggml-impl.h"
54
55
  #include "ggml-quants.h"
55
56
 
@@ -65,22 +66,6 @@
65
66
  #define VECTOR_REGISTERS 16
66
67
  #endif
67
68
 
68
- // there will be blocks
69
- #define BEGIN_KERNEL(RM, RN) \
70
- int ytiles = (m - m0) / RM; \
71
- int xtiles = (n - n0) / RN; \
72
- int tiles = ytiles * xtiles; \
73
- int duty = (tiles + nth - 1) / nth; \
74
- int start = duty * ith; \
75
- int end = start + duty; \
76
- if (end > tiles) \
77
- end = tiles; \
78
- for (int job = start; job < end; ++job) { \
79
- int i = m0 + job / xtiles * RM; \
80
- int j = n0 + job % xtiles * RN;
81
-
82
- #define END_KERNEL() }
83
-
84
69
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
85
70
 
86
71
  namespace {
@@ -122,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
122
107
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
123
108
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
124
109
 
110
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
111
+ // VECTORIZED FUSED MULTIPLY ADD
112
+
113
+ /**
114
+ * Computes a * b + c.
115
+ */
116
+ template <typename T, typename U>
117
+ inline U madd(T a, T b, U c) {
118
+ return add(mul(a, b), c);
119
+ }
120
+
121
+ #if defined(__FMA__)
122
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
123
+ template <>
124
+ inline __m256 madd(__m256 a, __m256 b, __m256 c) {
125
+ return _mm256_fmadd_ps(a, b, c);
126
+ }
127
+ #endif
128
+ #if defined(__AVX512F__)
129
+ template <>
130
+ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
131
+ return _mm512_fmadd_ps(a, b, c);
132
+ }
133
+ #endif
134
+ #endif
135
+
136
+ #if defined(__ARM_FEATURE_FMA)
137
+ template <>
138
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
139
+ return vfmaq_f32(c, b, a);
140
+ }
141
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
142
+ template <>
143
+ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
144
+ return vfmaq_f16(c, b, a);
145
+ }
146
+ #endif
147
+ #endif
148
+
125
149
  ////////////////////////////////////////////////////////////////////////////////////////////////////
126
150
  // VECTORIZED HORIZONTAL SUM
127
151
 
@@ -213,36 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
213
237
  }
214
238
  #endif // __AVX512F__
215
239
 
216
- ////////////////////////////////////////////////////////////////////////////////////////////////////
217
- // ABSTRACTIONS
218
-
219
- /**
220
- * Computes a * b + c.
221
- *
222
- * This operation will become fused into a single arithmetic instruction
223
- * if the hardware has support for this feature, e.g. Intel Haswell+ (c.
224
- * 2013), AMD Bulldozer+ (c. 2011), etc.
225
- */
226
- template <typename T, typename U>
227
- inline U madd(T a, T b, U c) {
228
- return add(mul(a, b), c);
229
- }
230
-
231
- /**
232
- * Computes a * b + c with error correction.
233
- *
234
- * @see W. Kahan, "Further remarks on reducing truncation errors,"
235
- * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
236
- * doi: 10.1145/363707.363723.
237
- */
238
- template <typename T, typename U>
239
- inline U madder(T a, T b, U c, U *e) {
240
- U y = sub(mul(a, b), *e);
241
- U t = add(c, y);
242
- *e = sub(sub(t, c), y);
243
- return t;
244
- }
245
-
246
240
  ////////////////////////////////////////////////////////////////////////////////////////////////////
247
241
  // FLOATING POINT MATRIX MULTIPLICATION
248
242
 
@@ -265,226 +259,179 @@ class tinyBLAS {
265
259
  private:
266
260
  NOINLINE void mnpack(int m0, int m, int n0, int n) {
267
261
  int mc, nc, mp, np;
268
- if (m - m0 <= 0 || n - n0 <= 0)
269
- return;
270
- if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) {
262
+ switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
263
+ #if VECTOR_REGISTERS == 32
264
+ case 0x55:
265
+ mc = 5;
266
+ nc = 5;
267
+ gemm<5, 5>(m0, m, n0, n);
268
+ break;
269
+ case 0x45:
270
+ mc = 4;
271
+ nc = 5;
272
+ gemm<4, 5>(m0, m, n0, n);
273
+ break;
274
+ case 0x54:
271
275
  mc = 5;
276
+ nc = 4;
277
+ gemm<5, 4>(m0, m, n0, n);
278
+ break;
279
+ case 0x44:
280
+ mc = 4;
281
+ nc = 4;
282
+ gemm<4, 4>(m0, m, n0, n);
283
+ break;
284
+ case 0x53:
285
+ mc = 5;
286
+ nc = 3;
287
+ gemm<5, 3>(m0, m, n0, n);
288
+ break;
289
+ case 0x35:
290
+ mc = 3;
272
291
  nc = 5;
273
- gemm5x5(m0, m, n0, n);
274
- } else if (n - n0 >= 4 && m - m0 >= 3) {
292
+ gemm<3, 5>(m0, m, n0, n);
293
+ break;
294
+ case 0x43:
295
+ mc = 4;
296
+ nc = 3;
297
+ gemm<4, 3>(m0, m, n0, n);
298
+ break;
299
+ #else
300
+ case 0x55:
301
+ case 0x54:
302
+ case 0x53:
303
+ case 0x45:
304
+ case 0x44:
305
+ case 0x43:
306
+ mc = 4;
307
+ nc = 3;
308
+ gemm<4, 3>(m0, m, n0, n);
309
+ break;
310
+ case 0x35:
311
+ #endif
312
+ case 0x34:
275
313
  mc = 3;
276
314
  nc = 4;
277
- gemm3x4(m0, m, n0, n);
278
- } else if (n - n0 >= 4) {
279
- mc = 1;
315
+ gemm<3, 4>(m0, m, n0, n);
316
+ break;
317
+ case 0x52:
318
+ mc = 5;
319
+ nc = 2;
320
+ gemm<5, 2>(m0, m, n0, n);
321
+ break;
322
+ case 0x33:
323
+ mc = 3;
324
+ nc = 3;
325
+ gemm<3, 3>(m0, m, n0, n);
326
+ break;
327
+ case 0x25:
328
+ mc = 2;
329
+ nc = 5;
330
+ gemm<2, 5>(m0, m, n0, n);
331
+ break;
332
+ case 0x42:
333
+ mc = 4;
334
+ nc = 2;
335
+ gemm<4, 2>(m0, m, n0, n);
336
+ break;
337
+ case 0x24:
338
+ mc = 2;
280
339
  nc = 4;
281
- gemm1x4(m0, m, n0, n);
282
- } else if (m - m0 >= 4) {
340
+ gemm<2, 4>(m0, m, n0, n);
341
+ break;
342
+ case 0x32:
343
+ mc = 3;
344
+ nc = 2;
345
+ gemm<3, 2>(m0, m, n0, n);
346
+ break;
347
+ case 0x23:
348
+ mc = 2;
349
+ nc = 3;
350
+ gemm<2, 3>(m0, m, n0, n);
351
+ break;
352
+ case 0x51:
353
+ mc = 5;
354
+ nc = 1;
355
+ gemm<5, 1>(m0, m, n0, n);
356
+ break;
357
+ case 0x41:
283
358
  mc = 4;
284
359
  nc = 1;
285
- gemm4x1(m0, m, n0, n);
286
- } else {
360
+ gemm<4, 1>(m0, m, n0, n);
361
+ break;
362
+ case 0x22:
363
+ mc = 2;
364
+ nc = 2;
365
+ gemm<2, 2>(m0, m, n0, n);
366
+ break;
367
+ case 0x15:
368
+ mc = 1;
369
+ nc = 5;
370
+ gemm<1, 5>(m0, m, n0, n);
371
+ break;
372
+ case 0x14:
373
+ mc = 1;
374
+ nc = 4;
375
+ gemm<1, 4>(m0, m, n0, n);
376
+ break;
377
+ case 0x31:
378
+ mc = 3;
379
+ nc = 1;
380
+ gemm<3, 1>(m0, m, n0, n);
381
+ break;
382
+ case 0x13:
287
383
  mc = 1;
384
+ nc = 3;
385
+ gemm<1, 3>(m0, m, n0, n);
386
+ break;
387
+ case 0x21:
388
+ mc = 2;
288
389
  nc = 1;
289
- gemm1x1(m0, m, n0, n);
390
+ gemm<2, 1>(m0, m, n0, n);
391
+ break;
392
+ case 0x12:
393
+ mc = 1;
394
+ nc = 2;
395
+ gemm<1, 2>(m0, m, n0, n);
396
+ break;
397
+ case 0x11:
398
+ mc = 1;
399
+ nc = 1;
400
+ gemm<1, 1>(m0, m, n0, n);
401
+ break;
402
+ default:
403
+ return;
290
404
  }
291
405
  mp = m0 + (m - m0) / mc * mc;
292
406
  np = n0 + (n - n0) / nc * nc;
293
407
  mnpack(mp, m, n0, np);
294
- mnpack(m0, mp, np, n);
295
- mnpack(mp, m, np, n);
296
- }
297
-
298
- NOINLINE void gemm5x5(int m0, int m, int n0, int n) {
299
- BEGIN_KERNEL(5, 5)
300
- D c00 = {0};
301
- D c01 = {0};
302
- D c02 = {0};
303
- D c03 = {0};
304
- D c04 = {0};
305
- D c10 = {0};
306
- D c11 = {0};
307
- D c12 = {0};
308
- D c13 = {0};
309
- D c14 = {0};
310
- D c20 = {0};
311
- D c21 = {0};
312
- D c22 = {0};
313
- D c23 = {0};
314
- D c24 = {0};
315
- D c30 = {0};
316
- D c31 = {0};
317
- D c32 = {0};
318
- D c33 = {0};
319
- D c34 = {0};
320
- D c40 = {0};
321
- D c41 = {0};
322
- D c42 = {0};
323
- D c43 = {0};
324
- D c44 = {0};
325
- for (int l = 0; l < k; l += KN) {
326
- V k0 = load<V>(B + ldb * (j + 0) + l);
327
- V k1 = load<V>(B + ldb * (j + 1) + l);
328
- V k2 = load<V>(B + ldb * (j + 2) + l);
329
- V k3 = load<V>(B + ldb * (j + 3) + l);
330
- V k4 = load<V>(B + ldb * (j + 4) + l);
331
- V a0 = load<V>(A + lda * (i + 0) + l);
332
- c00 = madd(a0, k0, c00);
333
- c01 = madd(a0, k1, c01);
334
- c02 = madd(a0, k2, c02);
335
- c03 = madd(a0, k3, c03);
336
- c04 = madd(a0, k4, c04);
337
- V a1 = load<V>(A + lda * (i + 1) + l);
338
- c10 = madd(a1, k0, c10);
339
- c11 = madd(a1, k1, c11);
340
- c12 = madd(a1, k2, c12);
341
- c13 = madd(a1, k3, c13);
342
- c14 = madd(a1, k4, c14);
343
- V a2 = load<V>(A + lda * (i + 2) + l);
344
- c20 = madd(a2, k0, c20);
345
- c21 = madd(a2, k1, c21);
346
- c22 = madd(a2, k2, c22);
347
- c23 = madd(a2, k3, c23);
348
- c24 = madd(a2, k4, c24);
349
- V a3 = load<V>(A + lda * (i + 3) + l);
350
- c30 = madd(a3, k0, c30);
351
- c31 = madd(a3, k1, c31);
352
- c32 = madd(a3, k2, c32);
353
- c33 = madd(a3, k3, c33);
354
- c34 = madd(a3, k4, c34);
355
- V a4 = load<V>(A + lda * (i + 4) + l);
356
- c40 = madd(a4, k0, c40);
357
- c41 = madd(a4, k1, c41);
358
- c42 = madd(a4, k2, c42);
359
- c43 = madd(a4, k3, c43);
360
- c44 = madd(a4, k4, c44);
361
- }
362
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
363
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
364
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
365
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
366
- C[ldc * (j + 0) + (i + 4)] = hsum(c40);
367
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
368
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
369
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
370
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
371
- C[ldc * (j + 1) + (i + 4)] = hsum(c41);
372
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
373
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
374
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
375
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
376
- C[ldc * (j + 2) + (i + 4)] = hsum(c42);
377
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
378
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
379
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
380
- C[ldc * (j + 3) + (i + 3)] = hsum(c33);
381
- C[ldc * (j + 3) + (i + 4)] = hsum(c43);
382
- C[ldc * (j + 4) + (i + 0)] = hsum(c04);
383
- C[ldc * (j + 4) + (i + 1)] = hsum(c14);
384
- C[ldc * (j + 4) + (i + 2)] = hsum(c24);
385
- C[ldc * (j + 4) + (i + 3)] = hsum(c34);
386
- C[ldc * (j + 4) + (i + 4)] = hsum(c44);
387
- END_KERNEL()
388
- }
389
-
390
- NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
391
- BEGIN_KERNEL(3, 4)
392
- D c00 = {0};
393
- D c01 = {0};
394
- D c02 = {0};
395
- D c03 = {0};
396
- D c10 = {0};
397
- D c11 = {0};
398
- D c12 = {0};
399
- D c13 = {0};
400
- D c20 = {0};
401
- D c21 = {0};
402
- D c22 = {0};
403
- D c23 = {0};
404
- for (int l = 0; l < k; l += KN) {
405
- V k0 = load<V>(B + ldb * (j + 0) + l);
406
- V k1 = load<V>(B + ldb * (j + 1) + l);
407
- V k2 = load<V>(B + ldb * (j + 2) + l);
408
- V k3 = load<V>(B + ldb * (j + 3) + l);
409
- V a0 = load<V>(A + lda * (i + 0) + l);
410
- c00 = madd(a0, k0, c00);
411
- c01 = madd(a0, k1, c01);
412
- c02 = madd(a0, k2, c02);
413
- c03 = madd(a0, k3, c03);
414
- V a1 = load<V>(A + lda * (i + 1) + l);
415
- c10 = madd(a1, k0, c10);
416
- c11 = madd(a1, k1, c11);
417
- c12 = madd(a1, k2, c12);
418
- c13 = madd(a1, k3, c13);
419
- V a2 = load<V>(A + lda * (i + 2) + l);
420
- c20 = madd(a2, k0, c20);
421
- c21 = madd(a2, k1, c21);
422
- c22 = madd(a2, k2, c22);
423
- c23 = madd(a2, k3, c23);
424
- }
425
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
426
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
427
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
428
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
429
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
430
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
431
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
432
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
433
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
434
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
435
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
436
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
437
- END_KERNEL()
438
- }
439
-
440
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
441
- BEGIN_KERNEL(1, 4)
442
- D c00 = {0}, e00 = {0};
443
- D c01 = {0}, e01 = {0};
444
- D c02 = {0}, e02 = {0};
445
- D c03 = {0}, e03 = {0};
446
- for (int l = 0; l < k; l += KN) {
447
- V a = load<V>(A + lda * (i + 0) + l);
448
- c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
449
- c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
450
- c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
451
- c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
452
- }
453
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
454
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
455
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
456
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
457
- END_KERNEL()
458
- }
459
-
460
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
461
- BEGIN_KERNEL(4, 1)
462
- D c00 = {0}, e00 = {0};
463
- D c10 = {0}, e10 = {0};
464
- D c20 = {0}, e20 = {0};
465
- D c30 = {0}, e30 = {0};
466
- for (int l = 0; l < k; l += KN) {
467
- V b = load<V>(B + ldb * (j + 0) + l);
468
- c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
469
- c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
470
- c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
471
- c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
408
+ mnpack(m0, m, np, n);
409
+ }
410
+
411
+ template <int RM, int RN>
412
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
413
+ int ytiles = (m - m0) / RM;
414
+ int xtiles = (n - n0) / RN;
415
+ int tiles = xtiles * ytiles;
416
+ int duty = (tiles + nth - 1) / nth;
417
+ int start = duty * ith;
418
+ int end = start + duty;
419
+ if (end > tiles)
420
+ end = tiles;
421
+ for (int job = start; job < end; ++job) {
422
+ int ii = m0 + job / xtiles * RM;
423
+ int jj = n0 + job % xtiles * RN;
424
+ D Cv[RN][RM] = {};
425
+ for (int l = 0; l < k; l += KN)
426
+ for (int j = 0; j < RN; ++j)
427
+ for (int i = 0; i < RM; ++i)
428
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
429
+ load<V>(B + ldb * (jj + j) + l),
430
+ Cv[j][i]);
431
+ for (int j = 0; j < RN; ++j)
432
+ for (int i = 0; i < RM; ++i)
433
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
472
434
  }
473
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
474
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
475
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
476
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
477
- END_KERNEL()
478
- }
479
-
480
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
481
- BEGIN_KERNEL(1, 1)
482
- D c = {0}, e = {0};
483
- for (int l = 0; l < k; l += KN)
484
- c = madder(load<V>(A + lda * i + l),
485
- load<V>(B + ldb * j + l), c, &e);
486
- C[ldc * j + i] = hsum(c);
487
- END_KERNEL()
488
435
  }
489
436
 
490
437
  const TA *const A;
@@ -521,120 +468,97 @@ class tinyBLAS_Q0_ARM {
521
468
  private:
522
469
  NOINLINE void mnpack(int m0, int m, int n0, int n) {
523
470
  int mc, nc, mp, np;
524
- if (m - m0 <= 0 || n - n0 <= 0)
525
- return;
526
- if (m - m0 >= 3 && n - n0 >= 3) {
471
+ switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
472
+ case 0x33:
527
473
  mc = 3;
528
474
  nc = 3;
529
- gemm3x3(m0, m, n0, n);
530
- } else {
475
+ gemm<3, 3>(m0, m, n0, n);
476
+ break;
477
+ case 0x32:
478
+ mc = 3;
479
+ nc = 2;
480
+ gemm<3, 2>(m0, m, n0, n);
481
+ break;
482
+ case 0x23:
483
+ mc = 2;
484
+ nc = 3;
485
+ gemm<2, 3>(m0, m, n0, n);
486
+ break;
487
+ case 0x22:
488
+ mc = 2;
489
+ nc = 2;
490
+ gemm<2, 2>(m0, m, n0, n);
491
+ break;
492
+ case 0x31:
493
+ mc = 3;
494
+ nc = 1;
495
+ gemm<3, 1>(m0, m, n0, n);
496
+ break;
497
+ case 0x13:
498
+ mc = 1;
499
+ nc = 3;
500
+ gemm<1, 3>(m0, m, n0, n);
501
+ break;
502
+ case 0x21:
503
+ mc = 2;
504
+ nc = 1;
505
+ gemm<2, 1>(m0, m, n0, n);
506
+ break;
507
+ case 0x12:
508
+ mc = 1;
509
+ nc = 2;
510
+ gemm<1, 2>(m0, m, n0, n);
511
+ break;
512
+ case 0x11:
531
513
  mc = 1;
532
514
  nc = 1;
533
- gemm1x1(m0, m, n0, n);
515
+ gemm<1, 1>(m0, m, n0, n);
516
+ break;
517
+ default:
518
+ return;
534
519
  }
535
520
  mp = m0 + (m - m0) / mc * mc;
536
521
  np = n0 + (n - n0) / nc * nc;
537
522
  mnpack(mp, m, n0, np);
538
- mnpack(m0, mp, np, n);
539
- mnpack(mp, m, np, n);
540
- }
541
-
542
- NOINLINE void gemm3x3(int m0, int m, int n0, int n) {
543
- BEGIN_KERNEL(3, 3)
544
- int32x4_t zero = vdupq_n_s32(0);
545
- float32x4_t c00 = vdupq_n_f32(0.f);
546
- float32x4_t c01 = vdupq_n_f32(0.f);
547
- float32x4_t c02 = vdupq_n_f32(0.f);
548
- float32x4_t c10 = vdupq_n_f32(0.f);
549
- float32x4_t c11 = vdupq_n_f32(0.f);
550
- float32x4_t c12 = vdupq_n_f32(0.f);
551
- float32x4_t c20 = vdupq_n_f32(0.f);
552
- float32x4_t c21 = vdupq_n_f32(0.f);
553
- float32x4_t c22 = vdupq_n_f32(0.f);
554
- const TA *Ap0 = A + lda * (i + 0);
555
- const TA *Ap1 = A + lda * (i + 1);
556
- const TA *Ap2 = A + lda * (i + 2);
557
- const block_q8_0 *Bp0 = B + ldb * (j + 0);
558
- const block_q8_0 *Bp1 = B + ldb * (j + 1);
559
- const block_q8_0 *Bp2 = B + ldb * (j + 2);
560
- for (int l = 0; l < k; ++l) {
561
- c00 = vmlaq_n_f32(
562
- c00,
563
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)),
564
- load_hi(Ap0 + l), load_hi(Bp0 + l))),
565
- unhalf(Ap0[l].d) * unhalf(Bp0[l].d));
566
- c01 = vmlaq_n_f32(
567
- c01,
568
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)),
569
- load_hi(Ap0 + l), load_hi(Bp1 + l))),
570
- unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
571
- c02 = vmlaq_n_f32(
572
- c02,
573
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
574
- load_hi(Ap0 + l), load_hi(Bp2 + l))),
575
- unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
576
- c10 = vmlaq_n_f32(
577
- c10,
578
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
579
- load_hi(Ap1 + l), load_hi(Bp0 + l))),
580
- unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
581
- c11 = vmlaq_n_f32(
582
- c11,
583
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
584
- load_hi(Ap1 + l), load_hi(Bp1 + l))),
585
- unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
586
- c12 = vmlaq_n_f32(
587
- c12,
588
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
589
- load_hi(Ap1 + l), load_hi(Bp2 + l))),
590
- unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
591
- c20 = vmlaq_n_f32(
592
- c20,
593
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
594
- load_hi(Ap2 + l), load_hi(Bp0 + l))),
595
- unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
596
- c21 = vmlaq_n_f32(
597
- c21,
598
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
599
- load_hi(Ap2 + l), load_hi(Bp1 + l))),
600
- unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
601
- c22 = vmlaq_n_f32(
602
- c22,
603
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
604
- load_hi(Ap2 + l), load_hi(Bp2 + l))),
605
- unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
606
- }
607
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
608
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
609
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
610
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
611
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
612
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
613
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
614
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
615
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
616
- END_KERNEL()
617
- }
618
-
619
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
620
- BEGIN_KERNEL(1, 1)
621
- float32x4_t acc = vdupq_n_f32(0.f);
622
- const TA *Ap = A + lda * i;
623
- const block_q8_0 *Bp = B + ldb * j;
624
- for (int l = 0; l < k; ++l) {
625
- acc = vmlaq_n_f32(acc,
626
- vcvtq_f32_s32(vdotq_s32(
627
- vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
628
- load_hi(Ap + l), load_hi(Bp + l))),
629
- unhalf(Ap[l].d) * unhalf(Bp[l].d));
523
+ mnpack(m0, m, np, n);
524
+ }
525
+
526
+ template <int RM, int RN>
527
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
528
+ int ytiles = (m - m0) / RM;
529
+ int xtiles = (n - n0) / RN;
530
+ int tiles = xtiles * ytiles;
531
+ int duty = (tiles + nth - 1) / nth;
532
+ int start = duty * ith;
533
+ int end = start + duty;
534
+ if (end > tiles)
535
+ end = tiles;
536
+ for (int job = start; job < end; ++job) {
537
+ int ii = m0 + job / xtiles * RM;
538
+ int jj = n0 + job % xtiles * RN;
539
+ float32x4_t Cv[RN][RM] = {};
540
+ for (int l = 0; l < k; ++l)
541
+ for (int j = 0; j < RN; ++j)
542
+ for (int i = 0; i < RM; ++i)
543
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
544
+ vcvtq_f32_s32(vdotq_s32(
545
+ vdotq_s32(vdupq_n_s32(0),
546
+ load_lo(A + lda * (ii + i) + l),
547
+ load_lo(B + ldb * (jj + j) + l)),
548
+ load_hi(A + lda * (ii + i) + l),
549
+ load_hi(B + ldb * (jj + j) + l))),
550
+ unhalf(A[lda * (ii + i) + l].d) *
551
+ unhalf(B[ldb * (jj + j) + l].d));
552
+ for (int j = 0; j < RN; ++j)
553
+ for (int i = 0; i < RM; ++i)
554
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
630
555
  }
631
- C[ldc * j + i] = hsum(acc);
632
- END_KERNEL()
633
556
  }
634
557
 
635
558
  inline int8x16_t load_lo(const block_q8_0 *b) {
636
559
  return vld1q_s8(b->qs);
637
560
  }
561
+
638
562
  inline int8x16_t load_hi(const block_q8_0 *b) {
639
563
  return vld1q_s8(b->qs + 16);
640
564
  }
@@ -644,6 +568,7 @@ class tinyBLAS_Q0_ARM {
644
568
  vdupq_n_u8(0x0f))),
645
569
  vdupq_n_s8(0x8));
646
570
  }
571
+
647
572
  inline int8x16_t load_hi(const block_q4_0 *b) {
648
573
  return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
649
574
  vdupq_n_s8(0x8));
@@ -679,217 +604,143 @@ class tinyBLAS_Q0_AVX2 {
679
604
  }
680
605
 
681
606
  private:
682
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
607
+ void mnpack(int m0, int m, int n0, int n) {
683
608
  int mc, nc, mp, np;
684
- if (m - m0 <= 0 || n - n0 <= 0)
685
- return;
686
- if (m - m0 >= 4 && n - n0 >= 3) {
609
+ switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
610
+ #if VECTOR_REGISTERS == 32
611
+ case 0x44:
612
+ mc = 4;
613
+ nc = 4;
614
+ gemm<4, 4>(m0, m, n0, n);
615
+ break;
616
+ case 0x43:
617
+ mc = 4;
618
+ nc = 3;
619
+ gemm<4, 3>(m0, m, n0, n);
620
+ break;
621
+ case 0x34:
622
+ mc = 3;
623
+ nc = 4;
624
+ gemm<3, 4>(m0, m, n0, n);
625
+ break;
626
+ case 0x33:
627
+ mc = 3;
628
+ nc = 3;
629
+ gemm<3, 3>(m0, m, n0, n);
630
+ break;
631
+ case 0x42:
687
632
  mc = 4;
633
+ nc = 2;
634
+ gemm<4, 2>(m0, m, n0, n);
635
+ break;
636
+ case 0x24:
637
+ mc = 2;
638
+ nc = 4;
639
+ gemm<2, 4>(m0, m, n0, n);
640
+ break;
641
+ #else
642
+ case 0x44:
643
+ case 0x43:
644
+ case 0x42:
645
+ mc = 4;
646
+ nc = 2;
647
+ gemm<4, 2>(m0, m, n0, n);
648
+ break;
649
+ case 0x34:
650
+ case 0x24:
651
+ mc = 2;
652
+ nc = 4;
653
+ gemm<2, 4>(m0, m, n0, n);
654
+ break;
655
+ case 0x33:
656
+ #endif
657
+ case 0x32:
658
+ mc = 3;
659
+ nc = 2;
660
+ gemm<3, 2>(m0, m, n0, n);
661
+ break;
662
+ case 0x23:
663
+ mc = 2;
688
664
  nc = 3;
689
- gemm4x3(m0, m, n0, n);
690
- } else if (m - m0 >= 4 && n - n0 >= 1) {
665
+ gemm<2, 3>(m0, m, n0, n);
666
+ break;
667
+ case 0x41:
691
668
  mc = 4;
692
669
  nc = 1;
693
- gemm4x1(m0, m, n0, n);
694
- } else if (m - m0 >= 1 && n - n0 >= 4) {
670
+ gemm<4, 1>(m0, m, n0, n);
671
+ break;
672
+ case 0x22:
673
+ mc = 2;
674
+ nc = 2;
675
+ gemm<2, 2>(m0, m, n0, n);
676
+ break;
677
+ case 0x14:
695
678
  mc = 1;
696
679
  nc = 4;
697
- gemm1x4(m0, m, n0, n);
698
- } else {
680
+ gemm<1, 4>(m0, m, n0, n);
681
+ break;
682
+ case 0x31:
683
+ mc = 3;
684
+ nc = 1;
685
+ gemm<3, 1>(m0, m, n0, n);
686
+ break;
687
+ case 0x13:
699
688
  mc = 1;
689
+ nc = 3;
690
+ gemm<1, 3>(m0, m, n0, n);
691
+ break;
692
+ case 0x21:
693
+ mc = 2;
700
694
  nc = 1;
701
- gemm1x1(m0, m, n0, n);
695
+ gemm<2, 1>(m0, m, n0, n);
696
+ break;
697
+ case 0x12:
698
+ mc = 1;
699
+ nc = 2;
700
+ gemm<1, 2>(m0, m, n0, n);
701
+ break;
702
+ case 0x11:
703
+ mc = 1;
704
+ nc = 1;
705
+ gemm<1, 1>(m0, m, n0, n);
706
+ break;
707
+ default:
708
+ return;
702
709
  }
703
710
  mp = m0 + (m - m0) / mc * mc;
704
711
  np = n0 + (n - n0) / nc * nc;
705
712
  mnpack(mp, m, n0, np);
706
- mnpack(m0, mp, np, n);
707
- mnpack(mp, m, np, n);
708
- }
709
-
710
- NOINLINE void gemm4x3(int m0, int m, int n0, int n) {
711
- BEGIN_KERNEL(4, 3)
712
- __m256 c00 = _mm256_setzero_ps();
713
- __m256 c10 = _mm256_setzero_ps();
714
- __m256 c20 = _mm256_setzero_ps();
715
- __m256 c30 = _mm256_setzero_ps();
716
- __m256 c01 = _mm256_setzero_ps();
717
- __m256 c11 = _mm256_setzero_ps();
718
- __m256 c21 = _mm256_setzero_ps();
719
- __m256 c31 = _mm256_setzero_ps();
720
- __m256 c02 = _mm256_setzero_ps();
721
- __m256 c12 = _mm256_setzero_ps();
722
- __m256 c22 = _mm256_setzero_ps();
723
- __m256 c32 = _mm256_setzero_ps();
724
- const TA *Ap0 = A + lda * (i + 0);
725
- const TA *Ap1 = A + lda * (i + 1);
726
- const TA *Ap2 = A + lda * (i + 2);
727
- const TA *Ap3 = A + lda * (i + 3);
728
- const TB *Bp0 = B + ldb * (j + 0);
729
- const TB *Bp1 = B + ldb * (j + 1);
730
- const TB *Bp2 = B + ldb * (j + 2);
731
- for (int l = 0; l < k; ++l) {
732
- float da0 = unhalf(Ap0[l].d);
733
- float da1 = unhalf(Ap1[l].d);
734
- float da2 = unhalf(Ap2[l].d);
735
- float da3 = unhalf(Ap3[l].d);
736
- __m256i e0 = load(Ap0 + l);
737
- __m256i e1 = load(Ap1 + l);
738
- __m256i e2 = load(Ap2 + l);
739
- __m256i e3 = load(Ap3 + l);
740
- float db0 = unhalf(Bp0[l].d);
741
- __m256 d00 = _mm256_set1_ps(da0 * db0);
742
- __m256 d10 = _mm256_set1_ps(da1 * db0);
743
- __m256 d20 = _mm256_set1_ps(da2 * db0);
744
- __m256 d30 = _mm256_set1_ps(da3 * db0);
745
- __m256i f0 = load(Bp0 + l);
746
- __m256i u0 = _mm256_sign_epi8(f0, f0);
747
- __m256i s00 = _mm256_sign_epi8(e0, f0);
748
- __m256i s10 = _mm256_sign_epi8(e1, f0);
749
- __m256i s20 = _mm256_sign_epi8(e2, f0);
750
- __m256i s30 = _mm256_sign_epi8(e3, f0);
751
- c00 = madd(d00, updot(u0, s00), c00);
752
- c10 = madd(d10, updot(u0, s10), c10);
753
- c20 = madd(d20, updot(u0, s20), c20);
754
- c30 = madd(d30, updot(u0, s30), c30);
755
- float db1 = unhalf(Bp1[l].d);
756
- __m256 d01 = _mm256_set1_ps(da0 * db1);
757
- __m256 d11 = _mm256_set1_ps(da1 * db1);
758
- __m256 d21 = _mm256_set1_ps(da2 * db1);
759
- __m256 d31 = _mm256_set1_ps(da3 * db1);
760
- __m256i f1 = load(Bp1 + l);
761
- __m256i u1 = _mm256_sign_epi8(f1, f1);
762
- __m256i s01 = _mm256_sign_epi8(e0, f1);
763
- __m256i s11 = _mm256_sign_epi8(e1, f1);
764
- __m256i s21 = _mm256_sign_epi8(e2, f1);
765
- __m256i s31 = _mm256_sign_epi8(e3, f1);
766
- c01 = madd(d01, updot(u1, s01), c01);
767
- c11 = madd(d11, updot(u1, s11), c11);
768
- c21 = madd(d21, updot(u1, s21), c21);
769
- c31 = madd(d31, updot(u1, s31), c31);
770
- float db2 = unhalf(Bp2[l].d);
771
- __m256 d02 = _mm256_set1_ps(da0 * db2);
772
- __m256 d12 = _mm256_set1_ps(da1 * db2);
773
- __m256 d22 = _mm256_set1_ps(da2 * db2);
774
- __m256 d32 = _mm256_set1_ps(da3 * db2);
775
- __m256i f2 = load(Bp2 + l);
776
- __m256i u2 = _mm256_sign_epi8(f2, f2);
777
- __m256i s02 = _mm256_sign_epi8(e0, f2);
778
- __m256i s12 = _mm256_sign_epi8(e1, f2);
779
- __m256i s22 = _mm256_sign_epi8(e2, f2);
780
- __m256i s32 = _mm256_sign_epi8(e3, f2);
781
- c02 = madd(d02, updot(u2, s02), c02);
782
- c12 = madd(d12, updot(u2, s12), c12);
783
- c22 = madd(d22, updot(u2, s22), c22);
784
- c32 = madd(d32, updot(u2, s32), c32);
785
- }
786
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
787
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
788
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
789
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
790
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
791
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
792
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
793
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
794
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
795
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
796
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
797
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
798
- END_KERNEL()
799
- }
800
-
801
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
802
- BEGIN_KERNEL(4, 1)
803
- __m256 c0 = _mm256_setzero_ps();
804
- __m256 c1 = _mm256_setzero_ps();
805
- __m256 c2 = _mm256_setzero_ps();
806
- __m256 c3 = _mm256_setzero_ps();
807
- const TA *Ap0 = A + lda * (i + 0);
808
- const TA *Ap1 = A + lda * (i + 1);
809
- const TA *Ap2 = A + lda * (i + 2);
810
- const TA *Ap3 = A + lda * (i + 3);
811
- const TB *Bp = B + ldb * j;
812
- for (int l = 0; l < k; ++l) {
813
- float db0 = unhalf(Bp[l].d);
814
- __m256i f = load(Bp + l);
815
- __m256i u = _mm256_sign_epi8(f, f);
816
- __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
817
- __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
818
- __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
819
- __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
820
- __m256i e0 = load(Ap0 + l);
821
- __m256i e1 = load(Ap1 + l);
822
- __m256i e2 = load(Ap2 + l);
823
- __m256i e3 = load(Ap3 + l);
824
- __m256i s0 = _mm256_sign_epi8(e0, f);
825
- __m256i s1 = _mm256_sign_epi8(e1, f);
826
- __m256i s2 = _mm256_sign_epi8(e2, f);
827
- __m256i s3 = _mm256_sign_epi8(e3, f);
828
- __m256 g0 = updot(u, s0);
829
- __m256 g1 = updot(u, s1);
830
- __m256 g2 = updot(u, s2);
831
- __m256 g3 = updot(u, s3);
832
- c0 = madd(d0, g0, c0);
833
- c1 = madd(d1, g1, c1);
834
- c2 = madd(d2, g2, c2);
835
- c3 = madd(d3, g3, c3);
836
- }
837
- C[ldc * j + (i + 0)] = hsum(c0);
838
- C[ldc * j + (i + 1)] = hsum(c1);
839
- C[ldc * j + (i + 2)] = hsum(c2);
840
- C[ldc * j + (i + 3)] = hsum(c3);
841
- END_KERNEL()
842
- }
843
-
844
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
845
- BEGIN_KERNEL(1, 4)
846
- __m256 c0 = _mm256_setzero_ps();
847
- __m256 c1 = _mm256_setzero_ps();
848
- __m256 c2 = _mm256_setzero_ps();
849
- __m256 c3 = _mm256_setzero_ps();
850
- const TB *Bp0 = B + ldb * (j + 0);
851
- const TB *Bp1 = B + ldb * (j + 1);
852
- const TB *Bp2 = B + ldb * (j + 2);
853
- const TB *Bp3 = B + ldb * (j + 3);
854
- const TA *Ap = A + lda * i;
855
- for (int l = 0; l < k; ++l) {
856
- float da0 = unhalf(Ap[l].d);
857
- __m256i f = load(Ap + l);
858
- __m256i u = _mm256_sign_epi8(f, f);
859
- __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
860
- __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
861
- __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
862
- __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
863
- __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
864
- __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
865
- __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
866
- __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
867
- c0 = madd(d0, g0, c0);
868
- c1 = madd(d1, g1, c1);
869
- c2 = madd(d2, g2, c2);
870
- c3 = madd(d3, g3, c3);
871
- }
872
- C[ldc * (j + 0) + i] = hsum(c0);
873
- C[ldc * (j + 1) + i] = hsum(c1);
874
- C[ldc * (j + 2) + i] = hsum(c2);
875
- C[ldc * (j + 3) + i] = hsum(c3);
876
- END_KERNEL()
877
- }
878
-
879
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
880
- BEGIN_KERNEL(1, 1)
881
- __m256 c = _mm256_setzero_ps();
882
- const TA *Ap = A + lda * i;
883
- const TB *Bp = B + ldb * j;
884
- for (int l = 0; l < k; ++l) {
885
- __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
886
- __m256i e = load(Ap + l);
887
- __m256i f = load(Bp + l);
888
- __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
889
- c = madd(d, g, c);
713
+ mnpack(m0, m, np, n);
714
+ }
715
+
716
+ template <int RM, int RN>
717
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
718
+ int ytiles = (m - m0) / RM;
719
+ int xtiles = (n - n0) / RN;
720
+ int tiles = xtiles * ytiles;
721
+ int duty = (tiles + nth - 1) / nth;
722
+ int start = duty * ith;
723
+ int end = start + duty;
724
+ if (end > tiles)
725
+ end = tiles;
726
+ for (int job = start; job < end; ++job) {
727
+ int ii = m0 + job / xtiles * RM;
728
+ int jj = n0 + job % xtiles * RN;
729
+ __m256 Cv[RN][RM] = {};
730
+ for (int l = 0; l < k; ++l)
731
+ for (int j = 0; j < RN; ++j)
732
+ for (int i = 0; i < RM; ++i)
733
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
734
+ unhalf(B[ldb * (jj + j) + l].d)),
735
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
736
+ load(A + lda * (ii + i) + l)),
737
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
738
+ load(A + lda * (ii + i) + l))),
739
+ Cv[j][i]);
740
+ for (int j = 0; j < RN; ++j)
741
+ for (int i = 0; i < RM; ++i)
742
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
890
743
  }
891
- C[ldc * j + i] = hsum(c);
892
- END_KERNEL()
893
744
  }
894
745
 
895
746
  inline __m256i load(const block_q8_0 *b) {
@@ -911,10 +762,10 @@ class tinyBLAS_Q0_AVX2 {
911
762
  }
912
763
 
913
764
  static inline __m256i denibble(const uint8_t *p) {
914
- const __m128i tmp = _mm_loadu_si128((const __m128i *)p);
915
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
916
- const __m256i lowMask = _mm256_set1_epi8(15);
917
- return _mm256_and_si256(lowMask, bytes);
765
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
766
+ return _mm256_and_si256(_mm256_set1_epi8(15),
767
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
768
+ _mm_srli_epi16(x, 4), 1));
918
769
  }
919
770
 
920
771
  const TA *const A;