llama_cpp 0.14.7 → 0.15.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +59 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +24 -3
- data/vendor/tmp/llama.cpp/Makefile +42 -18
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +295 -17
- data/vendor/tmp/llama.cpp/ggml-impl.h +78 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +399 -184
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +302 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +28 -16
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +951 -263
- data/vendor/tmp/llama.cpp/ggml.c +1457 -92
- data/vendor/tmp/llama.cpp/ggml.h +37 -7
- data/vendor/tmp/llama.cpp/llama.cpp +671 -403
- data/vendor/tmp/llama.cpp/llama.h +34 -10
- data/vendor/tmp/llama.cpp/sgemm.cpp +134 -103
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1188 -656
- data/vendor/tmp/llama.cpp/unicode-data.h +4 -3
- data/vendor/tmp/llama.cpp/unicode.cpp +590 -49
- data/vendor/tmp/llama.cpp/unicode.h +6 -3
- metadata +3 -3
@@ -40,7 +40,7 @@
|
|
40
40
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
41
41
|
|
42
42
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
43
|
-
#define LLAMA_SESSION_VERSION
|
43
|
+
#define LLAMA_SESSION_VERSION 6
|
44
44
|
|
45
45
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
46
46
|
#define LLAMA_STATE_SEQ_VERSION 1
|
@@ -69,6 +69,23 @@ extern "C" {
|
|
69
69
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
70
70
|
};
|
71
71
|
|
72
|
+
// pre-tokenization types
|
73
|
+
enum llama_vocab_pre_type {
|
74
|
+
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
75
|
+
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
76
|
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
77
|
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
78
|
+
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
79
|
+
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
80
|
+
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
81
|
+
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
82
|
+
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
83
|
+
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
84
|
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
|
85
|
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
|
86
|
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
|
87
|
+
};
|
88
|
+
|
72
89
|
// note: these values should be synchronized with ggml_rope
|
73
90
|
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
74
91
|
enum llama_rope_type {
|
@@ -122,6 +139,7 @@ extern "C" {
|
|
122
139
|
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
123
140
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
124
141
|
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
142
|
+
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
125
143
|
|
126
144
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
127
145
|
};
|
@@ -159,7 +177,7 @@ extern "C" {
|
|
159
177
|
bool sorted;
|
160
178
|
} llama_token_data_array;
|
161
179
|
|
162
|
-
typedef bool (*llama_progress_callback)(float progress, void *
|
180
|
+
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
163
181
|
|
164
182
|
// Input data for llama_decode
|
165
183
|
// A llama_batch object can contain input about one or many sequences
|
@@ -195,15 +213,19 @@ extern "C" {
|
|
195
213
|
LLAMA_KV_OVERRIDE_TYPE_INT,
|
196
214
|
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
197
215
|
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
216
|
+
LLAMA_KV_OVERRIDE_TYPE_STR,
|
198
217
|
};
|
199
218
|
|
200
219
|
struct llama_model_kv_override {
|
201
|
-
char key[128];
|
202
220
|
enum llama_model_kv_override_type tag;
|
221
|
+
|
222
|
+
char key[128];
|
223
|
+
|
203
224
|
union {
|
204
|
-
int64_t
|
205
|
-
double
|
206
|
-
bool
|
225
|
+
int64_t val_i64;
|
226
|
+
double val_f64;
|
227
|
+
bool val_bool;
|
228
|
+
char val_str[128];
|
207
229
|
};
|
208
230
|
};
|
209
231
|
|
@@ -232,9 +254,10 @@ extern "C" {
|
|
232
254
|
const struct llama_model_kv_override * kv_overrides;
|
233
255
|
|
234
256
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
235
|
-
bool vocab_only;
|
236
|
-
bool use_mmap;
|
237
|
-
bool use_mlock;
|
257
|
+
bool vocab_only; // only load the vocabulary, no weights
|
258
|
+
bool use_mmap; // use mmap if possible
|
259
|
+
bool use_mlock; // force system to keep model in RAM
|
260
|
+
bool check_tensors; // validate model tensor data
|
238
261
|
};
|
239
262
|
|
240
263
|
struct llama_context_params {
|
@@ -270,6 +293,7 @@ extern "C" {
|
|
270
293
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
271
294
|
bool embeddings; // if true, extract embeddings (together with logits)
|
272
295
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
296
|
+
bool flash_attn; // whether to use flash attention
|
273
297
|
|
274
298
|
// Abort callback
|
275
299
|
// if it returns true, execution of llama_decode() will be aborted
|
@@ -525,7 +549,7 @@ extern "C" {
|
|
525
549
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
526
550
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
527
551
|
|
528
|
-
// Clear the KV cache
|
552
|
+
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
529
553
|
LLAMA_API void llama_kv_cache_clear(
|
530
554
|
struct llama_context * ctx);
|
531
555
|
|
@@ -1,6 +1,3 @@
|
|
1
|
-
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
2
|
-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
3
|
-
//
|
4
1
|
// Copyright 2024 Mozilla Foundation
|
5
2
|
//
|
6
3
|
// Permission is hereby granted, free of charge, to any person obtaining
|
@@ -50,7 +47,6 @@
|
|
50
47
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
51
48
|
|
52
49
|
#include "sgemm.h"
|
53
|
-
#include <algorithm>
|
54
50
|
#include "ggml-impl.h"
|
55
51
|
#include "ggml-quants.h"
|
56
52
|
|
@@ -243,23 +239,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
243
239
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
244
240
|
class tinyBLAS {
|
245
241
|
public:
|
246
|
-
tinyBLAS(
|
247
|
-
const TA *A,
|
248
|
-
const TB *B,
|
249
|
-
TC *C,
|
242
|
+
tinyBLAS(int64_t k,
|
243
|
+
const TA *A, int64_t lda,
|
244
|
+
const TB *B, int64_t ldb,
|
245
|
+
TC *C, int64_t ldc,
|
250
246
|
int ith, int nth)
|
251
247
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
252
248
|
}
|
253
249
|
|
254
|
-
void matmul(
|
250
|
+
void matmul(int64_t m, int64_t n, int task) {
|
255
251
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
256
252
|
mnpack(0, m, 0, n);
|
257
253
|
}
|
258
254
|
|
259
255
|
private:
|
260
|
-
NOINLINE void mnpack(
|
261
|
-
|
262
|
-
switch ((
|
256
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
257
|
+
int64_t mc, nc, mp, np;
|
258
|
+
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
263
259
|
#if VECTOR_REGISTERS == 32
|
264
260
|
case 0x55:
|
265
261
|
mc = 5;
|
@@ -409,27 +405,27 @@ class tinyBLAS {
|
|
409
405
|
}
|
410
406
|
|
411
407
|
template <int RM, int RN>
|
412
|
-
NOINLINE void gemm(
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
408
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
409
|
+
int64_t ytiles = (m - m0) / RM;
|
410
|
+
int64_t xtiles = (n - n0) / RN;
|
411
|
+
int64_t tiles = xtiles * ytiles;
|
412
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
413
|
+
int64_t start = duty * ith;
|
414
|
+
int64_t end = start + duty;
|
419
415
|
if (end > tiles)
|
420
416
|
end = tiles;
|
421
|
-
for (
|
422
|
-
|
423
|
-
|
417
|
+
for (int64_t job = start; job < end; ++job) {
|
418
|
+
int64_t ii = m0 + job / xtiles * RM;
|
419
|
+
int64_t jj = n0 + job % xtiles * RN;
|
424
420
|
D Cv[RN][RM] = {};
|
425
|
-
for (
|
426
|
-
for (
|
427
|
-
for (
|
421
|
+
for (int64_t l = 0; l < k; l += KN)
|
422
|
+
for (int64_t j = 0; j < RN; ++j)
|
423
|
+
for (int64_t i = 0; i < RM; ++i)
|
428
424
|
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
429
425
|
load<V>(B + ldb * (jj + j) + l),
|
430
426
|
Cv[j][i]);
|
431
|
-
for (
|
432
|
-
for (
|
427
|
+
for (int64_t j = 0; j < RN; ++j)
|
428
|
+
for (int64_t i = 0; i < RM; ++i)
|
433
429
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
434
430
|
}
|
435
431
|
}
|
@@ -437,10 +433,10 @@ class tinyBLAS {
|
|
437
433
|
const TA *const A;
|
438
434
|
const TB *const B;
|
439
435
|
TC *const C;
|
440
|
-
const
|
441
|
-
const
|
442
|
-
const
|
443
|
-
const
|
436
|
+
const int64_t k;
|
437
|
+
const int64_t lda;
|
438
|
+
const int64_t ldb;
|
439
|
+
const int64_t ldc;
|
444
440
|
const int ith;
|
445
441
|
const int nth;
|
446
442
|
};
|
@@ -452,23 +448,23 @@ class tinyBLAS {
|
|
452
448
|
template <typename TA>
|
453
449
|
class tinyBLAS_Q0_ARM {
|
454
450
|
public:
|
455
|
-
tinyBLAS_Q0_ARM(
|
456
|
-
const TA *A,
|
457
|
-
const block_q8_0 *B,
|
458
|
-
float *C,
|
451
|
+
tinyBLAS_Q0_ARM(int64_t k,
|
452
|
+
const TA *A, int64_t lda,
|
453
|
+
const block_q8_0 *B, int64_t ldb,
|
454
|
+
float *C, int64_t ldc,
|
459
455
|
int ith, int nth)
|
460
456
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
461
457
|
}
|
462
458
|
|
463
|
-
void matmul(
|
459
|
+
void matmul(int64_t m, int64_t n, int task) {
|
464
460
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
465
461
|
mnpack(0, m, 0, n);
|
466
462
|
}
|
467
463
|
|
468
464
|
private:
|
469
|
-
NOINLINE void mnpack(
|
470
|
-
|
471
|
-
switch ((
|
465
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
466
|
+
int64_t mc, nc, mp, np;
|
467
|
+
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
|
472
468
|
case 0x33:
|
473
469
|
mc = 3;
|
474
470
|
nc = 3;
|
@@ -524,22 +520,22 @@ class tinyBLAS_Q0_ARM {
|
|
524
520
|
}
|
525
521
|
|
526
522
|
template <int RM, int RN>
|
527
|
-
NOINLINE void gemm(
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
523
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
524
|
+
int64_t ytiles = (m - m0) / RM;
|
525
|
+
int64_t xtiles = (n - n0) / RN;
|
526
|
+
int64_t tiles = xtiles * ytiles;
|
527
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
528
|
+
int64_t start = duty * ith;
|
529
|
+
int64_t end = start + duty;
|
534
530
|
if (end > tiles)
|
535
531
|
end = tiles;
|
536
|
-
for (
|
537
|
-
|
538
|
-
|
532
|
+
for (int64_t job = start; job < end; ++job) {
|
533
|
+
int64_t ii = m0 + job / xtiles * RM;
|
534
|
+
int64_t jj = n0 + job % xtiles * RN;
|
539
535
|
float32x4_t Cv[RN][RM] = {};
|
540
|
-
for (
|
541
|
-
for (
|
542
|
-
for (
|
536
|
+
for (int64_t l = 0; l < k; ++l)
|
537
|
+
for (int64_t j = 0; j < RN; ++j)
|
538
|
+
for (int64_t i = 0; i < RM; ++i)
|
543
539
|
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
544
540
|
vcvtq_f32_s32(vdotq_s32(
|
545
541
|
vdotq_s32(vdupq_n_s32(0),
|
@@ -549,8 +545,8 @@ class tinyBLAS_Q0_ARM {
|
|
549
545
|
load_hi(B + ldb * (jj + j) + l))),
|
550
546
|
unhalf(A[lda * (ii + i) + l].d) *
|
551
547
|
unhalf(B[ldb * (jj + j) + l].d));
|
552
|
-
for (
|
553
|
-
for (
|
548
|
+
for (int64_t j = 0; j < RN; ++j)
|
549
|
+
for (int64_t i = 0; i < RM; ++i)
|
554
550
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
555
551
|
}
|
556
552
|
}
|
@@ -577,36 +573,36 @@ class tinyBLAS_Q0_ARM {
|
|
577
573
|
const TA *const A;
|
578
574
|
const block_q8_0 *const B;
|
579
575
|
float *const C;
|
580
|
-
const
|
581
|
-
const
|
582
|
-
const
|
583
|
-
const
|
576
|
+
const int64_t k;
|
577
|
+
const int64_t lda;
|
578
|
+
const int64_t ldb;
|
579
|
+
const int64_t ldc;
|
584
580
|
const int ith;
|
585
581
|
const int nth;
|
586
582
|
};
|
587
583
|
#endif // __ARM_FEATURE_DOTPROD
|
588
584
|
|
589
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
585
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
590
586
|
template <typename TA, typename TB, typename TC>
|
591
|
-
class
|
587
|
+
class tinyBLAS_Q0_AVX {
|
592
588
|
public:
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
589
|
+
tinyBLAS_Q0_AVX(int64_t k,
|
590
|
+
const TA *A, int64_t lda,
|
591
|
+
const TB *B, int64_t ldb,
|
592
|
+
TC *C, int64_t ldc,
|
593
|
+
int ith, int nth)
|
598
594
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
599
595
|
}
|
600
596
|
|
601
|
-
void matmul(
|
597
|
+
void matmul(int64_t m, int64_t n, int task) {
|
602
598
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
603
599
|
mnpack(0, m, 0, n);
|
604
600
|
}
|
605
601
|
|
606
602
|
private:
|
607
|
-
void mnpack(
|
608
|
-
|
609
|
-
switch ((
|
603
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
604
|
+
int64_t mc, nc, mp, np;
|
605
|
+
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
|
610
606
|
#if VECTOR_REGISTERS == 32
|
611
607
|
case 0x44:
|
612
608
|
mc = 4;
|
@@ -714,31 +710,51 @@ class tinyBLAS_Q0_AVX2 {
|
|
714
710
|
}
|
715
711
|
|
716
712
|
template <int RM, int RN>
|
717
|
-
NOINLINE void gemm(
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
713
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
714
|
+
int64_t ytiles = (m - m0) / RM;
|
715
|
+
int64_t xtiles = (n - n0) / RN;
|
716
|
+
int64_t tiles = xtiles * ytiles;
|
717
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
718
|
+
int64_t start = duty * ith;
|
719
|
+
int64_t end = start + duty;
|
724
720
|
if (end > tiles)
|
725
721
|
end = tiles;
|
726
|
-
for (
|
727
|
-
|
728
|
-
|
722
|
+
for (int64_t job = start; job < end; ++job) {
|
723
|
+
int64_t ii = m0 + job / xtiles * RM;
|
724
|
+
int64_t jj = n0 + job % xtiles * RN;
|
729
725
|
__m256 Cv[RN][RM] = {};
|
730
|
-
for (
|
731
|
-
for (
|
732
|
-
for (
|
726
|
+
for (int64_t l = 0; l < k; ++l)
|
727
|
+
for (int64_t j = 0; j < RN; ++j)
|
728
|
+
for (int64_t i = 0; i < RM; ++i) {
|
729
|
+
#if defined(__AVX2__)
|
730
|
+
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
731
|
+
load(A + lda * (ii + i) + l)),
|
732
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
733
|
+
load(A + lda * (ii + i) + l)));
|
734
|
+
#else
|
735
|
+
__m128i ali0 = load0(A + lda * (ii + i) + l);
|
736
|
+
__m128i ali1 = load1(A + lda * (ii + i) + l);
|
737
|
+
__m128i blj0 = load0(B + ldb * (jj + j) + l);
|
738
|
+
__m128i blj1 = load1(B + ldb * (jj + j) + l);
|
739
|
+
|
740
|
+
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
741
|
+
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
742
|
+
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
743
|
+
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
744
|
+
|
745
|
+
// updot
|
746
|
+
const __m128i oneFill = _mm_set1_epi16(1);
|
747
|
+
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
748
|
+
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
749
|
+
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
750
|
+
#endif
|
733
751
|
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
734
752
|
unhalf(B[ldb * (jj + j) + l].d)),
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
for (int j = 0; j < RN; ++j)
|
741
|
-
for (int i = 0; i < RM; ++i)
|
753
|
+
udTmp,
|
754
|
+
Cv[j][i]);
|
755
|
+
}
|
756
|
+
for (int64_t j = 0; j < RN; ++j)
|
757
|
+
for (int64_t i = 0; i < RM; ++i)
|
742
758
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
743
759
|
}
|
744
760
|
}
|
@@ -747,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
|
|
747
763
|
return _mm256_loadu_si256((const __m256i *)b->qs);
|
748
764
|
}
|
749
765
|
|
766
|
+
inline __m128i load0(const block_q8_0 *b) {
|
767
|
+
return _mm_loadu_si128((const __m128i *)b->qs);
|
768
|
+
}
|
769
|
+
|
770
|
+
inline __m128i load1(const block_q8_0 *b) {
|
771
|
+
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
|
772
|
+
}
|
773
|
+
|
750
774
|
inline __m256i load(const block_q4_0 *b) {
|
751
775
|
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
|
752
776
|
}
|
753
777
|
|
778
|
+
inline __m128i load0(const block_q4_0 *b) {
|
779
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
780
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
|
781
|
+
}
|
782
|
+
|
783
|
+
inline __m128i load1(const block_q4_0 *b) {
|
784
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
785
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
786
|
+
}
|
787
|
+
|
754
788
|
inline __m256 updot(__m256i u, __m256i s) {
|
755
789
|
__m256i res;
|
756
790
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
@@ -771,14 +805,14 @@ class tinyBLAS_Q0_AVX2 {
|
|
771
805
|
const TA *const A;
|
772
806
|
const TB *const B;
|
773
807
|
TC *const C;
|
774
|
-
const
|
775
|
-
const
|
776
|
-
const
|
777
|
-
const
|
808
|
+
const int64_t k;
|
809
|
+
const int64_t lda;
|
810
|
+
const int64_t ldb;
|
811
|
+
const int64_t ldc;
|
778
812
|
const int ith;
|
779
813
|
const int nth;
|
780
814
|
};
|
781
|
-
#endif //
|
815
|
+
#endif // __AVX__
|
782
816
|
|
783
817
|
} // namespace
|
784
818
|
|
@@ -813,8 +847,8 @@ class tinyBLAS_Q0_AVX2 {
|
|
813
847
|
* @param Ctype is GGML data type of `C`
|
814
848
|
* @return true if this function was able to service the matmul request
|
815
849
|
*/
|
816
|
-
bool llamafile_sgemm(
|
817
|
-
|
850
|
+
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
851
|
+
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
818
852
|
|
819
853
|
assert(m >= 0);
|
820
854
|
assert(n >= 0);
|
@@ -824,9 +858,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
824
858
|
assert(ldc >= m);
|
825
859
|
assert(nth > 0);
|
826
860
|
assert(ith < nth);
|
827
|
-
assert(1ll * lda * m <= 0x7fffffff);
|
828
|
-
assert(1ll * ldb * n <= 0x7fffffff);
|
829
|
-
assert(1ll * ldc * n <= 0x7fffffff);
|
830
861
|
|
831
862
|
if (Ctype != GGML_TYPE_F32)
|
832
863
|
return false;
|
@@ -932,8 +963,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
932
963
|
case GGML_TYPE_Q8_0: {
|
933
964
|
if (Btype != GGML_TYPE_Q8_0)
|
934
965
|
return false;
|
935
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
936
|
-
|
966
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
967
|
+
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
|
937
968
|
k, (const block_q8_0 *)A, lda,
|
938
969
|
(const block_q8_0 *)B, ldb,
|
939
970
|
(float *)C, ldc,
|
@@ -956,8 +987,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
956
987
|
case GGML_TYPE_Q4_0: {
|
957
988
|
if (Btype != GGML_TYPE_Q8_0)
|
958
989
|
return false;
|
959
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
960
|
-
|
990
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
991
|
+
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
|
961
992
|
k, (const block_q4_0 *)A, lda,
|
962
993
|
(const block_q8_0 *)B, ldb,
|
963
994
|
(float *)C, ldc,
|
@@ -1,11 +1,13 @@
|
|
1
1
|
#pragma once
|
2
|
+
#include <stdint.h>
|
2
3
|
#include <stdbool.h>
|
3
4
|
#ifdef __cplusplus
|
4
5
|
extern "C" {
|
5
6
|
#endif
|
6
7
|
|
7
|
-
bool llamafile_sgemm(
|
8
|
-
void *,
|
8
|
+
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
9
|
+
const void *, int64_t, void *, int64_t, int, int,
|
10
|
+
int, int, int, int);
|
9
11
|
|
10
12
|
#ifdef __cplusplus
|
11
13
|
}
|