@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -50,6 +50,7 @@
|
|
|
50
50
|
|
|
51
51
|
#include "sgemm.h"
|
|
52
52
|
#include "ggml-impl.h"
|
|
53
|
+
#include "ggml-cpu-impl.h"
|
|
53
54
|
#include "ggml-quants.h"
|
|
54
55
|
|
|
55
56
|
#ifdef _MSC_VER
|
|
@@ -235,6 +236,14 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|
|
235
236
|
}
|
|
236
237
|
#endif // __AVX512F__
|
|
237
238
|
|
|
239
|
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
240
|
+
// CONSTANTS
|
|
241
|
+
|
|
242
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
|
243
|
+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
244
|
+
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|
245
|
+
#endif
|
|
246
|
+
|
|
238
247
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
239
248
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
240
249
|
|
|
@@ -606,17 +615,29 @@ class tinyBLAS_Q0_AVX {
|
|
|
606
615
|
case 0x44:
|
|
607
616
|
mc = 4;
|
|
608
617
|
nc = 4;
|
|
618
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
619
|
+
gemm4xN<4>(m0, m, n0, n);
|
|
620
|
+
#else
|
|
609
621
|
gemm<4, 4>(m0, m, n0, n);
|
|
622
|
+
#endif
|
|
610
623
|
break;
|
|
611
624
|
case 0x43:
|
|
612
625
|
mc = 4;
|
|
613
626
|
nc = 3;
|
|
627
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
628
|
+
gemm4xN<3>(m0, m, n0, n);
|
|
629
|
+
#else
|
|
614
630
|
gemm<4, 3>(m0, m, n0, n);
|
|
631
|
+
#endif
|
|
615
632
|
break;
|
|
616
633
|
case 0x34:
|
|
617
634
|
mc = 3;
|
|
618
635
|
nc = 4;
|
|
636
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
637
|
+
gemmMx4<3>(m0, m, n0, n);
|
|
638
|
+
#else
|
|
619
639
|
gemm<3, 4>(m0, m, n0, n);
|
|
640
|
+
#endif
|
|
620
641
|
break;
|
|
621
642
|
case 0x33:
|
|
622
643
|
mc = 3;
|
|
@@ -626,12 +647,20 @@ class tinyBLAS_Q0_AVX {
|
|
|
626
647
|
case 0x42:
|
|
627
648
|
mc = 4;
|
|
628
649
|
nc = 2;
|
|
650
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
651
|
+
gemm4xN<2>(m0, m, n0, n);
|
|
652
|
+
#else
|
|
629
653
|
gemm<4, 2>(m0, m, n0, n);
|
|
654
|
+
#endif
|
|
630
655
|
break;
|
|
631
656
|
case 0x24:
|
|
632
657
|
mc = 2;
|
|
633
658
|
nc = 4;
|
|
659
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
660
|
+
gemmMx4<2>(m0, m, n0, n);
|
|
661
|
+
#else
|
|
634
662
|
gemm<2, 4>(m0, m, n0, n);
|
|
663
|
+
#endif
|
|
635
664
|
break;
|
|
636
665
|
#else
|
|
637
666
|
case 0x44:
|
|
@@ -639,13 +668,21 @@ class tinyBLAS_Q0_AVX {
|
|
|
639
668
|
case 0x42:
|
|
640
669
|
mc = 4;
|
|
641
670
|
nc = 2;
|
|
671
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
672
|
+
gemm4xN<2>(m0, m, n0, n);
|
|
673
|
+
#else
|
|
642
674
|
gemm<4, 2>(m0, m, n0, n);
|
|
675
|
+
#endif
|
|
643
676
|
break;
|
|
644
677
|
case 0x34:
|
|
645
678
|
case 0x24:
|
|
646
679
|
mc = 2;
|
|
647
680
|
nc = 4;
|
|
681
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
682
|
+
gemmMx4<2>(m0, m, n0, n);
|
|
683
|
+
#else
|
|
648
684
|
gemm<2, 4>(m0, m, n0, n);
|
|
685
|
+
#endif
|
|
649
686
|
break;
|
|
650
687
|
case 0x33:
|
|
651
688
|
#endif
|
|
@@ -662,7 +699,11 @@ class tinyBLAS_Q0_AVX {
|
|
|
662
699
|
case 0x41:
|
|
663
700
|
mc = 4;
|
|
664
701
|
nc = 1;
|
|
702
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
703
|
+
gemm4xN<1>(m0, m, n0, n);
|
|
704
|
+
#else
|
|
665
705
|
gemm<4, 1>(m0, m, n0, n);
|
|
706
|
+
#endif
|
|
666
707
|
break;
|
|
667
708
|
case 0x22:
|
|
668
709
|
mc = 2;
|
|
@@ -672,7 +713,11 @@ class tinyBLAS_Q0_AVX {
|
|
|
672
713
|
case 0x14:
|
|
673
714
|
mc = 1;
|
|
674
715
|
nc = 4;
|
|
716
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
717
|
+
gemmMx4<1>(m0, m, n0, n);
|
|
718
|
+
#else
|
|
675
719
|
gemm<1, 4>(m0, m, n0, n);
|
|
720
|
+
#endif
|
|
676
721
|
break;
|
|
677
722
|
case 0x31:
|
|
678
723
|
mc = 3;
|
|
@@ -708,6 +753,119 @@ class tinyBLAS_Q0_AVX {
|
|
|
708
753
|
mnpack(m0, m, np, n);
|
|
709
754
|
}
|
|
710
755
|
|
|
756
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
|
757
|
+
// Templated functions for gemm of dimensions 4xN
|
|
758
|
+
template <int RN>
|
|
759
|
+
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
760
|
+
int64_t ytiles = (m - m0) / 4;
|
|
761
|
+
int64_t xtiles = (n - n0) / RN;
|
|
762
|
+
int64_t tiles = xtiles * ytiles;
|
|
763
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
764
|
+
int64_t start = duty * ith;
|
|
765
|
+
int64_t end = start + duty;
|
|
766
|
+
if (end > tiles)
|
|
767
|
+
end = tiles;
|
|
768
|
+
for (int64_t job = start; job < end; ++job) {
|
|
769
|
+
int64_t ii = m0 + job / xtiles * 4;
|
|
770
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
771
|
+
__m256 Cv[RN][4] = {};
|
|
772
|
+
for (int64_t l = 0; l < k; ++l) {
|
|
773
|
+
uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
|
|
774
|
+
// Convert delta values for four blocks to float values
|
|
775
|
+
__m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
|
|
776
|
+
__m256i avec0 = load(A + lda * (ii + 0) + l);
|
|
777
|
+
__m256i avec1 = load(A + lda * (ii + 1) + l);
|
|
778
|
+
__m256i avec2 = load(A + lda * (ii + 2) + l);
|
|
779
|
+
__m256i avec3 = load(A + lda * (ii + 3) + l);
|
|
780
|
+
for (int64_t j = 0; j < RN; ++j) {
|
|
781
|
+
__m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
|
|
782
|
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
|
783
|
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
|
784
|
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
|
785
|
+
// Computation of dot product and multiplication with appropriate delta value products
|
|
786
|
+
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
|
787
|
+
updot(_mm256_sign_epi8(avec0, avec0),
|
|
788
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
|
|
789
|
+
Cv[j][0]);
|
|
790
|
+
Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
|
791
|
+
updot(_mm256_sign_epi8(avec1, avec1),
|
|
792
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
|
|
793
|
+
Cv[j][1]);
|
|
794
|
+
Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
|
795
|
+
updot(_mm256_sign_epi8(avec2, avec2),
|
|
796
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
|
|
797
|
+
Cv[j][2]);
|
|
798
|
+
Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
|
799
|
+
updot(_mm256_sign_epi8(avec3, avec3),
|
|
800
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
|
|
801
|
+
Cv[j][3]);
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
for (int64_t j = 0; j < RN; ++j)
|
|
806
|
+
for (int64_t i = 0; i < 4; ++i)
|
|
807
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
|
808
|
+
}
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
// Templated functions for gemm of dimensions Mx4
|
|
812
|
+
template <int RM>
|
|
813
|
+
NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
814
|
+
int64_t ytiles = (m - m0) / RM;
|
|
815
|
+
int64_t xtiles = (n - n0) / 4;
|
|
816
|
+
int64_t tiles = xtiles * ytiles;
|
|
817
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
818
|
+
int64_t start = duty * ith;
|
|
819
|
+
int64_t end = start + duty;
|
|
820
|
+
if (end > tiles)
|
|
821
|
+
end = tiles;
|
|
822
|
+
for (int64_t job = start; job < end; ++job) {
|
|
823
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
824
|
+
int64_t jj = n0 + job % xtiles * 4;
|
|
825
|
+
__m256 Cv[4][RM] = {};
|
|
826
|
+
for (int64_t l = 0; l < k; ++l) {
|
|
827
|
+
uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
|
|
828
|
+
// Convert delta values for four blocks to float values
|
|
829
|
+
__m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
|
|
830
|
+
__m256i bvec0 = load(B + ldb * (jj + 0) + l);
|
|
831
|
+
__m256i bvec1 = load(B + ldb * (jj + 1) + l);
|
|
832
|
+
__m256i bvec2 = load(B + ldb * (jj + 2) + l);
|
|
833
|
+
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
|
|
834
|
+
for (int64_t i = 0; i < RM; ++i) {
|
|
835
|
+
__m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
|
|
836
|
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
|
837
|
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
|
838
|
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
|
839
|
+
// Computation of dot product and multiplication with appropriate delta value products
|
|
840
|
+
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
|
841
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
|
842
|
+
load(A + lda * (ii + i) + l)),
|
|
843
|
+
_mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
|
|
844
|
+
Cv[0][i]);
|
|
845
|
+
Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
|
846
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
|
847
|
+
load(A + lda * (ii + i) + l)),
|
|
848
|
+
_mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
|
|
849
|
+
Cv[1][i]);
|
|
850
|
+
Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
|
851
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
|
852
|
+
load(A + lda * (ii + i) + l)),
|
|
853
|
+
_mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
|
|
854
|
+
Cv[2][i]);
|
|
855
|
+
Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
|
856
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
|
857
|
+
load(A + lda * (ii + i) + l)),
|
|
858
|
+
_mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
|
|
859
|
+
Cv[3][i]);
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
for (int64_t j = 0; j < 4; ++j)
|
|
863
|
+
for (int64_t i = 0; i < RM; ++i)
|
|
864
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
#endif
|
|
868
|
+
|
|
711
869
|
template <int RM, int RN>
|
|
712
870
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
713
871
|
int64_t ytiles = (m - m0) / RM;
|
|
@@ -784,6 +942,20 @@ class tinyBLAS_Q0_AVX {
|
|
|
784
942
|
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
|
785
943
|
}
|
|
786
944
|
|
|
945
|
+
inline __m256i load(const block_iq4_nl *b) {
|
|
946
|
+
return MM256_SET_M128I(load1(b), load0(b));
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
inline __m128i load0(const block_iq4_nl *b) {
|
|
950
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
|
951
|
+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
inline __m128i load1(const block_iq4_nl *b) {
|
|
955
|
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
|
956
|
+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
|
|
957
|
+
}
|
|
958
|
+
|
|
787
959
|
inline __m256 updot(__m256i u, __m256i s) {
|
|
788
960
|
__m256i res;
|
|
789
961
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
|
@@ -857,6 +1029,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
857
1029
|
assert(nth > 0);
|
|
858
1030
|
assert(ith < nth);
|
|
859
1031
|
|
|
1032
|
+
// only enable sgemm for prompt processing
|
|
1033
|
+
if (n < 2)
|
|
1034
|
+
return false;
|
|
1035
|
+
|
|
860
1036
|
if (Ctype != GGML_TYPE_F32)
|
|
861
1037
|
return false;
|
|
862
1038
|
|
|
@@ -1006,6 +1182,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
|
1006
1182
|
#endif
|
|
1007
1183
|
}
|
|
1008
1184
|
|
|
1185
|
+
case GGML_TYPE_IQ4_NL: {
|
|
1186
|
+
if (Btype != GGML_TYPE_Q8_0)
|
|
1187
|
+
return false;
|
|
1188
|
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
|
1189
|
+
tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
|
|
1190
|
+
k, (const block_iq4_nl *)A, lda,
|
|
1191
|
+
(const block_q8_0 *)B, ldb,
|
|
1192
|
+
(float *)C, ldc,
|
|
1193
|
+
ith, nth};
|
|
1194
|
+
tb.matmul(m, n);
|
|
1195
|
+
return true;
|
|
1196
|
+
#else
|
|
1197
|
+
return false;
|
|
1198
|
+
#endif
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1009
1201
|
default:
|
|
1010
1202
|
return false;
|
|
1011
1203
|
}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
find_package (Threads REQUIRED)
|
|
1
2
|
|
|
2
3
|
set(TARGET vulkan-shaders-gen)
|
|
3
4
|
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
|
4
5
|
install(TARGETS ${TARGET} RUNTIME)
|
|
5
6
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
7
|
+
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
|
@@ -22,6 +22,7 @@
|
|
|
22
22
|
#ifdef _WIN32
|
|
23
23
|
#include <windows.h>
|
|
24
24
|
#include <direct.h> // For _mkdir on Windows
|
|
25
|
+
#include <algorithm> // For std::replace on w64devkit
|
|
25
26
|
#else
|
|
26
27
|
#include <unistd.h>
|
|
27
28
|
#include <sys/wait.h>
|
|
@@ -30,20 +31,6 @@
|
|
|
30
31
|
|
|
31
32
|
#define ASYNCIO_CONCURRENCY 64
|
|
32
33
|
|
|
33
|
-
// define prototypes
|
|
34
|
-
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str);
|
|
35
|
-
bool directory_exists(const std::string& path);
|
|
36
|
-
bool create_directory(const std::string& path);
|
|
37
|
-
std::string to_uppercase(const std::string& input);
|
|
38
|
-
bool string_ends_with(const std::string& str, const std::string& suffix);
|
|
39
|
-
std::string join_paths(const std::string& path1, const std::string& path2);
|
|
40
|
-
std::string basename(const std::string &path);
|
|
41
|
-
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16);
|
|
42
|
-
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b);
|
|
43
|
-
void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id);
|
|
44
|
-
void process_shaders(std::vector<std::future<void>>& tasks);
|
|
45
|
-
void write_output_files();
|
|
46
|
-
|
|
47
34
|
std::mutex lock;
|
|
48
35
|
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
|
49
36
|
|
|
@@ -52,7 +39,7 @@ std::string input_dir = "vulkan-shaders";
|
|
|
52
39
|
std::string output_dir = "/tmp";
|
|
53
40
|
std::string target_hpp = "ggml-vulkan-shaders.hpp";
|
|
54
41
|
std::string target_cpp = "ggml-vulkan-shaders.cpp";
|
|
55
|
-
bool
|
|
42
|
+
bool no_clean = false;
|
|
56
43
|
|
|
57
44
|
const std::vector<std::string> type_names = {
|
|
58
45
|
"f32",
|
|
@@ -193,11 +180,7 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
|
|
|
193
180
|
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
|
194
181
|
}
|
|
195
182
|
|
|
196
|
-
|
|
197
|
-
static const char path_separator = '\\';
|
|
198
|
-
#else
|
|
199
|
-
static const char path_separator = '/';
|
|
200
|
-
#endif
|
|
183
|
+
static const char path_separator = '/';
|
|
201
184
|
|
|
202
185
|
std::string join_paths(const std::string& path1, const std::string& path2) {
|
|
203
186
|
return path1 + path_separator + path2;
|
|
@@ -212,7 +195,16 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
|
|
212
195
|
std::string out_fname = join_paths(output_dir, name + ".spv");
|
|
213
196
|
std::string in_path = join_paths(input_dir, in_fname);
|
|
214
197
|
|
|
215
|
-
|
|
198
|
+
#ifdef _WIN32
|
|
199
|
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
|
|
200
|
+
#else
|
|
201
|
+
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
|
|
202
|
+
#endif
|
|
203
|
+
|
|
204
|
+
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
|
|
205
|
+
cmd.push_back("-g");
|
|
206
|
+
#endif
|
|
207
|
+
|
|
216
208
|
for (const auto& define : defines) {
|
|
217
209
|
cmd.push_back("-D" + define.first + "=" + define.second);
|
|
218
210
|
}
|
|
@@ -283,9 +275,12 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
|
|
|
283
275
|
|
|
284
276
|
for (const auto& tname : type_names) {
|
|
285
277
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
278
|
+
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
|
279
|
+
std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
|
|
280
|
+
// For aligned matmul loads
|
|
286
281
|
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
|
|
287
282
|
tasks.push_back(std::async(std::launch::async, [=] {
|
|
288
|
-
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A",
|
|
283
|
+
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
|
|
289
284
|
}));
|
|
290
285
|
tasks.push_back(std::async(std::launch::async, [=] {
|
|
291
286
|
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
|
|
@@ -354,6 +349,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|
|
354
349
|
tasks.push_back(std::async(std::launch::async, [=] {
|
|
355
350
|
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
356
351
|
}));
|
|
352
|
+
tasks.push_back(std::async(std::launch::async, [=] {
|
|
353
|
+
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
354
|
+
}));
|
|
357
355
|
tasks.push_back(std::async(std::launch::async, [=] {
|
|
358
356
|
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
359
357
|
}));
|
|
@@ -371,6 +369,13 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|
|
371
369
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
372
370
|
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
373
371
|
}));
|
|
372
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
373
|
+
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
|
374
|
+
}));
|
|
375
|
+
|
|
376
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
377
|
+
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
378
|
+
}));
|
|
374
379
|
|
|
375
380
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
376
381
|
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
|
@@ -384,6 +389,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|
|
384
389
|
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
385
390
|
}));
|
|
386
391
|
|
|
392
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
393
|
+
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
394
|
+
}));
|
|
395
|
+
|
|
387
396
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
388
397
|
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
389
398
|
}));
|
|
@@ -392,19 +401,54 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|
|
392
401
|
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
393
402
|
}));
|
|
394
403
|
|
|
404
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
405
|
+
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
406
|
+
}));
|
|
407
|
+
|
|
408
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
409
|
+
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
410
|
+
}));
|
|
411
|
+
|
|
395
412
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
396
413
|
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
397
414
|
}));
|
|
398
415
|
|
|
416
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
417
|
+
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
418
|
+
}));
|
|
419
|
+
|
|
420
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
421
|
+
string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
422
|
+
}));
|
|
423
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
424
|
+
string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
|
425
|
+
}));
|
|
426
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
427
|
+
string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
|
|
428
|
+
}));
|
|
429
|
+
|
|
430
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
431
|
+
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
432
|
+
}));
|
|
433
|
+
|
|
399
434
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
400
435
|
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
401
436
|
}));
|
|
437
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
438
|
+
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
439
|
+
}));
|
|
402
440
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
403
441
|
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
404
442
|
}));
|
|
405
443
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
406
444
|
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
407
445
|
}));
|
|
446
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
447
|
+
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
448
|
+
}));
|
|
449
|
+
tasks.push_back(std::async(std::launch::async, [] {
|
|
450
|
+
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
451
|
+
}));
|
|
408
452
|
|
|
409
453
|
tasks.push_back(std::async(std::launch::async, [] {
|
|
410
454
|
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
@@ -438,6 +482,17 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|
|
438
482
|
tasks.push_back(std::async(std::launch::async, [=] {
|
|
439
483
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
440
484
|
}));
|
|
485
|
+
|
|
486
|
+
tasks.push_back(std::async(std::launch::async, [=] {
|
|
487
|
+
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
488
|
+
}));
|
|
489
|
+
tasks.push_back(std::async(std::launch::async, [=] {
|
|
490
|
+
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
|
491
|
+
}));
|
|
492
|
+
|
|
493
|
+
tasks.push_back(std::async(std::launch::async, [=] {
|
|
494
|
+
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
495
|
+
}));
|
|
441
496
|
}
|
|
442
497
|
|
|
443
498
|
void write_output_files() {
|
|
@@ -449,10 +504,16 @@ void write_output_files() {
|
|
|
449
504
|
|
|
450
505
|
for (const auto& pair : shader_fnames) {
|
|
451
506
|
const std::string& name = pair.first;
|
|
452
|
-
|
|
507
|
+
#ifdef _WIN32
|
|
508
|
+
std::string path = pair.second;
|
|
509
|
+
std::replace(path.begin(), path.end(), '/', '\\' );
|
|
510
|
+
#else
|
|
511
|
+
const std::string& path = pair.second;
|
|
512
|
+
#endif
|
|
513
|
+
|
|
453
514
|
FILE* spv = fopen(path.c_str(), "rb");
|
|
454
515
|
if (!spv) {
|
|
455
|
-
std::cerr << "Error opening SPIR-V file: " << path << "\n";
|
|
516
|
+
std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
|
|
456
517
|
continue;
|
|
457
518
|
}
|
|
458
519
|
|
|
@@ -464,7 +525,7 @@ void write_output_files() {
|
|
|
464
525
|
size_t read_size = fread(data.data(), 1, size, spv);
|
|
465
526
|
fclose(spv);
|
|
466
527
|
if (read_size != size) {
|
|
467
|
-
std::cerr << "Error reading SPIR-V file: " << path << "\n";
|
|
528
|
+
std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
|
|
468
529
|
continue;
|
|
469
530
|
}
|
|
470
531
|
|
|
@@ -478,9 +539,8 @@ void write_output_files() {
|
|
|
478
539
|
}
|
|
479
540
|
fprintf(src, "\n};\n\n");
|
|
480
541
|
|
|
481
|
-
if (
|
|
542
|
+
if (!no_clean) {
|
|
482
543
|
std::remove(path.c_str());
|
|
483
|
-
// fprintf(stderr, "Removed: %s\n", path.c_str());
|
|
484
544
|
}
|
|
485
545
|
}
|
|
486
546
|
|
|
@@ -496,18 +556,6 @@ int main(int argc, char** argv) {
|
|
|
496
556
|
}
|
|
497
557
|
}
|
|
498
558
|
|
|
499
|
-
if (argc <= 1 || args.find("--help") != args.end()) {
|
|
500
|
-
std::cout << "Usage:\n"
|
|
501
|
-
"\tvulkan-shaders-gen [options]\n\n"
|
|
502
|
-
"Options:\n"
|
|
503
|
-
"\t--glslc <path> Path to glslc executable (default: /usr/bin/glslc)\n"
|
|
504
|
-
"\t--input-dir Directory containing shader sources (required)\n"
|
|
505
|
-
"\t--output-dir Output directory for generated SPIR-V files and optional C++ headers\n"
|
|
506
|
-
"\t--target-hpp <path> Path to generate a header file with shader declarations in C++ format\n"
|
|
507
|
-
"\t--target-cpp <path> Path to generate a source code file implementing the declared shaders (optional)\n"
|
|
508
|
-
"\t--no-clean Keep temporary SPIR-V files after build (default: remove them)\n";
|
|
509
|
-
return EXIT_SUCCESS;
|
|
510
|
-
}
|
|
511
559
|
if (args.find("--glslc") != args.end()) {
|
|
512
560
|
GLSLC = args["--glslc"]; // Path to glslc
|
|
513
561
|
}
|
|
@@ -524,7 +572,7 @@ int main(int argc, char** argv) {
|
|
|
524
572
|
target_cpp = args["--target-cpp"]; // Path to generated cpp file
|
|
525
573
|
}
|
|
526
574
|
if (args.find("--no-clean") != args.end()) {
|
|
527
|
-
|
|
575
|
+
no_clean = true; // Keep temporary SPIR-V files in output-dir after build
|
|
528
576
|
}
|
|
529
577
|
|
|
530
578
|
if (!directory_exists(input_dir)) {
|