@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,39 +0,0 @@
1
- // SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
2
- #pragma once
3
-
4
- #define GGML_COMMON_DECL_C
5
- #include "ggml-common.h"
6
-
7
- #include "ggml.h"
8
-
9
- // GGML internal header
10
-
11
- #ifdef __cplusplus
12
- extern "C" {
13
- #endif
14
-
15
- // Quantization
16
- void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
17
- void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
18
-
19
- void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
20
-
21
- // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
22
- size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
23
- size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
24
- size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
25
-
26
- // GEMV
27
- void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
28
- void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
29
- void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
30
-
31
- // GEMM
32
- void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
33
- void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
34
- void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
35
-
36
- #ifdef __cplusplus
37
- }
38
- #endif
39
-
@@ -1,600 +0,0 @@
1
-
2
-
3
- #include <iostream>
4
- #include <fstream>
5
- #include <sstream>
6
- #include <string>
7
- #include <stdexcept>
8
- #include <array>
9
- #include <vector>
10
- #include <map>
11
- #include <thread>
12
- #include <mutex>
13
- #include <future>
14
- #include <queue>
15
- #include <condition_variable>
16
- #include <cstdio>
17
- #include <cstring>
18
- #include <cstdlib>
19
- #include <sys/stat.h>
20
- #include <sys/types.h>
21
-
22
- #ifdef _WIN32
23
- #include <windows.h>
24
- #include <direct.h> // For _mkdir on Windows
25
- #include <algorithm> // For std::replace on w64devkit
26
- #else
27
- #include <unistd.h>
28
- #include <sys/wait.h>
29
- #include <fcntl.h>
30
- #endif
31
-
32
- #define ASYNCIO_CONCURRENCY 64
33
-
34
- std::mutex lock;
35
- std::vector<std::pair<std::string, std::string>> shader_fnames;
36
-
37
- std::string GLSLC = "glslc";
38
- std::string input_dir = "vulkan-shaders";
39
- std::string output_dir = "/tmp";
40
- std::string target_hpp = "ggml-vulkan-shaders.hpp";
41
- std::string target_cpp = "ggml-vulkan-shaders.cpp";
42
- bool no_clean = false;
43
-
44
- const std::vector<std::string> type_names = {
45
- "f32",
46
- "f16",
47
- "q4_0",
48
- "q4_1",
49
- "q5_0",
50
- "q5_1",
51
- "q8_0",
52
- "q2_k",
53
- "q3_k",
54
- "q4_k",
55
- "q5_k",
56
- "q6_k",
57
- "iq4_nl"
58
- };
59
-
60
- void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
61
- #ifdef _WIN32
62
- HANDLE stdout_read, stdout_write;
63
- HANDLE stderr_read, stderr_write;
64
- SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
65
-
66
- if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
67
- !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
68
- throw std::runtime_error("Failed to create stdout pipe");
69
- }
70
-
71
- if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
72
- !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
73
- throw std::runtime_error("Failed to create stderr pipe");
74
- }
75
-
76
- PROCESS_INFORMATION pi;
77
- STARTUPINFOA si = { sizeof(STARTUPINFOA) };
78
- si.dwFlags = STARTF_USESTDHANDLES;
79
- si.hStdOutput = stdout_write;
80
- si.hStdError = stderr_write;
81
-
82
- std::vector<char> cmd(command.begin(), command.end());
83
- cmd.push_back('\0');
84
-
85
- if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
86
- throw std::runtime_error("Failed to create process");
87
- }
88
-
89
- CloseHandle(stdout_write);
90
- CloseHandle(stderr_write);
91
-
92
- std::array<char, 128> buffer;
93
- DWORD bytes_read;
94
-
95
- while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
96
- stdout_str.append(buffer.data(), bytes_read);
97
- }
98
-
99
- while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
100
- stderr_str.append(buffer.data(), bytes_read);
101
- }
102
-
103
- CloseHandle(stdout_read);
104
- CloseHandle(stderr_read);
105
- WaitForSingleObject(pi.hProcess, INFINITE);
106
- CloseHandle(pi.hProcess);
107
- CloseHandle(pi.hThread);
108
- #else
109
- int stdout_pipe[2];
110
- int stderr_pipe[2];
111
-
112
- if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
113
- throw std::runtime_error("Failed to create pipes");
114
- }
115
-
116
- pid_t pid = fork();
117
- if (pid < 0) {
118
- throw std::runtime_error("Failed to fork process");
119
- }
120
-
121
- if (pid == 0) {
122
- close(stdout_pipe[0]);
123
- close(stderr_pipe[0]);
124
- dup2(stdout_pipe[1], STDOUT_FILENO);
125
- dup2(stderr_pipe[1], STDERR_FILENO);
126
- close(stdout_pipe[1]);
127
- close(stderr_pipe[1]);
128
- execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
129
- _exit(EXIT_FAILURE);
130
- } else {
131
- close(stdout_pipe[1]);
132
- close(stderr_pipe[1]);
133
-
134
- std::array<char, 128> buffer;
135
- ssize_t bytes_read;
136
-
137
- while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
138
- stdout_str.append(buffer.data(), bytes_read);
139
- }
140
-
141
- while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
142
- stderr_str.append(buffer.data(), bytes_read);
143
- }
144
-
145
- close(stdout_pipe[0]);
146
- close(stderr_pipe[0]);
147
- waitpid(pid, nullptr, 0);
148
- }
149
- #endif
150
- }
151
-
152
- bool directory_exists(const std::string& path) {
153
- struct stat info;
154
- if (stat(path.c_str(), &info) != 0) {
155
- return false; // Path doesn't exist or can't be accessed
156
- }
157
- return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
158
- }
159
-
160
- bool create_directory(const std::string& path) {
161
- #ifdef _WIN32
162
- return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
163
- #else
164
- return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
165
- #endif
166
- }
167
-
168
- std::string to_uppercase(const std::string& input) {
169
- std::string result = input;
170
- for (char& c : result) {
171
- c = std::toupper(c);
172
- }
173
- return result;
174
- }
175
-
176
- bool string_ends_with(const std::string& str, const std::string& suffix) {
177
- if (suffix.size() > str.size()) {
178
- return false;
179
- }
180
- return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
181
- }
182
-
183
- static const char path_separator = '/';
184
-
185
- std::string join_paths(const std::string& path1, const std::string& path2) {
186
- return path1 + path_separator + path2;
187
- }
188
-
189
- std::string basename(const std::string &path) {
190
- return path.substr(path.find_last_of("/\\") + 1);
191
- }
192
-
193
- void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
194
- std::string name = _name + (fp16 ? "" : "_fp32");
195
- std::string out_fname = join_paths(output_dir, name + ".spv");
196
- std::string in_path = join_paths(input_dir, in_fname);
197
-
198
- #ifdef _WIN32
199
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
200
- #else
201
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
202
- #endif
203
-
204
- #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
205
- cmd.push_back("-g");
206
- #endif
207
-
208
- for (const auto& define : defines) {
209
- cmd.push_back("-D" + define.first + "=" + define.second);
210
- }
211
-
212
- std::string command;
213
- for (const auto& part : cmd) {
214
- command += part + " ";
215
- }
216
-
217
- std::string stdout_str, stderr_str;
218
- try {
219
- // std::cout << "Executing command: ";
220
- // for (const auto& part : cmd) {
221
- // std::cout << part << " ";
222
- // }
223
- // std::cout << std::endl;
224
-
225
- execute_command(command, stdout_str, stderr_str);
226
- if (!stderr_str.empty()) {
227
- std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
228
- return;
229
- }
230
-
231
- std::lock_guard<std::mutex> guard(lock);
232
- shader_fnames.push_back(std::make_pair(name, out_fname));
233
- } catch (const std::exception& e) {
234
- std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
235
- }
236
- }
237
-
238
- std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
239
- std::map<std::string, std::string> result = a;
240
- result.insert(b.begin(), b.end());
241
- return result;
242
- }
243
-
244
- void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id) {
245
- std::string load_vec = fp16 ? "8" : "4";
246
- std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
247
- std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
248
-
249
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
250
- std::string shader_name = "matmul";
251
-
252
- if (matmul_id) {
253
- base_dict["MUL_MAT_ID"] = "1";
254
- shader_name = "matmul_id";
255
- }
256
-
257
- if (fp16) {
258
- base_dict["FLOAT16"] = "1";
259
- }
260
-
261
- // Shaders with f16 B_TYPE
262
- tasks.push_back(std::async(std::launch::async, [=] {
263
- string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
264
- }));
265
- tasks.push_back(std::async(std::launch::async, [=] {
266
- string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
267
- }));
268
-
269
- tasks.push_back(std::async(std::launch::async, [=] {
270
- string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
271
- }));
272
- tasks.push_back(std::async(std::launch::async, [=] {
273
- string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
274
- }));
275
-
276
- for (const auto& tname : type_names) {
277
- std::string data_a_key = "DATA_A_" + to_uppercase(tname);
278
- // For unaligned, load one at a time for f32/f16, or two at a time for quants
279
- std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
280
- // For aligned matmul loads
281
- std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
282
- tasks.push_back(std::async(std::launch::async, [=] {
283
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
284
- }));
285
- tasks.push_back(std::async(std::launch::async, [=] {
286
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
287
- }));
288
- }
289
- }
290
-
291
- void process_shaders(std::vector<std::future<void>>& tasks) {
292
- std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
293
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
294
-
295
- for (const auto& fp16 : {false, true}) {
296
- matmul_shaders(tasks, fp16, false);
297
- matmul_shaders(tasks, fp16, true);
298
- }
299
-
300
- for (const auto& tname : type_names) {
301
- // mul mat vec
302
- std::string data_a_key = "DATA_A_" + to_uppercase(tname);
303
- std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
304
-
305
- tasks.push_back(std::async(std::launch::async, [=] {
306
- string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
307
- }));
308
- tasks.push_back(std::async(std::launch::async, [=] {
309
- string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
310
- }));
311
-
312
- tasks.push_back(std::async(std::launch::async, [=] {
313
- string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
314
- }));
315
-
316
- // Dequant shaders
317
- if (tname != "f16") {
318
- tasks.push_back(std::async(std::launch::async, [=] {
319
- string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
320
- }));
321
- }
322
-
323
- if (!string_ends_with(tname, "_k")) {
324
- shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
325
-
326
- if (tname == "f16") {
327
- tasks.push_back(std::async(std::launch::async, [=] {
328
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
329
- }));
330
- } else {
331
- tasks.push_back(std::async(std::launch::async, [=] {
332
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
333
- }));
334
- }
335
- tasks.push_back(std::async(std::launch::async, [=] {
336
- string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
337
- }));
338
- }
339
- }
340
-
341
- tasks.push_back(std::async(std::launch::async, [] {
342
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
343
- }));
344
- tasks.push_back(std::async(std::launch::async, [] {
345
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
346
- }));
347
-
348
- // Norms
349
- tasks.push_back(std::async(std::launch::async, [=] {
350
- string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
351
- }));
352
- tasks.push_back(std::async(std::launch::async, [=] {
353
- string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
354
- }));
355
- tasks.push_back(std::async(std::launch::async, [=] {
356
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
357
- }));
358
-
359
- tasks.push_back(std::async(std::launch::async, [] {
360
- string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
361
- }));
362
- tasks.push_back(std::async(std::launch::async, [] {
363
- string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
364
- }));
365
- tasks.push_back(std::async(std::launch::async, [] {
366
- string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
367
- }));
368
-
369
- tasks.push_back(std::async(std::launch::async, [] {
370
- string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
371
- }));
372
- tasks.push_back(std::async(std::launch::async, [] {
373
- string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
374
- }));
375
-
376
- tasks.push_back(std::async(std::launch::async, [] {
377
- string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
378
- }));
379
-
380
- tasks.push_back(std::async(std::launch::async, [] {
381
- string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
382
- }));
383
-
384
- tasks.push_back(std::async(std::launch::async, [] {
385
- string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
386
- }));
387
-
388
- tasks.push_back(std::async(std::launch::async, [] {
389
- string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
390
- }));
391
-
392
- tasks.push_back(std::async(std::launch::async, [] {
393
- string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
394
- }));
395
-
396
- tasks.push_back(std::async(std::launch::async, [] {
397
- string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
398
- }));
399
-
400
- tasks.push_back(std::async(std::launch::async, [] {
401
- string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
402
- }));
403
-
404
- tasks.push_back(std::async(std::launch::async, [] {
405
- string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
406
- }));
407
-
408
- tasks.push_back(std::async(std::launch::async, [] {
409
- string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
410
- }));
411
-
412
- tasks.push_back(std::async(std::launch::async, [] {
413
- string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
414
- }));
415
-
416
- tasks.push_back(std::async(std::launch::async, [] {
417
- string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
418
- }));
419
-
420
- tasks.push_back(std::async(std::launch::async, [] {
421
- string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
422
- }));
423
- tasks.push_back(std::async(std::launch::async, [] {
424
- string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
425
- }));
426
- tasks.push_back(std::async(std::launch::async, [] {
427
- string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
428
- }));
429
-
430
- tasks.push_back(std::async(std::launch::async, [] {
431
- string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
432
- }));
433
-
434
- tasks.push_back(std::async(std::launch::async, [] {
435
- string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
436
- }));
437
- tasks.push_back(std::async(std::launch::async, [] {
438
- string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
439
- }));
440
- tasks.push_back(std::async(std::launch::async, [] {
441
- string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
442
- }));
443
- tasks.push_back(std::async(std::launch::async, [] {
444
- string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
445
- }));
446
- tasks.push_back(std::async(std::launch::async, [] {
447
- string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448
- }));
449
- tasks.push_back(std::async(std::launch::async, [] {
450
- string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
451
- }));
452
-
453
- tasks.push_back(std::async(std::launch::async, [] {
454
- string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
455
- }));
456
-
457
- tasks.push_back(std::async(std::launch::async, [=] {
458
- string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
459
- }));
460
- tasks.push_back(std::async(std::launch::async, [=] {
461
- string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
462
- }));
463
-
464
- tasks.push_back(std::async(std::launch::async, [] {
465
- string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
466
- }));
467
- tasks.push_back(std::async(std::launch::async, [] {
468
- string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
469
- }));
470
-
471
- tasks.push_back(std::async(std::launch::async, [] {
472
- string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
473
- }));
474
- tasks.push_back(std::async(std::launch::async, [] {
475
- string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
476
- }));
477
-
478
- tasks.push_back(std::async(std::launch::async, [] {
479
- string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
480
- }));
481
-
482
- tasks.push_back(std::async(std::launch::async, [=] {
483
- string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
484
- }));
485
-
486
- tasks.push_back(std::async(std::launch::async, [=] {
487
- string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
488
- }));
489
- tasks.push_back(std::async(std::launch::async, [=] {
490
- string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
491
- }));
492
-
493
- tasks.push_back(std::async(std::launch::async, [=] {
494
- string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
495
- }));
496
- }
497
-
498
- void write_output_files() {
499
- FILE* hdr = fopen(target_hpp.c_str(), "w");
500
- FILE* src = fopen(target_cpp.c_str(), "w");
501
-
502
- fprintf(hdr, "#include <cstdint>\n\n");
503
- fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
504
-
505
- for (const auto& pair : shader_fnames) {
506
- const std::string& name = pair.first;
507
- #ifdef _WIN32
508
- std::string path = pair.second;
509
- std::replace(path.begin(), path.end(), '/', '\\' );
510
- #else
511
- const std::string& path = pair.second;
512
- #endif
513
-
514
- FILE* spv = fopen(path.c_str(), "rb");
515
- if (!spv) {
516
- std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
517
- continue;
518
- }
519
-
520
- fseek(spv, 0, SEEK_END);
521
- size_t size = ftell(spv);
522
- fseek(spv, 0, SEEK_SET);
523
-
524
- std::vector<unsigned char> data(size);
525
- size_t read_size = fread(data.data(), 1, size, spv);
526
- fclose(spv);
527
- if (read_size != size) {
528
- std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
529
- continue;
530
- }
531
-
532
- fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
533
- fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
534
-
535
- fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
536
- for (size_t i = 0; i < size; ++i) {
537
- fprintf(src, "0x%02x,", data[i]);
538
- if ((i + 1) % 12 == 0) fprintf(src, "\n");
539
- }
540
- fprintf(src, "\n};\n\n");
541
-
542
- if (!no_clean) {
543
- std::remove(path.c_str());
544
- }
545
- }
546
-
547
- fclose(hdr);
548
- fclose(src);
549
- }
550
-
551
- int main(int argc, char** argv) {
552
- std::map<std::string, std::string> args;
553
- for (int i = 1; i < argc; i += 2) {
554
- if (i + 1 < argc) {
555
- args[argv[i]] = argv[i + 1];
556
- }
557
- }
558
-
559
- if (args.find("--glslc") != args.end()) {
560
- GLSLC = args["--glslc"]; // Path to glslc
561
- }
562
- if (args.find("--input-dir") != args.end()) {
563
- input_dir = args["--input-dir"]; // Directory containing shader sources
564
- }
565
- if (args.find("--output-dir") != args.end()) {
566
- output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
567
- }
568
- if (args.find("--target-hpp") != args.end()) {
569
- target_hpp = args["--target-hpp"]; // Path to generated header file
570
- }
571
- if (args.find("--target-cpp") != args.end()) {
572
- target_cpp = args["--target-cpp"]; // Path to generated cpp file
573
- }
574
- if (args.find("--no-clean") != args.end()) {
575
- no_clean = true; // Keep temporary SPIR-V files in output-dir after build
576
- }
577
-
578
- if (!directory_exists(input_dir)) {
579
- std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
580
- return EXIT_FAILURE;
581
- }
582
-
583
- if (!directory_exists(output_dir)) {
584
- if (!create_directory(output_dir)) {
585
- std::cerr << "Error creating output directory: " << output_dir << "\n";
586
- return EXIT_FAILURE;
587
- }
588
- }
589
-
590
- std::vector<std::future<void>> tasks;
591
- process_shaders(tasks);
592
-
593
- for (auto& task : tasks) {
594
- task.get();
595
- }
596
-
597
- write_output_files();
598
-
599
- return EXIT_SUCCESS;
600
- }