llama_cpp 0.14.7 → 0.15.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +59 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +24 -3
- data/vendor/tmp/llama.cpp/Makefile +42 -18
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +295 -17
- data/vendor/tmp/llama.cpp/ggml-impl.h +78 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +399 -184
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +302 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +28 -16
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +951 -263
- data/vendor/tmp/llama.cpp/ggml.c +1457 -92
- data/vendor/tmp/llama.cpp/ggml.h +37 -7
- data/vendor/tmp/llama.cpp/llama.cpp +671 -403
- data/vendor/tmp/llama.cpp/llama.h +34 -10
- data/vendor/tmp/llama.cpp/sgemm.cpp +134 -103
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1188 -656
- data/vendor/tmp/llama.cpp/unicode-data.h +4 -3
- data/vendor/tmp/llama.cpp/unicode.cpp +590 -49
- data/vendor/tmp/llama.cpp/unicode.h +6 -3
- metadata +3 -3
@@ -40,7 +40,7 @@
|
|
40
40
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
41
41
|
|
42
42
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
43
|
-
#define LLAMA_SESSION_VERSION
|
43
|
+
#define LLAMA_SESSION_VERSION 6
|
44
44
|
|
45
45
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
46
46
|
#define LLAMA_STATE_SEQ_VERSION 1
|
@@ -69,6 +69,23 @@ extern "C" {
|
|
69
69
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
70
70
|
};
|
71
71
|
|
72
|
+
// pre-tokenization types
|
73
|
+
enum llama_vocab_pre_type {
|
74
|
+
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
75
|
+
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
76
|
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
77
|
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
78
|
+
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
79
|
+
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
80
|
+
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
81
|
+
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
82
|
+
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
83
|
+
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
84
|
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
|
85
|
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
|
86
|
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
|
87
|
+
};
|
88
|
+
|
72
89
|
// note: these values should be synchronized with ggml_rope
|
73
90
|
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
74
91
|
enum llama_rope_type {
|
@@ -122,6 +139,7 @@ extern "C" {
|
|
122
139
|
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
123
140
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
124
141
|
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
142
|
+
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
125
143
|
|
126
144
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
127
145
|
};
|
@@ -159,7 +177,7 @@ extern "C" {
|
|
159
177
|
bool sorted;
|
160
178
|
} llama_token_data_array;
|
161
179
|
|
162
|
-
typedef bool (*llama_progress_callback)(float progress, void *
|
180
|
+
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
163
181
|
|
164
182
|
// Input data for llama_decode
|
165
183
|
// A llama_batch object can contain input about one or many sequences
|
@@ -195,15 +213,19 @@ extern "C" {
|
|
195
213
|
LLAMA_KV_OVERRIDE_TYPE_INT,
|
196
214
|
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
197
215
|
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
216
|
+
LLAMA_KV_OVERRIDE_TYPE_STR,
|
198
217
|
};
|
199
218
|
|
200
219
|
struct llama_model_kv_override {
|
201
|
-
char key[128];
|
202
220
|
enum llama_model_kv_override_type tag;
|
221
|
+
|
222
|
+
char key[128];
|
223
|
+
|
203
224
|
union {
|
204
|
-
int64_t
|
205
|
-
double
|
206
|
-
bool
|
225
|
+
int64_t val_i64;
|
226
|
+
double val_f64;
|
227
|
+
bool val_bool;
|
228
|
+
char val_str[128];
|
207
229
|
};
|
208
230
|
};
|
209
231
|
|
@@ -232,9 +254,10 @@ extern "C" {
|
|
232
254
|
const struct llama_model_kv_override * kv_overrides;
|
233
255
|
|
234
256
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
235
|
-
bool vocab_only;
|
236
|
-
bool use_mmap;
|
237
|
-
bool use_mlock;
|
257
|
+
bool vocab_only; // only load the vocabulary, no weights
|
258
|
+
bool use_mmap; // use mmap if possible
|
259
|
+
bool use_mlock; // force system to keep model in RAM
|
260
|
+
bool check_tensors; // validate model tensor data
|
238
261
|
};
|
239
262
|
|
240
263
|
struct llama_context_params {
|
@@ -270,6 +293,7 @@ extern "C" {
|
|
270
293
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
271
294
|
bool embeddings; // if true, extract embeddings (together with logits)
|
272
295
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
296
|
+
bool flash_attn; // whether to use flash attention
|
273
297
|
|
274
298
|
// Abort callback
|
275
299
|
// if it returns true, execution of llama_decode() will be aborted
|
@@ -525,7 +549,7 @@ extern "C" {
|
|
525
549
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
526
550
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
527
551
|
|
528
|
-
// Clear the KV cache
|
552
|
+
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
529
553
|
LLAMA_API void llama_kv_cache_clear(
|
530
554
|
struct llama_context * ctx);
|
531
555
|
|
@@ -1,6 +1,3 @@
|
|
1
|
-
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
2
|
-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
3
|
-
//
|
4
1
|
// Copyright 2024 Mozilla Foundation
|
5
2
|
//
|
6
3
|
// Permission is hereby granted, free of charge, to any person obtaining
|
@@ -50,7 +47,6 @@
|
|
50
47
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
51
48
|
|
52
49
|
#include "sgemm.h"
|
53
|
-
#include <algorithm>
|
54
50
|
#include "ggml-impl.h"
|
55
51
|
#include "ggml-quants.h"
|
56
52
|
|
@@ -243,23 +239,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
243
239
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
244
240
|
class tinyBLAS {
|
245
241
|
public:
|
246
|
-
tinyBLAS(
|
247
|
-
const TA *A,
|
248
|
-
const TB *B,
|
249
|
-
TC *C,
|
242
|
+
tinyBLAS(int64_t k,
|
243
|
+
const TA *A, int64_t lda,
|
244
|
+
const TB *B, int64_t ldb,
|
245
|
+
TC *C, int64_t ldc,
|
250
246
|
int ith, int nth)
|
251
247
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
252
248
|
}
|
253
249
|
|
254
|
-
void matmul(
|
250
|
+
void matmul(int64_t m, int64_t n, int task) {
|
255
251
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
256
252
|
mnpack(0, m, 0, n);
|
257
253
|
}
|
258
254
|
|
259
255
|
private:
|
260
|
-
NOINLINE void mnpack(
|
261
|
-
|
262
|
-
switch ((
|
256
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
257
|
+
int64_t mc, nc, mp, np;
|
258
|
+
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
263
259
|
#if VECTOR_REGISTERS == 32
|
264
260
|
case 0x55:
|
265
261
|
mc = 5;
|
@@ -409,27 +405,27 @@ class tinyBLAS {
|
|
409
405
|
}
|
410
406
|
|
411
407
|
template <int RM, int RN>
|
412
|
-
NOINLINE void gemm(
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
408
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
409
|
+
int64_t ytiles = (m - m0) / RM;
|
410
|
+
int64_t xtiles = (n - n0) / RN;
|
411
|
+
int64_t tiles = xtiles * ytiles;
|
412
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
413
|
+
int64_t start = duty * ith;
|
414
|
+
int64_t end = start + duty;
|
419
415
|
if (end > tiles)
|
420
416
|
end = tiles;
|
421
|
-
for (
|
422
|
-
|
423
|
-
|
417
|
+
for (int64_t job = start; job < end; ++job) {
|
418
|
+
int64_t ii = m0 + job / xtiles * RM;
|
419
|
+
int64_t jj = n0 + job % xtiles * RN;
|
424
420
|
D Cv[RN][RM] = {};
|
425
|
-
for (
|
426
|
-
for (
|
427
|
-
for (
|
421
|
+
for (int64_t l = 0; l < k; l += KN)
|
422
|
+
for (int64_t j = 0; j < RN; ++j)
|
423
|
+
for (int64_t i = 0; i < RM; ++i)
|
428
424
|
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
429
425
|
load<V>(B + ldb * (jj + j) + l),
|
430
426
|
Cv[j][i]);
|
431
|
-
for (
|
432
|
-
for (
|
427
|
+
for (int64_t j = 0; j < RN; ++j)
|
428
|
+
for (int64_t i = 0; i < RM; ++i)
|
433
429
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
434
430
|
}
|
435
431
|
}
|
@@ -437,10 +433,10 @@ class tinyBLAS {
|
|
437
433
|
const TA *const A;
|
438
434
|
const TB *const B;
|
439
435
|
TC *const C;
|
440
|
-
const
|
441
|
-
const
|
442
|
-
const
|
443
|
-
const
|
436
|
+
const int64_t k;
|
437
|
+
const int64_t lda;
|
438
|
+
const int64_t ldb;
|
439
|
+
const int64_t ldc;
|
444
440
|
const int ith;
|
445
441
|
const int nth;
|
446
442
|
};
|
@@ -452,23 +448,23 @@ class tinyBLAS {
|
|
452
448
|
template <typename TA>
|
453
449
|
class tinyBLAS_Q0_ARM {
|
454
450
|
public:
|
455
|
-
tinyBLAS_Q0_ARM(
|
456
|
-
const TA *A,
|
457
|
-
const block_q8_0 *B,
|
458
|
-
float *C,
|
451
|
+
tinyBLAS_Q0_ARM(int64_t k,
|
452
|
+
const TA *A, int64_t lda,
|
453
|
+
const block_q8_0 *B, int64_t ldb,
|
454
|
+
float *C, int64_t ldc,
|
459
455
|
int ith, int nth)
|
460
456
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
461
457
|
}
|
462
458
|
|
463
|
-
void matmul(
|
459
|
+
void matmul(int64_t m, int64_t n, int task) {
|
464
460
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
465
461
|
mnpack(0, m, 0, n);
|
466
462
|
}
|
467
463
|
|
468
464
|
private:
|
469
|
-
NOINLINE void mnpack(
|
470
|
-
|
471
|
-
switch ((
|
465
|
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
466
|
+
int64_t mc, nc, mp, np;
|
467
|
+
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
|
472
468
|
case 0x33:
|
473
469
|
mc = 3;
|
474
470
|
nc = 3;
|
@@ -524,22 +520,22 @@ class tinyBLAS_Q0_ARM {
|
|
524
520
|
}
|
525
521
|
|
526
522
|
template <int RM, int RN>
|
527
|
-
NOINLINE void gemm(
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
523
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
524
|
+
int64_t ytiles = (m - m0) / RM;
|
525
|
+
int64_t xtiles = (n - n0) / RN;
|
526
|
+
int64_t tiles = xtiles * ytiles;
|
527
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
528
|
+
int64_t start = duty * ith;
|
529
|
+
int64_t end = start + duty;
|
534
530
|
if (end > tiles)
|
535
531
|
end = tiles;
|
536
|
-
for (
|
537
|
-
|
538
|
-
|
532
|
+
for (int64_t job = start; job < end; ++job) {
|
533
|
+
int64_t ii = m0 + job / xtiles * RM;
|
534
|
+
int64_t jj = n0 + job % xtiles * RN;
|
539
535
|
float32x4_t Cv[RN][RM] = {};
|
540
|
-
for (
|
541
|
-
for (
|
542
|
-
for (
|
536
|
+
for (int64_t l = 0; l < k; ++l)
|
537
|
+
for (int64_t j = 0; j < RN; ++j)
|
538
|
+
for (int64_t i = 0; i < RM; ++i)
|
543
539
|
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
544
540
|
vcvtq_f32_s32(vdotq_s32(
|
545
541
|
vdotq_s32(vdupq_n_s32(0),
|
@@ -549,8 +545,8 @@ class tinyBLAS_Q0_ARM {
|
|
549
545
|
load_hi(B + ldb * (jj + j) + l))),
|
550
546
|
unhalf(A[lda * (ii + i) + l].d) *
|
551
547
|
unhalf(B[ldb * (jj + j) + l].d));
|
552
|
-
for (
|
553
|
-
for (
|
548
|
+
for (int64_t j = 0; j < RN; ++j)
|
549
|
+
for (int64_t i = 0; i < RM; ++i)
|
554
550
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
555
551
|
}
|
556
552
|
}
|
@@ -577,36 +573,36 @@ class tinyBLAS_Q0_ARM {
|
|
577
573
|
const TA *const A;
|
578
574
|
const block_q8_0 *const B;
|
579
575
|
float *const C;
|
580
|
-
const
|
581
|
-
const
|
582
|
-
const
|
583
|
-
const
|
576
|
+
const int64_t k;
|
577
|
+
const int64_t lda;
|
578
|
+
const int64_t ldb;
|
579
|
+
const int64_t ldc;
|
584
580
|
const int ith;
|
585
581
|
const int nth;
|
586
582
|
};
|
587
583
|
#endif // __ARM_FEATURE_DOTPROD
|
588
584
|
|
589
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
585
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
590
586
|
template <typename TA, typename TB, typename TC>
|
591
|
-
class
|
587
|
+
class tinyBLAS_Q0_AVX {
|
592
588
|
public:
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
589
|
+
tinyBLAS_Q0_AVX(int64_t k,
|
590
|
+
const TA *A, int64_t lda,
|
591
|
+
const TB *B, int64_t ldb,
|
592
|
+
TC *C, int64_t ldc,
|
593
|
+
int ith, int nth)
|
598
594
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
599
595
|
}
|
600
596
|
|
601
|
-
void matmul(
|
597
|
+
void matmul(int64_t m, int64_t n, int task) {
|
602
598
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
603
599
|
mnpack(0, m, 0, n);
|
604
600
|
}
|
605
601
|
|
606
602
|
private:
|
607
|
-
void mnpack(
|
608
|
-
|
609
|
-
switch ((
|
603
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
604
|
+
int64_t mc, nc, mp, np;
|
605
|
+
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
|
610
606
|
#if VECTOR_REGISTERS == 32
|
611
607
|
case 0x44:
|
612
608
|
mc = 4;
|
@@ -714,31 +710,51 @@ class tinyBLAS_Q0_AVX2 {
|
|
714
710
|
}
|
715
711
|
|
716
712
|
template <int RM, int RN>
|
717
|
-
NOINLINE void gemm(
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
713
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
714
|
+
int64_t ytiles = (m - m0) / RM;
|
715
|
+
int64_t xtiles = (n - n0) / RN;
|
716
|
+
int64_t tiles = xtiles * ytiles;
|
717
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
718
|
+
int64_t start = duty * ith;
|
719
|
+
int64_t end = start + duty;
|
724
720
|
if (end > tiles)
|
725
721
|
end = tiles;
|
726
|
-
for (
|
727
|
-
|
728
|
-
|
722
|
+
for (int64_t job = start; job < end; ++job) {
|
723
|
+
int64_t ii = m0 + job / xtiles * RM;
|
724
|
+
int64_t jj = n0 + job % xtiles * RN;
|
729
725
|
__m256 Cv[RN][RM] = {};
|
730
|
-
for (
|
731
|
-
for (
|
732
|
-
for (
|
726
|
+
for (int64_t l = 0; l < k; ++l)
|
727
|
+
for (int64_t j = 0; j < RN; ++j)
|
728
|
+
for (int64_t i = 0; i < RM; ++i) {
|
729
|
+
#if defined(__AVX2__)
|
730
|
+
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
731
|
+
load(A + lda * (ii + i) + l)),
|
732
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
733
|
+
load(A + lda * (ii + i) + l)));
|
734
|
+
#else
|
735
|
+
__m128i ali0 = load0(A + lda * (ii + i) + l);
|
736
|
+
__m128i ali1 = load1(A + lda * (ii + i) + l);
|
737
|
+
__m128i blj0 = load0(B + ldb * (jj + j) + l);
|
738
|
+
__m128i blj1 = load1(B + ldb * (jj + j) + l);
|
739
|
+
|
740
|
+
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
741
|
+
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
742
|
+
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
743
|
+
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
744
|
+
|
745
|
+
// updot
|
746
|
+
const __m128i oneFill = _mm_set1_epi16(1);
|
747
|
+
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
748
|
+
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
749
|
+
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
750
|
+
#endif
|
733
751
|
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
734
752
|
unhalf(B[ldb * (jj + j) + l].d)),
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
for (int j = 0; j < RN; ++j)
|
741
|
-
for (int i = 0; i < RM; ++i)
|
753
|
+
udTmp,
|
754
|
+
Cv[j][i]);
|
755
|
+
}
|
756
|
+
for (int64_t j = 0; j < RN; ++j)
|
757
|
+
for (int64_t i = 0; i < RM; ++i)
|
742
758
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
743
759
|
}
|
744
760
|
}
|
@@ -747,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
|
|
747
763
|
return _mm256_loadu_si256((const __m256i *)b->qs);
|
748
764
|
}
|
749
765
|
|
766
|
+
inline __m128i load0(const block_q8_0 *b) {
|
767
|
+
return _mm_loadu_si128((const __m128i *)b->qs);
|
768
|
+
}
|
769
|
+
|
770
|
+
inline __m128i load1(const block_q8_0 *b) {
|
771
|
+
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
|
772
|
+
}
|
773
|
+
|
750
774
|
inline __m256i load(const block_q4_0 *b) {
|
751
775
|
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
|
752
776
|
}
|
753
777
|
|
778
|
+
inline __m128i load0(const block_q4_0 *b) {
|
779
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
780
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
|
781
|
+
}
|
782
|
+
|
783
|
+
inline __m128i load1(const block_q4_0 *b) {
|
784
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
785
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
786
|
+
}
|
787
|
+
|
754
788
|
inline __m256 updot(__m256i u, __m256i s) {
|
755
789
|
__m256i res;
|
756
790
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
@@ -771,14 +805,14 @@ class tinyBLAS_Q0_AVX2 {
|
|
771
805
|
const TA *const A;
|
772
806
|
const TB *const B;
|
773
807
|
TC *const C;
|
774
|
-
const
|
775
|
-
const
|
776
|
-
const
|
777
|
-
const
|
808
|
+
const int64_t k;
|
809
|
+
const int64_t lda;
|
810
|
+
const int64_t ldb;
|
811
|
+
const int64_t ldc;
|
778
812
|
const int ith;
|
779
813
|
const int nth;
|
780
814
|
};
|
781
|
-
#endif //
|
815
|
+
#endif // __AVX__
|
782
816
|
|
783
817
|
} // namespace
|
784
818
|
|
@@ -813,8 +847,8 @@ class tinyBLAS_Q0_AVX2 {
|
|
813
847
|
* @param Ctype is GGML data type of `C`
|
814
848
|
* @return true if this function was able to service the matmul request
|
815
849
|
*/
|
816
|
-
bool llamafile_sgemm(
|
817
|
-
|
850
|
+
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
851
|
+
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
818
852
|
|
819
853
|
assert(m >= 0);
|
820
854
|
assert(n >= 0);
|
@@ -824,9 +858,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
824
858
|
assert(ldc >= m);
|
825
859
|
assert(nth > 0);
|
826
860
|
assert(ith < nth);
|
827
|
-
assert(1ll * lda * m <= 0x7fffffff);
|
828
|
-
assert(1ll * ldb * n <= 0x7fffffff);
|
829
|
-
assert(1ll * ldc * n <= 0x7fffffff);
|
830
861
|
|
831
862
|
if (Ctype != GGML_TYPE_F32)
|
832
863
|
return false;
|
@@ -932,8 +963,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
932
963
|
case GGML_TYPE_Q8_0: {
|
933
964
|
if (Btype != GGML_TYPE_Q8_0)
|
934
965
|
return false;
|
935
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
936
|
-
|
966
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
967
|
+
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
|
937
968
|
k, (const block_q8_0 *)A, lda,
|
938
969
|
(const block_q8_0 *)B, ldb,
|
939
970
|
(float *)C, ldc,
|
@@ -956,8 +987,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
|
956
987
|
case GGML_TYPE_Q4_0: {
|
957
988
|
if (Btype != GGML_TYPE_Q8_0)
|
958
989
|
return false;
|
959
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
960
|
-
|
990
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
991
|
+
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
|
961
992
|
k, (const block_q4_0 *)A, lda,
|
962
993
|
(const block_q8_0 *)B, ldb,
|
963
994
|
(float *)C, ldc,
|
@@ -1,11 +1,13 @@
|
|
1
1
|
#pragma once
|
2
|
+
#include <stdint.h>
|
2
3
|
#include <stdbool.h>
|
3
4
|
#ifdef __cplusplus
|
4
5
|
extern "C" {
|
5
6
|
#endif
|
6
7
|
|
7
|
-
bool llamafile_sgemm(
|
8
|
-
void *,
|
8
|
+
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
9
|
+
const void *, int64_t, void *, int64_t, int, int,
|
10
|
+
int, int, int, int);
|
9
11
|
|
10
12
|
#ifdef __cplusplus
|
11
13
|
}
|