llama_cpp 0.14.6 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +90 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +22 -3
- data/vendor/tmp/llama.cpp/Makefile +52 -22
- data/vendor/tmp/llama.cpp/ggml-alloc.c +8 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +21 -15
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +6 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +262 -4
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +376 -176
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-quants.c +284 -293
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +17 -7
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml.c +394 -44
- data/vendor/tmp/llama.cpp/ggml.h +22 -0
- data/vendor/tmp/llama.cpp/llama.cpp +996 -455
- data/vendor/tmp/llama.cpp/llama.h +46 -15
- data/vendor/tmp/llama.cpp/sgemm.cpp +437 -590
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1 -1
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -2
- data/vendor/tmp/llama.cpp/unicode.cpp +448 -39
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +3 -3
@@ -65,22 +65,6 @@
|
|
65
65
|
#define VECTOR_REGISTERS 16
|
66
66
|
#endif
|
67
67
|
|
68
|
-
// there will be blocks
|
69
|
-
#define BEGIN_KERNEL(RM, RN) \
|
70
|
-
int ytiles = (m - m0) / RM; \
|
71
|
-
int xtiles = (n - n0) / RN; \
|
72
|
-
int tiles = ytiles * xtiles; \
|
73
|
-
int duty = (tiles + nth - 1) / nth; \
|
74
|
-
int start = duty * ith; \
|
75
|
-
int end = start + duty; \
|
76
|
-
if (end > tiles) \
|
77
|
-
end = tiles; \
|
78
|
-
for (int job = start; job < end; ++job) { \
|
79
|
-
int i = m0 + job / xtiles * RM; \
|
80
|
-
int j = n0 + job % xtiles * RN;
|
81
|
-
|
82
|
-
#define END_KERNEL() }
|
83
|
-
|
84
68
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
85
69
|
|
86
70
|
namespace {
|
@@ -122,6 +106,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|
122
106
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
123
107
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
124
108
|
|
109
|
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
110
|
+
// VECTORIZED FUSED MULTIPLY ADD
|
111
|
+
|
112
|
+
/**
|
113
|
+
* Computes a * b + c.
|
114
|
+
*/
|
115
|
+
template <typename T, typename U>
|
116
|
+
inline U madd(T a, T b, U c) {
|
117
|
+
return add(mul(a, b), c);
|
118
|
+
}
|
119
|
+
|
120
|
+
#if defined(__FMA__)
|
121
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
122
|
+
template <>
|
123
|
+
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
124
|
+
return _mm256_fmadd_ps(a, b, c);
|
125
|
+
}
|
126
|
+
#endif
|
127
|
+
#if defined(__AVX512F__)
|
128
|
+
template <>
|
129
|
+
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
130
|
+
return _mm512_fmadd_ps(a, b, c);
|
131
|
+
}
|
132
|
+
#endif
|
133
|
+
#endif
|
134
|
+
|
135
|
+
#if defined(__ARM_FEATURE_FMA)
|
136
|
+
template <>
|
137
|
+
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
138
|
+
return vfmaq_f32(c, b, a);
|
139
|
+
}
|
140
|
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
141
|
+
template <>
|
142
|
+
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
143
|
+
return vfmaq_f16(c, b, a);
|
144
|
+
}
|
145
|
+
#endif
|
146
|
+
#endif
|
147
|
+
|
125
148
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
126
149
|
// VECTORIZED HORIZONTAL SUM
|
127
150
|
|
@@ -213,287 +236,210 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
213
236
|
}
|
214
237
|
#endif // __AVX512F__
|
215
238
|
|
216
|
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
217
|
-
// ABSTRACTIONS
|
218
|
-
|
219
|
-
/**
|
220
|
-
* Computes a * b + c.
|
221
|
-
*
|
222
|
-
* This operation will become fused into a single arithmetic instruction
|
223
|
-
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
|
224
|
-
* 2013), AMD Bulldozer+ (c. 2011), etc.
|
225
|
-
*/
|
226
|
-
template <typename T, typename U>
|
227
|
-
inline U madd(T a, T b, U c) {
|
228
|
-
return add(mul(a, b), c);
|
229
|
-
}
|
230
|
-
|
231
|
-
/**
|
232
|
-
* Computes a * b + c with error correction.
|
233
|
-
*
|
234
|
-
* @see W. Kahan, "Further remarks on reducing truncation errors,"
|
235
|
-
* Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
|
236
|
-
* doi: 10.1145/363707.363723.
|
237
|
-
*/
|
238
|
-
template <typename T, typename U>
|
239
|
-
inline U madder(T a, T b, U c, U *e) {
|
240
|
-
U y = sub(mul(a, b), *e);
|
241
|
-
U t = add(c, y);
|
242
|
-
*e = sub(sub(t, c), y);
|
243
|
-
return t;
|
244
|
-
}
|
245
|
-
|
246
239
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
247
240
|
// FLOATING POINT MATRIX MULTIPLICATION
|
248
241
|
|
249
242
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
250
243
|
class tinyBLAS {
|
251
244
|
public:
|
252
|
-
tinyBLAS(
|
253
|
-
const TA *A,
|
254
|
-
const TB *B,
|
255
|
-
TC *C,
|
245
|
+
tinyBLAS(int64_t k,
|
246
|
+
const TA *A, int64_t lda,
|
247
|
+
const TB *B, int64_t ldb,
|
248
|
+
TC *C, int64_t ldc,
|
256
249
|
int ith, int nth)
|
257
250
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
258
251
|
}
|
259
252
|
|
260
|
-
void matmul(
|
253
|
+
void matmul(int64_t m, int64_t n, int task) {
|
261
254
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
262
255
|
mnpack(0, m, 0, n);
|
263
256
|
}
|
264
257
|
|
265
258
|
private:
|
266
|
-
NOINLINE void mnpack(
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
259
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
260
|
+
int64_t mc, nc, mp, np;
|
261
|
+
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
262
|
+
#if VECTOR_REGISTERS == 32
|
263
|
+
case 0x55:
|
264
|
+
mc = 5;
|
265
|
+
nc = 5;
|
266
|
+
gemm<5, 5>(m0, m, n0, n);
|
267
|
+
break;
|
268
|
+
case 0x45:
|
269
|
+
mc = 4;
|
270
|
+
nc = 5;
|
271
|
+
gemm<4, 5>(m0, m, n0, n);
|
272
|
+
break;
|
273
|
+
case 0x54:
|
271
274
|
mc = 5;
|
275
|
+
nc = 4;
|
276
|
+
gemm<5, 4>(m0, m, n0, n);
|
277
|
+
break;
|
278
|
+
case 0x44:
|
279
|
+
mc = 4;
|
280
|
+
nc = 4;
|
281
|
+
gemm<4, 4>(m0, m, n0, n);
|
282
|
+
break;
|
283
|
+
case 0x53:
|
284
|
+
mc = 5;
|
285
|
+
nc = 3;
|
286
|
+
gemm<5, 3>(m0, m, n0, n);
|
287
|
+
break;
|
288
|
+
case 0x35:
|
289
|
+
mc = 3;
|
272
290
|
nc = 5;
|
273
|
-
|
274
|
-
|
291
|
+
gemm<3, 5>(m0, m, n0, n);
|
292
|
+
break;
|
293
|
+
case 0x43:
|
294
|
+
mc = 4;
|
295
|
+
nc = 3;
|
296
|
+
gemm<4, 3>(m0, m, n0, n);
|
297
|
+
break;
|
298
|
+
#else
|
299
|
+
case 0x55:
|
300
|
+
case 0x54:
|
301
|
+
case 0x53:
|
302
|
+
case 0x45:
|
303
|
+
case 0x44:
|
304
|
+
case 0x43:
|
305
|
+
mc = 4;
|
306
|
+
nc = 3;
|
307
|
+
gemm<4, 3>(m0, m, n0, n);
|
308
|
+
break;
|
309
|
+
case 0x35:
|
310
|
+
#endif
|
311
|
+
case 0x34:
|
275
312
|
mc = 3;
|
276
313
|
nc = 4;
|
277
|
-
|
278
|
-
|
279
|
-
|
314
|
+
gemm<3, 4>(m0, m, n0, n);
|
315
|
+
break;
|
316
|
+
case 0x52:
|
317
|
+
mc = 5;
|
318
|
+
nc = 2;
|
319
|
+
gemm<5, 2>(m0, m, n0, n);
|
320
|
+
break;
|
321
|
+
case 0x33:
|
322
|
+
mc = 3;
|
323
|
+
nc = 3;
|
324
|
+
gemm<3, 3>(m0, m, n0, n);
|
325
|
+
break;
|
326
|
+
case 0x25:
|
327
|
+
mc = 2;
|
328
|
+
nc = 5;
|
329
|
+
gemm<2, 5>(m0, m, n0, n);
|
330
|
+
break;
|
331
|
+
case 0x42:
|
332
|
+
mc = 4;
|
333
|
+
nc = 2;
|
334
|
+
gemm<4, 2>(m0, m, n0, n);
|
335
|
+
break;
|
336
|
+
case 0x24:
|
337
|
+
mc = 2;
|
280
338
|
nc = 4;
|
281
|
-
|
282
|
-
|
339
|
+
gemm<2, 4>(m0, m, n0, n);
|
340
|
+
break;
|
341
|
+
case 0x32:
|
342
|
+
mc = 3;
|
343
|
+
nc = 2;
|
344
|
+
gemm<3, 2>(m0, m, n0, n);
|
345
|
+
break;
|
346
|
+
case 0x23:
|
347
|
+
mc = 2;
|
348
|
+
nc = 3;
|
349
|
+
gemm<2, 3>(m0, m, n0, n);
|
350
|
+
break;
|
351
|
+
case 0x51:
|
352
|
+
mc = 5;
|
353
|
+
nc = 1;
|
354
|
+
gemm<5, 1>(m0, m, n0, n);
|
355
|
+
break;
|
356
|
+
case 0x41:
|
283
357
|
mc = 4;
|
284
358
|
nc = 1;
|
285
|
-
|
286
|
-
|
359
|
+
gemm<4, 1>(m0, m, n0, n);
|
360
|
+
break;
|
361
|
+
case 0x22:
|
362
|
+
mc = 2;
|
363
|
+
nc = 2;
|
364
|
+
gemm<2, 2>(m0, m, n0, n);
|
365
|
+
break;
|
366
|
+
case 0x15:
|
367
|
+
mc = 1;
|
368
|
+
nc = 5;
|
369
|
+
gemm<1, 5>(m0, m, n0, n);
|
370
|
+
break;
|
371
|
+
case 0x14:
|
287
372
|
mc = 1;
|
373
|
+
nc = 4;
|
374
|
+
gemm<1, 4>(m0, m, n0, n);
|
375
|
+
break;
|
376
|
+
case 0x31:
|
377
|
+
mc = 3;
|
288
378
|
nc = 1;
|
289
|
-
|
379
|
+
gemm<3, 1>(m0, m, n0, n);
|
380
|
+
break;
|
381
|
+
case 0x13:
|
382
|
+
mc = 1;
|
383
|
+
nc = 3;
|
384
|
+
gemm<1, 3>(m0, m, n0, n);
|
385
|
+
break;
|
386
|
+
case 0x21:
|
387
|
+
mc = 2;
|
388
|
+
nc = 1;
|
389
|
+
gemm<2, 1>(m0, m, n0, n);
|
390
|
+
break;
|
391
|
+
case 0x12:
|
392
|
+
mc = 1;
|
393
|
+
nc = 2;
|
394
|
+
gemm<1, 2>(m0, m, n0, n);
|
395
|
+
break;
|
396
|
+
case 0x11:
|
397
|
+
mc = 1;
|
398
|
+
nc = 1;
|
399
|
+
gemm<1, 1>(m0, m, n0, n);
|
400
|
+
break;
|
401
|
+
default:
|
402
|
+
return;
|
290
403
|
}
|
291
404
|
mp = m0 + (m - m0) / mc * mc;
|
292
405
|
np = n0 + (n - n0) / nc * nc;
|
293
406
|
mnpack(mp, m, n0, np);
|
294
|
-
mnpack(m0,
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
NOINLINE void
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
D c40 = {0};
|
321
|
-
D c41 = {0};
|
322
|
-
D c42 = {0};
|
323
|
-
D c43 = {0};
|
324
|
-
D c44 = {0};
|
325
|
-
for (int l = 0; l < k; l += KN) {
|
326
|
-
V k0 = load<V>(B + ldb * (j + 0) + l);
|
327
|
-
V k1 = load<V>(B + ldb * (j + 1) + l);
|
328
|
-
V k2 = load<V>(B + ldb * (j + 2) + l);
|
329
|
-
V k3 = load<V>(B + ldb * (j + 3) + l);
|
330
|
-
V k4 = load<V>(B + ldb * (j + 4) + l);
|
331
|
-
V a0 = load<V>(A + lda * (i + 0) + l);
|
332
|
-
c00 = madd(a0, k0, c00);
|
333
|
-
c01 = madd(a0, k1, c01);
|
334
|
-
c02 = madd(a0, k2, c02);
|
335
|
-
c03 = madd(a0, k3, c03);
|
336
|
-
c04 = madd(a0, k4, c04);
|
337
|
-
V a1 = load<V>(A + lda * (i + 1) + l);
|
338
|
-
c10 = madd(a1, k0, c10);
|
339
|
-
c11 = madd(a1, k1, c11);
|
340
|
-
c12 = madd(a1, k2, c12);
|
341
|
-
c13 = madd(a1, k3, c13);
|
342
|
-
c14 = madd(a1, k4, c14);
|
343
|
-
V a2 = load<V>(A + lda * (i + 2) + l);
|
344
|
-
c20 = madd(a2, k0, c20);
|
345
|
-
c21 = madd(a2, k1, c21);
|
346
|
-
c22 = madd(a2, k2, c22);
|
347
|
-
c23 = madd(a2, k3, c23);
|
348
|
-
c24 = madd(a2, k4, c24);
|
349
|
-
V a3 = load<V>(A + lda * (i + 3) + l);
|
350
|
-
c30 = madd(a3, k0, c30);
|
351
|
-
c31 = madd(a3, k1, c31);
|
352
|
-
c32 = madd(a3, k2, c32);
|
353
|
-
c33 = madd(a3, k3, c33);
|
354
|
-
c34 = madd(a3, k4, c34);
|
355
|
-
V a4 = load<V>(A + lda * (i + 4) + l);
|
356
|
-
c40 = madd(a4, k0, c40);
|
357
|
-
c41 = madd(a4, k1, c41);
|
358
|
-
c42 = madd(a4, k2, c42);
|
359
|
-
c43 = madd(a4, k3, c43);
|
360
|
-
c44 = madd(a4, k4, c44);
|
361
|
-
}
|
362
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
363
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
364
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
365
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
366
|
-
C[ldc * (j + 0) + (i + 4)] = hsum(c40);
|
367
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
368
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
369
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
370
|
-
C[ldc * (j + 1) + (i + 3)] = hsum(c31);
|
371
|
-
C[ldc * (j + 1) + (i + 4)] = hsum(c41);
|
372
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
373
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
374
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
375
|
-
C[ldc * (j + 2) + (i + 3)] = hsum(c32);
|
376
|
-
C[ldc * (j + 2) + (i + 4)] = hsum(c42);
|
377
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
378
|
-
C[ldc * (j + 3) + (i + 1)] = hsum(c13);
|
379
|
-
C[ldc * (j + 3) + (i + 2)] = hsum(c23);
|
380
|
-
C[ldc * (j + 3) + (i + 3)] = hsum(c33);
|
381
|
-
C[ldc * (j + 3) + (i + 4)] = hsum(c43);
|
382
|
-
C[ldc * (j + 4) + (i + 0)] = hsum(c04);
|
383
|
-
C[ldc * (j + 4) + (i + 1)] = hsum(c14);
|
384
|
-
C[ldc * (j + 4) + (i + 2)] = hsum(c24);
|
385
|
-
C[ldc * (j + 4) + (i + 3)] = hsum(c34);
|
386
|
-
C[ldc * (j + 4) + (i + 4)] = hsum(c44);
|
387
|
-
END_KERNEL()
|
388
|
-
}
|
389
|
-
|
390
|
-
NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
|
391
|
-
BEGIN_KERNEL(3, 4)
|
392
|
-
D c00 = {0};
|
393
|
-
D c01 = {0};
|
394
|
-
D c02 = {0};
|
395
|
-
D c03 = {0};
|
396
|
-
D c10 = {0};
|
397
|
-
D c11 = {0};
|
398
|
-
D c12 = {0};
|
399
|
-
D c13 = {0};
|
400
|
-
D c20 = {0};
|
401
|
-
D c21 = {0};
|
402
|
-
D c22 = {0};
|
403
|
-
D c23 = {0};
|
404
|
-
for (int l = 0; l < k; l += KN) {
|
405
|
-
V k0 = load<V>(B + ldb * (j + 0) + l);
|
406
|
-
V k1 = load<V>(B + ldb * (j + 1) + l);
|
407
|
-
V k2 = load<V>(B + ldb * (j + 2) + l);
|
408
|
-
V k3 = load<V>(B + ldb * (j + 3) + l);
|
409
|
-
V a0 = load<V>(A + lda * (i + 0) + l);
|
410
|
-
c00 = madd(a0, k0, c00);
|
411
|
-
c01 = madd(a0, k1, c01);
|
412
|
-
c02 = madd(a0, k2, c02);
|
413
|
-
c03 = madd(a0, k3, c03);
|
414
|
-
V a1 = load<V>(A + lda * (i + 1) + l);
|
415
|
-
c10 = madd(a1, k0, c10);
|
416
|
-
c11 = madd(a1, k1, c11);
|
417
|
-
c12 = madd(a1, k2, c12);
|
418
|
-
c13 = madd(a1, k3, c13);
|
419
|
-
V a2 = load<V>(A + lda * (i + 2) + l);
|
420
|
-
c20 = madd(a2, k0, c20);
|
421
|
-
c21 = madd(a2, k1, c21);
|
422
|
-
c22 = madd(a2, k2, c22);
|
423
|
-
c23 = madd(a2, k3, c23);
|
424
|
-
}
|
425
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
426
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
427
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
428
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
429
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
430
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
431
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
432
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
433
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
434
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
435
|
-
C[ldc * (j + 3) + (i + 1)] = hsum(c13);
|
436
|
-
C[ldc * (j + 3) + (i + 2)] = hsum(c23);
|
437
|
-
END_KERNEL()
|
438
|
-
}
|
439
|
-
|
440
|
-
NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
|
441
|
-
BEGIN_KERNEL(1, 4)
|
442
|
-
D c00 = {0}, e00 = {0};
|
443
|
-
D c01 = {0}, e01 = {0};
|
444
|
-
D c02 = {0}, e02 = {0};
|
445
|
-
D c03 = {0}, e03 = {0};
|
446
|
-
for (int l = 0; l < k; l += KN) {
|
447
|
-
V a = load<V>(A + lda * (i + 0) + l);
|
448
|
-
c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
|
449
|
-
c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
|
450
|
-
c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
|
451
|
-
c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
|
452
|
-
}
|
453
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
454
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
455
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
456
|
-
C[ldc * (j + 3) + (i + 0)] = hsum(c03);
|
457
|
-
END_KERNEL()
|
458
|
-
}
|
459
|
-
|
460
|
-
NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
|
461
|
-
BEGIN_KERNEL(4, 1)
|
462
|
-
D c00 = {0}, e00 = {0};
|
463
|
-
D c10 = {0}, e10 = {0};
|
464
|
-
D c20 = {0}, e20 = {0};
|
465
|
-
D c30 = {0}, e30 = {0};
|
466
|
-
for (int l = 0; l < k; l += KN) {
|
467
|
-
V b = load<V>(B + ldb * (j + 0) + l);
|
468
|
-
c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
|
469
|
-
c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
|
470
|
-
c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
|
471
|
-
c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
|
407
|
+
mnpack(m0, m, np, n);
|
408
|
+
}
|
409
|
+
|
410
|
+
template <int RM, int RN>
|
411
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
412
|
+
int64_t ytiles = (m - m0) / RM;
|
413
|
+
int64_t xtiles = (n - n0) / RN;
|
414
|
+
int64_t tiles = xtiles * ytiles;
|
415
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
416
|
+
int64_t start = duty * ith;
|
417
|
+
int64_t end = start + duty;
|
418
|
+
if (end > tiles)
|
419
|
+
end = tiles;
|
420
|
+
for (int64_t job = start; job < end; ++job) {
|
421
|
+
int64_t ii = m0 + job / xtiles * RM;
|
422
|
+
int64_t jj = n0 + job % xtiles * RN;
|
423
|
+
D Cv[RN][RM] = {};
|
424
|
+
for (int64_t l = 0; l < k; l += KN)
|
425
|
+
for (int64_t j = 0; j < RN; ++j)
|
426
|
+
for (int64_t i = 0; i < RM; ++i)
|
427
|
+
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
428
|
+
load<V>(B + ldb * (jj + j) + l),
|
429
|
+
Cv[j][i]);
|
430
|
+
for (int64_t j = 0; j < RN; ++j)
|
431
|
+
for (int64_t i = 0; i < RM; ++i)
|
432
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
472
433
|
}
|
473
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
474
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
475
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
476
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
477
|
-
END_KERNEL()
|
478
|
-
}
|
479
|
-
|
480
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
481
|
-
BEGIN_KERNEL(1, 1)
|
482
|
-
D c = {0}, e = {0};
|
483
|
-
for (int l = 0; l < k; l += KN)
|
484
|
-
c = madder(load<V>(A + lda * i + l),
|
485
|
-
load<V>(B + ldb * j + l), c, &e);
|
486
|
-
C[ldc * j + i] = hsum(c);
|
487
|
-
END_KERNEL()
|
488
434
|
}
|
489
435
|
|
490
436
|
const TA *const A;
|
491
437
|
const TB *const B;
|
492
438
|
TC *const C;
|
493
|
-
const
|
494
|
-
const
|
495
|
-
const
|
496
|
-
const
|
439
|
+
const int64_t k;
|
440
|
+
const int64_t lda;
|
441
|
+
const int64_t ldb;
|
442
|
+
const int64_t ldc;
|
497
443
|
const int ith;
|
498
444
|
const int nth;
|
499
445
|
};
|
@@ -505,136 +451,113 @@ class tinyBLAS {
|
|
505
451
|
template <typename TA>
|
506
452
|
class tinyBLAS_Q0_ARM {
|
507
453
|
public:
|
508
|
-
tinyBLAS_Q0_ARM(
|
509
|
-
const TA *A,
|
510
|
-
const block_q8_0 *B,
|
511
|
-
float *C,
|
454
|
+
tinyBLAS_Q0_ARM(int64_t k,
|
455
|
+
const TA *A, int64_t lda,
|
456
|
+
const block_q8_0 *B, int64_t ldb,
|
457
|
+
float *C, int64_t ldc,
|
512
458
|
int ith, int nth)
|
513
459
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
514
460
|
}
|
515
461
|
|
516
|
-
void matmul(
|
462
|
+
void matmul(int64_t m, int64_t n, int task) {
|
517
463
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
518
464
|
mnpack(0, m, 0, n);
|
519
465
|
}
|
520
466
|
|
521
467
|
private:
|
522
|
-
NOINLINE void mnpack(
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
if (m - m0 >= 3 && n - n0 >= 3) {
|
468
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
469
|
+
int64_t mc, nc, mp, np;
|
470
|
+
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
|
471
|
+
case 0x33:
|
527
472
|
mc = 3;
|
528
473
|
nc = 3;
|
529
|
-
|
530
|
-
|
474
|
+
gemm<3, 3>(m0, m, n0, n);
|
475
|
+
break;
|
476
|
+
case 0x32:
|
477
|
+
mc = 3;
|
478
|
+
nc = 2;
|
479
|
+
gemm<3, 2>(m0, m, n0, n);
|
480
|
+
break;
|
481
|
+
case 0x23:
|
482
|
+
mc = 2;
|
483
|
+
nc = 3;
|
484
|
+
gemm<2, 3>(m0, m, n0, n);
|
485
|
+
break;
|
486
|
+
case 0x22:
|
487
|
+
mc = 2;
|
488
|
+
nc = 2;
|
489
|
+
gemm<2, 2>(m0, m, n0, n);
|
490
|
+
break;
|
491
|
+
case 0x31:
|
492
|
+
mc = 3;
|
493
|
+
nc = 1;
|
494
|
+
gemm<3, 1>(m0, m, n0, n);
|
495
|
+
break;
|
496
|
+
case 0x13:
|
497
|
+
mc = 1;
|
498
|
+
nc = 3;
|
499
|
+
gemm<1, 3>(m0, m, n0, n);
|
500
|
+
break;
|
501
|
+
case 0x21:
|
502
|
+
mc = 2;
|
503
|
+
nc = 1;
|
504
|
+
gemm<2, 1>(m0, m, n0, n);
|
505
|
+
break;
|
506
|
+
case 0x12:
|
507
|
+
mc = 1;
|
508
|
+
nc = 2;
|
509
|
+
gemm<1, 2>(m0, m, n0, n);
|
510
|
+
break;
|
511
|
+
case 0x11:
|
531
512
|
mc = 1;
|
532
513
|
nc = 1;
|
533
|
-
|
514
|
+
gemm<1, 1>(m0, m, n0, n);
|
515
|
+
break;
|
516
|
+
default:
|
517
|
+
return;
|
534
518
|
}
|
535
519
|
mp = m0 + (m - m0) / mc * mc;
|
536
520
|
np = n0 + (n - n0) / nc * nc;
|
537
521
|
mnpack(mp, m, n0, np);
|
538
|
-
mnpack(m0,
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
NOINLINE void
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
|
571
|
-
c02 = vmlaq_n_f32(
|
572
|
-
c02,
|
573
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
|
574
|
-
load_hi(Ap0 + l), load_hi(Bp2 + l))),
|
575
|
-
unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
|
576
|
-
c10 = vmlaq_n_f32(
|
577
|
-
c10,
|
578
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
|
579
|
-
load_hi(Ap1 + l), load_hi(Bp0 + l))),
|
580
|
-
unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
|
581
|
-
c11 = vmlaq_n_f32(
|
582
|
-
c11,
|
583
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
|
584
|
-
load_hi(Ap1 + l), load_hi(Bp1 + l))),
|
585
|
-
unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
|
586
|
-
c12 = vmlaq_n_f32(
|
587
|
-
c12,
|
588
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
|
589
|
-
load_hi(Ap1 + l), load_hi(Bp2 + l))),
|
590
|
-
unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
|
591
|
-
c20 = vmlaq_n_f32(
|
592
|
-
c20,
|
593
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
|
594
|
-
load_hi(Ap2 + l), load_hi(Bp0 + l))),
|
595
|
-
unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
|
596
|
-
c21 = vmlaq_n_f32(
|
597
|
-
c21,
|
598
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
|
599
|
-
load_hi(Ap2 + l), load_hi(Bp1 + l))),
|
600
|
-
unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
|
601
|
-
c22 = vmlaq_n_f32(
|
602
|
-
c22,
|
603
|
-
vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
|
604
|
-
load_hi(Ap2 + l), load_hi(Bp2 + l))),
|
605
|
-
unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
|
606
|
-
}
|
607
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
608
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
609
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
610
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
611
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
612
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
613
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
614
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
615
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
616
|
-
END_KERNEL()
|
617
|
-
}
|
618
|
-
|
619
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
620
|
-
BEGIN_KERNEL(1, 1)
|
621
|
-
float32x4_t acc = vdupq_n_f32(0.f);
|
622
|
-
const TA *Ap = A + lda * i;
|
623
|
-
const block_q8_0 *Bp = B + ldb * j;
|
624
|
-
for (int l = 0; l < k; ++l) {
|
625
|
-
acc = vmlaq_n_f32(acc,
|
626
|
-
vcvtq_f32_s32(vdotq_s32(
|
627
|
-
vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
|
628
|
-
load_hi(Ap + l), load_hi(Bp + l))),
|
629
|
-
unhalf(Ap[l].d) * unhalf(Bp[l].d));
|
522
|
+
mnpack(m0, m, np, n);
|
523
|
+
}
|
524
|
+
|
525
|
+
template <int RM, int RN>
|
526
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
527
|
+
int64_t ytiles = (m - m0) / RM;
|
528
|
+
int64_t xtiles = (n - n0) / RN;
|
529
|
+
int64_t tiles = xtiles * ytiles;
|
530
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
531
|
+
int64_t start = duty * ith;
|
532
|
+
int64_t end = start + duty;
|
533
|
+
if (end > tiles)
|
534
|
+
end = tiles;
|
535
|
+
for (int64_t job = start; job < end; ++job) {
|
536
|
+
int64_t ii = m0 + job / xtiles * RM;
|
537
|
+
int64_t jj = n0 + job % xtiles * RN;
|
538
|
+
float32x4_t Cv[RN][RM] = {};
|
539
|
+
for (int64_t l = 0; l < k; ++l)
|
540
|
+
for (int64_t j = 0; j < RN; ++j)
|
541
|
+
for (int64_t i = 0; i < RM; ++i)
|
542
|
+
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
543
|
+
vcvtq_f32_s32(vdotq_s32(
|
544
|
+
vdotq_s32(vdupq_n_s32(0),
|
545
|
+
load_lo(A + lda * (ii + i) + l),
|
546
|
+
load_lo(B + ldb * (jj + j) + l)),
|
547
|
+
load_hi(A + lda * (ii + i) + l),
|
548
|
+
load_hi(B + ldb * (jj + j) + l))),
|
549
|
+
unhalf(A[lda * (ii + i) + l].d) *
|
550
|
+
unhalf(B[ldb * (jj + j) + l].d));
|
551
|
+
for (int64_t j = 0; j < RN; ++j)
|
552
|
+
for (int64_t i = 0; i < RM; ++i)
|
553
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
630
554
|
}
|
631
|
-
C[ldc * j + i] = hsum(acc);
|
632
|
-
END_KERNEL()
|
633
555
|
}
|
634
556
|
|
635
557
|
inline int8x16_t load_lo(const block_q8_0 *b) {
|
636
558
|
return vld1q_s8(b->qs);
|
637
559
|
}
|
560
|
+
|
638
561
|
inline int8x16_t load_hi(const block_q8_0 *b) {
|
639
562
|
return vld1q_s8(b->qs + 16);
|
640
563
|
}
|
@@ -644,6 +567,7 @@ class tinyBLAS_Q0_ARM {
|
|
644
567
|
vdupq_n_u8(0x0f))),
|
645
568
|
vdupq_n_s8(0x8));
|
646
569
|
}
|
570
|
+
|
647
571
|
inline int8x16_t load_hi(const block_q4_0 *b) {
|
648
572
|
return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
|
649
573
|
vdupq_n_s8(0x8));
|
@@ -652,10 +576,10 @@ class tinyBLAS_Q0_ARM {
|
|
652
576
|
const TA *const A;
|
653
577
|
const block_q8_0 *const B;
|
654
578
|
float *const C;
|
655
|
-
const
|
656
|
-
const
|
657
|
-
const
|
658
|
-
const
|
579
|
+
const int64_t k;
|
580
|
+
const int64_t lda;
|
581
|
+
const int64_t ldb;
|
582
|
+
const int64_t ldc;
|
659
583
|
const int ith;
|
660
584
|
const int nth;
|
661
585
|
};
|
@@ -665,231 +589,157 @@ class tinyBLAS_Q0_ARM {
|
|
665
589
|
template <typename TA, typename TB, typename TC>
|
666
590
|
class tinyBLAS_Q0_AVX2 {
|
667
591
|
public:
|
668
|
-
tinyBLAS_Q0_AVX2(
|
669
|
-
const TA *A,
|
670
|
-
const TB *B,
|
671
|
-
TC *C,
|
592
|
+
tinyBLAS_Q0_AVX2(int64_t k,
|
593
|
+
const TA *A, int64_t lda,
|
594
|
+
const TB *B, int64_t ldb,
|
595
|
+
TC *C, int64_t ldc,
|
672
596
|
int ith, int nth)
|
673
597
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
674
598
|
}
|
675
599
|
|
676
|
-
void matmul(
|
600
|
+
void matmul(int64_t m, int64_t n, int task) {
|
677
601
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
678
602
|
mnpack(0, m, 0, n);
|
679
603
|
}
|
680
604
|
|
681
605
|
private:
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
606
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
607
|
+
int64_t mc, nc, mp, np;
|
608
|
+
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
|
609
|
+
#if VECTOR_REGISTERS == 32
|
610
|
+
case 0x44:
|
611
|
+
mc = 4;
|
612
|
+
nc = 4;
|
613
|
+
gemm<4, 4>(m0, m, n0, n);
|
614
|
+
break;
|
615
|
+
case 0x43:
|
616
|
+
mc = 4;
|
617
|
+
nc = 3;
|
618
|
+
gemm<4, 3>(m0, m, n0, n);
|
619
|
+
break;
|
620
|
+
case 0x34:
|
621
|
+
mc = 3;
|
622
|
+
nc = 4;
|
623
|
+
gemm<3, 4>(m0, m, n0, n);
|
624
|
+
break;
|
625
|
+
case 0x33:
|
626
|
+
mc = 3;
|
627
|
+
nc = 3;
|
628
|
+
gemm<3, 3>(m0, m, n0, n);
|
629
|
+
break;
|
630
|
+
case 0x42:
|
631
|
+
mc = 4;
|
632
|
+
nc = 2;
|
633
|
+
gemm<4, 2>(m0, m, n0, n);
|
634
|
+
break;
|
635
|
+
case 0x24:
|
636
|
+
mc = 2;
|
637
|
+
nc = 4;
|
638
|
+
gemm<2, 4>(m0, m, n0, n);
|
639
|
+
break;
|
640
|
+
#else
|
641
|
+
case 0x44:
|
642
|
+
case 0x43:
|
643
|
+
case 0x42:
|
687
644
|
mc = 4;
|
645
|
+
nc = 2;
|
646
|
+
gemm<4, 2>(m0, m, n0, n);
|
647
|
+
break;
|
648
|
+
case 0x34:
|
649
|
+
case 0x24:
|
650
|
+
mc = 2;
|
651
|
+
nc = 4;
|
652
|
+
gemm<2, 4>(m0, m, n0, n);
|
653
|
+
break;
|
654
|
+
case 0x33:
|
655
|
+
#endif
|
656
|
+
case 0x32:
|
657
|
+
mc = 3;
|
658
|
+
nc = 2;
|
659
|
+
gemm<3, 2>(m0, m, n0, n);
|
660
|
+
break;
|
661
|
+
case 0x23:
|
662
|
+
mc = 2;
|
688
663
|
nc = 3;
|
689
|
-
|
690
|
-
|
664
|
+
gemm<2, 3>(m0, m, n0, n);
|
665
|
+
break;
|
666
|
+
case 0x41:
|
691
667
|
mc = 4;
|
692
668
|
nc = 1;
|
693
|
-
|
694
|
-
|
669
|
+
gemm<4, 1>(m0, m, n0, n);
|
670
|
+
break;
|
671
|
+
case 0x22:
|
672
|
+
mc = 2;
|
673
|
+
nc = 2;
|
674
|
+
gemm<2, 2>(m0, m, n0, n);
|
675
|
+
break;
|
676
|
+
case 0x14:
|
695
677
|
mc = 1;
|
696
678
|
nc = 4;
|
697
|
-
|
698
|
-
|
679
|
+
gemm<1, 4>(m0, m, n0, n);
|
680
|
+
break;
|
681
|
+
case 0x31:
|
682
|
+
mc = 3;
|
683
|
+
nc = 1;
|
684
|
+
gemm<3, 1>(m0, m, n0, n);
|
685
|
+
break;
|
686
|
+
case 0x13:
|
687
|
+
mc = 1;
|
688
|
+
nc = 3;
|
689
|
+
gemm<1, 3>(m0, m, n0, n);
|
690
|
+
break;
|
691
|
+
case 0x21:
|
692
|
+
mc = 2;
|
693
|
+
nc = 1;
|
694
|
+
gemm<2, 1>(m0, m, n0, n);
|
695
|
+
break;
|
696
|
+
case 0x12:
|
697
|
+
mc = 1;
|
698
|
+
nc = 2;
|
699
|
+
gemm<1, 2>(m0, m, n0, n);
|
700
|
+
break;
|
701
|
+
case 0x11:
|
699
702
|
mc = 1;
|
700
703
|
nc = 1;
|
701
|
-
|
704
|
+
gemm<1, 1>(m0, m, n0, n);
|
705
|
+
break;
|
706
|
+
default:
|
707
|
+
return;
|
702
708
|
}
|
703
709
|
mp = m0 + (m - m0) / mc * mc;
|
704
710
|
np = n0 + (n - n0) / nc * nc;
|
705
711
|
mnpack(mp, m, n0, np);
|
706
|
-
mnpack(m0,
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
NOINLINE void
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
__m256i e0 = load(Ap0 + l);
|
737
|
-
__m256i e1 = load(Ap1 + l);
|
738
|
-
__m256i e2 = load(Ap2 + l);
|
739
|
-
__m256i e3 = load(Ap3 + l);
|
740
|
-
float db0 = unhalf(Bp0[l].d);
|
741
|
-
__m256 d00 = _mm256_set1_ps(da0 * db0);
|
742
|
-
__m256 d10 = _mm256_set1_ps(da1 * db0);
|
743
|
-
__m256 d20 = _mm256_set1_ps(da2 * db0);
|
744
|
-
__m256 d30 = _mm256_set1_ps(da3 * db0);
|
745
|
-
__m256i f0 = load(Bp0 + l);
|
746
|
-
__m256i u0 = _mm256_sign_epi8(f0, f0);
|
747
|
-
__m256i s00 = _mm256_sign_epi8(e0, f0);
|
748
|
-
__m256i s10 = _mm256_sign_epi8(e1, f0);
|
749
|
-
__m256i s20 = _mm256_sign_epi8(e2, f0);
|
750
|
-
__m256i s30 = _mm256_sign_epi8(e3, f0);
|
751
|
-
c00 = madd(d00, updot(u0, s00), c00);
|
752
|
-
c10 = madd(d10, updot(u0, s10), c10);
|
753
|
-
c20 = madd(d20, updot(u0, s20), c20);
|
754
|
-
c30 = madd(d30, updot(u0, s30), c30);
|
755
|
-
float db1 = unhalf(Bp1[l].d);
|
756
|
-
__m256 d01 = _mm256_set1_ps(da0 * db1);
|
757
|
-
__m256 d11 = _mm256_set1_ps(da1 * db1);
|
758
|
-
__m256 d21 = _mm256_set1_ps(da2 * db1);
|
759
|
-
__m256 d31 = _mm256_set1_ps(da3 * db1);
|
760
|
-
__m256i f1 = load(Bp1 + l);
|
761
|
-
__m256i u1 = _mm256_sign_epi8(f1, f1);
|
762
|
-
__m256i s01 = _mm256_sign_epi8(e0, f1);
|
763
|
-
__m256i s11 = _mm256_sign_epi8(e1, f1);
|
764
|
-
__m256i s21 = _mm256_sign_epi8(e2, f1);
|
765
|
-
__m256i s31 = _mm256_sign_epi8(e3, f1);
|
766
|
-
c01 = madd(d01, updot(u1, s01), c01);
|
767
|
-
c11 = madd(d11, updot(u1, s11), c11);
|
768
|
-
c21 = madd(d21, updot(u1, s21), c21);
|
769
|
-
c31 = madd(d31, updot(u1, s31), c31);
|
770
|
-
float db2 = unhalf(Bp2[l].d);
|
771
|
-
__m256 d02 = _mm256_set1_ps(da0 * db2);
|
772
|
-
__m256 d12 = _mm256_set1_ps(da1 * db2);
|
773
|
-
__m256 d22 = _mm256_set1_ps(da2 * db2);
|
774
|
-
__m256 d32 = _mm256_set1_ps(da3 * db2);
|
775
|
-
__m256i f2 = load(Bp2 + l);
|
776
|
-
__m256i u2 = _mm256_sign_epi8(f2, f2);
|
777
|
-
__m256i s02 = _mm256_sign_epi8(e0, f2);
|
778
|
-
__m256i s12 = _mm256_sign_epi8(e1, f2);
|
779
|
-
__m256i s22 = _mm256_sign_epi8(e2, f2);
|
780
|
-
__m256i s32 = _mm256_sign_epi8(e3, f2);
|
781
|
-
c02 = madd(d02, updot(u2, s02), c02);
|
782
|
-
c12 = madd(d12, updot(u2, s12), c12);
|
783
|
-
c22 = madd(d22, updot(u2, s22), c22);
|
784
|
-
c32 = madd(d32, updot(u2, s32), c32);
|
785
|
-
}
|
786
|
-
C[ldc * (j + 0) + (i + 0)] = hsum(c00);
|
787
|
-
C[ldc * (j + 0) + (i + 1)] = hsum(c10);
|
788
|
-
C[ldc * (j + 0) + (i + 2)] = hsum(c20);
|
789
|
-
C[ldc * (j + 0) + (i + 3)] = hsum(c30);
|
790
|
-
C[ldc * (j + 1) + (i + 0)] = hsum(c01);
|
791
|
-
C[ldc * (j + 1) + (i + 1)] = hsum(c11);
|
792
|
-
C[ldc * (j + 1) + (i + 2)] = hsum(c21);
|
793
|
-
C[ldc * (j + 1) + (i + 3)] = hsum(c31);
|
794
|
-
C[ldc * (j + 2) + (i + 0)] = hsum(c02);
|
795
|
-
C[ldc * (j + 2) + (i + 1)] = hsum(c12);
|
796
|
-
C[ldc * (j + 2) + (i + 2)] = hsum(c22);
|
797
|
-
C[ldc * (j + 2) + (i + 3)] = hsum(c32);
|
798
|
-
END_KERNEL()
|
799
|
-
}
|
800
|
-
|
801
|
-
NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
|
802
|
-
BEGIN_KERNEL(4, 1)
|
803
|
-
__m256 c0 = _mm256_setzero_ps();
|
804
|
-
__m256 c1 = _mm256_setzero_ps();
|
805
|
-
__m256 c2 = _mm256_setzero_ps();
|
806
|
-
__m256 c3 = _mm256_setzero_ps();
|
807
|
-
const TA *Ap0 = A + lda * (i + 0);
|
808
|
-
const TA *Ap1 = A + lda * (i + 1);
|
809
|
-
const TA *Ap2 = A + lda * (i + 2);
|
810
|
-
const TA *Ap3 = A + lda * (i + 3);
|
811
|
-
const TB *Bp = B + ldb * j;
|
812
|
-
for (int l = 0; l < k; ++l) {
|
813
|
-
float db0 = unhalf(Bp[l].d);
|
814
|
-
__m256i f = load(Bp + l);
|
815
|
-
__m256i u = _mm256_sign_epi8(f, f);
|
816
|
-
__m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
|
817
|
-
__m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
|
818
|
-
__m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
|
819
|
-
__m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
|
820
|
-
__m256i e0 = load(Ap0 + l);
|
821
|
-
__m256i e1 = load(Ap1 + l);
|
822
|
-
__m256i e2 = load(Ap2 + l);
|
823
|
-
__m256i e3 = load(Ap3 + l);
|
824
|
-
__m256i s0 = _mm256_sign_epi8(e0, f);
|
825
|
-
__m256i s1 = _mm256_sign_epi8(e1, f);
|
826
|
-
__m256i s2 = _mm256_sign_epi8(e2, f);
|
827
|
-
__m256i s3 = _mm256_sign_epi8(e3, f);
|
828
|
-
__m256 g0 = updot(u, s0);
|
829
|
-
__m256 g1 = updot(u, s1);
|
830
|
-
__m256 g2 = updot(u, s2);
|
831
|
-
__m256 g3 = updot(u, s3);
|
832
|
-
c0 = madd(d0, g0, c0);
|
833
|
-
c1 = madd(d1, g1, c1);
|
834
|
-
c2 = madd(d2, g2, c2);
|
835
|
-
c3 = madd(d3, g3, c3);
|
836
|
-
}
|
837
|
-
C[ldc * j + (i + 0)] = hsum(c0);
|
838
|
-
C[ldc * j + (i + 1)] = hsum(c1);
|
839
|
-
C[ldc * j + (i + 2)] = hsum(c2);
|
840
|
-
C[ldc * j + (i + 3)] = hsum(c3);
|
841
|
-
END_KERNEL()
|
842
|
-
}
|
843
|
-
|
844
|
-
NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
|
845
|
-
BEGIN_KERNEL(1, 4)
|
846
|
-
__m256 c0 = _mm256_setzero_ps();
|
847
|
-
__m256 c1 = _mm256_setzero_ps();
|
848
|
-
__m256 c2 = _mm256_setzero_ps();
|
849
|
-
__m256 c3 = _mm256_setzero_ps();
|
850
|
-
const TB *Bp0 = B + ldb * (j + 0);
|
851
|
-
const TB *Bp1 = B + ldb * (j + 1);
|
852
|
-
const TB *Bp2 = B + ldb * (j + 2);
|
853
|
-
const TB *Bp3 = B + ldb * (j + 3);
|
854
|
-
const TA *Ap = A + lda * i;
|
855
|
-
for (int l = 0; l < k; ++l) {
|
856
|
-
float da0 = unhalf(Ap[l].d);
|
857
|
-
__m256i f = load(Ap + l);
|
858
|
-
__m256i u = _mm256_sign_epi8(f, f);
|
859
|
-
__m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
|
860
|
-
__m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
|
861
|
-
__m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
|
862
|
-
__m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
|
863
|
-
__m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
|
864
|
-
__m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
|
865
|
-
__m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
|
866
|
-
__m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
|
867
|
-
c0 = madd(d0, g0, c0);
|
868
|
-
c1 = madd(d1, g1, c1);
|
869
|
-
c2 = madd(d2, g2, c2);
|
870
|
-
c3 = madd(d3, g3, c3);
|
871
|
-
}
|
872
|
-
C[ldc * (j + 0) + i] = hsum(c0);
|
873
|
-
C[ldc * (j + 1) + i] = hsum(c1);
|
874
|
-
C[ldc * (j + 2) + i] = hsum(c2);
|
875
|
-
C[ldc * (j + 3) + i] = hsum(c3);
|
876
|
-
END_KERNEL()
|
877
|
-
}
|
878
|
-
|
879
|
-
NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
|
880
|
-
BEGIN_KERNEL(1, 1)
|
881
|
-
__m256 c = _mm256_setzero_ps();
|
882
|
-
const TA *Ap = A + lda * i;
|
883
|
-
const TB *Bp = B + ldb * j;
|
884
|
-
for (int l = 0; l < k; ++l) {
|
885
|
-
__m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
|
886
|
-
__m256i e = load(Ap + l);
|
887
|
-
__m256i f = load(Bp + l);
|
888
|
-
__m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
|
889
|
-
c = madd(d, g, c);
|
712
|
+
mnpack(m0, m, np, n);
|
713
|
+
}
|
714
|
+
|
715
|
+
template <int RM, int RN>
|
716
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
717
|
+
int64_t ytiles = (m - m0) / RM;
|
718
|
+
int64_t xtiles = (n - n0) / RN;
|
719
|
+
int64_t tiles = xtiles * ytiles;
|
720
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
721
|
+
int64_t start = duty * ith;
|
722
|
+
int64_t end = start + duty;
|
723
|
+
if (end > tiles)
|
724
|
+
end = tiles;
|
725
|
+
for (int64_t job = start; job < end; ++job) {
|
726
|
+
int64_t ii = m0 + job / xtiles * RM;
|
727
|
+
int64_t jj = n0 + job % xtiles * RN;
|
728
|
+
__m256 Cv[RN][RM] = {};
|
729
|
+
for (int64_t l = 0; l < k; ++l)
|
730
|
+
for (int64_t j = 0; j < RN; ++j)
|
731
|
+
for (int64_t i = 0; i < RM; ++i)
|
732
|
+
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
733
|
+
unhalf(B[ldb * (jj + j) + l].d)),
|
734
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
735
|
+
load(A + lda * (ii + i) + l)),
|
736
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
737
|
+
load(A + lda * (ii + i) + l))),
|
738
|
+
Cv[j][i]);
|
739
|
+
for (int64_t j = 0; j < RN; ++j)
|
740
|
+
for (int64_t i = 0; i < RM; ++i)
|
741
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
890
742
|
}
|
891
|
-
C[ldc * j + i] = hsum(c);
|
892
|
-
END_KERNEL()
|
893
743
|
}
|
894
744
|
|
895
745
|
inline __m256i load(const block_q8_0 *b) {
|
@@ -911,19 +761,19 @@ class tinyBLAS_Q0_AVX2 {
|
|
911
761
|
}
|
912
762
|
|
913
763
|
static inline __m256i denibble(const uint8_t *p) {
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
764
|
+
__m128i x = _mm_loadu_si128((const __m128i *)p);
|
765
|
+
return _mm256_and_si256(_mm256_set1_epi8(15),
|
766
|
+
_mm256_insertf128_si256(_mm256_castsi128_si256(x),
|
767
|
+
_mm_srli_epi16(x, 4), 1));
|
918
768
|
}
|
919
769
|
|
920
770
|
const TA *const A;
|
921
771
|
const TB *const B;
|
922
772
|
TC *const C;
|
923
|
-
const
|
924
|
-
const
|
925
|
-
const
|
926
|
-
const
|
773
|
+
const int64_t k;
|
774
|
+
const int64_t lda;
|
775
|
+
const int64_t ldb;
|
776
|
+
const int64_t ldc;
|
927
777
|
const int ith;
|
928
778
|
const int nth;
|
929
779
|
};
|
@@ -962,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
|
|
962
812
|
* @param Ctype is GGML data type of `C`
|
963
813
|
* @return true if this function was able to service the matmul request
|
964
814
|
*/
|
965
|
-
bool llamafile_sgemm(
|
966
|
-
|
815
|
+
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
816
|
+
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
967
817
|
|
968
818
|
assert(m >= 0);
|
969
819
|
assert(n >= 0);
|
@@ -973,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
973
823
|
assert(ldc >= m);
|
974
824
|
assert(nth > 0);
|
975
825
|
assert(ith < nth);
|
976
|
-
assert(1ll * lda * m <= 0x7fffffff);
|
977
|
-
assert(1ll * ldb * n <= 0x7fffffff);
|
978
|
-
assert(1ll * ldc * n <= 0x7fffffff);
|
979
826
|
|
980
827
|
if (Ctype != GGML_TYPE_F32)
|
981
828
|
return false;
|