@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,7 +1,9 @@
1
1
  find_package (Threads REQUIRED)
2
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
2
3
 
3
4
  set(TARGET vulkan-shaders-gen)
4
5
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
5
6
  install(TARGETS ${TARGET} RUNTIME)
6
- target_compile_features(${TARGET} PRIVATE cxx_std_11)
7
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
7
8
  target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
9
+ target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
@@ -0,0 +1,593 @@
1
+
2
+
3
+ #include <iostream>
4
+ #include <fstream>
5
+ #include <sstream>
6
+ #include <string>
7
+ #include <stdexcept>
8
+ #include <array>
9
+ #include <vector>
10
+ #include <map>
11
+ #include <thread>
12
+ #include <mutex>
13
+ #include <future>
14
+ #include <queue>
15
+ #include <condition_variable>
16
+ #include <cstdio>
17
+ #include <cstring>
18
+ #include <cstdlib>
19
+ #include <cassert>
20
+ #include <sys/stat.h>
21
+ #include <sys/types.h>
22
+
23
+ #ifdef _WIN32
24
+ #include <windows.h>
25
+ #include <direct.h> // For _mkdir on Windows
26
+ #include <algorithm> // For std::replace on w64devkit
27
+ #else
28
+ #include <unistd.h>
29
+ #include <sys/wait.h>
30
+ #include <fcntl.h>
31
+ #endif
32
+
33
+ #include <vulkan/vulkan_core.h>
34
+
35
+ #define ASYNCIO_CONCURRENCY 64
36
+
37
+ std::mutex lock;
38
+ std::vector<std::pair<std::string, std::string>> shader_fnames;
39
+
40
+ std::string GLSLC = "glslc";
41
+ std::string input_dir = "vulkan-shaders";
42
+ std::string output_dir = "/tmp";
43
+ std::string target_hpp = "ggml-vulkan-shaders.hpp";
44
+ std::string target_cpp = "ggml-vulkan-shaders.cpp";
45
+ bool no_clean = false;
46
+
47
+ const std::vector<std::string> type_names = {
48
+ "f32",
49
+ "f16",
50
+ "q4_0",
51
+ "q4_1",
52
+ "q5_0",
53
+ "q5_1",
54
+ "q8_0",
55
+ "q2_k",
56
+ "q3_k",
57
+ "q4_k",
58
+ "q5_k",
59
+ "q6_k",
60
+ "iq4_nl"
61
+ };
62
+
63
+ namespace {
64
+ void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
65
+ #ifdef _WIN32
66
+ HANDLE stdout_read, stdout_write;
67
+ HANDLE stderr_read, stderr_write;
68
+ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
69
+
70
+ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
71
+ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
72
+ throw std::runtime_error("Failed to create stdout pipe");
73
+ }
74
+
75
+ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
76
+ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
77
+ throw std::runtime_error("Failed to create stderr pipe");
78
+ }
79
+
80
+ PROCESS_INFORMATION pi;
81
+ STARTUPINFOA si = { sizeof(STARTUPINFOA) };
82
+ si.dwFlags = STARTF_USESTDHANDLES;
83
+ si.hStdOutput = stdout_write;
84
+ si.hStdError = stderr_write;
85
+
86
+ std::vector<char> cmd(command.begin(), command.end());
87
+ cmd.push_back('\0');
88
+
89
+ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
90
+ throw std::runtime_error("Failed to create process");
91
+ }
92
+
93
+ CloseHandle(stdout_write);
94
+ CloseHandle(stderr_write);
95
+
96
+ std::array<char, 128> buffer;
97
+ DWORD bytes_read;
98
+
99
+ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
100
+ stdout_str.append(buffer.data(), bytes_read);
101
+ }
102
+
103
+ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
104
+ stderr_str.append(buffer.data(), bytes_read);
105
+ }
106
+
107
+ CloseHandle(stdout_read);
108
+ CloseHandle(stderr_read);
109
+ WaitForSingleObject(pi.hProcess, INFINITE);
110
+ CloseHandle(pi.hProcess);
111
+ CloseHandle(pi.hThread);
112
+ #else
113
+ int stdout_pipe[2];
114
+ int stderr_pipe[2];
115
+
116
+ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
117
+ throw std::runtime_error("Failed to create pipes");
118
+ }
119
+
120
+ pid_t pid = fork();
121
+ if (pid < 0) {
122
+ throw std::runtime_error("Failed to fork process");
123
+ }
124
+
125
+ if (pid == 0) {
126
+ close(stdout_pipe[0]);
127
+ close(stderr_pipe[0]);
128
+ dup2(stdout_pipe[1], STDOUT_FILENO);
129
+ dup2(stderr_pipe[1], STDERR_FILENO);
130
+ close(stdout_pipe[1]);
131
+ close(stderr_pipe[1]);
132
+ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
133
+ _exit(EXIT_FAILURE);
134
+ } else {
135
+ close(stdout_pipe[1]);
136
+ close(stderr_pipe[1]);
137
+
138
+ std::array<char, 128> buffer;
139
+ ssize_t bytes_read;
140
+
141
+ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
142
+ stdout_str.append(buffer.data(), bytes_read);
143
+ }
144
+
145
+ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
146
+ stderr_str.append(buffer.data(), bytes_read);
147
+ }
148
+
149
+ close(stdout_pipe[0]);
150
+ close(stderr_pipe[0]);
151
+ waitpid(pid, nullptr, 0);
152
+ }
153
+ #endif
154
+ }
155
+
156
+ bool directory_exists(const std::string& path) {
157
+ struct stat info;
158
+ if (stat(path.c_str(), &info) != 0) {
159
+ return false; // Path doesn't exist or can't be accessed
160
+ }
161
+ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
162
+ }
163
+
164
+ bool create_directory(const std::string& path) {
165
+ #ifdef _WIN32
166
+ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
167
+ #else
168
+ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
169
+ #endif
170
+ }
171
+
172
+ std::string to_uppercase(const std::string& input) {
173
+ std::string result = input;
174
+ for (char& c : result) {
175
+ c = std::toupper(c);
176
+ }
177
+ return result;
178
+ }
179
+
180
+ bool string_ends_with(const std::string& str, const std::string& suffix) {
181
+ if (suffix.size() > str.size()) {
182
+ return false;
183
+ }
184
+ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
185
+ }
186
+
187
+ static const char path_separator = '/';
188
+
189
+ std::string join_paths(const std::string& path1, const std::string& path2) {
190
+ return path1 + path_separator + path2;
191
+ }
192
+
193
+ std::string basename(const std::string &path) {
194
+ return path.substr(path.find_last_of("/\\") + 1);
195
+ }
196
+
197
+ // variables to track number of compiles in progress
198
+ static uint32_t compile_count = 0;
199
+ static std::mutex compile_count_mutex;
200
+ static std::condition_variable compile_count_cond;
201
+
202
+ void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
203
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
204
+ std::string out_fname = join_paths(output_dir, name + ".spv");
205
+ std::string in_path = join_paths(input_dir, in_fname);
206
+
207
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
208
+
209
+ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
210
+ std::string opt_level = coopmat ? "" : "-O";
211
+
212
+ #ifdef _WIN32
213
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
214
+ #else
215
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
216
+ #endif
217
+
218
+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
219
+ cmd.push_back("-g");
220
+ #endif
221
+
222
+ for (const auto& define : defines) {
223
+ cmd.push_back("-D" + define.first + "=" + define.second);
224
+ }
225
+
226
+ std::string command;
227
+ for (const auto& part : cmd) {
228
+ command += part + " ";
229
+ }
230
+
231
+ std::string stdout_str, stderr_str;
232
+ try {
233
+ // std::cout << "Executing command: ";
234
+ // for (const auto& part : cmd) {
235
+ // std::cout << part << " ";
236
+ // }
237
+ // std::cout << std::endl;
238
+
239
+ execute_command(command, stdout_str, stderr_str);
240
+ if (!stderr_str.empty()) {
241
+ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
242
+ return;
243
+ }
244
+
245
+ std::lock_guard<std::mutex> guard(lock);
246
+ shader_fnames.push_back(std::make_pair(name, out_fname));
247
+ } catch (const std::exception& e) {
248
+ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
249
+ }
250
+ {
251
+ std::lock_guard<std::mutex> guard(compile_count_mutex);
252
+ assert(compile_count > 0);
253
+ compile_count--;
254
+ }
255
+ compile_count_cond.notify_all();
256
+ }
257
+
258
+ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
259
+ std::map<std::string, std::string> result = a;
260
+ result.insert(b.begin(), b.end());
261
+ return result;
262
+ }
263
+
264
+ static std::vector<std::future<void>> compiles;
265
+ void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
266
+ {
267
+ // wait until fewer than N compiles are in progress.
268
+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
269
+ uint32_t N = 16;
270
+ std::unique_lock<std::mutex> guard(compile_count_mutex);
271
+ while (compile_count >= N) {
272
+ compile_count_cond.wait(guard);
273
+ }
274
+ compile_count++;
275
+ }
276
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
277
+ }
278
+
279
+ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
280
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
281
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
282
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
283
+
284
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
285
+ std::string shader_name = "matmul";
286
+
287
+ if (matmul_id) {
288
+ base_dict["MUL_MAT_ID"] = "1";
289
+ shader_name = "matmul_id";
290
+ }
291
+
292
+ if (fp16) {
293
+ base_dict["FLOAT16"] = "1";
294
+ }
295
+
296
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
297
+
298
+ if (coopmat) {
299
+ base_dict["COOPMAT"] = "1";
300
+ }
301
+
302
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
303
+
304
+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
305
+
306
+ // Shaders with f16 B_TYPE
307
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
308
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
309
+
310
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
311
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
312
+
313
+ for (const auto& tname : type_names) {
314
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
315
+ // For unaligned, load one at a time for f32/f16, or two at a time for quants
316
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
317
+ // For aligned matmul loads
318
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
319
+
320
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
321
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
322
+
323
+ if (tname != "f16" && tname != "f32") {
324
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
325
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
326
+ }
327
+ }
328
+ }
329
+
330
+ void process_shaders() {
331
+ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
332
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
333
+
334
+ // matmul
335
+ for (const auto& matmul_id : {false, true}) {
336
+ // No coopmats
337
+ // fp32
338
+ matmul_shaders(false, matmul_id, false, false, false);
339
+
340
+ // fp16, fp32acc and fp16acc
341
+ matmul_shaders(true, matmul_id, false, false, false);
342
+ matmul_shaders(true, matmul_id, false, false, true);
343
+
344
+ // Coopmat, fp32acc and fp16acc
345
+ matmul_shaders(true, matmul_id, true, false, false);
346
+ matmul_shaders(true, matmul_id, true, false, true);
347
+
348
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
349
+ // Coopmat2, fp32acc and fp16acc
350
+ matmul_shaders(true, matmul_id, false, true, false);
351
+ matmul_shaders(true, matmul_id, false, true, true);
352
+ #endif
353
+ }
354
+
355
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
356
+ // flash attention
357
+ for (const auto& f16acc : {false, true}) {
358
+ std::string acctype = f16acc ? "float16_t" : "float";
359
+
360
+ for (const auto& tname : type_names) {
361
+ if (tname == "f32") {
362
+ continue;
363
+ }
364
+
365
+ if (tname == "f16") {
366
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
367
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
368
+ } else {
369
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
370
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
371
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
372
+ }
373
+ }
374
+ }
375
+ #endif
376
+
377
+ for (const auto& tname : type_names) {
378
+ // mul mat vec
379
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
380
+ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
381
+
382
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
383
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
384
+
385
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
386
+
387
+ // Dequant shaders
388
+ if (tname != "f16") {
389
+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
390
+ }
391
+
392
+ if (!string_ends_with(tname, "_k")) {
393
+ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
394
+
395
+ if (tname == "f16") {
396
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
397
+ } else {
398
+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
399
+ }
400
+ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
401
+ }
402
+ }
403
+
404
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
405
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
406
+
407
+ // Norms
408
+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
409
+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
410
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
411
+
412
+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
413
+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
414
+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
415
+ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
416
+ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
417
+ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
418
+
419
+ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
420
+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
421
+
422
+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
423
+
424
+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
425
+
426
+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
427
+
428
+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
429
+
430
+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
431
+
432
+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
433
+
434
+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
435
+
436
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
437
+
438
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
439
+
440
+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
441
+
442
+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
443
+
444
+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
445
+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
446
+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
447
+
448
+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
449
+
450
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
452
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
453
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
454
+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
455
+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
456
+
457
+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
458
+
459
+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
460
+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
461
+
462
+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463
+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
464
+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
465
+
466
+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
467
+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
468
+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
469
+
470
+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
471
+
472
+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
473
+
474
+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
475
+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
476
+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
477
+
478
+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
479
+
480
+ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481
+
482
+ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
483
+
484
+ for (auto &c : compiles) {
485
+ c.wait();
486
+ }
487
+ }
488
+
489
+ void write_output_files() {
490
+ FILE* hdr = fopen(target_hpp.c_str(), "w");
491
+ FILE* src = fopen(target_cpp.c_str(), "w");
492
+
493
+ fprintf(hdr, "#include <cstdint>\n\n");
494
+ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
495
+
496
+ for (const auto& pair : shader_fnames) {
497
+ const std::string& name = pair.first;
498
+ #ifdef _WIN32
499
+ std::string path = pair.second;
500
+ std::replace(path.begin(), path.end(), '/', '\\' );
501
+ #else
502
+ const std::string& path = pair.second;
503
+ #endif
504
+
505
+ FILE* spv = fopen(path.c_str(), "rb");
506
+ if (!spv) {
507
+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
508
+ continue;
509
+ }
510
+
511
+ fseek(spv, 0, SEEK_END);
512
+ size_t size = ftell(spv);
513
+ fseek(spv, 0, SEEK_SET);
514
+
515
+ std::vector<unsigned char> data(size);
516
+ size_t read_size = fread(data.data(), 1, size, spv);
517
+ fclose(spv);
518
+ if (read_size != size) {
519
+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
520
+ continue;
521
+ }
522
+
523
+ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
524
+ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
525
+
526
+ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
527
+ for (size_t i = 0; i < size; ++i) {
528
+ fprintf(src, "0x%02x,", data[i]);
529
+ if ((i + 1) % 12 == 0) fprintf(src, "\n");
530
+ }
531
+ fprintf(src, "\n};\n\n");
532
+
533
+ if (!no_clean) {
534
+ std::remove(path.c_str());
535
+ }
536
+ }
537
+
538
+ fclose(hdr);
539
+ fclose(src);
540
+ }
541
+ }
542
+
543
+ int main(int argc, char** argv) {
544
+ std::map<std::string, std::string> args;
545
+ for (int i = 1; i < argc; ++i) {
546
+ std::string arg = argv[i];
547
+ if (arg.rfind("--", 0) == 0) {
548
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
549
+ args[arg] = argv[i + 1];
550
+ ++i;
551
+ } else {
552
+ args[arg] = "";
553
+ }
554
+ }
555
+ }
556
+
557
+ if (args.find("--glslc") != args.end()) {
558
+ GLSLC = args["--glslc"]; // Path to glslc
559
+ }
560
+ if (args.find("--input-dir") != args.end()) {
561
+ input_dir = args["--input-dir"]; // Directory containing shader sources
562
+ }
563
+ if (args.find("--output-dir") != args.end()) {
564
+ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
565
+ }
566
+ if (args.find("--target-hpp") != args.end()) {
567
+ target_hpp = args["--target-hpp"]; // Path to generated header file
568
+ }
569
+ if (args.find("--target-cpp") != args.end()) {
570
+ target_cpp = args["--target-cpp"]; // Path to generated cpp file
571
+ }
572
+ if (args.find("--no-clean") != args.end()) {
573
+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build
574
+ }
575
+
576
+ if (!directory_exists(input_dir)) {
577
+ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
578
+ return EXIT_FAILURE;
579
+ }
580
+
581
+ if (!directory_exists(output_dir)) {
582
+ if (!create_directory(output_dir)) {
583
+ std::cerr << "Error creating output directory: " << output_dir << "\n";
584
+ return EXIT_FAILURE;
585
+ }
586
+ }
587
+
588
+ process_shaders();
589
+
590
+ write_output_files();
591
+
592
+ return EXIT_SUCCESS;
593
+ }