@fugood/llama.node 1.1.9 → 1.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/lib/binding.ts +7 -1
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +15 -5
  4. package/src/LlamaCompletionWorker.cpp +12 -3
  5. package/src/LlamaCompletionWorker.h +3 -1
  6. package/src/LlamaContext.cpp +20 -2
  7. package/src/llama.cpp/common/arg.cpp +29 -19
  8. package/src/llama.cpp/common/chat.cpp +153 -3
  9. package/src/llama.cpp/common/chat.h +1 -0
  10. package/src/llama.cpp/common/common.cpp +10 -3
  11. package/src/llama.cpp/common/common.h +4 -1
  12. package/src/llama.cpp/ggml/CMakeLists.txt +1 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -4
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  16. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +14 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +16 -12
  20. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  21. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  22. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +103 -1
  23. package/src/llama.cpp/include/llama.h +27 -1
  24. package/src/llama.cpp/src/llama-adapter.cpp +68 -4
  25. package/src/llama.cpp/src/llama-adapter.h +3 -0
  26. package/src/llama.cpp/src/llama-arch.cpp +46 -2
  27. package/src/llama.cpp/src/llama-arch.h +4 -0
  28. package/src/llama.cpp/src/llama-context.cpp +80 -39
  29. package/src/llama.cpp/src/llama-context.h +0 -4
  30. package/src/llama.cpp/src/llama-graph.cpp +20 -10
  31. package/src/llama.cpp/src/llama-graph.h +2 -1
  32. package/src/llama.cpp/src/llama-hparams.cpp +25 -0
  33. package/src/llama.cpp/src/llama-hparams.h +6 -0
  34. package/src/llama.cpp/src/llama-impl.h +2 -0
  35. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +24 -7
  36. package/src/llama.cpp/src/llama-kv-cache-iswa.h +4 -2
  37. package/src/llama.cpp/src/llama-kv-cache.cpp +67 -130
  38. package/src/llama.cpp/src/llama-kv-cache.h +16 -28
  39. package/src/llama.cpp/src/llama-memory-hybrid.cpp +29 -28
  40. package/src/llama.cpp/src/llama-memory-hybrid.h +18 -22
  41. package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
  42. package/src/llama.cpp/src/llama-memory-recurrent.h +7 -11
  43. package/src/llama.cpp/src/llama-memory.h +8 -0
  44. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  45. package/src/llama.cpp/src/llama-model.cpp +302 -31
  46. package/src/llama.cpp/src/llama-model.h +1 -0
  47. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  48. package/src/llama.cpp/src/llama.cpp +12 -0
@@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
2169
2169
  class tinyBLAS_PPC {
2170
2170
  public:
2171
2171
  tinyBLAS_PPC(int64_t k,
2172
- const float *A, int64_t lda,
2173
- const float *B, int64_t ldb,
2174
- float *C, int64_t ldc,
2172
+ const float * A, int64_t lda,
2173
+ const float * B, int64_t ldb,
2174
+ float * C, int64_t ldc,
2175
2175
  int ith, int nth)
2176
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2177
2177
  }
2178
2178
 
2179
2179
  void matmul(int64_t m, int64_t n) {
2180
- mnpack(0, m, 0, n);
2180
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2181
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2182
+ matmul_tiled(m, n, mc, nc, kc);
2183
+ } else {
2184
+ mnpack(0, m, 0, n);
2185
+ }
2181
2186
  }
2182
2187
 
2183
2188
  private:
2184
2189
 
2185
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2190
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2191
+ vec_t vec_C[4];
2192
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2193
+ for (int I = 0; I < 4; I++) {
2194
+ for (int J = 0; J < 4; J++) {
2195
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2196
+ }
2197
+ }
2198
+ }
2186
2199
 
2187
- inline void vector_permute_store_4(vector float *src, float *vecOffset) {
2188
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
2189
- t1 = vec_mergeh(src[0], src[1]);
2190
- t2 = vec_mergeh(src[2], src[3]);
2191
- t3 = vec_mergel(src[0], src[1]);
2192
- t4 = vec_mergel(src[2], src[3]);
2200
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2201
+ vec_t vec_C[4];
2202
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2203
+ for (int I = 0; I < 4; I++) {
2204
+ for (int J = 0; J < 4; J++) {
2205
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2206
+ *c_ptr += *((float *)&vec_C[I]+J);
2207
+ }
2208
+ }
2209
+ }
2193
2210
 
2194
- t5 = vec_xxpermdi(t1, t2, 0);
2195
- t6 = vec_xxpermdi(t1, t2, 3);
2196
- t7 = vec_xxpermdi(t3, t4, 0);
2197
- t8 = vec_xxpermdi(t3, t4, 3);
2211
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2212
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2213
+ t1 = vec_mergeh(src[0], src[1]);
2214
+ t2 = vec_mergeh(src[2], src[3]);
2215
+ t3 = vec_mergel(src[0], src[1]);
2216
+ t4 = vec_mergel(src[2], src[3]);
2198
2217
 
2199
- vec_xst(t5, 0, vecOffset);
2200
- vec_xst(t6, 0, vecOffset + 4);
2201
- vec_xst(t7, 0, vecOffset + 8);
2202
- vec_xst(t8, 0, vecOffset + 12);
2203
- }
2218
+ t5 = vec_xxpermdi(t1, t2, 0);
2219
+ t6 = vec_xxpermdi(t1, t2, 3);
2220
+ t7 = vec_xxpermdi(t3, t4, 0);
2221
+ t8 = vec_xxpermdi(t3, t4, 3);
2204
2222
 
2205
- inline void vector_permute_store_8(vector float *src, float *vecOffset) {
2206
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
2207
- t1 = vec_mergeh(src[0], src[1]);
2208
- t2 = vec_mergeh(src[2], src[3]);
2209
- t3 = vec_mergeh(src[4], src[5]);
2210
- t4 = vec_mergeh(src[6], src[7]);
2223
+ vec_xst(t5, 0, vecOffset);
2224
+ vec_xst(t6, 0, vecOffset + 4);
2225
+ vec_xst(t7, 0, vecOffset + 8);
2226
+ vec_xst(t8, 0, vecOffset + 12);
2227
+ }
2211
2228
 
2212
- t5 = vec_xxpermdi(t1, t2, 0);
2213
- t6 = vec_xxpermdi(t3, t4, 0);
2214
- t7 = vec_xxpermdi(t1, t2, 3);
2215
- t8 = vec_xxpermdi(t3, t4, 3);
2229
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2230
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2231
+ t1 = vec_mergeh(src[0], src[1]);
2232
+ t2 = vec_mergeh(src[2], src[3]);
2233
+ t3 = vec_mergeh(src[4], src[5]);
2234
+ t4 = vec_mergeh(src[6], src[7]);
2216
2235
 
2217
- vec_xst(t5, 0, vecOffset);
2218
- vec_xst(t6, 0, vecOffset + 4);
2219
- vec_xst(t7, 0, vecOffset + 8);
2220
- vec_xst(t8, 0, vecOffset + 12);
2236
+ t5 = vec_xxpermdi(t1, t2, 0);
2237
+ t6 = vec_xxpermdi(t3, t4, 0);
2238
+ t7 = vec_xxpermdi(t1, t2, 3);
2239
+ t8 = vec_xxpermdi(t3, t4, 3);
2221
2240
 
2222
- t1 = vec_mergel(src[0], src[1]);
2223
- t2 = vec_mergel(src[2], src[3]);
2224
- t3 = vec_mergel(src[4], src[5]);
2225
- t4 = vec_mergel(src[6], src[7]);
2241
+ vec_xst(t5, 0, vecOffset);
2242
+ vec_xst(t6, 0, vecOffset + 4);
2243
+ vec_xst(t7, 0, vecOffset + 8);
2244
+ vec_xst(t8, 0, vecOffset + 12);
2226
2245
 
2227
- t5 = vec_xxpermdi(t1, t2, 0);
2228
- t6 = vec_xxpermdi(t3, t4, 0);
2229
- t7 = vec_xxpermdi(t1, t2, 3);
2230
- t8 = vec_xxpermdi(t3, t4, 3);
2246
+ t1 = vec_mergel(src[0], src[1]);
2247
+ t2 = vec_mergel(src[2], src[3]);
2248
+ t3 = vec_mergel(src[4], src[5]);
2249
+ t4 = vec_mergel(src[6], src[7]);
2231
2250
 
2232
- vec_xst(t5, 0, vecOffset + 16);
2233
- vec_xst(t6, 0, vecOffset + 20);
2234
- vec_xst(t7, 0, vecOffset + 24);
2235
- vec_xst(t8, 0, vecOffset + 28);
2251
+ t5 = vec_xxpermdi(t1, t2, 0);
2252
+ t6 = vec_xxpermdi(t3, t4, 0);
2253
+ t7 = vec_xxpermdi(t1, t2, 3);
2254
+ t8 = vec_xxpermdi(t3, t4, 3);
2255
+
2256
+ vec_xst(t5, 0, vecOffset + 16);
2257
+ vec_xst(t6, 0, vecOffset + 20);
2258
+ vec_xst(t7, 0, vecOffset + 24);
2259
+ vec_xst(t8, 0, vecOffset + 28);
2236
2260
  }
2237
2261
 
2238
- void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
2262
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2239
2263
  int64_t i, j;
2240
2264
  float * aoffsets[8];
2241
- float *aoffset = NULL, *boffset = NULL;
2265
+ float * aoffset = NULL, * boffset = NULL;
2242
2266
  __vector_pair arr[8];
2243
2267
  vector float c[8][2] = {0};
2244
2268
  vector float c1[8] = {0};
2245
2269
  vector float c2[8] = {0};
2246
- aoffset = const_cast<float*>(a);
2270
+ aoffset = const_cast<float *>(a);
2247
2271
  boffset = vec;
2248
2272
  j = (rows >> 3);
2249
2273
  if (j > 0) {
2250
-
2251
2274
  do {
2252
2275
  aoffsets[0] = aoffset;
2253
- for (int it = 1; it< 8; it++)
2276
+ for (int it = 1; it < 8; it++)
2254
2277
  aoffsets[it] = aoffsets[it-1] + lda;
2255
2278
  aoffset += 8 * lda;
2256
2279
  i = (cols >> 3);
2257
2280
  if (i > 0) {
2258
2281
  do {
2259
- for (int it = 0; it< 8; it++) {
2282
+ for (int it = 0; it < 8; it++) {
2260
2283
  arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2261
2284
  __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2262
2285
  c1[it] = c[it][0];
@@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
2264
2287
  }
2265
2288
 
2266
2289
  vector_permute_store_8(c1, boffset);
2267
- vector_permute_store_8(c2, boffset+32);
2268
- for (int it = 0; it < 4; it++)
2269
- aoffsets[it] = aoffsets[it] + 8*lda;
2290
+ vector_permute_store_8(c2, boffset + 32);
2270
2291
  boffset += 64;
2271
2292
  i--;
2293
+ if (i > 0) {
2294
+ for (int it = 0; it < 8; it++) {
2295
+ aoffsets[it] = aoffsets[it] + 8;
2296
+ }
2297
+ }
2272
2298
  } while(i > 0);
2273
2299
  }
2274
2300
  if (cols & 4) {
@@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
2295
2321
  c2[it] = c[it][1];
2296
2322
  }
2297
2323
  vector_permute_store_4(c1, boffset);
2298
- vector_permute_store_4(c2, boffset+16);
2324
+ vector_permute_store_4(c2, boffset + 16);
2299
2325
  for (int it = 0; it < 4; it++)
2300
- aoffsets[it] += 8*lda;
2326
+ aoffsets[it] += 8 * lda;
2301
2327
  boffset += 32;
2302
2328
  i--;
2303
2329
  } while(i > 0);
@@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
2325
2351
  vec_t vec_A[4], vec_B[4], vec_C[4];
2326
2352
  acc_t acc_0;
2327
2353
  __builtin_mma_xxsetaccz(&acc_0);
2328
- for (int l = 0; l < k; l+=4) {
2329
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2330
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2354
+ for (int l = 0; l < k; l += 4) {
2355
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2356
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2331
2357
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2332
2358
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2333
2359
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2334
2360
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2335
2361
  }
2336
- SAVE_ACC(&acc_0, ii, jj);
2362
+ save_acc(&acc_0, ii, jj);
2337
2363
  }
2338
2364
 
2339
2365
  void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
2341
2367
  acc_t acc_0, acc_1;
2342
2368
  __builtin_mma_xxsetaccz(&acc_0);
2343
2369
  __builtin_mma_xxsetaccz(&acc_1);
2344
- for (int64_t l = 0; l < k; l+=4) {
2345
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2346
- packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2370
+ for (int64_t l = 0; l < k; l += 4) {
2371
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2372
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2347
2373
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2348
2374
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2349
2375
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
2353
2379
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2354
2380
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2355
2381
  }
2356
- SAVE_ACC(&acc_0, ii, jj);
2357
- SAVE_ACC(&acc_1, ii, jj+4);
2382
+ save_acc(&acc_0, ii, jj);
2383
+ save_acc(&acc_1, ii, jj + 4);
2358
2384
  }
2359
2385
 
2360
2386
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
2362
2388
  acc_t acc_0, acc_1;
2363
2389
  __builtin_mma_xxsetaccz(&acc_0);
2364
2390
  __builtin_mma_xxsetaccz(&acc_1);
2365
- for (int64_t l = 0; l < k; l+=4) {
2366
- packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
2367
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2391
+ for (int64_t l = 0; l < k; l += 4) {
2392
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2393
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2368
2394
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2369
2395
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2370
2396
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
2374
2400
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
2375
2401
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
2376
2402
  }
2377
- SAVE_ACC(&acc_0, ii, jj);
2378
- SAVE_ACC(&acc_1, ii+4, jj);
2403
+ save_acc(&acc_0, ii, jj);
2404
+ save_acc(&acc_1, ii + 4, jj);
2379
2405
  }
2380
2406
 
2381
2407
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
2386
2412
  __builtin_mma_xxsetaccz(&acc_2);
2387
2413
  __builtin_mma_xxsetaccz(&acc_3);
2388
2414
  for (int l = 0; l < k; l+=8) {
2389
- packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
2390
- packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
2415
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2416
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
2391
2417
  for(int x = 0; x < 16; x+=2) {
2392
2418
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2393
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
2394
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
2395
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
2419
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2420
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2421
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2422
+ }
2423
+ }
2424
+ save_acc(&acc_0, ii, jj);
2425
+ save_acc(&acc_1, ii, jj + 4);
2426
+ save_acc(&acc_2, ii + 4, jj);
2427
+ save_acc(&acc_3, ii + 4, jj + 4);
2428
+ }
2429
+
2430
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2431
+ for (int x = 0; x < 16; x += 2) {
2432
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2433
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2434
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2435
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2436
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2437
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2438
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2439
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2440
+ }
2441
+ }
2442
+
2443
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2444
+ for (int64_t i = 0; i < mc; i += 16) {
2445
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
2446
+ for (int64_t j = 0; j < nc; j += 8) {
2447
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
2448
+ acc_t acc[8];
2449
+ vec_t A0_block[16]; vec_t A1_block[16];
2450
+ for (int x = 0; x < 8; x++)
2451
+ __builtin_mma_xxsetaccz(&acc[x]);
2452
+ for (int64_t l = 0; l < kc; l += 8) {
2453
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
2454
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2455
+ int B_block_idx = B_base_addr + (l / 8) * 16;
2456
+ vec_t* A0_block = &vec_A[A0_block_idx];
2457
+ vec_t* A1_block = &vec_A[A1_block_idx];
2458
+ vec_t* B_block = &vec_B[B_block_idx];
2459
+ MMA_16x8(A0_block, A1_block, B_block, acc);
2460
+ }
2461
+ if (kk == 0) {
2462
+ save_acc(&acc[0], ii + i, jj + j);
2463
+ save_acc(&acc[1], ii + i, jj + j + 4);
2464
+ save_acc(&acc[2], ii + i + 4, jj + j);
2465
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
2466
+ save_acc(&acc[4], ii + i + 8, jj + j);
2467
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
2468
+ save_acc(&acc[6], ii + i + 12, jj + j);
2469
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
2470
+ } else {
2471
+ add_save_acc(&acc[0], ii + i, jj + j);
2472
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
2473
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
2474
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2475
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
2476
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2477
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
2478
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2483
+
2484
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2485
+ int64_t ytiles = m / mc;
2486
+ int64_t xtiles = n / nc;
2487
+ int64_t tiles = xtiles * ytiles;
2488
+ int64_t duty = (tiles + nth - 1) / nth;
2489
+ int64_t start = duty * ith;
2490
+ int64_t end = start + duty;
2491
+ if (end > tiles) {
2492
+ end = tiles;
2493
+ }
2494
+ for (int64_t job = start; job < end; ++job) {
2495
+ int64_t ii = (job / xtiles) * mc;
2496
+ int64_t jj = (job % xtiles) * nc;
2497
+ for (int64_t kk = 0; kk < k; kk += kc) {
2498
+ vec_t A_pack[kc * mc / 4];
2499
+ vec_t B_pack[kc * nc / 4];
2500
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2501
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2502
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2396
2503
  }
2397
2504
  }
2398
- SAVE_ACC(&acc_0, ii, jj);
2399
- SAVE_ACC(&acc_1, ii, jj+4);
2400
- SAVE_ACC(&acc_2, ii+4, jj);
2401
- SAVE_ACC(&acc_3, ii+4, jj+4);
2402
2505
  }
2403
2506
 
2404
2507
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
2406
2509
  int n_rem = MIN(n - n0, 8);
2407
2510
  int mc = 0, nc = 0;
2408
2511
  if (m_rem >= 8 && n_rem >= 8) {
2409
- mc = 8;
2410
- nc = 8;
2411
- gemm<8, 8>(m0, m, n0, n);
2512
+ mc = 8;
2513
+ nc = 8;
2514
+ gemm<8, 8>(m0, m, n0, n);
2412
2515
  } else if (m_rem >= 4 && n_rem >= 8) {
2413
- mc = 4;
2414
- nc = 8;
2415
- gemm<4, 8>(m0, m, n0, n);
2516
+ mc = 4;
2517
+ nc = 8;
2518
+ gemm<4, 8>(m0, m, n0, n);
2416
2519
  } else if (m_rem >= 8 && n_rem >= 4) {
2417
- mc = 8;
2418
- nc = 4;
2419
- gemm<8, 4>(m0, m, n0, n);
2520
+ mc = 8;
2521
+ nc = 4;
2522
+ gemm<8, 4>(m0, m, n0, n);
2420
2523
  } else if (m_rem >= 4 && n_rem >= 4) {
2421
- mc = 4;
2422
- nc = 4;
2423
- gemm<4, 4>(m0, m, n0, n);
2524
+ mc = 4;
2525
+ nc = 4;
2526
+ gemm<4, 4>(m0, m, n0, n);
2424
2527
  } else {
2425
2528
  mc = (m_rem >= 4) ? 4 : m_rem;
2426
2529
  nc = (n_rem >= 4) ? 4 : n_rem;
2427
2530
  if (mc == 0 || nc == 0)
2428
- return;
2531
+ return;
2429
2532
  gemm_small(m0, m, n0, n, mc, nc);
2430
2533
  }
2431
2534
  int64_t mp = m0 + ((m - m0) / mc) * mc;
2432
2535
  int64_t np = n0 + ((n - n0) / nc) * nc;
2433
2536
  mnpack(mp, m, n0, np);
2434
2537
  mnpack(m0, m, np, n);
2435
- }
2538
+ }
2436
2539
 
2437
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2540
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2438
2541
  int64_t ytiles = (m - m0) / RM;
2439
2542
  int64_t xtiles = (n - n0) / RN;
2440
2543
  int64_t tiles = xtiles * ytiles;
@@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
2449
2552
  vec_t vec_C[4];
2450
2553
  acc_t acc_0;
2451
2554
  __builtin_mma_xxsetaccz(&acc_0);
2452
- vec_t vec_A[4] {0}, vec_B[4] = {0};
2453
- for (int l=0; l<k; l+=4) {
2555
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
2556
+ for (int l = 0; l < k; l += 4) {
2454
2557
  /* 'GEMV Forwarding' concept is used in first two conditional loops.
2455
2558
  * when one of the matrix has a single row/column, the elements are
2456
2559
  * broadcasted, instead of using packing routine to prepack the
2457
2560
  * matrix elements.
2458
2561
  */
2459
2562
  if (RM == 1) {
2460
- float* a = const_cast<float*>(A+(ii)*lda+l);
2461
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2563
+ float * a = const_cast<float *>(A + (ii) * lda + l);
2564
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2462
2565
  vec_A[0] = (vec_t)vec_xl(0,a);
2463
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
2464
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
2465
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
2566
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2567
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2568
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
2466
2569
  } else if (RN == 1) {
2467
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2468
- float* b = const_cast<float*>(B+(jj)*ldb+l);
2570
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2571
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
2469
2572
  vec_B[0] = (vec_t)vec_xl(0,b);
2470
- vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
2471
- vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
2472
- vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
2573
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2574
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2575
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
2473
2576
  } else {
2474
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2475
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2577
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2578
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2476
2579
  }
2477
2580
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2478
2581
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
2482
2585
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
2483
2586
  for (int I = 0; I < RM; I++) {
2484
2587
  for (int J = 0; J < RN; J++) {
2485
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
2588
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2486
2589
  }
2487
2590
  }
2488
2591
  }
2489
2592
  }
2490
2593
 
2594
+ template<int RM, int RN>
2595
+ inline void kernel(int64_t ii, int64_t jj) {
2596
+ if constexpr(RM == 4 && RN == 4) {
2597
+ KERNEL_4x4(ii, jj);
2598
+ } else if constexpr(RM == 4 && RN == 8) {
2599
+ KERNEL_4x8(ii, jj);
2600
+ } else if constexpr(RM == 8 && RN == 4) {
2601
+ KERNEL_8x4(ii, jj);
2602
+ } else if constexpr(RM == 8 && RN == 8) {
2603
+ KERNEL_8x8(ii, jj);
2604
+ } else {
2605
+ static_assert(false, "RN/RM values not supported");
2606
+ }
2607
+ }
2608
+
2491
2609
  template <int RM, int RN>
2492
2610
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2493
2611
  int64_t ytiles = (m - m0) / RM;
@@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
2496
2614
  int64_t duty = (tiles + nth - 1) / nth;
2497
2615
  int64_t start = duty * ith;
2498
2616
  int64_t end = start + duty;
2499
- if (RM == 4 && RN == 4) {
2500
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
2501
- } else if (RM == 4 && RN == 8) {
2502
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
2503
- } else if (RM == 8 && RN == 4) {
2504
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
2505
- } else if (RM == 8 && RN == 8) {
2506
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
2507
- }
2508
2617
  if (end > tiles)
2509
2618
  end = tiles;
2510
2619
  for (int64_t job = start; job < end; ++job) {
2511
2620
  int64_t ii = m0 + job / xtiles * RM;
2512
2621
  int64_t jj = n0 + job % xtiles * RN;
2513
- (this->*kernel)(ii, jj);
2622
+ kernel<RM, RN>(ii, jj);
2514
2623
  }
2515
2624
  }
2516
2625
 
2517
- const float *const A;
2518
- const float *const B;
2519
- float *C;
2626
+ const float * const A;
2627
+ const float * const B;
2628
+ float * C;
2520
2629
  const int64_t k;
2521
2630
  const int64_t lda;
2522
2631
  const int64_t ldb;
@@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9003
9003
  GGML_ASSERT(src4->nb[0] == sizeof(float));
9004
9004
  GGML_ASSERT(src5->nb[0] == sizeof(float));
9005
9005
  GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
9006
- // allows optimizing the modulo since n_group should be a power of 2
9007
- GGML_ASSERT((ng & -ng) == ng);
9006
+ GGML_ASSERT(nh % ng == 0);
9008
9007
 
9009
9008
  // heads per thread
9010
9009
  const int dh = (nh + nth - 1)/nth;
@@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9035
9034
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9036
9035
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9037
9036
  const float dA = expf(dt_soft_plus * A[h]);
9037
+ const int g = h / (nh / ng); // repeat_interleave
9038
9038
 
9039
9039
  // dim
9040
9040
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9057
9057
  // TODO: maybe unroll more?
9058
9058
  for (int j = 0; j < 1; j++) {
9059
9059
  GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
9060
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
9061
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
9060
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
9061
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
9062
9062
 
9063
9063
  t0 = GGML_F32_VEC_MUL(t0, adA);
9064
9064
  t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
9072
9072
  }
9073
9073
 
9074
9074
  sumf = GGML_F32xt_REDUCE_ONE(sum);
9075
+ #elif defined(__riscv_v_intrinsic)
9076
+ // todo: RVV implementation
9077
+ const int np = 0;
9075
9078
  #else
9076
9079
  const int np = (nc & ~(GGML_F32_STEP - 1));
9077
9080
 
@@ -9087,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9087
9090
  for (int i = 0; i < np; i += GGML_F32_STEP) {
9088
9091
  for (int j = 0; j < GGML_F32_ARR; j++) {
9089
9092
  ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
9090
- ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
9091
- az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
9093
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
9094
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
9092
9095
 
9093
9096
  ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
9094
9097
  ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -9110,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9110
9113
  // d_state
9111
9114
  for (int i0 = np; i0 < nc; ++i0) {
9112
9115
  const int i = i0 + ii*nc;
9113
- const int ig = i0 + (h & (ng - 1))*nc;
9116
+ const int ig = i0 + g*nc;
9114
9117
  // state = prev_state * dA + dB * x
9115
9118
  const float state = (s0[i] * dA) + (B[ig] * x_dt);
9116
9119
  // y = rowwise_dotprod(state, C)
@@ -9127,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9127
9130
  for (int h = ih0; h < ih1; ++h) {
9128
9131
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9129
9132
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
9133
+ const int g = h / (nh / ng); // repeat_interleave
9130
9134
 
9131
9135
  // dim
9132
9136
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -9141,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
9141
9145
  // TODO: what happens when (d_state % svcntw()) != 0?
9142
9146
  for (int64_t k = 0; k < nc; k += svcntw()) {
9143
9147
  svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
9144
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
9145
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
9148
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
9149
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
9146
9150
  svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9147
9151
 
9148
9152
  svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9162,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
9162
9166
  // d_state
9163
9167
  for (int i0 = 0; i0 < nc; ++i0) {
9164
9168
  const int i = i0 + ii*nc;
9165
- const int ig = i0 + (h & (ng - 1))*nc;
9169
+ const int ig = i0 + g*nc;
9166
9170
  // state = prev_state * dA + dB * x
9167
9171
  const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9168
9172
  // y = rowwise_dotprod(state, C)
@@ -10023,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
10023
10027
  int64_t h_stride_2d = head_size * head_size;
10024
10028
 
10025
10029
  #if defined(GGML_SIMD)
10026
- #if defined(__ARM_FEATURE_SVE)
10027
- // scalar Route to scalar implementation //TODO: Write SVE code
10030
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
10031
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
10028
10032
  for (int64_t t = 0; t < T; t++) {
10029
10033
  int64_t t_offset = t * t_stride;
10030
10034
  int64_t state_offset = head_size * C * (t / (T / n_seqs));