cui-llama.rn 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/common.cpp +7 -4
- package/cpp/common.h +14 -2
- package/cpp/ggml-alloc.c +0 -1
- package/cpp/ggml-backend-reg.cpp +74 -49
- package/cpp/ggml-cpu-aarch64.cpp +51 -71
- package/cpp/ggml-cpu.c +6 -6
- package/cpp/ggml-cpu.cpp +9 -0
- package/cpp/ggml-impl.h +16 -0
- package/cpp/ggml.c +153 -136
- package/cpp/ggml.h +29 -12
- package/cpp/llama-grammar.cpp +15 -15
- package/cpp/llama-grammar.h +2 -5
- package/cpp/llama-vocab.cpp +5 -1
- package/cpp/llama-vocab.h +1 -1
- package/cpp/llama.cpp +992 -300
- package/cpp/llama.h +0 -3
- package/cpp/sgemm.cpp +265 -258
- package/cpp/sgemm.h +2 -2
- package/package.json +1 -1
package/cpp/sgemm.cpp
CHANGED
@@ -53,6 +53,8 @@
|
|
53
53
|
#include "ggml-cpu-impl.h"
|
54
54
|
#include "ggml-quants.h"
|
55
55
|
|
56
|
+
#include <atomic>
|
57
|
+
|
56
58
|
#ifdef _MSC_VER
|
57
59
|
#define NOINLINE __declspec(noinline)
|
58
60
|
#else
|
@@ -134,6 +136,16 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
|
134
136
|
return _mm512_fmadd_ps(a, b, c);
|
135
137
|
}
|
136
138
|
#endif
|
139
|
+
#if defined(__AVX512BF16__)
|
140
|
+
template <>
|
141
|
+
inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
|
142
|
+
return _mm512_dpbf16_ps(c, a, b);
|
143
|
+
}
|
144
|
+
template <>
|
145
|
+
inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
|
146
|
+
return _mm256_dpbf16_ps(c, a, b);
|
147
|
+
}
|
148
|
+
#endif
|
137
149
|
#endif
|
138
150
|
|
139
151
|
#if defined(__ARM_FEATURE_FMA)
|
@@ -204,6 +216,7 @@ template <> inline float32x4_t load(const float *p) {
|
|
204
216
|
return vld1q_f32(p);
|
205
217
|
}
|
206
218
|
#if !defined(_MSC_VER)
|
219
|
+
// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
207
220
|
template <> inline float16x8_t load(const lm_ggml_fp16_t *p) {
|
208
221
|
return vld1q_f16((const float16_t *)p);
|
209
222
|
}
|
@@ -225,6 +238,13 @@ template <> inline __m256 load(const float *p) {
|
|
225
238
|
}
|
226
239
|
#endif // __AVX__
|
227
240
|
|
241
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
242
|
+
template <> inline __m256 load(const lm_ggml_bf16_t *p) {
|
243
|
+
return _mm256_castsi256_ps(
|
244
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
|
245
|
+
}
|
246
|
+
#endif // __AVX2__
|
247
|
+
|
228
248
|
#if defined(__F16C__)
|
229
249
|
template <> inline __m256 load(const lm_ggml_fp16_t *p) {
|
230
250
|
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
|
@@ -238,8 +258,27 @@ template <> inline __m512 load(const float *p) {
|
|
238
258
|
template <> inline __m512 load(const lm_ggml_fp16_t *p) {
|
239
259
|
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
|
240
260
|
}
|
261
|
+
template <> inline __m512 load(const lm_ggml_bf16_t *p) {
|
262
|
+
return _mm512_castsi512_ps(
|
263
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
|
264
|
+
}
|
241
265
|
#endif // __AVX512F__
|
242
266
|
|
267
|
+
#if defined(__AVX512BF16__)
|
268
|
+
template <> inline __m512bh load(const lm_ggml_bf16_t *p) {
|
269
|
+
return (__m512bh)_mm512_loadu_ps((const float *)p);
|
270
|
+
}
|
271
|
+
template <> inline __m256bh load(const lm_ggml_bf16_t *p) {
|
272
|
+
return (__m256bh)_mm256_loadu_ps((const float *)p);
|
273
|
+
}
|
274
|
+
template <> inline __m512bh load(const float *p) {
|
275
|
+
return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
|
276
|
+
}
|
277
|
+
template <> inline __m256bh load(const float *p) {
|
278
|
+
return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
|
279
|
+
}
|
280
|
+
#endif
|
281
|
+
|
243
282
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
244
283
|
// CONSTANTS
|
245
284
|
|
@@ -251,199 +290,170 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|
251
290
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
252
291
|
// FLOATING POINT MATRIX MULTIPLICATION
|
253
292
|
|
293
|
+
template <int M>
|
294
|
+
static inline int64_t BLOCK_SIZE(size_t m) {
|
295
|
+
const int64_t NB_BLOC_M = (m + M - 1) / M;
|
296
|
+
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
|
297
|
+
}
|
298
|
+
|
299
|
+
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
|
300
|
+
return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
|
301
|
+
}
|
302
|
+
|
254
303
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
255
304
|
class tinyBLAS {
|
256
305
|
public:
|
257
|
-
tinyBLAS(int64_t k,
|
306
|
+
tinyBLAS(const lm_ggml_compute_params * params, int64_t k,
|
258
307
|
const TA *A, int64_t lda,
|
259
308
|
const TB *B, int64_t ldb,
|
260
|
-
TC *C, int64_t ldc
|
261
|
-
|
262
|
-
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
309
|
+
TC *C, int64_t ldc)
|
310
|
+
: params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
263
311
|
}
|
264
312
|
|
265
|
-
|
266
|
-
|
313
|
+
bool matmul(int64_t m, int64_t n) {
|
314
|
+
if (k % KN != 0)
|
315
|
+
return false;
|
316
|
+
// compute RM for only need tile with size RM&RM-1
|
317
|
+
#if VECTOR_REGISTERS == 32
|
318
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
319
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
320
|
+
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
321
|
+
return true;
|
322
|
+
}
|
323
|
+
if (m % 8 == 0 ) {
|
324
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
325
|
+
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
326
|
+
return true;
|
327
|
+
}
|
328
|
+
if (m % 4 == 0) {
|
329
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
330
|
+
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
331
|
+
return true;
|
332
|
+
}
|
333
|
+
#else // VECTOR_REGISTERS == 16
|
334
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
335
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
336
|
+
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
337
|
+
return true;
|
338
|
+
}
|
339
|
+
if (m % 8 == 0 ) {
|
340
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
341
|
+
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
342
|
+
return true;
|
343
|
+
}
|
344
|
+
if (m % 4 == 0) {
|
345
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
346
|
+
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
347
|
+
return true;
|
348
|
+
}
|
349
|
+
#endif
|
350
|
+
return false;
|
267
351
|
}
|
268
352
|
|
269
353
|
private:
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
mc = 4;
|
281
|
-
nc = 5;
|
282
|
-
gemm<4, 5>(m0, m, n0, n);
|
283
|
-
break;
|
284
|
-
case 0x54:
|
285
|
-
mc = 5;
|
286
|
-
nc = 4;
|
287
|
-
gemm<5, 4>(m0, m, n0, n);
|
288
|
-
break;
|
289
|
-
case 0x44:
|
290
|
-
mc = 4;
|
291
|
-
nc = 4;
|
292
|
-
gemm<4, 4>(m0, m, n0, n);
|
293
|
-
break;
|
294
|
-
case 0x53:
|
295
|
-
mc = 5;
|
296
|
-
nc = 3;
|
297
|
-
gemm<5, 3>(m0, m, n0, n);
|
298
|
-
break;
|
299
|
-
case 0x35:
|
300
|
-
mc = 3;
|
301
|
-
nc = 5;
|
302
|
-
gemm<3, 5>(m0, m, n0, n);
|
303
|
-
break;
|
304
|
-
case 0x43:
|
305
|
-
mc = 4;
|
306
|
-
nc = 3;
|
307
|
-
gemm<4, 3>(m0, m, n0, n);
|
308
|
-
break;
|
309
|
-
#else
|
310
|
-
case 0x55:
|
311
|
-
case 0x54:
|
312
|
-
case 0x53:
|
313
|
-
case 0x45:
|
314
|
-
case 0x44:
|
315
|
-
case 0x43:
|
316
|
-
mc = 4;
|
317
|
-
nc = 3;
|
318
|
-
gemm<4, 3>(m0, m, n0, n);
|
319
|
-
break;
|
320
|
-
case 0x35:
|
321
|
-
#endif
|
322
|
-
case 0x34:
|
323
|
-
mc = 3;
|
324
|
-
nc = 4;
|
325
|
-
gemm<3, 4>(m0, m, n0, n);
|
326
|
-
break;
|
327
|
-
case 0x52:
|
328
|
-
mc = 5;
|
329
|
-
nc = 2;
|
330
|
-
gemm<5, 2>(m0, m, n0, n);
|
331
|
-
break;
|
332
|
-
case 0x33:
|
333
|
-
mc = 3;
|
334
|
-
nc = 3;
|
335
|
-
gemm<3, 3>(m0, m, n0, n);
|
336
|
-
break;
|
337
|
-
case 0x25:
|
338
|
-
mc = 2;
|
339
|
-
nc = 5;
|
340
|
-
gemm<2, 5>(m0, m, n0, n);
|
341
|
-
break;
|
342
|
-
case 0x42:
|
343
|
-
mc = 4;
|
344
|
-
nc = 2;
|
345
|
-
gemm<4, 2>(m0, m, n0, n);
|
346
|
-
break;
|
347
|
-
case 0x24:
|
348
|
-
mc = 2;
|
349
|
-
nc = 4;
|
350
|
-
gemm<2, 4>(m0, m, n0, n);
|
351
|
-
break;
|
352
|
-
case 0x32:
|
353
|
-
mc = 3;
|
354
|
-
nc = 2;
|
355
|
-
gemm<3, 2>(m0, m, n0, n);
|
356
|
-
break;
|
357
|
-
case 0x23:
|
358
|
-
mc = 2;
|
359
|
-
nc = 3;
|
360
|
-
gemm<2, 3>(m0, m, n0, n);
|
361
|
-
break;
|
362
|
-
case 0x51:
|
363
|
-
mc = 5;
|
364
|
-
nc = 1;
|
365
|
-
gemm<5, 1>(m0, m, n0, n);
|
366
|
-
break;
|
367
|
-
case 0x41:
|
368
|
-
mc = 4;
|
369
|
-
nc = 1;
|
370
|
-
gemm<4, 1>(m0, m, n0, n);
|
371
|
-
break;
|
372
|
-
case 0x22:
|
373
|
-
mc = 2;
|
374
|
-
nc = 2;
|
375
|
-
gemm<2, 2>(m0, m, n0, n);
|
376
|
-
break;
|
377
|
-
case 0x15:
|
378
|
-
mc = 1;
|
379
|
-
nc = 5;
|
380
|
-
gemm<1, 5>(m0, m, n0, n);
|
381
|
-
break;
|
382
|
-
case 0x14:
|
383
|
-
mc = 1;
|
384
|
-
nc = 4;
|
385
|
-
gemm<1, 4>(m0, m, n0, n);
|
386
|
-
break;
|
387
|
-
case 0x31:
|
388
|
-
mc = 3;
|
389
|
-
nc = 1;
|
390
|
-
gemm<3, 1>(m0, m, n0, n);
|
391
|
-
break;
|
392
|
-
case 0x13:
|
393
|
-
mc = 1;
|
394
|
-
nc = 3;
|
395
|
-
gemm<1, 3>(m0, m, n0, n);
|
396
|
-
break;
|
397
|
-
case 0x21:
|
398
|
-
mc = 2;
|
399
|
-
nc = 1;
|
400
|
-
gemm<2, 1>(m0, m, n0, n);
|
401
|
-
break;
|
402
|
-
case 0x12:
|
403
|
-
mc = 1;
|
404
|
-
nc = 2;
|
405
|
-
gemm<1, 2>(m0, m, n0, n);
|
406
|
-
break;
|
407
|
-
case 0x11:
|
408
|
-
mc = 1;
|
409
|
-
nc = 1;
|
410
|
-
gemm<1, 1>(m0, m, n0, n);
|
411
|
-
break;
|
412
|
-
default:
|
413
|
-
return;
|
354
|
+
template <int RM, int RN, int BM>
|
355
|
+
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
356
|
+
if (SIZE_N == RN) {
|
357
|
+
return gemm<RM, RN, BM>(m, n, BN);
|
358
|
+
}
|
359
|
+
if constexpr (RN > 1) {
|
360
|
+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
361
|
+
} else {
|
362
|
+
LM_GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
363
|
+
LM_GGML_ASSERT(false); // we have miss something.
|
414
364
|
}
|
415
|
-
mp = m0 + (m - m0) / mc * mc;
|
416
|
-
np = n0 + (n - n0) / nc * nc;
|
417
|
-
mnpack(mp, m, n0, np);
|
418
|
-
mnpack(m0, m, np, n);
|
419
365
|
}
|
420
366
|
|
421
367
|
template <int RM, int RN>
|
422
|
-
|
423
|
-
|
424
|
-
int64_t
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
for (int64_t i = 0; i < RM; ++i)
|
443
|
-
|
368
|
+
inline void gemm_bloc(int64_t ii, int64_t jj) {
|
369
|
+
D Cv[RN][RM] = {};
|
370
|
+
for (int64_t l = 0; l < k; l += KN) {
|
371
|
+
// help compiler for op order.
|
372
|
+
if constexpr (RM <= RN) {
|
373
|
+
V Av[RM];
|
374
|
+
for (int64_t i = 0; i < RM; ++i) {
|
375
|
+
Av[i] = load<V>(A + lda * (ii + i) + l);
|
376
|
+
}
|
377
|
+
for (int64_t j = 0; j < RN; ++j) {
|
378
|
+
V Bv = load<V>(B + ldb * (jj + j) + l);
|
379
|
+
for (int64_t i = 0; i < RM; ++i) {
|
380
|
+
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
|
381
|
+
}
|
382
|
+
}
|
383
|
+
} else {
|
384
|
+
V Bv[RN];
|
385
|
+
for (int64_t j = 0; j < RN; ++j) {
|
386
|
+
Bv[j] = load<V>(B + ldb * (jj + j) + l);
|
387
|
+
}
|
388
|
+
for (int64_t i = 0; i < RM; ++i) {
|
389
|
+
V Av = load<V>(A + lda * (ii + i) + l);
|
390
|
+
for (int64_t j = 0; j < RN; ++j) {
|
391
|
+
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
|
392
|
+
}
|
393
|
+
}
|
394
|
+
}
|
444
395
|
}
|
396
|
+
for (int64_t j = 0; j < RN; ++j)
|
397
|
+
for (int64_t i = 0; i < RM; ++i)
|
398
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
445
399
|
}
|
446
400
|
|
401
|
+
template <int RM, int RN, int BM>
|
402
|
+
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
403
|
+
static std::atomic<int64_t> current_chunk;
|
404
|
+
|
405
|
+
LM_GGML_ASSERT(m % (RM * BM) == 0);
|
406
|
+
const int64_t ytiles = m / (RM * BM);
|
407
|
+
const int64_t xtiles = (n + RN -1) / RN;
|
408
|
+
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
409
|
+
|
410
|
+
// "round" bloc_size to "nearest" BN
|
411
|
+
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
412
|
+
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
413
|
+
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
414
|
+
const int64_t nb_job = ytiles * NB_BN;
|
415
|
+
|
416
|
+
if (params->ith == 0) {
|
417
|
+
LM_GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
418
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
419
|
+
std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
420
|
+
}
|
421
|
+
|
422
|
+
lm_ggml_barrier(params->threadpool);
|
423
|
+
|
424
|
+
int64_t job = params->ith;
|
425
|
+
while (job < nb_job) {
|
426
|
+
const int64_t ii = (job % ytiles) * RM * BM;
|
427
|
+
const int64_t jb = job / ytiles;
|
428
|
+
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
429
|
+
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
430
|
+
|
431
|
+
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
432
|
+
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
433
|
+
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
434
|
+
|
435
|
+
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
436
|
+
int64_t jj = jj0;
|
437
|
+
for (; jj < jj1; jj += RN) {
|
438
|
+
gemm_bloc<RM, RN>(ii + bi, jj);
|
439
|
+
}
|
440
|
+
if constexpr (RN > 1) {
|
441
|
+
for (; jj < jj2; jj += RN - 1) {
|
442
|
+
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
443
|
+
}
|
444
|
+
}
|
445
|
+
LM_GGML_ASSERT(jj == jj2);
|
446
|
+
}
|
447
|
+
|
448
|
+
// next step.
|
449
|
+
job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
450
|
+
}
|
451
|
+
|
452
|
+
lm_ggml_barrier(params->threadpool);
|
453
|
+
return;
|
454
|
+
}
|
455
|
+
|
456
|
+
const lm_ggml_compute_params * params;
|
447
457
|
const TA *const A;
|
448
458
|
const TB *const B;
|
449
459
|
TC *const C;
|
@@ -451,8 +461,6 @@ class tinyBLAS {
|
|
451
461
|
const int64_t lda;
|
452
462
|
const int64_t ldb;
|
453
463
|
const int64_t ldc;
|
454
|
-
const int ith;
|
455
|
-
const int nth;
|
456
464
|
};
|
457
465
|
|
458
466
|
//////////////////////////////////////////////////////////////////////////////////////////
|
@@ -1656,8 +1664,9 @@ class tinyBLAS_PPC {
|
|
1656
1664
|
* @param Ctype is GGML data type of `C`
|
1657
1665
|
* @return true if this function was able to service the matmul request
|
1658
1666
|
*/
|
1659
|
-
bool llamafile_sgemm(
|
1660
|
-
|
1667
|
+
bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
1668
|
+
const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
1669
|
+
int64_t ldc, int Atype, int Btype, int Ctype) {
|
1661
1670
|
|
1662
1671
|
assert(m >= 0);
|
1663
1672
|
assert(n >= 0);
|
@@ -1665,8 +1674,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1665
1674
|
assert(lda >= k);
|
1666
1675
|
assert(ldb >= k);
|
1667
1676
|
assert(ldc >= m);
|
1668
|
-
assert(nth > 0);
|
1669
|
-
assert(ith < nth);
|
1677
|
+
assert(params->nth > 0);
|
1678
|
+
assert(params->ith < params->nth);
|
1670
1679
|
|
1671
1680
|
// only enable sgemm for prompt processing
|
1672
1681
|
if (n < 2)
|
@@ -1681,37 +1690,25 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1681
1690
|
if (Btype != LM_GGML_TYPE_F32)
|
1682
1691
|
return false;
|
1683
1692
|
#if defined(__AVX512F__)
|
1684
|
-
|
1685
|
-
return false;
|
1686
|
-
tinyBLAS<16, __m512, __m512, float, float, float> tb{
|
1693
|
+
tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
|
1687
1694
|
k, (const float *)A, lda,
|
1688
1695
|
(const float *)B, ldb,
|
1689
|
-
(float *)C, ldc
|
1690
|
-
|
1691
|
-
tb.matmul(m, n);
|
1692
|
-
return true;
|
1696
|
+
(float *)C, ldc};
|
1697
|
+
return tb.matmul(m, n);
|
1693
1698
|
#elif defined(__AVX__) || defined(__AVX2__)
|
1694
|
-
|
1695
|
-
return false;
|
1696
|
-
tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
1699
|
+
tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
|
1697
1700
|
k, (const float *)A, lda,
|
1698
1701
|
(const float *)B, ldb,
|
1699
|
-
(float *)C, ldc
|
1700
|
-
|
1701
|
-
tb.matmul(m, n);
|
1702
|
-
return true;
|
1702
|
+
(float *)C, ldc};
|
1703
|
+
return tb.matmul(m, n);
|
1703
1704
|
#elif defined(__ARM_NEON)
|
1704
1705
|
if (n < 4)
|
1705
1706
|
return false;
|
1706
|
-
|
1707
|
-
return false;
|
1708
|
-
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
|
1707
|
+
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
1709
1708
|
k, (const float *)A, lda,
|
1710
1709
|
(const float *)B, ldb,
|
1711
|
-
(float *)C, ldc
|
1712
|
-
|
1713
|
-
tb.matmul(m, n);
|
1714
|
-
return true;
|
1710
|
+
(float *)C, ldc};
|
1711
|
+
return tb.matmul(m, n);
|
1715
1712
|
#elif defined(__MMA__)
|
1716
1713
|
if (k % 8)
|
1717
1714
|
return false;
|
@@ -1719,7 +1716,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1719
1716
|
k, (const float *)A, lda,
|
1720
1717
|
(const float *)B, ldb,
|
1721
1718
|
(float *)C, ldc,
|
1722
|
-
ith, nth};
|
1719
|
+
params->ith, params->nth};
|
1723
1720
|
tb.matmul(m, n);
|
1724
1721
|
return true;
|
1725
1722
|
#else
|
@@ -1727,60 +1724,71 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1727
1724
|
#endif
|
1728
1725
|
}
|
1729
1726
|
|
1727
|
+
case LM_GGML_TYPE_BF16: {
|
1728
|
+
#if defined(__AVX512BF16__)
|
1729
|
+
if (Btype == LM_GGML_TYPE_BF16) {
|
1730
|
+
tinyBLAS<32, __m512, __m512bh, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
|
1731
|
+
(const lm_ggml_bf16_t *)A, lda,
|
1732
|
+
(const lm_ggml_bf16_t *)B, ldb,
|
1733
|
+
(float *)C, ldc};
|
1734
|
+
return tb.matmul(m, n);
|
1735
|
+
}
|
1736
|
+
#elif defined(__AVX512F__)
|
1737
|
+
if (Btype == LM_GGML_TYPE_BF16) {
|
1738
|
+
tinyBLAS<16, __m512, __m512, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
|
1739
|
+
(const lm_ggml_bf16_t *)A, lda,
|
1740
|
+
(const lm_ggml_bf16_t *)B, ldb,
|
1741
|
+
(float *)C, ldc};
|
1742
|
+
return tb.matmul(m, n);
|
1743
|
+
}
|
1744
|
+
#elif defined(__AVX2__)
|
1745
|
+
if (Btype == LM_GGML_TYPE_BF16) {
|
1746
|
+
tinyBLAS<8, __m256, __m256, lm_ggml_bf16_t, lm_ggml_bf16_t, float> tb{ params, k,
|
1747
|
+
(const lm_ggml_bf16_t *)A, lda,
|
1748
|
+
(const lm_ggml_bf16_t *)B, ldb,
|
1749
|
+
(float *)C, ldc};
|
1750
|
+
return tb.matmul(m, n);
|
1751
|
+
}
|
1752
|
+
#endif
|
1753
|
+
return false;
|
1754
|
+
}
|
1730
1755
|
case LM_GGML_TYPE_F16: {
|
1731
1756
|
#if defined(__AVX512F__)
|
1732
|
-
if (
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
(float *)C, ldc,
|
1740
|
-
ith, nth};
|
1741
|
-
tb.matmul(m, n);
|
1742
|
-
return true;
|
1757
|
+
if (Btype == LM_GGML_TYPE_F16) {
|
1758
|
+
tinyBLAS<16, __m512, __m512, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params, k,
|
1759
|
+
(const lm_ggml_fp16_t *)A, lda,
|
1760
|
+
(const lm_ggml_fp16_t *)B, ldb,
|
1761
|
+
(float *)C, ldc};
|
1762
|
+
return tb.matmul(m, n);
|
1763
|
+
}
|
1743
1764
|
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
1744
|
-
if (
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
|
1749
|
-
|
1750
|
-
|
1751
|
-
(float *)C, ldc,
|
1752
|
-
ith, nth};
|
1753
|
-
tb.matmul(m, n);
|
1754
|
-
return true;
|
1765
|
+
if (Btype == LM_GGML_TYPE_F16) {
|
1766
|
+
tinyBLAS<8, __m256, __m256, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params, k,
|
1767
|
+
(const lm_ggml_fp16_t *)A, lda,
|
1768
|
+
(const lm_ggml_fp16_t *)B, ldb,
|
1769
|
+
(float *)C, ldc};
|
1770
|
+
return tb.matmul(m, n);
|
1771
|
+
}
|
1755
1772
|
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
1756
1773
|
if (n < 8)
|
1757
1774
|
return false;
|
1758
|
-
if (
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
(float *)C, ldc,
|
1766
|
-
ith, nth};
|
1767
|
-
tb.matmul(m, n);
|
1768
|
-
return true;
|
1775
|
+
if (Btype == LM_GGML_TYPE_F16) {
|
1776
|
+
tinyBLAS<8, float16x8_t, float16x8_t, lm_ggml_fp16_t, lm_ggml_fp16_t, float> tb{ params,
|
1777
|
+
k, (const lm_ggml_fp16_t *)A, lda,
|
1778
|
+
(const lm_ggml_fp16_t *)B, ldb,
|
1779
|
+
(float *)C, ldc};
|
1780
|
+
return tb.matmul(m, n);
|
1781
|
+
}
|
1769
1782
|
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
1770
|
-
if (
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
(float *)C, ldc,
|
1778
|
-
ith, nth};
|
1779
|
-
tb.matmul(m, n);
|
1780
|
-
return true;
|
1781
|
-
#else
|
1782
|
-
return false;
|
1783
|
+
if (Btype == LM_GGML_TYPE_F32) {
|
1784
|
+
tinyBLAS<4, float32x4_t, float32x4_t, lm_ggml_fp16_t, float, float> tb{ params,
|
1785
|
+
k, (const lm_ggml_fp16_t *)A, lda,
|
1786
|
+
(const float *)B, ldb,
|
1787
|
+
(float *)C, ldc};
|
1788
|
+
return tb.matmul(m, n);
|
1789
|
+
}
|
1783
1790
|
#endif
|
1791
|
+
return false;
|
1784
1792
|
}
|
1785
1793
|
|
1786
1794
|
case LM_GGML_TYPE_Q8_0: {
|
@@ -1791,7 +1799,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1791
1799
|
k, (const block_q8_0 *)A, lda,
|
1792
1800
|
(const block_q8_0 *)B, ldb,
|
1793
1801
|
(float *)C, ldc,
|
1794
|
-
ith, nth};
|
1802
|
+
params->ith, params->nth};
|
1795
1803
|
tb.matmul(m, n);
|
1796
1804
|
return true;
|
1797
1805
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
@@ -1799,7 +1807,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1799
1807
|
k, (const block_q8_0 *)A, lda,
|
1800
1808
|
(const block_q8_0 *)B, ldb,
|
1801
1809
|
(float *)C, ldc,
|
1802
|
-
ith, nth};
|
1810
|
+
params->ith, params->nth};
|
1803
1811
|
tb.matmul(m, n);
|
1804
1812
|
return true;
|
1805
1813
|
#else
|
@@ -1815,7 +1823,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1815
1823
|
k, (const block_q4_0 *)A, lda,
|
1816
1824
|
(const block_q8_0 *)B, ldb,
|
1817
1825
|
(float *)C, ldc,
|
1818
|
-
ith, nth};
|
1826
|
+
params->ith, params->nth};
|
1819
1827
|
tb.matmul(m, n);
|
1820
1828
|
return true;
|
1821
1829
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
@@ -1823,7 +1831,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1823
1831
|
k, (const block_q4_0 *)A, lda,
|
1824
1832
|
(const block_q8_0 *)B, ldb,
|
1825
1833
|
(float *)C, ldc,
|
1826
|
-
ith, nth};
|
1834
|
+
params->ith, params->nth};
|
1827
1835
|
tb.matmul(m, n);
|
1828
1836
|
return true;
|
1829
1837
|
#else
|
@@ -1839,7 +1847,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1839
1847
|
k, (const block_q5_0 *)A, lda,
|
1840
1848
|
(const block_q8_0 *)B, ldb,
|
1841
1849
|
(float *)C, ldc,
|
1842
|
-
ith, nth};
|
1850
|
+
params->ith, params->nth};
|
1843
1851
|
tb.matmul(m, n);
|
1844
1852
|
return true;
|
1845
1853
|
#else
|
@@ -1855,7 +1863,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1855
1863
|
k, (const block_iq4_nl *)A, lda,
|
1856
1864
|
(const block_q8_0 *)B, ldb,
|
1857
1865
|
(float *)C, ldc,
|
1858
|
-
ith, nth};
|
1866
|
+
params->ith, params->nth};
|
1859
1867
|
tb.matmul(m, n);
|
1860
1868
|
return true;
|
1861
1869
|
#else
|
@@ -1867,6 +1875,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1867
1875
|
return false;
|
1868
1876
|
}
|
1869
1877
|
|
1878
|
+
(void)params;
|
1870
1879
|
(void)m;
|
1871
1880
|
(void)n;
|
1872
1881
|
(void)k;
|
@@ -1876,8 +1885,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
1876
1885
|
(void)ldb;
|
1877
1886
|
(void)C;
|
1878
1887
|
(void)ldc;
|
1879
|
-
(void)ith;
|
1880
|
-
(void)nth;
|
1881
1888
|
(void)Atype;
|
1882
1889
|
(void)Btype;
|
1883
1890
|
(void)Ctype;
|