@fugood/llama.node 0.3.15 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +243 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +14 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +161 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1544 -291
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -45,6 +45,24 @@ using block_q4_0x8 = block<4, 8>;
|
|
|
45
45
|
using block_q8_0x4 = block<8, 4>;
|
|
46
46
|
using block_q8_0x8 = block<8, 8>;
|
|
47
47
|
|
|
48
|
+
|
|
49
|
+
struct block_q4_Kx8 {
|
|
50
|
+
ggml_half d[8]; // super-block scale for quantized scales
|
|
51
|
+
ggml_half dmin[8]; // super-block scale for quantized mins
|
|
52
|
+
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
|
53
|
+
uint8_t qs[1024]; // 4--bit quants
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
|
57
|
+
|
|
58
|
+
struct block_q8_Kx4 {
|
|
59
|
+
float d[4]; // delta
|
|
60
|
+
int8_t qs[QK_K * 4]; // quants
|
|
61
|
+
int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
|
|
65
|
+
|
|
48
66
|
struct block_iq4_nlx4 {
|
|
49
67
|
ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
50
68
|
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
@@ -60,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
|
|
|
60
78
|
|
|
61
79
|
#define UNUSED GGML_UNUSED
|
|
62
80
|
|
|
81
|
+
static inline int nearest_int(float fval) {
|
|
82
|
+
assert(fabsf(fval) <= 4194303.f);
|
|
83
|
+
float val = fval + 12582912.f;
|
|
84
|
+
int i; memcpy(&i, &val, sizeof(int));
|
|
85
|
+
return (i & 0x007fffff) - 0x00400000;
|
|
86
|
+
}
|
|
87
|
+
|
|
63
88
|
// Functions to create the interleaved data layout formats
|
|
64
89
|
|
|
65
90
|
// interleave 4 block_q4_0s in blocks of blck_size_interleave
|
|
@@ -158,74 +183,70 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
|
|
|
158
183
|
|
|
159
184
|
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
160
185
|
#if defined(__AVX512F__)
|
|
161
|
-
// add int16_t pairwise and return as 512 bit int vector
|
|
162
|
-
static inline __m512i
|
|
186
|
+
// add int16_t pairwise and return as 512 bit int vector, then add the accumulator
|
|
187
|
+
static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) {
|
|
163
188
|
const __m512i ones = _mm512_set1_epi16(1);
|
|
164
|
-
return _mm512_madd_epi16(ones, x);
|
|
189
|
+
return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x));
|
|
165
190
|
}
|
|
166
191
|
|
|
167
|
-
static inline __m512i
|
|
192
|
+
static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) {
|
|
168
193
|
#if defined(__AVX512VNNI__)
|
|
169
|
-
|
|
170
|
-
return _mm512_dpbusd_epi32(zero, ax, sy);
|
|
194
|
+
return _mm512_dpbusd_epi32(acc, ax, sy);
|
|
171
195
|
#else
|
|
172
196
|
// Perform multiplication and create 16-bit values
|
|
173
197
|
const __m512i dot = _mm512_maddubs_epi16(ax, sy);
|
|
174
|
-
return
|
|
198
|
+
return sum_i16_pairs_acc_int32x16(acc, dot);
|
|
175
199
|
#endif
|
|
176
200
|
}
|
|
177
201
|
|
|
178
|
-
// multiply int8_t, add results pairwise twice and return as 512 bit int vector
|
|
179
|
-
static inline __m512i
|
|
202
|
+
// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator
|
|
203
|
+
static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) {
|
|
180
204
|
const __m512i zero = _mm512_setzero_si512();
|
|
181
205
|
// Get absolute values of x vectors
|
|
182
206
|
const __m512i ax = _mm512_abs_epi8(x);
|
|
183
207
|
// Sign the values of the y vectors
|
|
184
208
|
__mmask64 blt0 = _mm512_movepi8_mask(x);
|
|
185
209
|
const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
|
|
186
|
-
return
|
|
210
|
+
return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy);
|
|
187
211
|
}
|
|
188
212
|
#endif
|
|
189
213
|
|
|
190
|
-
// add int16_t pairwise and return as 256 bit int vector
|
|
191
|
-
static inline __m256i
|
|
214
|
+
// add int16_t pairwise and return as 256 bit int vector, then add the accumulator
|
|
215
|
+
static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) {
|
|
192
216
|
const __m256i ones = _mm256_set1_epi16(1);
|
|
193
|
-
return _mm256_madd_epi16(ones, x);
|
|
217
|
+
return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x));
|
|
194
218
|
}
|
|
195
219
|
|
|
196
|
-
static inline __m256i
|
|
220
|
+
static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) {
|
|
197
221
|
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
198
|
-
|
|
199
|
-
return _mm256_dpbusd_epi32(zero, ax, sy);
|
|
222
|
+
return _mm256_dpbusd_epi32(acc, ax, sy);
|
|
200
223
|
#elif defined(__AVXVNNI__)
|
|
201
|
-
|
|
202
|
-
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
224
|
+
return _mm256_dpbusd_avx_epi32(acc, ax, sy);
|
|
203
225
|
#else
|
|
204
226
|
// Perform multiplication and create 16-bit values
|
|
205
227
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
206
|
-
return
|
|
228
|
+
return sum_i16_pairs_acc_int32x8(acc, dot);
|
|
207
229
|
#endif
|
|
208
230
|
}
|
|
209
231
|
|
|
210
232
|
// Integer variant of the function defined in ggml-quants.c
|
|
211
|
-
// multiply int8_t, add results pairwise twice and return as 256 bit int vector
|
|
212
|
-
static inline __m256i
|
|
213
|
-
#if __AVXVNNIINT8__
|
|
214
|
-
|
|
215
|
-
return _mm256_dpbssd_epi32(zero, x, y);
|
|
233
|
+
// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator
|
|
234
|
+
static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) {
|
|
235
|
+
#if defined(__AVXVNNIINT8__)
|
|
236
|
+
return _mm256_dpbssd_epi32(acc, x, y);
|
|
216
237
|
#else
|
|
217
238
|
// Get absolute values of x vectors
|
|
218
239
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
|
219
240
|
// Sign the values of the y vectors
|
|
220
241
|
const __m256i sy = _mm256_sign_epi8(y, x);
|
|
221
|
-
return
|
|
242
|
+
return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy);
|
|
222
243
|
#endif
|
|
223
244
|
}
|
|
224
245
|
#endif
|
|
225
246
|
|
|
226
247
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
227
248
|
|
|
228
|
-
static void
|
|
249
|
+
static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
229
250
|
assert(QK8_0 == 32);
|
|
230
251
|
assert(k % QK8_0 == 0);
|
|
231
252
|
const int nb = k / QK8_0;
|
|
@@ -319,7 +340,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
|
|
|
319
340
|
#endif
|
|
320
341
|
}
|
|
321
342
|
|
|
322
|
-
static void
|
|
343
|
+
static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
323
344
|
assert(QK8_0 == 32);
|
|
324
345
|
assert(k % QK8_0 == 0);
|
|
325
346
|
const int nb = k / QK8_0;
|
|
@@ -534,16 +555,289 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
|
|
|
534
555
|
#endif
|
|
535
556
|
}
|
|
536
557
|
|
|
537
|
-
static void
|
|
558
|
+
static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
559
|
+
assert(QK_K == 256);
|
|
560
|
+
assert(k % QK_K == 0);
|
|
561
|
+
const int nb = k / QK_K;
|
|
562
|
+
|
|
563
|
+
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
|
564
|
+
|
|
565
|
+
#if defined(__AVX2__)
|
|
566
|
+
float iscale[4];
|
|
567
|
+
__m256 srcv[4][32];
|
|
568
|
+
__m256 iscale_vec[4];
|
|
569
|
+
|
|
570
|
+
for (int i = 0; i < nb; i++) {
|
|
571
|
+
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
|
572
|
+
// Load elements into 4 AVX vectors
|
|
573
|
+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
|
574
|
+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
|
575
|
+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
|
576
|
+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
|
577
|
+
|
|
578
|
+
// Compute max(abs(e)) for the block
|
|
579
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
|
580
|
+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
|
581
|
+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
|
582
|
+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
|
583
|
+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
|
584
|
+
|
|
585
|
+
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
|
586
|
+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
|
587
|
+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
|
588
|
+
|
|
589
|
+
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
|
590
|
+
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
|
591
|
+
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
|
592
|
+
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
|
593
|
+
|
|
594
|
+
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
|
595
|
+
|
|
596
|
+
srcv[row_iter][0] = v0;
|
|
597
|
+
srcv[row_iter][1] = v1;
|
|
598
|
+
srcv[row_iter][2] = v2;
|
|
599
|
+
srcv[row_iter][3] = v3;
|
|
600
|
+
|
|
601
|
+
for (int sb = 1; sb < 8; sb++) {
|
|
602
|
+
// Temporarily stores absolute quant values
|
|
603
|
+
__m256 tempAbs = maxAbs;
|
|
604
|
+
|
|
605
|
+
// Load elements into 4 AVX vectors
|
|
606
|
+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
|
607
|
+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
|
608
|
+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
|
609
|
+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
|
610
|
+
|
|
611
|
+
// Compute max(abs(e)) for the block
|
|
612
|
+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
|
613
|
+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
|
614
|
+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
|
615
|
+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
|
616
|
+
|
|
617
|
+
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
|
618
|
+
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
|
619
|
+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
|
620
|
+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
|
621
|
+
|
|
622
|
+
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
|
623
|
+
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
|
624
|
+
|
|
625
|
+
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
|
626
|
+
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
|
627
|
+
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
|
628
|
+
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
|
629
|
+
|
|
630
|
+
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
|
631
|
+
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
|
632
|
+
|
|
633
|
+
srcv[row_iter][sb * 4] = v0;
|
|
634
|
+
srcv[row_iter][sb * 4 + 1] = v1;
|
|
635
|
+
srcv[row_iter][sb * 4 + 2] = v2;
|
|
636
|
+
srcv[row_iter][sb * 4 + 3] = v3;
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
|
640
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
|
641
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
|
642
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
|
643
|
+
|
|
644
|
+
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
|
645
|
+
|
|
646
|
+
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
|
647
|
+
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
|
648
|
+
|
|
649
|
+
const int mask = _mm256_movemask_ps(finalMask);
|
|
650
|
+
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
|
651
|
+
|
|
652
|
+
if(mask) {
|
|
653
|
+
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
|
657
|
+
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
__m256i quants_interleaved[32];
|
|
661
|
+
for (int j = 0; j < 32; j++) {
|
|
662
|
+
// Apply the multiplier
|
|
663
|
+
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
|
664
|
+
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
|
665
|
+
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
|
666
|
+
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
|
667
|
+
|
|
668
|
+
// Round to nearest integer
|
|
669
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
|
670
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
|
671
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
|
672
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
|
673
|
+
|
|
674
|
+
// Convert floats to integers
|
|
675
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
|
676
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
|
677
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
|
678
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
|
679
|
+
|
|
680
|
+
// Convert int32 to int16
|
|
681
|
+
i0 = _mm256_packs_epi32( i0, i1 );
|
|
682
|
+
i2 = _mm256_packs_epi32( i2, i3 );
|
|
683
|
+
// Convert int16 to int8
|
|
684
|
+
i0 = _mm256_packs_epi16( i0, i2 );
|
|
685
|
+
|
|
686
|
+
// Permute and store the quantized weights in the required order after the pack instruction
|
|
687
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
|
688
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
|
689
|
+
|
|
690
|
+
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
|
691
|
+
quants_interleaved[j] = i0;
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
|
695
|
+
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
|
696
|
+
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
|
697
|
+
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
|
698
|
+
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
|
699
|
+
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
|
700
|
+
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
|
701
|
+
|
|
702
|
+
for (int k = 0; k < 4; k++) {
|
|
703
|
+
// Quants from four different sub blocks are taken
|
|
704
|
+
__m256i q0 = quants_interleaved[k * 8 + 0];
|
|
705
|
+
__m256i q1 = quants_interleaved[k * 8 + 1];
|
|
706
|
+
__m256i q2 = quants_interleaved[k * 8 + 2];
|
|
707
|
+
__m256i q3 = quants_interleaved[k * 8 + 3];
|
|
708
|
+
__m256i q4 = quants_interleaved[k * 8 + 4];
|
|
709
|
+
__m256i q5 = quants_interleaved[k * 8 + 5];
|
|
710
|
+
__m256i q6 = quants_interleaved[k * 8 + 6];
|
|
711
|
+
__m256i q7 = quants_interleaved[k * 8 + 7];
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
|
715
|
+
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
|
716
|
+
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
|
717
|
+
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
|
718
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
|
719
|
+
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
|
720
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
|
721
|
+
|
|
722
|
+
__m256i one = _mm256_set1_epi8(1);
|
|
723
|
+
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
|
724
|
+
|
|
725
|
+
for (int l = 0; l < 3; l++) {
|
|
726
|
+
// Quants value shifted to process next two values from each sub block
|
|
727
|
+
q0 = _mm256_srli_epi64(q0, 16);
|
|
728
|
+
q2 = _mm256_srli_epi64(q2, 16);
|
|
729
|
+
q4 = _mm256_srli_epi64(q4, 16);
|
|
730
|
+
q6 = _mm256_srli_epi64(q6, 16);
|
|
731
|
+
|
|
732
|
+
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
|
733
|
+
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
|
734
|
+
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
|
735
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
|
736
|
+
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
|
737
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
|
738
|
+
|
|
739
|
+
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
|
743
|
+
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
|
744
|
+
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
|
745
|
+
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
|
746
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
|
747
|
+
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
|
748
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
|
749
|
+
|
|
750
|
+
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
|
751
|
+
|
|
752
|
+
for (int l = 0; l < 3; l++) {
|
|
753
|
+
// Quants value shifted to process next two values from each sub block
|
|
754
|
+
q1 = _mm256_srli_epi64(q1, 16);
|
|
755
|
+
q3 = _mm256_srli_epi64(q3, 16);
|
|
756
|
+
q5 = _mm256_srli_epi64(q5, 16);
|
|
757
|
+
q7 = _mm256_srli_epi64(q7, 16);
|
|
758
|
+
|
|
759
|
+
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
|
760
|
+
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
|
761
|
+
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
|
762
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
|
763
|
+
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
|
764
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
|
765
|
+
|
|
766
|
+
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
// Overall bsums in interleaved fashion computed by adding results of both halves
|
|
770
|
+
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
|
771
|
+
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
#else
|
|
776
|
+
|
|
777
|
+
// scalar
|
|
778
|
+
const int blck_size_interleave = 8;
|
|
779
|
+
float srcv[4][QK_K];
|
|
780
|
+
float iscale[4];
|
|
781
|
+
|
|
782
|
+
for (int i = 0; i < nb; i++) {
|
|
783
|
+
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
|
784
|
+
float amax = 0.0f; // absolute max
|
|
785
|
+
float max = 0;
|
|
786
|
+
|
|
787
|
+
for (int j = 0; j < QK_K; j++) {
|
|
788
|
+
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
|
|
789
|
+
// Update the maximum value of the corresponding super block
|
|
790
|
+
if(amax < fabsf(srcv[row_iter][j])) {
|
|
791
|
+
amax = fabsf(srcv[row_iter][j]);
|
|
792
|
+
max = srcv[row_iter][j];
|
|
793
|
+
}
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
iscale[row_iter] = amax ? -127.f/max : 0;
|
|
797
|
+
|
|
798
|
+
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
for (int j = 0; j < QK_K / 4; j++) {
|
|
802
|
+
y[i].bsums[j] = 0;
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
// Quants values are interleaved in sequence of eight bytes from corresponding super blocks
|
|
806
|
+
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
|
|
807
|
+
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
|
|
808
|
+
for (int j = 0; j < QK_K * 4; j++) {
|
|
809
|
+
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
|
810
|
+
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
|
811
|
+
src_offset += (j % blck_size_interleave);
|
|
812
|
+
int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
|
|
813
|
+
|
|
814
|
+
float x0 = srcv[src_id][src_offset] * iscale[src_id];
|
|
815
|
+
y[i].qs[j] = nearest_int(x0);
|
|
816
|
+
y[i].bsums[index] += y[i].qs[j];
|
|
817
|
+
}
|
|
818
|
+
}
|
|
819
|
+
#endif
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
|
823
|
+
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
|
|
824
|
+
|
|
825
|
+
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
|
538
826
|
assert(nrow == 4);
|
|
539
827
|
UNUSED(nrow);
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
828
|
+
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
|
832
|
+
assert(nrow == 4);
|
|
833
|
+
UNUSED(nrow);
|
|
834
|
+
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
|
838
|
+
assert(nrow == 4);
|
|
839
|
+
UNUSED(nrow);
|
|
840
|
+
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
|
547
841
|
}
|
|
548
842
|
|
|
549
843
|
static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
@@ -877,17 +1171,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
877
1171
|
// ...........................................................................
|
|
878
1172
|
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
|
879
1173
|
|
|
880
|
-
iacc =
|
|
881
|
-
iacc =
|
|
1174
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0));
|
|
1175
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85));
|
|
882
1176
|
|
|
883
|
-
iacc =
|
|
884
|
-
iacc =
|
|
1177
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170));
|
|
1178
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255));
|
|
885
1179
|
|
|
886
|
-
iacc =
|
|
887
|
-
iacc =
|
|
1180
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0));
|
|
1181
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85));
|
|
888
1182
|
|
|
889
|
-
iacc =
|
|
890
|
-
iacc =
|
|
1183
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170));
|
|
1184
|
+
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255));
|
|
891
1185
|
|
|
892
1186
|
// Accumulated values multipled with appropriate scales
|
|
893
1187
|
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
|
@@ -994,6 +1288,281 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
994
1288
|
}
|
|
995
1289
|
}
|
|
996
1290
|
|
|
1291
|
+
static void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1292
|
+
const int qk = QK_K;
|
|
1293
|
+
const int nb = n / qk;
|
|
1294
|
+
const int ncols_interleaved = 8;
|
|
1295
|
+
const int blocklen = 8;
|
|
1296
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
|
1297
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
|
1298
|
+
static const uint32_t kmask3 = 0x03030303;
|
|
1299
|
+
|
|
1300
|
+
assert (n % qk == 0);
|
|
1301
|
+
assert (nc % ncols_interleaved == 0);
|
|
1302
|
+
|
|
1303
|
+
UNUSED(s);
|
|
1304
|
+
UNUSED(bs);
|
|
1305
|
+
UNUSED(vx);
|
|
1306
|
+
UNUSED(vy);
|
|
1307
|
+
UNUSED(nr);
|
|
1308
|
+
UNUSED(nc);
|
|
1309
|
+
UNUSED(nb);
|
|
1310
|
+
UNUSED(ncols_interleaved);
|
|
1311
|
+
UNUSED(blocklen);
|
|
1312
|
+
|
|
1313
|
+
#if defined(__AVX2__)
|
|
1314
|
+
// Lookup table to convert signed nibbles to signed bytes
|
|
1315
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
1316
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1317
|
+
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
|
|
1318
|
+
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
|
1319
|
+
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
|
1320
|
+
// Permute mask used for easier vector processing at later stages
|
|
1321
|
+
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
|
1322
|
+
|
|
1323
|
+
// Mask to extract nibbles from bytes
|
|
1324
|
+
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
1325
|
+
|
|
1326
|
+
int64_t b_nb = n / QK_K;
|
|
1327
|
+
|
|
1328
|
+
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
|
|
1329
|
+
const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
|
|
1330
|
+
|
|
1331
|
+
// Process Q8_K blocks one by one
|
|
1332
|
+
for (int64_t y = 0; y < nr; y++) {
|
|
1333
|
+
|
|
1334
|
+
// Pointers to LHS blocks of block_q8_K format
|
|
1335
|
+
const block_q8_K * a_ptr = a_ptr_start + (y * nb);
|
|
1336
|
+
|
|
1337
|
+
// Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
|
|
1338
|
+
for (int64_t x = 0; x < nc / 8; x++) {
|
|
1339
|
+
|
|
1340
|
+
// Pointers to RHS blocks
|
|
1341
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
1342
|
+
|
|
1343
|
+
// Master FP accumulators
|
|
1344
|
+
__m256 acc_row = _mm256_setzero_ps();
|
|
1345
|
+
__m256 acc_min_rows = _mm256_setzero_ps();
|
|
1346
|
+
|
|
1347
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
1348
|
+
|
|
1349
|
+
// Load and convert to FP32 scale from block_q8_K
|
|
1350
|
+
const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
|
|
1351
|
+
|
|
1352
|
+
// Load the scale values for the 8 blocks interleaved in block_q4_Kx8
|
|
1353
|
+
// col_scale_f32 rearranged so as to multiply with appropriate quants
|
|
1354
|
+
const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
|
|
1355
|
+
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
1356
|
+
|
|
1357
|
+
__m256i iacc_b = _mm256_setzero_si256();
|
|
1358
|
+
__m256i iacc_min_b = _mm256_setzero_si256();
|
|
1359
|
+
|
|
1360
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
|
|
1361
|
+
__m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
|
|
1362
|
+
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
|
|
1363
|
+
|
|
1364
|
+
// Processes two sub blocks from each Q4_K in each iteration
|
|
1365
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
1366
|
+
|
|
1367
|
+
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1368
|
+
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
|
1369
|
+
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
|
1370
|
+
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
1371
|
+
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
1372
|
+
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
|
1373
|
+
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
|
1374
|
+
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
1375
|
+
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
1376
|
+
|
|
1377
|
+
// 4-bit -> 8-bit
|
|
1378
|
+
// Values of the first sub block of eight block_q4_K structures for the sb loop
|
|
1379
|
+
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
|
1380
|
+
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
|
1381
|
+
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
|
1382
|
+
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
|
1383
|
+
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
|
1384
|
+
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
|
1385
|
+
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
|
1386
|
+
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
|
1387
|
+
|
|
1388
|
+
// Values of the second sub block of eight block_q4_K structures when sb = 1
|
|
1389
|
+
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
|
|
1390
|
+
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
|
1391
|
+
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
|
1392
|
+
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
|
1393
|
+
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
|
1394
|
+
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
|
1395
|
+
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
|
1396
|
+
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
|
1397
|
+
|
|
1398
|
+
uint32_t utmp_0[4], utmp_1[4];
|
|
1399
|
+
|
|
1400
|
+
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
|
|
1401
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
1402
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
|
1403
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
|
1404
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
|
1405
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
|
1406
|
+
utmp_0[2] = uaux_0;
|
|
1407
|
+
utmp_0[0] &= kmask1;
|
|
1408
|
+
|
|
1409
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
1410
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
|
1411
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
|
1412
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
|
1413
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
|
1414
|
+
utmp_1[2] = uaux_1;
|
|
1415
|
+
utmp_1[0] &= kmask1;
|
|
1416
|
+
|
|
1417
|
+
// Scales of first sub block in the sb loop
|
|
1418
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
|
1419
|
+
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
|
|
1420
|
+
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
|
1421
|
+
|
|
1422
|
+
// Scales of second sub block in the sb loop
|
|
1423
|
+
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
|
1424
|
+
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
|
|
1425
|
+
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
|
1426
|
+
|
|
1427
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
1428
|
+
__m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
|
1429
|
+
|
|
1430
|
+
// Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
|
1431
|
+
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
|
|
1432
|
+
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
|
|
1433
|
+
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
|
|
1434
|
+
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
|
|
1435
|
+
|
|
1436
|
+
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
|
|
1437
|
+
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
|
|
1438
|
+
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
|
|
1439
|
+
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
|
|
1440
|
+
|
|
1441
|
+
// Dot product done within 32 bit lanes and accumulated in the same vector
|
|
1442
|
+
// First done for first sub block and thenn for second sub block in each sb
|
|
1443
|
+
// B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
|
1444
|
+
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
|
1445
|
+
// ...........................................................................
|
|
1446
|
+
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
|
1447
|
+
|
|
1448
|
+
|
|
1449
|
+
__m256i iacc_0 = _mm256_setzero_si256();
|
|
1450
|
+
__m256i iacc_1 = _mm256_setzero_si256();
|
|
1451
|
+
|
|
1452
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
|
|
1453
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
|
|
1454
|
+
|
|
1455
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
|
|
1456
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
|
|
1457
|
+
|
|
1458
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
|
|
1459
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
|
|
1460
|
+
|
|
1461
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
|
|
1462
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
|
|
1463
|
+
|
|
1464
|
+
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
|
1465
|
+
|
|
1466
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
|
|
1467
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
|
|
1468
|
+
|
|
1469
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
|
|
1470
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
|
|
1471
|
+
|
|
1472
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
|
|
1473
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
|
|
1474
|
+
|
|
1475
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
|
|
1476
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
|
|
1477
|
+
|
|
1478
|
+
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
|
1479
|
+
|
|
1480
|
+
// Accumulate the iacc value for one sb
|
|
1481
|
+
__m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
|
|
1482
|
+
|
|
1483
|
+
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
|
|
1484
|
+
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
|
|
1485
|
+
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
|
|
1486
|
+
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
|
1487
|
+
q8s = _mm256_bsrli_epi128(q8s, 4);
|
|
1488
|
+
|
|
1489
|
+
// Accumulate for the complete block
|
|
1490
|
+
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
|
1491
|
+
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
|
|
1492
|
+
}
|
|
1493
|
+
|
|
1494
|
+
// Multiply-Add with scale values for the complete super block
|
|
1495
|
+
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
|
1496
|
+
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
|
|
1497
|
+
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
// Accumulated output values permuted so as to be stored in appropriate order post accumulation
|
|
1501
|
+
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
|
|
1502
|
+
_mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
|
|
1503
|
+
}
|
|
1504
|
+
}
|
|
1505
|
+
|
|
1506
|
+
#else
|
|
1507
|
+
|
|
1508
|
+
float sumf[8];
|
|
1509
|
+
float sum_minf[8];
|
|
1510
|
+
uint32_t utmp[32];
|
|
1511
|
+
int sumi1;
|
|
1512
|
+
int sumi2;
|
|
1513
|
+
int sumi;
|
|
1514
|
+
|
|
1515
|
+
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
|
1516
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1517
|
+
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
1518
|
+
|
|
1519
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1520
|
+
sumf[j] = 0.0;
|
|
1521
|
+
sum_minf[j] = 0.0;
|
|
1522
|
+
}
|
|
1523
|
+
for (int l = 0; l < nb; l++) {
|
|
1524
|
+
for (int sb = 0; sb < 8; sb++) {
|
|
1525
|
+
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
|
1526
|
+
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
|
1527
|
+
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
|
1528
|
+
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
|
1529
|
+
utmp[sb * 4 + 2] = uaux_0;
|
|
1530
|
+
utmp[sb * 4 + 0] &= kmask1;
|
|
1531
|
+
}
|
|
1532
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
1533
|
+
uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
|
|
1534
|
+
uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
|
|
1535
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1536
|
+
sumi1 = 0;
|
|
1537
|
+
sumi2 = 0;
|
|
1538
|
+
sumi = 0;
|
|
1539
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
1540
|
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
|
1541
|
+
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
|
1542
|
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
|
|
1543
|
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
|
|
1544
|
+
sumi1 = sumi1 * scales_0[j];
|
|
1545
|
+
sumi2 = sumi2 * scales_1[j];
|
|
1546
|
+
sumi += sumi1 + sumi2;
|
|
1547
|
+
}
|
|
1548
|
+
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
|
1549
|
+
}
|
|
1550
|
+
}
|
|
1551
|
+
for (int sb = 0; sb < 8; sb++) {
|
|
1552
|
+
uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
|
|
1553
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1554
|
+
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
|
1555
|
+
}
|
|
1556
|
+
}
|
|
1557
|
+
}
|
|
1558
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1559
|
+
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
|
1560
|
+
}
|
|
1561
|
+
}
|
|
1562
|
+
#endif
|
|
1563
|
+
}
|
|
1564
|
+
|
|
1565
|
+
|
|
997
1566
|
static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
998
1567
|
const int qk = QK8_0;
|
|
999
1568
|
const int nb = n / qk;
|
|
@@ -2666,22 +3235,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
2666
3235
|
|
|
2667
3236
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2668
3237
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
2669
|
-
__m512i
|
|
2670
|
-
|
|
2671
|
-
__m512i iacc_mat_01_sp1 =
|
|
2672
|
-
|
|
2673
|
-
__m512i
|
|
2674
|
-
|
|
2675
|
-
__m512i
|
|
2676
|
-
|
|
2677
|
-
__m512i
|
|
2678
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
|
2679
|
-
__m512i iacc_mat_01_sp2 =
|
|
2680
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
|
2681
|
-
__m512i iacc_mat_10_sp2 =
|
|
2682
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
|
2683
|
-
__m512i iacc_mat_11_sp2 =
|
|
2684
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
|
3238
|
+
const __m512i zero = _mm512_setzero_epi32();
|
|
3239
|
+
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
3240
|
+
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
3241
|
+
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
3242
|
+
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
3243
|
+
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
3244
|
+
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
3245
|
+
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
3246
|
+
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
2685
3247
|
|
|
2686
3248
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
2687
3249
|
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
@@ -2857,22 +3419,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
2857
3419
|
|
|
2858
3420
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2859
3421
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
2860
|
-
__m512i
|
|
2861
|
-
|
|
2862
|
-
__m512i iacc_mat_01_sp1 =
|
|
2863
|
-
|
|
2864
|
-
__m512i
|
|
2865
|
-
|
|
2866
|
-
__m512i
|
|
2867
|
-
|
|
2868
|
-
__m512i
|
|
2869
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
|
2870
|
-
__m512i iacc_mat_01_sp2 =
|
|
2871
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
|
2872
|
-
__m512i iacc_mat_10_sp2 =
|
|
2873
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
|
2874
|
-
__m512i iacc_mat_11_sp2 =
|
|
2875
|
-
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
|
3422
|
+
const __m512i zero = _mm512_setzero_epi32();
|
|
3423
|
+
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
3424
|
+
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
3425
|
+
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
3426
|
+
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
3427
|
+
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
3428
|
+
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
3429
|
+
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
3430
|
+
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
2876
3431
|
|
|
2877
3432
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
2878
3433
|
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
@@ -3032,22 +3587,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
3032
3587
|
|
|
3033
3588
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
3034
3589
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
3035
|
-
__m256i
|
|
3036
|
-
|
|
3037
|
-
__m256i iacc_mat_01_sp1 =
|
|
3038
|
-
|
|
3039
|
-
__m256i
|
|
3040
|
-
|
|
3041
|
-
__m256i
|
|
3042
|
-
|
|
3043
|
-
__m256i
|
|
3044
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
|
3045
|
-
__m256i iacc_mat_01_sp2 =
|
|
3046
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
|
3047
|
-
__m256i iacc_mat_10_sp2 =
|
|
3048
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
|
3049
|
-
__m256i iacc_mat_11_sp2 =
|
|
3050
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
|
3590
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
3591
|
+
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
|
3592
|
+
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
3593
|
+
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
3594
|
+
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
3595
|
+
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
3596
|
+
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
3597
|
+
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
3598
|
+
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
3051
3599
|
|
|
3052
3600
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
3053
3601
|
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
@@ -3196,22 +3744,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
3196
3744
|
|
|
3197
3745
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
3198
3746
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
3199
|
-
__m256i
|
|
3200
|
-
|
|
3201
|
-
__m256i iacc_mat_01_sp1 =
|
|
3202
|
-
|
|
3203
|
-
__m256i
|
|
3204
|
-
|
|
3205
|
-
__m256i
|
|
3206
|
-
|
|
3207
|
-
__m256i
|
|
3208
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
|
3209
|
-
__m256i iacc_mat_01_sp2 =
|
|
3210
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
|
3211
|
-
__m256i iacc_mat_10_sp2 =
|
|
3212
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
|
3213
|
-
__m256i iacc_mat_11_sp2 =
|
|
3214
|
-
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
|
3747
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
3748
|
+
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
|
3749
|
+
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
3750
|
+
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
3751
|
+
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
3752
|
+
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
3753
|
+
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
3754
|
+
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
3755
|
+
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
3215
3756
|
|
|
3216
3757
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
3217
3758
|
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
@@ -3480,11 +4021,14 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
3480
4021
|
}
|
|
3481
4022
|
}
|
|
3482
4023
|
|
|
3483
|
-
static void
|
|
3484
|
-
const int qk =
|
|
4024
|
+
static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
4025
|
+
const int qk = QK_K;
|
|
3485
4026
|
const int nb = n / qk;
|
|
3486
|
-
const int ncols_interleaved =
|
|
3487
|
-
const int blocklen =
|
|
4027
|
+
const int ncols_interleaved = 8;
|
|
4028
|
+
const int blocklen = 8;
|
|
4029
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
|
4030
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
|
4031
|
+
static const uint32_t kmask3 = 0x03030303;
|
|
3488
4032
|
|
|
3489
4033
|
assert (n % qk == 0);
|
|
3490
4034
|
assert (nr % 4 == 0);
|
|
@@ -3500,62 +4044,1574 @@ static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
|
|
|
3500
4044
|
UNUSED(ncols_interleaved);
|
|
3501
4045
|
UNUSED(blocklen);
|
|
3502
4046
|
|
|
3503
|
-
#if
|
|
3504
|
-
|
|
3505
|
-
|
|
4047
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
4048
|
+
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
|
|
4049
|
+
const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
|
|
4050
|
+
int64_t b_nb = n / QK_K;
|
|
4051
|
+
int64_t y = 0;
|
|
3506
4052
|
|
|
3507
|
-
|
|
3508
|
-
|
|
3509
|
-
|
|
3510
|
-
|
|
4053
|
+
// Mask to mask out nibbles from packed bytes
|
|
4054
|
+
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
4055
|
+
// Permute mask used for easier vector processing at later stages
|
|
4056
|
+
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
|
4057
|
+
int64_t xstart = 0;
|
|
4058
|
+
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
|
4059
|
+
#ifdef __AVX512F__
|
|
4060
|
+
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
|
4061
|
+
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
|
4062
|
+
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
|
4063
|
+
//Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
|
4064
|
+
for (; y < anr / 4; y += 4) {
|
|
4065
|
+
|
|
4066
|
+
const block_q8_Kx4 * a_ptrs[4];
|
|
4067
|
+
|
|
4068
|
+
a_ptrs[0] = a_ptr_start + (y * nb);
|
|
4069
|
+
for (int i = 0; i < 3; ++i) {
|
|
4070
|
+
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
4071
|
+
}
|
|
3511
4072
|
|
|
3512
|
-
|
|
3513
|
-
|
|
3514
|
-
sumf[m] = vdupq_n_f32(0);
|
|
3515
|
-
}
|
|
4073
|
+
// Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
|
|
4074
|
+
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
3516
4075
|
|
|
3517
|
-
|
|
3518
|
-
|
|
3519
|
-
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
4076
|
+
const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
4077
|
+
const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
3520
4078
|
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
4079
|
+
// Master FP accumulators
|
|
4080
|
+
__m512 acc_rows[16];
|
|
4081
|
+
for (int i = 0; i < 16; i++) {
|
|
4082
|
+
acc_rows[i] = _mm512_setzero_ps();
|
|
4083
|
+
}
|
|
3525
4084
|
|
|
3526
|
-
|
|
3527
|
-
|
|
3528
|
-
|
|
4085
|
+
__m512 acc_min_rows[16];
|
|
4086
|
+
for (int i = 0; i < 16; i++) {
|
|
4087
|
+
acc_min_rows[i] = _mm512_setzero_ps();
|
|
4088
|
+
}
|
|
3529
4089
|
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
4090
|
+
// For super block
|
|
4091
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
4092
|
+
// Scale values - Load the sixteen scale values from two block_q4_kx8 structures
|
|
4093
|
+
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
4094
|
+
|
|
4095
|
+
// dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
|
|
4096
|
+
const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
|
|
4097
|
+
|
|
4098
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
|
4099
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
4100
|
+
|
|
4101
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
|
|
4102
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
|
|
4103
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
|
|
4104
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
|
|
4105
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
|
|
4106
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
|
|
4107
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
|
|
4108
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
|
|
4109
|
+
|
|
4110
|
+
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
|
|
4111
|
+
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
|
|
4112
|
+
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
|
|
4113
|
+
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
|
|
4114
|
+
const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
|
|
4115
|
+
const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
|
|
4116
|
+
const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
|
|
4117
|
+
const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
|
|
3533
4118
|
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3539
|
-
|
|
3540
|
-
|
|
3541
|
-
|
|
3542
|
-
}
|
|
4119
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
4120
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
4121
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
4122
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
4123
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
|
4124
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
|
4125
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
|
4126
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
|
3543
4127
|
|
|
3544
|
-
|
|
3545
|
-
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
|
|
4128
|
+
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
4129
|
+
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
4130
|
+
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
4131
|
+
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
4132
|
+
const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
|
|
4133
|
+
const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
|
|
4134
|
+
const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
|
|
4135
|
+
const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
|
|
3549
4136
|
|
|
3550
|
-
|
|
3551
|
-
|
|
3552
|
-
|
|
3553
|
-
|
|
3554
|
-
|
|
3555
|
-
|
|
3556
|
-
|
|
3557
|
-
|
|
3558
|
-
|
|
4137
|
+
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
|
4138
|
+
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
|
4139
|
+
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
|
4140
|
+
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
|
4141
|
+
|
|
4142
|
+
const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
|
|
4143
|
+
const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
|
|
4144
|
+
const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
|
|
4145
|
+
const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
|
|
4146
|
+
|
|
4147
|
+
//4-bit -> 8-bit
|
|
4148
|
+
const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
|
|
4149
|
+
const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
|
|
4150
|
+
const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
|
|
4151
|
+
const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
|
|
4152
|
+
|
|
4153
|
+
const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
|
|
4154
|
+
const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
|
|
4155
|
+
const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
|
|
4156
|
+
const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
|
|
4157
|
+
|
|
4158
|
+
const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
|
|
4159
|
+
const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
|
|
4160
|
+
const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
|
|
4161
|
+
const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
|
|
4162
|
+
|
|
4163
|
+
const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
|
|
4164
|
+
const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
|
|
4165
|
+
const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
|
|
4166
|
+
const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
|
|
4167
|
+
|
|
4168
|
+
// Shuffle pattern one - right side input
|
|
4169
|
+
const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
|
|
4170
|
+
const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
|
|
4171
|
+
const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
|
|
4172
|
+
const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
|
|
4173
|
+
const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
|
|
4174
|
+
const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
|
|
4175
|
+
const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
|
|
4176
|
+
const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
|
|
4177
|
+
|
|
4178
|
+
const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
|
|
4179
|
+
const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
|
|
4180
|
+
const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
|
|
4181
|
+
const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
|
|
4182
|
+
const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
|
|
4183
|
+
const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
|
|
4184
|
+
const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
|
|
4185
|
+
const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
|
|
4186
|
+
|
|
4187
|
+
// Shuffle pattern two - right side input
|
|
4188
|
+
const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
|
|
4189
|
+
const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
|
|
4190
|
+
const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
|
|
4191
|
+
const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
|
|
4192
|
+
const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
|
|
4193
|
+
const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
|
|
4194
|
+
const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
|
|
4195
|
+
const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
|
|
4196
|
+
|
|
4197
|
+
const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
|
|
4198
|
+
const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
|
|
4199
|
+
const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
|
|
4200
|
+
const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
|
|
4201
|
+
const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
|
|
4202
|
+
const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
|
|
4203
|
+
const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
|
|
4204
|
+
const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
|
|
4205
|
+
|
|
4206
|
+
uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
|
|
4207
|
+
|
|
4208
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
|
4209
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4210
|
+
memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
|
|
4211
|
+
utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
|
|
4212
|
+
const uint32_t uaux_00 = utmp_00[1] & kmask1;
|
|
4213
|
+
utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
|
|
4214
|
+
utmp_00[2] = uaux_00;
|
|
4215
|
+
utmp_00[0] &= kmask1;
|
|
4216
|
+
|
|
4217
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4218
|
+
memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
|
|
4219
|
+
utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
|
|
4220
|
+
const uint32_t uaux_01 = utmp_01[1] & kmask1;
|
|
4221
|
+
utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
|
|
4222
|
+
utmp_01[2] = uaux_01;
|
|
4223
|
+
utmp_01[0] &= kmask1;
|
|
4224
|
+
|
|
4225
|
+
memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
|
|
4226
|
+
utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
|
|
4227
|
+
const uint32_t uaux_10 = utmp_10[1] & kmask1;
|
|
4228
|
+
utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
|
|
4229
|
+
utmp_10[2] = uaux_10;
|
|
4230
|
+
utmp_10[0] &= kmask1;
|
|
4231
|
+
|
|
4232
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4233
|
+
memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
|
|
4234
|
+
utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
|
|
4235
|
+
const uint32_t uaux_11 = utmp_11[1] & kmask1;
|
|
4236
|
+
utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
|
|
4237
|
+
utmp_11[2] = uaux_11;
|
|
4238
|
+
utmp_11[0] &= kmask1;
|
|
4239
|
+
|
|
4240
|
+
// Scales of first sub block in the sb loop
|
|
4241
|
+
const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
|
|
4242
|
+
const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
|
4243
|
+
|
|
4244
|
+
// Scales of second sub block in the sb loop
|
|
4245
|
+
const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
|
|
4246
|
+
const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
|
4247
|
+
|
|
4248
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
4249
|
+
const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
|
|
4250
|
+
|
|
4251
|
+
const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
|
|
4252
|
+
const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
|
|
4253
|
+
|
|
4254
|
+
const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
|
|
4255
|
+
const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
|
|
4256
|
+
|
|
4257
|
+
for (int rp = 0; rp < 4; rp++) {
|
|
4258
|
+
|
|
4259
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
|
4260
|
+
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
|
4261
|
+
__m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
|
|
4262
|
+
__m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
|
|
4263
|
+
__m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
|
|
4264
|
+
__m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
|
|
4265
|
+
__m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
|
|
4266
|
+
__m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
|
|
4267
|
+
__m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
|
|
4268
|
+
__m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
|
|
4269
|
+
__m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
|
|
4270
|
+
__m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
|
|
4271
|
+
__m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
|
|
4272
|
+
__m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
|
|
4273
|
+
__m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
|
|
4274
|
+
__m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
|
|
4275
|
+
__m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
|
|
4276
|
+
__m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
|
|
4277
|
+
__m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
|
|
4278
|
+
__m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
|
|
4279
|
+
__m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
|
|
4280
|
+
__m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
|
|
4281
|
+
__m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
|
|
4282
|
+
__m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
|
|
4283
|
+
__m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
|
|
4284
|
+
__m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
|
|
4285
|
+
|
|
4286
|
+
__m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
|
|
4287
|
+
__m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
|
|
4288
|
+
__m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
|
|
4289
|
+
__m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
|
|
4290
|
+
__m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
|
|
4291
|
+
__m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
|
|
4292
|
+
__m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
|
|
4293
|
+
__m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
|
|
4294
|
+
|
|
4295
|
+
__m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
|
|
4296
|
+
__m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
|
|
4297
|
+
__m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
|
|
4298
|
+
__m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
|
|
4299
|
+
__m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
|
|
4300
|
+
__m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
|
|
4301
|
+
__m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
|
|
4302
|
+
__m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
|
|
4303
|
+
|
|
4304
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
|
4305
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
|
|
4306
|
+
__m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
|
4307
|
+
lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
|
|
4308
|
+
__m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
|
|
4309
|
+
|
|
4310
|
+
// Shuffle pattern one - left side input
|
|
4311
|
+
const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
|
4312
|
+
const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
|
|
4313
|
+
const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
|
4314
|
+
const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
|
|
4315
|
+
const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
|
4316
|
+
const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
|
|
4317
|
+
const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
|
4318
|
+
const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
|
|
4319
|
+
|
|
4320
|
+
const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
|
4321
|
+
const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
|
|
4322
|
+
const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
|
4323
|
+
const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
|
|
4324
|
+
const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
|
4325
|
+
const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
|
|
4326
|
+
const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
|
4327
|
+
const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
|
|
4328
|
+
|
|
4329
|
+
const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
|
4330
|
+
const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
|
|
4331
|
+
const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
|
4332
|
+
const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
|
|
4333
|
+
const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
|
4334
|
+
const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
|
|
4335
|
+
const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
|
4336
|
+
const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
|
|
4337
|
+
|
|
4338
|
+
const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
|
4339
|
+
const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
|
|
4340
|
+
const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
|
4341
|
+
const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
|
|
4342
|
+
const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
|
4343
|
+
const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
|
|
4344
|
+
const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
|
4345
|
+
const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
|
|
4346
|
+
|
|
4347
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
4348
|
+
__m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
|
|
4349
|
+
__m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
|
|
4350
|
+
__m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
|
|
4351
|
+
__m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
|
|
4352
|
+
__m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
|
|
4353
|
+
__m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
|
|
4354
|
+
__m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
|
|
4355
|
+
__m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
|
|
4356
|
+
|
|
4357
|
+
__m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
|
|
4358
|
+
__m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
|
|
4359
|
+
__m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
|
|
4360
|
+
__m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
|
|
4361
|
+
__m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
|
|
4362
|
+
__m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
|
|
4363
|
+
__m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
|
|
4364
|
+
__m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
|
|
4365
|
+
|
|
4366
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
4367
|
+
__m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
|
4368
|
+
__m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
|
4369
|
+
__m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
|
4370
|
+
__m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
|
4371
|
+
|
|
4372
|
+
__m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
|
4373
|
+
__m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
|
4374
|
+
__m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
|
4375
|
+
__m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
|
4376
|
+
|
|
4377
|
+
iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
|
|
4378
|
+
iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
|
|
4379
|
+
iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
|
|
4380
|
+
iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
|
|
4381
|
+
|
|
4382
|
+
iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
|
|
4383
|
+
iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
|
|
4384
|
+
iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
|
|
4385
|
+
iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
|
|
4386
|
+
|
|
4387
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
|
4388
|
+
__m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
|
|
4389
|
+
__m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
|
|
4390
|
+
__m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
|
|
4391
|
+
__m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
|
|
4392
|
+
__m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
|
|
4393
|
+
__m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
|
|
4394
|
+
__m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
|
|
4395
|
+
__m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
|
|
4396
|
+
|
|
4397
|
+
__m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
|
4398
|
+
__m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
|
4399
|
+
__m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
|
4400
|
+
__m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
|
4401
|
+
|
|
4402
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
|
4403
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
|
|
4404
|
+
const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
|
|
4405
|
+
const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
|
|
4406
|
+
|
|
4407
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
|
4408
|
+
acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
|
4409
|
+
acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
|
4410
|
+
acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
4411
|
+
acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
4412
|
+
|
|
4413
|
+
__m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
|
|
4414
|
+
__m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
|
|
4415
|
+
__m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
|
|
4416
|
+
__m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
|
|
4417
|
+
|
|
4418
|
+
acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
|
|
4419
|
+
acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
|
|
4420
|
+
acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
|
|
4421
|
+
acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
|
|
4422
|
+
}
|
|
4423
|
+
}
|
|
4424
|
+
}
|
|
4425
|
+
// Store the accumulated values
|
|
4426
|
+
for (int i = 0; i < 16; i++) {
|
|
4427
|
+
_mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
|
|
4428
|
+
}
|
|
4429
|
+
}
|
|
4430
|
+
}
|
|
4431
|
+
|
|
4432
|
+
for (; y < nr / 4; y++) {
|
|
4433
|
+
|
|
4434
|
+
const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
|
|
4435
|
+
|
|
4436
|
+
// Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
|
|
4437
|
+
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
4438
|
+
|
|
4439
|
+
const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
4440
|
+
const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
4441
|
+
|
|
4442
|
+
// Master FP accumulators
|
|
4443
|
+
__m512 acc_rows[4];
|
|
4444
|
+
for (int i = 0; i < 4; i++) {
|
|
4445
|
+
acc_rows[i] = _mm512_setzero_ps();
|
|
4446
|
+
}
|
|
4447
|
+
|
|
4448
|
+
__m512 acc_min_rows[4];
|
|
4449
|
+
for (int i = 0; i < 4; i++) {
|
|
4450
|
+
acc_min_rows[i] = _mm512_setzero_ps();
|
|
4451
|
+
}
|
|
4452
|
+
|
|
4453
|
+
// For super block
|
|
4454
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
4455
|
+
// Scale values - Load the sixteen scale values from two block_q4_kx8 structures
|
|
4456
|
+
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
4457
|
+
|
|
4458
|
+
// dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
|
|
4459
|
+
const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
|
|
4460
|
+
|
|
4461
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
|
4462
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
4463
|
+
|
|
4464
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
|
|
4465
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
|
|
4466
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
|
|
4467
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
|
|
4468
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
|
|
4469
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
|
|
4470
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
|
|
4471
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
|
|
4472
|
+
|
|
4473
|
+
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
|
|
4474
|
+
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
|
|
4475
|
+
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
|
|
4476
|
+
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
|
|
4477
|
+
const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
|
|
4478
|
+
const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
|
|
4479
|
+
const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
|
|
4480
|
+
const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
|
|
4481
|
+
|
|
4482
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
4483
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
4484
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
4485
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
4486
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
|
4487
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
|
4488
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
|
4489
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
|
4490
|
+
|
|
4491
|
+
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
4492
|
+
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
4493
|
+
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
4494
|
+
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
4495
|
+
const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
|
|
4496
|
+
const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
|
|
4497
|
+
const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
|
|
4498
|
+
const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
|
|
4499
|
+
|
|
4500
|
+
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
|
4501
|
+
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
|
4502
|
+
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
|
4503
|
+
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
|
4504
|
+
|
|
4505
|
+
const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
|
|
4506
|
+
const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
|
|
4507
|
+
const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
|
|
4508
|
+
const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
|
|
4509
|
+
|
|
4510
|
+
//4-bit -> 8-bit
|
|
4511
|
+
const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
|
|
4512
|
+
const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
|
|
4513
|
+
const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
|
|
4514
|
+
const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
|
|
4515
|
+
|
|
4516
|
+
const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
|
|
4517
|
+
const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
|
|
4518
|
+
const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
|
|
4519
|
+
const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
|
|
4520
|
+
|
|
4521
|
+
const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
|
|
4522
|
+
const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
|
|
4523
|
+
const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
|
|
4524
|
+
const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
|
|
4525
|
+
|
|
4526
|
+
const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
|
|
4527
|
+
const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
|
|
4528
|
+
const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
|
|
4529
|
+
const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
|
|
4530
|
+
|
|
4531
|
+
// Shuffle pattern one - right side input
|
|
4532
|
+
const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
|
|
4533
|
+
const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
|
|
4534
|
+
const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
|
|
4535
|
+
const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
|
|
4536
|
+
const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
|
|
4537
|
+
const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
|
|
4538
|
+
const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
|
|
4539
|
+
const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
|
|
4540
|
+
|
|
4541
|
+
const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
|
|
4542
|
+
const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
|
|
4543
|
+
const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
|
|
4544
|
+
const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
|
|
4545
|
+
const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
|
|
4546
|
+
const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
|
|
4547
|
+
const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
|
|
4548
|
+
const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
|
|
4549
|
+
|
|
4550
|
+
// Shuffle pattern two - right side input
|
|
4551
|
+
const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
|
|
4552
|
+
const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
|
|
4553
|
+
const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
|
|
4554
|
+
const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
|
|
4555
|
+
const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
|
|
4556
|
+
const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
|
|
4557
|
+
const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
|
|
4558
|
+
const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
|
|
4559
|
+
|
|
4560
|
+
const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
|
|
4561
|
+
const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
|
|
4562
|
+
const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
|
|
4563
|
+
const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
|
|
4564
|
+
const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
|
|
4565
|
+
const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
|
|
4566
|
+
const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
|
|
4567
|
+
const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
|
|
4568
|
+
|
|
4569
|
+
uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
|
|
4570
|
+
|
|
4571
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
|
4572
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4573
|
+
memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
|
|
4574
|
+
utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
|
|
4575
|
+
const uint32_t uaux_00 = utmp_00[1] & kmask1;
|
|
4576
|
+
utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
|
|
4577
|
+
utmp_00[2] = uaux_00;
|
|
4578
|
+
utmp_00[0] &= kmask1;
|
|
4579
|
+
|
|
4580
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4581
|
+
memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
|
|
4582
|
+
utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
|
|
4583
|
+
const uint32_t uaux_01 = utmp_01[1] & kmask1;
|
|
4584
|
+
utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
|
|
4585
|
+
utmp_01[2] = uaux_01;
|
|
4586
|
+
utmp_01[0] &= kmask1;
|
|
4587
|
+
|
|
4588
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4589
|
+
memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
|
|
4590
|
+
utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
|
|
4591
|
+
const uint32_t uaux_10 = utmp_10[1] & kmask1;
|
|
4592
|
+
utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
|
|
4593
|
+
utmp_10[2] = uaux_10;
|
|
4594
|
+
utmp_10[0] &= kmask1;
|
|
4595
|
+
|
|
4596
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4597
|
+
memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
|
|
4598
|
+
utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
|
|
4599
|
+
const uint32_t uaux_11 = utmp_11[1] & kmask1;
|
|
4600
|
+
utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
|
|
4601
|
+
utmp_11[2] = uaux_11;
|
|
4602
|
+
utmp_11[0] &= kmask1;
|
|
4603
|
+
|
|
4604
|
+
// Scales of first sub block in the sb loop
|
|
4605
|
+
const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
|
|
4606
|
+
const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
|
4607
|
+
|
|
4608
|
+
// Scales of second sub block in the sb loop
|
|
4609
|
+
const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
|
|
4610
|
+
const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
|
4611
|
+
|
|
4612
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
4613
|
+
const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
|
|
4614
|
+
|
|
4615
|
+
const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
|
|
4616
|
+
const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
|
|
4617
|
+
|
|
4618
|
+
const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
|
|
4619
|
+
const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
|
|
4620
|
+
|
|
4621
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
|
4622
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
4623
|
+
__m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
|
|
4624
|
+
__m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
|
|
4625
|
+
__m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
|
|
4626
|
+
__m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
|
|
4627
|
+
__m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
|
|
4628
|
+
__m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
|
|
4629
|
+
__m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
|
|
4630
|
+
__m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
|
|
4631
|
+
__m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
|
|
4632
|
+
__m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
|
|
4633
|
+
__m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
|
|
4634
|
+
__m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
|
|
4635
|
+
__m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
|
|
4636
|
+
__m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
|
|
4637
|
+
__m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
|
|
4638
|
+
__m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
|
|
4639
|
+
__m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
|
|
4640
|
+
__m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
|
|
4641
|
+
__m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
|
|
4642
|
+
__m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
|
|
4643
|
+
__m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
|
|
4644
|
+
__m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
|
|
4645
|
+
__m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
|
|
4646
|
+
__m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
|
|
4647
|
+
|
|
4648
|
+
//Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector
|
|
4649
|
+
__m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
|
|
4650
|
+
__m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
|
|
4651
|
+
__m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
|
|
4652
|
+
__m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
|
|
4653
|
+
__m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
|
|
4654
|
+
__m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
|
|
4655
|
+
__m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
|
|
4656
|
+
__m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
|
|
4657
|
+
|
|
4658
|
+
__m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
|
|
4659
|
+
__m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
|
|
4660
|
+
__m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
|
|
4661
|
+
__m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
|
|
4662
|
+
__m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
|
|
4663
|
+
__m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
|
|
4664
|
+
__m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
|
|
4665
|
+
__m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
|
|
4666
|
+
|
|
4667
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
|
4668
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
|
|
4669
|
+
__m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
|
4670
|
+
lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
|
|
4671
|
+
__m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
|
|
4672
|
+
|
|
4673
|
+
// Shuffle pattern one - left side input
|
|
4674
|
+
const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
|
4675
|
+
const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
|
|
4676
|
+
const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
|
4677
|
+
const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
|
|
4678
|
+
const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
|
4679
|
+
const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
|
|
4680
|
+
const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
|
4681
|
+
const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
|
|
4682
|
+
|
|
4683
|
+
const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
|
4684
|
+
const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
|
|
4685
|
+
const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
|
4686
|
+
const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
|
|
4687
|
+
const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
|
4688
|
+
const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
|
|
4689
|
+
const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
|
4690
|
+
const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
|
|
4691
|
+
|
|
4692
|
+
const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
|
4693
|
+
const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
|
|
4694
|
+
const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
|
4695
|
+
const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
|
|
4696
|
+
const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
|
4697
|
+
const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
|
|
4698
|
+
const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
|
4699
|
+
const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
|
|
4700
|
+
|
|
4701
|
+
const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
|
4702
|
+
const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
|
|
4703
|
+
const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
|
4704
|
+
const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
|
|
4705
|
+
const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
|
4706
|
+
const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
|
|
4707
|
+
const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
|
4708
|
+
const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
|
|
4709
|
+
|
|
4710
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
4711
|
+
__m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
|
|
4712
|
+
__m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
|
|
4713
|
+
__m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
|
|
4714
|
+
__m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
|
|
4715
|
+
__m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
|
|
4716
|
+
__m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
|
|
4717
|
+
__m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
|
|
4718
|
+
__m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
|
|
4719
|
+
|
|
4720
|
+
__m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
|
|
4721
|
+
__m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
|
|
4722
|
+
__m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
|
|
4723
|
+
__m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
|
|
4724
|
+
__m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
|
|
4725
|
+
__m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
|
|
4726
|
+
__m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
|
|
4727
|
+
__m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
|
|
4728
|
+
|
|
4729
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
4730
|
+
__m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
|
4731
|
+
__m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
|
4732
|
+
__m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
|
4733
|
+
__m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
|
4734
|
+
|
|
4735
|
+
__m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
|
4736
|
+
__m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
|
4737
|
+
__m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
|
4738
|
+
__m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
|
4739
|
+
|
|
4740
|
+
iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
|
|
4741
|
+
iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
|
|
4742
|
+
iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
|
|
4743
|
+
iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
|
|
4744
|
+
|
|
4745
|
+
iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
|
|
4746
|
+
iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
|
|
4747
|
+
iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
|
|
4748
|
+
iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
|
|
4749
|
+
|
|
4750
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
|
4751
|
+
__m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
|
|
4752
|
+
__m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
|
|
4753
|
+
__m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
|
|
4754
|
+
__m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
|
|
4755
|
+
__m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
|
|
4756
|
+
__m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
|
|
4757
|
+
__m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
|
|
4758
|
+
__m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
|
|
4759
|
+
|
|
4760
|
+
__m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
|
4761
|
+
__m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
|
4762
|
+
__m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
|
4763
|
+
__m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
|
4764
|
+
|
|
4765
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
|
4766
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
|
|
4767
|
+
const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
|
|
4768
|
+
const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
|
|
4769
|
+
|
|
4770
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
|
4771
|
+
acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
|
4772
|
+
acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
|
4773
|
+
acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
|
4774
|
+
acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
|
4775
|
+
|
|
4776
|
+
__m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
|
|
4777
|
+
__m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
|
|
4778
|
+
__m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
|
|
4779
|
+
__m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
|
|
4780
|
+
|
|
4781
|
+
acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
|
|
4782
|
+
acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
|
|
4783
|
+
acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
|
|
4784
|
+
acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
|
|
4785
|
+
}
|
|
4786
|
+
}
|
|
4787
|
+
// Store accumlated values
|
|
4788
|
+
for (int i = 0; i < 4; i++) {
|
|
4789
|
+
_mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
|
|
4790
|
+
}
|
|
4791
|
+
}
|
|
4792
|
+
}
|
|
4793
|
+
if (anc != nc) {
|
|
4794
|
+
xstart = anc/8;
|
|
4795
|
+
y = 0;
|
|
4796
|
+
}
|
|
4797
|
+
#endif //AVX512F
|
|
4798
|
+
|
|
4799
|
+
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
|
4800
|
+
for (; y < anr / 4; y += 4) {
|
|
4801
|
+
|
|
4802
|
+
const block_q8_Kx4 * a_ptrs[4];
|
|
4803
|
+
|
|
4804
|
+
a_ptrs[0] = a_ptr_start + (y * nb);
|
|
4805
|
+
for (int i = 0; i < 3; ++i) {
|
|
4806
|
+
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
4807
|
+
}
|
|
4808
|
+
|
|
4809
|
+
// Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
|
|
4810
|
+
for (int64_t x = xstart; x < nc / 8; x++) {
|
|
4811
|
+
|
|
4812
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
4813
|
+
|
|
4814
|
+
// Master FP accumulators
|
|
4815
|
+
__m256 acc_rows[16];
|
|
4816
|
+
for (int i = 0; i < 16; i++) {
|
|
4817
|
+
acc_rows[i] = _mm256_setzero_ps();
|
|
4818
|
+
}
|
|
4819
|
+
|
|
4820
|
+
__m256 acc_min_rows[16];
|
|
4821
|
+
for (int i = 0; i < 16; i++) {
|
|
4822
|
+
acc_min_rows[i] = _mm256_setzero_ps();
|
|
4823
|
+
}
|
|
4824
|
+
|
|
4825
|
+
// For super block
|
|
4826
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
4827
|
+
|
|
4828
|
+
// Scale values - Load the eight scale values of block_q4_kx8
|
|
4829
|
+
const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
|
4830
|
+
|
|
4831
|
+
// dmin values - Load the eight dmin values of block_q4_kx8
|
|
4832
|
+
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
4833
|
+
|
|
4834
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
|
4835
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
4836
|
+
|
|
4837
|
+
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
|
|
4838
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
|
4839
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
|
4840
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
4841
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
4842
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
|
4843
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
|
4844
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
4845
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
4846
|
+
|
|
4847
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
|
|
4848
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
4849
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
4850
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
4851
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
4852
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
|
4853
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
|
4854
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
|
4855
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
|
4856
|
+
|
|
4857
|
+
// 4-bit -> 8-bit
|
|
4858
|
+
// First sub block of the two sub blocks processed in the iteration
|
|
4859
|
+
const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
|
|
4860
|
+
const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
|
|
4861
|
+
|
|
4862
|
+
const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
|
|
4863
|
+
const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
|
|
4864
|
+
|
|
4865
|
+
const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
|
|
4866
|
+
const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
|
|
4867
|
+
|
|
4868
|
+
const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
|
|
4869
|
+
const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
|
|
4870
|
+
|
|
4871
|
+
// Second sub block of the two sub blocks processed in the iteration
|
|
4872
|
+
const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
|
|
4873
|
+
const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
|
|
4874
|
+
|
|
4875
|
+
const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
|
|
4876
|
+
const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
|
|
4877
|
+
|
|
4878
|
+
const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
|
|
4879
|
+
const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
|
|
4880
|
+
|
|
4881
|
+
const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
|
|
4882
|
+
const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
|
|
4883
|
+
|
|
4884
|
+
// Shuffle pattern one - right side input
|
|
4885
|
+
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
|
|
4886
|
+
const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
|
|
4887
|
+
|
|
4888
|
+
const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
|
|
4889
|
+
const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
|
|
4890
|
+
|
|
4891
|
+
const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
|
|
4892
|
+
const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
|
|
4893
|
+
|
|
4894
|
+
const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
|
|
4895
|
+
const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
|
|
4896
|
+
|
|
4897
|
+
const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
|
|
4898
|
+
const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
|
|
4899
|
+
|
|
4900
|
+
const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
|
|
4901
|
+
const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
|
|
4902
|
+
|
|
4903
|
+
const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
|
|
4904
|
+
const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
|
|
4905
|
+
|
|
4906
|
+
const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
|
|
4907
|
+
const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
|
|
4908
|
+
|
|
4909
|
+
|
|
4910
|
+
// Shuffle pattern two - right side input
|
|
4911
|
+
const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
|
|
4912
|
+
const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
|
|
4913
|
+
|
|
4914
|
+
const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
|
|
4915
|
+
const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
|
|
4916
|
+
|
|
4917
|
+
const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
|
|
4918
|
+
const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
|
|
4919
|
+
|
|
4920
|
+
const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
|
|
4921
|
+
const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
|
|
4922
|
+
|
|
4923
|
+
const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
|
|
4924
|
+
const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
|
|
4925
|
+
|
|
4926
|
+
const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
|
|
4927
|
+
const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
|
|
4928
|
+
|
|
4929
|
+
const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
|
|
4930
|
+
const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
|
|
4931
|
+
|
|
4932
|
+
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
|
|
4933
|
+
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
|
|
4934
|
+
|
|
4935
|
+
uint32_t utmp_0[4], utmp_1[4];
|
|
4936
|
+
|
|
4937
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
|
4938
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4939
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
|
4940
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
|
4941
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
|
4942
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
|
4943
|
+
utmp_0[2] = uaux_0;
|
|
4944
|
+
utmp_0[0] &= kmask1;
|
|
4945
|
+
|
|
4946
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
4947
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
|
4948
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
|
4949
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
|
4950
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
|
4951
|
+
utmp_1[2] = uaux_1;
|
|
4952
|
+
utmp_1[0] &= kmask1;
|
|
4953
|
+
|
|
4954
|
+
// Scales of first sub block in the sb loop
|
|
4955
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
|
4956
|
+
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
|
4957
|
+
|
|
4958
|
+
// Scales of second sub block in the sb loop
|
|
4959
|
+
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
|
4960
|
+
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
|
4961
|
+
|
|
4962
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
4963
|
+
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
|
4964
|
+
|
|
4965
|
+
const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
|
|
4966
|
+
const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
|
|
4967
|
+
|
|
4968
|
+
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
|
|
4969
|
+
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
|
|
4970
|
+
|
|
4971
|
+
for (int rp = 0; rp < 4; rp++) {
|
|
4972
|
+
|
|
4973
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
|
4974
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
4975
|
+
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
|
|
4976
|
+
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
|
|
4977
|
+
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
|
|
4978
|
+
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
|
|
4979
|
+
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
|
|
4980
|
+
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
|
|
4981
|
+
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
|
|
4982
|
+
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
|
|
4983
|
+
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
|
|
4984
|
+
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
|
|
4985
|
+
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
|
|
4986
|
+
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
|
|
4987
|
+
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
|
|
4988
|
+
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
|
|
4989
|
+
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
|
|
4990
|
+
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
|
|
4991
|
+
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
|
|
4992
|
+
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
|
|
4993
|
+
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
|
|
4994
|
+
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
|
|
4995
|
+
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
|
|
4996
|
+
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
|
|
4997
|
+
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
|
|
4998
|
+
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
|
|
4999
|
+
|
|
5000
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
|
5001
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
|
|
5002
|
+
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
|
5003
|
+
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
|
5004
|
+
|
|
5005
|
+
// Shuffle pattern one - left side input
|
|
5006
|
+
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
|
5007
|
+
const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
|
|
5008
|
+
|
|
5009
|
+
const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
|
5010
|
+
const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
|
|
5011
|
+
|
|
5012
|
+
const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
|
5013
|
+
const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
|
|
5014
|
+
|
|
5015
|
+
const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
|
5016
|
+
const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
|
|
5017
|
+
|
|
5018
|
+
const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
|
5019
|
+
const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
|
|
5020
|
+
|
|
5021
|
+
const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
|
5022
|
+
const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
|
|
5023
|
+
|
|
5024
|
+
const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
|
5025
|
+
const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
|
|
5026
|
+
|
|
5027
|
+
const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
|
5028
|
+
const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
|
|
5029
|
+
|
|
5030
|
+
// Shuffle pattern two- left side input
|
|
5031
|
+
const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
|
5032
|
+
const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
|
|
5033
|
+
|
|
5034
|
+
const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
|
5035
|
+
const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
|
|
5036
|
+
|
|
5037
|
+
const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
|
5038
|
+
const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
|
|
5039
|
+
|
|
5040
|
+
const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
|
5041
|
+
const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
|
|
5042
|
+
|
|
5043
|
+
const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
|
5044
|
+
const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
|
|
5045
|
+
|
|
5046
|
+
const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
|
5047
|
+
const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
|
|
5048
|
+
|
|
5049
|
+
const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
|
5050
|
+
const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
|
|
5051
|
+
|
|
5052
|
+
const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
|
5053
|
+
const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
|
|
5054
|
+
|
|
5055
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
5056
|
+
__m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
|
|
5057
|
+
__m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
|
|
5058
|
+
__m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
|
|
5059
|
+
__m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
|
|
5060
|
+
__m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
|
|
5061
|
+
__m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
|
|
5062
|
+
__m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
|
|
5063
|
+
__m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
|
|
5064
|
+
|
|
5065
|
+
__m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
|
|
5066
|
+
__m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
|
|
5067
|
+
__m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
|
|
5068
|
+
__m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
|
|
5069
|
+
__m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
|
|
5070
|
+
__m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
|
|
5071
|
+
__m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
|
|
5072
|
+
__m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
|
|
5073
|
+
|
|
5074
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
5075
|
+
__m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
|
5076
|
+
__m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
|
5077
|
+
__m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
|
5078
|
+
__m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
|
5079
|
+
|
|
5080
|
+
__m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
|
5081
|
+
__m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
|
5082
|
+
__m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
|
5083
|
+
__m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
|
5084
|
+
|
|
5085
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
5086
|
+
iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
|
|
5087
|
+
iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
|
|
5088
|
+
iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
|
|
5089
|
+
iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
|
|
5090
|
+
|
|
5091
|
+
iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
|
|
5092
|
+
iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
|
|
5093
|
+
iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
|
|
5094
|
+
iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
|
|
5095
|
+
|
|
5096
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
|
5097
|
+
__m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
|
|
5098
|
+
__m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
|
|
5099
|
+
__m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
|
|
5100
|
+
__m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
|
|
5101
|
+
__m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
|
|
5102
|
+
__m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
|
|
5103
|
+
__m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
|
|
5104
|
+
__m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
|
|
5105
|
+
|
|
5106
|
+
__m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
|
5107
|
+
__m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
|
5108
|
+
__m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
|
5109
|
+
__m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
|
5110
|
+
|
|
5111
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
|
5112
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
|
|
5113
|
+
const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
|
|
5114
|
+
|
|
5115
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
|
5116
|
+
acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
|
5117
|
+
acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
|
5118
|
+
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
5119
|
+
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
5120
|
+
|
|
5121
|
+
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
|
|
5122
|
+
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
|
|
5123
|
+
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
|
|
5124
|
+
__m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
|
|
5125
|
+
|
|
5126
|
+
acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
|
|
5127
|
+
acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
|
|
5128
|
+
acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
|
|
5129
|
+
acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
|
|
5130
|
+
|
|
5131
|
+
}
|
|
5132
|
+
}
|
|
5133
|
+
}
|
|
5134
|
+
// Store the accumulated values
|
|
5135
|
+
for (int i = 0; i < 16; i++) {
|
|
5136
|
+
_mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
|
|
5137
|
+
}
|
|
5138
|
+
}
|
|
5139
|
+
}
|
|
5140
|
+
for (; y < nr / 4; y++) {
|
|
5141
|
+
|
|
5142
|
+
const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
|
|
5143
|
+
|
|
5144
|
+
for (int64_t x = xstart; x < nc / 8; x++) {
|
|
5145
|
+
|
|
5146
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
5147
|
+
|
|
5148
|
+
// Master FP accumulators
|
|
5149
|
+
__m256 acc_rows[4];
|
|
5150
|
+
for (int i = 0; i < 4; i++) {
|
|
5151
|
+
acc_rows[i] = _mm256_setzero_ps();
|
|
5152
|
+
}
|
|
5153
|
+
|
|
5154
|
+
__m256 acc_min_rows[4];
|
|
5155
|
+
for (int i = 0; i < 4; i++) {
|
|
5156
|
+
acc_min_rows[i] = _mm256_setzero_ps();
|
|
5157
|
+
}
|
|
5158
|
+
|
|
5159
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
5160
|
+
|
|
5161
|
+
// Scale values - Load the eight scale values of block_q4_Kx8
|
|
5162
|
+
const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
|
5163
|
+
|
|
5164
|
+
// dmin values - Load the eight dmin values of block_q4_Kx8
|
|
5165
|
+
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
5166
|
+
|
|
5167
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
|
5168
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
5169
|
+
|
|
5170
|
+
// Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
|
|
5171
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
|
5172
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
|
5173
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
5174
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
5175
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
|
5176
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
|
5177
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
5178
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
5179
|
+
|
|
5180
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
|
|
5181
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
5182
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
5183
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
5184
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
5185
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
|
5186
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
|
5187
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
|
5188
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
|
5189
|
+
|
|
5190
|
+
// 4-bit -> 8-bit
|
|
5191
|
+
// First sub block of the two sub blocks processed in the iteration
|
|
5192
|
+
const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
|
|
5193
|
+
const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
|
|
5194
|
+
|
|
5195
|
+
const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
|
|
5196
|
+
const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
|
|
5197
|
+
|
|
5198
|
+
const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
|
|
5199
|
+
const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
|
|
5200
|
+
|
|
5201
|
+
const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
|
|
5202
|
+
const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
|
|
5203
|
+
|
|
5204
|
+
// Second sub block of the two sub blocks processed in the iteration
|
|
5205
|
+
const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
|
|
5206
|
+
const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
|
|
5207
|
+
|
|
5208
|
+
const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
|
|
5209
|
+
const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
|
|
5210
|
+
|
|
5211
|
+
const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
|
|
5212
|
+
const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
|
|
5213
|
+
|
|
5214
|
+
const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
|
|
5215
|
+
const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
|
|
5216
|
+
|
|
5217
|
+
// Shuffle pattern one - right side input
|
|
5218
|
+
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
|
|
5219
|
+
const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
|
|
5220
|
+
|
|
5221
|
+
const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
|
|
5222
|
+
const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
|
|
5223
|
+
|
|
5224
|
+
const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
|
|
5225
|
+
const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
|
|
5226
|
+
|
|
5227
|
+
const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
|
|
5228
|
+
const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
|
|
5229
|
+
|
|
5230
|
+
const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
|
|
5231
|
+
const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
|
|
5232
|
+
|
|
5233
|
+
const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
|
|
5234
|
+
const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
|
|
5235
|
+
|
|
5236
|
+
const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
|
|
5237
|
+
const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
|
|
5238
|
+
|
|
5239
|
+
const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
|
|
5240
|
+
const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
|
|
5241
|
+
|
|
5242
|
+
// Shuffle pattern two - right side input
|
|
5243
|
+
const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
|
|
5244
|
+
const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
|
|
5245
|
+
|
|
5246
|
+
const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
|
|
5247
|
+
const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
|
|
5248
|
+
|
|
5249
|
+
const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
|
|
5250
|
+
const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
|
|
5251
|
+
|
|
5252
|
+
const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
|
|
5253
|
+
const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
|
|
5254
|
+
|
|
5255
|
+
const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
|
|
5256
|
+
const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
|
|
5257
|
+
|
|
5258
|
+
const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
|
|
5259
|
+
const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
|
|
5260
|
+
|
|
5261
|
+
const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
|
|
5262
|
+
const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
|
|
5263
|
+
|
|
5264
|
+
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
|
|
5265
|
+
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
|
|
5266
|
+
|
|
5267
|
+
uint32_t utmp_0[4], utmp_1[4];
|
|
5268
|
+
|
|
5269
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
|
5270
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
5271
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
|
5272
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
|
5273
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
|
5274
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
|
5275
|
+
utmp_0[2] = uaux_0;
|
|
5276
|
+
utmp_0[0] &= kmask1;
|
|
5277
|
+
|
|
5278
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
|
|
5279
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
|
5280
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
|
5281
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
|
5282
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
|
5283
|
+
utmp_1[2] = uaux_1;
|
|
5284
|
+
utmp_1[0] &= kmask1;
|
|
5285
|
+
|
|
5286
|
+
// Scales of first sub block in the sb loop
|
|
5287
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
|
5288
|
+
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
|
5289
|
+
|
|
5290
|
+
// Scales of second sub block in the sb loop
|
|
5291
|
+
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
|
5292
|
+
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
|
5293
|
+
|
|
5294
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
5295
|
+
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
|
5296
|
+
|
|
5297
|
+
const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
|
|
5298
|
+
const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
|
|
5299
|
+
|
|
5300
|
+
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
|
|
5301
|
+
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
|
|
5302
|
+
|
|
5303
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
|
5304
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
5305
|
+
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
|
|
5306
|
+
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
|
|
5307
|
+
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
|
|
5308
|
+
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
|
|
5309
|
+
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
|
|
5310
|
+
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
|
|
5311
|
+
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
|
|
5312
|
+
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
|
|
5313
|
+
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
|
|
5314
|
+
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
|
|
5315
|
+
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
|
|
5316
|
+
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
|
|
5317
|
+
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
|
|
5318
|
+
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
|
|
5319
|
+
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
|
|
5320
|
+
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
|
|
5321
|
+
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
|
|
5322
|
+
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
|
|
5323
|
+
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
|
|
5324
|
+
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
|
|
5325
|
+
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
|
|
5326
|
+
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
|
|
5327
|
+
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
|
|
5328
|
+
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
|
|
5329
|
+
|
|
5330
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
|
5331
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
|
|
5332
|
+
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
|
5333
|
+
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
|
5334
|
+
|
|
5335
|
+
// Shuffle pattern one - left side input
|
|
5336
|
+
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
|
5337
|
+
const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
|
|
5338
|
+
|
|
5339
|
+
const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
|
5340
|
+
const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
|
|
5341
|
+
|
|
5342
|
+
const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
|
5343
|
+
const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
|
|
5344
|
+
|
|
5345
|
+
const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
|
5346
|
+
const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
|
|
5347
|
+
|
|
5348
|
+
const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
|
5349
|
+
const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
|
|
5350
|
+
|
|
5351
|
+
const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
|
5352
|
+
const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
|
|
5353
|
+
|
|
5354
|
+
const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
|
5355
|
+
const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
|
|
5356
|
+
|
|
5357
|
+
const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
|
5358
|
+
const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
|
|
5359
|
+
|
|
5360
|
+
// Shuffle pattern two- left side input
|
|
5361
|
+
const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
|
5362
|
+
const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
|
|
5363
|
+
|
|
5364
|
+
const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
|
5365
|
+
const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
|
|
5366
|
+
|
|
5367
|
+
const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
|
5368
|
+
const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
|
|
5369
|
+
|
|
5370
|
+
const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
|
5371
|
+
const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
|
|
5372
|
+
|
|
5373
|
+
const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
|
5374
|
+
const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
|
|
5375
|
+
|
|
5376
|
+
const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
|
5377
|
+
const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
|
|
5378
|
+
|
|
5379
|
+
const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
|
5380
|
+
const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
|
|
5381
|
+
|
|
5382
|
+
const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
|
5383
|
+
const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
|
|
5384
|
+
|
|
5385
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
5386
|
+
__m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
|
|
5387
|
+
__m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
|
|
5388
|
+
__m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
|
|
5389
|
+
__m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
|
|
5390
|
+
__m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
|
|
5391
|
+
__m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
|
|
5392
|
+
__m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
|
|
5393
|
+
__m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
|
|
5394
|
+
|
|
5395
|
+
__m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
|
|
5396
|
+
__m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
|
|
5397
|
+
__m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
|
|
5398
|
+
__m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
|
|
5399
|
+
__m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
|
|
5400
|
+
__m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
|
|
5401
|
+
__m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
|
|
5402
|
+
__m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
|
|
5403
|
+
|
|
5404
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
5405
|
+
__m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
|
5406
|
+
__m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
|
5407
|
+
__m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
|
5408
|
+
__m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
|
5409
|
+
|
|
5410
|
+
__m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
|
5411
|
+
__m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
|
5412
|
+
__m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
|
5413
|
+
__m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
|
5414
|
+
|
|
5415
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
5416
|
+
iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
|
|
5417
|
+
iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
|
|
5418
|
+
iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
|
|
5419
|
+
iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
|
|
5420
|
+
|
|
5421
|
+
iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
|
|
5422
|
+
iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
|
|
5423
|
+
iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
|
|
5424
|
+
iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
|
|
5425
|
+
|
|
5426
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
|
5427
|
+
__m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
|
|
5428
|
+
__m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
|
|
5429
|
+
__m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
|
|
5430
|
+
__m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
|
|
5431
|
+
__m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
|
|
5432
|
+
__m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
|
|
5433
|
+
__m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
|
|
5434
|
+
__m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
|
|
5435
|
+
|
|
5436
|
+
__m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
|
5437
|
+
__m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
|
5438
|
+
__m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
|
5439
|
+
__m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
|
5440
|
+
|
|
5441
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
|
5442
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
|
|
5443
|
+
const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
|
|
5444
|
+
|
|
5445
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
|
5446
|
+
acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
|
5447
|
+
acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
|
5448
|
+
acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
|
5449
|
+
acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
|
5450
|
+
|
|
5451
|
+
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
|
|
5452
|
+
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
|
|
5453
|
+
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
|
|
5454
|
+
__m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
|
|
5455
|
+
|
|
5456
|
+
acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
|
|
5457
|
+
acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
|
|
5458
|
+
acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
|
|
5459
|
+
acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
|
|
5460
|
+
}
|
|
5461
|
+
}
|
|
5462
|
+
|
|
5463
|
+
// Store the accumulated values
|
|
5464
|
+
for (int i = 0; i < 4; i++) {
|
|
5465
|
+
_mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
|
|
5466
|
+
}
|
|
5467
|
+
}
|
|
5468
|
+
}
|
|
5469
|
+
|
|
5470
|
+
#else
|
|
5471
|
+
|
|
5472
|
+
float sumf[4][8];
|
|
5473
|
+
float sum_minf[4][8];
|
|
5474
|
+
uint32_t utmp[32];
|
|
5475
|
+
int sumi1;
|
|
5476
|
+
int sumi2;
|
|
5477
|
+
int sumi;
|
|
5478
|
+
|
|
5479
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
5480
|
+
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
5481
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
5482
|
+
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
5483
|
+
for (int m = 0; m < 4; m++) {
|
|
5484
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
5485
|
+
sumf[m][j] = 0.0;
|
|
5486
|
+
sum_minf[m][j] = 0.0;
|
|
5487
|
+
}
|
|
5488
|
+
}
|
|
5489
|
+
for (int l = 0; l < nb; l++) {
|
|
5490
|
+
for (int sb = 0; sb < 8; sb++) {
|
|
5491
|
+
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
|
5492
|
+
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
|
5493
|
+
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
|
5494
|
+
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
|
5495
|
+
utmp[sb * 4 + 2] = uaux_0;
|
|
5496
|
+
utmp[sb * 4 + 0] &= kmask1;
|
|
5497
|
+
}
|
|
5498
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
5499
|
+
uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
|
|
5500
|
+
uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
|
|
5501
|
+
for (int m = 0; m < 4; m++) {
|
|
5502
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
5503
|
+
sumi1 = 0;
|
|
5504
|
+
sumi2 = 0;
|
|
5505
|
+
sumi = 0;
|
|
5506
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
5507
|
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
|
5508
|
+
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
|
5509
|
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
|
|
5510
|
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
|
|
5511
|
+
sumi1 = sumi1 * scales_0[j];
|
|
5512
|
+
sumi2 = sumi2 * scales_1[j];
|
|
5513
|
+
sumi += sumi1 + sumi2;
|
|
5514
|
+
}
|
|
5515
|
+
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
|
5516
|
+
}
|
|
5517
|
+
}
|
|
5518
|
+
}
|
|
5519
|
+
for (int sb = 0; sb < 8; sb++) {
|
|
5520
|
+
uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
|
|
5521
|
+
for(int m = 0; m < 4; m++) {
|
|
5522
|
+
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
|
5523
|
+
for(int j = 0; j < ncols_interleaved; j++) {
|
|
5524
|
+
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
|
5525
|
+
}
|
|
5526
|
+
}
|
|
5527
|
+
}
|
|
5528
|
+
}
|
|
5529
|
+
for (int m = 0; m < 4; m++) {
|
|
5530
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
5531
|
+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
|
5532
|
+
}
|
|
5533
|
+
}
|
|
5534
|
+
}
|
|
5535
|
+
}
|
|
5536
|
+
#endif
|
|
5537
|
+
}
|
|
5538
|
+
|
|
5539
|
+
static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
5540
|
+
const int qk = QK8_0;
|
|
5541
|
+
const int nb = n / qk;
|
|
5542
|
+
const int ncols_interleaved = 4;
|
|
5543
|
+
const int blocklen = 4;
|
|
5544
|
+
|
|
5545
|
+
assert (n % qk == 0);
|
|
5546
|
+
assert (nr % 4 == 0);
|
|
5547
|
+
assert (nc % ncols_interleaved == 0);
|
|
5548
|
+
|
|
5549
|
+
UNUSED(s);
|
|
5550
|
+
UNUSED(bs);
|
|
5551
|
+
UNUSED(vx);
|
|
5552
|
+
UNUSED(vy);
|
|
5553
|
+
UNUSED(nr);
|
|
5554
|
+
UNUSED(nc);
|
|
5555
|
+
UNUSED(nb);
|
|
5556
|
+
UNUSED(ncols_interleaved);
|
|
5557
|
+
UNUSED(blocklen);
|
|
5558
|
+
|
|
5559
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
5560
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
5561
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
5562
|
+
|
|
5563
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
5564
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
5565
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
5566
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
5567
|
+
|
|
5568
|
+
float32x4_t sumf[4];
|
|
5569
|
+
for (int m = 0; m < 4; m++) {
|
|
5570
|
+
sumf[m] = vdupq_n_f32(0);
|
|
5571
|
+
}
|
|
5572
|
+
|
|
5573
|
+
for (int l = 0; l < nb; l++) {
|
|
5574
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
|
5575
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
5576
|
+
|
|
5577
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
5578
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
5579
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
5580
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
5581
|
+
|
|
5582
|
+
for (int k = 0; k < 4; k++) {
|
|
5583
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
|
5584
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
|
5585
|
+
|
|
5586
|
+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
|
5587
|
+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
|
5588
|
+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
|
5589
|
+
|
|
5590
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
|
5591
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
|
5592
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
|
5593
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
|
5594
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
|
5595
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
|
5596
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
|
5597
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
|
5598
|
+
}
|
|
5599
|
+
|
|
5600
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
5601
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
5602
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
5603
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
5604
|
+
}
|
|
5605
|
+
|
|
5606
|
+
for (int m = 0; m < 4; m++) {
|
|
5607
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
5608
|
+
}
|
|
5609
|
+
}
|
|
5610
|
+
}
|
|
5611
|
+
return;
|
|
5612
|
+
}
|
|
5613
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
5614
|
+
{
|
|
3559
5615
|
float sumf[4][4];
|
|
3560
5616
|
int sumi;
|
|
3561
5617
|
|
|
@@ -3660,6 +5716,82 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
|
3660
5716
|
return out;
|
|
3661
5717
|
}
|
|
3662
5718
|
|
|
5719
|
+
static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
|
|
5720
|
+
block_q4_Kx8 out;
|
|
5721
|
+
//Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
|
|
5722
|
+
for (int i = 0; i < 8; i++) {
|
|
5723
|
+
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
|
5724
|
+
}
|
|
5725
|
+
|
|
5726
|
+
for (int i = 0; i < 8; i++) {
|
|
5727
|
+
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
|
5728
|
+
}
|
|
5729
|
+
|
|
5730
|
+
const int end = QK_K * 4 / blck_size_interleave;
|
|
5731
|
+
|
|
5732
|
+
// Interleave Q4_K quants by taking 8 bytes at a time
|
|
5733
|
+
for (int i = 0; i < end; ++i) {
|
|
5734
|
+
int src_id = i % 8;
|
|
5735
|
+
int src_offset = (i / 8) * blck_size_interleave;
|
|
5736
|
+
int dst_offset = i * blck_size_interleave;
|
|
5737
|
+
|
|
5738
|
+
uint64_t elems;
|
|
5739
|
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
5740
|
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
|
5741
|
+
}
|
|
5742
|
+
|
|
5743
|
+
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
|
|
5744
|
+
// Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
|
5745
|
+
// The output Q4_Kx8 structure has 96 bytes
|
|
5746
|
+
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
|
|
5747
|
+
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
|
|
5748
|
+
uint8_t s[8], m[8];
|
|
5749
|
+
|
|
5750
|
+
for (int i = 0; i < 4; i++) {
|
|
5751
|
+
for (int j = 0; j < 8; j++) {
|
|
5752
|
+
s[j] = in[j].scales[i] & 63;
|
|
5753
|
+
m[j] = in[j].scales[i + 4] & 63;
|
|
5754
|
+
}
|
|
5755
|
+
|
|
5756
|
+
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
|
5757
|
+
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
|
5758
|
+
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
|
5759
|
+
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
|
5760
|
+
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
|
5761
|
+
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
|
5762
|
+
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
|
5763
|
+
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
|
5764
|
+
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
|
5765
|
+
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
|
5766
|
+
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
|
5767
|
+
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
|
5768
|
+
|
|
5769
|
+
}
|
|
5770
|
+
|
|
5771
|
+
for (int i = 0; i < 4; i++) {
|
|
5772
|
+
for (int j = 0; j < 8; j++) {
|
|
5773
|
+
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
|
|
5774
|
+
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
|
|
5775
|
+
}
|
|
5776
|
+
|
|
5777
|
+
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
|
5778
|
+
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
|
5779
|
+
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
|
5780
|
+
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
|
5781
|
+
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
|
5782
|
+
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
|
5783
|
+
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
|
5784
|
+
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
|
5785
|
+
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
|
5786
|
+
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
|
5787
|
+
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
|
5788
|
+
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
|
5789
|
+
|
|
5790
|
+
}
|
|
5791
|
+
|
|
5792
|
+
return out;
|
|
5793
|
+
}
|
|
5794
|
+
|
|
3663
5795
|
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
3664
5796
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
3665
5797
|
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
@@ -3690,6 +5822,36 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
|
|
3690
5822
|
|
|
3691
5823
|
GGML_UNUSED(data_size);
|
|
3692
5824
|
}
|
|
5825
|
+
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
5826
|
+
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
|
5827
|
+
GGML_ASSERT(interleave_block == 8);
|
|
5828
|
+
constexpr int nrows_interleaved = 8;
|
|
5829
|
+
|
|
5830
|
+
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
|
|
5831
|
+
const block_q4_K * src = (const block_q4_K*) data;
|
|
5832
|
+
block_q4_K dst_tmp[8];
|
|
5833
|
+
int nrow = ggml_nrows(t);
|
|
5834
|
+
int nblocks = t->ne[0] / QK_K;
|
|
5835
|
+
|
|
5836
|
+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
|
|
5837
|
+
|
|
5838
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
5839
|
+
return -1;
|
|
5840
|
+
}
|
|
5841
|
+
|
|
5842
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
|
5843
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
|
5844
|
+
for (int i = 0; i < nrows_interleaved; i++ ) {
|
|
5845
|
+
dst_tmp[i] = src[x + i * nblocks];
|
|
5846
|
+
}
|
|
5847
|
+
*dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
|
|
5848
|
+
}
|
|
5849
|
+
src += nrows_interleaved * nblocks;
|
|
5850
|
+
}
|
|
5851
|
+
return 0;
|
|
5852
|
+
|
|
5853
|
+
GGML_UNUSED(data_size);
|
|
5854
|
+
}
|
|
3693
5855
|
|
|
3694
5856
|
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
3695
5857
|
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
@@ -3807,6 +5969,10 @@ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * da
|
|
|
3807
5969
|
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
|
3808
5970
|
}
|
|
3809
5971
|
|
|
5972
|
+
template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
5973
|
+
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
|
5974
|
+
}
|
|
5975
|
+
|
|
3810
5976
|
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3811
5977
|
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
|
3812
5978
|
}
|
|
@@ -3817,44 +5983,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
|
|
|
3817
5983
|
//}
|
|
3818
5984
|
|
|
3819
5985
|
// gemv
|
|
3820
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
5986
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
|
3821
5987
|
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
|
3822
5988
|
|
|
3823
|
-
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
5989
|
+
template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3824
5990
|
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3825
5991
|
}
|
|
3826
5992
|
|
|
3827
|
-
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
5993
|
+
template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3828
5994
|
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3829
5995
|
}
|
|
3830
5996
|
|
|
3831
|
-
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
5997
|
+
template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3832
5998
|
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3833
5999
|
}
|
|
3834
6000
|
|
|
3835
|
-
template <>
|
|
3836
|
-
|
|
6001
|
+
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
6002
|
+
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
|
6003
|
+
}
|
|
6004
|
+
|
|
6005
|
+
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3837
6006
|
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3838
6007
|
}
|
|
3839
6008
|
|
|
3840
6009
|
// gemm
|
|
3841
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
6010
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
|
3842
6011
|
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
|
3843
6012
|
|
|
3844
|
-
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
6013
|
+
template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3845
6014
|
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3846
6015
|
}
|
|
3847
6016
|
|
|
3848
|
-
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
6017
|
+
template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3849
6018
|
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3850
6019
|
}
|
|
3851
6020
|
|
|
3852
|
-
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
6021
|
+
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3853
6022
|
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3854
6023
|
}
|
|
3855
6024
|
|
|
3856
|
-
template <>
|
|
3857
|
-
|
|
6025
|
+
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
6026
|
+
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
|
6027
|
+
}
|
|
6028
|
+
|
|
6029
|
+
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3858
6030
|
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3859
6031
|
}
|
|
3860
6032
|
|
|
@@ -3863,37 +6035,37 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
|
3863
6035
|
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
3864
6036
|
};
|
|
3865
6037
|
|
|
3866
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
|
6038
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
|
|
3867
6039
|
|
|
3868
6040
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
3869
6041
|
// not realy a GGML_TYPE_Q8_0 but same size.
|
|
3870
6042
|
switch (op->op) {
|
|
3871
|
-
|
|
3872
|
-
|
|
3873
|
-
|
|
3874
|
-
|
|
3875
|
-
|
|
3876
|
-
|
|
3877
|
-
|
|
3878
|
-
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
6043
|
+
case GGML_OP_MUL_MAT:
|
|
6044
|
+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
|
6045
|
+
return true;
|
|
6046
|
+
case GGML_OP_MUL_MAT_ID:
|
|
6047
|
+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
|
6048
|
+
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
|
6049
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
|
6050
|
+
return true;
|
|
6051
|
+
default:
|
|
6052
|
+
// GGML_ABORT("fatal error");
|
|
6053
|
+
break;
|
|
3882
6054
|
}
|
|
3883
6055
|
return false;
|
|
3884
6056
|
}
|
|
3885
6057
|
|
|
3886
6058
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
|
|
3887
6059
|
switch (op->op) {
|
|
3888
|
-
|
|
3889
|
-
|
|
3890
|
-
|
|
3891
|
-
|
|
3892
|
-
|
|
3893
|
-
|
|
3894
|
-
|
|
3895
|
-
|
|
3896
|
-
|
|
6060
|
+
case GGML_OP_MUL_MAT:
|
|
6061
|
+
forward_mul_mat(params, op);
|
|
6062
|
+
return true;
|
|
6063
|
+
case GGML_OP_MUL_MAT_ID:
|
|
6064
|
+
forward_mul_mat_id(params, op);
|
|
6065
|
+
return true;
|
|
6066
|
+
default:
|
|
6067
|
+
// GGML_ABORT("fatal error");
|
|
6068
|
+
break;
|
|
3897
6069
|
}
|
|
3898
6070
|
return false;
|
|
3899
6071
|
}
|
|
@@ -3925,17 +6097,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
3925
6097
|
// GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
|
|
3926
6098
|
|
|
3927
6099
|
char * wdata = static_cast<char *>(params->wdata);
|
|
3928
|
-
const size_t nbw1 = ggml_row_size(
|
|
6100
|
+
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
|
3929
6101
|
|
|
3930
6102
|
assert(params->wsize >= nbw1 * ne11);
|
|
3931
6103
|
|
|
3932
|
-
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(
|
|
6104
|
+
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
|
3933
6105
|
|
|
3934
6106
|
int64_t i11_processed = 0;
|
|
3935
6107
|
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
3936
|
-
|
|
3937
|
-
INTER_SIZE);
|
|
6108
|
+
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
|
3938
6109
|
}
|
|
6110
|
+
|
|
3939
6111
|
i11_processed = ne11 - ne11 % 4;
|
|
3940
6112
|
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
3941
6113
|
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
|
@@ -3944,26 +6116,28 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
3944
6116
|
ggml_barrier(params->threadpool);
|
|
3945
6117
|
|
|
3946
6118
|
const void * src1_wdata = params->wdata;
|
|
3947
|
-
const size_t src1_col_stride = ggml_row_size(
|
|
6119
|
+
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
|
3948
6120
|
int64_t src0_start = (ith * ne01) / nth;
|
|
3949
6121
|
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
|
3950
6122
|
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
|
3951
|
-
src0_end = (src0_end
|
|
6123
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
|
3952
6124
|
if (src0_start >= src0_end) {
|
|
3953
6125
|
return;
|
|
3954
6126
|
}
|
|
3955
6127
|
|
|
3956
6128
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
3957
6129
|
if (ne11 > 3) {
|
|
3958
|
-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00,
|
|
3959
|
-
|
|
3960
|
-
|
|
6130
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
6131
|
+
(float *) ((char *) dst->data) + src0_start, ne01,
|
|
6132
|
+
(const char *) src0->data + src0_start * nb01,
|
|
6133
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
3961
6134
|
}
|
|
3962
6135
|
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
3963
|
-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00,
|
|
3964
|
-
|
|
3965
|
-
|
|
3966
|
-
|
|
6136
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
6137
|
+
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
6138
|
+
(const char *) src0->data + src0_start * nb01,
|
|
6139
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
6140
|
+
src0_end - src0_start);
|
|
3967
6141
|
}
|
|
3968
6142
|
}
|
|
3969
6143
|
|
|
@@ -3978,7 +6152,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
3978
6152
|
const int ith = params->ith;
|
|
3979
6153
|
const int nth = params->nth;
|
|
3980
6154
|
|
|
3981
|
-
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(
|
|
6155
|
+
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
|
3982
6156
|
|
|
3983
6157
|
// we don't support permuted src0 or src1
|
|
3984
6158
|
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
|
@@ -4000,7 +6174,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
4000
6174
|
const int n_ids = ids->ne[0]; // n_expert_used
|
|
4001
6175
|
const int n_as = ne02; // n_expert
|
|
4002
6176
|
|
|
4003
|
-
const size_t nbw1 = ggml_row_size(
|
|
6177
|
+
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
|
4004
6178
|
const size_t nbw2 = nbw1*ne11;
|
|
4005
6179
|
const size_t nbw3 = nbw2*ne12;
|
|
4006
6180
|
|
|
@@ -4012,12 +6186,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
4012
6186
|
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
|
4013
6187
|
n_as * ne12 * sizeof(mmid_row_mapping)));
|
|
4014
6188
|
|
|
4015
|
-
auto
|
|
4016
|
-
auto
|
|
4017
|
-
|
|
6189
|
+
auto * wdata = (char *) params->wdata;
|
|
6190
|
+
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
|
6191
|
+
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
|
6192
|
+
|
|
4018
6193
|
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
|
4019
6194
|
|
|
4020
|
-
// src1: float32 =>
|
|
6195
|
+
// src1: float32 => param type
|
|
4021
6196
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
4022
6197
|
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
4023
6198
|
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
|
@@ -4056,34 +6231,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
4056
6231
|
continue;
|
|
4057
6232
|
}
|
|
4058
6233
|
|
|
4059
|
-
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
6234
|
+
const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
4060
6235
|
|
|
4061
6236
|
//const int64_t nr0 = ne01; // src0 rows
|
|
4062
6237
|
const int64_t nr1 = cne1; // src1 rows
|
|
4063
6238
|
|
|
4064
6239
|
int64_t src0_cur_start = (ith * ne01) / nth;
|
|
4065
6240
|
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
4066
|
-
src0_cur_start =
|
|
4067
|
-
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
4068
|
-
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
4069
6241
|
|
|
4070
|
-
|
|
6242
|
+
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
6243
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
6244
|
+
|
|
6245
|
+
if (src0_cur_start >= src0_cur_end) {
|
|
6246
|
+
return;
|
|
6247
|
+
}
|
|
4071
6248
|
|
|
4072
6249
|
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
|
4073
6250
|
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
4074
|
-
const int id = row_mapping.i1; // selected expert index
|
|
4075
6251
|
|
|
4076
|
-
const
|
|
4077
|
-
|
|
6252
|
+
const int id = row_mapping.i1; // selected expert index
|
|
6253
|
+
|
|
6254
|
+
const int64_t i11 = id % ne11;
|
|
6255
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
4078
6256
|
|
|
4079
|
-
const int64_t
|
|
4080
|
-
const int64_t
|
|
6257
|
+
const int64_t i1 = id; // selected expert index
|
|
6258
|
+
const int64_t i2 = i12; // row
|
|
4081
6259
|
|
|
4082
|
-
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
|
6260
|
+
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
|
4083
6261
|
|
|
4084
|
-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
|
4085
|
-
|
|
4086
|
-
|
|
6262
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
6263
|
+
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
|
6264
|
+
src0_cur + src0_cur_start * nb01,
|
|
4087
6265
|
src1_col, 1, src0_cur_end - src0_cur_start);
|
|
4088
6266
|
}
|
|
4089
6267
|
}
|
|
@@ -4098,12 +6276,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
4098
6276
|
};
|
|
4099
6277
|
|
|
4100
6278
|
// instance for Q4
|
|
4101
|
-
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
|
4102
|
-
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
|
4103
|
-
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
|
6279
|
+
static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
|
6280
|
+
static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
|
6281
|
+
static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
|
6282
|
+
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
|
4104
6283
|
|
|
4105
6284
|
// instance for IQ4
|
|
4106
|
-
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
|
6285
|
+
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
|
4107
6286
|
|
|
4108
6287
|
} // namespace ggml::cpu::aarch64
|
|
4109
6288
|
|
|
@@ -4124,6 +6303,12 @@ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(con
|
|
|
4124
6303
|
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
|
4125
6304
|
}
|
|
4126
6305
|
}
|
|
6306
|
+
} else if (cur->type == GGML_TYPE_Q4_K) {
|
|
6307
|
+
if (ggml_cpu_has_avx2()) {
|
|
6308
|
+
if (cur->ne[1] % 8 == 0) {
|
|
6309
|
+
return &ggml::cpu::aarch64::q4_K_8x8_q8_K;
|
|
6310
|
+
}
|
|
6311
|
+
}
|
|
4127
6312
|
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
|
4128
6313
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
4129
6314
|
if (cur->ne[1] % 4 == 0) {
|