@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
|
|
|
1054
1054
|
} \
|
|
1055
1055
|
} \
|
|
1056
1056
|
|
|
1057
|
+
template <typename TA, typename TB, typename TC>
|
|
1058
|
+
class tinyBLAS_BF16_PPC {
|
|
1059
|
+
public:
|
|
1060
|
+
tinyBLAS_BF16_PPC(int64_t k,
|
|
1061
|
+
const TA *A, int64_t lda,
|
|
1062
|
+
const TB *B, int64_t ldb,
|
|
1063
|
+
TC *C, int64_t ldc,
|
|
1064
|
+
int ith, int nth)
|
|
1065
|
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
void matmul(int64_t m, int64_t n) {
|
|
1069
|
+
mnpack(0, m, 0, n);
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
private:
|
|
1073
|
+
void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
|
|
1074
|
+
vec_t t[8], s[8];
|
|
1075
|
+
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
1076
|
+
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
1077
|
+
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1078
|
+
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1079
|
+
|
|
1080
|
+
if (numVec == 2) {
|
|
1081
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
|
1082
|
+
t[1] = vec_perm(c[2], c[3], swiz1);
|
|
1083
|
+
s[0] = vec_perm(t[0], t[1], swiz3);
|
|
1084
|
+
s[1] = vec_perm(t[0], t[1], swiz4);
|
|
1085
|
+
vec_xst(s[0], 0, (vec_t*)vecOffset);
|
|
1086
|
+
vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
|
|
1087
|
+
} else if (numVec == 4) {
|
|
1088
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
|
1089
|
+
t[1] = vec_perm(c[0], c[1], swiz2);
|
|
1090
|
+
t[2] = vec_perm(c[2], c[3], swiz1);
|
|
1091
|
+
t[3] = vec_perm(c[2], c[3], swiz2);
|
|
1092
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
1093
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
1094
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
1095
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
1096
|
+
for (int i = 0; i < 4; ++i)
|
|
1097
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
|
1098
|
+
} else if (numVec == 8) {
|
|
1099
|
+
for (int i = 0; i < 4; i += 2) {
|
|
1100
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
|
1101
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
|
1102
|
+
}
|
|
1103
|
+
for (int i = 4; i < 8; i += 2) {
|
|
1104
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
|
1105
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
|
1106
|
+
}
|
|
1107
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
1108
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
1109
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
1110
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
1111
|
+
s[4] = vec_perm(t[4], t[6], swiz3);
|
|
1112
|
+
s[5] = vec_perm(t[4], t[6], swiz4);
|
|
1113
|
+
s[6] = vec_perm(t[5], t[7], swiz3);
|
|
1114
|
+
s[7] = vec_perm(t[5], t[7], swiz4);
|
|
1115
|
+
for (int i = 0; i < 8; ++i)
|
|
1116
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
|
1117
|
+
}
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
|
|
1121
|
+
int64_t i, j;
|
|
1122
|
+
TA *aoffset = NULL;
|
|
1123
|
+
unsigned char *vecOffset = NULL;
|
|
1124
|
+
TA * aoffsets[8];
|
|
1125
|
+
vector unsigned char c_arr[8];
|
|
1126
|
+
aoffset = const_cast<TA*>(a);
|
|
1127
|
+
vecOffset = vec;
|
|
1128
|
+
j = (rows >> 3);
|
|
1129
|
+
if (j > 0) {
|
|
1130
|
+
do {
|
|
1131
|
+
if (cols == 4) {
|
|
1132
|
+
aoffsets[0] = aoffset;
|
|
1133
|
+
for (int it = 1; it < 4; ++it)
|
|
1134
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1135
|
+
aoffset += 4 * lda;
|
|
1136
|
+
for (int i = 0; i < 4; ++i)
|
|
1137
|
+
c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
|
|
1138
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1139
|
+
for (int i = 0; i<4; i++)
|
|
1140
|
+
aoffsets[i] = aoffsets[i]+lda;
|
|
1141
|
+
vecOffset +=64;
|
|
1142
|
+
}
|
|
1143
|
+
i = (cols >> 3);
|
|
1144
|
+
if (i > 0) {
|
|
1145
|
+
aoffsets[0] = aoffset;
|
|
1146
|
+
for (int it = 1; it < 8; ++it) {
|
|
1147
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1148
|
+
}
|
|
1149
|
+
aoffset += 8 * lda;
|
|
1150
|
+
do {
|
|
1151
|
+
for (int it = 0; it < 8; ++it)
|
|
1152
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1153
|
+
vector_permute_store(c_arr, 8, vecOffset);
|
|
1154
|
+
for (int it = 0; it < 8; ++it)
|
|
1155
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
|
1156
|
+
vecOffset += 128;
|
|
1157
|
+
i--;
|
|
1158
|
+
} while(i > 0);
|
|
1159
|
+
}
|
|
1160
|
+
j--;
|
|
1161
|
+
} while(j > 0);
|
|
1162
|
+
}
|
|
1163
|
+
if (rows & 4) {
|
|
1164
|
+
aoffsets[0] = aoffset;
|
|
1165
|
+
for (int it = 1; it < 4; ++it)
|
|
1166
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1167
|
+
aoffset += 4 * lda;
|
|
1168
|
+
if (cols == 4) {
|
|
1169
|
+
for (int it = 0; it < 4; ++it)
|
|
1170
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1171
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
|
1172
|
+
for (int it = 0; it< 4; it++)
|
|
1173
|
+
aoffsets[it] = aoffsets[it] + lda;
|
|
1174
|
+
vecOffset += 32;
|
|
1175
|
+
}
|
|
1176
|
+
i = (cols >> 3);
|
|
1177
|
+
if (i > 0) {
|
|
1178
|
+
do {
|
|
1179
|
+
for (int it = 0; it < 4; ++it)
|
|
1180
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1181
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1182
|
+
for (int it = 0; it< 4; it++)
|
|
1183
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
|
1184
|
+
vecOffset += 64;
|
|
1185
|
+
i--;
|
|
1186
|
+
} while(i > 0);
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
if (rows & 3) {
|
|
1190
|
+
aoffsets[0] = aoffset;
|
|
1191
|
+
for (int it = 1; it < 4; ++it)
|
|
1192
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1193
|
+
if (cols == 4) {
|
|
1194
|
+
switch(rows) {
|
|
1195
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
|
1196
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
|
1197
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
|
1198
|
+
break;
|
|
1199
|
+
}
|
|
1200
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
|
1201
|
+
for (int it = 0; it< 4; it++)
|
|
1202
|
+
aoffsets[it] = aoffsets[it] + lda;
|
|
1203
|
+
vecOffset += 32;
|
|
1204
|
+
}
|
|
1205
|
+
i = (cols >> 3);
|
|
1206
|
+
if (i > 0) {
|
|
1207
|
+
do {
|
|
1208
|
+
switch(rows) {
|
|
1209
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
|
1210
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
|
1211
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
|
1212
|
+
break;
|
|
1213
|
+
}
|
|
1214
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1215
|
+
for (int it = 0; it <4; it++)
|
|
1216
|
+
aoffsets[it] = aoffsets[it] + 8* lda;
|
|
1217
|
+
vecOffset += 64;
|
|
1218
|
+
i--;
|
|
1219
|
+
} while(i > 0);
|
|
1220
|
+
}
|
|
1221
|
+
}
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1225
|
+
int64_t mc, nc, mp, np;
|
|
1226
|
+
int m_rem = MIN(m - m0, 8);
|
|
1227
|
+
int n_rem = MIN(n - n0, 8);
|
|
1228
|
+
|
|
1229
|
+
if (m_rem >= 8 && n_rem >= 8) {
|
|
1230
|
+
mc = 8;
|
|
1231
|
+
nc = 8;
|
|
1232
|
+
gemm<8,8>(m0, m, n0, n);
|
|
1233
|
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
|
1234
|
+
mc = 4;
|
|
1235
|
+
nc = 8;
|
|
1236
|
+
gemm<4,8>(m0, m, n0, n);
|
|
1237
|
+
} else if (m_rem >=8 && n_rem >=4){
|
|
1238
|
+
mc = 8;
|
|
1239
|
+
nc = 4;
|
|
1240
|
+
gemm<8,4>(m0, m, n0, n);
|
|
1241
|
+
} else if ((m_rem < 4) && (n_rem >= 8)) {
|
|
1242
|
+
nc = 8;
|
|
1243
|
+
switch(m_rem) {
|
|
1244
|
+
case 1:
|
|
1245
|
+
mc = 1;
|
|
1246
|
+
gemm_Mx8<1>(m0, m, n0, n);
|
|
1247
|
+
break;
|
|
1248
|
+
case 2:
|
|
1249
|
+
mc = 2;
|
|
1250
|
+
gemm_Mx8<2>(m0, m, n0, n);
|
|
1251
|
+
break;
|
|
1252
|
+
case 3:
|
|
1253
|
+
mc = 3;
|
|
1254
|
+
gemm_Mx8<3>(m0, m, n0, n);
|
|
1255
|
+
break;
|
|
1256
|
+
default:
|
|
1257
|
+
return;
|
|
1258
|
+
}
|
|
1259
|
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
|
1260
|
+
mc = 4;
|
|
1261
|
+
nc = 4;
|
|
1262
|
+
gemm_small<4, 4>(m0, m, n0, n);
|
|
1263
|
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
|
1264
|
+
mc = 4;
|
|
1265
|
+
switch(n_rem) {
|
|
1266
|
+
case 1:
|
|
1267
|
+
nc = 1;
|
|
1268
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1269
|
+
break;
|
|
1270
|
+
case 2:
|
|
1271
|
+
nc = 2;
|
|
1272
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1273
|
+
break;
|
|
1274
|
+
case 3:
|
|
1275
|
+
nc = 3;
|
|
1276
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1277
|
+
break;
|
|
1278
|
+
|
|
1279
|
+
default:
|
|
1280
|
+
return;
|
|
1281
|
+
}
|
|
1282
|
+
} else {
|
|
1283
|
+
switch((m_rem << 4) | n_rem) {
|
|
1284
|
+
case 0x43:
|
|
1285
|
+
mc = 4;
|
|
1286
|
+
nc = 3;
|
|
1287
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1288
|
+
break;
|
|
1289
|
+
case 0x42:
|
|
1290
|
+
mc = 4;
|
|
1291
|
+
nc = 2;
|
|
1292
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1293
|
+
break;
|
|
1294
|
+
case 0x41:
|
|
1295
|
+
mc = 4;
|
|
1296
|
+
nc = 1;
|
|
1297
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1298
|
+
break;
|
|
1299
|
+
case 0x34:
|
|
1300
|
+
mc = 3;
|
|
1301
|
+
nc = 4;
|
|
1302
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
|
1303
|
+
break;
|
|
1304
|
+
case 0x33:
|
|
1305
|
+
mc = 3;
|
|
1306
|
+
nc = 3;
|
|
1307
|
+
gemm_small<3, 3>(m0, m, n0, n);
|
|
1308
|
+
break;
|
|
1309
|
+
case 0x32:
|
|
1310
|
+
mc = 3;
|
|
1311
|
+
nc = 2;
|
|
1312
|
+
gemm_small<3, 2>(m0, m, n0, n);
|
|
1313
|
+
break;
|
|
1314
|
+
case 0x31:
|
|
1315
|
+
mc = 3;
|
|
1316
|
+
nc = 1;
|
|
1317
|
+
gemm_small<3, 1>(m0, m, n0, n);
|
|
1318
|
+
break;
|
|
1319
|
+
case 0x24:
|
|
1320
|
+
mc = 2;
|
|
1321
|
+
nc = 4;
|
|
1322
|
+
gemm_small<2,4>(m0, m, n0, n);
|
|
1323
|
+
break;
|
|
1324
|
+
case 0x23:
|
|
1325
|
+
mc = 2;
|
|
1326
|
+
nc = 3;
|
|
1327
|
+
gemm_small<2, 3>(m0, m, n0, n);
|
|
1328
|
+
break;
|
|
1329
|
+
case 0x22:
|
|
1330
|
+
mc = 2;
|
|
1331
|
+
nc = 2;
|
|
1332
|
+
gemm_small<2, 2>(m0, m, n0, n);
|
|
1333
|
+
break;
|
|
1334
|
+
case 0x21:
|
|
1335
|
+
mc = 2;
|
|
1336
|
+
nc = 1;
|
|
1337
|
+
gemm_small<2, 1>(m0, m, n0, n);
|
|
1338
|
+
break;
|
|
1339
|
+
case 0x14:
|
|
1340
|
+
mc = 1;
|
|
1341
|
+
nc = 4;
|
|
1342
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
|
1343
|
+
break;
|
|
1344
|
+
case 0x13:
|
|
1345
|
+
mc = 1;
|
|
1346
|
+
nc = 3;
|
|
1347
|
+
gemm_small<1, 3>(m0, m, n0, n);
|
|
1348
|
+
break;
|
|
1349
|
+
case 0x12:
|
|
1350
|
+
mc = 1;
|
|
1351
|
+
nc = 2;
|
|
1352
|
+
gemm_small<1, 2>(m0, m, n0, n);
|
|
1353
|
+
break;
|
|
1354
|
+
case 0x11:
|
|
1355
|
+
mc = 1;
|
|
1356
|
+
nc = 1;
|
|
1357
|
+
gemm_small<1, 1>(m0, m, n0, n);
|
|
1358
|
+
break;
|
|
1359
|
+
default:
|
|
1360
|
+
return;
|
|
1361
|
+
}
|
|
1362
|
+
}
|
|
1363
|
+
mp = m0 + (m - m0) / mc * mc;
|
|
1364
|
+
np = n0 + (n - n0) / nc * nc;
|
|
1365
|
+
mnpack(mp, m, n0, np);
|
|
1366
|
+
mnpack(m0, m, np, n);
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
1370
|
+
vec_t vec_A[4], vec_B[8] , vec_C[4];
|
|
1371
|
+
acc_t acc_0, acc_1;
|
|
1372
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1373
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1374
|
+
for (int l = 0; l < k; l+=8) {
|
|
1375
|
+
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
|
1376
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
|
1377
|
+
for (int x = 0; x < 4; x++) {
|
|
1378
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1379
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1380
|
+
}
|
|
1381
|
+
}
|
|
1382
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1383
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
|
1384
|
+
}
|
|
1385
|
+
|
|
1386
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
1387
|
+
vec_t vec_A[8], vec_B[4] , vec_C[4];
|
|
1388
|
+
acc_t acc_0, acc_1;
|
|
1389
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1390
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1391
|
+
for (int l = 0; l < k; l+=8) {
|
|
1392
|
+
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
|
1393
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
|
1394
|
+
for (int x = 0; x < 4; x++) {
|
|
1395
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1396
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
|
1397
|
+
}
|
|
1398
|
+
}
|
|
1399
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1400
|
+
SAVE_ACC(&acc_1, ii+4, jj);
|
|
1401
|
+
}
|
|
1402
|
+
|
|
1403
|
+
|
|
1404
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
1405
|
+
vec_t vec_A[8], vec_B[8], vec_C[4];
|
|
1406
|
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
1407
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1408
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1409
|
+
__builtin_mma_xxsetaccz(&acc_2);
|
|
1410
|
+
__builtin_mma_xxsetaccz(&acc_3);
|
|
1411
|
+
for (int l = 0; l < k; l+=8) {
|
|
1412
|
+
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
|
1413
|
+
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
|
1414
|
+
for (int x = 0; x < 4; x++) {
|
|
1415
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1416
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
|
1417
|
+
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
|
1418
|
+
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
|
1419
|
+
}
|
|
1420
|
+
}
|
|
1421
|
+
|
|
1422
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1423
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
|
1424
|
+
SAVE_ACC(&acc_2, ii+4, jj);
|
|
1425
|
+
SAVE_ACC(&acc_3, ii+4, jj+4);
|
|
1426
|
+
}
|
|
1427
|
+
|
|
1428
|
+
template<int RM, int RN>
|
|
1429
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1430
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1431
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1432
|
+
int64_t tiles = xtiles * ytiles;
|
|
1433
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1434
|
+
int64_t start = duty * ith;
|
|
1435
|
+
int64_t end = start + duty;
|
|
1436
|
+
if (end > tiles)
|
|
1437
|
+
end = tiles;
|
|
1438
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1439
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1440
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1441
|
+
vec_t vec_C[4];
|
|
1442
|
+
acc_t acc_0;
|
|
1443
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1444
|
+
vec_t vec_A[2], vec_B[2];
|
|
1445
|
+
for (int l=0; l<k; l+=4) {
|
|
1446
|
+
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
|
1447
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
|
1448
|
+
for (int x = 0; x<2; x++) {
|
|
1449
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1450
|
+
}
|
|
1451
|
+
}
|
|
1452
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1453
|
+
for (int I = 0; I < RM; I++) {
|
|
1454
|
+
for (int J = 0; J < RN; J++) {
|
|
1455
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1456
|
+
}
|
|
1457
|
+
}
|
|
1458
|
+
}
|
|
1459
|
+
}
|
|
1460
|
+
|
|
1461
|
+
template<int RM>
|
|
1462
|
+
void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1463
|
+
int RN = 8;
|
|
1464
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1465
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1466
|
+
int64_t tiles = xtiles * ytiles;
|
|
1467
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1468
|
+
int64_t start = duty * ith;
|
|
1469
|
+
int64_t end = start + duty;
|
|
1470
|
+
if (end > tiles)
|
|
1471
|
+
end = tiles;
|
|
1472
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1473
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1474
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1475
|
+
vec_t vec_C[4];
|
|
1476
|
+
acc_t acc_0, acc_1;
|
|
1477
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1478
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1479
|
+
vec_t vec_A[4], vec_B[8];
|
|
1480
|
+
for (int l=0; l<k; l+=8) {
|
|
1481
|
+
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
|
1482
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
|
1483
|
+
for (int x = 0; x<4; x++) {
|
|
1484
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1485
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1486
|
+
}
|
|
1487
|
+
}
|
|
1488
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1489
|
+
for (int I = 0; I < RM; I++) {
|
|
1490
|
+
for (int J = 0; J < 4; J++) {
|
|
1491
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1492
|
+
}
|
|
1493
|
+
}
|
|
1494
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_1);
|
|
1495
|
+
for (int I = 0; I < RM; I++) {
|
|
1496
|
+
for (int J = 0; J < 4; J++) {
|
|
1497
|
+
*((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1498
|
+
}
|
|
1499
|
+
}
|
|
1500
|
+
}
|
|
1501
|
+
}
|
|
1502
|
+
|
|
1503
|
+
template<int RM, int RN>
|
|
1504
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
|
1505
|
+
if constexpr(RM == 4 && RN == 8) {
|
|
1506
|
+
KERNEL_4x8(ii,jj);
|
|
1507
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
|
1508
|
+
KERNEL_8x8(ii,jj);
|
|
1509
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
|
1510
|
+
KERNEL_8x4(ii,jj);
|
|
1511
|
+
} else {
|
|
1512
|
+
static_assert(false, "RN/RM values not supported");
|
|
1513
|
+
}
|
|
1514
|
+
}
|
|
1515
|
+
|
|
1516
|
+
template <int RM, int RN>
|
|
1517
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1518
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1519
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1520
|
+
int64_t tiles = xtiles * ytiles;
|
|
1521
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1522
|
+
int64_t start = duty * ith;
|
|
1523
|
+
int64_t end = start + duty;
|
|
1524
|
+
if (end > tiles)
|
|
1525
|
+
end = tiles;
|
|
1526
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1527
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1528
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1529
|
+
kernel<RM, RN>(ii, jj);
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
|
|
1533
|
+
const TA *const A;
|
|
1534
|
+
const TB *const B;
|
|
1535
|
+
TC *C;
|
|
1536
|
+
const int64_t k;
|
|
1537
|
+
const int64_t lda;
|
|
1538
|
+
const int64_t ldb;
|
|
1539
|
+
const int64_t ldc;
|
|
1540
|
+
const int ith;
|
|
1541
|
+
const int nth;
|
|
1542
|
+
};
|
|
1543
|
+
|
|
1057
1544
|
template <typename TA, typename TB, typename TC>
|
|
1058
1545
|
class tinyBLAS_Q0_PPC {
|
|
1059
1546
|
public:
|
|
@@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
|
|
|
2202
2689
|
boffset = vec;
|
|
2203
2690
|
j = (rows >> 3);
|
|
2204
2691
|
if (j > 0) {
|
|
2692
|
+
|
|
2205
2693
|
do {
|
|
2206
2694
|
aoffset1 = aoffset;
|
|
2207
2695
|
aoffset2 = aoffset1 + lda;
|
|
@@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2875
3363
|
(float *)C, ldc};
|
|
2876
3364
|
return tb.matmul(m, n);
|
|
2877
3365
|
}
|
|
3366
|
+
#elif defined(__MMA__)
|
|
3367
|
+
if ((k % 8))
|
|
3368
|
+
return false;
|
|
3369
|
+
if(Btype == GGML_TYPE_BF16) {
|
|
3370
|
+
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
|
3371
|
+
(const ggml_bf16_t *)A, lda,
|
|
3372
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3373
|
+
(float *)C, ldc,
|
|
3374
|
+
params->ith, params->nth};
|
|
3375
|
+
tb.matmul(m, n);
|
|
3376
|
+
return true;
|
|
3377
|
+
}
|
|
2878
3378
|
#endif
|
|
2879
3379
|
return false;
|
|
2880
3380
|
}
|
|
3381
|
+
|
|
2881
3382
|
case GGML_TYPE_F16: {
|
|
2882
3383
|
#if defined(__AVX512F__)
|
|
2883
3384
|
if (Btype == GGML_TYPE_F16) {
|
|
@@ -8,19 +8,6 @@
|
|
|
8
8
|
|
|
9
9
|
#include <float.h>
|
|
10
10
|
|
|
11
|
-
#if defined(_MSC_VER)
|
|
12
|
-
// disable "possible loss of data" to avoid hundreds of casts
|
|
13
|
-
// we should just be careful :)
|
|
14
|
-
#pragma warning(disable: 4244 4267)
|
|
15
|
-
|
|
16
|
-
// disable POSIX deprecation warnings
|
|
17
|
-
// these functions are never going away, anyway
|
|
18
|
-
#pragma warning(disable: 4996)
|
|
19
|
-
|
|
20
|
-
// unreachable code because of multiple instances of code after GGML_ABORT
|
|
21
|
-
#pragma warning(disable: 4702)
|
|
22
|
-
#endif
|
|
23
|
-
|
|
24
11
|
// ggml_compute_forward_dup
|
|
25
12
|
|
|
26
13
|
static void ggml_compute_forward_dup_same_cont(
|
|
@@ -2,12 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
#include <cassert>
|
|
4
4
|
|
|
5
|
-
#if defined(_MSC_VER)
|
|
6
|
-
// disable "possible loss of data" to avoid hundreds of casts
|
|
7
|
-
// we should just be careful :)
|
|
8
|
-
#pragma warning(disable: 4244 4267)
|
|
9
|
-
#endif
|
|
10
|
-
|
|
11
5
|
// precomputed gelu table for f16 (128 KB)
|
|
12
6
|
ggml_fp16_t ggml_table_gelu_f16[1 << 16];
|
|
13
7
|
|
|
@@ -12,12 +12,30 @@ if (CUDAToolkit_FOUND)
|
|
|
12
12
|
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
|
|
13
13
|
# 70 == V100, FP16 tensor cores
|
|
14
14
|
# 75 == Turing, int8 tensor cores
|
|
15
|
+
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
|
|
16
|
+
# 86 == RTX 3000, needs CUDA v11.1
|
|
17
|
+
# 89 == RTX 4000, needs CUDA v11.8
|
|
18
|
+
#
|
|
19
|
+
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
|
|
20
|
+
# XX-real == compile CUDA code as device code for this specific architecture
|
|
21
|
+
# no suffix == compile as both PTX and device code
|
|
22
|
+
#
|
|
23
|
+
# The default behavior for a non-native is to build virtual architectures as needed to cover all features needed
|
|
24
|
+
# for best performance and to also build real architectures for the most commonly used GPUs.
|
|
15
25
|
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
|
16
26
|
set(CMAKE_CUDA_ARCHITECTURES "native")
|
|
17
27
|
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
|
18
|
-
|
|
28
|
+
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
|
|
29
|
+
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
|
|
30
|
+
else()
|
|
31
|
+
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
|
|
32
|
+
endif()
|
|
19
33
|
else()
|
|
20
|
-
|
|
34
|
+
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
|
|
35
|
+
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
|
|
36
|
+
else()
|
|
37
|
+
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
|
|
38
|
+
endif()
|
|
21
39
|
endif()
|
|
22
40
|
endif()
|
|
23
41
|
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
|
@@ -100,7 +118,7 @@ if (CUDAToolkit_FOUND)
|
|
|
100
118
|
|
|
101
119
|
set(CUDA_CXX_FLAGS "")
|
|
102
120
|
|
|
103
|
-
set(CUDA_FLAGS -use_fast_math)
|
|
121
|
+
set(CUDA_FLAGS -use_fast_math -extended-lambda)
|
|
104
122
|
|
|
105
123
|
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
|
106
124
|
# Options are:
|
|
@@ -133,6 +151,7 @@ if (CUDAToolkit_FOUND)
|
|
|
133
151
|
COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
|
|
134
152
|
OUTPUT_VARIABLE CUDA_CCVER
|
|
135
153
|
ERROR_QUIET
|
|
154
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
136
155
|
)
|
|
137
156
|
else()
|
|
138
157
|
if (CUDA_CCFULLVER MATCHES Apple)
|
|
@@ -143,7 +162,7 @@ if (CUDAToolkit_FOUND)
|
|
|
143
162
|
string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
|
|
144
163
|
endif()
|
|
145
164
|
|
|
146
|
-
message("
|
|
165
|
+
message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
|
|
147
166
|
|
|
148
167
|
ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER})
|
|
149
168
|
list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
|
|
@@ -207,6 +207,10 @@ typedef struct {
|
|
|
207
207
|
float attn_factor;
|
|
208
208
|
float beta_fast;
|
|
209
209
|
float beta_slow;
|
|
210
|
+
int32_t sect_0;
|
|
211
|
+
int32_t sect_1;
|
|
212
|
+
int32_t sect_2;
|
|
213
|
+
int32_t sect_3;
|
|
210
214
|
} ggml_metal_kargs_rope;
|
|
211
215
|
|
|
212
216
|
typedef struct {
|
|
@@ -299,21 +303,42 @@ typedef struct {
|
|
|
299
303
|
} ggml_metal_kargs_mul_mv_ext;
|
|
300
304
|
|
|
301
305
|
typedef struct {
|
|
302
|
-
int32_t
|
|
303
|
-
int32_t
|
|
304
|
-
uint64_t
|
|
306
|
+
int32_t ne10;
|
|
307
|
+
int32_t ne11; // n_expert_used (bcast)
|
|
308
|
+
uint64_t nb11;
|
|
309
|
+
uint64_t nb12;
|
|
310
|
+
int32_t neh11; // n_tokens
|
|
311
|
+
uint64_t nbh11;
|
|
312
|
+
int32_t ne20; // n_expert_used
|
|
313
|
+
uint64_t nb21;
|
|
314
|
+
} ggml_metal_kargs_mul_mm_id_map0;
|
|
315
|
+
|
|
316
|
+
typedef struct {
|
|
317
|
+
int32_t ne20; // n_expert_used
|
|
318
|
+
int32_t neh0;
|
|
319
|
+
int32_t neh1;
|
|
320
|
+
uint64_t nbh1;
|
|
321
|
+
uint64_t nbh2;
|
|
322
|
+
int32_t ne0;
|
|
323
|
+
uint64_t nb1;
|
|
324
|
+
uint64_t nb2;
|
|
325
|
+
} ggml_metal_kargs_mul_mm_id_map1;
|
|
326
|
+
|
|
327
|
+
typedef struct {
|
|
305
328
|
int32_t ne00;
|
|
306
329
|
int32_t ne02;
|
|
307
330
|
uint64_t nb01;
|
|
308
331
|
uint64_t nb02;
|
|
309
|
-
|
|
310
|
-
int32_t
|
|
311
|
-
|
|
312
|
-
uint64_t
|
|
313
|
-
uint64_t
|
|
314
|
-
uint64_t
|
|
315
|
-
int32_t
|
|
316
|
-
int32_t
|
|
332
|
+
uint64_t nb03;
|
|
333
|
+
int32_t neh12;
|
|
334
|
+
uint64_t nbh10;
|
|
335
|
+
uint64_t nbh11;
|
|
336
|
+
uint64_t nbh12;
|
|
337
|
+
uint64_t nbh13;
|
|
338
|
+
int32_t neh0;
|
|
339
|
+
int32_t neh1;
|
|
340
|
+
int16_t r2;
|
|
341
|
+
int16_t r3;
|
|
317
342
|
} ggml_metal_kargs_mul_mm_id;
|
|
318
343
|
|
|
319
344
|
typedef struct {
|
|
@@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
|
|
4855
4855
|
if (!any_on_device) {
|
|
4856
4856
|
return false;
|
|
4857
4857
|
}
|
|
4858
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
4859
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
4860
4858
|
func = ggml_cl_add;
|
|
4861
4859
|
break;
|
|
4862
4860
|
case GGML_OP_MUL:
|