@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -50,6 +50,7 @@
50
50
 
51
51
  #include "sgemm.h"
52
52
  #include "ggml-impl.h"
53
+ #include "ggml-cpu-impl.h"
53
54
  #include "ggml-quants.h"
54
55
 
55
56
  #ifdef _MSC_VER
@@ -235,6 +236,14 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
235
236
  }
236
237
  #endif // __AVX512F__
237
238
 
239
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
240
+ // CONSTANTS
241
+
242
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
243
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
244
+ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
245
+ #endif
246
+
238
247
  ////////////////////////////////////////////////////////////////////////////////////////////////////
239
248
  // FLOATING POINT MATRIX MULTIPLICATION
240
249
 
@@ -606,17 +615,29 @@ class tinyBLAS_Q0_AVX {
606
615
  case 0x44:
607
616
  mc = 4;
608
617
  nc = 4;
618
+ #if defined(__AVX2__) && defined(__F16C__)
619
+ gemm4xN<4>(m0, m, n0, n);
620
+ #else
609
621
  gemm<4, 4>(m0, m, n0, n);
622
+ #endif
610
623
  break;
611
624
  case 0x43:
612
625
  mc = 4;
613
626
  nc = 3;
627
+ #if defined(__AVX2__) && defined(__F16C__)
628
+ gemm4xN<3>(m0, m, n0, n);
629
+ #else
614
630
  gemm<4, 3>(m0, m, n0, n);
631
+ #endif
615
632
  break;
616
633
  case 0x34:
617
634
  mc = 3;
618
635
  nc = 4;
636
+ #if defined(__AVX2__) && defined(__F16C__)
637
+ gemmMx4<3>(m0, m, n0, n);
638
+ #else
619
639
  gemm<3, 4>(m0, m, n0, n);
640
+ #endif
620
641
  break;
621
642
  case 0x33:
622
643
  mc = 3;
@@ -626,12 +647,20 @@ class tinyBLAS_Q0_AVX {
626
647
  case 0x42:
627
648
  mc = 4;
628
649
  nc = 2;
650
+ #if defined(__AVX2__) && defined(__F16C__)
651
+ gemm4xN<2>(m0, m, n0, n);
652
+ #else
629
653
  gemm<4, 2>(m0, m, n0, n);
654
+ #endif
630
655
  break;
631
656
  case 0x24:
632
657
  mc = 2;
633
658
  nc = 4;
659
+ #if defined(__AVX2__) && defined(__F16C__)
660
+ gemmMx4<2>(m0, m, n0, n);
661
+ #else
634
662
  gemm<2, 4>(m0, m, n0, n);
663
+ #endif
635
664
  break;
636
665
  #else
637
666
  case 0x44:
@@ -639,13 +668,21 @@ class tinyBLAS_Q0_AVX {
639
668
  case 0x42:
640
669
  mc = 4;
641
670
  nc = 2;
671
+ #if defined(__AVX2__) && defined(__F16C__)
672
+ gemm4xN<2>(m0, m, n0, n);
673
+ #else
642
674
  gemm<4, 2>(m0, m, n0, n);
675
+ #endif
643
676
  break;
644
677
  case 0x34:
645
678
  case 0x24:
646
679
  mc = 2;
647
680
  nc = 4;
681
+ #if defined(__AVX2__) && defined(__F16C__)
682
+ gemmMx4<2>(m0, m, n0, n);
683
+ #else
648
684
  gemm<2, 4>(m0, m, n0, n);
685
+ #endif
649
686
  break;
650
687
  case 0x33:
651
688
  #endif
@@ -662,7 +699,11 @@ class tinyBLAS_Q0_AVX {
662
699
  case 0x41:
663
700
  mc = 4;
664
701
  nc = 1;
702
+ #if defined(__AVX2__) && defined(__F16C__)
703
+ gemm4xN<1>(m0, m, n0, n);
704
+ #else
665
705
  gemm<4, 1>(m0, m, n0, n);
706
+ #endif
666
707
  break;
667
708
  case 0x22:
668
709
  mc = 2;
@@ -672,7 +713,11 @@ class tinyBLAS_Q0_AVX {
672
713
  case 0x14:
673
714
  mc = 1;
674
715
  nc = 4;
716
+ #if defined(__AVX2__) && defined(__F16C__)
717
+ gemmMx4<1>(m0, m, n0, n);
718
+ #else
675
719
  gemm<1, 4>(m0, m, n0, n);
720
+ #endif
676
721
  break;
677
722
  case 0x31:
678
723
  mc = 3;
@@ -708,6 +753,119 @@ class tinyBLAS_Q0_AVX {
708
753
  mnpack(m0, m, np, n);
709
754
  }
710
755
 
756
+ #if defined(__AVX2__) && defined(__F16C__)
757
+ // Templated functions for gemm of dimensions 4xN
758
+ template <int RN>
759
+ NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
760
+ int64_t ytiles = (m - m0) / 4;
761
+ int64_t xtiles = (n - n0) / RN;
762
+ int64_t tiles = xtiles * ytiles;
763
+ int64_t duty = (tiles + nth - 1) / nth;
764
+ int64_t start = duty * ith;
765
+ int64_t end = start + duty;
766
+ if (end > tiles)
767
+ end = tiles;
768
+ for (int64_t job = start; job < end; ++job) {
769
+ int64_t ii = m0 + job / xtiles * 4;
770
+ int64_t jj = n0 + job % xtiles * RN;
771
+ __m256 Cv[RN][4] = {};
772
+ for (int64_t l = 0; l < k; ++l) {
773
+ uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
774
+ // Convert delta values for four blocks to float values
775
+ __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
776
+ __m256i avec0 = load(A + lda * (ii + 0) + l);
777
+ __m256i avec1 = load(A + lda * (ii + 1) + l);
778
+ __m256i avec2 = load(A + lda * (ii + 2) + l);
779
+ __m256i avec3 = load(A + lda * (ii + 3) + l);
780
+ for (int64_t j = 0; j < RN; ++j) {
781
+ __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
782
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
783
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
784
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
785
+ // Computation of dot product and multiplication with appropriate delta value products
786
+ Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
787
+ updot(_mm256_sign_epi8(avec0, avec0),
788
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
789
+ Cv[j][0]);
790
+ Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
791
+ updot(_mm256_sign_epi8(avec1, avec1),
792
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
793
+ Cv[j][1]);
794
+ Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
795
+ updot(_mm256_sign_epi8(avec2, avec2),
796
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
797
+ Cv[j][2]);
798
+ Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
799
+ updot(_mm256_sign_epi8(avec3, avec3),
800
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
801
+ Cv[j][3]);
802
+ }
803
+ }
804
+
805
+ for (int64_t j = 0; j < RN; ++j)
806
+ for (int64_t i = 0; i < 4; ++i)
807
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
808
+ }
809
+ }
810
+
811
+ // Templated functions for gemm of dimensions Mx4
812
+ template <int RM>
813
+ NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
814
+ int64_t ytiles = (m - m0) / RM;
815
+ int64_t xtiles = (n - n0) / 4;
816
+ int64_t tiles = xtiles * ytiles;
817
+ int64_t duty = (tiles + nth - 1) / nth;
818
+ int64_t start = duty * ith;
819
+ int64_t end = start + duty;
820
+ if (end > tiles)
821
+ end = tiles;
822
+ for (int64_t job = start; job < end; ++job) {
823
+ int64_t ii = m0 + job / xtiles * RM;
824
+ int64_t jj = n0 + job % xtiles * 4;
825
+ __m256 Cv[4][RM] = {};
826
+ for (int64_t l = 0; l < k; ++l) {
827
+ uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
828
+ // Convert delta values for four blocks to float values
829
+ __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
830
+ __m256i bvec0 = load(B + ldb * (jj + 0) + l);
831
+ __m256i bvec1 = load(B + ldb * (jj + 1) + l);
832
+ __m256i bvec2 = load(B + ldb * (jj + 2) + l);
833
+ __m256i bvec3 = load(B + ldb * (jj + 3) + l);
834
+ for (int64_t i = 0; i < RM; ++i) {
835
+ __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
836
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
837
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
838
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
839
+ // Computation of dot product and multiplication with appropriate delta value products
840
+ Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
841
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
842
+ load(A + lda * (ii + i) + l)),
843
+ _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
844
+ Cv[0][i]);
845
+ Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
846
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
847
+ load(A + lda * (ii + i) + l)),
848
+ _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
849
+ Cv[1][i]);
850
+ Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
851
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
852
+ load(A + lda * (ii + i) + l)),
853
+ _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
854
+ Cv[2][i]);
855
+ Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
856
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
857
+ load(A + lda * (ii + i) + l)),
858
+ _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
859
+ Cv[3][i]);
860
+ }
861
+ }
862
+ for (int64_t j = 0; j < 4; ++j)
863
+ for (int64_t i = 0; i < RM; ++i)
864
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
865
+ }
866
+ }
867
+ #endif
868
+
711
869
  template <int RM, int RN>
712
870
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
713
871
  int64_t ytiles = (m - m0) / RM;
@@ -784,6 +942,20 @@ class tinyBLAS_Q0_AVX {
784
942
  return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
785
943
  }
786
944
 
945
+ inline __m256i load(const block_iq4_nl *b) {
946
+ return MM256_SET_M128I(load1(b), load0(b));
947
+ }
948
+
949
+ inline __m128i load0(const block_iq4_nl *b) {
950
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
951
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
952
+ }
953
+
954
+ inline __m128i load1(const block_iq4_nl *b) {
955
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
956
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
957
+ }
958
+
787
959
  inline __m256 updot(__m256i u, __m256i s) {
788
960
  __m256i res;
789
961
  #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -857,6 +1029,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
857
1029
  assert(nth > 0);
858
1030
  assert(ith < nth);
859
1031
 
1032
+ // only enable sgemm for prompt processing
1033
+ if (n < 2)
1034
+ return false;
1035
+
860
1036
  if (Ctype != GGML_TYPE_F32)
861
1037
  return false;
862
1038
 
@@ -1006,6 +1182,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1006
1182
  #endif
1007
1183
  }
1008
1184
 
1185
+ case GGML_TYPE_IQ4_NL: {
1186
+ if (Btype != GGML_TYPE_Q8_0)
1187
+ return false;
1188
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1189
+ tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
1190
+ k, (const block_iq4_nl *)A, lda,
1191
+ (const block_q8_0 *)B, ldb,
1192
+ (float *)C, ldc,
1193
+ ith, nth};
1194
+ tb.matmul(m, n);
1195
+ return true;
1196
+ #else
1197
+ return false;
1198
+ #endif
1199
+ }
1200
+
1009
1201
  default:
1010
1202
  return false;
1011
1203
  }
@@ -1,5 +1,7 @@
1
+ find_package (Threads REQUIRED)
1
2
 
2
3
  set(TARGET vulkan-shaders-gen)
3
4
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
4
5
  install(TARGETS ${TARGET} RUNTIME)
5
6
  target_compile_features(${TARGET} PRIVATE cxx_std_11)
7
+ target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
@@ -22,6 +22,7 @@
22
22
  #ifdef _WIN32
23
23
  #include <windows.h>
24
24
  #include <direct.h> // For _mkdir on Windows
25
+ #include <algorithm> // For std::replace on w64devkit
25
26
  #else
26
27
  #include <unistd.h>
27
28
  #include <sys/wait.h>
@@ -30,20 +31,6 @@
30
31
 
31
32
  #define ASYNCIO_CONCURRENCY 64
32
33
 
33
- // define prototypes
34
- void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str);
35
- bool directory_exists(const std::string& path);
36
- bool create_directory(const std::string& path);
37
- std::string to_uppercase(const std::string& input);
38
- bool string_ends_with(const std::string& str, const std::string& suffix);
39
- std::string join_paths(const std::string& path1, const std::string& path2);
40
- std::string basename(const std::string &path);
41
- void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16);
42
- std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b);
43
- void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id);
44
- void process_shaders(std::vector<std::future<void>>& tasks);
45
- void write_output_files();
46
-
47
34
  std::mutex lock;
48
35
  std::vector<std::pair<std::string, std::string>> shader_fnames;
49
36
 
@@ -52,7 +39,7 @@ std::string input_dir = "vulkan-shaders";
52
39
  std::string output_dir = "/tmp";
53
40
  std::string target_hpp = "ggml-vulkan-shaders.hpp";
54
41
  std::string target_cpp = "ggml-vulkan-shaders.cpp";
55
- bool clean = true;
42
+ bool no_clean = false;
56
43
 
57
44
  const std::vector<std::string> type_names = {
58
45
  "f32",
@@ -193,11 +180,7 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
193
180
  return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
194
181
  }
195
182
 
196
- #ifdef _WIN32
197
- static const char path_separator = '\\';
198
- #else
199
- static const char path_separator = '/';
200
- #endif
183
+ static const char path_separator = '/';
201
184
 
202
185
  std::string join_paths(const std::string& path1, const std::string& path2) {
203
186
  return path1 + path_separator + path2;
@@ -212,7 +195,16 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
212
195
  std::string out_fname = join_paths(output_dir, name + ".spv");
213
196
  std::string in_path = join_paths(input_dir, in_fname);
214
197
 
215
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
198
+ #ifdef _WIN32
199
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
200
+ #else
201
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
202
+ #endif
203
+
204
+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
205
+ cmd.push_back("-g");
206
+ #endif
207
+
216
208
  for (const auto& define : defines) {
217
209
  cmd.push_back("-D" + define.first + "=" + define.second);
218
210
  }
@@ -283,9 +275,12 @@ void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmu
283
275
 
284
276
  for (const auto& tname : type_names) {
285
277
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
278
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
279
+ std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
280
+ // For aligned matmul loads
286
281
  std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
287
282
  tasks.push_back(std::async(std::launch::async, [=] {
288
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
283
+ string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
289
284
  }));
290
285
  tasks.push_back(std::async(std::launch::async, [=] {
291
286
  string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
@@ -354,6 +349,9 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
354
349
  tasks.push_back(std::async(std::launch::async, [=] {
355
350
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
356
351
  }));
352
+ tasks.push_back(std::async(std::launch::async, [=] {
353
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
354
+ }));
357
355
  tasks.push_back(std::async(std::launch::async, [=] {
358
356
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
359
357
  }));
@@ -371,6 +369,13 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
371
369
  tasks.push_back(std::async(std::launch::async, [] {
372
370
  string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
373
371
  }));
372
+ tasks.push_back(std::async(std::launch::async, [] {
373
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
374
+ }));
375
+
376
+ tasks.push_back(std::async(std::launch::async, [] {
377
+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
378
+ }));
374
379
 
375
380
  tasks.push_back(std::async(std::launch::async, [] {
376
381
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -384,6 +389,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
384
389
  string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
385
390
  }));
386
391
 
392
+ tasks.push_back(std::async(std::launch::async, [] {
393
+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
394
+ }));
395
+
387
396
  tasks.push_back(std::async(std::launch::async, [] {
388
397
  string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
389
398
  }));
@@ -392,19 +401,54 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
392
401
  string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
393
402
  }));
394
403
 
404
+ tasks.push_back(std::async(std::launch::async, [] {
405
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
406
+ }));
407
+
408
+ tasks.push_back(std::async(std::launch::async, [] {
409
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
410
+ }));
411
+
395
412
  tasks.push_back(std::async(std::launch::async, [] {
396
413
  string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
397
414
  }));
398
415
 
416
+ tasks.push_back(std::async(std::launch::async, [] {
417
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
418
+ }));
419
+
420
+ tasks.push_back(std::async(std::launch::async, [] {
421
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
422
+ }));
423
+ tasks.push_back(std::async(std::launch::async, [] {
424
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
425
+ }));
426
+ tasks.push_back(std::async(std::launch::async, [] {
427
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
428
+ }));
429
+
430
+ tasks.push_back(std::async(std::launch::async, [] {
431
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
432
+ }));
433
+
399
434
  tasks.push_back(std::async(std::launch::async, [] {
400
435
  string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
401
436
  }));
437
+ tasks.push_back(std::async(std::launch::async, [] {
438
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
439
+ }));
402
440
  tasks.push_back(std::async(std::launch::async, [] {
403
441
  string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
404
442
  }));
405
443
  tasks.push_back(std::async(std::launch::async, [] {
406
444
  string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
407
445
  }));
446
+ tasks.push_back(std::async(std::launch::async, [] {
447
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448
+ }));
449
+ tasks.push_back(std::async(std::launch::async, [] {
450
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
+ }));
408
452
 
409
453
  tasks.push_back(std::async(std::launch::async, [] {
410
454
  string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -438,6 +482,17 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
438
482
  tasks.push_back(std::async(std::launch::async, [=] {
439
483
  string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
440
484
  }));
485
+
486
+ tasks.push_back(std::async(std::launch::async, [=] {
487
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
488
+ }));
489
+ tasks.push_back(std::async(std::launch::async, [=] {
490
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
491
+ }));
492
+
493
+ tasks.push_back(std::async(std::launch::async, [=] {
494
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
495
+ }));
441
496
  }
442
497
 
443
498
  void write_output_files() {
@@ -449,10 +504,16 @@ void write_output_files() {
449
504
 
450
505
  for (const auto& pair : shader_fnames) {
451
506
  const std::string& name = pair.first;
452
- const std::string& path = pair.second;
507
+ #ifdef _WIN32
508
+ std::string path = pair.second;
509
+ std::replace(path.begin(), path.end(), '/', '\\' );
510
+ #else
511
+ const std::string& path = pair.second;
512
+ #endif
513
+
453
514
  FILE* spv = fopen(path.c_str(), "rb");
454
515
  if (!spv) {
455
- std::cerr << "Error opening SPIR-V file: " << path << "\n";
516
+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
456
517
  continue;
457
518
  }
458
519
 
@@ -464,7 +525,7 @@ void write_output_files() {
464
525
  size_t read_size = fread(data.data(), 1, size, spv);
465
526
  fclose(spv);
466
527
  if (read_size != size) {
467
- std::cerr << "Error reading SPIR-V file: " << path << "\n";
528
+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
468
529
  continue;
469
530
  }
470
531
 
@@ -478,9 +539,8 @@ void write_output_files() {
478
539
  }
479
540
  fprintf(src, "\n};\n\n");
480
541
 
481
- if (clean) {
542
+ if (!no_clean) {
482
543
  std::remove(path.c_str());
483
- // fprintf(stderr, "Removed: %s\n", path.c_str());
484
544
  }
485
545
  }
486
546
 
@@ -496,18 +556,6 @@ int main(int argc, char** argv) {
496
556
  }
497
557
  }
498
558
 
499
- if (argc <= 1 || args.find("--help") != args.end()) {
500
- std::cout << "Usage:\n"
501
- "\tvulkan-shaders-gen [options]\n\n"
502
- "Options:\n"
503
- "\t--glslc <path> Path to glslc executable (default: /usr/bin/glslc)\n"
504
- "\t--input-dir Directory containing shader sources (required)\n"
505
- "\t--output-dir Output directory for generated SPIR-V files and optional C++ headers\n"
506
- "\t--target-hpp <path> Path to generate a header file with shader declarations in C++ format\n"
507
- "\t--target-cpp <path> Path to generate a source code file implementing the declared shaders (optional)\n"
508
- "\t--no-clean Keep temporary SPIR-V files after build (default: remove them)\n";
509
- return EXIT_SUCCESS;
510
- }
511
559
  if (args.find("--glslc") != args.end()) {
512
560
  GLSLC = args["--glslc"]; // Path to glslc
513
561
  }
@@ -524,7 +572,7 @@ int main(int argc, char** argv) {
524
572
  target_cpp = args["--target-cpp"]; // Path to generated cpp file
525
573
  }
526
574
  if (args.find("--no-clean") != args.end()) {
527
- clean = false; // Keep temporary SPIR-V files in output-dir after build
575
+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build
528
576
  }
529
577
 
530
578
  if (!directory_exists(input_dir)) {