cui-llama.rn 1.2.4 → 1.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -2
- package/android/src/main/CMakeLists.txt +1 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +0 -3
- package/android/src/main/jni.cpp +2 -4
- package/cpp/common.cpp +6 -14
- package/cpp/common.h +59 -40
- package/cpp/ggml-aarch64.c +269 -0
- package/cpp/ggml-backend-impl.h +4 -15
- package/cpp/ggml-backend.cpp +1640 -1604
- package/cpp/ggml-backend.h +13 -25
- package/cpp/ggml-cpp.h +38 -0
- package/cpp/ggml-cpu.c +13720 -0
- package/cpp/ggml-cpu.h +150 -0
- package/cpp/ggml-impl.h +87 -0
- package/cpp/ggml-metal.m +185 -71
- package/cpp/ggml-quants.c +38 -51
- package/cpp/ggml.c +4442 -19516
- package/cpp/ggml.h +25 -146
- package/cpp/llama-sampling.cpp +392 -241
- package/cpp/llama-sampling.h +18 -0
- package/cpp/llama-vocab.cpp +16 -0
- package/cpp/llama-vocab.h +5 -0
- package/cpp/llama.cpp +2084 -2007
- package/cpp/llama.h +13 -11
- package/cpp/sampling.cpp +19 -11
- package/cpp/sgemm.cpp +57 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +0 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +0 -1
package/README.md
CHANGED
@@ -11,8 +11,6 @@ The following features have been added for Android:
|
|
11
11
|
- `vocab_only` mode: utilize the llama.cpp tokenizer
|
12
12
|
- tokenizeSync: non-blocking, synchronous tokenizer function
|
13
13
|
- Context Shift taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)
|
14
|
-
- XTC sampling
|
15
|
-
- Progress callback
|
16
14
|
- Retrieving CPU Features to check for i8mm and dotprod flags
|
17
15
|
|
18
16
|
Original repo README.md below.
|
@@ -248,8 +248,6 @@ public class LlamaContext {
|
|
248
248
|
params.hasKey("xtc_t") ? (float) params.getDouble("xtc_t") : 0.00f,
|
249
249
|
// float xtc_p,
|
250
250
|
params.hasKey("xtc_p") ? (float) params.getDouble("xtc_p") : 0.00f,
|
251
|
-
// float tfs_z,
|
252
|
-
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
|
253
251
|
// float typical_p,
|
254
252
|
params.hasKey("typical_p") ? (float) params.getDouble("typical_p") : 1.00f,
|
255
253
|
// int seed,
|
@@ -438,7 +436,6 @@ public class LlamaContext {
|
|
438
436
|
float min_p,
|
439
437
|
float xtc_t,
|
440
438
|
float xtc_p,
|
441
|
-
float tfs_z,
|
442
439
|
float typical_p,
|
443
440
|
int seed,
|
444
441
|
String[] stop,
|
package/android/src/main/jni.cpp
CHANGED
@@ -399,7 +399,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
399
399
|
jfloat min_p,
|
400
400
|
jfloat xtc_t,
|
401
401
|
jfloat xtc_p,
|
402
|
-
jfloat tfs_z,
|
403
402
|
jfloat typical_p,
|
404
403
|
jint seed,
|
405
404
|
jobjectArray stop,
|
@@ -438,12 +437,11 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
438
437
|
sparams.top_k = top_k;
|
439
438
|
sparams.top_p = top_p;
|
440
439
|
sparams.min_p = min_p;
|
441
|
-
sparams.tfs_z = tfs_z;
|
442
440
|
sparams.typ_p = typical_p;
|
443
441
|
sparams.n_probs = n_probs;
|
444
442
|
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
|
445
|
-
sparams.
|
446
|
-
sparams.
|
443
|
+
sparams.xtc_threshold = xtc_t;
|
444
|
+
sparams.xtc_probability = xtc_p;
|
447
445
|
|
448
446
|
sparams.logit_bias.clear();
|
449
447
|
if (ignore_eos) {
|
package/cpp/common.cpp
CHANGED
@@ -422,19 +422,6 @@ std::string string_format(const char * fmt, ...) {
|
|
422
422
|
return std::string(buf.data(), size);
|
423
423
|
}
|
424
424
|
|
425
|
-
std::vector<std::string> string_split(std::string input, char separator) {
|
426
|
-
std::vector<std::string> parts;
|
427
|
-
size_t separator_pos = input.find(separator);
|
428
|
-
while (separator_pos != std::string::npos) {
|
429
|
-
std::string part = input.substr(0, separator_pos);
|
430
|
-
parts.emplace_back(part);
|
431
|
-
input = input.substr(separator_pos + 1);
|
432
|
-
separator_pos = input.find(separator);
|
433
|
-
}
|
434
|
-
parts.emplace_back(input);
|
435
|
-
return parts;
|
436
|
-
}
|
437
|
-
|
438
425
|
std::string string_strip(const std::string & str) {
|
439
426
|
size_t start = 0;
|
440
427
|
size_t end = str.size();
|
@@ -1974,6 +1961,8 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
|
1974
1961
|
|
1975
1962
|
void yaml_dump_non_result_info(FILE * stream, const common_params & params, const llama_context * lctx,
|
1976
1963
|
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
1964
|
+
lm_ggml_cpu_init(); // some ARM features are detected at runtime
|
1965
|
+
|
1977
1966
|
const auto & sparams = params.sparams;
|
1978
1967
|
|
1979
1968
|
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
@@ -2029,6 +2018,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|
2029
2018
|
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
2030
2019
|
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
2031
2020
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
2021
|
+
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
|
2022
|
+
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
|
2023
|
+
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
|
2024
|
+
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
|
2032
2025
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
2033
2026
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
2034
2027
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
@@ -2109,7 +2102,6 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|
2109
2102
|
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
2110
2103
|
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
|
2111
2104
|
|
2112
|
-
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
|
2113
2105
|
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
|
2114
2106
|
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
2115
2107
|
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
package/cpp/common.h
CHANGED
@@ -95,14 +95,15 @@ enum llama_example {
|
|
95
95
|
|
96
96
|
enum common_sampler_type {
|
97
97
|
COMMON_SAMPLER_TYPE_NONE = 0,
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
98
|
+
COMMON_SAMPLER_TYPE_DRY = 1,
|
99
|
+
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
100
|
+
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
101
|
+
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
102
|
+
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
103
|
+
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
104
|
+
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
105
|
+
COMMON_SAMPLER_TYPE_XTC = 8,
|
106
|
+
COMMON_SAMPLER_TYPE_INFILL = 9,
|
106
107
|
};
|
107
108
|
|
108
109
|
// dimensionality reduction methods, used by cvector-generator
|
@@ -114,37 +115,40 @@ enum dimre_method {
|
|
114
115
|
// sampler parameters
|
115
116
|
struct common_sampler_params {
|
116
117
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
117
|
-
|
118
|
-
int32_t n_prev
|
119
|
-
int32_t n_probs
|
120
|
-
int32_t min_keep
|
121
|
-
int32_t top_k
|
122
|
-
float top_p
|
123
|
-
float min_p
|
124
|
-
float xtc_probability
|
125
|
-
float xtc_threshold
|
126
|
-
float
|
127
|
-
float
|
128
|
-
float
|
129
|
-
float
|
130
|
-
|
131
|
-
float
|
132
|
-
float
|
133
|
-
|
134
|
-
float
|
135
|
-
float
|
136
|
-
|
137
|
-
int32_t
|
138
|
-
|
139
|
-
float
|
140
|
-
|
141
|
-
bool
|
142
|
-
bool
|
118
|
+
|
119
|
+
int32_t n_prev = 64; // number of previous tokens to remember
|
120
|
+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
121
|
+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
122
|
+
int32_t top_k = 40; // <= 0 to use vocab size
|
123
|
+
float top_p = 0.95f; // 1.0 = disabled
|
124
|
+
float min_p = 0.05f; // 0.0 = disabled
|
125
|
+
float xtc_probability = 0.00f; // 0.0 = disabled
|
126
|
+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
127
|
+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
128
|
+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
129
|
+
float dynatemp_range = 0.00f; // 0.0 = disabled
|
130
|
+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
131
|
+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
132
|
+
float penalty_repeat = 1.00f; // 1.0 = disabled
|
133
|
+
float penalty_freq = 0.00f; // 0.0 = disabled
|
134
|
+
float penalty_present = 0.00f; // 0.0 = disabled
|
135
|
+
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
136
|
+
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
137
|
+
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
138
|
+
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
139
|
+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
140
|
+
float mirostat_tau = 5.00f; // target entropy
|
141
|
+
float mirostat_eta = 0.10f; // learning rate
|
142
|
+
bool penalize_nl = false; // consider newlines as a repeatable token
|
143
|
+
bool ignore_eos = false;
|
144
|
+
bool no_perf = false; // disable performance metrics
|
145
|
+
|
146
|
+
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
143
147
|
|
144
148
|
|
145
149
|
std::vector<enum common_sampler_type> samplers = {
|
150
|
+
COMMON_SAMPLER_TYPE_DRY,
|
146
151
|
COMMON_SAMPLER_TYPE_TOP_K,
|
147
|
-
COMMON_SAMPLER_TYPE_TFS_Z,
|
148
152
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
149
153
|
COMMON_SAMPLER_TYPE_TOP_P,
|
150
154
|
COMMON_SAMPLER_TYPE_MIN_P,
|
@@ -166,7 +170,7 @@ struct common_params {
|
|
166
170
|
llama_progress_callback progress_callback = nullptr;
|
167
171
|
bool vocab_only = false;
|
168
172
|
int32_t n_predict = -1; // new tokens to predict
|
169
|
-
int32_t n_ctx =
|
173
|
+
int32_t n_ctx = 4096; // context size
|
170
174
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
171
175
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
172
176
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
@@ -291,9 +295,9 @@ struct common_params {
|
|
291
295
|
|
292
296
|
// embedding
|
293
297
|
bool embedding = false; // get only sentence embedding
|
294
|
-
int32_t embd_normalize = 2; // normalisation for
|
298
|
+
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
295
299
|
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
296
|
-
std::string embd_sep = "\n"; // separator of
|
300
|
+
std::string embd_sep = "\n"; // separator of embeddings
|
297
301
|
bool reranking = false; // enable reranking support on server
|
298
302
|
|
299
303
|
// server params
|
@@ -397,8 +401,6 @@ bool set_process_priority(enum lm_ggml_sched_priority prio);
|
|
397
401
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
398
402
|
std::string string_format(const char * fmt, ...);
|
399
403
|
|
400
|
-
std::vector<std::string> string_split(std::string input, char separator);
|
401
|
-
|
402
404
|
std::string string_strip(const std::string & str);
|
403
405
|
std::string string_get_sortable_timestamp();
|
404
406
|
|
@@ -406,6 +408,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|
406
408
|
|
407
409
|
template<class T>
|
408
410
|
static std::vector<T> string_split(const std::string & str, char delim) {
|
411
|
+
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
409
412
|
std::vector<T> values;
|
410
413
|
std::istringstream str_stream(str);
|
411
414
|
std::string token;
|
@@ -418,6 +421,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
|
418
421
|
return values;
|
419
422
|
}
|
420
423
|
|
424
|
+
template<>
|
425
|
+
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
426
|
+
{
|
427
|
+
std::vector<std::string> parts;
|
428
|
+
size_t begin_pos = 0;
|
429
|
+
size_t separator_pos = input.find(separator);
|
430
|
+
while (separator_pos != std::string::npos) {
|
431
|
+
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
432
|
+
parts.emplace_back(part);
|
433
|
+
begin_pos = separator_pos + 1;
|
434
|
+
separator_pos = input.find(separator, begin_pos);
|
435
|
+
}
|
436
|
+
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
437
|
+
return parts;
|
438
|
+
}
|
439
|
+
|
421
440
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
422
441
|
void string_process_escapes(std::string & input);
|
423
442
|
|
package/cpp/ggml-aarch64.c
CHANGED
@@ -7,6 +7,7 @@
|
|
7
7
|
|
8
8
|
#include "ggml-quants.h"
|
9
9
|
#include "ggml-impl.h"
|
10
|
+
#include "ggml-cpu.h"
|
10
11
|
#include "ggml-cpu-impl.h"
|
11
12
|
|
12
13
|
#include <math.h>
|
@@ -991,6 +992,73 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
991
992
|
}
|
992
993
|
}
|
993
994
|
return;
|
995
|
+
#elif defined(__riscv_v_intrinsic)
|
996
|
+
if (__riscv_vlenb() >= QK4_0) {
|
997
|
+
const size_t vl = QK4_0;
|
998
|
+
|
999
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
1000
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
1001
|
+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
1002
|
+
|
1003
|
+
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
1004
|
+
for (int l = 0; l < nb; l++) {
|
1005
|
+
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
1006
|
+
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
1007
|
+
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
1008
|
+
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
1009
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
1010
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
1011
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
1012
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
1013
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
1014
|
+
|
1015
|
+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
1016
|
+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
1017
|
+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
1018
|
+
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
1019
|
+
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
1020
|
+
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
1021
|
+
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
1022
|
+
|
1023
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
1024
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
1025
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
1026
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
1027
|
+
|
1028
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
1029
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
1030
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
1031
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
1032
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
1033
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
1034
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
1035
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
1036
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
1037
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
1038
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
1039
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
1040
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
1041
|
+
|
1042
|
+
// vector version needs Zvfhmin extension
|
1043
|
+
const float a_scale = LM_GGML_FP16_TO_FP32(a_ptr[l].d);
|
1044
|
+
const float b_scales[8] = {
|
1045
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
1046
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
1047
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
1048
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
1049
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
1050
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
1051
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
1052
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
1053
|
+
};
|
1054
|
+
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
1055
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
1056
|
+
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
1057
|
+
}
|
1058
|
+
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
1059
|
+
}
|
1060
|
+
return;
|
1061
|
+
}
|
994
1062
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
995
1063
|
{
|
996
1064
|
float sumf[8];
|
@@ -3171,6 +3239,207 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
3171
3239
|
}
|
3172
3240
|
}
|
3173
3241
|
}
|
3242
|
+
return;
|
3243
|
+
}
|
3244
|
+
#elif defined(__riscv_v_intrinsic)
|
3245
|
+
if (__riscv_vlenb() >= QK4_0) {
|
3246
|
+
const size_t vl = QK4_0;
|
3247
|
+
|
3248
|
+
for (int y = 0; y < nr / 4; y++) {
|
3249
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
3250
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
3251
|
+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
3252
|
+
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
3253
|
+
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
3254
|
+
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
3255
|
+
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
3256
|
+
for (int l = 0; l < nb; l++) {
|
3257
|
+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
3258
|
+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
3259
|
+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
3260
|
+
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
3261
|
+
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
3262
|
+
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
3263
|
+
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
3264
|
+
|
3265
|
+
// vector version needs Zvfhmin extension
|
3266
|
+
const float a_scales[4] = {
|
3267
|
+
LM_GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
3268
|
+
LM_GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
3269
|
+
LM_GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
3270
|
+
LM_GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
3271
|
+
};
|
3272
|
+
const float b_scales[8] = {
|
3273
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
3274
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
3275
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
3276
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
3277
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
3278
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
3279
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
3280
|
+
LM_GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
3281
|
+
};
|
3282
|
+
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
3283
|
+
|
3284
|
+
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
3285
|
+
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
3286
|
+
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
3287
|
+
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
3288
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
3289
|
+
vint16m4_t sumi_l0;
|
3290
|
+
{
|
3291
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
3292
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
3293
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
3294
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
3295
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
3296
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
3297
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
3298
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
3299
|
+
|
3300
|
+
sumi_l0 = sumi_hi_m;
|
3301
|
+
}
|
3302
|
+
|
3303
|
+
{
|
3304
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
3305
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
3306
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
3307
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
3308
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
3309
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
3310
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
3311
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
3312
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
3313
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
3314
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
3315
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
3316
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
3317
|
+
|
3318
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
3319
|
+
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
3320
|
+
}
|
3321
|
+
|
3322
|
+
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
3323
|
+
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
3324
|
+
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
3325
|
+
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
3326
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
3327
|
+
vint16m4_t sumi_l1;
|
3328
|
+
{
|
3329
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
3330
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
3331
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
3332
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
3333
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
3334
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
3335
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
3336
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
3337
|
+
|
3338
|
+
sumi_l1 = sumi_hi_m;
|
3339
|
+
}
|
3340
|
+
|
3341
|
+
{
|
3342
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
3343
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
3344
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
3345
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
3346
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
3347
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
3348
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
3349
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
3350
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
3351
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
3352
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
3353
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
3354
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
3355
|
+
|
3356
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
3357
|
+
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
3358
|
+
}
|
3359
|
+
|
3360
|
+
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
3361
|
+
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
3362
|
+
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
3363
|
+
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
3364
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
3365
|
+
vint16m4_t sumi_l2;
|
3366
|
+
{
|
3367
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
3368
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
3369
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
3370
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
3371
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
3372
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
3373
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
3374
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
3375
|
+
|
3376
|
+
sumi_l2 = sumi_hi_m;
|
3377
|
+
}
|
3378
|
+
|
3379
|
+
{
|
3380
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
3381
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
3382
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
3383
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
3384
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
3385
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
3386
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
3387
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
3388
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
3389
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
3390
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
3391
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
3392
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
3393
|
+
|
3394
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
3395
|
+
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
3396
|
+
}
|
3397
|
+
|
3398
|
+
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
3399
|
+
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
3400
|
+
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
3401
|
+
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
3402
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
3403
|
+
vint16m4_t sumi_l3;
|
3404
|
+
{
|
3405
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
3406
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
3407
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
3408
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
3409
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
3410
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
3411
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
3412
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
3413
|
+
|
3414
|
+
sumi_l3 = sumi_hi_m;
|
3415
|
+
}
|
3416
|
+
|
3417
|
+
{
|
3418
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
3419
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
3420
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
3421
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
3422
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
3423
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
3424
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
3425
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
3426
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
3427
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
3428
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
3429
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
3430
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
3431
|
+
|
3432
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
3433
|
+
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
3434
|
+
}
|
3435
|
+
}
|
3436
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
3437
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
3438
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
3439
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
3440
|
+
}
|
3441
|
+
}
|
3442
|
+
|
3174
3443
|
return;
|
3175
3444
|
}
|
3176
3445
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
package/cpp/ggml-backend-impl.h
CHANGED
@@ -22,7 +22,7 @@ extern "C" {
|
|
22
22
|
size_t (*get_max_size) (lm_ggml_backend_buffer_type_t buft);
|
23
23
|
// (optional) data size needed to allocate the tensor, including padding (defaults to lm_ggml_nbytes)
|
24
24
|
size_t (*get_alloc_size)(lm_ggml_backend_buffer_type_t buft, const struct lm_ggml_tensor * tensor);
|
25
|
-
// (optional) check if tensor data is in host memory (defaults to false)
|
25
|
+
// (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
|
26
26
|
bool (*is_host) (lm_ggml_backend_buffer_type_t buft);
|
27
27
|
};
|
28
28
|
|
@@ -37,7 +37,6 @@ extern "C" {
|
|
37
37
|
//
|
38
38
|
|
39
39
|
struct lm_ggml_backend_buffer_i {
|
40
|
-
const char * (*get_name) (lm_ggml_backend_buffer_t buffer);
|
41
40
|
// (optional) free the buffer
|
42
41
|
void (*free_buffer) (lm_ggml_backend_buffer_t buffer);
|
43
42
|
// base address of the buffer
|
@@ -88,19 +87,16 @@ extern "C" {
|
|
88
87
|
|
89
88
|
void (*free)(lm_ggml_backend_t backend);
|
90
89
|
|
91
|
-
// Will be moved to the device interface
|
92
|
-
// buffer allocation
|
93
|
-
lm_ggml_backend_buffer_type_t (*get_default_buffer_type)(lm_ggml_backend_t backend);
|
94
|
-
|
95
90
|
// (optional) asynchronous tensor data access
|
96
91
|
void (*set_tensor_async)(lm_ggml_backend_t backend, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
97
92
|
void (*get_tensor_async)(lm_ggml_backend_t backend, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
98
93
|
bool (*cpy_tensor_async)(lm_ggml_backend_t backend_src, lm_ggml_backend_t backend_dst, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst);
|
99
94
|
|
100
|
-
// (optional) complete all pending operations
|
95
|
+
// (optional) complete all pending operations (required if the backend supports async operations)
|
101
96
|
void (*synchronize)(lm_ggml_backend_t backend);
|
102
97
|
|
103
|
-
// (optional)
|
98
|
+
// (optional) graph plans (not used currently)
|
99
|
+
// compute graph with a plan
|
104
100
|
lm_ggml_backend_graph_plan_t (*graph_plan_create) (lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph);
|
105
101
|
void (*graph_plan_free) (lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan);
|
106
102
|
// update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
|
@@ -111,13 +107,6 @@ extern "C" {
|
|
111
107
|
// compute graph (always async if supported by the backend)
|
112
108
|
enum lm_ggml_status (*graph_compute) (lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph);
|
113
109
|
|
114
|
-
// IMPORTANT: these functions have been moved to the device interface and will be removed from the backend interface
|
115
|
-
// new backends should implement the device interface instead
|
116
|
-
// These functions are being moved to the device interface
|
117
|
-
bool (*supports_op) (lm_ggml_backend_t backend, const struct lm_ggml_tensor * op);
|
118
|
-
bool (*supports_buft)(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft);
|
119
|
-
bool (*offload_op) (lm_ggml_backend_t backend, const struct lm_ggml_tensor * op);
|
120
|
-
|
121
110
|
// (optional) event synchronization
|
122
111
|
// record an event on this stream
|
123
112
|
void (*event_record)(lm_ggml_backend_t backend, lm_ggml_backend_event_t event);
|