@fugood/llama.node 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +6 -6
- package/src/llama.cpp/common/arg.cpp +7 -0
- package/src/llama.cpp/common/common.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +21 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -40
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +0 -4
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +4 -3
- package/src/llama.cpp/src/models/ernie4-5.cpp +4 -5
- package/src/llama.cpp/src/models/openai-moe-iswa.cpp +2 -1
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@fugood/llama.node",
|
|
3
3
|
"access": "public",
|
|
4
|
-
"version": "1.3.
|
|
4
|
+
"version": "1.3.2",
|
|
5
5
|
"description": "An another Node binding of llama.cpp",
|
|
6
6
|
"main": "lib/index.js",
|
|
7
7
|
"scripts": {
|
|
@@ -72,19 +72,19 @@
|
|
|
72
72
|
"CMakeLists.txt"
|
|
73
73
|
],
|
|
74
74
|
"optionalDependencies": {
|
|
75
|
-
"@fugood/node-llama-linux-x64": "1.3.
|
|
76
|
-
"@fugood/node-llama-linux-x64-vulkan": "1.3.
|
|
77
|
-
"@fugood/node-llama-linux-x64-cuda": "1.3.
|
|
78
|
-
"@fugood/node-llama-linux-arm64": "1.3.
|
|
79
|
-
"@fugood/node-llama-linux-arm64-vulkan": "1.3.
|
|
80
|
-
"@fugood/node-llama-linux-arm64-cuda": "1.3.
|
|
81
|
-
"@fugood/node-llama-win32-x64": "1.3.
|
|
82
|
-
"@fugood/node-llama-win32-x64-vulkan": "1.3.
|
|
83
|
-
"@fugood/node-llama-win32-x64-cuda": "1.3.
|
|
84
|
-
"@fugood/node-llama-win32-arm64": "1.3.
|
|
85
|
-
"@fugood/node-llama-win32-arm64-vulkan": "1.3.
|
|
86
|
-
"@fugood/node-llama-darwin-x64": "1.3.
|
|
87
|
-
"@fugood/node-llama-darwin-arm64": "1.3.
|
|
75
|
+
"@fugood/node-llama-linux-x64": "1.3.2",
|
|
76
|
+
"@fugood/node-llama-linux-x64-vulkan": "1.3.2",
|
|
77
|
+
"@fugood/node-llama-linux-x64-cuda": "1.3.2",
|
|
78
|
+
"@fugood/node-llama-linux-arm64": "1.3.2",
|
|
79
|
+
"@fugood/node-llama-linux-arm64-vulkan": "1.3.2",
|
|
80
|
+
"@fugood/node-llama-linux-arm64-cuda": "1.3.2",
|
|
81
|
+
"@fugood/node-llama-win32-x64": "1.3.2",
|
|
82
|
+
"@fugood/node-llama-win32-x64-vulkan": "1.3.2",
|
|
83
|
+
"@fugood/node-llama-win32-x64-cuda": "1.3.2",
|
|
84
|
+
"@fugood/node-llama-win32-arm64": "1.3.2",
|
|
85
|
+
"@fugood/node-llama-win32-arm64-vulkan": "1.3.2",
|
|
86
|
+
"@fugood/node-llama-darwin-x64": "1.3.2",
|
|
87
|
+
"@fugood/node-llama-darwin-arm64": "1.3.2"
|
|
88
88
|
},
|
|
89
89
|
"devDependencies": {
|
|
90
90
|
"@babel/preset-env": "^7.24.4",
|
package/scripts/llama.cpp.patch
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
diff --git a/src/llama.cpp/common/CMakeLists.txt b/src/llama.cpp/common/CMakeLists.txt
|
|
2
|
-
index
|
|
2
|
+
index 7086d08e5..9a727bcf8 100644
|
|
3
3
|
--- a/src/llama.cpp/common/CMakeLists.txt
|
|
4
4
|
+++ b/src/llama.cpp/common/CMakeLists.txt
|
|
5
|
-
@@ -
|
|
5
|
+
@@ -172,9 +172,16 @@ if (LLAMA_LLGUIDANCE)
|
|
6
6
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
|
7
7
|
endif ()
|
|
8
8
|
|
|
@@ -85,10 +85,10 @@ index 50efb0d4e..f471a84c7 100644
|
|
|
85
85
|
struct common_chat_tool_call {
|
|
86
86
|
std::string name;
|
|
87
87
|
diff --git a/src/llama.cpp/common/common.cpp b/src/llama.cpp/common/common.cpp
|
|
88
|
-
index
|
|
88
|
+
index a8d709ab1..d8aed9c7e 100644
|
|
89
89
|
--- a/src/llama.cpp/common/common.cpp
|
|
90
90
|
+++ b/src/llama.cpp/common/common.cpp
|
|
91
|
-
@@ -
|
|
91
|
+
@@ -1159,6 +1159,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|
92
92
|
mparams.n_gpu_layers = params.n_gpu_layers;
|
|
93
93
|
}
|
|
94
94
|
|
|
@@ -97,7 +97,7 @@ index b0591e84b..93759f884 100644
|
|
|
97
97
|
mparams.split_mode = params.split_mode;
|
|
98
98
|
mparams.tensor_split = params.tensor_split;
|
|
99
99
|
diff --git a/src/llama.cpp/common/common.h b/src/llama.cpp/common/common.h
|
|
100
|
-
index
|
|
100
|
+
index f42c083fa..c573cc812 100644
|
|
101
101
|
--- a/src/llama.cpp/common/common.h
|
|
102
102
|
+++ b/src/llama.cpp/common/common.h
|
|
103
103
|
@@ -274,6 +274,7 @@ struct lr_opt {
|
|
@@ -109,7 +109,7 @@ index a8cb630ea..0919ec5d3 100644
|
|
|
109
109
|
int32_t n_ctx = 4096; // context size
|
|
110
110
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
|
111
111
|
diff --git a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
|
|
112
|
-
index
|
|
112
|
+
index a55191aed..53e318c62 100644
|
|
113
113
|
--- a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
|
|
114
114
|
+++ b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
|
|
115
115
|
@@ -106,7 +106,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
@@ -2253,6 +2253,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|
|
2253
2253
|
params.is_pp_shared = true;
|
|
2254
2254
|
}
|
|
2255
2255
|
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
|
2256
|
+
add_opt(common_arg(
|
|
2257
|
+
{"-tgs"},
|
|
2258
|
+
string_format("is the text generation separated across the different sequences (default: %s)", params.is_tg_separate ? "true" : "false"),
|
|
2259
|
+
[](common_params & params) {
|
|
2260
|
+
params.is_tg_separate = true;
|
|
2261
|
+
}
|
|
2262
|
+
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
|
2256
2263
|
add_opt(common_arg(
|
|
2257
2264
|
{"-npp"}, "n0,n1,...",
|
|
2258
2265
|
"number of prompt tokens",
|
|
@@ -461,7 +461,8 @@ struct common_params {
|
|
|
461
461
|
float slot_prompt_similarity = 0.1f;
|
|
462
462
|
|
|
463
463
|
// batched-bench params
|
|
464
|
-
bool is_pp_shared
|
|
464
|
+
bool is_pp_shared = false;
|
|
465
|
+
bool is_tg_separate = false;
|
|
465
466
|
|
|
466
467
|
std::vector<int32_t> n_pp;
|
|
467
468
|
std::vector<int32_t> n_tg;
|
|
@@ -126,25 +126,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
126
126
|
)
|
|
127
127
|
if (NOT ARM_MCPU_RESULT)
|
|
128
128
|
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
|
|
129
|
+
string(REGEX MATCH "-march=[^ ']+" ARM_MARCH_FLAG "${ARM_MCPU}")
|
|
130
|
+
|
|
131
|
+
# on some old GCC we need to read -march=
|
|
132
|
+
if (ARM_MARCH_FLAG AND NOT "${ARM_MARCH_FLAG}" STREQUAL "-march=native")
|
|
133
|
+
set(ARM_NATIVE_FLAG "${ARM_MARCH_FLAG}")
|
|
134
|
+
elseif(ARM_MCPU_FLAG AND NOT "${ARM_MCPU_FLAG}" STREQUAL "-mcpu=native")
|
|
135
|
+
set(ARM_NATIVE_FLAG "${ARM_MCPU_FLAG}")
|
|
136
|
+
endif()
|
|
129
137
|
endif()
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
138
|
+
|
|
139
|
+
if ("${ARM_NATIVE_FLAG}" STREQUAL "")
|
|
140
|
+
set(ARM_NATIVE_FLAG -mcpu=native)
|
|
141
|
+
message(WARNING "ARM -march/-mcpu not found, -mcpu=native will be used")
|
|
142
|
+
else()
|
|
143
|
+
message(STATUS "ARM detected flags: ${ARM_NATIVE_FLAG}")
|
|
133
144
|
endif()
|
|
134
145
|
|
|
135
146
|
include(CheckCXXSourceRuns)
|
|
136
147
|
|
|
137
148
|
function(check_arm_feature tag code)
|
|
138
149
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
|
139
|
-
set(CMAKE_REQUIRED_FLAGS "${
|
|
150
|
+
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
|
140
151
|
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
|
141
152
|
if (GGML_MACHINE_SUPPORTS_${tag})
|
|
142
|
-
set(
|
|
153
|
+
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
|
143
154
|
else()
|
|
144
|
-
set(CMAKE_REQUIRED_FLAGS "${
|
|
155
|
+
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
|
145
156
|
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
|
146
157
|
if (GGML_MACHINE_SUPPORTS_no${tag})
|
|
147
|
-
set(
|
|
158
|
+
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
|
148
159
|
endif()
|
|
149
160
|
endif()
|
|
150
161
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
|
@@ -155,7 +166,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
155
166
|
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
|
156
167
|
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
|
157
168
|
|
|
158
|
-
list(APPEND ARCH_FLAGS "${
|
|
169
|
+
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
|
159
170
|
else()
|
|
160
171
|
if (GGML_CPU_ARM_ARCH)
|
|
161
172
|
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
|
@@ -2044,6 +2044,26 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2044
2044
|
|
|
2045
2045
|
}
|
|
2046
2046
|
|
|
2047
|
+
#ifdef __ARM_FEATURE_SVE
|
|
2048
|
+
static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
|
|
2049
|
+
const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
|
|
2050
|
+
const svbool_t pg_false = svpfalse_b(); // 0x0000
|
|
2051
|
+
const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
|
|
2052
|
+
const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
|
|
2053
|
+
|
|
2054
|
+
svuint32_t vutmp_hi, vutmp_lo;
|
|
2055
|
+
svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
|
|
2056
|
+
vutmp_hi = svzip1_u32(vx01, vx01);
|
|
2057
|
+
vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
|
|
2058
|
+
vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
|
|
2059
|
+
const svuint32_t vx2 = svdup_u32(vx_scales[2]);
|
|
2060
|
+
vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
|
|
2061
|
+
vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
|
|
2062
|
+
svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
|
|
2063
|
+
return vutmp;
|
|
2064
|
+
}
|
|
2065
|
+
#endif
|
|
2066
|
+
|
|
2047
2067
|
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2048
2068
|
assert(n % QK_K == 0);
|
|
2049
2069
|
#ifdef __ARM_FEATURE_MATMUL_INT8
|
|
@@ -2066,8 +2086,220 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2066
2086
|
static const uint32_t kmask3 = 0x03030303;
|
|
2067
2087
|
|
|
2068
2088
|
uint32_t utmp[4];
|
|
2089
|
+
#ifdef __ARM_FEATURE_SVE
|
|
2090
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
2091
|
+
#endif
|
|
2069
2092
|
|
|
2070
|
-
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2093
|
+
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2094
|
+
if (nrc == 2) {
|
|
2095
|
+
svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
|
2096
|
+
|
|
2097
|
+
const block_q4_K * GGML_RESTRICT vx0 = vx;
|
|
2098
|
+
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
|
2099
|
+
const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
|
|
2100
|
+
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
|
2101
|
+
|
|
2102
|
+
union {
|
|
2103
|
+
uint32_t u32[8];
|
|
2104
|
+
uint64_t u64[4];
|
|
2105
|
+
} new_utmp;
|
|
2106
|
+
|
|
2107
|
+
svfloat32_t sumf1 = svdup_n_f32(0);
|
|
2108
|
+
|
|
2109
|
+
switch (vector_length) {
|
|
2110
|
+
case 128:
|
|
2111
|
+
{
|
|
2112
|
+
svbool_t pg_false = svpfalse_b();
|
|
2113
|
+
svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
|
|
2114
|
+
svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
|
|
2115
|
+
svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
|
|
2116
|
+
svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
|
|
2117
|
+
for (int i = 0; i < nb; ++i) {
|
|
2118
|
+
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
|
2119
|
+
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
|
2120
|
+
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
|
2121
|
+
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
|
2122
|
+
svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
|
2123
|
+
svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
|
|
2124
|
+
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
|
2125
|
+
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
|
2126
|
+
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
|
2127
|
+
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
|
2128
|
+
svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
|
|
2129
|
+
svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
|
2130
|
+
svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
|
|
2131
|
+
svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
|
|
2132
|
+
svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
|
2133
|
+
lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
|
|
2134
|
+
hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
|
2135
|
+
sum_tmp1 = svuzp1(lo, hi);
|
|
2136
|
+
sum_tmp2 = svuzp2(lo, hi);
|
|
2137
|
+
svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
|
2138
|
+
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
|
2139
|
+
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
|
2140
|
+
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
|
2141
|
+
svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
|
|
2142
|
+
svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
|
|
2143
|
+
svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
|
|
2144
|
+
svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
|
|
2145
|
+
svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
|
|
2146
|
+
svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
|
|
2147
|
+
svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
|
|
2148
|
+
svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
|
|
2149
|
+
svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
|
|
2150
|
+
svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
|
2151
|
+
svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
|
2152
|
+
svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
|
|
2153
|
+
svint32_t svscales, sumi1, sumi2;
|
|
2154
|
+
svint32_t acc_sumif1 = svdup_n_s32(0);
|
|
2155
|
+
svint32_t acc_sumif2 = svdup_n_s32(0);
|
|
2156
|
+
svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
|
|
2157
|
+
q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
|
|
2158
|
+
#pragma GCC unroll 1
|
|
2159
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
2160
|
+
q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
|
|
2161
|
+
q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
|
|
2162
|
+
q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
|
|
2163
|
+
q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
|
|
2164
|
+
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
|
2165
|
+
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
|
2166
|
+
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
|
2167
|
+
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
|
2168
|
+
q8bytes_0_h = svld1_s8(pg128_all, q8_0);
|
|
2169
|
+
q8bytes_1_h = svld1_s8(pg128_all, q8_1);
|
|
2170
|
+
q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
|
|
2171
|
+
q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
|
|
2172
|
+
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
|
2173
|
+
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
|
2174
|
+
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
|
2175
|
+
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
|
2176
|
+
sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
|
2177
|
+
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
|
2178
|
+
acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
|
|
2179
|
+
|
|
2180
|
+
q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
|
|
2181
|
+
q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
|
|
2182
|
+
q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
|
|
2183
|
+
q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
|
|
2184
|
+
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
|
2185
|
+
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
|
2186
|
+
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
|
2187
|
+
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
|
2188
|
+
q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
|
|
2189
|
+
q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
|
|
2190
|
+
q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
|
|
2191
|
+
q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
|
|
2192
|
+
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
|
2193
|
+
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
|
2194
|
+
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
|
2195
|
+
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
|
2196
|
+
sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
|
2197
|
+
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
|
2198
|
+
acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
|
|
2199
|
+
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
|
2200
|
+
}
|
|
2201
|
+
sumf1 = svmla_f32_x(pg128_all,
|
|
2202
|
+
svmla_f32_x(pg128_all,
|
|
2203
|
+
sumf1,
|
|
2204
|
+
svcvt_f32_x(pg128_all,
|
|
2205
|
+
svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
|
|
2206
|
+
svsuper_block_scales),
|
|
2207
|
+
svdmins,
|
|
2208
|
+
svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
|
|
2209
|
+
} //end of for nb
|
|
2210
|
+
} // end of case 128
|
|
2211
|
+
break;
|
|
2212
|
+
case 256:
|
|
2213
|
+
case 512:
|
|
2214
|
+
{
|
|
2215
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
|
2216
|
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
|
2217
|
+
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
|
2218
|
+
for (int i = 0; i < nb; ++i) {
|
|
2219
|
+
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
|
2220
|
+
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
|
2221
|
+
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
|
2222
|
+
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
|
2223
|
+
svint32_t svscales, sumi1, sumi2;
|
|
2224
|
+
svint32_t acc_sumif1 = svdup_n_s32(0);
|
|
2225
|
+
svint32_t acc_sumif2 = svdup_n_s32(0);
|
|
2226
|
+
svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
|
|
2227
|
+
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
|
2228
|
+
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
|
2229
|
+
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
|
2230
|
+
svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
|
|
2231
|
+
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
|
2232
|
+
svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
|
2233
|
+
svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
|
|
2234
|
+
svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
|
|
2235
|
+
svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
|
2236
|
+
svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
|
2237
|
+
svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
|
|
2238
|
+
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
|
2239
|
+
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
|
2240
|
+
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
|
2241
|
+
svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
|
|
2242
|
+
svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
|
2243
|
+
svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
|
2244
|
+
svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
|
|
2245
|
+
svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
|
|
2246
|
+
svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
|
|
2247
|
+
svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
|
|
2248
|
+
svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
|
|
2249
|
+
svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
|
|
2250
|
+
svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
|
|
2251
|
+
svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
|
|
2252
|
+
|
|
2253
|
+
#pragma GCC unroll 1
|
|
2254
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
2255
|
+
svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
|
|
2256
|
+
svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
|
|
2257
|
+
svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
|
|
2258
|
+
svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
|
|
2259
|
+
l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
|
2260
|
+
l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
|
2261
|
+
l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
|
2262
|
+
l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
|
2263
|
+
svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
|
|
2264
|
+
svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
|
|
2265
|
+
svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
|
|
2266
|
+
svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
|
|
2267
|
+
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2268
|
+
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2269
|
+
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
|
2270
|
+
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
|
2271
|
+
sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
|
|
2272
|
+
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
|
2273
|
+
acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
|
|
2274
|
+
sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
|
|
2275
|
+
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
|
2276
|
+
acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
|
|
2277
|
+
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
|
2278
|
+
}
|
|
2279
|
+
svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
|
|
2280
|
+
svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
|
|
2281
|
+
acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
|
|
2282
|
+
sumf1 = svmla_f32_x(pg32_4,
|
|
2283
|
+
svmla_f32_x(pg32_4,
|
|
2284
|
+
sumf1,
|
|
2285
|
+
svcvt_f32_x(pg32_4, acc_sumif),
|
|
2286
|
+
svsuper_block_scales),
|
|
2287
|
+
svdmins,
|
|
2288
|
+
svsumfs_tmp);
|
|
2289
|
+
} // end of for nb
|
|
2290
|
+
} // end of case 256-512
|
|
2291
|
+
break;
|
|
2292
|
+
default:
|
|
2293
|
+
assert(false && "Unsupported vector length");
|
|
2294
|
+
break;
|
|
2295
|
+
}
|
|
2296
|
+
|
|
2297
|
+
svst1_f32(pg32_2, s, sumf1);
|
|
2298
|
+
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
|
|
2299
|
+
|
|
2300
|
+
return;
|
|
2301
|
+
}
|
|
2302
|
+
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2071
2303
|
if (nrc == 2) {
|
|
2072
2304
|
const block_q4_K * GGML_RESTRICT x0 = x;
|
|
2073
2305
|
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
|
@@ -2235,7 +2467,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2235
2467
|
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
|
2236
2468
|
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2237
2469
|
|
|
2238
|
-
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
2239
2470
|
const svuint8_t m4b = svdup_n_u8(0xf);
|
|
2240
2471
|
const svint32_t mzero = svdup_n_s32(0);
|
|
2241
2472
|
svint32_t sumi1 = svdup_n_s32(0);
|
|
@@ -2480,7 +2711,201 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2480
2711
|
|
|
2481
2712
|
const int nb = n / QK_K;
|
|
2482
2713
|
|
|
2483
|
-
#
|
|
2714
|
+
#ifdef __ARM_FEATURE_SVE
|
|
2715
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
2716
|
+
#endif
|
|
2717
|
+
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2718
|
+
if (nrc == 2) {
|
|
2719
|
+
const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
|
2720
|
+
|
|
2721
|
+
svfloat32_t sum = svdup_n_f32(0);
|
|
2722
|
+
|
|
2723
|
+
const block_q6_K * GGML_RESTRICT vx0 = vx;
|
|
2724
|
+
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
|
2725
|
+
const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
|
|
2726
|
+
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
|
2727
|
+
|
|
2728
|
+
switch (vector_length) {
|
|
2729
|
+
case 128:
|
|
2730
|
+
{
|
|
2731
|
+
const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
|
|
2732
|
+
for (int i = 0; i < nb; ++i) {
|
|
2733
|
+
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
|
2734
|
+
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
|
2735
|
+
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
|
2736
|
+
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
|
2737
|
+
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
|
2738
|
+
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
|
2739
|
+
|
|
2740
|
+
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
|
2741
|
+
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
|
2742
|
+
|
|
2743
|
+
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
|
2744
|
+
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
|
2745
|
+
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
|
2746
|
+
// process q8sum summation 128 bit route
|
|
2747
|
+
const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
|
|
2748
|
+
const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
|
2749
|
+
const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
|
|
2750
|
+
const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
|
2751
|
+
const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
|
|
2752
|
+
const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
|
|
2753
|
+
const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
|
|
2754
|
+
const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
|
|
2755
|
+
const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
|
|
2756
|
+
const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
|
|
2757
|
+
const svint64_t prod = svdup_n_s64(0);
|
|
2758
|
+
|
|
2759
|
+
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
|
|
2760
|
+
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
|
|
2761
|
+
svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
|
2762
|
+
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
|
|
2763
|
+
svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
|
|
2764
|
+
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
|
|
2765
|
+
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
|
2766
|
+
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
|
2767
|
+
svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
|
|
2768
|
+
|
|
2769
|
+
// process mmla
|
|
2770
|
+
svint8_t l0, l1, r0, r1;
|
|
2771
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
2772
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
2773
|
+
for (int k = 0; k < 8; ++k) {
|
|
2774
|
+
svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
|
|
2775
|
+
svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
|
|
2776
|
+
svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
|
|
2777
|
+
svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
|
|
2778
|
+
const int ql_pos = (k/4)*4;
|
|
2779
|
+
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
|
|
2780
|
+
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
|
|
2781
|
+
const int qh_pos = (k/2)*2;
|
|
2782
|
+
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
|
|
2783
|
+
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
|
|
2784
|
+
svint8_t q6bytes_0, q6bytes_1;
|
|
2785
|
+
if (qh_pos <= 4) {
|
|
2786
|
+
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
|
2787
|
+
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
|
2788
|
+
} else {
|
|
2789
|
+
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
|
|
2790
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
|
|
2791
|
+
}
|
|
2792
|
+
svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
|
|
2793
|
+
svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
|
|
2794
|
+
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
|
2795
|
+
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
|
2796
|
+
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2797
|
+
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2798
|
+
svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
|
2799
|
+
isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
|
|
2800
|
+
}
|
|
2801
|
+
qh0 += 32; qh1 += 32;
|
|
2802
|
+
ql0 += 64; ql1 += 64;
|
|
2803
|
+
q80 += 128; q81 += 128;
|
|
2804
|
+
scale0 += 8; scale1 += 8;
|
|
2805
|
+
}
|
|
2806
|
+
sum = svmla_f32_x(pg128_all, sum,
|
|
2807
|
+
svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
|
|
2808
|
+
svisum_mins, svdup_n_s32(-32))),
|
|
2809
|
+
svsuper_block_scales);
|
|
2810
|
+
}
|
|
2811
|
+
} // end of case 128
|
|
2812
|
+
break;
|
|
2813
|
+
case 256:
|
|
2814
|
+
case 512:
|
|
2815
|
+
{
|
|
2816
|
+
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
|
2817
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
|
2818
|
+
for (int i = 0; i < nb; ++i) {
|
|
2819
|
+
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
|
2820
|
+
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
|
2821
|
+
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
|
2822
|
+
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
|
2823
|
+
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
|
2824
|
+
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
|
2825
|
+
|
|
2826
|
+
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
|
2827
|
+
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
|
2828
|
+
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
|
2829
|
+
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
|
2830
|
+
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
|
2831
|
+
svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
|
|
2832
|
+
// process q8sum summation 256 bit route
|
|
2833
|
+
const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
|
|
2834
|
+
const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
|
|
2835
|
+
const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
|
|
2836
|
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
|
|
2837
|
+
const svint64_t prod = svdup_n_s64(0);
|
|
2838
|
+
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
|
|
2839
|
+
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
|
|
2840
|
+
svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
|
|
2841
|
+
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
|
|
2842
|
+
svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
|
2843
|
+
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
|
|
2844
|
+
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
|
2845
|
+
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
|
2846
|
+
svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
|
|
2847
|
+
svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
|
|
2848
|
+
svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
|
|
2849
|
+
|
|
2850
|
+
// process mmla
|
|
2851
|
+
svint8_t l0, l1, r0, r1;
|
|
2852
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
|
2853
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
2854
|
+
for (int k = 0; k < 8; k+=2) { // process 2 block
|
|
2855
|
+
svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
|
|
2856
|
+
svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
|
|
2857
|
+
svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
|
|
2858
|
+
svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
|
|
2859
|
+
const int ql_pos = (k/4)*4;
|
|
2860
|
+
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
|
|
2861
|
+
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
|
|
2862
|
+
const int qh_pos = (k/2)*2;
|
|
2863
|
+
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
|
|
2864
|
+
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
|
|
2865
|
+
svint8_t q6bytes_0, q6bytes_1;
|
|
2866
|
+
if (qh_pos <= 4) {
|
|
2867
|
+
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
|
2868
|
+
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
|
2869
|
+
} else {
|
|
2870
|
+
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
|
|
2871
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
|
|
2872
|
+
}
|
|
2873
|
+
svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
|
|
2874
|
+
svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
|
|
2875
|
+
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
|
2876
|
+
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
|
2877
|
+
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2878
|
+
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
|
2879
|
+
svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
|
2880
|
+
svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
|
|
2881
|
+
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
|
|
2882
|
+
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
|
|
2883
|
+
}
|
|
2884
|
+
qh0 += 32; qh1 += 32;
|
|
2885
|
+
ql0 += 64; ql1 += 64;
|
|
2886
|
+
q80 += 128; q81 += 128;
|
|
2887
|
+
scale0 += 8; scale1 += 8;
|
|
2888
|
+
} // end of for
|
|
2889
|
+
svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
|
|
2890
|
+
isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
|
|
2891
|
+
sum = svmla_f32_x(pg32_4, sum,
|
|
2892
|
+
svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
|
|
2893
|
+
svisum_mins, svdup_n_s32(-32))),
|
|
2894
|
+
svsuper_block_scales);
|
|
2895
|
+
}
|
|
2896
|
+
} // end of case 256
|
|
2897
|
+
break;
|
|
2898
|
+
default:
|
|
2899
|
+
assert(false && "Unsupported vector length");
|
|
2900
|
+
break;
|
|
2901
|
+
} // end of switch
|
|
2902
|
+
|
|
2903
|
+
svst1_f32(pg32_2, s, sum);
|
|
2904
|
+
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
|
|
2905
|
+
|
|
2906
|
+
return;
|
|
2907
|
+
}
|
|
2908
|
+
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2484
2909
|
if (nrc == 2) {
|
|
2485
2910
|
const block_q6_K * GGML_RESTRICT x0 = x;
|
|
2486
2911
|
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
|
@@ -2594,27 +3019,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2594
3019
|
// adjust bias, apply superblock scale
|
|
2595
3020
|
{
|
|
2596
3021
|
int32_t bias[4];
|
|
2597
|
-
#ifdef __ARM_FEATURE_SVE
|
|
2598
|
-
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
|
2599
|
-
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
|
2600
|
-
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
|
2601
|
-
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
|
2602
|
-
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
|
2603
|
-
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
|
2604
|
-
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
|
2605
|
-
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
|
2606
|
-
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
|
2607
|
-
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
|
2608
|
-
const svint64_t zero = svdup_n_s64(0);
|
|
2609
|
-
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
|
2610
|
-
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
|
2611
|
-
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
|
2612
|
-
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
|
2613
|
-
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
|
2614
|
-
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
|
2615
|
-
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
|
2616
|
-
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
|
2617
|
-
#else
|
|
2618
3022
|
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
|
2619
3023
|
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
|
2620
3024
|
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
|
@@ -2646,7 +3050,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2646
3050
|
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
|
2647
3051
|
bias[3] = vaddvq_s32(prod);
|
|
2648
3052
|
|
|
2649
|
-
#endif
|
|
2650
3053
|
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
|
2651
3054
|
|
|
2652
3055
|
const float32x4_t superblock_scale = {
|
|
@@ -2672,7 +3075,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2672
3075
|
#endif
|
|
2673
3076
|
|
|
2674
3077
|
#ifdef __ARM_FEATURE_SVE
|
|
2675
|
-
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
2676
3078
|
float sum = 0;
|
|
2677
3079
|
svuint8_t m4b = svdup_n_u8(0xf);
|
|
2678
3080
|
svint32_t vzero = svdup_n_s32(0);
|
|
@@ -1807,22 +1807,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1807
1807
|
{
|
|
1808
1808
|
ggml_compute_forward_cont(params, tensor);
|
|
1809
1809
|
} break;
|
|
1810
|
-
case GGML_OP_RESHAPE:
|
|
1811
|
-
{
|
|
1812
|
-
ggml_compute_forward_reshape(params, tensor);
|
|
1813
|
-
} break;
|
|
1814
|
-
case GGML_OP_VIEW:
|
|
1815
|
-
{
|
|
1816
|
-
ggml_compute_forward_view(params, tensor);
|
|
1817
|
-
} break;
|
|
1818
|
-
case GGML_OP_PERMUTE:
|
|
1819
|
-
{
|
|
1820
|
-
ggml_compute_forward_permute(params, tensor);
|
|
1821
|
-
} break;
|
|
1822
|
-
case GGML_OP_TRANSPOSE:
|
|
1823
|
-
{
|
|
1824
|
-
ggml_compute_forward_transpose(params, tensor);
|
|
1825
|
-
} break;
|
|
1826
1810
|
case GGML_OP_GET_ROWS:
|
|
1827
1811
|
{
|
|
1828
1812
|
ggml_compute_forward_get_rows(params, tensor);
|
|
@@ -2042,6 +2026,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
2042
2026
|
{
|
|
2043
2027
|
// nop
|
|
2044
2028
|
} break;
|
|
2029
|
+
case GGML_OP_RESHAPE:
|
|
2030
|
+
{
|
|
2031
|
+
// nop
|
|
2032
|
+
} break;
|
|
2033
|
+
case GGML_OP_PERMUTE:
|
|
2034
|
+
{
|
|
2035
|
+
// nop
|
|
2036
|
+
} break;
|
|
2037
|
+
case GGML_OP_VIEW:
|
|
2038
|
+
{
|
|
2039
|
+
// nop
|
|
2040
|
+
} break;
|
|
2041
|
+
case GGML_OP_TRANSPOSE:
|
|
2042
|
+
{
|
|
2043
|
+
// nop
|
|
2044
|
+
} break;
|
|
2045
2045
|
case GGML_OP_COUNT:
|
|
2046
2046
|
{
|
|
2047
2047
|
GGML_ABORT("fatal error");
|
|
@@ -2884,6 +2884,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|
|
2884
2884
|
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
|
2885
2885
|
struct ggml_tensor * node = cgraph->nodes[node_n];
|
|
2886
2886
|
|
|
2887
|
+
if (ggml_op_is_empty(node->op)) {
|
|
2888
|
+
// skip NOPs
|
|
2889
|
+
continue;
|
|
2890
|
+
}
|
|
2891
|
+
|
|
2887
2892
|
ggml_compute_forward(¶ms, node);
|
|
2888
2893
|
|
|
2889
2894
|
if (state->ith == 0 && cplan->abort_callback &&
|
|
@@ -4455,46 +4455,6 @@ void ggml_compute_forward_cont(
|
|
|
4455
4455
|
ggml_compute_forward_dup(params, dst);
|
|
4456
4456
|
}
|
|
4457
4457
|
|
|
4458
|
-
// ggml_compute_forward_reshape
|
|
4459
|
-
|
|
4460
|
-
void ggml_compute_forward_reshape(
|
|
4461
|
-
const ggml_compute_params * params,
|
|
4462
|
-
ggml_tensor * dst) {
|
|
4463
|
-
// NOP
|
|
4464
|
-
GGML_UNUSED(params);
|
|
4465
|
-
GGML_UNUSED(dst);
|
|
4466
|
-
}
|
|
4467
|
-
|
|
4468
|
-
// ggml_compute_forward_view
|
|
4469
|
-
|
|
4470
|
-
void ggml_compute_forward_view(
|
|
4471
|
-
const ggml_compute_params * params,
|
|
4472
|
-
ggml_tensor * dst) {
|
|
4473
|
-
// NOP
|
|
4474
|
-
GGML_UNUSED(params);
|
|
4475
|
-
GGML_UNUSED(dst);
|
|
4476
|
-
}
|
|
4477
|
-
|
|
4478
|
-
// ggml_compute_forward_permute
|
|
4479
|
-
|
|
4480
|
-
void ggml_compute_forward_permute(
|
|
4481
|
-
const ggml_compute_params * params,
|
|
4482
|
-
ggml_tensor * dst) {
|
|
4483
|
-
// NOP
|
|
4484
|
-
GGML_UNUSED(params);
|
|
4485
|
-
GGML_UNUSED(dst);
|
|
4486
|
-
}
|
|
4487
|
-
|
|
4488
|
-
// ggml_compute_forward_transpose
|
|
4489
|
-
|
|
4490
|
-
void ggml_compute_forward_transpose(
|
|
4491
|
-
const ggml_compute_params * params,
|
|
4492
|
-
ggml_tensor * dst) {
|
|
4493
|
-
// NOP
|
|
4494
|
-
GGML_UNUSED(params);
|
|
4495
|
-
GGML_UNUSED(dst);
|
|
4496
|
-
}
|
|
4497
|
-
|
|
4498
4458
|
// ggml_compute_forward_get_rows
|
|
4499
4459
|
|
|
4500
4460
|
static void ggml_compute_forward_get_rows_q(
|
|
@@ -51,10 +51,6 @@ void ggml_compute_forward_scale(const struct ggml_compute_params * params, struc
|
|
|
51
51
|
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
52
52
|
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
53
53
|
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
54
|
-
void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
55
|
-
void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
56
|
-
void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
57
|
-
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
58
54
|
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
59
55
|
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
60
56
|
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
151
151
|
p1 = std::numeric_limits<llama_pos>::max();
|
|
152
152
|
}
|
|
153
153
|
|
|
154
|
-
// models like Mamba or RWKV can't have a state partially erased
|
|
154
|
+
// models like Mamba or RWKV can't have a state partially erased at the end
|
|
155
|
+
// of the sequence because their state isn't preserved for previous tokens
|
|
155
156
|
if (seq_id >= (int64_t) size) {
|
|
156
157
|
// could be fatal
|
|
157
158
|
return false;
|
|
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
|
160
161
|
int32_t & tail_id = cells[seq_id].tail;
|
|
161
162
|
if (tail_id >= 0) {
|
|
162
163
|
const auto & cell = cells[tail_id];
|
|
163
|
-
// partial intersection is invalid
|
|
164
|
-
if (
|
|
164
|
+
// partial intersection is invalid if it includes the final pos
|
|
165
|
+
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
|
165
166
|
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
|
|
166
167
|
return false;
|
|
167
168
|
}
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
#include "models.h"
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
3
|
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
|
|
6
4
|
llm_graph_context(params) {
|
|
7
5
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
@@ -19,6 +17,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap
|
|
|
19
17
|
|
|
20
18
|
auto * inp_attn = build_attn_inp_kv();
|
|
21
19
|
|
|
20
|
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
21
|
+
|
|
22
22
|
for (int il = 0; il < n_layer; ++il) {
|
|
23
23
|
ggml_tensor * inpSA = inpL;
|
|
24
24
|
|
|
@@ -67,9 +67,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap
|
|
|
67
67
|
}
|
|
68
68
|
if (il == n_layer - 1) {
|
|
69
69
|
// skip computing output for unused tokens
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
70
|
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
71
|
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
73
72
|
}
|
|
74
73
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
75
74
|
cb(ffn_inp, "ffn_inp", il);
|
|
@@ -11,6 +11,8 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
|
|
|
11
11
|
|
|
12
12
|
auto * inp_attn = build_attn_inp_kv_iswa();
|
|
13
13
|
|
|
14
|
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
15
|
+
|
|
14
16
|
for (int il = 0; il < n_layer; ++il) {
|
|
15
17
|
ggml_tensor * inpSA = inpL;
|
|
16
18
|
|
|
@@ -69,7 +71,6 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
|
|
|
69
71
|
}
|
|
70
72
|
if (il == n_layer - 1) {
|
|
71
73
|
// skip computing output for unused tokens
|
|
72
|
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
73
74
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
74
75
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
75
76
|
}
|