cui-llama.rn 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. package/android/src/main/CMakeLists.txt +14 -8
  2. package/android/src/main/jni.cpp +38 -37
  3. package/cpp/common.cpp +50 -30
  4. package/cpp/common.h +32 -13
  5. package/cpp/ggml-alloc.c +0 -1
  6. package/cpp/ggml-backend-reg.cpp +79 -49
  7. package/cpp/ggml-backend.cpp +5 -2
  8. package/cpp/ggml-cpp.h +1 -0
  9. package/cpp/ggml-cpu-aarch64.cpp +57 -72
  10. package/cpp/ggml-cpu-quants.c +5 -1
  11. package/cpp/ggml-cpu.c +6 -6
  12. package/cpp/ggml-cpu.cpp +9 -0
  13. package/cpp/ggml-impl.h +11 -0
  14. package/cpp/ggml-metal.m +2 -2
  15. package/cpp/ggml.c +129 -1388
  16. package/cpp/ggml.h +29 -152
  17. package/cpp/gguf.cpp +1325 -0
  18. package/cpp/gguf.h +202 -0
  19. package/cpp/llama-adapter.cpp +346 -0
  20. package/cpp/llama-adapter.h +73 -0
  21. package/cpp/llama-arch.cpp +1434 -0
  22. package/cpp/llama-arch.h +395 -0
  23. package/cpp/llama-batch.cpp +368 -0
  24. package/cpp/llama-batch.h +88 -0
  25. package/cpp/llama-chat.cpp +567 -0
  26. package/cpp/llama-chat.h +51 -0
  27. package/cpp/llama-context.cpp +1771 -0
  28. package/cpp/llama-context.h +128 -0
  29. package/cpp/llama-cparams.cpp +1 -0
  30. package/cpp/llama-cparams.h +37 -0
  31. package/cpp/llama-cpp.h +30 -0
  32. package/cpp/llama-grammar.cpp +16 -15
  33. package/cpp/llama-grammar.h +5 -6
  34. package/cpp/llama-hparams.cpp +71 -0
  35. package/cpp/llama-hparams.h +140 -0
  36. package/cpp/llama-impl.cpp +167 -0
  37. package/cpp/llama-impl.h +16 -136
  38. package/cpp/llama-kv-cache.cpp +718 -0
  39. package/cpp/llama-kv-cache.h +218 -0
  40. package/cpp/llama-mmap.cpp +589 -0
  41. package/cpp/llama-mmap.h +67 -0
  42. package/cpp/llama-model-loader.cpp +1011 -0
  43. package/cpp/llama-model-loader.h +158 -0
  44. package/cpp/llama-model.cpp +2202 -0
  45. package/cpp/llama-model.h +391 -0
  46. package/cpp/llama-sampling.cpp +117 -4
  47. package/cpp/llama-vocab.cpp +26 -29
  48. package/cpp/llama-vocab.h +14 -2
  49. package/cpp/llama.cpp +8839 -19131
  50. package/cpp/llama.cpp.rej +23 -0
  51. package/cpp/llama.h +31 -9
  52. package/cpp/rn-llama.hpp +39 -37
  53. package/cpp/sgemm.cpp +1091 -378
  54. package/cpp/sgemm.h +2 -2
  55. package/cpp/unicode.cpp +6 -0
  56. package/package.json +1 -1
package/cpp/sgemm.cpp CHANGED
@@ -53,6 +53,9 @@
53
53
  #include "ggml-cpu-impl.h"
54
54
  #include "ggml-quants.h"
55
55
 
56
+ #include <atomic>
57
+ #include <array>
58
+
56
59
  #ifdef _MSC_VER
57
60
  #define NOINLINE __declspec(noinline)
58
61
  #else
@@ -134,6 +137,16 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
134
137
  return _mm512_fmadd_ps(a, b, c);
135
138
  }
136
139
  #endif
140
+ #if defined(__AVX512BF16__)
141
+ template <>
142
+ inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
143
+ return _mm512_dpbf16_ps(c, a, b);
144
+ }
145
+ template <>
146
+ inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
147
+ return _mm256_dpbf16_ps(c, a, b);
148
+ }
149
+ #endif
137
150
  #endif
138
151
 
139
152
  #if defined(__ARM_FEATURE_FMA)
@@ -204,6 +217,7 @@ template <> inline float32x4_t load(const float *p) {
204
217
  return vld1q_f32(p);
205
218
  }
206
219
  #if !defined(_MSC_VER)
220
+ // FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
207
221
  template <> inline float16x8_t load(const lm_ggml_fp16_t *p) {
208
222
  return vld1q_f16((const float16_t *)p);
209
223
  }
@@ -225,6 +239,13 @@ template <> inline __m256 load(const float *p) {
225
239
  }
226
240
  #endif // __AVX__
227
241
 
242
+ #if defined(__AVX2__) || defined(__AVX512F__)
243
+ template <> inline __m256 load(const lm_ggml_bf16_t *p) {
244
+ return _mm256_castsi256_ps(
245
+ _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
246
+ }
247
+ #endif // __AVX2__
248
+
228
249
  #if defined(__F16C__)
229
250
  template <> inline __m256 load(const lm_ggml_fp16_t *p) {
230
251
  return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
@@ -238,8 +259,27 @@ template <> inline __m512 load(const float *p) {
238
259
  template <> inline __m512 load(const lm_ggml_fp16_t *p) {
239
260
  return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
240
261
  }
262
+ template <> inline __m512 load(const lm_ggml_bf16_t *p) {
263
+ return _mm512_castsi512_ps(
264
+ _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
265
+ }
241
266
  #endif // __AVX512F__
242
267
 
268
+ #if defined(__AVX512BF16__)
269
+ template <> inline __m512bh load(const lm_ggml_bf16_t *p) {
270
+ return (__m512bh)_mm512_loadu_ps((const float *)p);
271
+ }
272
+ template <> inline __m256bh load(const lm_ggml_bf16_t *p) {
273
+ return (__m256bh)_mm256_loadu_ps((const float *)p);
274
+ }
275
+ template <> inline __m512bh load(const float *p) {
276
+ return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
277
+ }
278
+ template <> inline __m256bh load(const float *p) {
279
+ return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
280
+ }
281
+ #endif
282
+
243
283
  ////////////////////////////////////////////////////////////////////////////////////////////////////
244
284
  // CONSTANTS
245
285
 
@@ -251,199 +291,170 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
251
291
  ////////////////////////////////////////////////////////////////////////////////////////////////////
252
292
  // FLOATING POINT MATRIX MULTIPLICATION
253
293
 
294
+ template <int M>
295
+ static inline int64_t BLOCK_SIZE(size_t m) {
296
+ const int64_t NB_BLOC_M = (m + M - 1) / M;
297
+ return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
298
+ }
299
+
300
+ static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
301
+ return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
302
+ }
303
+
254
304
  template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
255
305
  class tinyBLAS {
256
306
  public:
257
- tinyBLAS(int64_t k,
307
+ tinyBLAS(const lm_ggml_compute_params * params, int64_t k,
258
308
  const TA *A, int64_t lda,
259
309
  const TB *B, int64_t ldb,
260
- TC *C, int64_t ldc,
261
- int ith, int nth)
262
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
310
+ TC *C, int64_t ldc)
311
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
263
312
  }
264
313
 
265
- void matmul(int64_t m, int64_t n) {
266
- mnpack(0, m, 0, n);
314
+ bool matmul(int64_t m, int64_t n) {
315
+ if (k % KN != 0)
316
+ return false;
317
+ // compute RM for only need tile with size RM&RM-1
318
+ #if VECTOR_REGISTERS == 32
319
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
320
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
321
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
322
+ return true;
323
+ }
324
+ if (m % 8 == 0 ) {
325
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
326
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
327
+ return true;
328
+ }
329
+ if (m % 4 == 0) {
330
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
331
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
332
+ return true;
333
+ }
334
+ #else // VECTOR_REGISTERS == 16
335
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
336
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
337
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
338
+ return true;
339
+ }
340
+ if (m % 8 == 0 ) {
341
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
342
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
343
+ return true;
344
+ }
345
+ if (m % 4 == 0) {
346
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
347
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
348
+ return true;
349
+ }
350
+ #endif
351
+ return false;
267
352
  }
268
353
 
269
354
  private:
270
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
271
- int64_t mc, nc, mp, np;
272
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
273
- #if VECTOR_REGISTERS == 32
274
- case 0x55:
275
- mc = 5;
276
- nc = 5;
277
- gemm<5, 5>(m0, m, n0, n);
278
- break;
279
- case 0x45:
280
- mc = 4;
281
- nc = 5;
282
- gemm<4, 5>(m0, m, n0, n);
283
- break;
284
- case 0x54:
285
- mc = 5;
286
- nc = 4;
287
- gemm<5, 4>(m0, m, n0, n);
288
- break;
289
- case 0x44:
290
- mc = 4;
291
- nc = 4;
292
- gemm<4, 4>(m0, m, n0, n);
293
- break;
294
- case 0x53:
295
- mc = 5;
296
- nc = 3;
297
- gemm<5, 3>(m0, m, n0, n);
298
- break;
299
- case 0x35:
300
- mc = 3;
301
- nc = 5;
302
- gemm<3, 5>(m0, m, n0, n);
303
- break;
304
- case 0x43:
305
- mc = 4;
306
- nc = 3;
307
- gemm<4, 3>(m0, m, n0, n);
308
- break;
309
- #else
310
- case 0x55:
311
- case 0x54:
312
- case 0x53:
313
- case 0x45:
314
- case 0x44:
315
- case 0x43:
316
- mc = 4;
317
- nc = 3;
318
- gemm<4, 3>(m0, m, n0, n);
319
- break;
320
- case 0x35:
321
- #endif
322
- case 0x34:
323
- mc = 3;
324
- nc = 4;
325
- gemm<3, 4>(m0, m, n0, n);
326
- break;
327
- case 0x52:
328
- mc = 5;
329
- nc = 2;
330
- gemm<5, 2>(m0, m, n0, n);
331
- break;
332
- case 0x33:
333
- mc = 3;
334
- nc = 3;
335
- gemm<3, 3>(m0, m, n0, n);
336
- break;
337
- case 0x25:
338
- mc = 2;
339
- nc = 5;
340
- gemm<2, 5>(m0, m, n0, n);
341
- break;
342
- case 0x42:
343
- mc = 4;
344
- nc = 2;
345
- gemm<4, 2>(m0, m, n0, n);
346
- break;
347
- case 0x24:
348
- mc = 2;
349
- nc = 4;
350
- gemm<2, 4>(m0, m, n0, n);
351
- break;
352
- case 0x32:
353
- mc = 3;
354
- nc = 2;
355
- gemm<3, 2>(m0, m, n0, n);
356
- break;
357
- case 0x23:
358
- mc = 2;
359
- nc = 3;
360
- gemm<2, 3>(m0, m, n0, n);
361
- break;
362
- case 0x51:
363
- mc = 5;
364
- nc = 1;
365
- gemm<5, 1>(m0, m, n0, n);
366
- break;
367
- case 0x41:
368
- mc = 4;
369
- nc = 1;
370
- gemm<4, 1>(m0, m, n0, n);
371
- break;
372
- case 0x22:
373
- mc = 2;
374
- nc = 2;
375
- gemm<2, 2>(m0, m, n0, n);
376
- break;
377
- case 0x15:
378
- mc = 1;
379
- nc = 5;
380
- gemm<1, 5>(m0, m, n0, n);
381
- break;
382
- case 0x14:
383
- mc = 1;
384
- nc = 4;
385
- gemm<1, 4>(m0, m, n0, n);
386
- break;
387
- case 0x31:
388
- mc = 3;
389
- nc = 1;
390
- gemm<3, 1>(m0, m, n0, n);
391
- break;
392
- case 0x13:
393
- mc = 1;
394
- nc = 3;
395
- gemm<1, 3>(m0, m, n0, n);
396
- break;
397
- case 0x21:
398
- mc = 2;
399
- nc = 1;
400
- gemm<2, 1>(m0, m, n0, n);
401
- break;
402
- case 0x12:
403
- mc = 1;
404
- nc = 2;
405
- gemm<1, 2>(m0, m, n0, n);
406
- break;
407
- case 0x11:
408
- mc = 1;
409
- nc = 1;
410
- gemm<1, 1>(m0, m, n0, n);
411
- break;
412
- default:
413
- return;
355
+ template <int RM, int RN, int BM>
356
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
357
+ if (SIZE_N == RN) {
358
+ return gemm<RM, RN, BM>(m, n, BN);
359
+ }
360
+ if constexpr (RN > 1) {
361
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
362
+ } else {
363
+ LM_GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
364
+ LM_GGML_ASSERT(false); // we have miss something.
414
365
  }
415
- mp = m0 + (m - m0) / mc * mc;
416
- np = n0 + (n - n0) / nc * nc;
417
- mnpack(mp, m, n0, np);
418
- mnpack(m0, m, np, n);
419
366
  }
420
367
 
421
368
  template <int RM, int RN>
422
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
423
- int64_t ytiles = (m - m0) / RM;
424
- int64_t xtiles = (n - n0) / RN;
425
- int64_t tiles = xtiles * ytiles;
426
- int64_t duty = (tiles + nth - 1) / nth;
427
- int64_t start = duty * ith;
428
- int64_t end = start + duty;
429
- if (end > tiles)
430
- end = tiles;
431
- for (int64_t job = start; job < end; ++job) {
432
- int64_t ii = m0 + job / xtiles * RM;
433
- int64_t jj = n0 + job % xtiles * RN;
434
- D Cv[RN][RM] = {};
435
- for (int64_t l = 0; l < k; l += KN)
436
- for (int64_t j = 0; j < RN; ++j)
437
- for (int64_t i = 0; i < RM; ++i)
438
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
439
- load<V>(B + ldb * (jj + j) + l),
440
- Cv[j][i]);
441
- for (int64_t j = 0; j < RN; ++j)
442
- for (int64_t i = 0; i < RM; ++i)
443
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
369
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
370
+ D Cv[RN][RM] = {};
371
+ for (int64_t l = 0; l < k; l += KN) {
372
+ // help compiler for op order.
373
+ if constexpr (RM <= RN) {
374
+ V Av[RM];
375
+ for (int64_t i = 0; i < RM; ++i) {
376
+ Av[i] = load<V>(A + lda * (ii + i) + l);
377
+ }
378
+ for (int64_t j = 0; j < RN; ++j) {
379
+ V Bv = load<V>(B + ldb * (jj + j) + l);
380
+ for (int64_t i = 0; i < RM; ++i) {
381
+ Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
382
+ }
383
+ }
384
+ } else {
385
+ V Bv[RN];
386
+ for (int64_t j = 0; j < RN; ++j) {
387
+ Bv[j] = load<V>(B + ldb * (jj + j) + l);
388
+ }
389
+ for (int64_t i = 0; i < RM; ++i) {
390
+ V Av = load<V>(A + lda * (ii + i) + l);
391
+ for (int64_t j = 0; j < RN; ++j) {
392
+ Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
393
+ }
394
+ }
395
+ }
396
+ }
397
+ for (int64_t j = 0; j < RN; ++j)
398
+ for (int64_t i = 0; i < RM; ++i)
399
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
400
+ }
401
+
402
+ template <int RM, int RN, int BM>
403
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
404
+ static std::atomic<int64_t> current_chunk;
405
+
406
+ LM_GGML_ASSERT(m % (RM * BM) == 0);
407
+ const int64_t ytiles = m / (RM * BM);
408
+ const int64_t xtiles = (n + RN -1) / RN;
409
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
410
+
411
+ // "round" bloc_size to "nearest" BN
412
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
413
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
414
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
415
+ const int64_t nb_job = ytiles * NB_BN;
416
+
417
+ if (params->ith == 0) {
418
+ LM_GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
419
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
420
+ std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
421
+ }
422
+
423
+ lm_ggml_barrier(params->threadpool);
424
+
425
+ int64_t job = params->ith;
426
+ while (job < nb_job) {
427
+ const int64_t ii = (job % ytiles) * RM * BM;
428
+ const int64_t jb = job / ytiles;
429
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
430
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
431
+
432
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
433
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
434
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
435
+
436
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
437
+ int64_t jj = jj0;
438
+ for (; jj < jj1; jj += RN) {
439
+ gemm_bloc<RM, RN>(ii + bi, jj);
440
+ }
441
+ if constexpr (RN > 1) {
442
+ for (; jj < jj2; jj += RN - 1) {
443
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
444
+ }
445
+ }
446
+ LM_GGML_ASSERT(jj == jj2);
447
+ }
448
+
449
+ // next step.
450
+ job = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed);
444
451
  }
452
+
453
+ lm_ggml_barrier(params->threadpool);
454
+ return;
445
455
  }
446
456
 
457
+ const lm_ggml_compute_params * params;
447
458
  const TA *const A;
448
459
  const TB *const B;
449
460
  TC *const C;
@@ -451,8 +462,6 @@ class tinyBLAS {
451
462
  const int64_t lda;
452
463
  const int64_t ldb;
453
464
  const int64_t ldc;
454
- const int ith;
455
- const int nth;
456
465
  };
457
466
 
458
467
  //////////////////////////////////////////////////////////////////////////////////////////
@@ -992,8 +1001,10 @@ class tinyBLAS_Q0_AVX {
992
1001
 
993
1002
  inline __m256 updot(__m256i u, __m256i s) {
994
1003
  __m256i res;
995
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
1004
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
996
1005
  res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1006
+ #elif defined(__AVXVNNI__)
1007
+ res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
997
1008
  #else
998
1009
  res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
999
1010
  #endif
@@ -1042,9 +1053,9 @@ class tinyBLAS_Q0_AVX {
1042
1053
  } \
1043
1054
 
1044
1055
  template <typename TA, typename TB, typename TC>
1045
- class tinyBLAS_PPC {
1056
+ class tinyBLAS_Q0_PPC {
1046
1057
  public:
1047
- tinyBLAS_PPC(int64_t k,
1058
+ tinyBLAS_Q0_PPC(int64_t k,
1048
1059
  const TA *A, int64_t lda,
1049
1060
  const TB *B, int64_t ldb,
1050
1061
  TC *C, int64_t ldc,
@@ -1053,74 +1064,773 @@ class tinyBLAS_PPC {
1053
1064
  }
1054
1065
 
1055
1066
  void matmul(int64_t m, int64_t n) {
1056
- mnpack(0, m, 0, n);
1067
+ mnpack(0, m, 0, n);
1057
1068
  }
1058
1069
 
1059
1070
  private:
1060
1071
 
1061
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1072
+ template<int RM, int RN>
1073
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1074
+ for (int I = 0; I < RM; I++) {
1075
+ for (int J = 0; J < RN; J++) {
1076
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1077
+ }
1078
+ }
1079
+ }
1062
1080
 
1063
- void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1064
- int64_t i, j;
1065
- float *aoffset = NULL, *boffset = NULL;
1066
- float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1067
- float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1081
+ template<int size>
1082
+ inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1083
+ vector signed int vec_C[4];
1084
+ vector float CA[4] = {0};
1085
+ vector float res[4] = {0};
1086
+ __builtin_mma_disassemble_acc(vec_C, ACC);
1087
+ for (int i = 0; i < 4; i++) {
1088
+ CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1089
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1090
+ fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1091
+ }
1092
+ }
1068
1093
 
1069
- aoffset = const_cast<float*>(a);
1070
- boffset = vec;
1094
+ template<typename VA, typename VB>
1095
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1096
+ int64_t i, j;
1097
+ TA *aoffset = NULL;
1098
+ VA *vecOffset = NULL;
1099
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1100
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1101
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1102
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1103
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
1104
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1105
+ vector unsigned char xor_vector;
1106
+ uint8_t flip_vec = 0x80;
1107
+ xor_vector = vec_splats(flip_vec);
1108
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1109
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1110
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1111
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1112
+
1113
+ aoffset = const_cast<TA*>(a);
1114
+ vecOffset = vec;
1071
1115
  j = (rows >> 3);
1072
1116
  if (j > 0) {
1073
1117
  do {
1074
- aoffset1 = aoffset;
1075
- aoffset2 = aoffset1 + lda;
1076
- aoffset3 = aoffset2 + lda;
1077
- aoffset4 = aoffset3 + lda;
1078
- aoffset5 = aoffset4 + lda;
1079
- aoffset6 = aoffset5 + lda;
1080
- aoffset7 = aoffset6 + lda;
1081
- aoffset8 = aoffset7 + lda;
1082
- aoffset += 8 * lda;
1083
- i = (cols >> 3);
1084
- if (i > 0) {
1085
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1086
- vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1087
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1088
- do {
1089
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1090
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1091
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1092
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1093
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1094
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1095
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1096
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1097
- __builtin_vsx_disassemble_pair(c1, &C1);
1098
- __builtin_vsx_disassemble_pair(c2, &C2);
1099
- __builtin_vsx_disassemble_pair(c3, &C3);
1100
- __builtin_vsx_disassemble_pair(c4, &C4);
1101
- __builtin_vsx_disassemble_pair(c5, &C5);
1102
- __builtin_vsx_disassemble_pair(c6, &C6);
1103
- __builtin_vsx_disassemble_pair(c7, &C7);
1104
- __builtin_vsx_disassemble_pair(c8, &C8);
1118
+ aoffset1 = aoffset;
1119
+ aoffset2 = aoffset1 + lda;
1120
+ aoffset3 = aoffset2 + lda;
1121
+ aoffset4 = aoffset3 + lda;
1122
+ aoffset5 = aoffset4 + lda;
1123
+ aoffset6 = aoffset5 + lda;
1124
+ aoffset7 = aoffset6 + lda;
1125
+ aoffset8 = aoffset7 + lda;
1126
+ aoffset += 8 * lda;
1105
1127
 
1106
- t1 = vec_mergeh(c1[0], c2[0]);
1107
- t2 = vec_mergeh(c3[0], c4[0]);
1108
- t3 = vec_mergeh(c5[0], c6[0]);
1109
- t4 = vec_mergeh(c7[0], c8[0]);
1110
- t5 = vec_xxpermdi(t1, t2, 0);
1111
- t6 = vec_xxpermdi(t3, t4, 0);
1112
- t7 = vec_xxpermdi(t1, t2, 3);
1113
- t8 = vec_xxpermdi(t3, t4, 3);
1114
- vec_xst(t5, 0, boffset);
1115
- vec_xst(t6, 0, boffset+4);
1116
- vec_xst(t7, 0, boffset+8);
1117
- vec_xst(t8, 0, boffset+12);
1128
+ i = (cols >> 3);
1129
+ if (i > 0) {
1130
+ do {
1131
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1132
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1133
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1134
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1135
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
1136
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
1137
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
1138
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
1118
1139
 
1119
- t1 = vec_mergel(c1[0], c2[0]);
1120
- t2 = vec_mergel(c3[0], c4[0]);
1121
- t3 = vec_mergel(c5[0], c6[0]);
1122
- t4 = vec_mergel(c7[0], c8[0]);
1123
- t5 = vec_xxpermdi(t1, t2, 0);
1140
+ __builtin_vsx_disassemble_pair(c1, &C1);
1141
+ __builtin_vsx_disassemble_pair(c2, &C2);
1142
+ __builtin_vsx_disassemble_pair(c3, &C3);
1143
+ __builtin_vsx_disassemble_pair(c4, &C4);
1144
+ __builtin_vsx_disassemble_pair(c5, &C5);
1145
+ __builtin_vsx_disassemble_pair(c6, &C6);
1146
+ __builtin_vsx_disassemble_pair(c7, &C7);
1147
+ __builtin_vsx_disassemble_pair(c8, &C8);
1148
+
1149
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1150
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1151
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1152
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1153
+ t5 = vec_perm(t1, t3, swiz3);
1154
+ t6 = vec_perm(t1, t3, swiz4);
1155
+ t7 = vec_perm(t2, t4, swiz3);
1156
+ t8 = vec_perm(t2, t4, swiz4);
1157
+ if (flip == true) {
1158
+ t5 = vec_xor(t5, xor_vector);
1159
+ t6 = vec_xor(t6, xor_vector);
1160
+ t7 = vec_xor(t7, xor_vector);
1161
+ t8 = vec_xor(t8, xor_vector);
1162
+ }
1163
+ vec_xst(t5, 0, vecOffset);
1164
+ vec_xst(t6, 0, vecOffset+16);
1165
+ vec_xst(t7, 0, vecOffset+32);
1166
+ vec_xst(t8, 0, vecOffset+48);
1167
+
1168
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1169
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1170
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1171
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1172
+ t5 = vec_perm(t1, t3, swiz3);
1173
+ t6 = vec_perm(t1, t3, swiz4);
1174
+ t7 = vec_perm(t2, t4, swiz3);
1175
+ t8 = vec_perm(t2, t4, swiz4);
1176
+ if (flip == true) {
1177
+ t5 = vec_xor(t5, xor_vector);
1178
+ t6 = vec_xor(t6, xor_vector);
1179
+ t7 = vec_xor(t7, xor_vector);
1180
+ t8 = vec_xor(t8, xor_vector);
1181
+ }
1182
+ vec_xst(t5, 0, vecOffset+64);
1183
+ vec_xst(t6, 0, vecOffset+80);
1184
+ vec_xst(t7, 0, vecOffset+96);
1185
+ vec_xst(t8, 0, vecOffset+112);
1186
+
1187
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1188
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1189
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1190
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1191
+ t5 = vec_perm(t1, t3, swiz3);
1192
+ t6 = vec_perm(t1, t3, swiz4);
1193
+ t7 = vec_perm(t2, t4, swiz3);
1194
+ t8 = vec_perm(t2, t4, swiz4);
1195
+ if (flip == true) {
1196
+ t5 = vec_xor(t5, xor_vector);
1197
+ t6 = vec_xor(t6, xor_vector);
1198
+ t7 = vec_xor(t7, xor_vector);
1199
+ t8 = vec_xor(t8, xor_vector);
1200
+ }
1201
+ vec_xst(t5, 0, vecOffset+128);
1202
+ vec_xst(t6, 0, vecOffset+144);
1203
+ vec_xst(t7, 0, vecOffset+160);
1204
+ vec_xst(t8, 0, vecOffset+176);
1205
+
1206
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1207
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1208
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1209
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1210
+ t5 = vec_perm(t1, t3, swiz3);
1211
+ t6 = vec_perm(t1, t3, swiz4);
1212
+ t7 = vec_perm(t2, t4, swiz3);
1213
+ t8 = vec_perm(t2, t4, swiz4);
1214
+ if (flip == true) {
1215
+ t5 = vec_xor(t5, xor_vector);
1216
+ t6 = vec_xor(t6, xor_vector);
1217
+ t7 = vec_xor(t7, xor_vector);
1218
+ t8 = vec_xor(t8, xor_vector);
1219
+ }
1220
+ vec_xst(t5, 0, vecOffset+192);
1221
+ vec_xst(t6, 0, vecOffset+208);
1222
+ vec_xst(t7, 0, vecOffset+224);
1223
+ vec_xst(t8, 0, vecOffset+240);
1224
+
1225
+ aoffset1 += lda;
1226
+ aoffset2 += lda;
1227
+ aoffset3 += lda;
1228
+ aoffset4 += lda;
1229
+ aoffset5 += lda;
1230
+ aoffset6 += lda;
1231
+ aoffset7 += lda;
1232
+ aoffset8 += lda;
1233
+ vecOffset += 256;
1234
+ i--;
1235
+ } while(i > 0);
1236
+ }
1237
+ j--;
1238
+ } while(j > 0);
1239
+ }
1240
+
1241
+ if (rows & 4) {
1242
+ aoffset1 = aoffset;
1243
+ aoffset2 = aoffset1 + lda;
1244
+ aoffset3 = aoffset2 + lda;
1245
+ aoffset4 = aoffset3 + lda;
1246
+ aoffset += 4 * lda;
1247
+
1248
+ i = (cols >> 3);
1249
+ if (i > 0) {
1250
+ do {
1251
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1252
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1253
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1254
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1255
+
1256
+ __builtin_vsx_disassemble_pair(c1, &C1);
1257
+ __builtin_vsx_disassemble_pair(c2, &C2);
1258
+ __builtin_vsx_disassemble_pair(c3, &C3);
1259
+ __builtin_vsx_disassemble_pair(c4, &C4);
1260
+
1261
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1262
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1263
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1264
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1265
+ t5 = vec_perm(t1, t3, swiz3);
1266
+ t6 = vec_perm(t1, t3, swiz4);
1267
+ t7 = vec_perm(t2, t4, swiz3);
1268
+ t8 = vec_perm(t2, t4, swiz4);
1269
+ if (flip == true) {
1270
+ t5 = vec_xor(t5, xor_vector);
1271
+ t6 = vec_xor(t6, xor_vector);
1272
+ t7 = vec_xor(t7, xor_vector);
1273
+ t8 = vec_xor(t8, xor_vector);
1274
+ }
1275
+ vec_xst(t5, 0, vecOffset);
1276
+ vec_xst(t6, 0, vecOffset+16);
1277
+ vec_xst(t7, 0, vecOffset+32);
1278
+ vec_xst(t8, 0, vecOffset+48);
1279
+
1280
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1281
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1282
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1283
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1284
+ t5 = vec_perm(t1, t3, swiz3);
1285
+ t6 = vec_perm(t1, t3, swiz4);
1286
+ t7 = vec_perm(t2, t4, swiz3);
1287
+ t8 = vec_perm(t2, t4, swiz4);
1288
+ if (flip == true) {
1289
+ t5 = vec_xor(t5, xor_vector);
1290
+ t6 = vec_xor(t6, xor_vector);
1291
+ t7 = vec_xor(t7, xor_vector);
1292
+ t8 = vec_xor(t8, xor_vector);
1293
+ }
1294
+ vec_xst(t5, 0, vecOffset+64);
1295
+ vec_xst(t6, 0, vecOffset+80);
1296
+ vec_xst(t7, 0, vecOffset+96);
1297
+ vec_xst(t8, 0, vecOffset+112);
1298
+
1299
+ aoffset1 += lda;
1300
+ aoffset2 += lda;
1301
+ aoffset3 += lda;
1302
+ aoffset4 += lda;
1303
+ vecOffset += 128;
1304
+ i--;
1305
+ } while(i > 0);
1306
+ }
1307
+ }
1308
+ if (rows & 3) {
1309
+ aoffset1 = aoffset;
1310
+ aoffset2 = aoffset1 + lda;
1311
+ aoffset3 = aoffset2 + lda;
1312
+ i = (cols >> 3);
1313
+ if (i > 0) {
1314
+ do {
1315
+ switch(rows) {
1316
+ case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1317
+ __builtin_vsx_disassemble_pair(c3, &C3);
1318
+ case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1319
+ __builtin_vsx_disassemble_pair(c2, &C2);
1320
+ case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1321
+ __builtin_vsx_disassemble_pair(c1, &C1);
1322
+ break;
1323
+ }
1324
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1325
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1326
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1327
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1328
+ t5 = vec_perm(t1, t3, swiz3);
1329
+ t6 = vec_perm(t1, t3, swiz4);
1330
+ t7 = vec_perm(t2, t4, swiz3);
1331
+ t8 = vec_perm(t2, t4, swiz4);
1332
+ if (flip == true) {
1333
+ t5 = vec_xor(t5, xor_vector);
1334
+ t6 = vec_xor(t6, xor_vector);
1335
+ t7 = vec_xor(t7, xor_vector);
1336
+ t8 = vec_xor(t8, xor_vector);
1337
+ }
1338
+ vec_xst(t5, 0, vecOffset);
1339
+ vec_xst(t6, 0, vecOffset+16);
1340
+ vec_xst(t7, 0, vecOffset+32);
1341
+ vec_xst(t8, 0, vecOffset+48);
1342
+
1343
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1344
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1345
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1346
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1347
+ t5 = vec_perm(t1, t3, swiz3);
1348
+ t6 = vec_perm(t1, t3, swiz4);
1349
+ t7 = vec_perm(t2, t4, swiz3);
1350
+ t8 = vec_perm(t2, t4, swiz4);
1351
+ if (flip == true) {
1352
+ t5 = vec_xor(t5, xor_vector);
1353
+ t6 = vec_xor(t6, xor_vector);
1354
+ t7 = vec_xor(t7, xor_vector);
1355
+ t8 = vec_xor(t8, xor_vector);
1356
+ }
1357
+ vec_xst(t5, 0, vecOffset+64);
1358
+ vec_xst(t6, 0, vecOffset+80);
1359
+ vec_xst(t7, 0, vecOffset+96);
1360
+ vec_xst(t8, 0, vecOffset+112);
1361
+
1362
+ aoffset1 += lda;
1363
+ aoffset2 += lda;
1364
+ aoffset3 += lda;
1365
+ vecOffset += 128;
1366
+ i--;
1367
+ } while(i > 0);
1368
+ }
1369
+ }
1370
+ }
1371
+
1372
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1373
+ int64_t mc, nc, mp, np;
1374
+ int m_rem = MIN(m - m0, 8);
1375
+ int n_rem = MIN(n - n0, 8);
1376
+ // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
1377
+ // issues. After resolving them, below code will be enabled.
1378
+ /*if (m_rem >= 16 && n_rem >= 8) {
1379
+ mc = 16;
1380
+ nc = 8;
1381
+ gemm<16,8>(m0, m, n0, n);
1382
+ } else if(m_rem >= 8 && n_rem >= 16) {
1383
+ mc = 8;
1384
+ nc = 16;
1385
+ gemm<8,16>(m0, m, n0, n);
1386
+ }*/
1387
+ if (m_rem >= 8 && n_rem >= 8) {
1388
+ mc = 8;
1389
+ nc = 8;
1390
+ gemm<8,8>(m0, m, n0, n);
1391
+ } else if (m_rem >= 4 && n_rem >= 8) {
1392
+ mc = 4;
1393
+ nc = 8;
1394
+ gemm<4,8>(m0, m, n0, n);
1395
+ } else if (m_rem >= 8 && n_rem >= 4) {
1396
+ mc = 8;
1397
+ nc = 4;
1398
+ gemm<8,4>(m0, m, n0, n);
1399
+ } else if (m_rem >= 4 && n_rem >= 4) {
1400
+ mc = 4;
1401
+ nc = 4;
1402
+ gemm_small<4, 4>(m0, m, n0, n);
1403
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1404
+ nc = 4;
1405
+ switch(m_rem) {
1406
+ case 1:
1407
+ mc = 1;
1408
+ gemm_small<1, 4>(m0, m, n0, n);
1409
+ break;
1410
+ case 2:
1411
+ mc = 2;
1412
+ gemm_small<2, 4>(m0, m, n0, n);
1413
+ break;
1414
+ case 3:
1415
+ mc = 3;
1416
+ gemm_small<3, 4>(m0, m, n0, n);
1417
+ break;
1418
+ default:
1419
+ return;
1420
+ }
1421
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1422
+ mc = 4;
1423
+ switch(n_rem) {
1424
+ case 1:
1425
+ nc = 1;
1426
+ gemm_small<4, 1>(m0, m, n0, n);
1427
+ break;
1428
+ case 2:
1429
+ nc = 2;
1430
+ gemm_small<4, 2>(m0, m, n0, n);
1431
+ break;
1432
+ case 3:
1433
+ nc = 3;
1434
+ gemm_small<4, 3>(m0, m, n0, n);
1435
+ break;
1436
+ default:
1437
+ return;
1438
+ }
1439
+ } else {
1440
+ switch((m_rem << 4) | n_rem) {
1441
+ case 0x43:
1442
+ mc = 4;
1443
+ nc = 3;
1444
+ gemm_small<4, 3>(m0, m, n0, n);
1445
+ break;
1446
+ case 0x42:
1447
+ mc = 4;
1448
+ nc = 2;
1449
+ gemm_small<4, 2>(m0, m, n0, n);
1450
+ break;
1451
+ case 0x41:
1452
+ mc = 4;
1453
+ nc = 1;
1454
+ gemm_small<4, 1>(m0, m, n0, n);
1455
+ break;
1456
+ case 0x34:
1457
+ mc = 3;
1458
+ nc = 4;
1459
+ gemm_small<3, 4>(m0, m, n0, n);
1460
+ break;
1461
+ case 0x33:
1462
+ mc = 3;
1463
+ nc = 3;
1464
+ gemm_small<3, 3>(m0, m, n0, n);
1465
+ break;
1466
+ case 0x32:
1467
+ mc = 3;
1468
+ nc = 2;
1469
+ gemm_small<3, 2>(m0, m, n0, n);
1470
+ break;
1471
+ case 0x31:
1472
+ mc = 3;
1473
+ nc = 1;
1474
+ gemm_small<3, 1>(m0, m, n0, n);
1475
+ break;
1476
+ case 0x24:
1477
+ mc = 2;
1478
+ nc = 4;
1479
+ gemm_small<2, 4>(m0, m, n0, n);
1480
+ break;
1481
+ case 0x23:
1482
+ mc = 2;
1483
+ nc = 3;
1484
+ gemm_small<2, 3>(m0, m, n0, n);
1485
+ break;
1486
+ case 0x22:
1487
+ mc = 2;
1488
+ nc = 2;
1489
+ gemm_small<2, 2>(m0, m, n0, n);
1490
+ break;
1491
+ case 0x21:
1492
+ mc = 2;
1493
+ nc = 1;
1494
+ gemm_small<2, 1>(m0, m, n0, n);
1495
+ break;
1496
+ case 0x14:
1497
+ mc = 1;
1498
+ nc = 4;
1499
+ gemm_small<1, 4>(m0, m, n0, n);
1500
+ break;
1501
+ case 0x13:
1502
+ mc = 1;
1503
+ nc = 3;
1504
+ gemm_small<1, 3>(m0, m, n0, n);
1505
+ break;
1506
+ case 0x12:
1507
+ mc = 1;
1508
+ nc = 2;
1509
+ gemm_small<1, 2>(m0, m, n0, n);
1510
+ break;
1511
+ case 0x11:
1512
+ mc = 1;
1513
+ nc = 1;
1514
+ gemm_small<1, 1>(m0, m, n0, n);
1515
+ break;
1516
+ default:
1517
+ return;
1518
+ }
1519
+ }
1520
+ mp = m0 + (m - m0) / mc * mc;
1521
+ np = n0 + (n - n0) / nc * nc;
1522
+ mnpack(mp, m, n0, np);
1523
+ mnpack(m0, m, np, n);
1524
+ }
1525
+
1526
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1527
+ vec_t vec_A[8], vec_B[16] = {0};
1528
+ acc_t acc_0, acc_1;
1529
+ std::array<int, 4> comparray;
1530
+ vector float fin_res[8] = {0};
1531
+ vector float vs[8] = {0};
1532
+ for (int l = 0; l < k; l++) {
1533
+ __builtin_mma_xxsetaccz(&acc_0);
1534
+ __builtin_mma_xxsetaccz(&acc_1);
1535
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1536
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1537
+ for(int x = 0; x < 8; x++) {
1538
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1539
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
1540
+ }
1541
+ for (int I = 0; I<4; I++) {
1542
+ for (int J = 0; J<4; J++) {
1543
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1544
+ *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1545
+ }
1546
+ }
1547
+ auto aoffset = A+(ii*lda)+l;
1548
+ for (int i = 0; i < 4; i++) {
1549
+ comparray[i] = 0;
1550
+ int ca = 0;
1551
+ const int8_t *at = aoffset->qs;
1552
+ for (int j = 0; j < 32; j++)
1553
+ ca += (int)*at++;
1554
+ comparray[i] = ca;
1555
+ aoffset += lda;
1556
+ }
1557
+ compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1558
+ compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
1559
+ }
1560
+ save_res<4, 4>(ii, jj, 0, fin_res);
1561
+ save_res<4, 4>(ii, jj+4, 4, fin_res);
1562
+ }
1563
+
1564
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1565
+ vec_t vec_A[16], vec_B[8] = {0};
1566
+ acc_t acc_0, acc_1;
1567
+ std::array<int, 8> comparray;
1568
+ vector float fin_res[8] = {0};
1569
+ vector float vs[8] = {0};
1570
+ for (int l = 0; l < k; l++) {
1571
+ __builtin_mma_xxsetaccz(&acc_0);
1572
+ __builtin_mma_xxsetaccz(&acc_1);
1573
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1574
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1575
+ for(int x = 0; x < 8; x++) {
1576
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1577
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1578
+ }
1579
+ for (int I = 0; I<8; I++) {
1580
+ for (int J = 0; J<4; J++) {
1581
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1582
+ }
1583
+ }
1584
+ auto aoffset = A+(ii*lda)+l;
1585
+ for (int i = 0; i < 8; i++) {
1586
+ comparray[i] = 0;
1587
+ int ca = 0;
1588
+ const int8_t *at = aoffset->qs;
1589
+ for (int j = 0; j < 32; j++)
1590
+ ca += (int)*at++;
1591
+ comparray[i] = ca;
1592
+ aoffset += lda;
1593
+ }
1594
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1595
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1596
+ }
1597
+ save_res<4, 4>(ii, jj, 0, fin_res);
1598
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1599
+ }
1600
+
1601
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1602
+ vec_t vec_A[16], vec_B[16] = {0};
1603
+ acc_t acc_0, acc_1, acc_2, acc_3;
1604
+ std::array<int, 8> comparray;
1605
+ vector float fin_res[16] = {0};
1606
+ vector float vs[16] = {0};
1607
+ for (int l = 0; l < k; l++) {
1608
+ __builtin_mma_xxsetaccz(&acc_0);
1609
+ __builtin_mma_xxsetaccz(&acc_1);
1610
+ __builtin_mma_xxsetaccz(&acc_2);
1611
+ __builtin_mma_xxsetaccz(&acc_3);
1612
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1613
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1614
+ for(int x = 0; x < 8; x++) {
1615
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1616
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1617
+ __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
1618
+ __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
1619
+ }
1620
+ for (int I = 0; I<8; I++) {
1621
+ for (int J = 0; J<4; J++) {
1622
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1623
+ *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1624
+ }
1625
+ }
1626
+ auto aoffset = A+(ii*lda)+l;
1627
+ for (int i = 0; i < 8; i++) {
1628
+ comparray[i] = 0;
1629
+ int ca = 0;
1630
+ const int8_t *at = aoffset->qs;
1631
+ for (int j = 0; j < 32; j++)
1632
+ ca += (int)*at++;
1633
+ comparray[i] = ca;
1634
+ aoffset += lda;
1635
+ }
1636
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1637
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1638
+ compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
1639
+ compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
1640
+ }
1641
+ save_res<4, 4>(ii, jj, 0, fin_res);
1642
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1643
+ save_res<4, 4>(ii, jj+4, 8, fin_res);
1644
+ save_res<4, 4>(ii+4, jj+4, 12, fin_res);
1645
+ }
1646
+
1647
+ template<int RM, int RN>
1648
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1649
+ int64_t ytiles = (m - m0) / RM;
1650
+ int64_t xtiles = (n - n0) / RN;
1651
+ int64_t tiles = xtiles * ytiles;
1652
+ int64_t duty = (tiles + nth - 1) / nth;
1653
+ int64_t start = duty * ith;
1654
+ int64_t end = start + duty;
1655
+ vec_t vec_A[8], vec_B[8] = {0};
1656
+ vector signed int vec_C[4];
1657
+ acc_t acc_0;
1658
+
1659
+ if (end > tiles)
1660
+ end = tiles;
1661
+ for (int64_t job = start; job < end; ++job) {
1662
+ int64_t ii = m0 + job / xtiles * RM;
1663
+ int64_t jj = n0 + job % xtiles * RN;
1664
+ std::array<int, RM> comparray;
1665
+ vector float res[4] = {0};
1666
+ vector float fin_res[4] = {0};
1667
+ vector float vs[4] = {0};
1668
+ vector float CA[4] = {0};
1669
+ __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
1670
+ __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
1671
+ for (int l = 0; l < k; l++) {
1672
+ __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1673
+ __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
+ __builtin_mma_xxsetaccz(&acc_0);
1675
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
1676
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1677
+ for(int x = 0; x < 8; x+=4) {
1678
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1679
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
1680
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
1681
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
1682
+ }
1683
+ for (int I = 0; I<RM; I++) {
1684
+ for (int J = 0; J<RN; J++) {
1685
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1686
+ }
1687
+ }
1688
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1689
+ auto aoffset = A+(ii*lda)+l;
1690
+ for (int i = 0; i < RM; i++) {
1691
+ comparray[i] = 0;
1692
+ int ca = 0;
1693
+ const int8_t *at = aoffset->qs;
1694
+ for (int j = 0; j < 32; j++)
1695
+ ca += (int)*at++;
1696
+ comparray[i] = ca;
1697
+ aoffset += lda;
1698
+ }
1699
+
1700
+ for (int i = 0; i < RM; i++) {
1701
+ CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1702
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1703
+ fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
1704
+ }
1705
+ }
1706
+ save_res<RM, RN>(ii, jj, 0, fin_res);
1707
+ }
1708
+ }
1709
+
1710
+ template<int RM, int RN>
1711
+ inline void kernel(int64_t ii, int64_t jj) {
1712
+ if constexpr(RM == 4 && RN == 8) {
1713
+ KERNEL_4x8(ii,jj);
1714
+ } else if constexpr(RM == 8 && RN == 4) {
1715
+ KERNEL_8x4(ii,jj);
1716
+ } else if constexpr(RM == 8 && RN == 8) {
1717
+ KERNEL_8x8(ii,jj);
1718
+ } else {
1719
+ static_assert(false, "RN/RM values not supported");
1720
+ }
1721
+ }
1722
+
1723
+ template <int RM, int RN>
1724
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1725
+ int64_t ytiles = (m - m0) / RM;
1726
+ int64_t xtiles = (n - n0) / RN;
1727
+ int64_t tiles = xtiles * ytiles;
1728
+ int64_t duty = (tiles + nth - 1) / nth;
1729
+ int64_t start = duty * ith;
1730
+ int64_t end = start + duty;
1731
+ if (end > tiles)
1732
+ end = tiles;
1733
+ for (int64_t job = start; job < end; ++job) {
1734
+ int64_t ii = m0 + job / xtiles * RM;
1735
+ int64_t jj = n0 + job % xtiles * RN;
1736
+ kernel<RM, RN>(ii, jj);
1737
+ }
1738
+ }
1739
+
1740
+ const TA *const A;
1741
+ const TB *const B;
1742
+ TC *C;
1743
+ TA *At;
1744
+ TB *Bt;
1745
+ const int64_t k;
1746
+ const int64_t lda;
1747
+ const int64_t ldb;
1748
+ const int64_t ldc;
1749
+ const int ith;
1750
+ const int nth;
1751
+ };
1752
+
1753
+ template <typename TA, typename TB, typename TC>
1754
+ class tinyBLAS_PPC {
1755
+ public:
1756
+ tinyBLAS_PPC(int64_t k,
1757
+ const TA *A, int64_t lda,
1758
+ const TB *B, int64_t ldb,
1759
+ TC *C, int64_t ldc,
1760
+ int ith, int nth)
1761
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1762
+ }
1763
+
1764
+ void matmul(int64_t m, int64_t n) {
1765
+ mnpack(0, m, 0, n);
1766
+ }
1767
+
1768
+ private:
1769
+
1770
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1771
+
1772
+ template<typename VA>
1773
+ void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
1774
+ int64_t i, j;
1775
+ TA *aoffset = NULL, *boffset = NULL;
1776
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1777
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1778
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1779
+ VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1780
+ VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1781
+ VA t1, t2, t3, t4, t5, t6, t7, t8;
1782
+ aoffset = const_cast<TA*>(a);
1783
+ boffset = vec;
1784
+ j = (rows >> 3);
1785
+ if (j > 0) {
1786
+ do {
1787
+ aoffset1 = aoffset;
1788
+ aoffset2 = aoffset1 + lda;
1789
+ aoffset3 = aoffset2 + lda;
1790
+ aoffset4 = aoffset3 + lda;
1791
+ aoffset5 = aoffset4 + lda;
1792
+ aoffset6 = aoffset5 + lda;
1793
+ aoffset7 = aoffset6 + lda;
1794
+ aoffset8 = aoffset7 + lda;
1795
+ aoffset += 8 * lda;
1796
+ i = (cols >> 3);
1797
+ if (i > 0) {
1798
+ do {
1799
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1800
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1801
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1802
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1803
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1804
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1805
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1806
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1807
+ __builtin_vsx_disassemble_pair(c1, &C1);
1808
+ __builtin_vsx_disassemble_pair(c2, &C2);
1809
+ __builtin_vsx_disassemble_pair(c3, &C3);
1810
+ __builtin_vsx_disassemble_pair(c4, &C4);
1811
+ __builtin_vsx_disassemble_pair(c5, &C5);
1812
+ __builtin_vsx_disassemble_pair(c6, &C6);
1813
+ __builtin_vsx_disassemble_pair(c7, &C7);
1814
+ __builtin_vsx_disassemble_pair(c8, &C8);
1815
+
1816
+ t1 = vec_mergeh(c1[0], c2[0]);
1817
+ t2 = vec_mergeh(c3[0], c4[0]);
1818
+ t3 = vec_mergeh(c5[0], c6[0]);
1819
+ t4 = vec_mergeh(c7[0], c8[0]);
1820
+ t5 = vec_xxpermdi(t1, t2, 0);
1821
+ t6 = vec_xxpermdi(t3, t4, 0);
1822
+ t7 = vec_xxpermdi(t1, t2, 3);
1823
+ t8 = vec_xxpermdi(t3, t4, 3);
1824
+ vec_xst(t5, 0, boffset);
1825
+ vec_xst(t6, 0, boffset+4);
1826
+ vec_xst(t7, 0, boffset+8);
1827
+ vec_xst(t8, 0, boffset+12);
1828
+
1829
+ t1 = vec_mergel(c1[0], c2[0]);
1830
+ t2 = vec_mergel(c3[0], c4[0]);
1831
+ t3 = vec_mergel(c5[0], c6[0]);
1832
+ t4 = vec_mergel(c7[0], c8[0]);
1833
+ t5 = vec_xxpermdi(t1, t2, 0);
1124
1834
  t6 = vec_xxpermdi(t3, t4, 0);
1125
1835
  t7 = vec_xxpermdi(t1, t2, 3);
1126
1836
  t8 = vec_xxpermdi(t3, t4, 3);
@@ -1164,21 +1874,19 @@ class tinyBLAS_PPC {
1164
1874
  } while(i > 0);
1165
1875
  }
1166
1876
  if (cols & 4) {
1167
- vector float c1, c2, c3, c4, c5, c6, c7, c8;
1168
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1169
- c1 = vec_xl(0, aoffset1);
1170
- c2 = vec_xl(0, aoffset2);
1171
- c3 = vec_xl(0, aoffset3);
1172
- c4 = vec_xl(0, aoffset4);
1173
- c5 = vec_xl(0, aoffset5);
1174
- c6 = vec_xl(0, aoffset6);
1175
- c7 = vec_xl(0, aoffset7);
1176
- c8 = vec_xl(0, aoffset8);
1177
-
1178
- t1 = vec_mergeh(c1, c2);
1179
- t2 = vec_mergeh(c3, c4);
1180
- t3 = vec_mergeh(c5, c6);
1181
- t4 = vec_mergeh(c7, c8);
1877
+ c1[0] = vec_xl(0, aoffset1);
1878
+ c2[0] = vec_xl(0, aoffset2);
1879
+ c3[0] = vec_xl(0, aoffset3);
1880
+ c4[0] = vec_xl(0, aoffset4);
1881
+ c5[0] = vec_xl(0, aoffset5);
1882
+ c6[0] = vec_xl(0, aoffset6);
1883
+ c7[0] = vec_xl(0, aoffset7);
1884
+ c8[0] = vec_xl(0, aoffset8);
1885
+
1886
+ t1 = vec_mergeh(c1[0], c2[0]);
1887
+ t2 = vec_mergeh(c3[0], c4[0]);
1888
+ t3 = vec_mergeh(c5[0], c6[0]);
1889
+ t4 = vec_mergeh(c7[0], c8[0]);
1182
1890
  t5 = vec_xxpermdi(t1, t2, 0);
1183
1891
  t6 = vec_xxpermdi(t3, t4, 0);
1184
1892
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1188,10 +1896,10 @@ class tinyBLAS_PPC {
1188
1896
  vec_xst(t7, 0, boffset+8);
1189
1897
  vec_xst(t8, 0, boffset+12);
1190
1898
 
1191
- t1 = vec_mergel(c1, c2);
1192
- t2 = vec_mergel(c3, c4);
1193
- t3 = vec_mergel(c5, c6);
1194
- t4 = vec_mergel(c7, c8);
1899
+ t1 = vec_mergel(c1[0], c2[0]);
1900
+ t2 = vec_mergel(c3[0], c4[0]);
1901
+ t3 = vec_mergel(c5[0], c6[0]);
1902
+ t4 = vec_mergel(c7[0], c8[0]);
1195
1903
  t5 = vec_xxpermdi(t1, t2, 0);
1196
1904
  t6 = vec_xxpermdi(t3, t4, 0);
1197
1905
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1213,9 +1921,6 @@ class tinyBLAS_PPC {
1213
1921
  aoffset += 4 * lda;
1214
1922
  i = (cols >> 3);
1215
1923
  if (i > 0) {
1216
- __vector_pair C1, C2, C3, C4;
1217
- vector float c1[2], c2[2], c3[2], c4[2];
1218
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1219
1924
  do {
1220
1925
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1221
1926
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1262,22 +1967,20 @@ class tinyBLAS_PPC {
1262
1967
  }
1263
1968
 
1264
1969
  if (cols & 4) {
1265
- vector float c1, c2, c3, c4;
1266
- vector float t1, t2, t3, t4;
1267
- c1 = vec_xl(0, aoffset1);
1268
- c2 = vec_xl(0, aoffset2);
1269
- c3 = vec_xl(0, aoffset3);
1270
- c4 = vec_xl(0, aoffset4);
1271
-
1272
- t1 = vec_mergeh(c1, c2);
1273
- t2 = vec_mergeh(c3, c4);
1970
+ c1[0] = vec_xl(0, aoffset1);
1971
+ c2[0] = vec_xl(0, aoffset2);
1972
+ c3[0] = vec_xl(0, aoffset3);
1973
+ c4[0] = vec_xl(0, aoffset4);
1974
+
1975
+ t1 = vec_mergeh(c1[0], c2[0]);
1976
+ t2 = vec_mergeh(c3[0], c4[0]);
1274
1977
  t3 = vec_xxpermdi(t1, t2, 0);
1275
1978
  t4 = vec_xxpermdi(t1, t2, 3);
1276
1979
  vec_xst(t3, 0, boffset);
1277
1980
  vec_xst(t4, 0, boffset+4);
1278
1981
 
1279
- t1 = vec_mergel(c1, c2);
1280
- t2 = vec_mergel(c3, c4);
1982
+ t1 = vec_mergel(c1[0], c2[0]);
1983
+ t2 = vec_mergel(c3[0], c4[0]);
1281
1984
  t3 = vec_xxpermdi(t1, t2, 0);
1282
1985
  t4 = vec_xxpermdi(t1, t2, 3);
1283
1986
  vec_xst(t3, 0, boffset+8);
@@ -1289,21 +1992,19 @@ class tinyBLAS_PPC {
1289
1992
  aoffset2 = aoffset1 + lda;
1290
1993
  aoffset3 = aoffset2 + lda;
1291
1994
  if (cols & 4) {
1292
- vector float c1, c2, c3, c4 = {0};
1293
- vector float t1, t2, t3, t4;
1294
- c1 = vec_xl(0, aoffset1);
1295
- c2 = vec_xl(0, aoffset2);
1296
- c3 = vec_xl(0, aoffset3);
1297
-
1298
- t1 = vec_mergeh(c1, c2);
1299
- t2 = vec_mergeh(c3, c4);
1995
+ c1[0] = vec_xl(0, aoffset1);
1996
+ c2[0] = vec_xl(0, aoffset2);
1997
+ c3[0] = vec_xl(0, aoffset3);
1998
+
1999
+ t1 = vec_mergeh(c1[0], c2[0]);
2000
+ t2 = vec_mergeh(c3[0], c4[0]);
1300
2001
  t3 = vec_xxpermdi(t1, t2, 0);
1301
2002
  t4 = vec_xxpermdi(t1, t2, 3);
1302
2003
  vec_xst(t3, 0, boffset);
1303
2004
  vec_xst(t4, 0, boffset+4);
1304
2005
 
1305
- t1 = vec_mergel(c1, c2);
1306
- t2 = vec_mergel(c3, c4);
2006
+ t1 = vec_mergel(c1[0], c2[0]);
2007
+ t2 = vec_mergel(c3[0], c4[0]);
1307
2008
  t3 = vec_xxpermdi(t1, t2, 0);
1308
2009
  t4 = vec_xxpermdi(t1, t2, 3);
1309
2010
  vec_xst(t3, 0, boffset+8);
@@ -1311,14 +2012,13 @@ class tinyBLAS_PPC {
1311
2012
  }
1312
2013
  }
1313
2014
  }
1314
-
1315
2015
  void KERNEL_4x4(int64_t ii, int64_t jj) {
1316
2016
  vec_t vec_A[4], vec_B[4], vec_C[4];
1317
2017
  acc_t acc_0;
1318
2018
  __builtin_mma_xxsetaccz(&acc_0);
1319
2019
  for (int l = 0; l < k; l+=4) {
1320
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1321
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2020
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2021
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1322
2022
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1323
2023
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1324
2024
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1333,8 +2033,8 @@ class tinyBLAS_PPC {
1333
2033
  __builtin_mma_xxsetaccz(&acc_0);
1334
2034
  __builtin_mma_xxsetaccz(&acc_1);
1335
2035
  for (int64_t l = 0; l < k; l+=4) {
1336
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1337
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2036
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2037
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
1338
2038
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1339
2039
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1340
2040
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1354,8 +2054,8 @@ class tinyBLAS_PPC {
1354
2054
  __builtin_mma_xxsetaccz(&acc_0);
1355
2055
  __builtin_mma_xxsetaccz(&acc_1);
1356
2056
  for (int64_t l = 0; l < k; l+=4) {
1357
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1358
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2057
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2058
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1359
2059
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1360
2060
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1361
2061
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1377,8 +2077,8 @@ class tinyBLAS_PPC {
1377
2077
  __builtin_mma_xxsetaccz(&acc_2);
1378
2078
  __builtin_mma_xxsetaccz(&acc_3);
1379
2079
  for (int l = 0; l < k; l+=8) {
1380
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1381
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
2080
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
2081
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
1382
2082
  for(int x = 0; x < 16; x+=2) {
1383
2083
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1384
2084
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1561,15 +2261,15 @@ class tinyBLAS_PPC {
1561
2261
  vec_t vec_A[4], vec_B[4];
1562
2262
  for (int l=0; l<k; l+=4) {
1563
2263
  if (RN >= 4 && RM == 1) {
1564
- float* a = const_cast<float*>(A+(ii)*lda+l);
1565
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2264
+ TA* a = const_cast<TA*>(A+(ii)*lda+l);
2265
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1566
2266
  vec_A[0] = (vec_t)vec_xl(0,a);
1567
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1568
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1569
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
2267
+ vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2268
+ vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2269
+ vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
1570
2270
  } else {
1571
- READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1572
- READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2271
+ packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2272
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
1573
2273
  }
1574
2274
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1575
2275
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1579,7 +2279,7 @@ class tinyBLAS_PPC {
1579
2279
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1580
2280
  for (int I = 0; I < RM; I++) {
1581
2281
  for (int J = 0; J < RN; J++) {
1582
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
2282
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1583
2283
  }
1584
2284
  }
1585
2285
  }
@@ -1656,8 +2356,9 @@ class tinyBLAS_PPC {
1656
2356
  * @param Ctype is GGML data type of `C`
1657
2357
  * @return true if this function was able to service the matmul request
1658
2358
  */
1659
- bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1660
- int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
2359
+ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
2360
+ const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
2361
+ int64_t ldc, int Atype, int Btype, int Ctype) {
1661
2362
 
1662
2363
  assert(m >= 0);
1663
2364
  assert(n >= 0);
@@ -1665,8 +2366,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1665
2366
  assert(lda >= k);
1666
2367
  assert(ldb >= k);
1667
2368
  assert(ldc >= m);
1668
- assert(nth > 0);
1669
- assert(ith < nth);
2369
+ assert(params->nth > 0);
2370
+ assert(params->ith < params->nth);
1670
2371
 
1671
2372
  // only enable sgemm for prompt processing
1672
2373
  if (n < 2)
@@ -1681,37 +2382,25 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1681
2382
  if (Btype != LM_GGML_TYPE_F32)
1682
2383
  return false;
1683
2384
  #if defined(__AVX512F__)
1684
- if (k % 16)
1685
- return false;
1686
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
2385
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
1687
2386
  k, (const float *)A, lda,
1688
2387
  (const float *)B, ldb,
1689
- (float *)C, ldc,
1690
- ith, nth};
1691
- tb.matmul(m, n);
1692
- return true;
2388
+ (float *)C, ldc};
2389
+ return tb.matmul(m, n);
1693
2390
  #elif defined(__AVX__) || defined(__AVX2__)
1694
- if (k % 8)
1695
- return false;
1696
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
2391
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
1697
2392
  k, (const float *)A, lda,
1698
2393
  (const float *)B, ldb,
1699
- (float *)C, ldc,
1700
- ith, nth};
1701
- tb.matmul(m, n);
1702
- return true;
2394
+ (float *)C, ldc};
2395
+ return tb.matmul(m, n);
1703
2396
  #elif defined(__ARM_NEON)
1704
2397
  if (n < 4)
1705
2398
  return false;
1706
- if (k % 4)
1707
- return false;
1708
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
2399
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
1709
2400
  k, (const float *)A, lda,
1710
2401
  (const float *)B, ldb,
1711
- (float *)C, ldc,
1712
- ith, nth};
1713
- tb.matmul(m, n);
1714
- return true;
2402
+ (float *)C, ldc};
2403
+ return tb.matmul(m, n);
1715
2404
  #elif defined(__MMA__)
1716
2405
  if (k % 8)
1717
2406
  return false;
@@ -1719,7 +2408,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1719
2408
  k, (const float *)A, lda,
1720
2409
  (const float *)B, ldb,
1721
2410
  (float *)C, ldc,
1722
- ith, nth};
2411
+ params->ith, params->nth};
1723
2412
  tb.matmul(m, n);
1724
2413
  return true;
1725
2414
  #else
@@ -1727,60 +2416,71 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1727
2416
  #endif
1728
2417
  }
1729
2418
 
2419
+ case LM_GGML_TYPE_BF16: {
2420
+ #if defined(__AVX512BF16__)
2421
+ if (Btype == LM_GGML_TYPE_BF16) {
2422
+ tinyBLAS<32, __m512, __m512bh, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
2423
+ (const lm_ggml_bf16_t *)A, lda,
2424
+ (const lm_ggml_bf16_t *)B, ldb,
2425
+ (float *)C, ldc};
2426
+ return tb.matmul(m, n);
2427
+ }
2428
+ #elif defined(__AVX512F__)
2429
+ if (Btype == LM_GGML_TYPE_BF16) {
2430
+ tinyBLAS<16, __m512, __m512, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
2431
+ (const lm_ggml_bf16_t *)A, lda,
2432
+ (const lm_ggml_bf16_t *)B, ldb,
2433
+ (float *)C, ldc};
2434
+ return tb.matmul(m, n);
2435
+ }
2436
+ #elif defined(__AVX2__)
2437
+ if (Btype == LM_GGML_TYPE_BF16) {
2438
+ tinyBLAS<8, __m256, __m256, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
2439
+ (const lm_ggml_bf16_t *)A, lda,
2440
+ (const lm_ggml_bf16_t *)B, ldb,
2441
+ (float *)C, ldc};
2442
+ return tb.matmul(m, n);
2443
+ }
2444
+ #endif
2445
+ return false;
2446
+ }
1730
2447
  case LM_GGML_TYPE_F16: {
1731
2448
  #if defined(__AVX512F__)
1732
- if (k % 16)
1733
- return false;
1734
- if (Btype != LM_GGML_TYPE_F32)
1735
- return false;
1736
- tinyBLAS<16, __m512, __m512, lm_ggml_fp16_t, float, float> tb{
1737
- k, (const lm_ggml_fp16_t *)A, lda,
1738
- (const float *)B, ldb,
1739
- (float *)C, ldc,
1740
- ith, nth};
1741
- tb.matmul(m, n);
1742
- return true;
2449
+ if (Btype == LM_GGML_TYPE_F16) {
2450
+ tinyBLAS<16, __m512, __m512, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params, k,
2451
+ (const lm_ggml_fp16_t *)A, lda,
2452
+ (const lm_ggml_fp16_t *)B, ldb,
2453
+ (float *)C, ldc};
2454
+ return tb.matmul(m, n);
2455
+ }
1743
2456
  #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
1744
- if (k % 8)
1745
- return false;
1746
- if (Btype != LM_GGML_TYPE_F32)
1747
- return false;
1748
- tinyBLAS<8, __m256, __m256, lm_ggml_fp16_t, float, float> tb{
1749
- k, (const lm_ggml_fp16_t *)A, lda,
1750
- (const float *)B, ldb,
1751
- (float *)C, ldc,
1752
- ith, nth};
1753
- tb.matmul(m, n);
1754
- return true;
2457
+ if (Btype == LM_GGML_TYPE_F16) {
2458
+ tinyBLAS<8, __m256, __m256, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params, k,
2459
+ (const lm_ggml_fp16_t *)A, lda,
2460
+ (const lm_ggml_fp16_t *)B, ldb,
2461
+ (float *)C, ldc};
2462
+ return tb.matmul(m, n);
2463
+ }
1755
2464
  #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
1756
2465
  if (n < 8)
1757
2466
  return false;
1758
- if (k % 8)
1759
- return false;
1760
- if (Btype != LM_GGML_TYPE_F16)
1761
- return false;
1762
- tinyBLAS<8, float16x8_t, float16x8_t, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{
1763
- k, (const lm_ggml_fp16_t *)A, lda,
1764
- (const lm_ggml_fp16_t *)B, ldb,
1765
- (float *)C, ldc,
1766
- ith, nth};
1767
- tb.matmul(m, n);
1768
- return true;
2467
+ if (Btype == LM_GGML_TYPE_F16) {
2468
+ tinyBLAS<8, float16x8_t, float16x8_t, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params,
2469
+ k, (const lm_ggml_fp16_t *)A, lda,
2470
+ (const lm_ggml_fp16_t *)B, ldb,
2471
+ (float *)C, ldc};
2472
+ return tb.matmul(m, n);
2473
+ }
1769
2474
  #elif defined(__ARM_NEON) && !defined(_MSC_VER)
1770
- if (k % 4)
1771
- return false;
1772
- if (Btype != LM_GGML_TYPE_F32)
1773
- return false;
1774
- tinyBLAS<4, float32x4_t, float32x4_t, lm_ggml_fp16_t, float, float> tb{
1775
- k, (const lm_ggml_fp16_t *)A, lda,
1776
- (const float *)B, ldb,
1777
- (float *)C, ldc,
1778
- ith, nth};
1779
- tb.matmul(m, n);
1780
- return true;
1781
- #else
1782
- return false;
2475
+ if (Btype == LM_GGML_TYPE_F32) {
2476
+ tinyBLAS<4, float32x4_t, float32x4_t, lm_ggml_fp16_t, float, float> tb{ params,
2477
+ k, (const lm_ggml_fp16_t *)A, lda,
2478
+ (const float *)B, ldb,
2479
+ (float *)C, ldc};
2480
+ return tb.matmul(m, n);
2481
+ }
1783
2482
  #endif
2483
+ return false;
1784
2484
  }
1785
2485
 
1786
2486
  case LM_GGML_TYPE_Q8_0: {
@@ -1791,7 +2491,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1791
2491
  k, (const block_q8_0 *)A, lda,
1792
2492
  (const block_q8_0 *)B, ldb,
1793
2493
  (float *)C, ldc,
1794
- ith, nth};
2494
+ params->ith, params->nth};
1795
2495
  tb.matmul(m, n);
1796
2496
  return true;
1797
2497
  #elif defined(__ARM_FEATURE_DOTPROD)
@@ -1799,9 +2499,23 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1799
2499
  k, (const block_q8_0 *)A, lda,
1800
2500
  (const block_q8_0 *)B, ldb,
1801
2501
  (float *)C, ldc,
1802
- ith, nth};
2502
+ params->ith, params->nth};
2503
+ tb.matmul(m, n);
2504
+ return true;
2505
+
2506
+ #elif defined(__MMA__)
2507
+ if (n < 8 && n != 4)
2508
+ return false;
2509
+ if (m < 8 && m != 4)
2510
+ return false;
2511
+ tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2512
+ k, (const block_q8_0 *)A, lda,
2513
+ (const block_q8_0 *)B, ldb,
2514
+ (float *)C, ldc,
2515
+ params->ith, params->nth};
1803
2516
  tb.matmul(m, n);
1804
2517
  return true;
2518
+
1805
2519
  #else
1806
2520
  return false;
1807
2521
  #endif
@@ -1815,7 +2529,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1815
2529
  k, (const block_q4_0 *)A, lda,
1816
2530
  (const block_q8_0 *)B, ldb,
1817
2531
  (float *)C, ldc,
1818
- ith, nth};
2532
+ params->ith, params->nth};
1819
2533
  tb.matmul(m, n);
1820
2534
  return true;
1821
2535
  #elif defined(__ARM_FEATURE_DOTPROD)
@@ -1823,7 +2537,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1823
2537
  k, (const block_q4_0 *)A, lda,
1824
2538
  (const block_q8_0 *)B, ldb,
1825
2539
  (float *)C, ldc,
1826
- ith, nth};
2540
+ params->ith, params->nth};
1827
2541
  tb.matmul(m, n);
1828
2542
  return true;
1829
2543
  #else
@@ -1839,7 +2553,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1839
2553
  k, (const block_q5_0 *)A, lda,
1840
2554
  (const block_q8_0 *)B, ldb,
1841
2555
  (float *)C, ldc,
1842
- ith, nth};
2556
+ params->ith, params->nth};
1843
2557
  tb.matmul(m, n);
1844
2558
  return true;
1845
2559
  #else
@@ -1855,7 +2569,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1855
2569
  k, (const block_iq4_nl *)A, lda,
1856
2570
  (const block_q8_0 *)B, ldb,
1857
2571
  (float *)C, ldc,
1858
- ith, nth};
2572
+ params->ith, params->nth};
1859
2573
  tb.matmul(m, n);
1860
2574
  return true;
1861
2575
  #else
@@ -1867,6 +2581,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1867
2581
  return false;
1868
2582
  }
1869
2583
 
2584
+ (void)params;
1870
2585
  (void)m;
1871
2586
  (void)n;
1872
2587
  (void)k;
@@ -1876,8 +2591,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1876
2591
  (void)ldb;
1877
2592
  (void)C;
1878
2593
  (void)ldc;
1879
- (void)ith;
1880
- (void)nth;
1881
2594
  (void)Atype;
1882
2595
  (void)Btype;
1883
2596
  (void)Ctype;