llama_cpp 0.15.0 → 0.15.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +3 -4
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +289 -17
- data/vendor/tmp/llama.cpp/ggml-impl.h +77 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +23 -8
- data/vendor/tmp/llama.cpp/ggml-metal.metal +1 -1
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +18 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +11 -9
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +950 -267
- data/vendor/tmp/llama.cpp/ggml.c +1090 -89
- data/vendor/tmp/llama.cpp/ggml.h +15 -7
- data/vendor/tmp/llama.cpp/llama.cpp +57 -17
- data/vendor/tmp/llama.cpp/llama.h +7 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +56 -21
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1187 -655
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -1
- data/vendor/tmp/llama.cpp/unicode.cpp +254 -122
- data/vendor/tmp/llama.cpp/unicode.h +4 -2
- metadata +2 -2
data/vendor/tmp/llama.cpp/ggml.h
CHANGED
@@ -326,14 +326,20 @@ extern "C" {
|
|
326
326
|
// get ggml_status name string
|
327
327
|
GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
|
328
328
|
|
329
|
+
// ieee 754-2008 half-precision float16
|
330
|
+
// todo: make this not an integral type
|
329
331
|
typedef uint16_t ggml_fp16_t;
|
330
|
-
|
331
|
-
|
332
|
-
GGML_API
|
333
|
-
GGML_API
|
334
|
-
|
335
|
-
|
336
|
-
|
332
|
+
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
|
333
|
+
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
|
334
|
+
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
|
335
|
+
GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
|
336
|
+
|
337
|
+
// google brain half-precision bfloat16
|
338
|
+
typedef struct { uint16_t bits; } ggml_bf16_t;
|
339
|
+
GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
|
340
|
+
GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
|
341
|
+
GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
|
342
|
+
GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
|
337
343
|
|
338
344
|
struct ggml_object;
|
339
345
|
struct ggml_context;
|
@@ -370,6 +376,7 @@ extern "C" {
|
|
370
376
|
GGML_TYPE_I64 = 27,
|
371
377
|
GGML_TYPE_F64 = 28,
|
372
378
|
GGML_TYPE_IQ1_M = 29,
|
379
|
+
GGML_TYPE_BF16 = 30,
|
373
380
|
GGML_TYPE_COUNT,
|
374
381
|
};
|
375
382
|
|
@@ -410,6 +417,7 @@ extern "C" {
|
|
410
417
|
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
411
418
|
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
412
419
|
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
420
|
+
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
413
421
|
};
|
414
422
|
|
415
423
|
// available tensor operations:
|
@@ -3175,6 +3175,7 @@ struct llama_model_loader {
|
|
3175
3175
|
switch (type_max) {
|
3176
3176
|
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
|
3177
3177
|
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
|
3178
|
+
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
|
3178
3179
|
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
|
3179
3180
|
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
|
3180
3181
|
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
|
@@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|
3666
3667
|
switch (ftype) {
|
3667
3668
|
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
3668
3669
|
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
|
3670
|
+
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
|
3669
3671
|
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
|
3670
3672
|
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
|
3671
3673
|
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
@@ -4383,6 +4385,21 @@ static void llm_load_vocab(
|
|
4383
4385
|
} else if (
|
4384
4386
|
tokenizer_pre == "gpt-2") {
|
4385
4387
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
4388
|
+
} else if (
|
4389
|
+
tokenizer_pre == "refact") {
|
4390
|
+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT;
|
4391
|
+
} else if (
|
4392
|
+
tokenizer_pre == "command-r") {
|
4393
|
+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
4394
|
+
} else if (
|
4395
|
+
tokenizer_pre == "qwen2") {
|
4396
|
+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
4397
|
+
} else if (
|
4398
|
+
tokenizer_pre == "olmo") {
|
4399
|
+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
|
4400
|
+
} else if (
|
4401
|
+
tokenizer_pre == "dbrx") {
|
4402
|
+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
|
4386
4403
|
} else {
|
4387
4404
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
4388
4405
|
}
|
@@ -6120,6 +6137,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|
6120
6137
|
|| !(
|
6121
6138
|
model.ftype == LLAMA_FTYPE_ALL_F32 ||
|
6122
6139
|
model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
|
6140
|
+
model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
|
6123
6141
|
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
|
6124
6142
|
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
|
6125
6143
|
)
|
@@ -11952,7 +11970,7 @@ static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id
|
|
11952
11970
|
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
11953
11971
|
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
|
11954
11972
|
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
11955
|
-
const auto& token_data = vocab.id_to_token.at(id);
|
11973
|
+
const auto & token_data = vocab.id_to_token.at(id);
|
11956
11974
|
switch (llama_vocab_get_type(vocab)) {
|
11957
11975
|
case LLAMA_VOCAB_TYPE_SPM: {
|
11958
11976
|
auto buf = token_data.text.substr(3, 2);
|
@@ -12188,6 +12206,7 @@ struct llm_tokenizer_bpe {
|
|
12188
12206
|
case LLAMA_VOCAB_TYPE_BPE:
|
12189
12207
|
switch (vocab.type_pre) {
|
12190
12208
|
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
12209
|
+
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
12191
12210
|
word_collection = unicode_regex_split(text, {
|
12192
12211
|
// original regex from tokenizer.json
|
12193
12212
|
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
@@ -12212,14 +12231,13 @@ struct llm_tokenizer_bpe {
|
|
12212
12231
|
"\\s?\\p{L}+",
|
12213
12232
|
"\\s?\\p{P}+",
|
12214
12233
|
"[一-龥ࠀ-一가-]+",
|
12215
|
-
"\\p{N}
|
12234
|
+
"\\p{N}",
|
12216
12235
|
});
|
12217
12236
|
break;
|
12218
12237
|
case LLAMA_VOCAB_PRE_TYPE_FALCON:
|
12219
12238
|
word_collection = unicode_regex_split(text, {
|
12220
12239
|
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
12221
12240
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
12222
|
-
"\\p{N}+",
|
12223
12241
|
"[0-9][0-9][0-9]",
|
12224
12242
|
});
|
12225
12243
|
break;
|
@@ -12235,11 +12253,26 @@ struct llm_tokenizer_bpe {
|
|
12235
12253
|
});
|
12236
12254
|
break;
|
12237
12255
|
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
12256
|
+
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
12257
|
+
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
12258
|
+
word_collection = unicode_regex_split(text, {
|
12259
|
+
"\\p{N}",
|
12260
|
+
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
12261
|
+
});
|
12262
|
+
break;
|
12238
12263
|
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
12264
|
+
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
12239
12265
|
word_collection = unicode_regex_split(text, {
|
12240
12266
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
12241
12267
|
});
|
12242
12268
|
break;
|
12269
|
+
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
12270
|
+
word_collection = unicode_regex_split(text, {
|
12271
|
+
// original regex from tokenizer.json
|
12272
|
+
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
12273
|
+
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
12274
|
+
});
|
12275
|
+
break;
|
12243
12276
|
default:
|
12244
12277
|
// default regex for BPE tokenization pre-processing
|
12245
12278
|
word_collection = unicode_regex_split(text, {
|
@@ -12455,7 +12488,7 @@ struct llm_tokenizer_wpm {
|
|
12455
12488
|
continue;
|
12456
12489
|
}
|
12457
12490
|
code = unicode_tolower(code);
|
12458
|
-
if (type ==
|
12491
|
+
if (type == CODEPOINT_TYPE_SEPARATOR) {
|
12459
12492
|
code = ' ';
|
12460
12493
|
}
|
12461
12494
|
std::string s = unicode_cpt_to_utf8(code);
|
@@ -14142,13 +14175,16 @@ static void llama_tensor_dequantize_internal(
|
|
14142
14175
|
if (qtype.to_float == NULL) {
|
14143
14176
|
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
14144
14177
|
}
|
14145
|
-
} else if (tensor->type != GGML_TYPE_F16
|
14178
|
+
} else if (tensor->type != GGML_TYPE_F16 &&
|
14179
|
+
tensor->type != GGML_TYPE_BF16) {
|
14146
14180
|
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
14147
14181
|
}
|
14148
14182
|
|
14149
14183
|
if (nthread < 2) {
|
14150
14184
|
if (tensor->type == GGML_TYPE_F16) {
|
14151
14185
|
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
14186
|
+
} else if (tensor->type == GGML_TYPE_BF16) {
|
14187
|
+
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
|
14152
14188
|
} else if (ggml_is_quantized(tensor->type)) {
|
14153
14189
|
qtype.to_float(tensor->data, f32_output, nelements);
|
14154
14190
|
} else {
|
@@ -14157,7 +14193,14 @@ static void llama_tensor_dequantize_internal(
|
|
14157
14193
|
return;
|
14158
14194
|
}
|
14159
14195
|
|
14160
|
-
size_t block_size
|
14196
|
+
size_t block_size;
|
14197
|
+
if (tensor->type == GGML_TYPE_F16 ||
|
14198
|
+
tensor->type == GGML_TYPE_BF16) {
|
14199
|
+
block_size = 1;
|
14200
|
+
} else {
|
14201
|
+
block_size = (size_t)ggml_blck_size(tensor->type);
|
14202
|
+
}
|
14203
|
+
|
14161
14204
|
size_t block_size_bytes = ggml_type_size(tensor->type);
|
14162
14205
|
|
14163
14206
|
GGML_ASSERT(nelements % block_size == 0);
|
@@ -14176,6 +14219,8 @@ static void llama_tensor_dequantize_internal(
|
|
14176
14219
|
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
14177
14220
|
if (typ == GGML_TYPE_F16) {
|
14178
14221
|
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
14222
|
+
} else if (typ == GGML_TYPE_BF16) {
|
14223
|
+
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
|
14179
14224
|
} else {
|
14180
14225
|
qtype.to_float(inbuf, outbuf, nels);
|
14181
14226
|
}
|
@@ -14536,6 +14581,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|
14536
14581
|
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
|
14537
14582
|
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
|
14538
14583
|
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
14584
|
+
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
14539
14585
|
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
14540
14586
|
|
14541
14587
|
// K-quants
|
@@ -15473,13 +15519,6 @@ struct llama_context * llama_new_context_with_model(
|
|
15473
15519
|
cparams.flash_attn = false;
|
15474
15520
|
}
|
15475
15521
|
|
15476
|
-
#ifdef GGML_USE_HIPBLAS
|
15477
|
-
if (cparams.flash_attn) {
|
15478
|
-
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
|
15479
|
-
cparams.flash_attn = false;
|
15480
|
-
}
|
15481
|
-
#endif
|
15482
|
-
|
15483
15522
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
15484
15523
|
params.seed = time(NULL);
|
15485
15524
|
}
|
@@ -17466,9 +17505,10 @@ int32_t llama_tokenize(
|
|
17466
17505
|
|
17467
17506
|
static std::string llama_decode_text(const std::string & text) {
|
17468
17507
|
std::string decoded_text;
|
17469
|
-
|
17470
|
-
|
17471
|
-
|
17508
|
+
|
17509
|
+
const auto cpts = unicode_cpts_from_utf8(text);
|
17510
|
+
for (const auto cpt : cpts) {
|
17511
|
+
decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(cpt));
|
17472
17512
|
}
|
17473
17513
|
|
17474
17514
|
return decoded_text;
|
@@ -17832,7 +17872,7 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
|
17832
17872
|
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
17833
17873
|
|
17834
17874
|
/*.n_sample =*/ std::max(1, ctx->n_sample),
|
17835
|
-
/*.n_p_eval =*/ std::max(
|
17875
|
+
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
17836
17876
|
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
17837
17877
|
};
|
17838
17878
|
|
@@ -79,6 +79,11 @@ extern "C" {
|
|
79
79
|
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
80
80
|
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
81
81
|
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
82
|
+
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
83
|
+
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
84
|
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
|
85
|
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
|
86
|
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
|
82
87
|
};
|
83
88
|
|
84
89
|
// note: these values should be synchronized with ggml_rope
|
@@ -134,6 +139,7 @@ extern "C" {
|
|
134
139
|
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
135
140
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
136
141
|
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
142
|
+
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
137
143
|
|
138
144
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
139
145
|
};
|
@@ -171,7 +177,7 @@ extern "C" {
|
|
171
177
|
bool sorted;
|
172
178
|
} llama_token_data_array;
|
173
179
|
|
174
|
-
typedef bool (*llama_progress_callback)(float progress, void *
|
180
|
+
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
175
181
|
|
176
182
|
// Input data for llama_decode
|
177
183
|
// A llama_batch object can contain input about one or many sequences
|
@@ -1,6 +1,3 @@
|
|
1
|
-
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
2
|
-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
3
|
-
//
|
4
1
|
// Copyright 2024 Mozilla Foundation
|
5
2
|
//
|
6
3
|
// Permission is hereby granted, free of charge, to any person obtaining
|
@@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
|
|
585
582
|
};
|
586
583
|
#endif // __ARM_FEATURE_DOTPROD
|
587
584
|
|
588
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
585
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
589
586
|
template <typename TA, typename TB, typename TC>
|
590
|
-
class
|
587
|
+
class tinyBLAS_Q0_AVX {
|
591
588
|
public:
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
589
|
+
tinyBLAS_Q0_AVX(int64_t k,
|
590
|
+
const TA *A, int64_t lda,
|
591
|
+
const TB *B, int64_t ldb,
|
592
|
+
TC *C, int64_t ldc,
|
593
|
+
int ith, int nth)
|
597
594
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
598
595
|
}
|
599
596
|
|
@@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
|
|
728
725
|
__m256 Cv[RN][RM] = {};
|
729
726
|
for (int64_t l = 0; l < k; ++l)
|
730
727
|
for (int64_t j = 0; j < RN; ++j)
|
731
|
-
for (int64_t i = 0; i < RM; ++i)
|
728
|
+
for (int64_t i = 0; i < RM; ++i) {
|
729
|
+
#if defined(__AVX2__)
|
730
|
+
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
731
|
+
load(A + lda * (ii + i) + l)),
|
732
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
733
|
+
load(A + lda * (ii + i) + l)));
|
734
|
+
#else
|
735
|
+
__m128i ali0 = load0(A + lda * (ii + i) + l);
|
736
|
+
__m128i ali1 = load1(A + lda * (ii + i) + l);
|
737
|
+
__m128i blj0 = load0(B + ldb * (jj + j) + l);
|
738
|
+
__m128i blj1 = load1(B + ldb * (jj + j) + l);
|
739
|
+
|
740
|
+
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
741
|
+
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
742
|
+
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
743
|
+
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
744
|
+
|
745
|
+
// updot
|
746
|
+
const __m128i oneFill = _mm_set1_epi16(1);
|
747
|
+
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
748
|
+
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
749
|
+
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
750
|
+
#endif
|
732
751
|
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
733
752
|
unhalf(B[ldb * (jj + j) + l].d)),
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
load(A + lda * (ii + i) + l))),
|
738
|
-
Cv[j][i]);
|
753
|
+
udTmp,
|
754
|
+
Cv[j][i]);
|
755
|
+
}
|
739
756
|
for (int64_t j = 0; j < RN; ++j)
|
740
757
|
for (int64_t i = 0; i < RM; ++i)
|
741
758
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
@@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
|
|
746
763
|
return _mm256_loadu_si256((const __m256i *)b->qs);
|
747
764
|
}
|
748
765
|
|
766
|
+
inline __m128i load0(const block_q8_0 *b) {
|
767
|
+
return _mm_loadu_si128((const __m128i *)b->qs);
|
768
|
+
}
|
769
|
+
|
770
|
+
inline __m128i load1(const block_q8_0 *b) {
|
771
|
+
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
|
772
|
+
}
|
773
|
+
|
749
774
|
inline __m256i load(const block_q4_0 *b) {
|
750
775
|
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
|
751
776
|
}
|
752
777
|
|
778
|
+
inline __m128i load0(const block_q4_0 *b) {
|
779
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
780
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
|
781
|
+
}
|
782
|
+
|
783
|
+
inline __m128i load1(const block_q4_0 *b) {
|
784
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
785
|
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
786
|
+
}
|
787
|
+
|
753
788
|
inline __m256 updot(__m256i u, __m256i s) {
|
754
789
|
__m256i res;
|
755
790
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
@@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
|
|
777
812
|
const int ith;
|
778
813
|
const int nth;
|
779
814
|
};
|
780
|
-
#endif //
|
815
|
+
#endif // __AVX__
|
781
816
|
|
782
817
|
} // namespace
|
783
818
|
|
@@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
928
963
|
case GGML_TYPE_Q8_0: {
|
929
964
|
if (Btype != GGML_TYPE_Q8_0)
|
930
965
|
return false;
|
931
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
932
|
-
|
966
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
967
|
+
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
|
933
968
|
k, (const block_q8_0 *)A, lda,
|
934
969
|
(const block_q8_0 *)B, ldb,
|
935
970
|
(float *)C, ldc,
|
@@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
952
987
|
case GGML_TYPE_Q4_0: {
|
953
988
|
if (Btype != GGML_TYPE_Q8_0)
|
954
989
|
return false;
|
955
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
956
|
-
|
990
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
991
|
+
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
|
957
992
|
k, (const block_q4_0 *)A, lda,
|
958
993
|
(const block_q8_0 *)B, ldb,
|
959
994
|
(float *)C, ldc,
|