llama_cpp 0.14.6 → 0.14.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -50,6 +50,7 @@
50
50
  #pragma GCC diagnostic ignored "-Wignored-attributes"
51
51
 
52
52
  #include "sgemm.h"
53
+ #include <algorithm>
53
54
  #include "ggml-impl.h"
54
55
  #include "ggml-quants.h"
55
56
 
@@ -65,22 +66,6 @@
65
66
  #define VECTOR_REGISTERS 16
66
67
  #endif
67
68
 
68
- // there will be blocks
69
- #define BEGIN_KERNEL(RM, RN) \
70
- int ytiles = (m - m0) / RM; \
71
- int xtiles = (n - n0) / RN; \
72
- int tiles = ytiles * xtiles; \
73
- int duty = (tiles + nth - 1) / nth; \
74
- int start = duty * ith; \
75
- int end = start + duty; \
76
- if (end > tiles) \
77
- end = tiles; \
78
- for (int job = start; job < end; ++job) { \
79
- int i = m0 + job / xtiles * RM; \
80
- int j = n0 + job % xtiles * RN;
81
-
82
- #define END_KERNEL() }
83
-
84
69
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
85
70
 
86
71
  namespace {
@@ -122,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
122
107
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
123
108
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
124
109
 
110
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
111
+ // VECTORIZED FUSED MULTIPLY ADD
112
+
113
+ /**
114
+ * Computes a * b + c.
115
+ */
116
+ template <typename T, typename U>
117
+ inline U madd(T a, T b, U c) {
118
+ return add(mul(a, b), c);
119
+ }
120
+
121
+ #if defined(__FMA__)
122
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
123
+ template <>
124
+ inline __m256 madd(__m256 a, __m256 b, __m256 c) {
125
+ return _mm256_fmadd_ps(a, b, c);
126
+ }
127
+ #endif
128
+ #if defined(__AVX512F__)
129
+ template <>
130
+ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
131
+ return _mm512_fmadd_ps(a, b, c);
132
+ }
133
+ #endif
134
+ #endif
135
+
136
+ #if defined(__ARM_FEATURE_FMA)
137
+ template <>
138
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
139
+ return vfmaq_f32(c, b, a);
140
+ }
141
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
142
+ template <>
143
+ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
144
+ return vfmaq_f16(c, b, a);
145
+ }
146
+ #endif
147
+ #endif
148
+
125
149
  ////////////////////////////////////////////////////////////////////////////////////////////////////
126
150
  // VECTORIZED HORIZONTAL SUM
127
151
 
@@ -213,36 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
213
237
  }
214
238
  #endif // __AVX512F__
215
239
 
216
- ////////////////////////////////////////////////////////////////////////////////////////////////////
217
- // ABSTRACTIONS
218
-
219
- /**
220
- * Computes a * b + c.
221
- *
222
- * This operation will become fused into a single arithmetic instruction
223
- * if the hardware has support for this feature, e.g. Intel Haswell+ (c.
224
- * 2013), AMD Bulldozer+ (c. 2011), etc.
225
- */
226
- template <typename T, typename U>
227
- inline U madd(T a, T b, U c) {
228
- return add(mul(a, b), c);
229
- }
230
-
231
- /**
232
- * Computes a * b + c with error correction.
233
- *
234
- * @see W. Kahan, "Further remarks on reducing truncation errors,"
235
- * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
236
- * doi: 10.1145/363707.363723.
237
- */
238
- template <typename T, typename U>
239
- inline U madder(T a, T b, U c, U *e) {
240
- U y = sub(mul(a, b), *e);
241
- U t = add(c, y);
242
- *e = sub(sub(t, c), y);
243
- return t;
244
- }
245
-
246
240
  ////////////////////////////////////////////////////////////////////////////////////////////////////
247
241
  // FLOATING POINT MATRIX MULTIPLICATION
248
242
 
@@ -265,226 +259,179 @@ class tinyBLAS {
265
259
  private:
266
260
  NOINLINE void mnpack(int m0, int m, int n0, int n) {
267
261
  int mc, nc, mp, np;
268
- if (m - m0 <= 0 || n - n0 <= 0)
269
- return;
270
- if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) {
262
+ switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
263
+ #if VECTOR_REGISTERS == 32
264
+ case 0x55:
265
+ mc = 5;
266
+ nc = 5;
267
+ gemm<5, 5>(m0, m, n0, n);
268
+ break;
269
+ case 0x45:
270
+ mc = 4;
271
+ nc = 5;
272
+ gemm<4, 5>(m0, m, n0, n);
273
+ break;
274
+ case 0x54:
271
275
  mc = 5;
276
+ nc = 4;
277
+ gemm<5, 4>(m0, m, n0, n);
278
+ break;
279
+ case 0x44:
280
+ mc = 4;
281
+ nc = 4;
282
+ gemm<4, 4>(m0, m, n0, n);
283
+ break;
284
+ case 0x53:
285
+ mc = 5;
286
+ nc = 3;
287
+ gemm<5, 3>(m0, m, n0, n);
288
+ break;
289
+ case 0x35:
290
+ mc = 3;
272
291
  nc = 5;
273
- gemm5x5(m0, m, n0, n);
274
- } else if (n - n0 >= 4 && m - m0 >= 3) {
292
+ gemm<3, 5>(m0, m, n0, n);
293
+ break;
294
+ case 0x43:
295
+ mc = 4;
296
+ nc = 3;
297
+ gemm<4, 3>(m0, m, n0, n);
298
+ break;
299
+ #else
300
+ case 0x55:
301
+ case 0x54:
302
+ case 0x53:
303
+ case 0x45:
304
+ case 0x44:
305
+ case 0x43:
306
+ mc = 4;
307
+ nc = 3;
308
+ gemm<4, 3>(m0, m, n0, n);
309
+ break;
310
+ case 0x35:
311
+ #endif
312
+ case 0x34:
275
313
  mc = 3;
276
314
  nc = 4;
277
- gemm3x4(m0, m, n0, n);
278
- } else if (n - n0 >= 4) {
279
- mc = 1;
315
+ gemm<3, 4>(m0, m, n0, n);
316
+ break;
317
+ case 0x52:
318
+ mc = 5;
319
+ nc = 2;
320
+ gemm<5, 2>(m0, m, n0, n);
321
+ break;
322
+ case 0x33:
323
+ mc = 3;
324
+ nc = 3;
325
+ gemm<3, 3>(m0, m, n0, n);
326
+ break;
327
+ case 0x25:
328
+ mc = 2;
329
+ nc = 5;
330
+ gemm<2, 5>(m0, m, n0, n);
331
+ break;
332
+ case 0x42:
333
+ mc = 4;
334
+ nc = 2;
335
+ gemm<4, 2>(m0, m, n0, n);
336
+ break;
337
+ case 0x24:
338
+ mc = 2;
280
339
  nc = 4;
281
- gemm1x4(m0, m, n0, n);
282
- } else if (m - m0 >= 4) {
340
+ gemm<2, 4>(m0, m, n0, n);
341
+ break;
342
+ case 0x32:
343
+ mc = 3;
344
+ nc = 2;
345
+ gemm<3, 2>(m0, m, n0, n);
346
+ break;
347
+ case 0x23:
348
+ mc = 2;
349
+ nc = 3;
350
+ gemm<2, 3>(m0, m, n0, n);
351
+ break;
352
+ case 0x51:
353
+ mc = 5;
354
+ nc = 1;
355
+ gemm<5, 1>(m0, m, n0, n);
356
+ break;
357
+ case 0x41:
283
358
  mc = 4;
284
359
  nc = 1;
285
- gemm4x1(m0, m, n0, n);
286
- } else {
360
+ gemm<4, 1>(m0, m, n0, n);
361
+ break;
362
+ case 0x22:
363
+ mc = 2;
364
+ nc = 2;
365
+ gemm<2, 2>(m0, m, n0, n);
366
+ break;
367
+ case 0x15:
368
+ mc = 1;
369
+ nc = 5;
370
+ gemm<1, 5>(m0, m, n0, n);
371
+ break;
372
+ case 0x14:
373
+ mc = 1;
374
+ nc = 4;
375
+ gemm<1, 4>(m0, m, n0, n);
376
+ break;
377
+ case 0x31:
378
+ mc = 3;
379
+ nc = 1;
380
+ gemm<3, 1>(m0, m, n0, n);
381
+ break;
382
+ case 0x13:
287
383
  mc = 1;
384
+ nc = 3;
385
+ gemm<1, 3>(m0, m, n0, n);
386
+ break;
387
+ case 0x21:
388
+ mc = 2;
288
389
  nc = 1;
289
- gemm1x1(m0, m, n0, n);
390
+ gemm<2, 1>(m0, m, n0, n);
391
+ break;
392
+ case 0x12:
393
+ mc = 1;
394
+ nc = 2;
395
+ gemm<1, 2>(m0, m, n0, n);
396
+ break;
397
+ case 0x11:
398
+ mc = 1;
399
+ nc = 1;
400
+ gemm<1, 1>(m0, m, n0, n);
401
+ break;
402
+ default:
403
+ return;
290
404
  }
291
405
  mp = m0 + (m - m0) / mc * mc;
292
406
  np = n0 + (n - n0) / nc * nc;
293
407
  mnpack(mp, m, n0, np);
294
- mnpack(m0, mp, np, n);
295
- mnpack(mp, m, np, n);
296
- }
297
-
298
- NOINLINE void gemm5x5(int m0, int m, int n0, int n) {
299
- BEGIN_KERNEL(5, 5)
300
- D c00 = {0};
301
- D c01 = {0};
302
- D c02 = {0};
303
- D c03 = {0};
304
- D c04 = {0};
305
- D c10 = {0};
306
- D c11 = {0};
307
- D c12 = {0};
308
- D c13 = {0};
309
- D c14 = {0};
310
- D c20 = {0};
311
- D c21 = {0};
312
- D c22 = {0};
313
- D c23 = {0};
314
- D c24 = {0};
315
- D c30 = {0};
316
- D c31 = {0};
317
- D c32 = {0};
318
- D c33 = {0};
319
- D c34 = {0};
320
- D c40 = {0};
321
- D c41 = {0};
322
- D c42 = {0};
323
- D c43 = {0};
324
- D c44 = {0};
325
- for (int l = 0; l < k; l += KN) {
326
- V k0 = load<V>(B + ldb * (j + 0) + l);
327
- V k1 = load<V>(B + ldb * (j + 1) + l);
328
- V k2 = load<V>(B + ldb * (j + 2) + l);
329
- V k3 = load<V>(B + ldb * (j + 3) + l);
330
- V k4 = load<V>(B + ldb * (j + 4) + l);
331
- V a0 = load<V>(A + lda * (i + 0) + l);
332
- c00 = madd(a0, k0, c00);
333
- c01 = madd(a0, k1, c01);
334
- c02 = madd(a0, k2, c02);
335
- c03 = madd(a0, k3, c03);
336
- c04 = madd(a0, k4, c04);
337
- V a1 = load<V>(A + lda * (i + 1) + l);
338
- c10 = madd(a1, k0, c10);
339
- c11 = madd(a1, k1, c11);
340
- c12 = madd(a1, k2, c12);
341
- c13 = madd(a1, k3, c13);
342
- c14 = madd(a1, k4, c14);
343
- V a2 = load<V>(A + lda * (i + 2) + l);
344
- c20 = madd(a2, k0, c20);
345
- c21 = madd(a2, k1, c21);
346
- c22 = madd(a2, k2, c22);
347
- c23 = madd(a2, k3, c23);
348
- c24 = madd(a2, k4, c24);
349
- V a3 = load<V>(A + lda * (i + 3) + l);
350
- c30 = madd(a3, k0, c30);
351
- c31 = madd(a3, k1, c31);
352
- c32 = madd(a3, k2, c32);
353
- c33 = madd(a3, k3, c33);
354
- c34 = madd(a3, k4, c34);
355
- V a4 = load<V>(A + lda * (i + 4) + l);
356
- c40 = madd(a4, k0, c40);
357
- c41 = madd(a4, k1, c41);
358
- c42 = madd(a4, k2, c42);
359
- c43 = madd(a4, k3, c43);
360
- c44 = madd(a4, k4, c44);
361
- }
362
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
363
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
364
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
365
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
366
- C[ldc * (j + 0) + (i + 4)] = hsum(c40);
367
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
368
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
369
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
370
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
371
- C[ldc * (j + 1) + (i + 4)] = hsum(c41);
372
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
373
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
374
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
375
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
376
- C[ldc * (j + 2) + (i + 4)] = hsum(c42);
377
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
378
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
379
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
380
- C[ldc * (j + 3) + (i + 3)] = hsum(c33);
381
- C[ldc * (j + 3) + (i + 4)] = hsum(c43);
382
- C[ldc * (j + 4) + (i + 0)] = hsum(c04);
383
- C[ldc * (j + 4) + (i + 1)] = hsum(c14);
384
- C[ldc * (j + 4) + (i + 2)] = hsum(c24);
385
- C[ldc * (j + 4) + (i + 3)] = hsum(c34);
386
- C[ldc * (j + 4) + (i + 4)] = hsum(c44);
387
- END_KERNEL()
388
- }
389
-
390
- NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
391
- BEGIN_KERNEL(3, 4)
392
- D c00 = {0};
393
- D c01 = {0};
394
- D c02 = {0};
395
- D c03 = {0};
396
- D c10 = {0};
397
- D c11 = {0};
398
- D c12 = {0};
399
- D c13 = {0};
400
- D c20 = {0};
401
- D c21 = {0};
402
- D c22 = {0};
403
- D c23 = {0};
404
- for (int l = 0; l < k; l += KN) {
405
- V k0 = load<V>(B + ldb * (j + 0) + l);
406
- V k1 = load<V>(B + ldb * (j + 1) + l);
407
- V k2 = load<V>(B + ldb * (j + 2) + l);
408
- V k3 = load<V>(B + ldb * (j + 3) + l);
409
- V a0 = load<V>(A + lda * (i + 0) + l);
410
- c00 = madd(a0, k0, c00);
411
- c01 = madd(a0, k1, c01);
412
- c02 = madd(a0, k2, c02);
413
- c03 = madd(a0, k3, c03);
414
- V a1 = load<V>(A + lda * (i + 1) + l);
415
- c10 = madd(a1, k0, c10);
416
- c11 = madd(a1, k1, c11);
417
- c12 = madd(a1, k2, c12);
418
- c13 = madd(a1, k3, c13);
419
- V a2 = load<V>(A + lda * (i + 2) + l);
420
- c20 = madd(a2, k0, c20);
421
- c21 = madd(a2, k1, c21);
422
- c22 = madd(a2, k2, c22);
423
- c23 = madd(a2, k3, c23);
424
- }
425
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
426
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
427
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
428
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
429
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
430
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
431
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
432
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
433
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
434
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
435
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
436
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
437
- END_KERNEL()
438
- }
439
-
440
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
441
- BEGIN_KERNEL(1, 4)
442
- D c00 = {0}, e00 = {0};
443
- D c01 = {0}, e01 = {0};
444
- D c02 = {0}, e02 = {0};
445
- D c03 = {0}, e03 = {0};
446
- for (int l = 0; l < k; l += KN) {
447
- V a = load<V>(A + lda * (i + 0) + l);
448
- c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
449
- c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
450
- c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
451
- c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
452
- }
453
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
454
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
455
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
456
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
457
- END_KERNEL()
458
- }
459
-
460
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
461
- BEGIN_KERNEL(4, 1)
462
- D c00 = {0}, e00 = {0};
463
- D c10 = {0}, e10 = {0};
464
- D c20 = {0}, e20 = {0};
465
- D c30 = {0}, e30 = {0};
466
- for (int l = 0; l < k; l += KN) {
467
- V b = load<V>(B + ldb * (j + 0) + l);
468
- c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
469
- c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
470
- c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
471
- c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
408
+ mnpack(m0, m, np, n);
409
+ }
410
+
411
+ template <int RM, int RN>
412
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
413
+ int ytiles = (m - m0) / RM;
414
+ int xtiles = (n - n0) / RN;
415
+ int tiles = xtiles * ytiles;
416
+ int duty = (tiles + nth - 1) / nth;
417
+ int start = duty * ith;
418
+ int end = start + duty;
419
+ if (end > tiles)
420
+ end = tiles;
421
+ for (int job = start; job < end; ++job) {
422
+ int ii = m0 + job / xtiles * RM;
423
+ int jj = n0 + job % xtiles * RN;
424
+ D Cv[RN][RM] = {};
425
+ for (int l = 0; l < k; l += KN)
426
+ for (int j = 0; j < RN; ++j)
427
+ for (int i = 0; i < RM; ++i)
428
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
429
+ load<V>(B + ldb * (jj + j) + l),
430
+ Cv[j][i]);
431
+ for (int j = 0; j < RN; ++j)
432
+ for (int i = 0; i < RM; ++i)
433
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
472
434
  }
473
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
474
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
475
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
476
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
477
- END_KERNEL()
478
- }
479
-
480
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
481
- BEGIN_KERNEL(1, 1)
482
- D c = {0}, e = {0};
483
- for (int l = 0; l < k; l += KN)
484
- c = madder(load<V>(A + lda * i + l),
485
- load<V>(B + ldb * j + l), c, &e);
486
- C[ldc * j + i] = hsum(c);
487
- END_KERNEL()
488
435
  }
489
436
 
490
437
  const TA *const A;
@@ -521,120 +468,97 @@ class tinyBLAS_Q0_ARM {
521
468
  private:
522
469
  NOINLINE void mnpack(int m0, int m, int n0, int n) {
523
470
  int mc, nc, mp, np;
524
- if (m - m0 <= 0 || n - n0 <= 0)
525
- return;
526
- if (m - m0 >= 3 && n - n0 >= 3) {
471
+ switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
472
+ case 0x33:
527
473
  mc = 3;
528
474
  nc = 3;
529
- gemm3x3(m0, m, n0, n);
530
- } else {
475
+ gemm<3, 3>(m0, m, n0, n);
476
+ break;
477
+ case 0x32:
478
+ mc = 3;
479
+ nc = 2;
480
+ gemm<3, 2>(m0, m, n0, n);
481
+ break;
482
+ case 0x23:
483
+ mc = 2;
484
+ nc = 3;
485
+ gemm<2, 3>(m0, m, n0, n);
486
+ break;
487
+ case 0x22:
488
+ mc = 2;
489
+ nc = 2;
490
+ gemm<2, 2>(m0, m, n0, n);
491
+ break;
492
+ case 0x31:
493
+ mc = 3;
494
+ nc = 1;
495
+ gemm<3, 1>(m0, m, n0, n);
496
+ break;
497
+ case 0x13:
498
+ mc = 1;
499
+ nc = 3;
500
+ gemm<1, 3>(m0, m, n0, n);
501
+ break;
502
+ case 0x21:
503
+ mc = 2;
504
+ nc = 1;
505
+ gemm<2, 1>(m0, m, n0, n);
506
+ break;
507
+ case 0x12:
508
+ mc = 1;
509
+ nc = 2;
510
+ gemm<1, 2>(m0, m, n0, n);
511
+ break;
512
+ case 0x11:
531
513
  mc = 1;
532
514
  nc = 1;
533
- gemm1x1(m0, m, n0, n);
515
+ gemm<1, 1>(m0, m, n0, n);
516
+ break;
517
+ default:
518
+ return;
534
519
  }
535
520
  mp = m0 + (m - m0) / mc * mc;
536
521
  np = n0 + (n - n0) / nc * nc;
537
522
  mnpack(mp, m, n0, np);
538
- mnpack(m0, mp, np, n);
539
- mnpack(mp, m, np, n);
540
- }
541
-
542
- NOINLINE void gemm3x3(int m0, int m, int n0, int n) {
543
- BEGIN_KERNEL(3, 3)
544
- int32x4_t zero = vdupq_n_s32(0);
545
- float32x4_t c00 = vdupq_n_f32(0.f);
546
- float32x4_t c01 = vdupq_n_f32(0.f);
547
- float32x4_t c02 = vdupq_n_f32(0.f);
548
- float32x4_t c10 = vdupq_n_f32(0.f);
549
- float32x4_t c11 = vdupq_n_f32(0.f);
550
- float32x4_t c12 = vdupq_n_f32(0.f);
551
- float32x4_t c20 = vdupq_n_f32(0.f);
552
- float32x4_t c21 = vdupq_n_f32(0.f);
553
- float32x4_t c22 = vdupq_n_f32(0.f);
554
- const TA *Ap0 = A + lda * (i + 0);
555
- const TA *Ap1 = A + lda * (i + 1);
556
- const TA *Ap2 = A + lda * (i + 2);
557
- const block_q8_0 *Bp0 = B + ldb * (j + 0);
558
- const block_q8_0 *Bp1 = B + ldb * (j + 1);
559
- const block_q8_0 *Bp2 = B + ldb * (j + 2);
560
- for (int l = 0; l < k; ++l) {
561
- c00 = vmlaq_n_f32(
562
- c00,
563
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)),
564
- load_hi(Ap0 + l), load_hi(Bp0 + l))),
565
- unhalf(Ap0[l].d) * unhalf(Bp0[l].d));
566
- c01 = vmlaq_n_f32(
567
- c01,
568
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)),
569
- load_hi(Ap0 + l), load_hi(Bp1 + l))),
570
- unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
571
- c02 = vmlaq_n_f32(
572
- c02,
573
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
574
- load_hi(Ap0 + l), load_hi(Bp2 + l))),
575
- unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
576
- c10 = vmlaq_n_f32(
577
- c10,
578
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
579
- load_hi(Ap1 + l), load_hi(Bp0 + l))),
580
- unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
581
- c11 = vmlaq_n_f32(
582
- c11,
583
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
584
- load_hi(Ap1 + l), load_hi(Bp1 + l))),
585
- unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
586
- c12 = vmlaq_n_f32(
587
- c12,
588
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
589
- load_hi(Ap1 + l), load_hi(Bp2 + l))),
590
- unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
591
- c20 = vmlaq_n_f32(
592
- c20,
593
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
594
- load_hi(Ap2 + l), load_hi(Bp0 + l))),
595
- unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
596
- c21 = vmlaq_n_f32(
597
- c21,
598
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
599
- load_hi(Ap2 + l), load_hi(Bp1 + l))),
600
- unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
601
- c22 = vmlaq_n_f32(
602
- c22,
603
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
604
- load_hi(Ap2 + l), load_hi(Bp2 + l))),
605
- unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
606
- }
607
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
608
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
609
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
610
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
611
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
612
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
613
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
614
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
615
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
616
- END_KERNEL()
617
- }
618
-
619
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
620
- BEGIN_KERNEL(1, 1)
621
- float32x4_t acc = vdupq_n_f32(0.f);
622
- const TA *Ap = A + lda * i;
623
- const block_q8_0 *Bp = B + ldb * j;
624
- for (int l = 0; l < k; ++l) {
625
- acc = vmlaq_n_f32(acc,
626
- vcvtq_f32_s32(vdotq_s32(
627
- vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
628
- load_hi(Ap + l), load_hi(Bp + l))),
629
- unhalf(Ap[l].d) * unhalf(Bp[l].d));
523
+ mnpack(m0, m, np, n);
524
+ }
525
+
526
+ template <int RM, int RN>
527
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
528
+ int ytiles = (m - m0) / RM;
529
+ int xtiles = (n - n0) / RN;
530
+ int tiles = xtiles * ytiles;
531
+ int duty = (tiles + nth - 1) / nth;
532
+ int start = duty * ith;
533
+ int end = start + duty;
534
+ if (end > tiles)
535
+ end = tiles;
536
+ for (int job = start; job < end; ++job) {
537
+ int ii = m0 + job / xtiles * RM;
538
+ int jj = n0 + job % xtiles * RN;
539
+ float32x4_t Cv[RN][RM] = {};
540
+ for (int l = 0; l < k; ++l)
541
+ for (int j = 0; j < RN; ++j)
542
+ for (int i = 0; i < RM; ++i)
543
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
544
+ vcvtq_f32_s32(vdotq_s32(
545
+ vdotq_s32(vdupq_n_s32(0),
546
+ load_lo(A + lda * (ii + i) + l),
547
+ load_lo(B + ldb * (jj + j) + l)),
548
+ load_hi(A + lda * (ii + i) + l),
549
+ load_hi(B + ldb * (jj + j) + l))),
550
+ unhalf(A[lda * (ii + i) + l].d) *
551
+ unhalf(B[ldb * (jj + j) + l].d));
552
+ for (int j = 0; j < RN; ++j)
553
+ for (int i = 0; i < RM; ++i)
554
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
630
555
  }
631
- C[ldc * j + i] = hsum(acc);
632
- END_KERNEL()
633
556
  }
634
557
 
635
558
  inline int8x16_t load_lo(const block_q8_0 *b) {
636
559
  return vld1q_s8(b->qs);
637
560
  }
561
+
638
562
  inline int8x16_t load_hi(const block_q8_0 *b) {
639
563
  return vld1q_s8(b->qs + 16);
640
564
  }
@@ -644,6 +568,7 @@ class tinyBLAS_Q0_ARM {
644
568
  vdupq_n_u8(0x0f))),
645
569
  vdupq_n_s8(0x8));
646
570
  }
571
+
647
572
  inline int8x16_t load_hi(const block_q4_0 *b) {
648
573
  return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
649
574
  vdupq_n_s8(0x8));
@@ -679,217 +604,143 @@ class tinyBLAS_Q0_AVX2 {
679
604
  }
680
605
 
681
606
  private:
682
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
607
+ void mnpack(int m0, int m, int n0, int n) {
683
608
  int mc, nc, mp, np;
684
- if (m - m0 <= 0 || n - n0 <= 0)
685
- return;
686
- if (m - m0 >= 4 && n - n0 >= 3) {
609
+ switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
610
+ #if VECTOR_REGISTERS == 32
611
+ case 0x44:
612
+ mc = 4;
613
+ nc = 4;
614
+ gemm<4, 4>(m0, m, n0, n);
615
+ break;
616
+ case 0x43:
617
+ mc = 4;
618
+ nc = 3;
619
+ gemm<4, 3>(m0, m, n0, n);
620
+ break;
621
+ case 0x34:
622
+ mc = 3;
623
+ nc = 4;
624
+ gemm<3, 4>(m0, m, n0, n);
625
+ break;
626
+ case 0x33:
627
+ mc = 3;
628
+ nc = 3;
629
+ gemm<3, 3>(m0, m, n0, n);
630
+ break;
631
+ case 0x42:
687
632
  mc = 4;
633
+ nc = 2;
634
+ gemm<4, 2>(m0, m, n0, n);
635
+ break;
636
+ case 0x24:
637
+ mc = 2;
638
+ nc = 4;
639
+ gemm<2, 4>(m0, m, n0, n);
640
+ break;
641
+ #else
642
+ case 0x44:
643
+ case 0x43:
644
+ case 0x42:
645
+ mc = 4;
646
+ nc = 2;
647
+ gemm<4, 2>(m0, m, n0, n);
648
+ break;
649
+ case 0x34:
650
+ case 0x24:
651
+ mc = 2;
652
+ nc = 4;
653
+ gemm<2, 4>(m0, m, n0, n);
654
+ break;
655
+ case 0x33:
656
+ #endif
657
+ case 0x32:
658
+ mc = 3;
659
+ nc = 2;
660
+ gemm<3, 2>(m0, m, n0, n);
661
+ break;
662
+ case 0x23:
663
+ mc = 2;
688
664
  nc = 3;
689
- gemm4x3(m0, m, n0, n);
690
- } else if (m - m0 >= 4 && n - n0 >= 1) {
665
+ gemm<2, 3>(m0, m, n0, n);
666
+ break;
667
+ case 0x41:
691
668
  mc = 4;
692
669
  nc = 1;
693
- gemm4x1(m0, m, n0, n);
694
- } else if (m - m0 >= 1 && n - n0 >= 4) {
670
+ gemm<4, 1>(m0, m, n0, n);
671
+ break;
672
+ case 0x22:
673
+ mc = 2;
674
+ nc = 2;
675
+ gemm<2, 2>(m0, m, n0, n);
676
+ break;
677
+ case 0x14:
695
678
  mc = 1;
696
679
  nc = 4;
697
- gemm1x4(m0, m, n0, n);
698
- } else {
680
+ gemm<1, 4>(m0, m, n0, n);
681
+ break;
682
+ case 0x31:
683
+ mc = 3;
684
+ nc = 1;
685
+ gemm<3, 1>(m0, m, n0, n);
686
+ break;
687
+ case 0x13:
699
688
  mc = 1;
689
+ nc = 3;
690
+ gemm<1, 3>(m0, m, n0, n);
691
+ break;
692
+ case 0x21:
693
+ mc = 2;
700
694
  nc = 1;
701
- gemm1x1(m0, m, n0, n);
695
+ gemm<2, 1>(m0, m, n0, n);
696
+ break;
697
+ case 0x12:
698
+ mc = 1;
699
+ nc = 2;
700
+ gemm<1, 2>(m0, m, n0, n);
701
+ break;
702
+ case 0x11:
703
+ mc = 1;
704
+ nc = 1;
705
+ gemm<1, 1>(m0, m, n0, n);
706
+ break;
707
+ default:
708
+ return;
702
709
  }
703
710
  mp = m0 + (m - m0) / mc * mc;
704
711
  np = n0 + (n - n0) / nc * nc;
705
712
  mnpack(mp, m, n0, np);
706
- mnpack(m0, mp, np, n);
707
- mnpack(mp, m, np, n);
708
- }
709
-
710
- NOINLINE void gemm4x3(int m0, int m, int n0, int n) {
711
- BEGIN_KERNEL(4, 3)
712
- __m256 c00 = _mm256_setzero_ps();
713
- __m256 c10 = _mm256_setzero_ps();
714
- __m256 c20 = _mm256_setzero_ps();
715
- __m256 c30 = _mm256_setzero_ps();
716
- __m256 c01 = _mm256_setzero_ps();
717
- __m256 c11 = _mm256_setzero_ps();
718
- __m256 c21 = _mm256_setzero_ps();
719
- __m256 c31 = _mm256_setzero_ps();
720
- __m256 c02 = _mm256_setzero_ps();
721
- __m256 c12 = _mm256_setzero_ps();
722
- __m256 c22 = _mm256_setzero_ps();
723
- __m256 c32 = _mm256_setzero_ps();
724
- const TA *Ap0 = A + lda * (i + 0);
725
- const TA *Ap1 = A + lda * (i + 1);
726
- const TA *Ap2 = A + lda * (i + 2);
727
- const TA *Ap3 = A + lda * (i + 3);
728
- const TB *Bp0 = B + ldb * (j + 0);
729
- const TB *Bp1 = B + ldb * (j + 1);
730
- const TB *Bp2 = B + ldb * (j + 2);
731
- for (int l = 0; l < k; ++l) {
732
- float da0 = unhalf(Ap0[l].d);
733
- float da1 = unhalf(Ap1[l].d);
734
- float da2 = unhalf(Ap2[l].d);
735
- float da3 = unhalf(Ap3[l].d);
736
- __m256i e0 = load(Ap0 + l);
737
- __m256i e1 = load(Ap1 + l);
738
- __m256i e2 = load(Ap2 + l);
739
- __m256i e3 = load(Ap3 + l);
740
- float db0 = unhalf(Bp0[l].d);
741
- __m256 d00 = _mm256_set1_ps(da0 * db0);
742
- __m256 d10 = _mm256_set1_ps(da1 * db0);
743
- __m256 d20 = _mm256_set1_ps(da2 * db0);
744
- __m256 d30 = _mm256_set1_ps(da3 * db0);
745
- __m256i f0 = load(Bp0 + l);
746
- __m256i u0 = _mm256_sign_epi8(f0, f0);
747
- __m256i s00 = _mm256_sign_epi8(e0, f0);
748
- __m256i s10 = _mm256_sign_epi8(e1, f0);
749
- __m256i s20 = _mm256_sign_epi8(e2, f0);
750
- __m256i s30 = _mm256_sign_epi8(e3, f0);
751
- c00 = madd(d00, updot(u0, s00), c00);
752
- c10 = madd(d10, updot(u0, s10), c10);
753
- c20 = madd(d20, updot(u0, s20), c20);
754
- c30 = madd(d30, updot(u0, s30), c30);
755
- float db1 = unhalf(Bp1[l].d);
756
- __m256 d01 = _mm256_set1_ps(da0 * db1);
757
- __m256 d11 = _mm256_set1_ps(da1 * db1);
758
- __m256 d21 = _mm256_set1_ps(da2 * db1);
759
- __m256 d31 = _mm256_set1_ps(da3 * db1);
760
- __m256i f1 = load(Bp1 + l);
761
- __m256i u1 = _mm256_sign_epi8(f1, f1);
762
- __m256i s01 = _mm256_sign_epi8(e0, f1);
763
- __m256i s11 = _mm256_sign_epi8(e1, f1);
764
- __m256i s21 = _mm256_sign_epi8(e2, f1);
765
- __m256i s31 = _mm256_sign_epi8(e3, f1);
766
- c01 = madd(d01, updot(u1, s01), c01);
767
- c11 = madd(d11, updot(u1, s11), c11);
768
- c21 = madd(d21, updot(u1, s21), c21);
769
- c31 = madd(d31, updot(u1, s31), c31);
770
- float db2 = unhalf(Bp2[l].d);
771
- __m256 d02 = _mm256_set1_ps(da0 * db2);
772
- __m256 d12 = _mm256_set1_ps(da1 * db2);
773
- __m256 d22 = _mm256_set1_ps(da2 * db2);
774
- __m256 d32 = _mm256_set1_ps(da3 * db2);
775
- __m256i f2 = load(Bp2 + l);
776
- __m256i u2 = _mm256_sign_epi8(f2, f2);
777
- __m256i s02 = _mm256_sign_epi8(e0, f2);
778
- __m256i s12 = _mm256_sign_epi8(e1, f2);
779
- __m256i s22 = _mm256_sign_epi8(e2, f2);
780
- __m256i s32 = _mm256_sign_epi8(e3, f2);
781
- c02 = madd(d02, updot(u2, s02), c02);
782
- c12 = madd(d12, updot(u2, s12), c12);
783
- c22 = madd(d22, updot(u2, s22), c22);
784
- c32 = madd(d32, updot(u2, s32), c32);
785
- }
786
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
787
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
788
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
789
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
790
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
791
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
792
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
793
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
794
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
795
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
796
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
797
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
798
- END_KERNEL()
799
- }
800
-
801
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
802
- BEGIN_KERNEL(4, 1)
803
- __m256 c0 = _mm256_setzero_ps();
804
- __m256 c1 = _mm256_setzero_ps();
805
- __m256 c2 = _mm256_setzero_ps();
806
- __m256 c3 = _mm256_setzero_ps();
807
- const TA *Ap0 = A + lda * (i + 0);
808
- const TA *Ap1 = A + lda * (i + 1);
809
- const TA *Ap2 = A + lda * (i + 2);
810
- const TA *Ap3 = A + lda * (i + 3);
811
- const TB *Bp = B + ldb * j;
812
- for (int l = 0; l < k; ++l) {
813
- float db0 = unhalf(Bp[l].d);
814
- __m256i f = load(Bp + l);
815
- __m256i u = _mm256_sign_epi8(f, f);
816
- __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
817
- __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
818
- __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
819
- __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
820
- __m256i e0 = load(Ap0 + l);
821
- __m256i e1 = load(Ap1 + l);
822
- __m256i e2 = load(Ap2 + l);
823
- __m256i e3 = load(Ap3 + l);
824
- __m256i s0 = _mm256_sign_epi8(e0, f);
825
- __m256i s1 = _mm256_sign_epi8(e1, f);
826
- __m256i s2 = _mm256_sign_epi8(e2, f);
827
- __m256i s3 = _mm256_sign_epi8(e3, f);
828
- __m256 g0 = updot(u, s0);
829
- __m256 g1 = updot(u, s1);
830
- __m256 g2 = updot(u, s2);
831
- __m256 g3 = updot(u, s3);
832
- c0 = madd(d0, g0, c0);
833
- c1 = madd(d1, g1, c1);
834
- c2 = madd(d2, g2, c2);
835
- c3 = madd(d3, g3, c3);
836
- }
837
- C[ldc * j + (i + 0)] = hsum(c0);
838
- C[ldc * j + (i + 1)] = hsum(c1);
839
- C[ldc * j + (i + 2)] = hsum(c2);
840
- C[ldc * j + (i + 3)] = hsum(c3);
841
- END_KERNEL()
842
- }
843
-
844
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
845
- BEGIN_KERNEL(1, 4)
846
- __m256 c0 = _mm256_setzero_ps();
847
- __m256 c1 = _mm256_setzero_ps();
848
- __m256 c2 = _mm256_setzero_ps();
849
- __m256 c3 = _mm256_setzero_ps();
850
- const TB *Bp0 = B + ldb * (j + 0);
851
- const TB *Bp1 = B + ldb * (j + 1);
852
- const TB *Bp2 = B + ldb * (j + 2);
853
- const TB *Bp3 = B + ldb * (j + 3);
854
- const TA *Ap = A + lda * i;
855
- for (int l = 0; l < k; ++l) {
856
- float da0 = unhalf(Ap[l].d);
857
- __m256i f = load(Ap + l);
858
- __m256i u = _mm256_sign_epi8(f, f);
859
- __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
860
- __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
861
- __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
862
- __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
863
- __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
864
- __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
865
- __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
866
- __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
867
- c0 = madd(d0, g0, c0);
868
- c1 = madd(d1, g1, c1);
869
- c2 = madd(d2, g2, c2);
870
- c3 = madd(d3, g3, c3);
871
- }
872
- C[ldc * (j + 0) + i] = hsum(c0);
873
- C[ldc * (j + 1) + i] = hsum(c1);
874
- C[ldc * (j + 2) + i] = hsum(c2);
875
- C[ldc * (j + 3) + i] = hsum(c3);
876
- END_KERNEL()
877
- }
878
-
879
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
880
- BEGIN_KERNEL(1, 1)
881
- __m256 c = _mm256_setzero_ps();
882
- const TA *Ap = A + lda * i;
883
- const TB *Bp = B + ldb * j;
884
- for (int l = 0; l < k; ++l) {
885
- __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
886
- __m256i e = load(Ap + l);
887
- __m256i f = load(Bp + l);
888
- __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
889
- c = madd(d, g, c);
713
+ mnpack(m0, m, np, n);
714
+ }
715
+
716
+ template <int RM, int RN>
717
+ NOINLINE void gemm(int m0, int m, int n0, int n) {
718
+ int ytiles = (m - m0) / RM;
719
+ int xtiles = (n - n0) / RN;
720
+ int tiles = xtiles * ytiles;
721
+ int duty = (tiles + nth - 1) / nth;
722
+ int start = duty * ith;
723
+ int end = start + duty;
724
+ if (end > tiles)
725
+ end = tiles;
726
+ for (int job = start; job < end; ++job) {
727
+ int ii = m0 + job / xtiles * RM;
728
+ int jj = n0 + job % xtiles * RN;
729
+ __m256 Cv[RN][RM] = {};
730
+ for (int l = 0; l < k; ++l)
731
+ for (int j = 0; j < RN; ++j)
732
+ for (int i = 0; i < RM; ++i)
733
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
734
+ unhalf(B[ldb * (jj + j) + l].d)),
735
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
736
+ load(A + lda * (ii + i) + l)),
737
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
738
+ load(A + lda * (ii + i) + l))),
739
+ Cv[j][i]);
740
+ for (int j = 0; j < RN; ++j)
741
+ for (int i = 0; i < RM; ++i)
742
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
890
743
  }
891
- C[ldc * j + i] = hsum(c);
892
- END_KERNEL()
893
744
  }
894
745
 
895
746
  inline __m256i load(const block_q8_0 *b) {
@@ -911,10 +762,10 @@ class tinyBLAS_Q0_AVX2 {
911
762
  }
912
763
 
913
764
  static inline __m256i denibble(const uint8_t *p) {
914
- const __m128i tmp = _mm_loadu_si128((const __m128i *)p);
915
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
916
- const __m256i lowMask = _mm256_set1_epi8(15);
917
- return _mm256_and_si256(lowMask, bytes);
765
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
766
+ return _mm256_and_si256(_mm256_set1_epi8(15),
767
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
768
+ _mm_srli_epi16(x, 4), 1));
918
769
  }
919
770
 
920
771
  const TA *const A;