llama_cpp 0.14.5 → 0.14.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1148 @@
1
+ // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2
+ // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
3
+ //
4
+ // Copyright 2024 Mozilla Foundation
5
+ //
6
+ // Permission is hereby granted, free of charge, to any person obtaining
7
+ // a copy of this software and associated documentation files (the
8
+ // "Software"), to deal in the Software without restriction, including
9
+ // without limitation the rights to use, copy, modify, merge, publish,
10
+ // distribute, sublicense, and/or sell copies of the Software, and to
11
+ // permit persons to whom the Software is furnished to do so, subject to
12
+ // the following conditions:
13
+ //
14
+ // The above copyright notice and this permission notice shall be
15
+ // included in all copies or substantial portions of the Software.
16
+ //
17
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
+ // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
+ // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
+ // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21
+ // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22
+ // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ // SOFTWARE.
25
+
26
+ //
27
+ // _ _ ___ _ _ ___
28
+ // | |_(_)_ _ _ _| _ ) | /_\ / __|
29
+ // | _| | ' \ || | _ \ |__ / _ \\__ \.
30
+ // \__|_|_||_\_, |___/____/_/ \_\___/
31
+ // |__/
32
+ //
33
+ // BASIC LINEAR ALGEBRA SUBPROGRAMS
34
+ //
35
+ //
36
+ // This file implements multithreaded CPU matrix multiplication for the
37
+ // common contiguous use case C = Aᵀ * B. These kernels are designed to
38
+ // have excellent performance[1] for matrices that fit in the CPU cache
39
+ // without imposing any overhead such as cache filling or malloc calls.
40
+ //
41
+ // This implementation does not guarantee any upper bound with rounding
42
+ // errors, which grow along with k. Our goal's to maximally exploit the
43
+ // hardware for performance, and then use whatever resources remain for
44
+ // improving numerical accuracy.
45
+ //
46
+ // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
47
+ // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
48
+
49
+ #pragma GCC diagnostic ignored "-Wpedantic"
50
+ #pragma GCC diagnostic ignored "-Wignored-attributes"
51
+
52
+ #include "sgemm.h"
53
+ #include "ggml-impl.h"
54
+ #include "ggml-quants.h"
55
+
56
+ #ifdef _MSC_VER
57
+ #define NOINLINE __declspec(noinline)
58
+ #else
59
+ #define NOINLINE __attribute__((__noinline__))
60
+ #endif
61
+
62
+ #if defined(__ARM_NEON) || defined(__AVX512F__)
63
+ #define VECTOR_REGISTERS 32
64
+ #else
65
+ #define VECTOR_REGISTERS 16
66
+ #endif
67
+
68
+ // there will be blocks
69
+ #define BEGIN_KERNEL(RM, RN) \
70
+ int ytiles = (m - m0) / RM; \
71
+ int xtiles = (n - n0) / RN; \
72
+ int tiles = ytiles * xtiles; \
73
+ int duty = (tiles + nth - 1) / nth; \
74
+ int start = duty * ith; \
75
+ int end = start + duty; \
76
+ if (end > tiles) \
77
+ end = tiles; \
78
+ for (int job = start; job < end; ++job) { \
79
+ int i = m0 + job / xtiles * RM; \
80
+ int j = n0 + job % xtiles * RN;
81
+
82
+ #define END_KERNEL() }
83
+
84
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
85
+
86
+ namespace {
87
+
88
+ inline float unhalf(ggml_fp16_t d) {
89
+ return GGML_FP16_TO_FP32(d);
90
+ }
91
+
92
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
93
+ // VECTORIZED ARITHMETIC OPERATIONS
94
+
95
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
96
+ inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
97
+ inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
98
+ inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
99
+ #endif // __SSE__
100
+
101
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
102
+ inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
103
+ inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
104
+ inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
105
+ #endif // __AVX__
106
+
107
+ #if defined(__AVX512F__)
108
+ inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
109
+ inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
110
+ inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
111
+ #endif // __AVX512F__
112
+
113
+ #if defined(__ARM_NEON)
114
+ inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
115
+ inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
116
+ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
117
+ #endif // __ARM_NEON
118
+
119
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
120
+ inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
121
+ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
122
+ inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
123
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
124
+
125
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
126
+ // VECTORIZED HORIZONTAL SUM
127
+
128
+ #if defined(__ARM_NEON)
129
+ inline float hsum(float32x4_t x) {
130
+ return vaddvq_f32(x);
131
+ }
132
+ #endif // __ARM_NEON
133
+
134
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
135
+ inline float hsum(float16x8_t x) {
136
+ return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
137
+ vcvt_f32_f16(vget_high_f16(x))));
138
+ }
139
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
140
+
141
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
142
+ inline float hsum(__m128 x) {
143
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
144
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
145
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
146
+ #else
147
+ __m128 t;
148
+ t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
149
+ x = _mm_add_ps(x, t);
150
+ t = _mm_movehl_ps(t, x);
151
+ x = _mm_add_ss(x, t);
152
+ #endif
153
+ return _mm_cvtss_f32(x);
154
+ }
155
+ #endif
156
+
157
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
158
+ inline float hsum(__m256 x) {
159
+ return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
160
+ _mm256_castps256_ps128(x)));
161
+ }
162
+ #endif // __AVX__
163
+
164
+ #if defined(__AVX512F__)
165
+ inline float hsum(__m512 x) {
166
+ return _mm512_reduce_add_ps(x);
167
+ }
168
+ #endif // __AVX512F__
169
+
170
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
171
+ // VECTORIZED MEMORY LOADING
172
+
173
+ template <typename T, typename U> T load(const U *);
174
+
175
+ #if defined(__ARM_NEON)
176
+ template <> inline float32x4_t load(const float *p) {
177
+ return vld1q_f32(p);
178
+ }
179
+ #if !defined(_MSC_VER)
180
+ template <> inline float16x8_t load(const ggml_fp16_t *p) {
181
+ return vld1q_f16((const float16_t *)p);
182
+ }
183
+ template <> inline float32x4_t load(const ggml_fp16_t *p) {
184
+ return vcvt_f32_f16(vld1_f16((const float16_t *)p));
185
+ }
186
+ #endif // _MSC_VER
187
+ #endif // __ARM_NEON
188
+
189
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
190
+ template <> inline __m128 load(const float *p) {
191
+ return _mm_loadu_ps(p);
192
+ }
193
+ #endif // __SSE__
194
+
195
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
196
+ template <> inline __m256 load(const float *p) {
197
+ return _mm256_loadu_ps(p);
198
+ }
199
+ #endif // __AVX__
200
+
201
+ #if defined(__F16C__)
202
+ template <> inline __m256 load(const ggml_fp16_t *p) {
203
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
204
+ }
205
+ #endif // __F16C__
206
+
207
+ #if defined(__AVX512F__)
208
+ template <> inline __m512 load(const float *p) {
209
+ return _mm512_loadu_ps(p);
210
+ }
211
+ template <> inline __m512 load(const ggml_fp16_t *p) {
212
+ return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
213
+ }
214
+ #endif // __AVX512F__
215
+
216
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
217
+ // ABSTRACTIONS
218
+
219
+ /**
220
+ * Computes a * b + c.
221
+ *
222
+ * This operation will become fused into a single arithmetic instruction
223
+ * if the hardware has support for this feature, e.g. Intel Haswell+ (c.
224
+ * 2013), AMD Bulldozer+ (c. 2011), etc.
225
+ */
226
+ template <typename T, typename U>
227
+ inline U madd(T a, T b, U c) {
228
+ return add(mul(a, b), c);
229
+ }
230
+
231
+ /**
232
+ * Computes a * b + c with error correction.
233
+ *
234
+ * @see W. Kahan, "Further remarks on reducing truncation errors,"
235
+ * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
236
+ * doi: 10.1145/363707.363723.
237
+ */
238
+ template <typename T, typename U>
239
+ inline U madder(T a, T b, U c, U *e) {
240
+ U y = sub(mul(a, b), *e);
241
+ U t = add(c, y);
242
+ *e = sub(sub(t, c), y);
243
+ return t;
244
+ }
245
+
246
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
247
+ // FLOATING POINT MATRIX MULTIPLICATION
248
+
249
+ template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
250
+ class tinyBLAS {
251
+ public:
252
+ tinyBLAS(int k,
253
+ const TA *A, int lda,
254
+ const TB *B, int ldb,
255
+ TC *C, int ldc,
256
+ int ith, int nth)
257
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
258
+ }
259
+
260
+ void matmul(int m, int n, int task) {
261
+ if (task == GGML_TASK_TYPE_COMPUTE)
262
+ mnpack(0, m, 0, n);
263
+ }
264
+
265
+ private:
266
+ NOINLINE void mnpack(int m0, int m, int n0, int n) {
267
+ int mc, nc, mp, np;
268
+ if (m - m0 <= 0 || n - n0 <= 0)
269
+ return;
270
+ if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) {
271
+ mc = 5;
272
+ nc = 5;
273
+ gemm5x5(m0, m, n0, n);
274
+ } else if (n - n0 >= 4 && m - m0 >= 3) {
275
+ mc = 3;
276
+ nc = 4;
277
+ gemm3x4(m0, m, n0, n);
278
+ } else if (n - n0 >= 4) {
279
+ mc = 1;
280
+ nc = 4;
281
+ gemm1x4(m0, m, n0, n);
282
+ } else if (m - m0 >= 4) {
283
+ mc = 4;
284
+ nc = 1;
285
+ gemm4x1(m0, m, n0, n);
286
+ } else {
287
+ mc = 1;
288
+ nc = 1;
289
+ gemm1x1(m0, m, n0, n);
290
+ }
291
+ mp = m0 + (m - m0) / mc * mc;
292
+ np = n0 + (n - n0) / nc * nc;
293
+ mnpack(mp, m, n0, np);
294
+ mnpack(m0, mp, np, n);
295
+ mnpack(mp, m, np, n);
296
+ }
297
+
298
+ NOINLINE void gemm5x5(int m0, int m, int n0, int n) {
299
+ BEGIN_KERNEL(5, 5)
300
+ D c00 = {0};
301
+ D c01 = {0};
302
+ D c02 = {0};
303
+ D c03 = {0};
304
+ D c04 = {0};
305
+ D c10 = {0};
306
+ D c11 = {0};
307
+ D c12 = {0};
308
+ D c13 = {0};
309
+ D c14 = {0};
310
+ D c20 = {0};
311
+ D c21 = {0};
312
+ D c22 = {0};
313
+ D c23 = {0};
314
+ D c24 = {0};
315
+ D c30 = {0};
316
+ D c31 = {0};
317
+ D c32 = {0};
318
+ D c33 = {0};
319
+ D c34 = {0};
320
+ D c40 = {0};
321
+ D c41 = {0};
322
+ D c42 = {0};
323
+ D c43 = {0};
324
+ D c44 = {0};
325
+ for (int l = 0; l < k; l += KN) {
326
+ V k0 = load<V>(B + ldb * (j + 0) + l);
327
+ V k1 = load<V>(B + ldb * (j + 1) + l);
328
+ V k2 = load<V>(B + ldb * (j + 2) + l);
329
+ V k3 = load<V>(B + ldb * (j + 3) + l);
330
+ V k4 = load<V>(B + ldb * (j + 4) + l);
331
+ V a0 = load<V>(A + lda * (i + 0) + l);
332
+ c00 = madd(a0, k0, c00);
333
+ c01 = madd(a0, k1, c01);
334
+ c02 = madd(a0, k2, c02);
335
+ c03 = madd(a0, k3, c03);
336
+ c04 = madd(a0, k4, c04);
337
+ V a1 = load<V>(A + lda * (i + 1) + l);
338
+ c10 = madd(a1, k0, c10);
339
+ c11 = madd(a1, k1, c11);
340
+ c12 = madd(a1, k2, c12);
341
+ c13 = madd(a1, k3, c13);
342
+ c14 = madd(a1, k4, c14);
343
+ V a2 = load<V>(A + lda * (i + 2) + l);
344
+ c20 = madd(a2, k0, c20);
345
+ c21 = madd(a2, k1, c21);
346
+ c22 = madd(a2, k2, c22);
347
+ c23 = madd(a2, k3, c23);
348
+ c24 = madd(a2, k4, c24);
349
+ V a3 = load<V>(A + lda * (i + 3) + l);
350
+ c30 = madd(a3, k0, c30);
351
+ c31 = madd(a3, k1, c31);
352
+ c32 = madd(a3, k2, c32);
353
+ c33 = madd(a3, k3, c33);
354
+ c34 = madd(a3, k4, c34);
355
+ V a4 = load<V>(A + lda * (i + 4) + l);
356
+ c40 = madd(a4, k0, c40);
357
+ c41 = madd(a4, k1, c41);
358
+ c42 = madd(a4, k2, c42);
359
+ c43 = madd(a4, k3, c43);
360
+ c44 = madd(a4, k4, c44);
361
+ }
362
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
363
+ C[ldc * (j + 0) + (i + 1)] = hsum(c10);
364
+ C[ldc * (j + 0) + (i + 2)] = hsum(c20);
365
+ C[ldc * (j + 0) + (i + 3)] = hsum(c30);
366
+ C[ldc * (j + 0) + (i + 4)] = hsum(c40);
367
+ C[ldc * (j + 1) + (i + 0)] = hsum(c01);
368
+ C[ldc * (j + 1) + (i + 1)] = hsum(c11);
369
+ C[ldc * (j + 1) + (i + 2)] = hsum(c21);
370
+ C[ldc * (j + 1) + (i + 3)] = hsum(c31);
371
+ C[ldc * (j + 1) + (i + 4)] = hsum(c41);
372
+ C[ldc * (j + 2) + (i + 0)] = hsum(c02);
373
+ C[ldc * (j + 2) + (i + 1)] = hsum(c12);
374
+ C[ldc * (j + 2) + (i + 2)] = hsum(c22);
375
+ C[ldc * (j + 2) + (i + 3)] = hsum(c32);
376
+ C[ldc * (j + 2) + (i + 4)] = hsum(c42);
377
+ C[ldc * (j + 3) + (i + 0)] = hsum(c03);
378
+ C[ldc * (j + 3) + (i + 1)] = hsum(c13);
379
+ C[ldc * (j + 3) + (i + 2)] = hsum(c23);
380
+ C[ldc * (j + 3) + (i + 3)] = hsum(c33);
381
+ C[ldc * (j + 3) + (i + 4)] = hsum(c43);
382
+ C[ldc * (j + 4) + (i + 0)] = hsum(c04);
383
+ C[ldc * (j + 4) + (i + 1)] = hsum(c14);
384
+ C[ldc * (j + 4) + (i + 2)] = hsum(c24);
385
+ C[ldc * (j + 4) + (i + 3)] = hsum(c34);
386
+ C[ldc * (j + 4) + (i + 4)] = hsum(c44);
387
+ END_KERNEL()
388
+ }
389
+
390
+ NOINLINE void gemm3x4(int m0, int m, int n0, int n) {
391
+ BEGIN_KERNEL(3, 4)
392
+ D c00 = {0};
393
+ D c01 = {0};
394
+ D c02 = {0};
395
+ D c03 = {0};
396
+ D c10 = {0};
397
+ D c11 = {0};
398
+ D c12 = {0};
399
+ D c13 = {0};
400
+ D c20 = {0};
401
+ D c21 = {0};
402
+ D c22 = {0};
403
+ D c23 = {0};
404
+ for (int l = 0; l < k; l += KN) {
405
+ V k0 = load<V>(B + ldb * (j + 0) + l);
406
+ V k1 = load<V>(B + ldb * (j + 1) + l);
407
+ V k2 = load<V>(B + ldb * (j + 2) + l);
408
+ V k3 = load<V>(B + ldb * (j + 3) + l);
409
+ V a0 = load<V>(A + lda * (i + 0) + l);
410
+ c00 = madd(a0, k0, c00);
411
+ c01 = madd(a0, k1, c01);
412
+ c02 = madd(a0, k2, c02);
413
+ c03 = madd(a0, k3, c03);
414
+ V a1 = load<V>(A + lda * (i + 1) + l);
415
+ c10 = madd(a1, k0, c10);
416
+ c11 = madd(a1, k1, c11);
417
+ c12 = madd(a1, k2, c12);
418
+ c13 = madd(a1, k3, c13);
419
+ V a2 = load<V>(A + lda * (i + 2) + l);
420
+ c20 = madd(a2, k0, c20);
421
+ c21 = madd(a2, k1, c21);
422
+ c22 = madd(a2, k2, c22);
423
+ c23 = madd(a2, k3, c23);
424
+ }
425
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
426
+ C[ldc * (j + 0) + (i + 1)] = hsum(c10);
427
+ C[ldc * (j + 0) + (i + 2)] = hsum(c20);
428
+ C[ldc * (j + 1) + (i + 0)] = hsum(c01);
429
+ C[ldc * (j + 1) + (i + 1)] = hsum(c11);
430
+ C[ldc * (j + 1) + (i + 2)] = hsum(c21);
431
+ C[ldc * (j + 2) + (i + 0)] = hsum(c02);
432
+ C[ldc * (j + 2) + (i + 1)] = hsum(c12);
433
+ C[ldc * (j + 2) + (i + 2)] = hsum(c22);
434
+ C[ldc * (j + 3) + (i + 0)] = hsum(c03);
435
+ C[ldc * (j + 3) + (i + 1)] = hsum(c13);
436
+ C[ldc * (j + 3) + (i + 2)] = hsum(c23);
437
+ END_KERNEL()
438
+ }
439
+
440
+ NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
441
+ BEGIN_KERNEL(1, 4)
442
+ D c00 = {0}, e00 = {0};
443
+ D c01 = {0}, e01 = {0};
444
+ D c02 = {0}, e02 = {0};
445
+ D c03 = {0}, e03 = {0};
446
+ for (int l = 0; l < k; l += KN) {
447
+ V a = load<V>(A + lda * (i + 0) + l);
448
+ c00 = madder(a, load<V>(B + ldb * (j + 0) + l), c00, &e00);
449
+ c01 = madder(a, load<V>(B + ldb * (j + 1) + l), c01, &e01);
450
+ c02 = madder(a, load<V>(B + ldb * (j + 2) + l), c02, &e02);
451
+ c03 = madder(a, load<V>(B + ldb * (j + 3) + l), c03, &e03);
452
+ }
453
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
454
+ C[ldc * (j + 1) + (i + 0)] = hsum(c01);
455
+ C[ldc * (j + 2) + (i + 0)] = hsum(c02);
456
+ C[ldc * (j + 3) + (i + 0)] = hsum(c03);
457
+ END_KERNEL()
458
+ }
459
+
460
+ NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
461
+ BEGIN_KERNEL(4, 1)
462
+ D c00 = {0}, e00 = {0};
463
+ D c10 = {0}, e10 = {0};
464
+ D c20 = {0}, e20 = {0};
465
+ D c30 = {0}, e30 = {0};
466
+ for (int l = 0; l < k; l += KN) {
467
+ V b = load<V>(B + ldb * (j + 0) + l);
468
+ c00 = madder(load<V>(A + lda * (i + 0) + l), b, c00, &e00);
469
+ c10 = madder(load<V>(A + lda * (i + 1) + l), b, c10, &e10);
470
+ c20 = madder(load<V>(A + lda * (i + 2) + l), b, c20, &e20);
471
+ c30 = madder(load<V>(A + lda * (i + 3) + l), b, c30, &e30);
472
+ }
473
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
474
+ C[ldc * (j + 0) + (i + 1)] = hsum(c10);
475
+ C[ldc * (j + 0) + (i + 2)] = hsum(c20);
476
+ C[ldc * (j + 0) + (i + 3)] = hsum(c30);
477
+ END_KERNEL()
478
+ }
479
+
480
+ NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
481
+ BEGIN_KERNEL(1, 1)
482
+ D c = {0}, e = {0};
483
+ for (int l = 0; l < k; l += KN)
484
+ c = madder(load<V>(A + lda * i + l),
485
+ load<V>(B + ldb * j + l), c, &e);
486
+ C[ldc * j + i] = hsum(c);
487
+ END_KERNEL()
488
+ }
489
+
490
+ const TA *const A;
491
+ const TB *const B;
492
+ TC *const C;
493
+ const int k;
494
+ const int lda;
495
+ const int ldb;
496
+ const int ldc;
497
+ const int ith;
498
+ const int nth;
499
+ };
500
+
501
+ //////////////////////////////////////////////////////////////////////////////////////////
502
+ // QUANT ZERO MATRIX MULTIPLICATION
503
+
504
+ #if defined(__ARM_FEATURE_DOTPROD)
505
+ template <typename TA>
506
+ class tinyBLAS_Q0_ARM {
507
+ public:
508
+ tinyBLAS_Q0_ARM(int k,
509
+ const TA *A, int lda,
510
+ const block_q8_0 *B, int ldb,
511
+ float *C, int ldc,
512
+ int ith, int nth)
513
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
514
+ }
515
+
516
+ void matmul(int m, int n, int task) {
517
+ if (task == GGML_TASK_TYPE_COMPUTE)
518
+ mnpack(0, m, 0, n);
519
+ }
520
+
521
+ private:
522
+ NOINLINE void mnpack(int m0, int m, int n0, int n) {
523
+ int mc, nc, mp, np;
524
+ if (m - m0 <= 0 || n - n0 <= 0)
525
+ return;
526
+ if (m - m0 >= 3 && n - n0 >= 3) {
527
+ mc = 3;
528
+ nc = 3;
529
+ gemm3x3(m0, m, n0, n);
530
+ } else {
531
+ mc = 1;
532
+ nc = 1;
533
+ gemm1x1(m0, m, n0, n);
534
+ }
535
+ mp = m0 + (m - m0) / mc * mc;
536
+ np = n0 + (n - n0) / nc * nc;
537
+ mnpack(mp, m, n0, np);
538
+ mnpack(m0, mp, np, n);
539
+ mnpack(mp, m, np, n);
540
+ }
541
+
542
+ NOINLINE void gemm3x3(int m0, int m, int n0, int n) {
543
+ BEGIN_KERNEL(3, 3)
544
+ int32x4_t zero = vdupq_n_s32(0);
545
+ float32x4_t c00 = vdupq_n_f32(0.f);
546
+ float32x4_t c01 = vdupq_n_f32(0.f);
547
+ float32x4_t c02 = vdupq_n_f32(0.f);
548
+ float32x4_t c10 = vdupq_n_f32(0.f);
549
+ float32x4_t c11 = vdupq_n_f32(0.f);
550
+ float32x4_t c12 = vdupq_n_f32(0.f);
551
+ float32x4_t c20 = vdupq_n_f32(0.f);
552
+ float32x4_t c21 = vdupq_n_f32(0.f);
553
+ float32x4_t c22 = vdupq_n_f32(0.f);
554
+ const TA *Ap0 = A + lda * (i + 0);
555
+ const TA *Ap1 = A + lda * (i + 1);
556
+ const TA *Ap2 = A + lda * (i + 2);
557
+ const block_q8_0 *Bp0 = B + ldb * (j + 0);
558
+ const block_q8_0 *Bp1 = B + ldb * (j + 1);
559
+ const block_q8_0 *Bp2 = B + ldb * (j + 2);
560
+ for (int l = 0; l < k; ++l) {
561
+ c00 = vmlaq_n_f32(
562
+ c00,
563
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)),
564
+ load_hi(Ap0 + l), load_hi(Bp0 + l))),
565
+ unhalf(Ap0[l].d) * unhalf(Bp0[l].d));
566
+ c01 = vmlaq_n_f32(
567
+ c01,
568
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)),
569
+ load_hi(Ap0 + l), load_hi(Bp1 + l))),
570
+ unhalf(Ap0[l].d) * unhalf(Bp1[l].d));
571
+ c02 = vmlaq_n_f32(
572
+ c02,
573
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)),
574
+ load_hi(Ap0 + l), load_hi(Bp2 + l))),
575
+ unhalf(Ap0[l].d) * unhalf(Bp2[l].d));
576
+ c10 = vmlaq_n_f32(
577
+ c10,
578
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)),
579
+ load_hi(Ap1 + l), load_hi(Bp0 + l))),
580
+ unhalf(Ap1[l].d) * unhalf(Bp0[l].d));
581
+ c11 = vmlaq_n_f32(
582
+ c11,
583
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)),
584
+ load_hi(Ap1 + l), load_hi(Bp1 + l))),
585
+ unhalf(Ap1[l].d) * unhalf(Bp1[l].d));
586
+ c12 = vmlaq_n_f32(
587
+ c12,
588
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)),
589
+ load_hi(Ap1 + l), load_hi(Bp2 + l))),
590
+ unhalf(Ap1[l].d) * unhalf(Bp2[l].d));
591
+ c20 = vmlaq_n_f32(
592
+ c20,
593
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)),
594
+ load_hi(Ap2 + l), load_hi(Bp0 + l))),
595
+ unhalf(Ap2[l].d) * unhalf(Bp0[l].d));
596
+ c21 = vmlaq_n_f32(
597
+ c21,
598
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)),
599
+ load_hi(Ap2 + l), load_hi(Bp1 + l))),
600
+ unhalf(Ap2[l].d) * unhalf(Bp1[l].d));
601
+ c22 = vmlaq_n_f32(
602
+ c22,
603
+ vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)),
604
+ load_hi(Ap2 + l), load_hi(Bp2 + l))),
605
+ unhalf(Ap2[l].d) * unhalf(Bp2[l].d));
606
+ }
607
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
608
+ C[ldc * (j + 0) + (i + 1)] = hsum(c10);
609
+ C[ldc * (j + 0) + (i + 2)] = hsum(c20);
610
+ C[ldc * (j + 1) + (i + 0)] = hsum(c01);
611
+ C[ldc * (j + 1) + (i + 1)] = hsum(c11);
612
+ C[ldc * (j + 1) + (i + 2)] = hsum(c21);
613
+ C[ldc * (j + 2) + (i + 0)] = hsum(c02);
614
+ C[ldc * (j + 2) + (i + 1)] = hsum(c12);
615
+ C[ldc * (j + 2) + (i + 2)] = hsum(c22);
616
+ END_KERNEL()
617
+ }
618
+
619
+ NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
620
+ BEGIN_KERNEL(1, 1)
621
+ float32x4_t acc = vdupq_n_f32(0.f);
622
+ const TA *Ap = A + lda * i;
623
+ const block_q8_0 *Bp = B + ldb * j;
624
+ for (int l = 0; l < k; ++l) {
625
+ acc = vmlaq_n_f32(acc,
626
+ vcvtq_f32_s32(vdotq_s32(
627
+ vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)),
628
+ load_hi(Ap + l), load_hi(Bp + l))),
629
+ unhalf(Ap[l].d) * unhalf(Bp[l].d));
630
+ }
631
+ C[ldc * j + i] = hsum(acc);
632
+ END_KERNEL()
633
+ }
634
+
635
+ inline int8x16_t load_lo(const block_q8_0 *b) {
636
+ return vld1q_s8(b->qs);
637
+ }
638
+ inline int8x16_t load_hi(const block_q8_0 *b) {
639
+ return vld1q_s8(b->qs + 16);
640
+ }
641
+
642
+ inline int8x16_t load_lo(const block_q4_0 *b) {
643
+ return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
644
+ vdupq_n_u8(0x0f))),
645
+ vdupq_n_s8(0x8));
646
+ }
647
+ inline int8x16_t load_hi(const block_q4_0 *b) {
648
+ return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
649
+ vdupq_n_s8(0x8));
650
+ }
651
+
652
+ const TA *const A;
653
+ const block_q8_0 *const B;
654
+ float *const C;
655
+ const int k;
656
+ const int lda;
657
+ const int ldb;
658
+ const int ldc;
659
+ const int ith;
660
+ const int nth;
661
+ };
662
+ #endif // __ARM_FEATURE_DOTPROD
663
+
664
+ #if defined(__AVX2__) || defined(__AVX512F__)
665
+ template <typename TA, typename TB, typename TC>
666
+ class tinyBLAS_Q0_AVX2 {
667
+ public:
668
+ tinyBLAS_Q0_AVX2(int k,
669
+ const TA *A, int lda,
670
+ const TB *B, int ldb,
671
+ TC *C, int ldc,
672
+ int ith, int nth)
673
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
674
+ }
675
+
676
+ void matmul(int m, int n, int task) {
677
+ if (task == GGML_TASK_TYPE_COMPUTE)
678
+ mnpack(0, m, 0, n);
679
+ }
680
+
681
+ private:
682
+ NOINLINE void mnpack(int m0, int m, int n0, int n) {
683
+ int mc, nc, mp, np;
684
+ if (m - m0 <= 0 || n - n0 <= 0)
685
+ return;
686
+ if (m - m0 >= 4 && n - n0 >= 3) {
687
+ mc = 4;
688
+ nc = 3;
689
+ gemm4x3(m0, m, n0, n);
690
+ } else if (m - m0 >= 4 && n - n0 >= 1) {
691
+ mc = 4;
692
+ nc = 1;
693
+ gemm4x1(m0, m, n0, n);
694
+ } else if (m - m0 >= 1 && n - n0 >= 4) {
695
+ mc = 1;
696
+ nc = 4;
697
+ gemm1x4(m0, m, n0, n);
698
+ } else {
699
+ mc = 1;
700
+ nc = 1;
701
+ gemm1x1(m0, m, n0, n);
702
+ }
703
+ mp = m0 + (m - m0) / mc * mc;
704
+ np = n0 + (n - n0) / nc * nc;
705
+ mnpack(mp, m, n0, np);
706
+ mnpack(m0, mp, np, n);
707
+ mnpack(mp, m, np, n);
708
+ }
709
+
710
+ NOINLINE void gemm4x3(int m0, int m, int n0, int n) {
711
+ BEGIN_KERNEL(4, 3)
712
+ __m256 c00 = _mm256_setzero_ps();
713
+ __m256 c10 = _mm256_setzero_ps();
714
+ __m256 c20 = _mm256_setzero_ps();
715
+ __m256 c30 = _mm256_setzero_ps();
716
+ __m256 c01 = _mm256_setzero_ps();
717
+ __m256 c11 = _mm256_setzero_ps();
718
+ __m256 c21 = _mm256_setzero_ps();
719
+ __m256 c31 = _mm256_setzero_ps();
720
+ __m256 c02 = _mm256_setzero_ps();
721
+ __m256 c12 = _mm256_setzero_ps();
722
+ __m256 c22 = _mm256_setzero_ps();
723
+ __m256 c32 = _mm256_setzero_ps();
724
+ const TA *Ap0 = A + lda * (i + 0);
725
+ const TA *Ap1 = A + lda * (i + 1);
726
+ const TA *Ap2 = A + lda * (i + 2);
727
+ const TA *Ap3 = A + lda * (i + 3);
728
+ const TB *Bp0 = B + ldb * (j + 0);
729
+ const TB *Bp1 = B + ldb * (j + 1);
730
+ const TB *Bp2 = B + ldb * (j + 2);
731
+ for (int l = 0; l < k; ++l) {
732
+ float da0 = unhalf(Ap0[l].d);
733
+ float da1 = unhalf(Ap1[l].d);
734
+ float da2 = unhalf(Ap2[l].d);
735
+ float da3 = unhalf(Ap3[l].d);
736
+ __m256i e0 = load(Ap0 + l);
737
+ __m256i e1 = load(Ap1 + l);
738
+ __m256i e2 = load(Ap2 + l);
739
+ __m256i e3 = load(Ap3 + l);
740
+ float db0 = unhalf(Bp0[l].d);
741
+ __m256 d00 = _mm256_set1_ps(da0 * db0);
742
+ __m256 d10 = _mm256_set1_ps(da1 * db0);
743
+ __m256 d20 = _mm256_set1_ps(da2 * db0);
744
+ __m256 d30 = _mm256_set1_ps(da3 * db0);
745
+ __m256i f0 = load(Bp0 + l);
746
+ __m256i u0 = _mm256_sign_epi8(f0, f0);
747
+ __m256i s00 = _mm256_sign_epi8(e0, f0);
748
+ __m256i s10 = _mm256_sign_epi8(e1, f0);
749
+ __m256i s20 = _mm256_sign_epi8(e2, f0);
750
+ __m256i s30 = _mm256_sign_epi8(e3, f0);
751
+ c00 = madd(d00, updot(u0, s00), c00);
752
+ c10 = madd(d10, updot(u0, s10), c10);
753
+ c20 = madd(d20, updot(u0, s20), c20);
754
+ c30 = madd(d30, updot(u0, s30), c30);
755
+ float db1 = unhalf(Bp1[l].d);
756
+ __m256 d01 = _mm256_set1_ps(da0 * db1);
757
+ __m256 d11 = _mm256_set1_ps(da1 * db1);
758
+ __m256 d21 = _mm256_set1_ps(da2 * db1);
759
+ __m256 d31 = _mm256_set1_ps(da3 * db1);
760
+ __m256i f1 = load(Bp1 + l);
761
+ __m256i u1 = _mm256_sign_epi8(f1, f1);
762
+ __m256i s01 = _mm256_sign_epi8(e0, f1);
763
+ __m256i s11 = _mm256_sign_epi8(e1, f1);
764
+ __m256i s21 = _mm256_sign_epi8(e2, f1);
765
+ __m256i s31 = _mm256_sign_epi8(e3, f1);
766
+ c01 = madd(d01, updot(u1, s01), c01);
767
+ c11 = madd(d11, updot(u1, s11), c11);
768
+ c21 = madd(d21, updot(u1, s21), c21);
769
+ c31 = madd(d31, updot(u1, s31), c31);
770
+ float db2 = unhalf(Bp2[l].d);
771
+ __m256 d02 = _mm256_set1_ps(da0 * db2);
772
+ __m256 d12 = _mm256_set1_ps(da1 * db2);
773
+ __m256 d22 = _mm256_set1_ps(da2 * db2);
774
+ __m256 d32 = _mm256_set1_ps(da3 * db2);
775
+ __m256i f2 = load(Bp2 + l);
776
+ __m256i u2 = _mm256_sign_epi8(f2, f2);
777
+ __m256i s02 = _mm256_sign_epi8(e0, f2);
778
+ __m256i s12 = _mm256_sign_epi8(e1, f2);
779
+ __m256i s22 = _mm256_sign_epi8(e2, f2);
780
+ __m256i s32 = _mm256_sign_epi8(e3, f2);
781
+ c02 = madd(d02, updot(u2, s02), c02);
782
+ c12 = madd(d12, updot(u2, s12), c12);
783
+ c22 = madd(d22, updot(u2, s22), c22);
784
+ c32 = madd(d32, updot(u2, s32), c32);
785
+ }
786
+ C[ldc * (j + 0) + (i + 0)] = hsum(c00);
787
+ C[ldc * (j + 0) + (i + 1)] = hsum(c10);
788
+ C[ldc * (j + 0) + (i + 2)] = hsum(c20);
789
+ C[ldc * (j + 0) + (i + 3)] = hsum(c30);
790
+ C[ldc * (j + 1) + (i + 0)] = hsum(c01);
791
+ C[ldc * (j + 1) + (i + 1)] = hsum(c11);
792
+ C[ldc * (j + 1) + (i + 2)] = hsum(c21);
793
+ C[ldc * (j + 1) + (i + 3)] = hsum(c31);
794
+ C[ldc * (j + 2) + (i + 0)] = hsum(c02);
795
+ C[ldc * (j + 2) + (i + 1)] = hsum(c12);
796
+ C[ldc * (j + 2) + (i + 2)] = hsum(c22);
797
+ C[ldc * (j + 2) + (i + 3)] = hsum(c32);
798
+ END_KERNEL()
799
+ }
800
+
801
+ NOINLINE void gemm4x1(int m0, int m, int n0, int n) {
802
+ BEGIN_KERNEL(4, 1)
803
+ __m256 c0 = _mm256_setzero_ps();
804
+ __m256 c1 = _mm256_setzero_ps();
805
+ __m256 c2 = _mm256_setzero_ps();
806
+ __m256 c3 = _mm256_setzero_ps();
807
+ const TA *Ap0 = A + lda * (i + 0);
808
+ const TA *Ap1 = A + lda * (i + 1);
809
+ const TA *Ap2 = A + lda * (i + 2);
810
+ const TA *Ap3 = A + lda * (i + 3);
811
+ const TB *Bp = B + ldb * j;
812
+ for (int l = 0; l < k; ++l) {
813
+ float db0 = unhalf(Bp[l].d);
814
+ __m256i f = load(Bp + l);
815
+ __m256i u = _mm256_sign_epi8(f, f);
816
+ __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0);
817
+ __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0);
818
+ __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0);
819
+ __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0);
820
+ __m256i e0 = load(Ap0 + l);
821
+ __m256i e1 = load(Ap1 + l);
822
+ __m256i e2 = load(Ap2 + l);
823
+ __m256i e3 = load(Ap3 + l);
824
+ __m256i s0 = _mm256_sign_epi8(e0, f);
825
+ __m256i s1 = _mm256_sign_epi8(e1, f);
826
+ __m256i s2 = _mm256_sign_epi8(e2, f);
827
+ __m256i s3 = _mm256_sign_epi8(e3, f);
828
+ __m256 g0 = updot(u, s0);
829
+ __m256 g1 = updot(u, s1);
830
+ __m256 g2 = updot(u, s2);
831
+ __m256 g3 = updot(u, s3);
832
+ c0 = madd(d0, g0, c0);
833
+ c1 = madd(d1, g1, c1);
834
+ c2 = madd(d2, g2, c2);
835
+ c3 = madd(d3, g3, c3);
836
+ }
837
+ C[ldc * j + (i + 0)] = hsum(c0);
838
+ C[ldc * j + (i + 1)] = hsum(c1);
839
+ C[ldc * j + (i + 2)] = hsum(c2);
840
+ C[ldc * j + (i + 3)] = hsum(c3);
841
+ END_KERNEL()
842
+ }
843
+
844
+ NOINLINE void gemm1x4(int m0, int m, int n0, int n) {
845
+ BEGIN_KERNEL(1, 4)
846
+ __m256 c0 = _mm256_setzero_ps();
847
+ __m256 c1 = _mm256_setzero_ps();
848
+ __m256 c2 = _mm256_setzero_ps();
849
+ __m256 c3 = _mm256_setzero_ps();
850
+ const TB *Bp0 = B + ldb * (j + 0);
851
+ const TB *Bp1 = B + ldb * (j + 1);
852
+ const TB *Bp2 = B + ldb * (j + 2);
853
+ const TB *Bp3 = B + ldb * (j + 3);
854
+ const TA *Ap = A + lda * i;
855
+ for (int l = 0; l < k; ++l) {
856
+ float da0 = unhalf(Ap[l].d);
857
+ __m256i f = load(Ap + l);
858
+ __m256i u = _mm256_sign_epi8(f, f);
859
+ __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0);
860
+ __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0);
861
+ __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0);
862
+ __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0);
863
+ __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f));
864
+ __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f));
865
+ __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f));
866
+ __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f));
867
+ c0 = madd(d0, g0, c0);
868
+ c1 = madd(d1, g1, c1);
869
+ c2 = madd(d2, g2, c2);
870
+ c3 = madd(d3, g3, c3);
871
+ }
872
+ C[ldc * (j + 0) + i] = hsum(c0);
873
+ C[ldc * (j + 1) + i] = hsum(c1);
874
+ C[ldc * (j + 2) + i] = hsum(c2);
875
+ C[ldc * (j + 3) + i] = hsum(c3);
876
+ END_KERNEL()
877
+ }
878
+
879
+ NOINLINE void gemm1x1(int m0, int m, int n0, int n) {
880
+ BEGIN_KERNEL(1, 1)
881
+ __m256 c = _mm256_setzero_ps();
882
+ const TA *Ap = A + lda * i;
883
+ const TB *Bp = B + ldb * j;
884
+ for (int l = 0; l < k; ++l) {
885
+ __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d));
886
+ __m256i e = load(Ap + l);
887
+ __m256i f = load(Bp + l);
888
+ __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e));
889
+ c = madd(d, g, c);
890
+ }
891
+ C[ldc * j + i] = hsum(c);
892
+ END_KERNEL()
893
+ }
894
+
895
+ inline __m256i load(const block_q8_0 *b) {
896
+ return _mm256_loadu_si256((const __m256i *)b->qs);
897
+ }
898
+
899
+ inline __m256i load(const block_q4_0 *b) {
900
+ return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
901
+ }
902
+
903
+ inline __m256 updot(__m256i u, __m256i s) {
904
+ __m256i res;
905
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
906
+ res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
907
+ #else
908
+ res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
909
+ #endif
910
+ return _mm256_cvtepi32_ps(res);
911
+ }
912
+
913
+ static inline __m256i denibble(const uint8_t *p) {
914
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)p);
915
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
916
+ const __m256i lowMask = _mm256_set1_epi8(15);
917
+ return _mm256_and_si256(lowMask, bytes);
918
+ }
919
+
920
+ const TA *const A;
921
+ const TB *const B;
922
+ TC *const C;
923
+ const int k;
924
+ const int lda;
925
+ const int ldb;
926
+ const int ldc;
927
+ const int ith;
928
+ const int nth;
929
+ };
930
+ #endif // __AVX2__
931
+
932
+ } // namespace
933
+
934
+ /**
935
+ * Performs optimized matrix multiplication on CPU.
936
+ *
937
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
938
+ * Despite its name, this isn't a generalized implementation. Work is
939
+ * only performed when a handwritten kernel is written and available.
940
+ * Otherwise the caller should fall back to a general matmul routine.
941
+ *
942
+ * For example, for single-threaded single-precision GEMM you can say
943
+ *
944
+ * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
945
+ * 0, 1, GGML_TASK_TYPE_COMPUTE,
946
+ * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
947
+ *
948
+ * @param m is rows in `A` and `C`
949
+ * @param n is cols in `B` and `C`
950
+ * @param k is cols in `A` and rows in `B`
951
+ * @param A is first input matrix (always transposed)
952
+ * @param lda is row stride of `A`
953
+ * @param B is second input matrix (never transposed)
954
+ * @param ldb is row stride of `B`
955
+ * @param C is input/output array of output matrices
956
+ * @param ldc is row stride of `C`
957
+ * @param ith is thread id (must be less than `nth`)
958
+ * @param nth is number of threads (must be greater than zero)
959
+ * @param task is GGML task type
960
+ * @param Atype is GGML data type of `A`
961
+ * @param Btype is GGML data type of `B`
962
+ * @param Ctype is GGML data type of `C`
963
+ * @return true if this function was able to service the matmul request
964
+ */
965
+ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
966
+ int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
967
+
968
+ assert(m >= 0);
969
+ assert(n >= 0);
970
+ assert(k >= 0);
971
+ assert(lda >= k);
972
+ assert(ldb >= k);
973
+ assert(ldc >= m);
974
+ assert(nth > 0);
975
+ assert(ith < nth);
976
+ assert(1ll * lda * m <= 0x7fffffff);
977
+ assert(1ll * ldb * n <= 0x7fffffff);
978
+ assert(1ll * ldc * n <= 0x7fffffff);
979
+
980
+ if (Ctype != GGML_TYPE_F32)
981
+ return false;
982
+
983
+ switch (Atype) {
984
+
985
+ case GGML_TYPE_F32: {
986
+ if (Btype != GGML_TYPE_F32)
987
+ return false;
988
+ #if defined(__AVX512F__)
989
+ if (k % 16)
990
+ return false;
991
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{
992
+ k, (const float *)A, lda,
993
+ (const float *)B, ldb,
994
+ (float *)C, ldc,
995
+ ith, nth};
996
+ tb.matmul(m, n, task);
997
+ return true;
998
+ #elif defined(__AVX__) || defined(__AVX2__)
999
+ if (k % 8)
1000
+ return false;
1001
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{
1002
+ k, (const float *)A, lda,
1003
+ (const float *)B, ldb,
1004
+ (float *)C, ldc,
1005
+ ith, nth};
1006
+ tb.matmul(m, n, task);
1007
+ return true;
1008
+ #elif defined(__ARM_NEON)
1009
+ if (n < 4)
1010
+ return false;
1011
+ if (k % 4)
1012
+ return false;
1013
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
1014
+ k, (const float *)A, lda,
1015
+ (const float *)B, ldb,
1016
+ (float *)C, ldc,
1017
+ ith, nth};
1018
+ tb.matmul(m, n, task);
1019
+ return true;
1020
+ #else
1021
+ return false;
1022
+ #endif
1023
+ }
1024
+
1025
+ case GGML_TYPE_F16: {
1026
+ #if defined(__AVX512F__)
1027
+ if (k % 16)
1028
+ return false;
1029
+ if (Btype != GGML_TYPE_F32)
1030
+ return false;
1031
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
1032
+ k, (const ggml_fp16_t *)A, lda,
1033
+ (const float *)B, ldb,
1034
+ (float *)C, ldc,
1035
+ ith, nth};
1036
+ tb.matmul(m, n, task);
1037
+ return true;
1038
+ #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
1039
+ if (k % 8)
1040
+ return false;
1041
+ if (Btype != GGML_TYPE_F32)
1042
+ return false;
1043
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
1044
+ k, (const ggml_fp16_t *)A, lda,
1045
+ (const float *)B, ldb,
1046
+ (float *)C, ldc,
1047
+ ith, nth};
1048
+ tb.matmul(m, n, task);
1049
+ return true;
1050
+ #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
1051
+ if (n < 8)
1052
+ return false;
1053
+ if (k % 8)
1054
+ return false;
1055
+ if (Btype != GGML_TYPE_F16)
1056
+ return false;
1057
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
1058
+ k, (const ggml_fp16_t *)A, lda,
1059
+ (const ggml_fp16_t *)B, ldb,
1060
+ (float *)C, ldc,
1061
+ ith, nth};
1062
+ tb.matmul(m, n, task);
1063
+ return true;
1064
+ #elif defined(__ARM_NEON) && !defined(_MSC_VER)
1065
+ if (k % 4)
1066
+ return false;
1067
+ if (Btype != GGML_TYPE_F32)
1068
+ return false;
1069
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
1070
+ k, (const ggml_fp16_t *)A, lda,
1071
+ (const float *)B, ldb,
1072
+ (float *)C, ldc,
1073
+ ith, nth};
1074
+ tb.matmul(m, n, task);
1075
+ return true;
1076
+ #else
1077
+ return false;
1078
+ #endif
1079
+ }
1080
+
1081
+ case GGML_TYPE_Q8_0: {
1082
+ if (Btype != GGML_TYPE_Q8_0)
1083
+ return false;
1084
+ #if defined(__AVX2__) || defined(__AVX512F__)
1085
+ tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
1086
+ k, (const block_q8_0 *)A, lda,
1087
+ (const block_q8_0 *)B, ldb,
1088
+ (float *)C, ldc,
1089
+ ith, nth};
1090
+ tb.matmul(m, n, task);
1091
+ return true;
1092
+ #elif defined(__ARM_FEATURE_DOTPROD)
1093
+ tinyBLAS_Q0_ARM<block_q8_0> tb{
1094
+ k, (const block_q8_0 *)A, lda,
1095
+ (const block_q8_0 *)B, ldb,
1096
+ (float *)C, ldc,
1097
+ ith, nth};
1098
+ tb.matmul(m, n, task);
1099
+ return true;
1100
+ #else
1101
+ return false;
1102
+ #endif
1103
+ }
1104
+
1105
+ case GGML_TYPE_Q4_0: {
1106
+ if (Btype != GGML_TYPE_Q8_0)
1107
+ return false;
1108
+ #if defined(__AVX2__) || defined(__AVX512F__)
1109
+ tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
1110
+ k, (const block_q4_0 *)A, lda,
1111
+ (const block_q8_0 *)B, ldb,
1112
+ (float *)C, ldc,
1113
+ ith, nth};
1114
+ tb.matmul(m, n, task);
1115
+ return true;
1116
+ #elif defined(__ARM_FEATURE_DOTPROD)
1117
+ tinyBLAS_Q0_ARM<block_q4_0> tb{
1118
+ k, (const block_q4_0 *)A, lda,
1119
+ (const block_q8_0 *)B, ldb,
1120
+ (float *)C, ldc,
1121
+ ith, nth};
1122
+ tb.matmul(m, n, task);
1123
+ return true;
1124
+ #else
1125
+ return false;
1126
+ #endif
1127
+ }
1128
+
1129
+ default:
1130
+ return false;
1131
+ }
1132
+
1133
+ (void)m;
1134
+ (void)n;
1135
+ (void)k;
1136
+ (void)A;
1137
+ (void)lda;
1138
+ (void)B;
1139
+ (void)ldb;
1140
+ (void)C;
1141
+ (void)ldc;
1142
+ (void)ith;
1143
+ (void)nth;
1144
+ (void)task;
1145
+ (void)Atype;
1146
+ (void)Btype;
1147
+ (void)Ctype;
1148
+ }