@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -0,0 +1,911 @@
1
+ #if defined(_WIN32)
2
+ # include <windows.h>
3
+ #else
4
+ # include <sys/file.h>
5
+ # include <sys/ioctl.h>
6
+ # include <unistd.h>
7
+ #endif
8
+
9
+ #if defined(LLAMA_USE_CURL)
10
+ # include <curl/curl.h>
11
+ #endif
12
+
13
+ #include <climits>
14
+ #include <cstdarg>
15
+ #include <cstdio>
16
+ #include <cstring>
17
+ #include <filesystem>
18
+ #include <iostream>
19
+ #include <sstream>
20
+ #include <string>
21
+ #include <vector>
22
+
23
+ #include "common.h"
24
+ #include "json.hpp"
25
+ #include "llama-cpp.h"
26
+
27
+ GGML_ATTRIBUTE_FORMAT(1, 2)
28
+ static std::string fmt(const char * fmt, ...) {
29
+ va_list ap;
30
+ va_list ap2;
31
+ va_start(ap, fmt);
32
+ va_copy(ap2, ap);
33
+ const int size = vsnprintf(NULL, 0, fmt, ap);
34
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
35
+ std::string buf;
36
+ buf.resize(size);
37
+ const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
38
+ GGML_ASSERT(size2 == size);
39
+ va_end(ap2);
40
+ va_end(ap);
41
+
42
+ return buf;
43
+ }
44
+
45
+ GGML_ATTRIBUTE_FORMAT(1, 2)
46
+ static int printe(const char * fmt, ...) {
47
+ va_list args;
48
+ va_start(args, fmt);
49
+ const int ret = vfprintf(stderr, fmt, args);
50
+ va_end(args);
51
+
52
+ return ret;
53
+ }
54
+
55
+ class Opt {
56
+ public:
57
+ int init(int argc, const char ** argv) {
58
+ // Parse arguments
59
+ if (parse(argc, argv)) {
60
+ printe("Error: Failed to parse arguments.\n");
61
+ help();
62
+ return 1;
63
+ }
64
+
65
+ // If help is requested, show help and exit
66
+ if (help_) {
67
+ help();
68
+ return 2;
69
+ }
70
+
71
+ return 0; // Success
72
+ }
73
+
74
+ std::string model_;
75
+ std::string user_;
76
+ int context_size_ = -1, ngl_ = -1;
77
+ bool verbose_ = false;
78
+
79
+ private:
80
+ bool help_ = false;
81
+
82
+ bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
83
+ return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
84
+ }
85
+
86
+ int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) {
87
+ if (i + 1 >= argc) {
88
+ return 1;
89
+ }
90
+
91
+ option_value = std::atoi(argv[++i]);
92
+ return 0;
93
+ }
94
+
95
+ int parse(int argc, const char ** argv) {
96
+ bool options_parsing = true;
97
+ for (int i = 1, positional_args_i = 0; i < argc; ++i) {
98
+ if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
99
+ if (handle_option_with_value(argc, argv, i, context_size_) == 1) {
100
+ return 1;
101
+ }
102
+ } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
103
+ if (handle_option_with_value(argc, argv, i, ngl_) == 1) {
104
+ return 1;
105
+ }
106
+ } else if (options_parsing &&
107
+ (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
108
+ verbose_ = true;
109
+ } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
110
+ help_ = true;
111
+ return 0;
112
+ } else if (options_parsing && strcmp(argv[i], "--") == 0) {
113
+ options_parsing = false;
114
+ } else if (positional_args_i == 0) {
115
+ if (!argv[i][0] || argv[i][0] == '-') {
116
+ return 1;
117
+ }
118
+
119
+ ++positional_args_i;
120
+ model_ = argv[i];
121
+ } else if (positional_args_i == 1) {
122
+ ++positional_args_i;
123
+ user_ = argv[i];
124
+ } else {
125
+ user_ += " " + std::string(argv[i]);
126
+ }
127
+ }
128
+
129
+ return 0;
130
+ }
131
+
132
+ void help() const {
133
+ printf(
134
+ "Description:\n"
135
+ " Runs a llm\n"
136
+ "\n"
137
+ "Usage:\n"
138
+ " llama-run [options] model [prompt]\n"
139
+ "\n"
140
+ "Options:\n"
141
+ " -c, --context-size <value>\n"
142
+ " Context size (default: %d)\n"
143
+ " -n, --ngl <value>\n"
144
+ " Number of GPU layers (default: %d)\n"
145
+ " -v, --verbose, --log-verbose\n"
146
+ " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
147
+ " -h, --help\n"
148
+ " Show help message\n"
149
+ "\n"
150
+ "Commands:\n"
151
+ " model\n"
152
+ " Model is a string with an optional prefix of \n"
153
+ " huggingface:// (hf://), ollama://, https:// or file://.\n"
154
+ " If no protocol is specified and a file exists in the specified\n"
155
+ " path, file:// is assumed, otherwise if a file does not exist in\n"
156
+ " the specified path, ollama:// is assumed. Models that are being\n"
157
+ " pulled are downloaded with .partial extension while being\n"
158
+ " downloaded and then renamed as the file without the .partial\n"
159
+ " extension when complete.\n"
160
+ "\n"
161
+ "Examples:\n"
162
+ " llama-run llama3\n"
163
+ " llama-run ollama://granite-code\n"
164
+ " llama-run ollama://smollm:135m\n"
165
+ " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
166
+ " llama-run "
167
+ "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
168
+ " llama-run https://example.com/some-file1.gguf\n"
169
+ " llama-run some-file2.gguf\n"
170
+ " llama-run file://some-file3.gguf\n"
171
+ " llama-run --ngl 999 some-file4.gguf\n"
172
+ " llama-run --ngl 999 some-file5.gguf Hello World\n",
173
+ llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
174
+ }
175
+ };
176
+
177
+ struct progress_data {
178
+ size_t file_size = 0;
179
+ std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now();
180
+ bool printed = false;
181
+ };
182
+
183
+ static int get_terminal_width() {
184
+ #if defined(_WIN32)
185
+ CONSOLE_SCREEN_BUFFER_INFO csbi;
186
+ GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
187
+ return csbi.srWindow.Right - csbi.srWindow.Left + 1;
188
+ #else
189
+ struct winsize w;
190
+ ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
191
+ return w.ws_col;
192
+ #endif
193
+ }
194
+
195
+ #ifdef LLAMA_USE_CURL
196
+ class File {
197
+ public:
198
+ FILE * file = nullptr;
199
+
200
+ FILE * open(const std::string & filename, const char * mode) {
201
+ file = fopen(filename.c_str(), mode);
202
+
203
+ return file;
204
+ }
205
+
206
+ int lock() {
207
+ if (file) {
208
+ # ifdef _WIN32
209
+ fd = _fileno(file);
210
+ hFile = (HANDLE) _get_osfhandle(fd);
211
+ if (hFile == INVALID_HANDLE_VALUE) {
212
+ fd = -1;
213
+
214
+ return 1;
215
+ }
216
+
217
+ OVERLAPPED overlapped = { 0 };
218
+ if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD,
219
+ &overlapped)) {
220
+ fd = -1;
221
+
222
+ return 1;
223
+ }
224
+ # else
225
+ fd = fileno(file);
226
+ if (flock(fd, LOCK_EX | LOCK_NB) != 0) {
227
+ fd = -1;
228
+
229
+ return 1;
230
+ }
231
+ # endif
232
+ }
233
+
234
+ return 0;
235
+ }
236
+
237
+ ~File() {
238
+ if (fd >= 0) {
239
+ # ifdef _WIN32
240
+ if (hFile != INVALID_HANDLE_VALUE) {
241
+ OVERLAPPED overlapped = { 0 };
242
+ UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped);
243
+ }
244
+ # else
245
+ flock(fd, LOCK_UN);
246
+ # endif
247
+ }
248
+
249
+ if (file) {
250
+ fclose(file);
251
+ }
252
+ }
253
+
254
+ private:
255
+ int fd = -1;
256
+ # ifdef _WIN32
257
+ HANDLE hFile;
258
+ # endif
259
+ };
260
+
261
+ class HttpClient {
262
+ public:
263
+ int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
264
+ const bool progress, std::string * response_str = nullptr) {
265
+ std::string output_file_partial;
266
+ curl = curl_easy_init();
267
+ if (!curl) {
268
+ return 1;
269
+ }
270
+
271
+ progress_data data;
272
+ File out;
273
+ if (!output_file.empty()) {
274
+ output_file_partial = output_file + ".partial";
275
+ if (!out.open(output_file_partial, "ab")) {
276
+ printe("Failed to open file\n");
277
+
278
+ return 1;
279
+ }
280
+
281
+ if (out.lock()) {
282
+ printe("Failed to exclusively lock file\n");
283
+
284
+ return 1;
285
+ }
286
+ }
287
+
288
+ set_write_options(response_str, out);
289
+ data.file_size = set_resume_point(output_file_partial);
290
+ set_progress_options(progress, data);
291
+ set_headers(headers);
292
+ perform(url);
293
+ if (!output_file.empty()) {
294
+ std::filesystem::rename(output_file_partial, output_file);
295
+ }
296
+
297
+ return 0;
298
+ }
299
+
300
+ ~HttpClient() {
301
+ if (chunk) {
302
+ curl_slist_free_all(chunk);
303
+ }
304
+
305
+ if (curl) {
306
+ curl_easy_cleanup(curl);
307
+ }
308
+ }
309
+
310
+ private:
311
+ CURL * curl = nullptr;
312
+ struct curl_slist * chunk = nullptr;
313
+
314
+ void set_write_options(std::string * response_str, const File & out) {
315
+ if (response_str) {
316
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
317
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str);
318
+ } else {
319
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
320
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file);
321
+ }
322
+ }
323
+
324
+ size_t set_resume_point(const std::string & output_file) {
325
+ size_t file_size = 0;
326
+ if (std::filesystem::exists(output_file)) {
327
+ file_size = std::filesystem::file_size(output_file);
328
+ curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast<curl_off_t>(file_size));
329
+ }
330
+
331
+ return file_size;
332
+ }
333
+
334
+ void set_progress_options(bool progress, progress_data & data) {
335
+ if (progress) {
336
+ curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
337
+ curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data);
338
+ curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress);
339
+ }
340
+ }
341
+
342
+ void set_headers(const std::vector<std::string> & headers) {
343
+ if (!headers.empty()) {
344
+ if (chunk) {
345
+ curl_slist_free_all(chunk);
346
+ chunk = 0;
347
+ }
348
+
349
+ for (const auto & header : headers) {
350
+ chunk = curl_slist_append(chunk, header.c_str());
351
+ }
352
+
353
+ curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk);
354
+ }
355
+ }
356
+
357
+ void perform(const std::string & url) {
358
+ CURLcode res;
359
+ curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
360
+ curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
361
+ curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
362
+ curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
363
+ res = curl_easy_perform(curl);
364
+ if (res != CURLE_OK) {
365
+ printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
366
+ }
367
+ }
368
+
369
+ static std::string human_readable_time(double seconds) {
370
+ int hrs = static_cast<int>(seconds) / 3600;
371
+ int mins = (static_cast<int>(seconds) % 3600) / 60;
372
+ int secs = static_cast<int>(seconds) % 60;
373
+
374
+ if (hrs > 0) {
375
+ return fmt("%dh %02dm %02ds", hrs, mins, secs);
376
+ } else if (mins > 0) {
377
+ return fmt("%dm %02ds", mins, secs);
378
+ } else {
379
+ return fmt("%ds", secs);
380
+ }
381
+ }
382
+
383
+ static std::string human_readable_size(curl_off_t size) {
384
+ static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
385
+ char length = sizeof(suffix) / sizeof(suffix[0]);
386
+ int i = 0;
387
+ double dbl_size = size;
388
+ if (size > 1024) {
389
+ for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
390
+ dbl_size = size / 1024.0;
391
+ }
392
+ }
393
+
394
+ return fmt("%.2f %s", dbl_size, suffix[i]);
395
+ }
396
+
397
+ static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
398
+ curl_off_t) {
399
+ progress_data * data = static_cast<progress_data *>(ptr);
400
+ if (total_to_download <= 0) {
401
+ return 0;
402
+ }
403
+
404
+ total_to_download += data->file_size;
405
+ const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
406
+ const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
407
+ std::string progress_prefix = generate_progress_prefix(percentage);
408
+
409
+ const double speed = calculate_speed(now_downloaded, data->start_time);
410
+ const double tim = (total_to_download - now_downloaded) / speed;
411
+ std::string progress_suffix =
412
+ generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim);
413
+
414
+ int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
415
+ std::string progress_bar;
416
+ generate_progress_bar(progress_bar_width, percentage, progress_bar);
417
+
418
+ print_progress(progress_prefix, progress_bar, progress_suffix);
419
+ data->printed = true;
420
+
421
+ return 0;
422
+ }
423
+
424
+ static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
425
+ return (now_downloaded_plus_file_size * 100) / total_to_download;
426
+ }
427
+
428
+ static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", percentage); }
429
+
430
+ static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
431
+ const auto now = std::chrono::steady_clock::now();
432
+ const std::chrono::duration<double> elapsed_seconds = now - start_time;
433
+ return now_downloaded / elapsed_seconds.count();
434
+ }
435
+
436
+ static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
437
+ double speed, double estimated_time) {
438
+ const int width = 10;
439
+ return fmt("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), width,
440
+ human_readable_size(total_to_download).c_str(), width, human_readable_size(speed).c_str(), width,
441
+ human_readable_time(estimated_time).c_str());
442
+ }
443
+
444
+ static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
445
+ int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3;
446
+ if (progress_bar_width < 1) {
447
+ progress_bar_width = 1;
448
+ }
449
+
450
+ return progress_bar_width;
451
+ }
452
+
453
+ static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
454
+ std::string & progress_bar) {
455
+ const curl_off_t pos = (percentage * progress_bar_width) / 100;
456
+ for (int i = 0; i < progress_bar_width; ++i) {
457
+ progress_bar.append((i < pos) ? "█" : " ");
458
+ }
459
+
460
+ return progress_bar;
461
+ }
462
+
463
+ static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
464
+ const std::string & progress_suffix) {
465
+ printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
466
+ progress_suffix.c_str());
467
+ }
468
+ // Function to write data to a file
469
+ static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
470
+ FILE * out = static_cast<FILE *>(stream);
471
+ return fwrite(ptr, size, nmemb, out);
472
+ }
473
+
474
+ // Function to capture data into a string
475
+ static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) {
476
+ std::string * str = static_cast<std::string *>(stream);
477
+ str->append(static_cast<char *>(ptr), size * nmemb);
478
+ return size * nmemb;
479
+ }
480
+ };
481
+ #endif
482
+
483
+ class LlamaData {
484
+ public:
485
+ llama_model_ptr model;
486
+ llama_sampler_ptr sampler;
487
+ llama_context_ptr context;
488
+ std::vector<llama_chat_message> messages;
489
+ std::vector<std::string> msg_strs;
490
+ std::vector<char> fmtted;
491
+
492
+ int init(Opt & opt) {
493
+ model = initialize_model(opt);
494
+ if (!model) {
495
+ return 1;
496
+ }
497
+
498
+ context = initialize_context(model, opt.context_size_);
499
+ if (!context) {
500
+ return 1;
501
+ }
502
+
503
+ sampler = initialize_sampler();
504
+ return 0;
505
+ }
506
+
507
+ private:
508
+ #ifdef LLAMA_USE_CURL
509
+ int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
510
+ const bool progress, std::string * response_str = nullptr) {
511
+ HttpClient http;
512
+ if (http.init(url, headers, output_file, progress, response_str)) {
513
+ return 1;
514
+ }
515
+
516
+ return 0;
517
+ }
518
+ #else
519
+ int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
520
+ std::string * = nullptr) {
521
+ printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
522
+ return 1;
523
+ }
524
+ #endif
525
+
526
+ int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
527
+ // Find the second occurrence of '/' after protocol string
528
+ size_t pos = model.find('/');
529
+ pos = model.find('/', pos + 1);
530
+ if (pos == std::string::npos) {
531
+ return 1;
532
+ }
533
+
534
+ const std::string hfr = model.substr(0, pos);
535
+ const std::string hff = model.substr(pos + 1);
536
+ const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
537
+ return download(url, headers, bn, true);
538
+ }
539
+
540
+ int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
541
+ if (model.find('/') == std::string::npos) {
542
+ model = "library/" + model;
543
+ }
544
+
545
+ std::string model_tag = "latest";
546
+ size_t colon_pos = model.find(':');
547
+ if (colon_pos != std::string::npos) {
548
+ model_tag = model.substr(colon_pos + 1);
549
+ model = model.substr(0, colon_pos);
550
+ }
551
+
552
+ std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
553
+ std::string manifest_str;
554
+ const int ret = download(manifest_url, headers, "", false, &manifest_str);
555
+ if (ret) {
556
+ return ret;
557
+ }
558
+
559
+ nlohmann::json manifest = nlohmann::json::parse(manifest_str);
560
+ std::string layer;
561
+ for (const auto & l : manifest["layers"]) {
562
+ if (l["mediaType"] == "application/vnd.ollama.image.model") {
563
+ layer = l["digest"];
564
+ break;
565
+ }
566
+ }
567
+
568
+ std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
569
+ return download(blob_url, headers, bn, true);
570
+ }
571
+
572
+ std::string basename(const std::string & path) {
573
+ const size_t pos = path.find_last_of("/\\");
574
+ if (pos == std::string::npos) {
575
+ return path;
576
+ }
577
+
578
+ return path.substr(pos + 1);
579
+ }
580
+
581
+ int remove_proto(std::string & model_) {
582
+ const std::string::size_type pos = model_.find("://");
583
+ if (pos == std::string::npos) {
584
+ return 1;
585
+ }
586
+
587
+ model_ = model_.substr(pos + 3); // Skip past "://"
588
+ return 0;
589
+ }
590
+
591
+ int resolve_model(std::string & model_) {
592
+ int ret = 0;
593
+ if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
594
+ remove_proto(model_);
595
+
596
+ return ret;
597
+ }
598
+
599
+ const std::string bn = basename(model_);
600
+ const std::vector<std::string> headers = { "--header",
601
+ "Accept: application/vnd.docker.distribution.manifest.v2+json" };
602
+ if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
603
+ remove_proto(model_);
604
+ ret = huggingface_dl(model_, headers, bn);
605
+ } else if (string_starts_with(model_, "ollama://")) {
606
+ remove_proto(model_);
607
+ ret = ollama_dl(model_, headers, bn);
608
+ } else if (string_starts_with(model_, "https://")) {
609
+ download(model_, headers, bn, true);
610
+ } else {
611
+ ret = ollama_dl(model_, headers, bn);
612
+ }
613
+
614
+ model_ = bn;
615
+
616
+ return ret;
617
+ }
618
+
619
+ // Initializes the model and returns a unique pointer to it
620
+ llama_model_ptr initialize_model(Opt & opt) {
621
+ ggml_backend_load_all();
622
+ llama_model_params model_params = llama_model_default_params();
623
+ model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
624
+ resolve_model(opt.model_);
625
+ printe(
626
+ "\r%*s"
627
+ "\rLoading model",
628
+ get_terminal_width(), " ");
629
+ llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
630
+ if (!model) {
631
+ printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
632
+ }
633
+
634
+ printe("\r%*s\r", static_cast<int>(sizeof("Loading model")), " ");
635
+ return model;
636
+ }
637
+
638
+ // Initializes the context with the specified parameters
639
+ llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
640
+ llama_context_params ctx_params = llama_context_default_params();
641
+ ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
642
+ llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
643
+ if (!context) {
644
+ printe("%s: error: failed to create the llama_context\n", __func__);
645
+ }
646
+
647
+ return context;
648
+ }
649
+
650
+ // Initializes and configures the sampler
651
+ llama_sampler_ptr initialize_sampler() {
652
+ llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
653
+ llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
654
+ llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
655
+ llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
656
+
657
+ return sampler;
658
+ }
659
+ };
660
+
661
+ // Add a message to `messages` and store its content in `msg_strs`
662
+ static void add_message(const char * role, const std::string & text, LlamaData & llama_data) {
663
+ llama_data.msg_strs.push_back(std::move(text));
664
+ llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
665
+ }
666
+
667
+ // Function to apply the chat template and resize `formatted` if needed
668
+ static int apply_chat_template(LlamaData & llama_data, const bool append) {
669
+ int result = llama_chat_apply_template(
670
+ llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
671
+ append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
672
+ if (append && result > static_cast<int>(llama_data.fmtted.size())) {
673
+ llama_data.fmtted.resize(result);
674
+ result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
675
+ llama_data.messages.size(), append, llama_data.fmtted.data(),
676
+ llama_data.fmtted.size());
677
+ }
678
+
679
+ return result;
680
+ }
681
+
682
+ // Function to tokenize the prompt
683
+ static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt,
684
+ std::vector<llama_token> & prompt_tokens) {
685
+ const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true);
686
+ prompt_tokens.resize(n_prompt_tokens);
687
+ if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
688
+ true) < 0) {
689
+ printe("failed to tokenize the prompt\n");
690
+ return -1;
691
+ }
692
+
693
+ return n_prompt_tokens;
694
+ }
695
+
696
+ // Check if we have enough space in the context to evaluate this batch
697
+ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
698
+ const int n_ctx = llama_n_ctx(ctx.get());
699
+ const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
700
+ if (n_ctx_used + batch.n_tokens > n_ctx) {
701
+ printf("\033[0m\n");
702
+ printe("context size exceeded\n");
703
+ return 1;
704
+ }
705
+
706
+ return 0;
707
+ }
708
+
709
+ // convert the token to a string
710
+ static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) {
711
+ char buf[256];
712
+ int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true);
713
+ if (n < 0) {
714
+ printe("failed to convert token to piece\n");
715
+ return 1;
716
+ }
717
+
718
+ piece = std::string(buf, n);
719
+ return 0;
720
+ }
721
+
722
+ static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
723
+ printf("%s", piece.c_str());
724
+ fflush(stdout);
725
+ response += piece;
726
+ }
727
+
728
+ // helper function to evaluate a prompt and generate a response
729
+ static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
730
+ std::vector<llama_token> tokens;
731
+ if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) {
732
+ return 1;
733
+ }
734
+
735
+ // prepare a batch for the prompt
736
+ llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
737
+ llama_token new_token_id;
738
+ while (true) {
739
+ check_context_size(llama_data.context, batch);
740
+ if (llama_decode(llama_data.context.get(), batch)) {
741
+ printe("failed to decode\n");
742
+ return 1;
743
+ }
744
+
745
+ // sample the next token, check is it an end of generation?
746
+ new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
747
+ if (llama_token_is_eog(llama_data.model.get(), new_token_id)) {
748
+ break;
749
+ }
750
+
751
+ std::string piece;
752
+ if (convert_token_to_string(llama_data.model, new_token_id, piece)) {
753
+ return 1;
754
+ }
755
+
756
+ print_word_and_concatenate_to_response(piece, response);
757
+
758
+ // prepare the next batch with the sampled token
759
+ batch = llama_batch_get_one(&new_token_id, 1);
760
+ }
761
+
762
+ return 0;
763
+ }
764
+
765
+ static int read_user_input(std::string & user) {
766
+ std::getline(std::cin, user);
767
+ return user.empty(); // Should have data in happy path
768
+ }
769
+
770
+ // Function to generate a response based on the prompt
771
+ static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
772
+ const bool stdout_a_terminal) {
773
+ // Set response color
774
+ if (stdout_a_terminal) {
775
+ printf("\033[33m");
776
+ }
777
+
778
+ if (generate(llama_data, prompt, response)) {
779
+ printe("failed to generate response\n");
780
+ return 1;
781
+ }
782
+
783
+ // End response with color reset and newline
784
+ printf("\n%s", stdout_a_terminal ? "\033[0m" : "");
785
+ return 0;
786
+ }
787
+
788
+ // Helper function to apply the chat template and handle errors
789
+ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
790
+ const int new_len = apply_chat_template(llama_data, append);
791
+ if (new_len < 0) {
792
+ printe("failed to apply the chat template\n");
793
+ return -1;
794
+ }
795
+
796
+ output_length = new_len;
797
+ return 0;
798
+ }
799
+
800
+ // Helper function to handle user input
801
+ static int handle_user_input(std::string & user_input, const std::string & user_) {
802
+ if (!user_.empty()) {
803
+ user_input = user_;
804
+ return 0; // No need for interactive input
805
+ }
806
+
807
+ printf(
808
+ "\r%*s"
809
+ "\r\033[32m> \033[0m",
810
+ get_terminal_width(), " ");
811
+ return read_user_input(user_input); // Returns true if input ends the loop
812
+ }
813
+
814
+ static bool is_stdin_a_terminal() {
815
+ #if defined(_WIN32)
816
+ HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
817
+ DWORD mode;
818
+ return GetConsoleMode(hStdin, &mode);
819
+ #else
820
+ return isatty(STDIN_FILENO);
821
+ #endif
822
+ }
823
+
824
+ static bool is_stdout_a_terminal() {
825
+ #if defined(_WIN32)
826
+ HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE);
827
+ DWORD mode;
828
+ return GetConsoleMode(hStdout, &mode);
829
+ #else
830
+ return isatty(STDOUT_FILENO);
831
+ #endif
832
+ }
833
+
834
+ // Function to tokenize the prompt
835
+ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
836
+ int prev_len = 0;
837
+ llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
838
+ static const bool stdout_a_terminal = is_stdout_a_terminal();
839
+ while (true) {
840
+ // Get user input
841
+ std::string user_input;
842
+ while (handle_user_input(user_input, user_)) {
843
+ }
844
+
845
+ add_message("user", user_.empty() ? user_input : user_, llama_data);
846
+ int new_len;
847
+ if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
848
+ return 1;
849
+ }
850
+
851
+ std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
852
+ std::string response;
853
+ if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
854
+ return 1;
855
+ }
856
+
857
+ if (!user_.empty()) {
858
+ break;
859
+ }
860
+
861
+ add_message("assistant", response, llama_data);
862
+ if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
863
+ return 1;
864
+ }
865
+ }
866
+
867
+ return 0;
868
+ }
869
+
870
+ static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
871
+ const Opt * opt = static_cast<Opt *>(p);
872
+ if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
873
+ printe("%s", text);
874
+ }
875
+ }
876
+
877
+ static std::string read_pipe_data() {
878
+ std::ostringstream result;
879
+ result << std::cin.rdbuf(); // Read all data from std::cin
880
+ return result.str();
881
+ }
882
+
883
+ int main(int argc, const char ** argv) {
884
+ Opt opt;
885
+ const int ret = opt.init(argc, argv);
886
+ if (ret == 2) {
887
+ return 0;
888
+ } else if (ret) {
889
+ return 1;
890
+ }
891
+
892
+ if (!is_stdin_a_terminal()) {
893
+ if (!opt.user_.empty()) {
894
+ opt.user_ += "\n\n";
895
+ }
896
+
897
+ opt.user_ += read_pipe_data();
898
+ }
899
+
900
+ llama_log_set(log_callback, &opt);
901
+ LlamaData llama_data;
902
+ if (llama_data.init(opt)) {
903
+ return 1;
904
+ }
905
+
906
+ if (chat_loop(llama_data, opt.user_)) {
907
+ return 1;
908
+ }
909
+
910
+ return 0;
911
+ }