@fugood/llama.node 0.3.7 → 0.3.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +8 -0
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +156 -6
- package/src/LlamaContext.h +5 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -53,6 +53,9 @@
|
|
|
53
53
|
#include "ggml-cpu-impl.h"
|
|
54
54
|
#include "ggml-quants.h"
|
|
55
55
|
|
|
56
|
+
#include <atomic>
|
|
57
|
+
#include <array>
|
|
58
|
+
|
|
56
59
|
#ifdef _MSC_VER
|
|
57
60
|
#define NOINLINE __declspec(noinline)
|
|
58
61
|
#else
|
|
@@ -134,6 +137,16 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
|
|
134
137
|
return _mm512_fmadd_ps(a, b, c);
|
|
135
138
|
}
|
|
136
139
|
#endif
|
|
140
|
+
#if defined(__AVX512BF16__)
|
|
141
|
+
template <>
|
|
142
|
+
inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
|
|
143
|
+
return _mm512_dpbf16_ps(c, a, b);
|
|
144
|
+
}
|
|
145
|
+
template <>
|
|
146
|
+
inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
|
|
147
|
+
return _mm256_dpbf16_ps(c, a, b);
|
|
148
|
+
}
|
|
149
|
+
#endif
|
|
137
150
|
#endif
|
|
138
151
|
|
|
139
152
|
#if defined(__ARM_FEATURE_FMA)
|
|
@@ -226,6 +239,13 @@ template <> inline __m256 load(const float *p) {
|
|
|
226
239
|
}
|
|
227
240
|
#endif // __AVX__
|
|
228
241
|
|
|
242
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
243
|
+
template <> inline __m256 load(const ggml_bf16_t *p) {
|
|
244
|
+
return _mm256_castsi256_ps(
|
|
245
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
|
|
246
|
+
}
|
|
247
|
+
#endif // __AVX2__
|
|
248
|
+
|
|
229
249
|
#if defined(__F16C__)
|
|
230
250
|
template <> inline __m256 load(const ggml_fp16_t *p) {
|
|
231
251
|
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
|
|
@@ -239,8 +259,27 @@ template <> inline __m512 load(const float *p) {
|
|
|
239
259
|
template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
240
260
|
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
|
|
241
261
|
}
|
|
262
|
+
template <> inline __m512 load(const ggml_bf16_t *p) {
|
|
263
|
+
return _mm512_castsi512_ps(
|
|
264
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
|
|
265
|
+
}
|
|
242
266
|
#endif // __AVX512F__
|
|
243
267
|
|
|
268
|
+
#if defined(__AVX512BF16__)
|
|
269
|
+
template <> inline __m512bh load(const ggml_bf16_t *p) {
|
|
270
|
+
return (__m512bh)_mm512_loadu_ps((const float *)p);
|
|
271
|
+
}
|
|
272
|
+
template <> inline __m256bh load(const ggml_bf16_t *p) {
|
|
273
|
+
return (__m256bh)_mm256_loadu_ps((const float *)p);
|
|
274
|
+
}
|
|
275
|
+
template <> inline __m512bh load(const float *p) {
|
|
276
|
+
return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
|
|
277
|
+
}
|
|
278
|
+
template <> inline __m256bh load(const float *p) {
|
|
279
|
+
return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
|
|
280
|
+
}
|
|
281
|
+
#endif
|
|
282
|
+
|
|
244
283
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
245
284
|
// CONSTANTS
|
|
246
285
|
|
|
@@ -252,199 +291,170 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|
|
252
291
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
253
292
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
254
293
|
|
|
294
|
+
template <int M>
|
|
295
|
+
static inline int64_t BLOCK_SIZE(size_t m) {
|
|
296
|
+
const int64_t NB_BLOC_M = (m + M - 1) / M;
|
|
297
|
+
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
|
|
301
|
+
return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
|
|
302
|
+
}
|
|
303
|
+
|
|
255
304
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
|
256
305
|
class tinyBLAS {
|
|
257
306
|
public:
|
|
258
|
-
tinyBLAS(int64_t k,
|
|
307
|
+
tinyBLAS(const ggml_compute_params * params, int64_t k,
|
|
259
308
|
const TA *A, int64_t lda,
|
|
260
309
|
const TB *B, int64_t ldb,
|
|
261
|
-
TC *C, int64_t ldc
|
|
262
|
-
|
|
263
|
-
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
310
|
+
TC *C, int64_t ldc)
|
|
311
|
+
: params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
|
264
312
|
}
|
|
265
313
|
|
|
266
|
-
|
|
267
|
-
|
|
314
|
+
bool matmul(int64_t m, int64_t n) {
|
|
315
|
+
if (k % KN != 0)
|
|
316
|
+
return false;
|
|
317
|
+
// compute RM for only need tile with size RM&RM-1
|
|
318
|
+
#if VECTOR_REGISTERS == 32
|
|
319
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
320
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
321
|
+
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
|
322
|
+
return true;
|
|
323
|
+
}
|
|
324
|
+
if (m % 8 == 0 ) {
|
|
325
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
326
|
+
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
|
327
|
+
return true;
|
|
328
|
+
}
|
|
329
|
+
if (m % 4 == 0) {
|
|
330
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
331
|
+
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
|
332
|
+
return true;
|
|
333
|
+
}
|
|
334
|
+
#else // VECTOR_REGISTERS == 16
|
|
335
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
336
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
337
|
+
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
|
338
|
+
return true;
|
|
339
|
+
}
|
|
340
|
+
if (m % 8 == 0 ) {
|
|
341
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
342
|
+
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
|
343
|
+
return true;
|
|
344
|
+
}
|
|
345
|
+
if (m % 4 == 0) {
|
|
346
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
347
|
+
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
|
348
|
+
return true;
|
|
349
|
+
}
|
|
350
|
+
#endif
|
|
351
|
+
return false;
|
|
268
352
|
}
|
|
269
353
|
|
|
270
354
|
private:
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
mc = 4;
|
|
282
|
-
nc = 5;
|
|
283
|
-
gemm<4, 5>(m0, m, n0, n);
|
|
284
|
-
break;
|
|
285
|
-
case 0x54:
|
|
286
|
-
mc = 5;
|
|
287
|
-
nc = 4;
|
|
288
|
-
gemm<5, 4>(m0, m, n0, n);
|
|
289
|
-
break;
|
|
290
|
-
case 0x44:
|
|
291
|
-
mc = 4;
|
|
292
|
-
nc = 4;
|
|
293
|
-
gemm<4, 4>(m0, m, n0, n);
|
|
294
|
-
break;
|
|
295
|
-
case 0x53:
|
|
296
|
-
mc = 5;
|
|
297
|
-
nc = 3;
|
|
298
|
-
gemm<5, 3>(m0, m, n0, n);
|
|
299
|
-
break;
|
|
300
|
-
case 0x35:
|
|
301
|
-
mc = 3;
|
|
302
|
-
nc = 5;
|
|
303
|
-
gemm<3, 5>(m0, m, n0, n);
|
|
304
|
-
break;
|
|
305
|
-
case 0x43:
|
|
306
|
-
mc = 4;
|
|
307
|
-
nc = 3;
|
|
308
|
-
gemm<4, 3>(m0, m, n0, n);
|
|
309
|
-
break;
|
|
310
|
-
#else
|
|
311
|
-
case 0x55:
|
|
312
|
-
case 0x54:
|
|
313
|
-
case 0x53:
|
|
314
|
-
case 0x45:
|
|
315
|
-
case 0x44:
|
|
316
|
-
case 0x43:
|
|
317
|
-
mc = 4;
|
|
318
|
-
nc = 3;
|
|
319
|
-
gemm<4, 3>(m0, m, n0, n);
|
|
320
|
-
break;
|
|
321
|
-
case 0x35:
|
|
322
|
-
#endif
|
|
323
|
-
case 0x34:
|
|
324
|
-
mc = 3;
|
|
325
|
-
nc = 4;
|
|
326
|
-
gemm<3, 4>(m0, m, n0, n);
|
|
327
|
-
break;
|
|
328
|
-
case 0x52:
|
|
329
|
-
mc = 5;
|
|
330
|
-
nc = 2;
|
|
331
|
-
gemm<5, 2>(m0, m, n0, n);
|
|
332
|
-
break;
|
|
333
|
-
case 0x33:
|
|
334
|
-
mc = 3;
|
|
335
|
-
nc = 3;
|
|
336
|
-
gemm<3, 3>(m0, m, n0, n);
|
|
337
|
-
break;
|
|
338
|
-
case 0x25:
|
|
339
|
-
mc = 2;
|
|
340
|
-
nc = 5;
|
|
341
|
-
gemm<2, 5>(m0, m, n0, n);
|
|
342
|
-
break;
|
|
343
|
-
case 0x42:
|
|
344
|
-
mc = 4;
|
|
345
|
-
nc = 2;
|
|
346
|
-
gemm<4, 2>(m0, m, n0, n);
|
|
347
|
-
break;
|
|
348
|
-
case 0x24:
|
|
349
|
-
mc = 2;
|
|
350
|
-
nc = 4;
|
|
351
|
-
gemm<2, 4>(m0, m, n0, n);
|
|
352
|
-
break;
|
|
353
|
-
case 0x32:
|
|
354
|
-
mc = 3;
|
|
355
|
-
nc = 2;
|
|
356
|
-
gemm<3, 2>(m0, m, n0, n);
|
|
357
|
-
break;
|
|
358
|
-
case 0x23:
|
|
359
|
-
mc = 2;
|
|
360
|
-
nc = 3;
|
|
361
|
-
gemm<2, 3>(m0, m, n0, n);
|
|
362
|
-
break;
|
|
363
|
-
case 0x51:
|
|
364
|
-
mc = 5;
|
|
365
|
-
nc = 1;
|
|
366
|
-
gemm<5, 1>(m0, m, n0, n);
|
|
367
|
-
break;
|
|
368
|
-
case 0x41:
|
|
369
|
-
mc = 4;
|
|
370
|
-
nc = 1;
|
|
371
|
-
gemm<4, 1>(m0, m, n0, n);
|
|
372
|
-
break;
|
|
373
|
-
case 0x22:
|
|
374
|
-
mc = 2;
|
|
375
|
-
nc = 2;
|
|
376
|
-
gemm<2, 2>(m0, m, n0, n);
|
|
377
|
-
break;
|
|
378
|
-
case 0x15:
|
|
379
|
-
mc = 1;
|
|
380
|
-
nc = 5;
|
|
381
|
-
gemm<1, 5>(m0, m, n0, n);
|
|
382
|
-
break;
|
|
383
|
-
case 0x14:
|
|
384
|
-
mc = 1;
|
|
385
|
-
nc = 4;
|
|
386
|
-
gemm<1, 4>(m0, m, n0, n);
|
|
387
|
-
break;
|
|
388
|
-
case 0x31:
|
|
389
|
-
mc = 3;
|
|
390
|
-
nc = 1;
|
|
391
|
-
gemm<3, 1>(m0, m, n0, n);
|
|
392
|
-
break;
|
|
393
|
-
case 0x13:
|
|
394
|
-
mc = 1;
|
|
395
|
-
nc = 3;
|
|
396
|
-
gemm<1, 3>(m0, m, n0, n);
|
|
397
|
-
break;
|
|
398
|
-
case 0x21:
|
|
399
|
-
mc = 2;
|
|
400
|
-
nc = 1;
|
|
401
|
-
gemm<2, 1>(m0, m, n0, n);
|
|
402
|
-
break;
|
|
403
|
-
case 0x12:
|
|
404
|
-
mc = 1;
|
|
405
|
-
nc = 2;
|
|
406
|
-
gemm<1, 2>(m0, m, n0, n);
|
|
407
|
-
break;
|
|
408
|
-
case 0x11:
|
|
409
|
-
mc = 1;
|
|
410
|
-
nc = 1;
|
|
411
|
-
gemm<1, 1>(m0, m, n0, n);
|
|
412
|
-
break;
|
|
413
|
-
default:
|
|
414
|
-
return;
|
|
355
|
+
template <int RM, int RN, int BM>
|
|
356
|
+
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
|
357
|
+
if (SIZE_N == RN) {
|
|
358
|
+
return gemm<RM, RN, BM>(m, n, BN);
|
|
359
|
+
}
|
|
360
|
+
if constexpr (RN > 1) {
|
|
361
|
+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
362
|
+
} else {
|
|
363
|
+
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
|
364
|
+
GGML_ASSERT(false); // we have miss something.
|
|
415
365
|
}
|
|
416
|
-
mp = m0 + (m - m0) / mc * mc;
|
|
417
|
-
np = n0 + (n - n0) / nc * nc;
|
|
418
|
-
mnpack(mp, m, n0, np);
|
|
419
|
-
mnpack(m0, m, np, n);
|
|
420
366
|
}
|
|
421
367
|
|
|
422
368
|
template <int RM, int RN>
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
int64_t
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
for (int64_t i = 0; i < RM; ++i)
|
|
444
|
-
|
|
369
|
+
inline void gemm_bloc(int64_t ii, int64_t jj) {
|
|
370
|
+
D Cv[RN][RM] = {};
|
|
371
|
+
for (int64_t l = 0; l < k; l += KN) {
|
|
372
|
+
// help compiler for op order.
|
|
373
|
+
if constexpr (RM <= RN) {
|
|
374
|
+
V Av[RM];
|
|
375
|
+
for (int64_t i = 0; i < RM; ++i) {
|
|
376
|
+
Av[i] = load<V>(A + lda * (ii + i) + l);
|
|
377
|
+
}
|
|
378
|
+
for (int64_t j = 0; j < RN; ++j) {
|
|
379
|
+
V Bv = load<V>(B + ldb * (jj + j) + l);
|
|
380
|
+
for (int64_t i = 0; i < RM; ++i) {
|
|
381
|
+
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
} else {
|
|
385
|
+
V Bv[RN];
|
|
386
|
+
for (int64_t j = 0; j < RN; ++j) {
|
|
387
|
+
Bv[j] = load<V>(B + ldb * (jj + j) + l);
|
|
388
|
+
}
|
|
389
|
+
for (int64_t i = 0; i < RM; ++i) {
|
|
390
|
+
V Av = load<V>(A + lda * (ii + i) + l);
|
|
391
|
+
for (int64_t j = 0; j < RN; ++j) {
|
|
392
|
+
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
for (int64_t j = 0; j < RN; ++j)
|
|
398
|
+
for (int64_t i = 0; i < RM; ++i)
|
|
399
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
template <int RM, int RN, int BM>
|
|
403
|
+
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
404
|
+
static std::atomic<int64_t> current_chunk;
|
|
405
|
+
|
|
406
|
+
GGML_ASSERT(m % (RM * BM) == 0);
|
|
407
|
+
const int64_t ytiles = m / (RM * BM);
|
|
408
|
+
const int64_t xtiles = (n + RN -1) / RN;
|
|
409
|
+
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
|
410
|
+
|
|
411
|
+
// "round" bloc_size to "nearest" BN
|
|
412
|
+
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
|
413
|
+
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
|
414
|
+
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
|
415
|
+
const int64_t nb_job = ytiles * NB_BN;
|
|
416
|
+
|
|
417
|
+
if (params->ith == 0) {
|
|
418
|
+
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
419
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
420
|
+
std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
ggml_barrier(params->threadpool);
|
|
424
|
+
|
|
425
|
+
int64_t job = params->ith;
|
|
426
|
+
while (job < nb_job) {
|
|
427
|
+
const int64_t ii = (job % ytiles) * RM * BM;
|
|
428
|
+
const int64_t jb = job / ytiles;
|
|
429
|
+
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
|
430
|
+
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
|
431
|
+
|
|
432
|
+
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
|
433
|
+
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
|
434
|
+
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
|
435
|
+
|
|
436
|
+
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
|
437
|
+
int64_t jj = jj0;
|
|
438
|
+
for (; jj < jj1; jj += RN) {
|
|
439
|
+
gemm_bloc<RM, RN>(ii + bi, jj);
|
|
440
|
+
}
|
|
441
|
+
if constexpr (RN > 1) {
|
|
442
|
+
for (; jj < jj2; jj += RN - 1) {
|
|
443
|
+
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
GGML_ASSERT(jj == jj2);
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
// next step.
|
|
450
|
+
job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
|
445
451
|
}
|
|
452
|
+
|
|
453
|
+
ggml_barrier(params->threadpool);
|
|
454
|
+
return;
|
|
446
455
|
}
|
|
447
456
|
|
|
457
|
+
const ggml_compute_params * params;
|
|
448
458
|
const TA *const A;
|
|
449
459
|
const TB *const B;
|
|
450
460
|
TC *const C;
|
|
@@ -452,8 +462,6 @@ class tinyBLAS {
|
|
|
452
462
|
const int64_t lda;
|
|
453
463
|
const int64_t ldb;
|
|
454
464
|
const int64_t ldc;
|
|
455
|
-
const int ith;
|
|
456
|
-
const int nth;
|
|
457
465
|
};
|
|
458
466
|
|
|
459
467
|
//////////////////////////////////////////////////////////////////////////////////////////
|
|
@@ -993,8 +1001,10 @@ class tinyBLAS_Q0_AVX {
|
|
|
993
1001
|
|
|
994
1002
|
inline __m256 updot(__m256i u, __m256i s) {
|
|
995
1003
|
__m256i res;
|
|
996
|
-
#if defined(
|
|
1004
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
997
1005
|
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
|
1006
|
+
#elif defined(__AVXVNNI__)
|
|
1007
|
+
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
|
|
998
1008
|
#else
|
|
999
1009
|
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
|
1000
1010
|
#endif
|
|
@@ -1043,9 +1053,9 @@ class tinyBLAS_Q0_AVX {
|
|
|
1043
1053
|
} \
|
|
1044
1054
|
|
|
1045
1055
|
template <typename TA, typename TB, typename TC>
|
|
1046
|
-
class
|
|
1056
|
+
class tinyBLAS_Q0_PPC {
|
|
1047
1057
|
public:
|
|
1048
|
-
|
|
1058
|
+
tinyBLAS_Q0_PPC(int64_t k,
|
|
1049
1059
|
const TA *A, int64_t lda,
|
|
1050
1060
|
const TB *B, int64_t ldb,
|
|
1051
1061
|
TC *C, int64_t ldc,
|
|
@@ -1054,74 +1064,773 @@ class tinyBLAS_PPC {
|
|
|
1054
1064
|
}
|
|
1055
1065
|
|
|
1056
1066
|
void matmul(int64_t m, int64_t n) {
|
|
1057
|
-
|
|
1067
|
+
mnpack(0, m, 0, n);
|
|
1058
1068
|
}
|
|
1059
1069
|
|
|
1060
1070
|
private:
|
|
1061
1071
|
|
|
1062
|
-
|
|
1072
|
+
template<int RM, int RN>
|
|
1073
|
+
inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
|
|
1074
|
+
for (int I = 0; I < RM; I++) {
|
|
1075
|
+
for (int J = 0; J < RN; J++) {
|
|
1076
|
+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1063
1080
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1081
|
+
template<int size>
|
|
1082
|
+
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
|
|
1083
|
+
vector signed int vec_C[4];
|
|
1084
|
+
vector float CA[4] = {0};
|
|
1085
|
+
vector float res[4] = {0};
|
|
1086
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
1087
|
+
for (int i = 0; i < 4; i++) {
|
|
1088
|
+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
|
1089
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
1090
|
+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
|
1091
|
+
}
|
|
1092
|
+
}
|
|
1069
1093
|
|
|
1070
|
-
|
|
1071
|
-
|
|
1094
|
+
template<typename VA, typename VB>
|
|
1095
|
+
void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
|
1096
|
+
int64_t i, j;
|
|
1097
|
+
TA *aoffset = NULL;
|
|
1098
|
+
VA *vecOffset = NULL;
|
|
1099
|
+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
1100
|
+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
1101
|
+
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
|
1102
|
+
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
|
|
1103
|
+
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
|
|
1104
|
+
VB t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1105
|
+
vector unsigned char xor_vector;
|
|
1106
|
+
uint8_t flip_vec = 0x80;
|
|
1107
|
+
xor_vector = vec_splats(flip_vec);
|
|
1108
|
+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1109
|
+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1110
|
+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
1111
|
+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
|
1112
|
+
|
|
1113
|
+
aoffset = const_cast<TA*>(a);
|
|
1114
|
+
vecOffset = vec;
|
|
1072
1115
|
j = (rows >> 3);
|
|
1073
1116
|
if (j > 0) {
|
|
1074
1117
|
do {
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
i = (cols >> 3);
|
|
1085
|
-
if (i > 0) {
|
|
1086
|
-
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
|
1087
|
-
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
|
|
1088
|
-
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1089
|
-
do {
|
|
1090
|
-
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
|
1091
|
-
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
|
1092
|
-
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
|
1093
|
-
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
|
1094
|
-
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
|
|
1095
|
-
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
|
|
1096
|
-
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
|
|
1097
|
-
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
|
|
1098
|
-
__builtin_vsx_disassemble_pair(c1, &C1);
|
|
1099
|
-
__builtin_vsx_disassemble_pair(c2, &C2);
|
|
1100
|
-
__builtin_vsx_disassemble_pair(c3, &C3);
|
|
1101
|
-
__builtin_vsx_disassemble_pair(c4, &C4);
|
|
1102
|
-
__builtin_vsx_disassemble_pair(c5, &C5);
|
|
1103
|
-
__builtin_vsx_disassemble_pair(c6, &C6);
|
|
1104
|
-
__builtin_vsx_disassemble_pair(c7, &C7);
|
|
1105
|
-
__builtin_vsx_disassemble_pair(c8, &C8);
|
|
1118
|
+
aoffset1 = aoffset;
|
|
1119
|
+
aoffset2 = aoffset1 + lda;
|
|
1120
|
+
aoffset3 = aoffset2 + lda;
|
|
1121
|
+
aoffset4 = aoffset3 + lda;
|
|
1122
|
+
aoffset5 = aoffset4 + lda;
|
|
1123
|
+
aoffset6 = aoffset5 + lda;
|
|
1124
|
+
aoffset7 = aoffset6 + lda;
|
|
1125
|
+
aoffset8 = aoffset7 + lda;
|
|
1126
|
+
aoffset += 8 * lda;
|
|
1106
1127
|
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
vec_xst(t8, 0, boffset+12);
|
|
1128
|
+
i = (cols >> 3);
|
|
1129
|
+
if (i > 0) {
|
|
1130
|
+
do {
|
|
1131
|
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
|
1132
|
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
|
1133
|
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
|
1134
|
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
|
|
1135
|
+
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
|
|
1136
|
+
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
|
|
1137
|
+
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
|
|
1138
|
+
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
|
|
1119
1139
|
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1140
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
|
1141
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
|
1142
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
|
1143
|
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
|
1144
|
+
__builtin_vsx_disassemble_pair(c5, &C5);
|
|
1145
|
+
__builtin_vsx_disassemble_pair(c6, &C6);
|
|
1146
|
+
__builtin_vsx_disassemble_pair(c7, &C7);
|
|
1147
|
+
__builtin_vsx_disassemble_pair(c8, &C8);
|
|
1148
|
+
|
|
1149
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1150
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1151
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1152
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1153
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1154
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1155
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1156
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1157
|
+
if (flip == true) {
|
|
1158
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1159
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1160
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1161
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1162
|
+
}
|
|
1163
|
+
vec_xst(t5, 0, vecOffset);
|
|
1164
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1165
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1166
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1167
|
+
|
|
1168
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1169
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1170
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1171
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1172
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1173
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1174
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1175
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1176
|
+
if (flip == true) {
|
|
1177
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1178
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1179
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1180
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1181
|
+
}
|
|
1182
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1183
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1184
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1185
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1186
|
+
|
|
1187
|
+
t1 = vec_perm(c5[0], c6[0], swiz1);
|
|
1188
|
+
t2 = vec_perm(c5[0], c6[0], swiz2);
|
|
1189
|
+
t3 = vec_perm(c7[0], c8[0], swiz1);
|
|
1190
|
+
t4 = vec_perm(c7[0], c8[0], swiz2);
|
|
1191
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1192
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1193
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1194
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1195
|
+
if (flip == true) {
|
|
1196
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1197
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1198
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1199
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1200
|
+
}
|
|
1201
|
+
vec_xst(t5, 0, vecOffset+128);
|
|
1202
|
+
vec_xst(t6, 0, vecOffset+144);
|
|
1203
|
+
vec_xst(t7, 0, vecOffset+160);
|
|
1204
|
+
vec_xst(t8, 0, vecOffset+176);
|
|
1205
|
+
|
|
1206
|
+
t1 = vec_perm(c5[1], c6[1], swiz1);
|
|
1207
|
+
t2 = vec_perm(c5[1], c6[1], swiz2);
|
|
1208
|
+
t3 = vec_perm(c7[1], c8[1], swiz1);
|
|
1209
|
+
t4 = vec_perm(c7[1], c8[1], swiz2);
|
|
1210
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1211
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1212
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1213
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1214
|
+
if (flip == true) {
|
|
1215
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1216
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1217
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1218
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1219
|
+
}
|
|
1220
|
+
vec_xst(t5, 0, vecOffset+192);
|
|
1221
|
+
vec_xst(t6, 0, vecOffset+208);
|
|
1222
|
+
vec_xst(t7, 0, vecOffset+224);
|
|
1223
|
+
vec_xst(t8, 0, vecOffset+240);
|
|
1224
|
+
|
|
1225
|
+
aoffset1 += lda;
|
|
1226
|
+
aoffset2 += lda;
|
|
1227
|
+
aoffset3 += lda;
|
|
1228
|
+
aoffset4 += lda;
|
|
1229
|
+
aoffset5 += lda;
|
|
1230
|
+
aoffset6 += lda;
|
|
1231
|
+
aoffset7 += lda;
|
|
1232
|
+
aoffset8 += lda;
|
|
1233
|
+
vecOffset += 256;
|
|
1234
|
+
i--;
|
|
1235
|
+
} while(i > 0);
|
|
1236
|
+
}
|
|
1237
|
+
j--;
|
|
1238
|
+
} while(j > 0);
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
if (rows & 4) {
|
|
1242
|
+
aoffset1 = aoffset;
|
|
1243
|
+
aoffset2 = aoffset1 + lda;
|
|
1244
|
+
aoffset3 = aoffset2 + lda;
|
|
1245
|
+
aoffset4 = aoffset3 + lda;
|
|
1246
|
+
aoffset += 4 * lda;
|
|
1247
|
+
|
|
1248
|
+
i = (cols >> 3);
|
|
1249
|
+
if (i > 0) {
|
|
1250
|
+
do {
|
|
1251
|
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
|
1252
|
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
|
1253
|
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
|
1254
|
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
|
|
1255
|
+
|
|
1256
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
|
1257
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
|
1258
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
|
1259
|
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
|
1260
|
+
|
|
1261
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1262
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1263
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1264
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1265
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1266
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1267
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1268
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1269
|
+
if (flip == true) {
|
|
1270
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1271
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1272
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1273
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1274
|
+
}
|
|
1275
|
+
vec_xst(t5, 0, vecOffset);
|
|
1276
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1277
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1278
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1279
|
+
|
|
1280
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1281
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1282
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1283
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1284
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1285
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1286
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1287
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1288
|
+
if (flip == true) {
|
|
1289
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1290
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1291
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1292
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1293
|
+
}
|
|
1294
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1295
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1296
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1297
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1298
|
+
|
|
1299
|
+
aoffset1 += lda;
|
|
1300
|
+
aoffset2 += lda;
|
|
1301
|
+
aoffset3 += lda;
|
|
1302
|
+
aoffset4 += lda;
|
|
1303
|
+
vecOffset += 128;
|
|
1304
|
+
i--;
|
|
1305
|
+
} while(i > 0);
|
|
1306
|
+
}
|
|
1307
|
+
}
|
|
1308
|
+
if (rows & 3) {
|
|
1309
|
+
aoffset1 = aoffset;
|
|
1310
|
+
aoffset2 = aoffset1 + lda;
|
|
1311
|
+
aoffset3 = aoffset2 + lda;
|
|
1312
|
+
i = (cols >> 3);
|
|
1313
|
+
if (i > 0) {
|
|
1314
|
+
do {
|
|
1315
|
+
switch(rows) {
|
|
1316
|
+
case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
|
1317
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
|
1318
|
+
case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
|
1319
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
|
1320
|
+
case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
|
1321
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
|
1322
|
+
break;
|
|
1323
|
+
}
|
|
1324
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1325
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1326
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1327
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1328
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1329
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1330
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1331
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1332
|
+
if (flip == true) {
|
|
1333
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1334
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1335
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1336
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1337
|
+
}
|
|
1338
|
+
vec_xst(t5, 0, vecOffset);
|
|
1339
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1340
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1341
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1342
|
+
|
|
1343
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1344
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1345
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1346
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1347
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1348
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1349
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1350
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1351
|
+
if (flip == true) {
|
|
1352
|
+
t5 = vec_xor(t5, xor_vector);
|
|
1353
|
+
t6 = vec_xor(t6, xor_vector);
|
|
1354
|
+
t7 = vec_xor(t7, xor_vector);
|
|
1355
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1356
|
+
}
|
|
1357
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1358
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1359
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1360
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1361
|
+
|
|
1362
|
+
aoffset1 += lda;
|
|
1363
|
+
aoffset2 += lda;
|
|
1364
|
+
aoffset3 += lda;
|
|
1365
|
+
vecOffset += 128;
|
|
1366
|
+
i--;
|
|
1367
|
+
} while(i > 0);
|
|
1368
|
+
}
|
|
1369
|
+
}
|
|
1370
|
+
}
|
|
1371
|
+
|
|
1372
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1373
|
+
int64_t mc, nc, mp, np;
|
|
1374
|
+
int m_rem = MIN(m - m0, 8);
|
|
1375
|
+
int n_rem = MIN(n - n0, 8);
|
|
1376
|
+
// TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
|
|
1377
|
+
// issues. After resolving them, below code will be enabled.
|
|
1378
|
+
/*if (m_rem >= 16 && n_rem >= 8) {
|
|
1379
|
+
mc = 16;
|
|
1380
|
+
nc = 8;
|
|
1381
|
+
gemm<16,8>(m0, m, n0, n);
|
|
1382
|
+
} else if(m_rem >= 8 && n_rem >= 16) {
|
|
1383
|
+
mc = 8;
|
|
1384
|
+
nc = 16;
|
|
1385
|
+
gemm<8,16>(m0, m, n0, n);
|
|
1386
|
+
}*/
|
|
1387
|
+
if (m_rem >= 8 && n_rem >= 8) {
|
|
1388
|
+
mc = 8;
|
|
1389
|
+
nc = 8;
|
|
1390
|
+
gemm<8,8>(m0, m, n0, n);
|
|
1391
|
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
|
1392
|
+
mc = 4;
|
|
1393
|
+
nc = 8;
|
|
1394
|
+
gemm<4,8>(m0, m, n0, n);
|
|
1395
|
+
} else if (m_rem >= 8 && n_rem >= 4) {
|
|
1396
|
+
mc = 8;
|
|
1397
|
+
nc = 4;
|
|
1398
|
+
gemm<8,4>(m0, m, n0, n);
|
|
1399
|
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
|
1400
|
+
mc = 4;
|
|
1401
|
+
nc = 4;
|
|
1402
|
+
gemm_small<4, 4>(m0, m, n0, n);
|
|
1403
|
+
} else if ((m_rem < 4) && (n_rem > 4)) {
|
|
1404
|
+
nc = 4;
|
|
1405
|
+
switch(m_rem) {
|
|
1406
|
+
case 1:
|
|
1407
|
+
mc = 1;
|
|
1408
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
|
1409
|
+
break;
|
|
1410
|
+
case 2:
|
|
1411
|
+
mc = 2;
|
|
1412
|
+
gemm_small<2, 4>(m0, m, n0, n);
|
|
1413
|
+
break;
|
|
1414
|
+
case 3:
|
|
1415
|
+
mc = 3;
|
|
1416
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
|
1417
|
+
break;
|
|
1418
|
+
default:
|
|
1419
|
+
return;
|
|
1420
|
+
}
|
|
1421
|
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
|
1422
|
+
mc = 4;
|
|
1423
|
+
switch(n_rem) {
|
|
1424
|
+
case 1:
|
|
1425
|
+
nc = 1;
|
|
1426
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1427
|
+
break;
|
|
1428
|
+
case 2:
|
|
1429
|
+
nc = 2;
|
|
1430
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1431
|
+
break;
|
|
1432
|
+
case 3:
|
|
1433
|
+
nc = 3;
|
|
1434
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1435
|
+
break;
|
|
1436
|
+
default:
|
|
1437
|
+
return;
|
|
1438
|
+
}
|
|
1439
|
+
} else {
|
|
1440
|
+
switch((m_rem << 4) | n_rem) {
|
|
1441
|
+
case 0x43:
|
|
1442
|
+
mc = 4;
|
|
1443
|
+
nc = 3;
|
|
1444
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1445
|
+
break;
|
|
1446
|
+
case 0x42:
|
|
1447
|
+
mc = 4;
|
|
1448
|
+
nc = 2;
|
|
1449
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1450
|
+
break;
|
|
1451
|
+
case 0x41:
|
|
1452
|
+
mc = 4;
|
|
1453
|
+
nc = 1;
|
|
1454
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1455
|
+
break;
|
|
1456
|
+
case 0x34:
|
|
1457
|
+
mc = 3;
|
|
1458
|
+
nc = 4;
|
|
1459
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
|
1460
|
+
break;
|
|
1461
|
+
case 0x33:
|
|
1462
|
+
mc = 3;
|
|
1463
|
+
nc = 3;
|
|
1464
|
+
gemm_small<3, 3>(m0, m, n0, n);
|
|
1465
|
+
break;
|
|
1466
|
+
case 0x32:
|
|
1467
|
+
mc = 3;
|
|
1468
|
+
nc = 2;
|
|
1469
|
+
gemm_small<3, 2>(m0, m, n0, n);
|
|
1470
|
+
break;
|
|
1471
|
+
case 0x31:
|
|
1472
|
+
mc = 3;
|
|
1473
|
+
nc = 1;
|
|
1474
|
+
gemm_small<3, 1>(m0, m, n0, n);
|
|
1475
|
+
break;
|
|
1476
|
+
case 0x24:
|
|
1477
|
+
mc = 2;
|
|
1478
|
+
nc = 4;
|
|
1479
|
+
gemm_small<2, 4>(m0, m, n0, n);
|
|
1480
|
+
break;
|
|
1481
|
+
case 0x23:
|
|
1482
|
+
mc = 2;
|
|
1483
|
+
nc = 3;
|
|
1484
|
+
gemm_small<2, 3>(m0, m, n0, n);
|
|
1485
|
+
break;
|
|
1486
|
+
case 0x22:
|
|
1487
|
+
mc = 2;
|
|
1488
|
+
nc = 2;
|
|
1489
|
+
gemm_small<2, 2>(m0, m, n0, n);
|
|
1490
|
+
break;
|
|
1491
|
+
case 0x21:
|
|
1492
|
+
mc = 2;
|
|
1493
|
+
nc = 1;
|
|
1494
|
+
gemm_small<2, 1>(m0, m, n0, n);
|
|
1495
|
+
break;
|
|
1496
|
+
case 0x14:
|
|
1497
|
+
mc = 1;
|
|
1498
|
+
nc = 4;
|
|
1499
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
|
1500
|
+
break;
|
|
1501
|
+
case 0x13:
|
|
1502
|
+
mc = 1;
|
|
1503
|
+
nc = 3;
|
|
1504
|
+
gemm_small<1, 3>(m0, m, n0, n);
|
|
1505
|
+
break;
|
|
1506
|
+
case 0x12:
|
|
1507
|
+
mc = 1;
|
|
1508
|
+
nc = 2;
|
|
1509
|
+
gemm_small<1, 2>(m0, m, n0, n);
|
|
1510
|
+
break;
|
|
1511
|
+
case 0x11:
|
|
1512
|
+
mc = 1;
|
|
1513
|
+
nc = 1;
|
|
1514
|
+
gemm_small<1, 1>(m0, m, n0, n);
|
|
1515
|
+
break;
|
|
1516
|
+
default:
|
|
1517
|
+
return;
|
|
1518
|
+
}
|
|
1519
|
+
}
|
|
1520
|
+
mp = m0 + (m - m0) / mc * mc;
|
|
1521
|
+
np = n0 + (n - n0) / nc * nc;
|
|
1522
|
+
mnpack(mp, m, n0, np);
|
|
1523
|
+
mnpack(m0, m, np, n);
|
|
1524
|
+
}
|
|
1525
|
+
|
|
1526
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
1527
|
+
vec_t vec_A[8], vec_B[16] = {0};
|
|
1528
|
+
acc_t acc_0, acc_1;
|
|
1529
|
+
std::array<int, 4> comparray;
|
|
1530
|
+
vector float fin_res[8] = {0};
|
|
1531
|
+
vector float vs[8] = {0};
|
|
1532
|
+
for (int l = 0; l < k; l++) {
|
|
1533
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1534
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1535
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
|
1536
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
1537
|
+
for(int x = 0; x < 8; x++) {
|
|
1538
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1539
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
|
1540
|
+
}
|
|
1541
|
+
for (int I = 0; I<4; I++) {
|
|
1542
|
+
for (int J = 0; J<4; J++) {
|
|
1543
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1544
|
+
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
1545
|
+
}
|
|
1546
|
+
}
|
|
1547
|
+
auto aoffset = A+(ii*lda)+l;
|
|
1548
|
+
for (int i = 0; i < 4; i++) {
|
|
1549
|
+
comparray[i] = 0;
|
|
1550
|
+
int ca = 0;
|
|
1551
|
+
const int8_t *at = aoffset->qs;
|
|
1552
|
+
for (int j = 0; j < 32; j++)
|
|
1553
|
+
ca += (int)*at++;
|
|
1554
|
+
comparray[i] = ca;
|
|
1555
|
+
aoffset += lda;
|
|
1556
|
+
}
|
|
1557
|
+
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1558
|
+
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
|
|
1559
|
+
}
|
|
1560
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
|
1561
|
+
save_res<4, 4>(ii, jj+4, 4, fin_res);
|
|
1562
|
+
}
|
|
1563
|
+
|
|
1564
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
1565
|
+
vec_t vec_A[16], vec_B[8] = {0};
|
|
1566
|
+
acc_t acc_0, acc_1;
|
|
1567
|
+
std::array<int, 8> comparray;
|
|
1568
|
+
vector float fin_res[8] = {0};
|
|
1569
|
+
vector float vs[8] = {0};
|
|
1570
|
+
for (int l = 0; l < k; l++) {
|
|
1571
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1572
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1573
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
1574
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
|
1575
|
+
for(int x = 0; x < 8; x++) {
|
|
1576
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1577
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
1578
|
+
}
|
|
1579
|
+
for (int I = 0; I<8; I++) {
|
|
1580
|
+
for (int J = 0; J<4; J++) {
|
|
1581
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1582
|
+
}
|
|
1583
|
+
}
|
|
1584
|
+
auto aoffset = A+(ii*lda)+l;
|
|
1585
|
+
for (int i = 0; i < 8; i++) {
|
|
1586
|
+
comparray[i] = 0;
|
|
1587
|
+
int ca = 0;
|
|
1588
|
+
const int8_t *at = aoffset->qs;
|
|
1589
|
+
for (int j = 0; j < 32; j++)
|
|
1590
|
+
ca += (int)*at++;
|
|
1591
|
+
comparray[i] = ca;
|
|
1592
|
+
aoffset += lda;
|
|
1593
|
+
}
|
|
1594
|
+
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1595
|
+
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
1596
|
+
}
|
|
1597
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
|
1598
|
+
save_res<4, 4>(ii+4, jj, 4, fin_res);
|
|
1599
|
+
}
|
|
1600
|
+
|
|
1601
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
1602
|
+
vec_t vec_A[16], vec_B[16] = {0};
|
|
1603
|
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
1604
|
+
std::array<int, 8> comparray;
|
|
1605
|
+
vector float fin_res[16] = {0};
|
|
1606
|
+
vector float vs[16] = {0};
|
|
1607
|
+
for (int l = 0; l < k; l++) {
|
|
1608
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1609
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1610
|
+
__builtin_mma_xxsetaccz(&acc_2);
|
|
1611
|
+
__builtin_mma_xxsetaccz(&acc_3);
|
|
1612
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
1613
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
1614
|
+
for(int x = 0; x < 8; x++) {
|
|
1615
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1616
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
1617
|
+
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
|
1618
|
+
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
|
1619
|
+
}
|
|
1620
|
+
for (int I = 0; I<8; I++) {
|
|
1621
|
+
for (int J = 0; J<4; J++) {
|
|
1622
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1623
|
+
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
auto aoffset = A+(ii*lda)+l;
|
|
1627
|
+
for (int i = 0; i < 8; i++) {
|
|
1628
|
+
comparray[i] = 0;
|
|
1629
|
+
int ca = 0;
|
|
1630
|
+
const int8_t *at = aoffset->qs;
|
|
1631
|
+
for (int j = 0; j < 32; j++)
|
|
1632
|
+
ca += (int)*at++;
|
|
1633
|
+
comparray[i] = ca;
|
|
1634
|
+
aoffset += lda;
|
|
1635
|
+
}
|
|
1636
|
+
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1637
|
+
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
1638
|
+
compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
|
|
1639
|
+
compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
|
|
1640
|
+
}
|
|
1641
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
|
1642
|
+
save_res<4, 4>(ii+4, jj, 4, fin_res);
|
|
1643
|
+
save_res<4, 4>(ii, jj+4, 8, fin_res);
|
|
1644
|
+
save_res<4, 4>(ii+4, jj+4, 12, fin_res);
|
|
1645
|
+
}
|
|
1646
|
+
|
|
1647
|
+
template<int RM, int RN>
|
|
1648
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1649
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1650
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1651
|
+
int64_t tiles = xtiles * ytiles;
|
|
1652
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1653
|
+
int64_t start = duty * ith;
|
|
1654
|
+
int64_t end = start + duty;
|
|
1655
|
+
vec_t vec_A[8], vec_B[8] = {0};
|
|
1656
|
+
vector signed int vec_C[4];
|
|
1657
|
+
acc_t acc_0;
|
|
1658
|
+
|
|
1659
|
+
if (end > tiles)
|
|
1660
|
+
end = tiles;
|
|
1661
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1662
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1663
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1664
|
+
std::array<int, RM> comparray;
|
|
1665
|
+
vector float res[4] = {0};
|
|
1666
|
+
vector float fin_res[4] = {0};
|
|
1667
|
+
vector float vs[4] = {0};
|
|
1668
|
+
vector float CA[4] = {0};
|
|
1669
|
+
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
|
1670
|
+
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
|
1671
|
+
for (int l = 0; l < k; l++) {
|
|
1672
|
+
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
1673
|
+
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
1674
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1675
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
|
1676
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
|
1677
|
+
for(int x = 0; x < 8; x+=4) {
|
|
1678
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1679
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
|
1680
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
|
1681
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
|
1682
|
+
}
|
|
1683
|
+
for (int I = 0; I<RM; I++) {
|
|
1684
|
+
for (int J = 0; J<RN; J++) {
|
|
1685
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1686
|
+
}
|
|
1687
|
+
}
|
|
1688
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1689
|
+
auto aoffset = A+(ii*lda)+l;
|
|
1690
|
+
for (int i = 0; i < RM; i++) {
|
|
1691
|
+
comparray[i] = 0;
|
|
1692
|
+
int ca = 0;
|
|
1693
|
+
const int8_t *at = aoffset->qs;
|
|
1694
|
+
for (int j = 0; j < 32; j++)
|
|
1695
|
+
ca += (int)*at++;
|
|
1696
|
+
comparray[i] = ca;
|
|
1697
|
+
aoffset += lda;
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
for (int i = 0; i < RM; i++) {
|
|
1701
|
+
CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
|
|
1702
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
1703
|
+
fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
|
|
1704
|
+
}
|
|
1705
|
+
}
|
|
1706
|
+
save_res<RM, RN>(ii, jj, 0, fin_res);
|
|
1707
|
+
}
|
|
1708
|
+
}
|
|
1709
|
+
|
|
1710
|
+
template<int RM, int RN>
|
|
1711
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
|
1712
|
+
if constexpr(RM == 4 && RN == 8) {
|
|
1713
|
+
KERNEL_4x8(ii,jj);
|
|
1714
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
|
1715
|
+
KERNEL_8x4(ii,jj);
|
|
1716
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
|
1717
|
+
KERNEL_8x8(ii,jj);
|
|
1718
|
+
} else {
|
|
1719
|
+
static_assert(false, "RN/RM values not supported");
|
|
1720
|
+
}
|
|
1721
|
+
}
|
|
1722
|
+
|
|
1723
|
+
template <int RM, int RN>
|
|
1724
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1725
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1726
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1727
|
+
int64_t tiles = xtiles * ytiles;
|
|
1728
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1729
|
+
int64_t start = duty * ith;
|
|
1730
|
+
int64_t end = start + duty;
|
|
1731
|
+
if (end > tiles)
|
|
1732
|
+
end = tiles;
|
|
1733
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1734
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1735
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1736
|
+
kernel<RM, RN>(ii, jj);
|
|
1737
|
+
}
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
const TA *const A;
|
|
1741
|
+
const TB *const B;
|
|
1742
|
+
TC *C;
|
|
1743
|
+
TA *At;
|
|
1744
|
+
TB *Bt;
|
|
1745
|
+
const int64_t k;
|
|
1746
|
+
const int64_t lda;
|
|
1747
|
+
const int64_t ldb;
|
|
1748
|
+
const int64_t ldc;
|
|
1749
|
+
const int ith;
|
|
1750
|
+
const int nth;
|
|
1751
|
+
};
|
|
1752
|
+
|
|
1753
|
+
template <typename TA, typename TB, typename TC>
|
|
1754
|
+
class tinyBLAS_PPC {
|
|
1755
|
+
public:
|
|
1756
|
+
tinyBLAS_PPC(int64_t k,
|
|
1757
|
+
const TA *A, int64_t lda,
|
|
1758
|
+
const TB *B, int64_t ldb,
|
|
1759
|
+
TC *C, int64_t ldc,
|
|
1760
|
+
int ith, int nth)
|
|
1761
|
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
void matmul(int64_t m, int64_t n) {
|
|
1765
|
+
mnpack(0, m, 0, n);
|
|
1766
|
+
}
|
|
1767
|
+
|
|
1768
|
+
private:
|
|
1769
|
+
|
|
1770
|
+
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
|
1771
|
+
|
|
1772
|
+
template<typename VA>
|
|
1773
|
+
void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
|
|
1774
|
+
int64_t i, j;
|
|
1775
|
+
TA *aoffset = NULL, *boffset = NULL;
|
|
1776
|
+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
1777
|
+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
1778
|
+
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
|
1779
|
+
VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
|
1780
|
+
VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
|
1781
|
+
VA t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1782
|
+
aoffset = const_cast<TA*>(a);
|
|
1783
|
+
boffset = vec;
|
|
1784
|
+
j = (rows >> 3);
|
|
1785
|
+
if (j > 0) {
|
|
1786
|
+
do {
|
|
1787
|
+
aoffset1 = aoffset;
|
|
1788
|
+
aoffset2 = aoffset1 + lda;
|
|
1789
|
+
aoffset3 = aoffset2 + lda;
|
|
1790
|
+
aoffset4 = aoffset3 + lda;
|
|
1791
|
+
aoffset5 = aoffset4 + lda;
|
|
1792
|
+
aoffset6 = aoffset5 + lda;
|
|
1793
|
+
aoffset7 = aoffset6 + lda;
|
|
1794
|
+
aoffset8 = aoffset7 + lda;
|
|
1795
|
+
aoffset += 8 * lda;
|
|
1796
|
+
i = (cols >> 3);
|
|
1797
|
+
if (i > 0) {
|
|
1798
|
+
do {
|
|
1799
|
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
|
1800
|
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
|
1801
|
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
|
1802
|
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
|
1803
|
+
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
|
|
1804
|
+
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
|
|
1805
|
+
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
|
|
1806
|
+
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
|
|
1807
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
|
1808
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
|
1809
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
|
1810
|
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
|
1811
|
+
__builtin_vsx_disassemble_pair(c5, &C5);
|
|
1812
|
+
__builtin_vsx_disassemble_pair(c6, &C6);
|
|
1813
|
+
__builtin_vsx_disassemble_pair(c7, &C7);
|
|
1814
|
+
__builtin_vsx_disassemble_pair(c8, &C8);
|
|
1815
|
+
|
|
1816
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
|
1817
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
|
1818
|
+
t3 = vec_mergeh(c5[0], c6[0]);
|
|
1819
|
+
t4 = vec_mergeh(c7[0], c8[0]);
|
|
1820
|
+
t5 = vec_xxpermdi(t1, t2, 0);
|
|
1821
|
+
t6 = vec_xxpermdi(t3, t4, 0);
|
|
1822
|
+
t7 = vec_xxpermdi(t1, t2, 3);
|
|
1823
|
+
t8 = vec_xxpermdi(t3, t4, 3);
|
|
1824
|
+
vec_xst(t5, 0, boffset);
|
|
1825
|
+
vec_xst(t6, 0, boffset+4);
|
|
1826
|
+
vec_xst(t7, 0, boffset+8);
|
|
1827
|
+
vec_xst(t8, 0, boffset+12);
|
|
1828
|
+
|
|
1829
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
|
1830
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
|
1831
|
+
t3 = vec_mergel(c5[0], c6[0]);
|
|
1832
|
+
t4 = vec_mergel(c7[0], c8[0]);
|
|
1833
|
+
t5 = vec_xxpermdi(t1, t2, 0);
|
|
1125
1834
|
t6 = vec_xxpermdi(t3, t4, 0);
|
|
1126
1835
|
t7 = vec_xxpermdi(t1, t2, 3);
|
|
1127
1836
|
t8 = vec_xxpermdi(t3, t4, 3);
|
|
@@ -1165,21 +1874,19 @@ class tinyBLAS_PPC {
|
|
|
1165
1874
|
} while(i > 0);
|
|
1166
1875
|
}
|
|
1167
1876
|
if (cols & 4) {
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
t3 = vec_mergeh(c5, c6);
|
|
1182
|
-
t4 = vec_mergeh(c7, c8);
|
|
1877
|
+
c1[0] = vec_xl(0, aoffset1);
|
|
1878
|
+
c2[0] = vec_xl(0, aoffset2);
|
|
1879
|
+
c3[0] = vec_xl(0, aoffset3);
|
|
1880
|
+
c4[0] = vec_xl(0, aoffset4);
|
|
1881
|
+
c5[0] = vec_xl(0, aoffset5);
|
|
1882
|
+
c6[0] = vec_xl(0, aoffset6);
|
|
1883
|
+
c7[0] = vec_xl(0, aoffset7);
|
|
1884
|
+
c8[0] = vec_xl(0, aoffset8);
|
|
1885
|
+
|
|
1886
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
|
1887
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
|
1888
|
+
t3 = vec_mergeh(c5[0], c6[0]);
|
|
1889
|
+
t4 = vec_mergeh(c7[0], c8[0]);
|
|
1183
1890
|
t5 = vec_xxpermdi(t1, t2, 0);
|
|
1184
1891
|
t6 = vec_xxpermdi(t3, t4, 0);
|
|
1185
1892
|
t7 = vec_xxpermdi(t1, t2, 3);
|
|
@@ -1189,10 +1896,10 @@ class tinyBLAS_PPC {
|
|
|
1189
1896
|
vec_xst(t7, 0, boffset+8);
|
|
1190
1897
|
vec_xst(t8, 0, boffset+12);
|
|
1191
1898
|
|
|
1192
|
-
t1 = vec_mergel(c1, c2);
|
|
1193
|
-
t2 = vec_mergel(c3, c4);
|
|
1194
|
-
t3 = vec_mergel(c5, c6);
|
|
1195
|
-
t4 = vec_mergel(c7, c8);
|
|
1899
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
|
1900
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
|
1901
|
+
t3 = vec_mergel(c5[0], c6[0]);
|
|
1902
|
+
t4 = vec_mergel(c7[0], c8[0]);
|
|
1196
1903
|
t5 = vec_xxpermdi(t1, t2, 0);
|
|
1197
1904
|
t6 = vec_xxpermdi(t3, t4, 0);
|
|
1198
1905
|
t7 = vec_xxpermdi(t1, t2, 3);
|
|
@@ -1214,9 +1921,6 @@ class tinyBLAS_PPC {
|
|
|
1214
1921
|
aoffset += 4 * lda;
|
|
1215
1922
|
i = (cols >> 3);
|
|
1216
1923
|
if (i > 0) {
|
|
1217
|
-
__vector_pair C1, C2, C3, C4;
|
|
1218
|
-
vector float c1[2], c2[2], c3[2], c4[2];
|
|
1219
|
-
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1220
1924
|
do {
|
|
1221
1925
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
|
1222
1926
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
|
@@ -1263,22 +1967,20 @@ class tinyBLAS_PPC {
|
|
|
1263
1967
|
}
|
|
1264
1968
|
|
|
1265
1969
|
if (cols & 4) {
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
t1 = vec_mergeh(c1, c2);
|
|
1274
|
-
t2 = vec_mergeh(c3, c4);
|
|
1970
|
+
c1[0] = vec_xl(0, aoffset1);
|
|
1971
|
+
c2[0] = vec_xl(0, aoffset2);
|
|
1972
|
+
c3[0] = vec_xl(0, aoffset3);
|
|
1973
|
+
c4[0] = vec_xl(0, aoffset4);
|
|
1974
|
+
|
|
1975
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
|
1976
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
|
1275
1977
|
t3 = vec_xxpermdi(t1, t2, 0);
|
|
1276
1978
|
t4 = vec_xxpermdi(t1, t2, 3);
|
|
1277
1979
|
vec_xst(t3, 0, boffset);
|
|
1278
1980
|
vec_xst(t4, 0, boffset+4);
|
|
1279
1981
|
|
|
1280
|
-
t1 = vec_mergel(c1, c2);
|
|
1281
|
-
t2 = vec_mergel(c3, c4);
|
|
1982
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
|
1983
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
|
1282
1984
|
t3 = vec_xxpermdi(t1, t2, 0);
|
|
1283
1985
|
t4 = vec_xxpermdi(t1, t2, 3);
|
|
1284
1986
|
vec_xst(t3, 0, boffset+8);
|
|
@@ -1290,21 +1992,19 @@ class tinyBLAS_PPC {
|
|
|
1290
1992
|
aoffset2 = aoffset1 + lda;
|
|
1291
1993
|
aoffset3 = aoffset2 + lda;
|
|
1292
1994
|
if (cols & 4) {
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
t1 = vec_mergeh(c1, c2);
|
|
1300
|
-
t2 = vec_mergeh(c3, c4);
|
|
1995
|
+
c1[0] = vec_xl(0, aoffset1);
|
|
1996
|
+
c2[0] = vec_xl(0, aoffset2);
|
|
1997
|
+
c3[0] = vec_xl(0, aoffset3);
|
|
1998
|
+
|
|
1999
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
|
2000
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
|
1301
2001
|
t3 = vec_xxpermdi(t1, t2, 0);
|
|
1302
2002
|
t4 = vec_xxpermdi(t1, t2, 3);
|
|
1303
2003
|
vec_xst(t3, 0, boffset);
|
|
1304
2004
|
vec_xst(t4, 0, boffset+4);
|
|
1305
2005
|
|
|
1306
|
-
t1 = vec_mergel(c1, c2);
|
|
1307
|
-
t2 = vec_mergel(c3, c4);
|
|
2006
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
|
2007
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
|
1308
2008
|
t3 = vec_xxpermdi(t1, t2, 0);
|
|
1309
2009
|
t4 = vec_xxpermdi(t1, t2, 3);
|
|
1310
2010
|
vec_xst(t3, 0, boffset+8);
|
|
@@ -1312,14 +2012,13 @@ class tinyBLAS_PPC {
|
|
|
1312
2012
|
}
|
|
1313
2013
|
}
|
|
1314
2014
|
}
|
|
1315
|
-
|
|
1316
2015
|
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
|
1317
2016
|
vec_t vec_A[4], vec_B[4], vec_C[4];
|
|
1318
2017
|
acc_t acc_0;
|
|
1319
2018
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1320
2019
|
for (int l = 0; l < k; l+=4) {
|
|
1321
|
-
|
|
1322
|
-
|
|
2020
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
|
|
2021
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
|
1323
2022
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
|
1324
2023
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
|
1325
2024
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
|
@@ -1334,8 +2033,8 @@ class tinyBLAS_PPC {
|
|
|
1334
2033
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1335
2034
|
__builtin_mma_xxsetaccz(&acc_1);
|
|
1336
2035
|
for (int64_t l = 0; l < k; l+=4) {
|
|
1337
|
-
|
|
1338
|
-
|
|
2036
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
|
|
2037
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
|
|
1339
2038
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
|
1340
2039
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
|
1341
2040
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
|
@@ -1355,8 +2054,8 @@ class tinyBLAS_PPC {
|
|
|
1355
2054
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1356
2055
|
__builtin_mma_xxsetaccz(&acc_1);
|
|
1357
2056
|
for (int64_t l = 0; l < k; l+=4) {
|
|
1358
|
-
|
|
1359
|
-
|
|
2057
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
|
|
2058
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
|
1360
2059
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
|
1361
2060
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
|
1362
2061
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
|
@@ -1378,8 +2077,8 @@ class tinyBLAS_PPC {
|
|
|
1378
2077
|
__builtin_mma_xxsetaccz(&acc_2);
|
|
1379
2078
|
__builtin_mma_xxsetaccz(&acc_3);
|
|
1380
2079
|
for (int l = 0; l < k; l+=8) {
|
|
1381
|
-
|
|
1382
|
-
|
|
2080
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
|
|
2081
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
|
|
1383
2082
|
for(int x = 0; x < 16; x+=2) {
|
|
1384
2083
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
|
1385
2084
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
|
@@ -1562,15 +2261,15 @@ class tinyBLAS_PPC {
|
|
|
1562
2261
|
vec_t vec_A[4], vec_B[4];
|
|
1563
2262
|
for (int l=0; l<k; l+=4) {
|
|
1564
2263
|
if (RN >= 4 && RM == 1) {
|
|
1565
|
-
|
|
1566
|
-
|
|
2264
|
+
TA* a = const_cast<TA*>(A+(ii)*lda+l);
|
|
2265
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
|
1567
2266
|
vec_A[0] = (vec_t)vec_xl(0,a);
|
|
1568
|
-
vec_A[1] = (vec_t)vec_splats(*((
|
|
1569
|
-
vec_A[2] = (vec_t)vec_splats(*((
|
|
1570
|
-
vec_A[3] = (vec_t)vec_splats(*((
|
|
2267
|
+
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
|
|
2268
|
+
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
|
|
2269
|
+
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
|
|
1571
2270
|
} else {
|
|
1572
|
-
|
|
1573
|
-
|
|
2271
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
|
2272
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
|
1574
2273
|
}
|
|
1575
2274
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
|
1576
2275
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
|
@@ -1580,7 +2279,7 @@ class tinyBLAS_PPC {
|
|
|
1580
2279
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1581
2280
|
for (int I = 0; I < RM; I++) {
|
|
1582
2281
|
for (int J = 0; J < RN; J++) {
|
|
1583
|
-
*((
|
|
2282
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1584
2283
|
}
|
|
1585
2284
|
}
|
|
1586
2285
|
}
|
|
@@ -1657,8 +2356,9 @@ class tinyBLAS_PPC {
|
|
|
1657
2356
|
* @param Ctype is GGML data type of `C`
|
|
1658
2357
|
* @return true if this function was able to service the matmul request
|
|
1659
2358
|
*/
|
|
1660
|
-
bool llamafile_sgemm(
|
|
1661
|
-
|
|
2359
|
+
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
2360
|
+
const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
|
2361
|
+
int64_t ldc, int Atype, int Btype, int Ctype) {
|
|
1662
2362
|
|
|
1663
2363
|
assert(m >= 0);
|
|
1664
2364
|
assert(n >= 0);
|
|
@@ -1666,8 +2366,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1666
2366
|
assert(lda >= k);
|
|
1667
2367
|
assert(ldb >= k);
|
|
1668
2368
|
assert(ldc >= m);
|
|
1669
|
-
assert(nth > 0);
|
|
1670
|
-
assert(ith < nth);
|
|
2369
|
+
assert(params->nth > 0);
|
|
2370
|
+
assert(params->ith < params->nth);
|
|
1671
2371
|
|
|
1672
2372
|
// only enable sgemm for prompt processing
|
|
1673
2373
|
if (n < 2)
|
|
@@ -1682,37 +2382,25 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1682
2382
|
if (Btype != GGML_TYPE_F32)
|
|
1683
2383
|
return false;
|
|
1684
2384
|
#if defined(__AVX512F__)
|
|
1685
|
-
|
|
1686
|
-
return false;
|
|
1687
|
-
tinyBLAS<16, __m512, __m512, float, float, float> tb{
|
|
2385
|
+
tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
|
|
1688
2386
|
k, (const float *)A, lda,
|
|
1689
2387
|
(const float *)B, ldb,
|
|
1690
|
-
(float *)C, ldc
|
|
1691
|
-
|
|
1692
|
-
tb.matmul(m, n);
|
|
1693
|
-
return true;
|
|
2388
|
+
(float *)C, ldc};
|
|
2389
|
+
return tb.matmul(m, n);
|
|
1694
2390
|
#elif defined(__AVX__) || defined(__AVX2__)
|
|
1695
|
-
|
|
1696
|
-
return false;
|
|
1697
|
-
tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
|
2391
|
+
tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
|
|
1698
2392
|
k, (const float *)A, lda,
|
|
1699
2393
|
(const float *)B, ldb,
|
|
1700
|
-
(float *)C, ldc
|
|
1701
|
-
|
|
1702
|
-
tb.matmul(m, n);
|
|
1703
|
-
return true;
|
|
2394
|
+
(float *)C, ldc};
|
|
2395
|
+
return tb.matmul(m, n);
|
|
1704
2396
|
#elif defined(__ARM_NEON)
|
|
1705
2397
|
if (n < 4)
|
|
1706
2398
|
return false;
|
|
1707
|
-
|
|
1708
|
-
return false;
|
|
1709
|
-
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
|
|
2399
|
+
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
|
1710
2400
|
k, (const float *)A, lda,
|
|
1711
2401
|
(const float *)B, ldb,
|
|
1712
|
-
(float *)C, ldc
|
|
1713
|
-
|
|
1714
|
-
tb.matmul(m, n);
|
|
1715
|
-
return true;
|
|
2402
|
+
(float *)C, ldc};
|
|
2403
|
+
return tb.matmul(m, n);
|
|
1716
2404
|
#elif defined(__MMA__)
|
|
1717
2405
|
if (k % 8)
|
|
1718
2406
|
return false;
|
|
@@ -1720,7 +2408,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1720
2408
|
k, (const float *)A, lda,
|
|
1721
2409
|
(const float *)B, ldb,
|
|
1722
2410
|
(float *)C, ldc,
|
|
1723
|
-
ith, nth};
|
|
2411
|
+
params->ith, params->nth};
|
|
1724
2412
|
tb.matmul(m, n);
|
|
1725
2413
|
return true;
|
|
1726
2414
|
#else
|
|
@@ -1728,60 +2416,71 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1728
2416
|
#endif
|
|
1729
2417
|
}
|
|
1730
2418
|
|
|
2419
|
+
case GGML_TYPE_BF16: {
|
|
2420
|
+
#if defined(__AVX512BF16__)
|
|
2421
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
2422
|
+
tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
2423
|
+
(const ggml_bf16_t *)A, lda,
|
|
2424
|
+
(const ggml_bf16_t *)B, ldb,
|
|
2425
|
+
(float *)C, ldc};
|
|
2426
|
+
return tb.matmul(m, n);
|
|
2427
|
+
}
|
|
2428
|
+
#elif defined(__AVX512F__)
|
|
2429
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
2430
|
+
tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
2431
|
+
(const ggml_bf16_t *)A, lda,
|
|
2432
|
+
(const ggml_bf16_t *)B, ldb,
|
|
2433
|
+
(float *)C, ldc};
|
|
2434
|
+
return tb.matmul(m, n);
|
|
2435
|
+
}
|
|
2436
|
+
#elif defined(__AVX2__)
|
|
2437
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
2438
|
+
tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
|
|
2439
|
+
(const ggml_bf16_t *)A, lda,
|
|
2440
|
+
(const ggml_bf16_t *)B, ldb,
|
|
2441
|
+
(float *)C, ldc};
|
|
2442
|
+
return tb.matmul(m, n);
|
|
2443
|
+
}
|
|
2444
|
+
#endif
|
|
2445
|
+
return false;
|
|
2446
|
+
}
|
|
1731
2447
|
case GGML_TYPE_F16: {
|
|
1732
2448
|
#if defined(__AVX512F__)
|
|
1733
|
-
if (
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
(float *)C, ldc,
|
|
1741
|
-
ith, nth};
|
|
1742
|
-
tb.matmul(m, n);
|
|
1743
|
-
return true;
|
|
2449
|
+
if (Btype == GGML_TYPE_F16) {
|
|
2450
|
+
tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
|
|
2451
|
+
(const ggml_fp16_t *)A, lda,
|
|
2452
|
+
(const ggml_fp16_t *)B, ldb,
|
|
2453
|
+
(float *)C, ldc};
|
|
2454
|
+
return tb.matmul(m, n);
|
|
2455
|
+
}
|
|
1744
2456
|
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
|
1745
|
-
if (
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
(float *)C, ldc,
|
|
1753
|
-
ith, nth};
|
|
1754
|
-
tb.matmul(m, n);
|
|
1755
|
-
return true;
|
|
2457
|
+
if (Btype == GGML_TYPE_F16) {
|
|
2458
|
+
tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
|
|
2459
|
+
(const ggml_fp16_t *)A, lda,
|
|
2460
|
+
(const ggml_fp16_t *)B, ldb,
|
|
2461
|
+
(float *)C, ldc};
|
|
2462
|
+
return tb.matmul(m, n);
|
|
2463
|
+
}
|
|
1756
2464
|
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
|
1757
2465
|
if (n < 8)
|
|
1758
2466
|
return false;
|
|
1759
|
-
if (
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
(float *)C, ldc,
|
|
1767
|
-
ith, nth};
|
|
1768
|
-
tb.matmul(m, n);
|
|
1769
|
-
return true;
|
|
2467
|
+
if (Btype == GGML_TYPE_F16) {
|
|
2468
|
+
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
2469
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
2470
|
+
(const ggml_fp16_t *)B, ldb,
|
|
2471
|
+
(float *)C, ldc};
|
|
2472
|
+
return tb.matmul(m, n);
|
|
2473
|
+
}
|
|
1770
2474
|
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
|
1771
|
-
if (
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
(float *)C, ldc,
|
|
1779
|
-
ith, nth};
|
|
1780
|
-
tb.matmul(m, n);
|
|
1781
|
-
return true;
|
|
1782
|
-
#else
|
|
1783
|
-
return false;
|
|
2475
|
+
if (Btype == GGML_TYPE_F32) {
|
|
2476
|
+
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
|
|
2477
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
2478
|
+
(const float *)B, ldb,
|
|
2479
|
+
(float *)C, ldc};
|
|
2480
|
+
return tb.matmul(m, n);
|
|
2481
|
+
}
|
|
1784
2482
|
#endif
|
|
2483
|
+
return false;
|
|
1785
2484
|
}
|
|
1786
2485
|
|
|
1787
2486
|
case GGML_TYPE_Q8_0: {
|
|
@@ -1792,7 +2491,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1792
2491
|
k, (const block_q8_0 *)A, lda,
|
|
1793
2492
|
(const block_q8_0 *)B, ldb,
|
|
1794
2493
|
(float *)C, ldc,
|
|
1795
|
-
ith, nth};
|
|
2494
|
+
params->ith, params->nth};
|
|
1796
2495
|
tb.matmul(m, n);
|
|
1797
2496
|
return true;
|
|
1798
2497
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
|
@@ -1800,9 +2499,23 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1800
2499
|
k, (const block_q8_0 *)A, lda,
|
|
1801
2500
|
(const block_q8_0 *)B, ldb,
|
|
1802
2501
|
(float *)C, ldc,
|
|
1803
|
-
ith, nth};
|
|
2502
|
+
params->ith, params->nth};
|
|
2503
|
+
tb.matmul(m, n);
|
|
2504
|
+
return true;
|
|
2505
|
+
|
|
2506
|
+
#elif defined(__MMA__)
|
|
2507
|
+
if (n < 8 && n != 4)
|
|
2508
|
+
return false;
|
|
2509
|
+
if (m < 8 && m != 4)
|
|
2510
|
+
return false;
|
|
2511
|
+
tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
|
|
2512
|
+
k, (const block_q8_0 *)A, lda,
|
|
2513
|
+
(const block_q8_0 *)B, ldb,
|
|
2514
|
+
(float *)C, ldc,
|
|
2515
|
+
params->ith, params->nth};
|
|
1804
2516
|
tb.matmul(m, n);
|
|
1805
2517
|
return true;
|
|
2518
|
+
|
|
1806
2519
|
#else
|
|
1807
2520
|
return false;
|
|
1808
2521
|
#endif
|
|
@@ -1816,7 +2529,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1816
2529
|
k, (const block_q4_0 *)A, lda,
|
|
1817
2530
|
(const block_q8_0 *)B, ldb,
|
|
1818
2531
|
(float *)C, ldc,
|
|
1819
|
-
ith, nth};
|
|
2532
|
+
params->ith, params->nth};
|
|
1820
2533
|
tb.matmul(m, n);
|
|
1821
2534
|
return true;
|
|
1822
2535
|
#elif defined(__ARM_FEATURE_DOTPROD)
|
|
@@ -1824,7 +2537,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1824
2537
|
k, (const block_q4_0 *)A, lda,
|
|
1825
2538
|
(const block_q8_0 *)B, ldb,
|
|
1826
2539
|
(float *)C, ldc,
|
|
1827
|
-
ith, nth};
|
|
2540
|
+
params->ith, params->nth};
|
|
1828
2541
|
tb.matmul(m, n);
|
|
1829
2542
|
return true;
|
|
1830
2543
|
#else
|
|
@@ -1840,7 +2553,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1840
2553
|
k, (const block_q5_0 *)A, lda,
|
|
1841
2554
|
(const block_q8_0 *)B, ldb,
|
|
1842
2555
|
(float *)C, ldc,
|
|
1843
|
-
ith, nth};
|
|
2556
|
+
params->ith, params->nth};
|
|
1844
2557
|
tb.matmul(m, n);
|
|
1845
2558
|
return true;
|
|
1846
2559
|
#else
|
|
@@ -1856,7 +2569,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1856
2569
|
k, (const block_iq4_nl *)A, lda,
|
|
1857
2570
|
(const block_q8_0 *)B, ldb,
|
|
1858
2571
|
(float *)C, ldc,
|
|
1859
|
-
ith, nth};
|
|
2572
|
+
params->ith, params->nth};
|
|
1860
2573
|
tb.matmul(m, n);
|
|
1861
2574
|
return true;
|
|
1862
2575
|
#else
|
|
@@ -1868,6 +2581,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1868
2581
|
return false;
|
|
1869
2582
|
}
|
|
1870
2583
|
|
|
2584
|
+
(void)params;
|
|
1871
2585
|
(void)m;
|
|
1872
2586
|
(void)n;
|
|
1873
2587
|
(void)k;
|
|
@@ -1877,8 +2591,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1877
2591
|
(void)ldb;
|
|
1878
2592
|
(void)C;
|
|
1879
2593
|
(void)ldc;
|
|
1880
|
-
(void)ith;
|
|
1881
|
-
(void)nth;
|
|
1882
2594
|
(void)Atype;
|
|
1883
2595
|
(void)Btype;
|
|
1884
2596
|
(void)Ctype;
|