@fugood/llama.node 1.1.10 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +20 -2
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +174 -388
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +67 -37
  27. package/src/llama.cpp/common/chat.cpp +263 -2
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.cpp +10 -3
  30. package/src/llama.cpp/common/common.h +5 -2
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  39. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  40. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  45. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  46. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
  53. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
  54. package/src/llama.cpp/include/llama.h +32 -7
  55. package/src/llama.cpp/src/llama-adapter.cpp +101 -4
  56. package/src/llama.cpp/src/llama-adapter.h +6 -0
  57. package/src/llama.cpp/src/llama-arch.cpp +69 -2
  58. package/src/llama.cpp/src/llama-arch.h +6 -0
  59. package/src/llama.cpp/src/llama-context.cpp +92 -45
  60. package/src/llama.cpp/src/llama-context.h +1 -5
  61. package/src/llama.cpp/src/llama-graph.cpp +74 -19
  62. package/src/llama.cpp/src/llama-graph.h +10 -1
  63. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  64. package/src/llama.cpp/src/llama-hparams.h +9 -3
  65. package/src/llama.cpp/src/llama-impl.h +2 -0
  66. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
  67. package/src/llama.cpp/src/llama-kv-cache.h +4 -13
  68. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  69. package/src/llama.cpp/src/llama-model.cpp +434 -21
  70. package/src/llama.cpp/src/llama-model.h +1 -1
  71. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  72. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  73. package/src/llama.cpp/src/llama.cpp +12 -0
  74. package/src/anyascii.c +0 -22223
  75. package/src/anyascii.h +0 -42
  76. package/src/tts_utils.cpp +0 -371
  77. package/src/tts_utils.h +0 -103
@@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
2169
2169
  class tinyBLAS_PPC {
2170
2170
  public:
2171
2171
  tinyBLAS_PPC(int64_t k,
2172
- const float *A, int64_t lda,
2173
- const float *B, int64_t ldb,
2174
- float *C, int64_t ldc,
2172
+ const float * A, int64_t lda,
2173
+ const float * B, int64_t ldb,
2174
+ float * C, int64_t ldc,
2175
2175
  int ith, int nth)
2176
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2177
2177
  }
2178
2178
 
2179
2179
  void matmul(int64_t m, int64_t n) {
2180
- mnpack(0, m, 0, n);
2180
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2181
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2182
+ matmul_tiled(m, n, mc, nc, kc);
2183
+ } else {
2184
+ mnpack(0, m, 0, n);
2185
+ }
2181
2186
  }
2182
2187
 
2183
2188
  private:
2184
2189
 
2185
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2190
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2191
+ vec_t vec_C[4];
2192
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2193
+ for (int I = 0; I < 4; I++) {
2194
+ for (int J = 0; J < 4; J++) {
2195
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2196
+ }
2197
+ }
2198
+ }
2186
2199
 
2187
- inline void vector_permute_store_4(vector float *src, float *vecOffset) {
2188
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
2189
- t1 = vec_mergeh(src[0], src[1]);
2190
- t2 = vec_mergeh(src[2], src[3]);
2191
- t3 = vec_mergel(src[0], src[1]);
2192
- t4 = vec_mergel(src[2], src[3]);
2200
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2201
+ vec_t vec_C[4];
2202
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2203
+ for (int I = 0; I < 4; I++) {
2204
+ for (int J = 0; J < 4; J++) {
2205
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2206
+ *c_ptr += *((float *)&vec_C[I]+J);
2207
+ }
2208
+ }
2209
+ }
2193
2210
 
2194
- t5 = vec_xxpermdi(t1, t2, 0);
2195
- t6 = vec_xxpermdi(t1, t2, 3);
2196
- t7 = vec_xxpermdi(t3, t4, 0);
2197
- t8 = vec_xxpermdi(t3, t4, 3);
2211
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2212
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2213
+ t1 = vec_mergeh(src[0], src[1]);
2214
+ t2 = vec_mergeh(src[2], src[3]);
2215
+ t3 = vec_mergel(src[0], src[1]);
2216
+ t4 = vec_mergel(src[2], src[3]);
2198
2217
 
2199
- vec_xst(t5, 0, vecOffset);
2200
- vec_xst(t6, 0, vecOffset + 4);
2201
- vec_xst(t7, 0, vecOffset + 8);
2202
- vec_xst(t8, 0, vecOffset + 12);
2203
- }
2218
+ t5 = vec_xxpermdi(t1, t2, 0);
2219
+ t6 = vec_xxpermdi(t1, t2, 3);
2220
+ t7 = vec_xxpermdi(t3, t4, 0);
2221
+ t8 = vec_xxpermdi(t3, t4, 3);
2204
2222
 
2205
- inline void vector_permute_store_8(vector float *src, float *vecOffset) {
2206
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
2207
- t1 = vec_mergeh(src[0], src[1]);
2208
- t2 = vec_mergeh(src[2], src[3]);
2209
- t3 = vec_mergeh(src[4], src[5]);
2210
- t4 = vec_mergeh(src[6], src[7]);
2223
+ vec_xst(t5, 0, vecOffset);
2224
+ vec_xst(t6, 0, vecOffset + 4);
2225
+ vec_xst(t7, 0, vecOffset + 8);
2226
+ vec_xst(t8, 0, vecOffset + 12);
2227
+ }
2211
2228
 
2212
- t5 = vec_xxpermdi(t1, t2, 0);
2213
- t6 = vec_xxpermdi(t3, t4, 0);
2214
- t7 = vec_xxpermdi(t1, t2, 3);
2215
- t8 = vec_xxpermdi(t3, t4, 3);
2229
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2230
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2231
+ t1 = vec_mergeh(src[0], src[1]);
2232
+ t2 = vec_mergeh(src[2], src[3]);
2233
+ t3 = vec_mergeh(src[4], src[5]);
2234
+ t4 = vec_mergeh(src[6], src[7]);
2216
2235
 
2217
- vec_xst(t5, 0, vecOffset);
2218
- vec_xst(t6, 0, vecOffset + 4);
2219
- vec_xst(t7, 0, vecOffset + 8);
2220
- vec_xst(t8, 0, vecOffset + 12);
2236
+ t5 = vec_xxpermdi(t1, t2, 0);
2237
+ t6 = vec_xxpermdi(t3, t4, 0);
2238
+ t7 = vec_xxpermdi(t1, t2, 3);
2239
+ t8 = vec_xxpermdi(t3, t4, 3);
2221
2240
 
2222
- t1 = vec_mergel(src[0], src[1]);
2223
- t2 = vec_mergel(src[2], src[3]);
2224
- t3 = vec_mergel(src[4], src[5]);
2225
- t4 = vec_mergel(src[6], src[7]);
2241
+ vec_xst(t5, 0, vecOffset);
2242
+ vec_xst(t6, 0, vecOffset + 4);
2243
+ vec_xst(t7, 0, vecOffset + 8);
2244
+ vec_xst(t8, 0, vecOffset + 12);
2226
2245
 
2227
- t5 = vec_xxpermdi(t1, t2, 0);
2228
- t6 = vec_xxpermdi(t3, t4, 0);
2229
- t7 = vec_xxpermdi(t1, t2, 3);
2230
- t8 = vec_xxpermdi(t3, t4, 3);
2246
+ t1 = vec_mergel(src[0], src[1]);
2247
+ t2 = vec_mergel(src[2], src[3]);
2248
+ t3 = vec_mergel(src[4], src[5]);
2249
+ t4 = vec_mergel(src[6], src[7]);
2231
2250
 
2232
- vec_xst(t5, 0, vecOffset + 16);
2233
- vec_xst(t6, 0, vecOffset + 20);
2234
- vec_xst(t7, 0, vecOffset + 24);
2235
- vec_xst(t8, 0, vecOffset + 28);
2251
+ t5 = vec_xxpermdi(t1, t2, 0);
2252
+ t6 = vec_xxpermdi(t3, t4, 0);
2253
+ t7 = vec_xxpermdi(t1, t2, 3);
2254
+ t8 = vec_xxpermdi(t3, t4, 3);
2255
+
2256
+ vec_xst(t5, 0, vecOffset + 16);
2257
+ vec_xst(t6, 0, vecOffset + 20);
2258
+ vec_xst(t7, 0, vecOffset + 24);
2259
+ vec_xst(t8, 0, vecOffset + 28);
2236
2260
  }
2237
2261
 
2238
- void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
2262
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2239
2263
  int64_t i, j;
2240
2264
  float * aoffsets[8];
2241
- float *aoffset = NULL, *boffset = NULL;
2265
+ float * aoffset = NULL, * boffset = NULL;
2242
2266
  __vector_pair arr[8];
2243
2267
  vector float c[8][2] = {0};
2244
2268
  vector float c1[8] = {0};
2245
2269
  vector float c2[8] = {0};
2246
- aoffset = const_cast<float*>(a);
2270
+ aoffset = const_cast<float *>(a);
2247
2271
  boffset = vec;
2248
2272
  j = (rows >> 3);
2249
2273
  if (j > 0) {
2250
-
2251
2274
  do {
2252
2275
  aoffsets[0] = aoffset;
2253
- for (int it = 1; it< 8; it++)
2276
+ for (int it = 1; it < 8; it++)
2254
2277
  aoffsets[it] = aoffsets[it-1] + lda;
2255
2278
  aoffset += 8 * lda;
2256
2279
  i = (cols >> 3);
2257
2280
  if (i > 0) {
2258
2281
  do {
2259
- for (int it = 0; it< 8; it++) {
2282
+ for (int it = 0; it < 8; it++) {
2260
2283
  arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2261
2284
  __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2262
2285
  c1[it] = c[it][0];
@@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
2264
2287
  }
2265
2288
 
2266
2289
  vector_permute_store_8(c1, boffset);
2267
- vector_permute_store_8(c2, boffset+32);
2268
- for (int it = 0; it < 4; it++)
2269
- aoffsets[it] = aoffsets[it] + 8*lda;
2290
+ vector_permute_store_8(c2, boffset + 32);
2270
2291
  boffset += 64;
2271
2292
  i--;
2293
+ if (i > 0) {
2294
+ for (int it = 0; it < 8; it++) {
2295
+ aoffsets[it] = aoffsets[it] + 8;
2296
+ }
2297
+ }
2272
2298
  } while(i > 0);
2273
2299
  }
2274
2300
  if (cols & 4) {
@@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
2295
2321
  c2[it] = c[it][1];
2296
2322
  }
2297
2323
  vector_permute_store_4(c1, boffset);
2298
- vector_permute_store_4(c2, boffset+16);
2324
+ vector_permute_store_4(c2, boffset + 16);
2299
2325
  for (int it = 0; it < 4; it++)
2300
- aoffsets[it] += 8*lda;
2326
+ aoffsets[it] += 8 * lda;
2301
2327
  boffset += 32;
2302
2328
  i--;
2303
2329
  } while(i > 0);
@@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
2325
2351
  vec_t vec_A[4], vec_B[4], vec_C[4];
2326
2352
  acc_t acc_0;
2327
2353
  __builtin_mma_xxsetaccz(&acc_0);
2328
- for (int l = 0; l < k; l+=4) {
2329
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2330
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2354
+ for (int l = 0; l < k; l += 4) {
2355
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2356
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2331
2357
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2332
2358
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2333
2359
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2334
2360
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2335
2361
  }
2336
- SAVE_ACC(&acc_0, ii, jj);
2362
+ save_acc(&acc_0, ii, jj);
2337
2363
  }
2338
2364
 
2339
2365
  void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
2341
2367
  acc_t acc_0, acc_1;
2342
2368
  __builtin_mma_xxsetaccz(&acc_0);
2343
2369
  __builtin_mma_xxsetaccz(&acc_1);
2344
- for (int64_t l = 0; l < k; l+=4) {
2345
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2346
- packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2370
+ for (int64_t l = 0; l < k; l += 4) {
2371
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2372
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2347
2373
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2348
2374
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2349
2375
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
2353
2379
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2354
2380
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2355
2381
  }
2356
- SAVE_ACC(&acc_0, ii, jj);
2357
- SAVE_ACC(&acc_1, ii, jj+4);
2382
+ save_acc(&acc_0, ii, jj);
2383
+ save_acc(&acc_1, ii, jj + 4);
2358
2384
  }
2359
2385
 
2360
2386
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
2362
2388
  acc_t acc_0, acc_1;
2363
2389
  __builtin_mma_xxsetaccz(&acc_0);
2364
2390
  __builtin_mma_xxsetaccz(&acc_1);
2365
- for (int64_t l = 0; l < k; l+=4) {
2366
- packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
2367
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2391
+ for (int64_t l = 0; l < k; l += 4) {
2392
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2393
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2368
2394
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2369
2395
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2370
2396
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
2374
2400
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
2375
2401
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
2376
2402
  }
2377
- SAVE_ACC(&acc_0, ii, jj);
2378
- SAVE_ACC(&acc_1, ii+4, jj);
2403
+ save_acc(&acc_0, ii, jj);
2404
+ save_acc(&acc_1, ii + 4, jj);
2379
2405
  }
2380
2406
 
2381
2407
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
2386
2412
  __builtin_mma_xxsetaccz(&acc_2);
2387
2413
  __builtin_mma_xxsetaccz(&acc_3);
2388
2414
  for (int l = 0; l < k; l+=8) {
2389
- packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
2390
- packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
2415
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2416
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
2391
2417
  for(int x = 0; x < 16; x+=2) {
2392
2418
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2393
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
2394
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
2395
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
2419
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2420
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2421
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2422
+ }
2423
+ }
2424
+ save_acc(&acc_0, ii, jj);
2425
+ save_acc(&acc_1, ii, jj + 4);
2426
+ save_acc(&acc_2, ii + 4, jj);
2427
+ save_acc(&acc_3, ii + 4, jj + 4);
2428
+ }
2429
+
2430
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2431
+ for (int x = 0; x < 16; x += 2) {
2432
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2433
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2434
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2435
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2436
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2437
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2438
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2439
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2440
+ }
2441
+ }
2442
+
2443
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2444
+ for (int64_t i = 0; i < mc; i += 16) {
2445
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
2446
+ for (int64_t j = 0; j < nc; j += 8) {
2447
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
2448
+ acc_t acc[8];
2449
+ vec_t A0_block[16]; vec_t A1_block[16];
2450
+ for (int x = 0; x < 8; x++)
2451
+ __builtin_mma_xxsetaccz(&acc[x]);
2452
+ for (int64_t l = 0; l < kc; l += 8) {
2453
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
2454
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2455
+ int B_block_idx = B_base_addr + (l / 8) * 16;
2456
+ vec_t* A0_block = &vec_A[A0_block_idx];
2457
+ vec_t* A1_block = &vec_A[A1_block_idx];
2458
+ vec_t* B_block = &vec_B[B_block_idx];
2459
+ MMA_16x8(A0_block, A1_block, B_block, acc);
2460
+ }
2461
+ if (kk == 0) {
2462
+ save_acc(&acc[0], ii + i, jj + j);
2463
+ save_acc(&acc[1], ii + i, jj + j + 4);
2464
+ save_acc(&acc[2], ii + i + 4, jj + j);
2465
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
2466
+ save_acc(&acc[4], ii + i + 8, jj + j);
2467
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
2468
+ save_acc(&acc[6], ii + i + 12, jj + j);
2469
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
2470
+ } else {
2471
+ add_save_acc(&acc[0], ii + i, jj + j);
2472
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
2473
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
2474
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2475
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
2476
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2477
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
2478
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2483
+
2484
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2485
+ int64_t ytiles = m / mc;
2486
+ int64_t xtiles = n / nc;
2487
+ int64_t tiles = xtiles * ytiles;
2488
+ int64_t duty = (tiles + nth - 1) / nth;
2489
+ int64_t start = duty * ith;
2490
+ int64_t end = start + duty;
2491
+ if (end > tiles) {
2492
+ end = tiles;
2493
+ }
2494
+ for (int64_t job = start; job < end; ++job) {
2495
+ int64_t ii = (job / xtiles) * mc;
2496
+ int64_t jj = (job % xtiles) * nc;
2497
+ for (int64_t kk = 0; kk < k; kk += kc) {
2498
+ vec_t A_pack[kc * mc / 4];
2499
+ vec_t B_pack[kc * nc / 4];
2500
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2501
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2502
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2396
2503
  }
2397
2504
  }
2398
- SAVE_ACC(&acc_0, ii, jj);
2399
- SAVE_ACC(&acc_1, ii, jj+4);
2400
- SAVE_ACC(&acc_2, ii+4, jj);
2401
- SAVE_ACC(&acc_3, ii+4, jj+4);
2402
2505
  }
2403
2506
 
2404
2507
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
2406
2509
  int n_rem = MIN(n - n0, 8);
2407
2510
  int mc = 0, nc = 0;
2408
2511
  if (m_rem >= 8 && n_rem >= 8) {
2409
- mc = 8;
2410
- nc = 8;
2411
- gemm<8, 8>(m0, m, n0, n);
2512
+ mc = 8;
2513
+ nc = 8;
2514
+ gemm<8, 8>(m0, m, n0, n);
2412
2515
  } else if (m_rem >= 4 && n_rem >= 8) {
2413
- mc = 4;
2414
- nc = 8;
2415
- gemm<4, 8>(m0, m, n0, n);
2516
+ mc = 4;
2517
+ nc = 8;
2518
+ gemm<4, 8>(m0, m, n0, n);
2416
2519
  } else if (m_rem >= 8 && n_rem >= 4) {
2417
- mc = 8;
2418
- nc = 4;
2419
- gemm<8, 4>(m0, m, n0, n);
2520
+ mc = 8;
2521
+ nc = 4;
2522
+ gemm<8, 4>(m0, m, n0, n);
2420
2523
  } else if (m_rem >= 4 && n_rem >= 4) {
2421
- mc = 4;
2422
- nc = 4;
2423
- gemm<4, 4>(m0, m, n0, n);
2524
+ mc = 4;
2525
+ nc = 4;
2526
+ gemm<4, 4>(m0, m, n0, n);
2424
2527
  } else {
2425
2528
  mc = (m_rem >= 4) ? 4 : m_rem;
2426
2529
  nc = (n_rem >= 4) ? 4 : n_rem;
2427
2530
  if (mc == 0 || nc == 0)
2428
- return;
2531
+ return;
2429
2532
  gemm_small(m0, m, n0, n, mc, nc);
2430
2533
  }
2431
2534
  int64_t mp = m0 + ((m - m0) / mc) * mc;
2432
2535
  int64_t np = n0 + ((n - n0) / nc) * nc;
2433
2536
  mnpack(mp, m, n0, np);
2434
2537
  mnpack(m0, m, np, n);
2435
- }
2538
+ }
2436
2539
 
2437
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2540
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2438
2541
  int64_t ytiles = (m - m0) / RM;
2439
2542
  int64_t xtiles = (n - n0) / RN;
2440
2543
  int64_t tiles = xtiles * ytiles;
@@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
2449
2552
  vec_t vec_C[4];
2450
2553
  acc_t acc_0;
2451
2554
  __builtin_mma_xxsetaccz(&acc_0);
2452
- vec_t vec_A[4] {0}, vec_B[4] = {0};
2453
- for (int l=0; l<k; l+=4) {
2555
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
2556
+ for (int l = 0; l < k; l += 4) {
2454
2557
  /* 'GEMV Forwarding' concept is used in first two conditional loops.
2455
2558
  * when one of the matrix has a single row/column, the elements are
2456
2559
  * broadcasted, instead of using packing routine to prepack the
2457
2560
  * matrix elements.
2458
2561
  */
2459
2562
  if (RM == 1) {
2460
- float* a = const_cast<float*>(A+(ii)*lda+l);
2461
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2563
+ float * a = const_cast<float *>(A + (ii) * lda + l);
2564
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2462
2565
  vec_A[0] = (vec_t)vec_xl(0,a);
2463
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
2464
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
2465
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
2566
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2567
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2568
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
2466
2569
  } else if (RN == 1) {
2467
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2468
- float* b = const_cast<float*>(B+(jj)*ldb+l);
2570
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2571
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
2469
2572
  vec_B[0] = (vec_t)vec_xl(0,b);
2470
- vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
2471
- vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
2472
- vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
2573
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2574
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2575
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
2473
2576
  } else {
2474
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2475
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2577
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2578
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2476
2579
  }
2477
2580
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2478
2581
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
2482
2585
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
2483
2586
  for (int I = 0; I < RM; I++) {
2484
2587
  for (int J = 0; J < RN; J++) {
2485
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
2588
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2486
2589
  }
2487
2590
  }
2488
2591
  }
2489
2592
  }
2490
2593
 
2594
+ template<int RM, int RN>
2595
+ inline void kernel(int64_t ii, int64_t jj) {
2596
+ if constexpr(RM == 4 && RN == 4) {
2597
+ KERNEL_4x4(ii, jj);
2598
+ } else if constexpr(RM == 4 && RN == 8) {
2599
+ KERNEL_4x8(ii, jj);
2600
+ } else if constexpr(RM == 8 && RN == 4) {
2601
+ KERNEL_8x4(ii, jj);
2602
+ } else if constexpr(RM == 8 && RN == 8) {
2603
+ KERNEL_8x8(ii, jj);
2604
+ } else {
2605
+ static_assert(false, "RN/RM values not supported");
2606
+ }
2607
+ }
2608
+
2491
2609
  template <int RM, int RN>
2492
2610
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2493
2611
  int64_t ytiles = (m - m0) / RM;
@@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
2496
2614
  int64_t duty = (tiles + nth - 1) / nth;
2497
2615
  int64_t start = duty * ith;
2498
2616
  int64_t end = start + duty;
2499
- if (RM == 4 && RN == 4) {
2500
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
2501
- } else if (RM == 4 && RN == 8) {
2502
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
2503
- } else if (RM == 8 && RN == 4) {
2504
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
2505
- } else if (RM == 8 && RN == 8) {
2506
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
2507
- }
2508
2617
  if (end > tiles)
2509
2618
  end = tiles;
2510
2619
  for (int64_t job = start; job < end; ++job) {
2511
2620
  int64_t ii = m0 + job / xtiles * RM;
2512
2621
  int64_t jj = n0 + job % xtiles * RN;
2513
- (this->*kernel)(ii, jj);
2622
+ kernel<RM, RN>(ii, jj);
2514
2623
  }
2515
2624
  }
2516
2625
 
2517
- const float *const A;
2518
- const float *const B;
2519
- float *C;
2626
+ const float * const A;
2627
+ const float * const B;
2628
+ float * C;
2520
2629
  const int64_t k;
2521
2630
  const int64_t lda;
2522
2631
  const int64_t ldb;