llama_cpp 0.14.6 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,22 +65,6 @@
65
65
  #define VECTOR_REGISTERS 16
66
66
  #endif
67
67
 
68
- // there will be blocks
69
- #define BEGIN_KERNEL(RM, RN) \
70
- int ytiles = (m - m0) / RM; \
71
- int xtiles = (n - n0) / RN; \
72
- int tiles = ytiles * xtiles; \
73
- int duty = (tiles + nth - 1) / nth; \
74
- int start = duty * ith; \
75
- int end = start + duty; \
76
- if (end > tiles) \
77
- end = tiles; \
78
- for (int job = start; job < end; ++job) { \
79
- int i = m0 + job / xtiles * RM; \
80
- int j = n0 + job % xtiles * RN;
81
-
82
- #define END_KERNEL() }
83
-
84
68
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
85
69
 
86
70
  namespace {
@@ -122,6 +106,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
122
106
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
123
107
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
124
108
 
109
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
110
+ // VECTORIZED FUSED MULTIPLY ADD
111
+
112
+ /**
113
+ * Computes a * b + c.
114
+ */
115
+ template <typename T, typename U>
116
+ inline U madd(T a, T b, U c) {
117
+ return add(mul(a, b), c);
118
+ }
119
+
120
+ #if defined(__FMA__)
121
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
122
+ template <>
123
+ inline __m256 madd(__m256 a, __m256 b, __m256 c) {
124
+ return _mm256_fmadd_ps(a, b, c);
125
+ }
126
+ #endif
127
+ #if defined(__AVX512F__)
128
+ template <>
129
+ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
130
+ return _mm512_fmadd_ps(a, b, c);
131
+ }
132
+ #endif
133
+ #endif
134
+
135
+ #if defined(__ARM_FEATURE_FMA)
136
+ template <>
137
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
138
+ return vfmaq_f32(c, b, a);
139
+ }
140
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
141
+ template <>
142
+ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
143
+ return vfmaq_f16(c, b, a);
144
+ }
145
+ #endif
146
+ #endif
147
+
125
148
  ////////////////////////////////////////////////////////////////////////////////////////////////////
126
149
  // VECTORIZED HORIZONTAL SUM
127
150
 
@@ -213,287 +236,210 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
213
236
  }
214
237
  #endif // __AVX512F__
215
238
 
216
- ////////////////////////////////////////////////////////////////////////////////////////////////////
217
- // ABSTRACTIONS
218
-
219
- /**
220
- * Computes a * b + c.
221
- *
222
- * This operation will become fused into a single arithmetic instruction
223
- * if the hardware has support for this feature, e.g. Intel Haswell+ (c.
224
- * 2013), AMD Bulldozer+ (c. 2011), etc.
225
- */
226
- template <typename T, typename U>
227
- inline U madd(T a, T b, U c) {
228
- return add(mul(a, b), c);
229
- }
230
-
231
- /**
232
- * Computes a * b + c with error correction.
233
- *
234
- * @see W. Kahan, "Further remarks on reducing truncation errors,"
235
- * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
236
- * doi: 10.1145/363707.363723.
237
- */
238
- template <typename T, typename U>
239
- inline U madder(T a, T b, U c, U *e) {
240
- U y = sub(mul(a, b), *e);
241
- U t = add(c, y);
242
- *e = sub(sub(t, c), y);
243
- return t;
244
- }
245
-
246
239
  ////////////////////////////////////////////////////////////////////////////////////////////////////
247
240
  // FLOATING POINT MATRIX MULTIPLICATION
248
241
 
249
242
  template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
250
243
  class tinyBLAS {
251
244
  public:
252
- tinyBLAS(int k,
253
- const TA *A, int lda,
254
- const TB *B, int ldb,
255
- TC *C, int ldc,
245
+ tinyBLAS(int64_t k,
246
+ const TA *A, int64_t lda,
247
+ const TB *B, int64_t ldb,
248
+ TC *C, int64_t ldc,
256
249
  int ith, int nth)
257
250
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
258
251
  }
259
252
 
260
- void matmul(int m, int n, int task) {
253
+ void matmul(int64_t m, int64_t n, int task) {
261
254
  if (task == GGML_TASK_TYPE_COMPUTE)
262
255
  mnpack(0, m, 0, n);
263
256
  }
264
257
 
265
258
  private:
266
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
267
- int mc, nc, mp, np;
268
- if (m - m0 <= 0 || n - n0 <= 0)
269
- return;
270
- if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) {
259
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
260
+ int64_t mc, nc, mp, np;
261
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
262
+ #if VECTOR_REGISTERS == 32
263
+ case 0x55:
264
+ mc = 5;
265
+ nc = 5;
266
+ gemm<5, 5>(m0, m, n0, n);
267
+ break;
268
+ case 0x45:
269
+ mc = 4;
270
+ nc = 5;
271
+ gemm<4, 5>(m0, m, n0, n);
272
+ break;
273
+ case 0x54:
271
274
  mc = 5;
275
+ nc = 4;
276
+ gemm<5, 4>(m0, m, n0, n);
277
+ break;
278
+ case 0x44:
279
+ mc = 4;
280
+ nc = 4;
281
+ gemm<4, 4>(m0, m, n0, n);
282
+ break;
283
+ case 0x53:
284
+ mc = 5;
285
+ nc = 3;
286
+ gemm<5, 3>(m0, m, n0, n);
287
+ break;
288
+ case 0x35:
289
+ mc = 3;
272
290
  nc = 5;
273
- gemm5x5(m0, m, n0, n);
274
- } else if (n - n0 >= 4 && m - m0 >= 3) {
291
+ gemm<3, 5>(m0, m, n0, n);
292
+ break;
293
+ case 0x43:
294
+ mc = 4;
295
+ nc = 3;
296
+ gemm<4, 3>(m0, m, n0, n);
297
+ break;
298
+ #else
299
+ case 0x55:
300
+ case 0x54:
301
+ case 0x53:
302
+ case 0x45:
303
+ case 0x44:
304
+ case 0x43:
305
+ mc = 4;
306
+ nc = 3;
307
+ gemm<4, 3>(m0, m, n0, n);
308
+ break;
309
+ case 0x35:
310
+ #endif
311
+ case 0x34:
275
312
  mc = 3;
276
313
  nc = 4;
277
- gemm3x4(m0, m, n0, n);
278
- } else if (n - n0 >= 4) {
279
- mc = 1;
314
+ gemm<3, 4>(m0, m, n0, n);
315
+ break;
316
+ case 0x52:
317
+ mc = 5;
318
+ nc = 2;
319
+ gemm<5, 2>(m0, m, n0, n);
320
+ break;
321
+ case 0x33:
322
+ mc = 3;
323
+ nc = 3;
324
+ gemm<3, 3>(m0, m, n0, n);
325
+ break;
326
+ case 0x25:
327
+ mc = 2;
328
+ nc = 5;
329
+ gemm<2, 5>(m0, m, n0, n);
330
+ break;
331
+ case 0x42:
332
+ mc = 4;
333
+ nc = 2;
334
+ gemm<4, 2>(m0, m, n0, n);
335
+ break;
336
+ case 0x24:
337
+ mc = 2;
280
338
  nc = 4;
281
- gemm1x4(m0, m, n0, n);
282
- } else if (m - m0 >= 4) {
339
+ gemm<2, 4>(m0, m, n0, n);
340
+ break;
341
+ case 0x32:
342
+ mc = 3;
343
+ nc = 2;
344
+ gemm<3, 2>(m0, m, n0, n);
345
+ break;
346
+ case 0x23:
347
+ mc = 2;
348
+ nc = 3;
349
+ gemm<2, 3>(m0, m, n0, n);
350
+ break;
351
+ case 0x51:
352
+ mc = 5;
353
+ nc = 1;
354
+ gemm<5, 1>(m0, m, n0, n);
355
+ break;
356
+ case 0x41:
283
357
  mc = 4;
284
358
  nc = 1;
285
- gemm4x1(m0, m, n0, n);
286
- } else {
359
+ gemm<4, 1>(m0, m, n0, n);
360
+ break;
361
+ case 0x22:
362
+ mc = 2;
363
+ nc = 2;
364
+ gemm<2, 2>(m0, m, n0, n);
365
+ break;
366
+ case 0x15:
367
+ mc = 1;
368
+ nc = 5;
369
+ gemm<1, 5>(m0, m, n0, n);
370
+ break;
371
+ case 0x14:
287
372
  mc = 1;
373
+ nc = 4;
374
+ gemm<1, 4>(m0, m, n0, n);
375
+ break;
376
+ case 0x31:
377
+ mc = 3;
288
378
  nc = 1;
289
- gemm1x1(m0, m, n0, n);
379
+ gemm<3, 1>(m0, m, n0, n);
380
+ break;
381
+ case 0x13:
382
+ mc = 1;
383
+ nc = 3;
384
+ gemm<1, 3>(m0, m, n0, n);
385
+ break;
386
+ case 0x21:
387
+ mc = 2;
388
+ nc = 1;
389
+ gemm<2, 1>(m0, m, n0, n);
390
+ break;
391
+ case 0x12:
392
+ mc = 1;
393
+ nc = 2;
394
+ gemm<1, 2>(m0, m, n0, n);
395
+ break;
396
+ case 0x11:
397
+ mc = 1;
398
+ nc = 1;
399
+ gemm<1, 1>(m0, m, n0, n);
400
+ break;
401
+ default:
402
+ return;
290
403
  }
291
404
  mp = m0 + (m - m0) / mc * mc;
292
405
  np = n0 + (n - n0) / nc * nc;
293
406
  mnpack(mp, m, n0, np);
294
- mnpack(m0, mp, np, n);
295
- mnpack(mp, m, np, n);
296
- }
297
-
298
- NOINLINE void gemm5x5(int m0, int m, int n0, int n) {
299
- BEGIN_KERNEL(5, 5)
300
- D c00 = {0};
301
- D c01 = {0};
302
- D c02 = {0};
303
- D c03 = {0};
304
- D c04 = {0};
305
- D c10 = {0};
306
- D c11 = {0};
307
- D c12 = {0};
308
- D c13 = {0};
309
- D c14 = {0};
310
- D c20 = {0};
311
- D c21 = {0};
312
- D c22 = {0};
313
- D c23 = {0};
314
- D c24 = {0};
315
- D c30 = {0};
316
- D c31 = {0};
317
- D c32 = {0};
318
- D c33 = {0};
319
- D c34 = {0};
320
- D c40 = {0};
321
- D c41 = {0};
322
- D c42 = {0};
323
- D c43 = {0};
324
- D c44 = {0};
325
- for (int l = 0; l < k; l += KN) {
326
- V k0 = load<V>(B + ldb * (j + 0) + l);
327
- V k1 = load<V>(B + ldb * (j + 1) + l);
328
- V k2 = load<V>(B + ldb * (j + 2) + l);
329
- V k3 = load<V>(B + ldb * (j + 3) + l);
330
- V k4 = load<V>(B + ldb * (j + 4) + l);
331
- V a0 = load<V>(A + lda * (i + 0) + l);
332
- c00 = madd(a0, k0, c00);
333
- c01 = madd(a0, k1, c01);
334
- c02 = madd(a0, k2, c02);
335
- c03 = madd(a0, k3, c03);
336
- c04 = madd(a0, k4, c04);
337
- V a1 = load<V>(A + lda * (i + 1) + l);
338
- c10 = madd(a1, k0, c10);
339
- c11 = madd(a1, k1, c11);
340
- c12 = madd(a1, k2, c12);
341
- c13 = madd(a1, k3, c13);
342
- c14 = madd(a1, k4, c14);
343
- V a2 = load<V>(A + lda * (i + 2) + l);
344
- c20 = madd(a2, k0, c20);
345
- c21 = madd(a2, k1, c21);
346
- c22 = madd(a2, k2, c22);
347
- c23 = madd(a2, k3, c23);
348
- c24 = madd(a2, k4, c24);
349
- V a3 = load<V>(A + lda * (i + 3) + l);
350
- c30 = madd(a3, k0, c30);
351
- c31 = madd(a3, k1, c31);
352
- c32 = madd(a3, k2, c32);
353
- c33 = madd(a3, k3, c33);
354
- c34 = madd(a3, k4, c34);
355
- V a4 = load<V>(A + lda * (i + 4) + l);
356
- c40 = madd(a4, k0, c40);
357
- c41 = madd(a4, k1, c41);
358
- c42 = madd(a4, k2, c42);
359
- c43 = madd(a4, k3, c43);
360
- c44 = madd(a4, k4, c44);
361
- }
362
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
363
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
364
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
365
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
366
- C[ldc * (j + 0) + (i + 4)] = hsum(c40);
367
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
368
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
369
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
370
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
371
- C[ldc * (j + 1) + (i + 4)] = hsum(c41);
372
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
373
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
374
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
375
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
376
- C[ldc * (j + 2) + (i + 4)] = hsum(c42);
377
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
378
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
379
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
380
- C[ldc * (j + 3) + (i + 3)] = hsum(c33);
381
- C[ldc * (j + 3) + (i + 4)] = hsum(c43);
382
- C[ldc * (j + 4) + (i + 0)] = hsum(c04);
383
- C[ldc * (j + 4) + (i + 1)] = hsum(c14);
384
- C[ldc * (j + 4) + (i + 2)] = hsum(c24);
385
- C[ldc * (j + 4) + (i + 3)] = hsum(c34);
386
- C[ldc * (j + 4) + (i + 4)] = hsum(c44);
387
- END_KERNEL()
388
- }
389
-
390
- NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
391
- BEGIN_KERNEL(3, 4)
392
- D c00 = {0};
393
- D c01 = {0};
394
- D c02 = {0};
395
- D c03 = {0};
396
- D c10 = {0};
397
- D c11 = {0};
398
- D c12 = {0};
399
- D c13 = {0};
400
- D c20 = {0};
401
- D c21 = {0};
402
- D c22 = {0};
403
- D c23 = {0};
404
- for (int l = 0; l < k; l += KN) {
405
- V k0 = load<V>(B + ldb * (j + 0) + l);
406
- V k1 = load<V>(B + ldb * (j + 1) + l);
407
- V k2 = load<V>(B + ldb * (j + 2) + l);
408
- V k3 = load<V>(B + ldb * (j + 3) + l);
409
- V a0 = load<V>(A + lda * (i + 0) + l);
410
- c00 = madd(a0, k0, c00);
411
- c01 = madd(a0, k1, c01);
412
- c02 = madd(a0, k2, c02);
413
- c03 = madd(a0, k3, c03);
414
- V a1 = load<V>(A + lda * (i + 1) + l);
415
- c10 = madd(a1, k0, c10);
416
- c11 = madd(a1, k1, c11);
417
- c12 = madd(a1, k2, c12);
418
- c13 = madd(a1, k3, c13);
419
- V a2 = load<V>(A + lda * (i + 2) + l);
420
- c20 = madd(a2, k0, c20);
421
- c21 = madd(a2, k1, c21);
422
- c22 = madd(a2, k2, c22);
423
- c23 = madd(a2, k3, c23);
424
- }
425
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
426
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
427
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
428
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
429
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
430
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
431
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
432
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
433
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
434
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
435
- C[ldc * (j + 3) + (i + 1)] = hsum(c13);
436
- C[ldc * (j + 3) + (i + 2)] = hsum(c23);
437
- END_KERNEL()
438
- }
439
-
440
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
441
- BEGIN_KERNEL(1, 4)
442
- D c00 = {0}, e00 = {0};
443
- D c01 = {0}, e01 = {0};
444
- D c02 = {0}, e02 = {0};
445
- D c03 = {0}, e03 = {0};
446
- for (int l = 0; l < k; l += KN) {
447
- V a = load<V>(A + lda * (i + 0) + l);
448
- c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
449
- c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
450
- c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
451
- c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
452
- }
453
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
454
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
455
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
456
- C[ldc * (j + 3) + (i + 0)] = hsum(c03);
457
- END_KERNEL()
458
- }
459
-
460
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
461
- BEGIN_KERNEL(4, 1)
462
- D c00 = {0}, e00 = {0};
463
- D c10 = {0}, e10 = {0};
464
- D c20 = {0}, e20 = {0};
465
- D c30 = {0}, e30 = {0};
466
- for (int l = 0; l < k; l += KN) {
467
- V b = load<V>(B + ldb * (j + 0) + l);
468
- c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
469
- c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
470
- c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
471
- c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
407
+ mnpack(m0, m, np, n);
408
+ }
409
+
410
+ template <int RM, int RN>
411
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
412
+ int64_t ytiles = (m - m0) / RM;
413
+ int64_t xtiles = (n - n0) / RN;
414
+ int64_t tiles = xtiles * ytiles;
415
+ int64_t duty = (tiles + nth - 1) / nth;
416
+ int64_t start = duty * ith;
417
+ int64_t end = start + duty;
418
+ if (end > tiles)
419
+ end = tiles;
420
+ for (int64_t job = start; job < end; ++job) {
421
+ int64_t ii = m0 + job / xtiles * RM;
422
+ int64_t jj = n0 + job % xtiles * RN;
423
+ D Cv[RN][RM] = {};
424
+ for (int64_t l = 0; l < k; l += KN)
425
+ for (int64_t j = 0; j < RN; ++j)
426
+ for (int64_t i = 0; i < RM; ++i)
427
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
428
+ load<V>(B + ldb * (jj + j) + l),
429
+ Cv[j][i]);
430
+ for (int64_t j = 0; j < RN; ++j)
431
+ for (int64_t i = 0; i < RM; ++i)
432
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
472
433
  }
473
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
474
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
475
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
476
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
477
- END_KERNEL()
478
- }
479
-
480
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
481
- BEGIN_KERNEL(1, 1)
482
- D c = {0}, e = {0};
483
- for (int l = 0; l < k; l += KN)
484
- c = madder(load<V>(A + lda * i + l),
485
- load<V>(B + ldb * j + l), c, &e);
486
- C[ldc * j + i] = hsum(c);
487
- END_KERNEL()
488
434
  }
489
435
 
490
436
  const TA *const A;
491
437
  const TB *const B;
492
438
  TC *const C;
493
- const int k;
494
- const int lda;
495
- const int ldb;
496
- const int ldc;
439
+ const int64_t k;
440
+ const int64_t lda;
441
+ const int64_t ldb;
442
+ const int64_t ldc;
497
443
  const int ith;
498
444
  const int nth;
499
445
  };
@@ -505,136 +451,113 @@ class tinyBLAS {
505
451
  template <typename TA>
506
452
  class tinyBLAS_Q0_ARM {
507
453
  public:
508
- tinyBLAS_Q0_ARM(int k,
509
- const TA *A, int lda,
510
- const block_q8_0 *B, int ldb,
511
- float *C, int ldc,
454
+ tinyBLAS_Q0_ARM(int64_t k,
455
+ const TA *A, int64_t lda,
456
+ const block_q8_0 *B, int64_t ldb,
457
+ float *C, int64_t ldc,
512
458
  int ith, int nth)
513
459
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
514
460
  }
515
461
 
516
- void matmul(int m, int n, int task) {
462
+ void matmul(int64_t m, int64_t n, int task) {
517
463
  if (task == GGML_TASK_TYPE_COMPUTE)
518
464
  mnpack(0, m, 0, n);
519
465
  }
520
466
 
521
467
  private:
522
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
523
- int mc, nc, mp, np;
524
- if (m - m0 <= 0 || n - n0 <= 0)
525
- return;
526
- if (m - m0 >= 3 && n - n0 >= 3) {
468
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
469
+ int64_t mc, nc, mp, np;
470
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
471
+ case 0x33:
527
472
  mc = 3;
528
473
  nc = 3;
529
- gemm3x3(m0, m, n0, n);
530
- } else {
474
+ gemm<3, 3>(m0, m, n0, n);
475
+ break;
476
+ case 0x32:
477
+ mc = 3;
478
+ nc = 2;
479
+ gemm<3, 2>(m0, m, n0, n);
480
+ break;
481
+ case 0x23:
482
+ mc = 2;
483
+ nc = 3;
484
+ gemm<2, 3>(m0, m, n0, n);
485
+ break;
486
+ case 0x22:
487
+ mc = 2;
488
+ nc = 2;
489
+ gemm<2, 2>(m0, m, n0, n);
490
+ break;
491
+ case 0x31:
492
+ mc = 3;
493
+ nc = 1;
494
+ gemm<3, 1>(m0, m, n0, n);
495
+ break;
496
+ case 0x13:
497
+ mc = 1;
498
+ nc = 3;
499
+ gemm<1, 3>(m0, m, n0, n);
500
+ break;
501
+ case 0x21:
502
+ mc = 2;
503
+ nc = 1;
504
+ gemm<2, 1>(m0, m, n0, n);
505
+ break;
506
+ case 0x12:
507
+ mc = 1;
508
+ nc = 2;
509
+ gemm<1, 2>(m0, m, n0, n);
510
+ break;
511
+ case 0x11:
531
512
  mc = 1;
532
513
  nc = 1;
533
- gemm1x1(m0, m, n0, n);
514
+ gemm<1, 1>(m0, m, n0, n);
515
+ break;
516
+ default:
517
+ return;
534
518
  }
535
519
  mp = m0 + (m - m0) / mc * mc;
536
520
  np = n0 + (n - n0) / nc * nc;
537
521
  mnpack(mp, m, n0, np);
538
- mnpack(m0, mp, np, n);
539
- mnpack(mp, m, np, n);
540
- }
541
-
542
- NOINLINE void gemm3x3(int m0, int m, int n0, int n) {
543
- BEGIN_KERNEL(3, 3)
544
- int32x4_t zero = vdupq_n_s32(0);
545
- float32x4_t c00 = vdupq_n_f32(0.f);
546
- float32x4_t c01 = vdupq_n_f32(0.f);
547
- float32x4_t c02 = vdupq_n_f32(0.f);
548
- float32x4_t c10 = vdupq_n_f32(0.f);
549
- float32x4_t c11 = vdupq_n_f32(0.f);
550
- float32x4_t c12 = vdupq_n_f32(0.f);
551
- float32x4_t c20 = vdupq_n_f32(0.f);
552
- float32x4_t c21 = vdupq_n_f32(0.f);
553
- float32x4_t c22 = vdupq_n_f32(0.f);
554
- const TA *Ap0 = A + lda * (i + 0);
555
- const TA *Ap1 = A + lda * (i + 1);
556
- const TA *Ap2 = A + lda * (i + 2);
557
- const block_q8_0 *Bp0 = B + ldb * (j + 0);
558
- const block_q8_0 *Bp1 = B + ldb * (j + 1);
559
- const block_q8_0 *Bp2 = B + ldb * (j + 2);
560
- for (int l = 0; l < k; ++l) {
561
- c00 = vmlaq_n_f32(
562
- c00,
563
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)),
564
- load_hi(Ap0 + l), load_hi(Bp0 + l))),
565
- unhalf(Ap0[l].d) * unhalf(Bp0[l].d));
566
- c01 = vmlaq_n_f32(
567
- c01,
568
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)),
569
- load_hi(Ap0 + l), load_hi(Bp1 + l))),
570
- unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
571
- c02 = vmlaq_n_f32(
572
- c02,
573
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
574
- load_hi(Ap0 + l), load_hi(Bp2 + l))),
575
- unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
576
- c10 = vmlaq_n_f32(
577
- c10,
578
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
579
- load_hi(Ap1 + l), load_hi(Bp0 + l))),
580
- unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
581
- c11 = vmlaq_n_f32(
582
- c11,
583
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
584
- load_hi(Ap1 + l), load_hi(Bp1 + l))),
585
- unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
586
- c12 = vmlaq_n_f32(
587
- c12,
588
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
589
- load_hi(Ap1 + l), load_hi(Bp2 + l))),
590
- unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
591
- c20 = vmlaq_n_f32(
592
- c20,
593
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
594
- load_hi(Ap2 + l), load_hi(Bp0 + l))),
595
- unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
596
- c21 = vmlaq_n_f32(
597
- c21,
598
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
599
- load_hi(Ap2 + l), load_hi(Bp1 + l))),
600
- unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
601
- c22 = vmlaq_n_f32(
602
- c22,
603
- vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
604
- load_hi(Ap2 + l), load_hi(Bp2 + l))),
605
- unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
606
- }
607
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
608
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
609
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
610
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
611
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
612
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
613
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
614
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
615
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
616
- END_KERNEL()
617
- }
618
-
619
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
620
- BEGIN_KERNEL(1, 1)
621
- float32x4_t acc = vdupq_n_f32(0.f);
622
- const TA *Ap = A + lda * i;
623
- const block_q8_0 *Bp = B + ldb * j;
624
- for (int l = 0; l < k; ++l) {
625
- acc = vmlaq_n_f32(acc,
626
- vcvtq_f32_s32(vdotq_s32(
627
- vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
628
- load_hi(Ap + l), load_hi(Bp + l))),
629
- unhalf(Ap[l].d) * unhalf(Bp[l].d));
522
+ mnpack(m0, m, np, n);
523
+ }
524
+
525
+ template <int RM, int RN>
526
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
527
+ int64_t ytiles = (m - m0) / RM;
528
+ int64_t xtiles = (n - n0) / RN;
529
+ int64_t tiles = xtiles * ytiles;
530
+ int64_t duty = (tiles + nth - 1) / nth;
531
+ int64_t start = duty * ith;
532
+ int64_t end = start + duty;
533
+ if (end > tiles)
534
+ end = tiles;
535
+ for (int64_t job = start; job < end; ++job) {
536
+ int64_t ii = m0 + job / xtiles * RM;
537
+ int64_t jj = n0 + job % xtiles * RN;
538
+ float32x4_t Cv[RN][RM] = {};
539
+ for (int64_t l = 0; l < k; ++l)
540
+ for (int64_t j = 0; j < RN; ++j)
541
+ for (int64_t i = 0; i < RM; ++i)
542
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
543
+ vcvtq_f32_s32(vdotq_s32(
544
+ vdotq_s32(vdupq_n_s32(0),
545
+ load_lo(A + lda * (ii + i) + l),
546
+ load_lo(B + ldb * (jj + j) + l)),
547
+ load_hi(A + lda * (ii + i) + l),
548
+ load_hi(B + ldb * (jj + j) + l))),
549
+ unhalf(A[lda * (ii + i) + l].d) *
550
+ unhalf(B[ldb * (jj + j) + l].d));
551
+ for (int64_t j = 0; j < RN; ++j)
552
+ for (int64_t i = 0; i < RM; ++i)
553
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
630
554
  }
631
- C[ldc * j + i] = hsum(acc);
632
- END_KERNEL()
633
555
  }
634
556
 
635
557
  inline int8x16_t load_lo(const block_q8_0 *b) {
636
558
  return vld1q_s8(b->qs);
637
559
  }
560
+
638
561
  inline int8x16_t load_hi(const block_q8_0 *b) {
639
562
  return vld1q_s8(b->qs + 16);
640
563
  }
@@ -644,6 +567,7 @@ class tinyBLAS_Q0_ARM {
644
567
  vdupq_n_u8(0x0f))),
645
568
  vdupq_n_s8(0x8));
646
569
  }
570
+
647
571
  inline int8x16_t load_hi(const block_q4_0 *b) {
648
572
  return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
649
573
  vdupq_n_s8(0x8));
@@ -652,10 +576,10 @@ class tinyBLAS_Q0_ARM {
652
576
  const TA *const A;
653
577
  const block_q8_0 *const B;
654
578
  float *const C;
655
- const int k;
656
- const int lda;
657
- const int ldb;
658
- const int ldc;
579
+ const int64_t k;
580
+ const int64_t lda;
581
+ const int64_t ldb;
582
+ const int64_t ldc;
659
583
  const int ith;
660
584
  const int nth;
661
585
  };
@@ -665,231 +589,157 @@ class tinyBLAS_Q0_ARM {
665
589
  template <typename TA, typename TB, typename TC>
666
590
  class tinyBLAS_Q0_AVX2 {
667
591
  public:
668
- tinyBLAS_Q0_AVX2(int k,
669
- const TA *A, int lda,
670
- const TB *B, int ldb,
671
- TC *C, int ldc,
592
+ tinyBLAS_Q0_AVX2(int64_t k,
593
+ const TA *A, int64_t lda,
594
+ const TB *B, int64_t ldb,
595
+ TC *C, int64_t ldc,
672
596
  int ith, int nth)
673
597
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
674
598
  }
675
599
 
676
- void matmul(int m, int n, int task) {
600
+ void matmul(int64_t m, int64_t n, int task) {
677
601
  if (task == GGML_TASK_TYPE_COMPUTE)
678
602
  mnpack(0, m, 0, n);
679
603
  }
680
604
 
681
605
  private:
682
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
683
- int mc, nc, mp, np;
684
- if (m - m0 <= 0 || n - n0 <= 0)
685
- return;
686
- if (m - m0 >= 4 && n - n0 >= 3) {
606
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
607
+ int64_t mc, nc, mp, np;
608
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
609
+ #if VECTOR_REGISTERS == 32
610
+ case 0x44:
611
+ mc = 4;
612
+ nc = 4;
613
+ gemm<4, 4>(m0, m, n0, n);
614
+ break;
615
+ case 0x43:
616
+ mc = 4;
617
+ nc = 3;
618
+ gemm<4, 3>(m0, m, n0, n);
619
+ break;
620
+ case 0x34:
621
+ mc = 3;
622
+ nc = 4;
623
+ gemm<3, 4>(m0, m, n0, n);
624
+ break;
625
+ case 0x33:
626
+ mc = 3;
627
+ nc = 3;
628
+ gemm<3, 3>(m0, m, n0, n);
629
+ break;
630
+ case 0x42:
631
+ mc = 4;
632
+ nc = 2;
633
+ gemm<4, 2>(m0, m, n0, n);
634
+ break;
635
+ case 0x24:
636
+ mc = 2;
637
+ nc = 4;
638
+ gemm<2, 4>(m0, m, n0, n);
639
+ break;
640
+ #else
641
+ case 0x44:
642
+ case 0x43:
643
+ case 0x42:
687
644
  mc = 4;
645
+ nc = 2;
646
+ gemm<4, 2>(m0, m, n0, n);
647
+ break;
648
+ case 0x34:
649
+ case 0x24:
650
+ mc = 2;
651
+ nc = 4;
652
+ gemm<2, 4>(m0, m, n0, n);
653
+ break;
654
+ case 0x33:
655
+ #endif
656
+ case 0x32:
657
+ mc = 3;
658
+ nc = 2;
659
+ gemm<3, 2>(m0, m, n0, n);
660
+ break;
661
+ case 0x23:
662
+ mc = 2;
688
663
  nc = 3;
689
- gemm4x3(m0, m, n0, n);
690
- } else if (m - m0 >= 4 && n - n0 >= 1) {
664
+ gemm<2, 3>(m0, m, n0, n);
665
+ break;
666
+ case 0x41:
691
667
  mc = 4;
692
668
  nc = 1;
693
- gemm4x1(m0, m, n0, n);
694
- } else if (m - m0 >= 1 && n - n0 >= 4) {
669
+ gemm<4, 1>(m0, m, n0, n);
670
+ break;
671
+ case 0x22:
672
+ mc = 2;
673
+ nc = 2;
674
+ gemm<2, 2>(m0, m, n0, n);
675
+ break;
676
+ case 0x14:
695
677
  mc = 1;
696
678
  nc = 4;
697
- gemm1x4(m0, m, n0, n);
698
- } else {
679
+ gemm<1, 4>(m0, m, n0, n);
680
+ break;
681
+ case 0x31:
682
+ mc = 3;
683
+ nc = 1;
684
+ gemm<3, 1>(m0, m, n0, n);
685
+ break;
686
+ case 0x13:
687
+ mc = 1;
688
+ nc = 3;
689
+ gemm<1, 3>(m0, m, n0, n);
690
+ break;
691
+ case 0x21:
692
+ mc = 2;
693
+ nc = 1;
694
+ gemm<2, 1>(m0, m, n0, n);
695
+ break;
696
+ case 0x12:
697
+ mc = 1;
698
+ nc = 2;
699
+ gemm<1, 2>(m0, m, n0, n);
700
+ break;
701
+ case 0x11:
699
702
  mc = 1;
700
703
  nc = 1;
701
- gemm1x1(m0, m, n0, n);
704
+ gemm<1, 1>(m0, m, n0, n);
705
+ break;
706
+ default:
707
+ return;
702
708
  }
703
709
  mp = m0 + (m - m0) / mc * mc;
704
710
  np = n0 + (n - n0) / nc * nc;
705
711
  mnpack(mp, m, n0, np);
706
- mnpack(m0, mp, np, n);
707
- mnpack(mp, m, np, n);
708
- }
709
-
710
- NOINLINE void gemm4x3(int m0, int m, int n0, int n) {
711
- BEGIN_KERNEL(4, 3)
712
- __m256 c00 = _mm256_setzero_ps();
713
- __m256 c10 = _mm256_setzero_ps();
714
- __m256 c20 = _mm256_setzero_ps();
715
- __m256 c30 = _mm256_setzero_ps();
716
- __m256 c01 = _mm256_setzero_ps();
717
- __m256 c11 = _mm256_setzero_ps();
718
- __m256 c21 = _mm256_setzero_ps();
719
- __m256 c31 = _mm256_setzero_ps();
720
- __m256 c02 = _mm256_setzero_ps();
721
- __m256 c12 = _mm256_setzero_ps();
722
- __m256 c22 = _mm256_setzero_ps();
723
- __m256 c32 = _mm256_setzero_ps();
724
- const TA *Ap0 = A + lda * (i + 0);
725
- const TA *Ap1 = A + lda * (i + 1);
726
- const TA *Ap2 = A + lda * (i + 2);
727
- const TA *Ap3 = A + lda * (i + 3);
728
- const TB *Bp0 = B + ldb * (j + 0);
729
- const TB *Bp1 = B + ldb * (j + 1);
730
- const TB *Bp2 = B + ldb * (j + 2);
731
- for (int l = 0; l < k; ++l) {
732
- float da0 = unhalf(Ap0[l].d);
733
- float da1 = unhalf(Ap1[l].d);
734
- float da2 = unhalf(Ap2[l].d);
735
- float da3 = unhalf(Ap3[l].d);
736
- __m256i e0 = load(Ap0 + l);
737
- __m256i e1 = load(Ap1 + l);
738
- __m256i e2 = load(Ap2 + l);
739
- __m256i e3 = load(Ap3 + l);
740
- float db0 = unhalf(Bp0[l].d);
741
- __m256 d00 = _mm256_set1_ps(da0 * db0);
742
- __m256 d10 = _mm256_set1_ps(da1 * db0);
743
- __m256 d20 = _mm256_set1_ps(da2 * db0);
744
- __m256 d30 = _mm256_set1_ps(da3 * db0);
745
- __m256i f0 = load(Bp0 + l);
746
- __m256i u0 = _mm256_sign_epi8(f0, f0);
747
- __m256i s00 = _mm256_sign_epi8(e0, f0);
748
- __m256i s10 = _mm256_sign_epi8(e1, f0);
749
- __m256i s20 = _mm256_sign_epi8(e2, f0);
750
- __m256i s30 = _mm256_sign_epi8(e3, f0);
751
- c00 = madd(d00, updot(u0, s00), c00);
752
- c10 = madd(d10, updot(u0, s10), c10);
753
- c20 = madd(d20, updot(u0, s20), c20);
754
- c30 = madd(d30, updot(u0, s30), c30);
755
- float db1 = unhalf(Bp1[l].d);
756
- __m256 d01 = _mm256_set1_ps(da0 * db1);
757
- __m256 d11 = _mm256_set1_ps(da1 * db1);
758
- __m256 d21 = _mm256_set1_ps(da2 * db1);
759
- __m256 d31 = _mm256_set1_ps(da3 * db1);
760
- __m256i f1 = load(Bp1 + l);
761
- __m256i u1 = _mm256_sign_epi8(f1, f1);
762
- __m256i s01 = _mm256_sign_epi8(e0, f1);
763
- __m256i s11 = _mm256_sign_epi8(e1, f1);
764
- __m256i s21 = _mm256_sign_epi8(e2, f1);
765
- __m256i s31 = _mm256_sign_epi8(e3, f1);
766
- c01 = madd(d01, updot(u1, s01), c01);
767
- c11 = madd(d11, updot(u1, s11), c11);
768
- c21 = madd(d21, updot(u1, s21), c21);
769
- c31 = madd(d31, updot(u1, s31), c31);
770
- float db2 = unhalf(Bp2[l].d);
771
- __m256 d02 = _mm256_set1_ps(da0 * db2);
772
- __m256 d12 = _mm256_set1_ps(da1 * db2);
773
- __m256 d22 = _mm256_set1_ps(da2 * db2);
774
- __m256 d32 = _mm256_set1_ps(da3 * db2);
775
- __m256i f2 = load(Bp2 + l);
776
- __m256i u2 = _mm256_sign_epi8(f2, f2);
777
- __m256i s02 = _mm256_sign_epi8(e0, f2);
778
- __m256i s12 = _mm256_sign_epi8(e1, f2);
779
- __m256i s22 = _mm256_sign_epi8(e2, f2);
780
- __m256i s32 = _mm256_sign_epi8(e3, f2);
781
- c02 = madd(d02, updot(u2, s02), c02);
782
- c12 = madd(d12, updot(u2, s12), c12);
783
- c22 = madd(d22, updot(u2, s22), c22);
784
- c32 = madd(d32, updot(u2, s32), c32);
785
- }
786
- C[ldc * (j + 0) + (i + 0)] = hsum(c00);
787
- C[ldc * (j + 0) + (i + 1)] = hsum(c10);
788
- C[ldc * (j + 0) + (i + 2)] = hsum(c20);
789
- C[ldc * (j + 0) + (i + 3)] = hsum(c30);
790
- C[ldc * (j + 1) + (i + 0)] = hsum(c01);
791
- C[ldc * (j + 1) + (i + 1)] = hsum(c11);
792
- C[ldc * (j + 1) + (i + 2)] = hsum(c21);
793
- C[ldc * (j + 1) + (i + 3)] = hsum(c31);
794
- C[ldc * (j + 2) + (i + 0)] = hsum(c02);
795
- C[ldc * (j + 2) + (i + 1)] = hsum(c12);
796
- C[ldc * (j + 2) + (i + 2)] = hsum(c22);
797
- C[ldc * (j + 2) + (i + 3)] = hsum(c32);
798
- END_KERNEL()
799
- }
800
-
801
- NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
802
- BEGIN_KERNEL(4, 1)
803
- __m256 c0 = _mm256_setzero_ps();
804
- __m256 c1 = _mm256_setzero_ps();
805
- __m256 c2 = _mm256_setzero_ps();
806
- __m256 c3 = _mm256_setzero_ps();
807
- const TA *Ap0 = A + lda * (i + 0);
808
- const TA *Ap1 = A + lda * (i + 1);
809
- const TA *Ap2 = A + lda * (i + 2);
810
- const TA *Ap3 = A + lda * (i + 3);
811
- const TB *Bp = B + ldb * j;
812
- for (int l = 0; l < k; ++l) {
813
- float db0 = unhalf(Bp[l].d);
814
- __m256i f = load(Bp + l);
815
- __m256i u = _mm256_sign_epi8(f, f);
816
- __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
817
- __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
818
- __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
819
- __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
820
- __m256i e0 = load(Ap0 + l);
821
- __m256i e1 = load(Ap1 + l);
822
- __m256i e2 = load(Ap2 + l);
823
- __m256i e3 = load(Ap3 + l);
824
- __m256i s0 = _mm256_sign_epi8(e0, f);
825
- __m256i s1 = _mm256_sign_epi8(e1, f);
826
- __m256i s2 = _mm256_sign_epi8(e2, f);
827
- __m256i s3 = _mm256_sign_epi8(e3, f);
828
- __m256 g0 = updot(u, s0);
829
- __m256 g1 = updot(u, s1);
830
- __m256 g2 = updot(u, s2);
831
- __m256 g3 = updot(u, s3);
832
- c0 = madd(d0, g0, c0);
833
- c1 = madd(d1, g1, c1);
834
- c2 = madd(d2, g2, c2);
835
- c3 = madd(d3, g3, c3);
836
- }
837
- C[ldc * j + (i + 0)] = hsum(c0);
838
- C[ldc * j + (i + 1)] = hsum(c1);
839
- C[ldc * j + (i + 2)] = hsum(c2);
840
- C[ldc * j + (i + 3)] = hsum(c3);
841
- END_KERNEL()
842
- }
843
-
844
- NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
845
- BEGIN_KERNEL(1, 4)
846
- __m256 c0 = _mm256_setzero_ps();
847
- __m256 c1 = _mm256_setzero_ps();
848
- __m256 c2 = _mm256_setzero_ps();
849
- __m256 c3 = _mm256_setzero_ps();
850
- const TB *Bp0 = B + ldb * (j + 0);
851
- const TB *Bp1 = B + ldb * (j + 1);
852
- const TB *Bp2 = B + ldb * (j + 2);
853
- const TB *Bp3 = B + ldb * (j + 3);
854
- const TA *Ap = A + lda * i;
855
- for (int l = 0; l < k; ++l) {
856
- float da0 = unhalf(Ap[l].d);
857
- __m256i f = load(Ap + l);
858
- __m256i u = _mm256_sign_epi8(f, f);
859
- __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
860
- __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
861
- __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
862
- __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
863
- __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
864
- __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
865
- __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
866
- __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
867
- c0 = madd(d0, g0, c0);
868
- c1 = madd(d1, g1, c1);
869
- c2 = madd(d2, g2, c2);
870
- c3 = madd(d3, g3, c3);
871
- }
872
- C[ldc * (j + 0) + i] = hsum(c0);
873
- C[ldc * (j + 1) + i] = hsum(c1);
874
- C[ldc * (j + 2) + i] = hsum(c2);
875
- C[ldc * (j + 3) + i] = hsum(c3);
876
- END_KERNEL()
877
- }
878
-
879
- NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
880
- BEGIN_KERNEL(1, 1)
881
- __m256 c = _mm256_setzero_ps();
882
- const TA *Ap = A + lda * i;
883
- const TB *Bp = B + ldb * j;
884
- for (int l = 0; l < k; ++l) {
885
- __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
886
- __m256i e = load(Ap + l);
887
- __m256i f = load(Bp + l);
888
- __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
889
- c = madd(d, g, c);
712
+ mnpack(m0, m, np, n);
713
+ }
714
+
715
+ template <int RM, int RN>
716
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
717
+ int64_t ytiles = (m - m0) / RM;
718
+ int64_t xtiles = (n - n0) / RN;
719
+ int64_t tiles = xtiles * ytiles;
720
+ int64_t duty = (tiles + nth - 1) / nth;
721
+ int64_t start = duty * ith;
722
+ int64_t end = start + duty;
723
+ if (end > tiles)
724
+ end = tiles;
725
+ for (int64_t job = start; job < end; ++job) {
726
+ int64_t ii = m0 + job / xtiles * RM;
727
+ int64_t jj = n0 + job % xtiles * RN;
728
+ __m256 Cv[RN][RM] = {};
729
+ for (int64_t l = 0; l < k; ++l)
730
+ for (int64_t j = 0; j < RN; ++j)
731
+ for (int64_t i = 0; i < RM; ++i)
732
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
733
+ unhalf(B[ldb * (jj + j) + l].d)),
734
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
735
+ load(A + lda * (ii + i) + l)),
736
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
737
+ load(A + lda * (ii + i) + l))),
738
+ Cv[j][i]);
739
+ for (int64_t j = 0; j < RN; ++j)
740
+ for (int64_t i = 0; i < RM; ++i)
741
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
890
742
  }
891
- C[ldc * j + i] = hsum(c);
892
- END_KERNEL()
893
743
  }
894
744
 
895
745
  inline __m256i load(const block_q8_0 *b) {
@@ -911,19 +761,19 @@ class tinyBLAS_Q0_AVX2 {
911
761
  }
912
762
 
913
763
  static inline __m256i denibble(const uint8_t *p) {
914
- const __m128i tmp = _mm_loadu_si128((const __m128i *)p);
915
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
916
- const __m256i lowMask = _mm256_set1_epi8(15);
917
- return _mm256_and_si256(lowMask, bytes);
764
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
765
+ return _mm256_and_si256(_mm256_set1_epi8(15),
766
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
767
+ _mm_srli_epi16(x, 4), 1));
918
768
  }
919
769
 
920
770
  const TA *const A;
921
771
  const TB *const B;
922
772
  TC *const C;
923
- const int k;
924
- const int lda;
925
- const int ldb;
926
- const int ldc;
773
+ const int64_t k;
774
+ const int64_t lda;
775
+ const int64_t ldb;
776
+ const int64_t ldc;
927
777
  const int ith;
928
778
  const int nth;
929
779
  };
@@ -962,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
962
812
  * @param Ctype is GGML data type of `C`
963
813
  * @return true if this function was able to service the matmul request
964
814
  */
965
- bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
966
- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
815
+ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
816
+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
967
817
 
968
818
  assert(m >= 0);
969
819
  assert(n >= 0);
@@ -973,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
973
823
  assert(ldc >= m);
974
824
  assert(nth > 0);
975
825
  assert(ith < nth);
976
- assert(1ll * lda * m <= 0x7fffffff);
977
- assert(1ll * ldb * n <= 0x7fffffff);
978
- assert(1ll * ldc * n <= 0x7fffffff);
979
826
 
980
827
  if (Ctype != GGML_TYPE_F32)
981
828
  return false;