@fugood/llama.node 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (320) hide show
  1. package/CMakeLists.txt +5 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +8 -1
  17. package/package.json +1 -1
  18. package/patches/llama.patch +12 -12
  19. package/src/DetokenizeWorker.cpp +1 -1
  20. package/src/LlamaContext.cpp +33 -1
  21. package/src/LlamaContext.h +1 -0
  22. package/src/LoadSessionWorker.cpp +1 -0
  23. package/src/llama.cpp/.github/workflows/bench.yml +310 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +1315 -0
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +23 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +116 -0
  27. package/src/llama.cpp/.github/workflows/editorconfig.yml +27 -0
  28. package/src/llama.cpp/.github/workflows/gguf-publish.yml +44 -0
  29. package/src/llama.cpp/.github/workflows/labeler.yml +17 -0
  30. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +65 -0
  31. package/src/llama.cpp/.github/workflows/nix-ci.yml +72 -0
  32. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +22 -0
  33. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +36 -0
  34. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +35 -0
  35. package/src/llama.cpp/.github/workflows/python-lint.yml +23 -0
  36. package/src/llama.cpp/.github/workflows/python-type-check.yml +38 -0
  37. package/src/llama.cpp/.github/workflows/server.yml +183 -0
  38. package/src/llama.cpp/CMakeLists.txt +91 -1245
  39. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +1 -1
  40. package/src/llama.cpp/cmake/build-info.cmake +58 -0
  41. package/src/llama.cpp/cmake/git-vars.cmake +22 -0
  42. package/src/llama.cpp/common/CMakeLists.txt +4 -3
  43. package/src/llama.cpp/common/build-info.cpp.in +4 -0
  44. package/src/llama.cpp/common/common.cpp +1116 -877
  45. package/src/llama.cpp/common/common.h +191 -77
  46. package/src/llama.cpp/common/grammar-parser.cpp +118 -31
  47. package/src/llama.cpp/common/json-schema-to-grammar.cpp +346 -65
  48. package/src/llama.cpp/common/log.h +1 -1
  49. package/src/llama.cpp/common/ngram-cache.h +10 -3
  50. package/src/llama.cpp/common/sampling.cpp +19 -10
  51. package/src/llama.cpp/docs/build.md +353 -0
  52. package/src/llama.cpp/examples/CMakeLists.txt +22 -22
  53. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +6 -6
  55. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/batched/batched.cpp +52 -55
  57. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +20 -72
  59. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/chat-13B.bat +57 -0
  61. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/{finetune → cvector-generator}/CMakeLists.txt +2 -2
  63. package/src/llama.cpp/examples/cvector-generator/completions.txt +582 -0
  64. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +503 -0
  65. package/src/llama.cpp/examples/cvector-generator/mean.hpp +48 -0
  66. package/src/llama.cpp/examples/cvector-generator/negative.txt +4 -0
  67. package/src/llama.cpp/examples/cvector-generator/pca.hpp +325 -0
  68. package/src/llama.cpp/examples/cvector-generator/positive.txt +4 -0
  69. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +35 -0
  70. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  71. package/src/llama.cpp/examples/embedding/embedding.cpp +94 -46
  72. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +2 -2
  73. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +4 -6
  74. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/export-lora/export-lora.cpp +344 -386
  76. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +2 -2
  77. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +30 -25
  78. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/gguf/gguf.cpp +5 -0
  80. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +15 -0
  81. package/src/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h +46 -0
  82. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c +295 -0
  83. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h +52 -0
  84. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c +221 -0
  85. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h +24 -0
  86. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c +42 -0
  87. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h +7093 -0
  88. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +693 -0
  89. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  90. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +3 -3
  91. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  92. package/src/llama.cpp/examples/gritlm/gritlm.cpp +6 -2
  93. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  94. package/src/llama.cpp/examples/imatrix/imatrix.cpp +137 -176
  95. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/infill/infill.cpp +38 -153
  97. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +175 -94
  98. package/src/llama.cpp/examples/llama.android/app/build.gradle.kts +65 -0
  99. package/src/llama.cpp/examples/llama.android/build.gradle.kts +6 -0
  100. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +68 -0
  101. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +11 -7
  102. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +2 -2
  103. package/src/llama.cpp/examples/llama.android/settings.gradle.kts +18 -0
  104. package/src/llama.cpp/examples/llava/CMakeLists.txt +6 -5
  105. package/src/llama.cpp/examples/llava/android/build_64.sh +8 -0
  106. package/src/llama.cpp/examples/llava/clip.cpp +23 -14
  107. package/src/llama.cpp/examples/llava/llava-cli.cpp +8 -6
  108. package/src/llama.cpp/examples/llava/requirements.txt +3 -2
  109. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  110. package/src/llama.cpp/examples/lookahead/lookahead.cpp +2 -1
  111. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  112. package/src/llama.cpp/examples/lookup/lookup-create.cpp +2 -0
  113. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  114. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -2
  115. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  116. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  117. package/src/llama.cpp/examples/main/main.cpp +98 -75
  118. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +4 -5
  119. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  120. package/src/llama.cpp/examples/parallel/parallel.cpp +2 -1
  121. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  122. package/src/llama.cpp/examples/passkey/passkey.cpp +23 -43
  123. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  124. package/src/llama.cpp/examples/perplexity/perplexity.cpp +13 -10
  125. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  126. package/src/llama.cpp/examples/quantize/quantize.cpp +37 -34
  127. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  129. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  130. package/src/llama.cpp/examples/retrieval/retrieval.cpp +26 -77
  131. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +14 -7
  133. package/src/llama.cpp/examples/server/CMakeLists.txt +26 -2
  134. package/src/llama.cpp/examples/server/server.cpp +274 -671
  135. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  136. package/src/llama.cpp/examples/server/utils.hpp +28 -29
  137. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  138. package/src/llama.cpp/examples/simple/simple.cpp +21 -29
  139. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  140. package/src/llama.cpp/examples/speculative/speculative.cpp +2 -1
  141. package/src/llama.cpp/examples/sycl/CMakeLists.txt +1 -1
  142. package/src/llama.cpp/examples/sycl/build.sh +23 -0
  143. package/src/llama.cpp/examples/sycl/run-llama2.sh +36 -0
  144. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +33 -0
  145. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +9 -0
  146. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  147. package/src/llama.cpp/examples/tokenize/tokenize.cpp +16 -2
  148. package/src/llama.cpp/ggml/CMakeLists.txt +253 -0
  149. package/src/llama.cpp/{cmake → ggml/cmake}/FindSIMD.cmake +6 -6
  150. package/src/llama.cpp/{ggml-backend.h → ggml/include/ggml-backend.h} +22 -17
  151. package/src/llama.cpp/ggml/include/ggml-blas.h +23 -0
  152. package/src/llama.cpp/ggml/include/ggml-cann.h +125 -0
  153. package/src/llama.cpp/{ggml-cuda.h → ggml/include/ggml-cuda.h} +3 -0
  154. package/src/llama.cpp/{ggml-metal.h → ggml/include/ggml-metal.h} +1 -2
  155. package/src/llama.cpp/{ggml-sycl.h → ggml/include/ggml-sycl.h} +3 -10
  156. package/src/llama.cpp/{ggml.h → ggml/include/ggml.h} +80 -85
  157. package/src/llama.cpp/ggml/src/CMakeLists.txt +1329 -0
  158. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2193 -0
  159. package/src/llama.cpp/ggml/src/ggml-aarch64.h +39 -0
  160. package/src/llama.cpp/{ggml-alloc.c → ggml/src/ggml-alloc.c} +100 -49
  161. package/src/llama.cpp/{ggml-backend-impl.h → ggml/src/ggml-backend-impl.h} +20 -8
  162. package/src/llama.cpp/{ggml-backend.c → ggml/src/ggml-backend.c} +307 -167
  163. package/src/llama.cpp/ggml/src/ggml-blas.cpp +367 -0
  164. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +198 -0
  165. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +230 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +2944 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/common.h +282 -0
  169. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +32 -0
  170. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +17 -0
  171. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +223 -0
  172. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +186 -0
  173. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +180 -0
  174. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +193 -0
  175. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  176. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +208 -0
  177. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +206 -0
  178. package/src/llama.cpp/ggml/src/ggml-cann.cpp +2023 -0
  179. package/src/llama.cpp/{ggml-common.h → ggml/src/ggml-common.h} +41 -7
  180. package/src/llama.cpp/{ggml-impl.h → ggml/src/ggml-impl.h} +113 -9
  181. package/src/llama.cpp/{ggml-kompute.cpp → ggml/src/ggml-kompute.cpp} +33 -18
  182. package/src/llama.cpp/{ggml-quants.c → ggml/src/ggml-quants.c} +1460 -940
  183. package/src/llama.cpp/{ggml-quants.h → ggml/src/ggml-quants.h} +19 -20
  184. package/src/llama.cpp/{ggml-rpc.cpp → ggml/src/ggml-rpc.cpp} +95 -72
  185. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +27 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +53 -0
  187. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +355 -0
  188. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +195 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +21 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +547 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +27 -0
  192. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +698 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  195. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +3011 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  197. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.hpp +33 -0
  198. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1027 -0
  199. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  200. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +374 -0
  201. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +35 -0
  202. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +66 -0
  203. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +275 -0
  204. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +22 -0
  205. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +251 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +24 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +1140 -0
  208. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +5314 -0
  209. package/src/llama.cpp/{ggml-vulkan.cpp → ggml/src/ggml-vulkan.cpp} +1781 -1868
  210. package/src/llama.cpp/{ggml.c → ggml/src/ggml.c} +1245 -2087
  211. package/src/llama.cpp/{sgemm.cpp → ggml/src/llamafile/sgemm.cpp} +21 -24
  212. package/src/llama.cpp/{sgemm.h → ggml/src/llamafile/sgemm.h} +1 -1
  213. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +5 -0
  214. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +552 -0
  215. package/src/llama.cpp/{llama.h → include/llama.h} +175 -100
  216. package/src/llama.cpp/models/.editorconfig +1 -0
  217. package/src/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  218. package/src/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  219. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  220. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +112 -0
  221. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +46 -0
  222. package/src/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  223. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.inp +112 -0
  224. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.out +46 -0
  225. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  226. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +112 -0
  227. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +46 -0
  228. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  229. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +112 -0
  230. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +46 -0
  231. package/src/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  232. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.inp +112 -0
  233. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.out +46 -0
  234. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  235. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +112 -0
  236. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +46 -0
  237. package/src/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  238. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  239. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  242. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +112 -0
  243. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +46 -0
  244. package/src/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  245. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.inp +112 -0
  246. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.out +46 -0
  247. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  248. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +112 -0
  249. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.out +46 -0
  250. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  251. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +112 -0
  252. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.out +46 -0
  253. package/src/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  254. package/src/llama.cpp/models/ggml-vocab-refact.gguf.inp +112 -0
  255. package/src/llama.cpp/models/ggml-vocab-refact.gguf.out +46 -0
  256. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  257. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +112 -0
  258. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.out +46 -0
  259. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  260. package/src/llama.cpp/requirements/requirements-all.txt +12 -0
  261. package/src/llama.cpp/requirements/requirements-compare-llama-bench.txt +2 -0
  262. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -0
  263. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +3 -0
  264. package/src/llama.cpp/requirements/{requirements-convert.txt → requirements-convert_legacy_llama.txt} +1 -1
  265. package/src/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +1 -0
  266. package/src/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  267. package/src/llama.cpp/requirements/requirements-pydantic.txt +3 -0
  268. package/src/llama.cpp/requirements/requirements-test-tokenizer-random.txt +1 -0
  269. package/src/llama.cpp/requirements.txt +5 -4
  270. package/src/llama.cpp/scripts/build-info.sh +30 -0
  271. package/src/llama.cpp/scripts/install-oneapi.bat +19 -0
  272. package/src/llama.cpp/src/CMakeLists.txt +33 -0
  273. package/src/llama.cpp/src/llama-grammar.cpp +539 -0
  274. package/src/llama.cpp/src/llama-grammar.h +39 -0
  275. package/src/llama.cpp/src/llama-impl.h +26 -0
  276. package/src/llama.cpp/src/llama-sampling.cpp +635 -0
  277. package/src/llama.cpp/src/llama-sampling.h +56 -0
  278. package/src/llama.cpp/src/llama-vocab.cpp +1721 -0
  279. package/src/llama.cpp/src/llama-vocab.h +130 -0
  280. package/src/llama.cpp/{llama.cpp → src/llama.cpp} +5979 -5260
  281. package/src/llama.cpp/{unicode-data.cpp → src/unicode-data.cpp} +851 -802
  282. package/src/llama.cpp/{unicode.cpp → src/unicode.cpp} +52 -30
  283. package/src/llama.cpp/{unicode.h → src/unicode.h} +5 -1
  284. package/src/llama.cpp/tests/CMakeLists.txt +19 -20
  285. package/src/llama.cpp/tests/test-backend-ops.cpp +245 -67
  286. package/src/llama.cpp/tests/test-chat-template.cpp +57 -3
  287. package/src/llama.cpp/tests/test-double-float.cpp +2 -2
  288. package/src/llama.cpp/tests/test-grad0.cpp +2 -2
  289. package/src/llama.cpp/tests/test-grammar-integration.cpp +978 -31
  290. package/src/llama.cpp/tests/test-grammar-parser.cpp +423 -158
  291. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +508 -135
  292. package/src/llama.cpp/tests/test-llama-grammar.cpp +15 -9
  293. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -1
  294. package/src/llama.cpp/tests/test-quantize-perf.cpp +1 -1
  295. package/src/llama.cpp/tests/test-rope.cpp +3 -4
  296. package/src/llama.cpp/tests/test-sampling.cpp +5 -5
  297. package/src/llama.cpp/tests/test-tokenizer-0.cpp +6 -6
  298. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +20 -15
  299. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +22 -11
  300. package/bin/darwin/arm64/default.metallib +0 -0
  301. package/bin/darwin/x64/default.metallib +0 -0
  302. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +0 -5
  303. package/src/llama.cpp/examples/beam-search/beam-search.cpp +0 -188
  304. package/src/llama.cpp/examples/finetune/finetune.cpp +0 -1862
  305. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +0 -55
  306. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +0 -5
  307. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +0 -1253
  308. package/src/llama.cpp/ggml-opencl.cpp +0 -2305
  309. package/src/llama.cpp/ggml-opencl.h +0 -36
  310. package/src/llama.cpp/ggml-sycl.cpp +0 -17340
  311. package/src/llama.cpp/ggml-vulkan-shaders.hpp +0 -81211
  312. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +0 -2
  313. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -2
  314. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +0 -1
  315. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +0 -24
  316. /package/src/llama.cpp/{ggml-alloc.h → ggml/include/ggml-alloc.h} +0 -0
  317. /package/src/llama.cpp/{ggml-kompute.h → ggml/include/ggml-kompute.h} +0 -0
  318. /package/src/llama.cpp/{ggml-rpc.h → ggml/include/ggml-rpc.h} +0 -0
  319. /package/src/llama.cpp/{ggml-vulkan.h → ggml/include/ggml-vulkan.h} +0 -0
  320. /package/src/llama.cpp/{unicode-data.h → src/unicode-data.h} +0 -0
@@ -1,1253 +0,0 @@
1
- #include "ggml.h"
2
- #include "ggml-alloc.h"
3
- #include "ggml-backend.h"
4
- #include "common.h"
5
- #include "train.h"
6
- #include "llama.h"
7
- #include <unordered_map>
8
- #include <vector>
9
- #include <cassert>
10
- #include <climits>
11
- #include <cstring>
12
- #include <cstdarg>
13
- #include <ctime>
14
- #include <random>
15
- #include <stdexcept>
16
- #include <algorithm>
17
- #include <string>
18
-
19
- #if defined(_MSC_VER)
20
- #pragma warning(disable: 4244 4267) // possible loss of data
21
- #endif
22
-
23
- struct my_llama_hparams {
24
- uint32_t n_vocab = 32000;
25
- uint32_t n_ctx = 512;
26
- uint32_t n_embd = 4096;
27
- uint32_t n_head = 32;
28
- uint32_t n_layer = 32;
29
- uint32_t n_rot = 64;
30
- uint32_t n_ff = 11008;
31
-
32
- // float f_norm_eps = 1e-5f; // falcon
33
- float f_norm_rms_eps = 1e-5f; // llama
34
-
35
- float rope_freq_base = 10000.0f;
36
- float rope_freq_scale = 1.0f;
37
- };
38
-
39
- struct my_llama_layer {
40
- // normalization
41
- struct ggml_tensor * attention_norm;
42
-
43
- // attention
44
- struct ggml_tensor * wq;
45
- struct ggml_tensor * wk;
46
- struct ggml_tensor * wv;
47
- struct ggml_tensor * wo;
48
-
49
- // normalization
50
- struct ggml_tensor * ffn_norm;
51
-
52
- // ff
53
- struct ggml_tensor * ffn_gate; // w1
54
- struct ggml_tensor * ffn_down; // w2
55
- struct ggml_tensor * ffn_up; // w3
56
- };
57
-
58
- struct my_llama_model {
59
- struct ggml_context * ctx = NULL;
60
- ggml_backend_buffer_t data = NULL;
61
-
62
- my_llama_hparams hparams;
63
-
64
- struct ggml_tensor * tok_embeddings;
65
-
66
- struct ggml_tensor * norm;
67
- struct ggml_tensor * output;
68
-
69
- std::vector<my_llama_layer> layers;
70
- };
71
-
72
- // gguf constants (sync with gguf.py)
73
- static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
74
- static const char * LLM_KV_TRAINING_TYPE = "training.type";
75
-
76
- static const char * LLM_KV_GENERAL_NAME = "general.name";
77
- static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
78
- static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
79
-
80
- static const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
81
- static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
82
- static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
83
- static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
84
- static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
85
- static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
86
- static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
87
- static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
88
- static const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
89
-
90
- static const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
91
- static const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
92
- static const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
93
- static const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
94
- static const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
95
- static const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
96
- static const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
97
- static const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
98
- static const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
99
- static const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
100
-
101
- static const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
102
- static const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
103
- static const char * LLM_TENSOR_OUTPUT = "output";
104
- static const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
105
- static const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
106
- static const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
107
- static const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
108
- static const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
109
- static const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
110
- static const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
111
- static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
112
- static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
113
-
114
- static void print_params(struct my_llama_hparams * params) {
115
- printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
116
- printf("%s: n_ctx: %u\n", __func__, params->n_ctx);
117
- printf("%s: n_embd: %u\n", __func__, params->n_embd);
118
- printf("%s: n_head: %u\n", __func__, params->n_head);
119
- printf("%s: n_ff: %u\n", __func__, params->n_ff);
120
- printf("%s: n_layer: %u\n", __func__, params->n_layer);
121
- printf("%s: n_rot: %u\n", __func__, params->n_rot);
122
- }
123
-
124
- static void set_param_model(struct my_llama_model * model) {
125
- const auto& hparams = model->hparams;
126
-
127
- const uint32_t n_layer = hparams.n_layer;
128
-
129
- struct ggml_context* ctx = model->ctx;
130
-
131
- ggml_set_param(ctx, model->tok_embeddings);
132
- ggml_set_param(ctx, model->norm);
133
- ggml_set_param(ctx, model->output);
134
-
135
- for (uint32_t i = 0; i < n_layer; ++i) {
136
- auto & layer = model->layers[i];
137
-
138
- ggml_set_param(ctx, layer.attention_norm);
139
- ggml_set_param(ctx, layer.wq);
140
- ggml_set_param(ctx, layer.wk);
141
- ggml_set_param(ctx, layer.wv);
142
- ggml_set_param(ctx, layer.wo);
143
- ggml_set_param(ctx, layer.ffn_norm);
144
- ggml_set_param(ctx, layer.ffn_gate);
145
- ggml_set_param(ctx, layer.ffn_down);
146
- ggml_set_param(ctx, layer.ffn_up);
147
- }
148
- }
149
-
150
- static void init_model(struct my_llama_model * model) {
151
- const auto & hparams = model->hparams;
152
-
153
- const uint32_t n_embd = hparams.n_embd;
154
- const uint32_t n_layer = hparams.n_layer;
155
- const uint32_t n_vocab = hparams.n_vocab;
156
- const uint32_t n_ff = hparams.n_ff;
157
-
158
-
159
- std::vector<char> tn_buf;
160
- tn_buf.resize(GGML_MAX_NAME);
161
- auto tn = [&tn_buf](const char * key) -> const char * {
162
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
163
- return tn_buf.data();
164
- };
165
- auto tni = [&tn_buf](const char * key, int bid) -> const char * {
166
- snprintf(tn_buf.data(), tn_buf.size(), key, bid);
167
- std::string s = tn_buf.data();
168
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
169
- return tn_buf.data();
170
- };
171
-
172
- // context for model tensors without their data
173
- struct ggml_init_params ctx_model_params;
174
- ctx_model_params.mem_size = ggml_tensor_overhead()*2*(6 + n_layer*18);
175
- ctx_model_params.mem_buffer = NULL;
176
- ctx_model_params.no_alloc = true;
177
-
178
- struct ggml_context * ctx = ggml_init(ctx_model_params);
179
- model->ctx = ctx;
180
-
181
- model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
182
- model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
183
- model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
184
-
185
- ggml_set_name(model->tok_embeddings, tn(LLM_TENSOR_TOKEN_EMBD));
186
- ggml_set_name(model->norm, tn(LLM_TENSOR_OUTPUT_NORM));
187
- ggml_set_name(model->output, tn(LLM_TENSOR_OUTPUT));
188
-
189
- model->layers.resize(n_layer);
190
- for (uint32_t i = 0; i < n_layer; ++i) {
191
- auto & layer = model->layers[i];
192
-
193
- layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
194
-
195
- layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
196
- layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
197
- layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
198
- layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
199
-
200
- layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
201
-
202
- layer.ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
203
- layer.ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
204
- layer.ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
205
-
206
- ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
207
-
208
- ggml_set_name(layer.wq, tni(LLM_TENSOR_ATTN_Q, i));
209
- ggml_set_name(layer.wk, tni(LLM_TENSOR_ATTN_K, i));
210
- ggml_set_name(layer.wv, tni(LLM_TENSOR_ATTN_V, i));
211
- ggml_set_name(layer.wo, tni(LLM_TENSOR_ATTN_OUT, i));
212
-
213
- ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
214
-
215
- ggml_set_name(layer.ffn_gate, tni(LLM_TENSOR_FFN_GATE, i));
216
- ggml_set_name(layer.ffn_down, tni(LLM_TENSOR_FFN_DOWN, i));
217
- ggml_set_name(layer.ffn_up, tni(LLM_TENSOR_FFN_UP, i));
218
- }
219
-
220
- set_param_model(model);
221
-
222
- // allocate data
223
- model->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
224
- }
225
-
226
- static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
227
- const auto & hparams = model->hparams;
228
-
229
- const uint32_t n_layer = hparams.n_layer;
230
-
231
- struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
232
-
233
- randomize_tensor_normal(model->tok_embeddings, rnd);
234
- randomize_tensor_normal(model->norm, rnd);
235
- randomize_tensor_normal(model->output, rnd);
236
-
237
- for (uint32_t i = 0; i < n_layer; ++i) {
238
- auto & layer = model->layers[i];
239
- randomize_tensor_normal(layer.attention_norm, rnd);
240
-
241
- randomize_tensor_normal(layer.wq, rnd);
242
- randomize_tensor_normal(layer.wk, rnd);
243
- randomize_tensor_normal(layer.wv, rnd);
244
- randomize_tensor_normal(layer.wo, rnd);
245
-
246
- randomize_tensor_normal(layer.ffn_norm, rnd);
247
-
248
- randomize_tensor_normal(layer.ffn_gate, rnd);
249
- randomize_tensor_normal(layer.ffn_down, rnd);
250
- randomize_tensor_normal(layer.ffn_up, rnd);
251
- }
252
-
253
- free_random_normal_distribution(rnd);
254
- }
255
-
256
- static struct ggml_tensor * llama_build_train_graphs(
257
- struct my_llama_model * model,
258
- ggml_gallocr_t alloc,
259
- struct ggml_context * ctx,
260
- struct ggml_cgraph * gf,
261
- struct ggml_cgraph * gb,
262
- struct ggml_cgraph * gb_tmp,
263
- struct ggml_tensor * * logits,
264
- struct ggml_tensor * tokens_input,
265
- struct ggml_tensor * targets,
266
- const int n_tokens,
267
- const int n_batch,
268
- const bool enable_flash_attn,
269
- const bool enable_checkpointing,
270
- const bool measure_only) {
271
-
272
- ggml_set_scratch(ctx, { 0, 0, nullptr, });
273
- const int n_past = 0;
274
- const int N = n_tokens;
275
- const auto & hparams = model->hparams;
276
- const int n_ctx = hparams.n_ctx;
277
- const int n_vocab = hparams.n_vocab;
278
- const int n_embd = hparams.n_embd;
279
- const int n_layer = hparams.n_layer;
280
- const int n_head = hparams.n_head;
281
- const int n_rot = hparams.n_rot;
282
- const int n_ff = hparams.n_ff;
283
- const float f_norm_rms_eps = hparams.f_norm_rms_eps;
284
- const float rope_freq_base = hparams.rope_freq_base;
285
- const float rope_freq_scale = hparams.rope_freq_scale;
286
-
287
- auto set_name = [](struct ggml_tensor * t, const char * n) {
288
- ggml_set_name(t, n);
289
- if (t->grad) {
290
- ggml_format_name(t->grad, "%s->grad", n);
291
- }
292
- };
293
-
294
- // KQ_pos - contains the positions
295
- struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
296
- ggml_set_input(KQ_pos);
297
-
298
- // rope has so much parameters that we make a custom function for it
299
- auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
300
- (struct ggml_tensor * t) -> struct ggml_tensor * {
301
- // not capturing these, to silcence warnings
302
- const int rope_mode = 0;
303
-
304
- return ggml_rope_ext(
305
- ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
306
- );
307
- };
308
-
309
- set_name(tokens_input, "tokens_input");
310
- set_name(targets, "targets");
311
-
312
- GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
313
- struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
314
- struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
315
-
316
- struct ggml_tensor * cur = t01;
317
-
318
- std::vector<struct ggml_tensor *> checkpoints;
319
- checkpoints.push_back(tokens_input);
320
- checkpoints.push_back(targets);
321
- checkpoints.push_back(t00);
322
- checkpoints.push_back(t01);
323
-
324
- const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
325
-
326
- for (int il = 0; il < n_layer; ++il) {
327
- struct my_llama_layer & layer = model->layers[il];
328
- struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
329
- struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
330
- struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
331
- struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
332
- struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
333
- struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
334
- struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
335
- struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
336
- struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
337
- struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd);
338
- struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
339
- struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
340
- struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
341
- struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
342
- struct ggml_tensor * t16;
343
- if (enable_flash_attn) {
344
- GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
345
- //t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
346
- } else {
347
- struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
348
- struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
349
- struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
350
- struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
351
- t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
352
- }
353
- struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
354
- struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
355
- struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
356
- struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
357
- struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
358
- struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
359
- struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
360
- struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
361
- struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.ffn_up, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
362
- struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.ffn_gate, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
363
- struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
364
- struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
365
- struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.ffn_down, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
366
- struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
367
- cur = t30;
368
- checkpoints.push_back(cur);
369
- }
370
- struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
371
- struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
372
- struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
373
- struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
374
- struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
375
- struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
376
-
377
- checkpoints.push_back(t31);
378
- checkpoints.push_back(t32);
379
- checkpoints.push_back(t33);
380
- checkpoints.push_back(t34);
381
- checkpoints.push_back(t35);
382
- checkpoints.push_back(t36);
383
-
384
- ggml_build_forward_expand(gf, t36);
385
-
386
- if (enable_checkpointing) {
387
- ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
388
- } else {
389
- ggml_graph_cpy(gf, gb);
390
- ggml_build_backward_expand(ctx, gf, gb, true);
391
- }
392
-
393
- if (alloc) {
394
- // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
395
- int n_leafs_before = gb->n_leafs;
396
- int n_nodes_before = gb->n_nodes;
397
- // output tensors
398
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
399
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
400
- // input gradient
401
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
402
- // KQ_pos
403
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
404
- GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
405
- ggml_set_input(t36->grad);
406
-
407
- // allocating checkpoints in one block to reduce memory fragmentation
408
- // note: they will be freed in reverse order
409
- for (int i = 0; i < (int) checkpoints.size(); ++i) {
410
- if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
411
- ggml_set_input(checkpoints[i]);
412
- }
413
- }
414
-
415
- //int n_leafs_after = gb->n_leafs;
416
- //int n_nodes_after = gb->n_nodes;
417
- if (measure_only) {
418
- // FIXME: will still allocate
419
- ggml_gallocr_reserve(alloc, gb);
420
- } else {
421
- ggml_gallocr_alloc_graph(alloc, gb);
422
-
423
- if (!measure_only) {
424
- int * data = (int *) KQ_pos->data;
425
- for (int i = 0; i < N; ++i) {
426
- data[i] = n_past + i;
427
- }
428
- }
429
- }
430
-
431
- // remove the additional nodes and leafs
432
- for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
433
- gb->leafs[i] = NULL;
434
- }
435
- for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
436
- gb->nodes[i] = NULL;
437
- }
438
- gb->n_leafs = n_leafs_before;
439
- gb->n_nodes = n_nodes_before;
440
- }
441
-
442
- *logits = t35;
443
- return t36;
444
- }
445
-
446
- #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
447
- do { \
448
- const std::string skey(key); \
449
- const int kid = gguf_find_key(ctx, skey.c_str()); \
450
- if (kid >= 0) { \
451
- enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
452
- if (ktype != (type)) { \
453
- die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
454
- } \
455
- (dst) = func(ctx, kid); \
456
- } else if (req) { \
457
- die_fmt("key not found in model: %s", skey.c_str()); \
458
- } \
459
- } while (0)
460
-
461
- static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
462
- // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
463
- std::string arch;
464
-
465
- std::vector<char> keybuf;
466
- keybuf.resize(512);
467
- auto kv = [&arch, &keybuf](const char * key) -> const char * {
468
- snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
469
- return keybuf.data();
470
- };
471
-
472
- std::vector<char> tn_buf;
473
- tn_buf.resize(GGML_MAX_NAME);
474
- auto tn = [&tn_buf](const char * key) -> const char * {
475
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
476
- return tn_buf.data();
477
- };
478
- auto tni = [&tn_buf](const char * key, int bid) -> const char * {
479
- snprintf(tn_buf.data(), tn_buf.size(), key, bid);
480
- std::string s = tn_buf.data();
481
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
482
- return tn_buf.data();
483
- };
484
-
485
- GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
486
- GGML_ASSERT(arch == "llama");
487
-
488
- uint32_t ftype_u;
489
- GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
490
- GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
491
-
492
- // n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
493
- GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
494
-
495
- GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
496
- GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
497
- GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
498
- GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
499
-
500
- model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
501
- GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
502
-
503
- float rope_freq_scale = 1.0f;
504
- GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
505
- GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
506
- GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
507
- if (rope_freq_scale != 1.0f) {
508
- model->hparams.rope_freq_scale = 1.0f / rope_freq_scale;
509
- }
510
-
511
- init_model(model);
512
-
513
- copy_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
514
- copy_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
515
- copy_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
516
-
517
- for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
518
- auto & layer = model->layers[i];
519
-
520
- copy_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
521
- copy_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
522
- copy_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
523
- copy_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
524
- copy_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
525
- copy_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
526
- copy_tensor_by_name(layer.ffn_gate, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
527
- copy_tensor_by_name(layer.ffn_down, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
528
- copy_tensor_by_name(layer.ffn_up, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
529
- }
530
- }
531
-
532
- static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
533
- const char * arch = "llama";
534
-
535
- enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
536
-
537
- std::vector<char> keybuf;
538
- keybuf.resize(512);
539
- auto kv = [arch, &keybuf](const char * key) -> const char * {
540
- snprintf(keybuf.data(), keybuf.size(), key, arch);
541
- return keybuf.data();
542
- };
543
-
544
- // set arch
545
- gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
546
- gguf_set_val_str(fctx, LLM_KV_GENERAL_NAME, arch);
547
- gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
548
-
549
- // set hparams
550
- gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx );
551
- gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd );
552
- gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff );
553
- gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head );
554
- gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer );
555
- gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot );
556
-
557
- gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps );
558
- gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base ); // TODO load in llama.cpp
559
- gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), 1.0f / model->hparams.rope_freq_scale );
560
-
561
- // set vocab by copying from vocab_model gguf file
562
- {
563
- struct gguf_init_params params = {
564
- /*.no_alloc = */ false,
565
- /*.ctx = */ NULL,
566
- };
567
- struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
568
-
569
- const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST));
570
- if (token_idx == -1) {
571
- die("cannot find tokenizer vocab in model file");
572
- }
573
- const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
574
-
575
- const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES));
576
- if (score_idx == -1) {
577
- die("cannot find tokenizer scores in model file");
578
- }
579
-
580
- const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
581
-
582
- const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE));
583
- if (toktype_idx == -1) {
584
- die("cannot find token type list in GGUF file");
585
- }
586
-
587
- const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
588
-
589
- std::string tokenizer_name;
590
- GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
591
-
592
- gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL), tokenizer_name.c_str());
593
- gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES), GGUF_TYPE_FLOAT32, scores, n_vocab);
594
- gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE), GGUF_TYPE_INT32, toktypes, n_vocab);
595
-
596
- int32_t special_bos_id = 1;
597
- int32_t special_eos_id = 2;
598
- int32_t special_unk_id = 0;
599
- int32_t special_sep_id = -1;
600
- int32_t special_pad_id = -1;
601
- if (tokenizer_name == "llama") {
602
- // default special tokens
603
- special_bos_id = 1;
604
- special_eos_id = 2;
605
- special_unk_id = 0;
606
- special_sep_id = -1;
607
- special_pad_id = -1;
608
- } else if (tokenizer_name == "gpt2") {
609
- // read and copy bpe merges
610
- const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES));
611
- if (merges_keyidx == -1) {
612
- die("cannot find tokenizer merges in model file");
613
- }
614
-
615
- const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
616
-
617
- std::vector<const char*> merges;
618
- merges.resize(n_merges);
619
- for (int i = 0; i < n_merges; i++) {
620
- merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
621
- }
622
- gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES), merges.data(), n_merges);
623
-
624
- // default special tokens
625
- special_bos_id = 11;
626
- special_eos_id = 11;
627
- special_unk_id = -1;
628
- special_sep_id = -1;
629
- special_pad_id = -1;
630
- } else {
631
- fprintf(stderr, "%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
632
- fprintf(stderr, "%s: using default tokenizer: 'llama'", __func__);
633
- }
634
-
635
- std::vector<const char*> tokens;
636
- tokens.resize(n_vocab);
637
- for (uint32_t i = 0; i < n_vocab; i++) {
638
- tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
639
- }
640
- gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST), tokens.data(), n_vocab);
641
-
642
- GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
643
- GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
644
- GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
645
- GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
646
- GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
647
-
648
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID), special_bos_id);
649
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID), special_eos_id);
650
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID), special_unk_id);
651
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID), special_sep_id);
652
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID), special_pad_id);
653
-
654
- gguf_free(vctx);
655
- }
656
-
657
- // add tensors
658
- gguf_add_tensor(fctx, model->tok_embeddings);
659
- gguf_add_tensor(fctx, model->norm);
660
- gguf_add_tensor(fctx, model->output);
661
- for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
662
- auto & layer = model->layers[i];
663
-
664
-
665
- gguf_add_tensor(fctx, layer.attention_norm);
666
- gguf_add_tensor(fctx, layer.wq);
667
- gguf_add_tensor(fctx, layer.wk);
668
- gguf_add_tensor(fctx, layer.wv);
669
- gguf_add_tensor(fctx, layer.wo);
670
- gguf_add_tensor(fctx, layer.ffn_norm);
671
- gguf_add_tensor(fctx, layer.ffn_gate);
672
- gguf_add_tensor(fctx, layer.ffn_down);
673
- gguf_add_tensor(fctx, layer.ffn_up);
674
- }
675
- }
676
-
677
- static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
678
- printf("%s: saving to %s\n", __func__, filename);
679
- struct gguf_context * fctx = gguf_init_empty();
680
-
681
- save_llama_model_gguf(fctx, fn_vocab_model, model);
682
-
683
- // write file
684
- const bool only_meta = false;
685
- gguf_write_to_file(fctx, filename, only_meta);
686
- gguf_free(fctx);
687
- }
688
-
689
- static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
690
- load_llama_model_gguf(fctx, f_ggml_ctx, model);
691
- if (load_train_state_gguf(fctx, f_ggml_ctx, train)) {
692
- std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
693
- GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
694
- GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
695
- } else {
696
- printf("%s: loaded llama model as checkpoint\n", __func__);
697
- }
698
- }
699
-
700
- static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
701
- gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
702
- save_llama_model_gguf(fctx, fn_vocab_model, model);
703
- save_train_state_gguf(fctx, train);
704
- }
705
-
706
- static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct train_state * train) {
707
- struct ggml_context * f_ggml_ctx;
708
- struct gguf_init_params params;
709
- params.no_alloc = false;
710
- params.ctx = &f_ggml_ctx;
711
- struct gguf_context * fctx = gguf_init_from_file(filename, params);
712
- if (fctx == NULL) {
713
- return false;
714
- }
715
-
716
- load_checkpoint_gguf(fctx, f_ggml_ctx, model, train);
717
-
718
- gguf_free(fctx);
719
- return true;
720
- }
721
-
722
- static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
723
- printf("%s: saving to %s\n", __func__, filename);
724
- struct gguf_context * fctx = gguf_init_empty();
725
-
726
- save_checkpoint_gguf(fctx, fn_vocab_model, model, train);
727
-
728
- // write file
729
- const bool only_meta = false;
730
- gguf_write_to_file(fctx, filename, only_meta);
731
- gguf_free(fctx);
732
- }
733
-
734
- struct train_params {
735
- struct train_params_common common;
736
-
737
- const char * fn_vocab_model;
738
- const char * fn_model_out;
739
-
740
- bool only_write_model;
741
-
742
- int n_ctx;
743
- int n_embd;
744
- int n_head;
745
- int n_layer;
746
- int n_ff;
747
-
748
- float f_norm_rms_eps;
749
- float rope_freq_base;
750
- float rope_freq_scale;
751
- };
752
-
753
- static struct train_params get_default_train_params() {
754
- struct train_params params;
755
- params.common = get_default_train_params_common();
756
- params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
757
- params.fn_model_out = "ggml-checkpoint-f32.bin";
758
-
759
- params.only_write_model = false;
760
-
761
- params.n_ctx = 128;
762
- params.n_embd = 256;
763
- params.n_head = 8;
764
- params.n_layer = 16;
765
- params.n_ff = 768;
766
-
767
- params.f_norm_rms_eps = 1e-5f;
768
- params.rope_freq_base = 10000.0f;
769
- params.rope_freq_scale = 1.0f;
770
-
771
- return params;
772
- }
773
-
774
- static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
775
- fprintf(stderr, "usage: %s [options]\n", argv[0]);
776
- fprintf(stderr, "\n");
777
- fprintf(stderr, "options:\n");
778
- fprintf(stderr, " -h, --help show this help message and exit\n");
779
-
780
- fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
781
- fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
782
- fprintf(stderr, " --only-write-model only save llama model, don't do any training. use this if you only want to convert a checkpoint to a model.\n");
783
- fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
784
- fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
785
- fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
786
- fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
787
- fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
788
- fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
789
- fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
790
-
791
- print_common_train_usage(argc, argv, &params->common);
792
- }
793
-
794
- static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
795
- bool invalid_param = false;
796
- std::string arg;
797
- struct train_params default_params = get_default_train_params();
798
- const std::string arg_prefix = "--";
799
-
800
- for (int i = 1; i < argc; i++) {
801
- arg = argv[i];
802
- if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
803
- std::replace(arg.begin(), arg.end(), '_', '-');
804
- }
805
-
806
- if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
807
- if (invalid_param) {
808
- break;
809
- } else if (params->common.print_usage) {
810
- train_print_usage(argc, argv, &default_params);
811
- exit(0);
812
- }
813
- } else if (arg == "--vocab-model") {
814
- if (++i >= argc) {
815
- invalid_param = true;
816
- break;
817
- }
818
- params->fn_vocab_model = argv[i];
819
- } else if (arg == "--model-out") {
820
- if (++i >= argc) {
821
- invalid_param = true;
822
- break;
823
- }
824
- params->fn_model_out = argv[i];
825
- } else if (arg == "--only-write-model") {
826
- params->only_write_model = true;
827
- } else if (arg == "--embd") {
828
- if (++i >= argc) {
829
- invalid_param = true;
830
- break;
831
- }
832
- params->n_embd = std::stoi(argv[i]);
833
- } else if (arg == "--ff") {
834
- if (++i >= argc) {
835
- invalid_param = true;
836
- break;
837
- }
838
- params->n_ff = std::stoi(argv[i]);
839
- } else if (arg == "--head") {
840
- if (++i >= argc) {
841
- invalid_param = true;
842
- break;
843
- }
844
- params->n_head = std::stoi(argv[i]);
845
- } else if (arg == "--layer") {
846
- if (++i >= argc) {
847
- invalid_param = true;
848
- break;
849
- }
850
- params->n_layer = std::stoi(argv[i]);
851
- } else if (arg == "--norm-rms-eps") {
852
- if (++i >= argc) {
853
- invalid_param = true;
854
- break;
855
- }
856
- params->f_norm_rms_eps = std::stof(argv[i]);
857
- } else if (arg == "--rope-freq-base") {
858
- if (++i >= argc) {
859
- invalid_param = true;
860
- break;
861
- }
862
- params->rope_freq_base = std::stof(argv[i]);
863
- } else if (arg == "--rope-freq-scale") {
864
- if (++i >= argc) {
865
- invalid_param = true;
866
- break;
867
- }
868
- params->rope_freq_scale = std::stof(argv[i]);
869
- } else {
870
- fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
871
- train_print_usage(argc, argv, &default_params);
872
- exit(1);
873
- }
874
- }
875
- if (invalid_param) {
876
- fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
877
- train_print_usage(argc, argv, &default_params);
878
- exit(1);
879
- }
880
- finish_processing_train_args(&params->common);
881
-
882
- return true;
883
- }
884
-
885
- struct save_train_files_data {
886
- const char * fn_checkpoint_out;
887
- const char * fn_model_out;
888
- const char * fn_vocab_model;
889
- const char * pattern_fn_it;
890
- const char * fn_latest;
891
- struct my_llama_model * model;
892
- };
893
-
894
- static void save_train_files(void * vdata, struct train_state * train) {
895
- struct save_train_files_data * data = (struct save_train_files_data *) vdata;
896
- int64_t iter = train->opt->iter;
897
-
898
- if (strlen(data->fn_checkpoint_out) > 0) {
899
- save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model, train);
900
- save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model, train);
901
-
902
- }
903
- if (strlen(data->fn_model_out) > 0) {
904
- save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model);
905
- save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model);
906
- }
907
- }
908
-
909
- static int64_t get_parameter_count(struct my_llama_model* model) {
910
- int64_t nx = 0;
911
- nx += ggml_nelements(model->tok_embeddings);
912
- nx += ggml_nelements(model->norm);
913
- nx += ggml_nelements(model->output);
914
-
915
- for (uint32_t i = 0; i < model->layers.size(); ++i) {
916
- auto & layer = model->layers[i];
917
- nx += ggml_nelements(layer.attention_norm);
918
- nx += ggml_nelements(layer.wq);
919
- nx += ggml_nelements(layer.wk);
920
- nx += ggml_nelements(layer.wv);
921
- nx += ggml_nelements(layer.wo);
922
- nx += ggml_nelements(layer.ffn_norm);
923
- nx += ggml_nelements(layer.ffn_gate);
924
- nx += ggml_nelements(layer.ffn_down);
925
- nx += ggml_nelements(layer.ffn_up);
926
- }
927
- return nx;
928
- }
929
-
930
- int main(int argc, char ** argv) {
931
- struct train_params params = get_default_train_params();
932
-
933
- if (!train_params_parse(argc, argv, &params)) {
934
- return 1;
935
- }
936
-
937
- if (params.common.seed == LLAMA_DEFAULT_SEED) {
938
- params.common.seed = time(NULL);
939
- }
940
- printf("%s: seed: %u\n", __func__, params.common.seed);
941
- srand(params.common.seed);
942
-
943
- struct llama_model_params mparams = llama_model_default_params();
944
- mparams.vocab_only = true;
945
-
946
- struct llama_context_params cparams = llama_context_default_params();
947
-
948
- struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams);
949
- struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams);
950
-
951
- struct my_llama_model model;
952
- model.hparams.n_vocab = llama_n_vocab(lmodel);
953
- model.hparams.n_ctx = params.common.n_ctx;
954
- model.hparams.n_embd = params.n_embd;
955
- model.hparams.n_head = params.n_head;
956
- model.hparams.n_layer = params.n_layer;
957
- model.hparams.n_ff = params.n_ff;
958
- // llama.cpp requires n_rot to be exactly n_embd / n_head
959
- model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head;
960
- model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
961
- model.hparams.rope_freq_base = params.rope_freq_base;
962
- model.hparams.rope_freq_scale = params.rope_freq_scale;
963
-
964
- struct train_state * train = init_train_state();
965
- struct ggml_opt_context * opt = train->opt;
966
-
967
- // set opt params from command line
968
- opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
969
- opt->params.print_forward_graph = false;
970
- opt->params.print_backward_graph = false;
971
- opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
972
- opt->params.n_threads = params.common.n_threads;
973
- opt->params.past = params.common.opt_past;
974
- opt->params.delta = params.common.opt_delta;
975
- opt->params.max_no_improvement = params.common.opt_max_no_improvement;
976
- opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
977
- opt->params.adam.n_iter = params.common.adam_n_iter;
978
- opt->params.adam.sched = 1.0f;
979
- opt->params.adam.alpha = params.common.adam_alpha;
980
- opt->params.adam.decay = params.common.adam_decay;
981
- opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
982
- opt->params.adam.beta1 = params.common.adam_beta1;
983
- opt->params.adam.beta2 = params.common.adam_beta2;
984
- opt->params.adam.gclip = params.common.adam_gclip;
985
- opt->params.adam.eps_f = params.common.adam_eps_f;
986
-
987
- printf("%s: init model\n", __func__);
988
- bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
989
- if (existed) {
990
- // overwrite last n_ctx with user provided n_ctx
991
- if (params.common.custom_n_ctx) {
992
- model.hparams.n_ctx = params.common.n_ctx;
993
- }
994
-
995
- const bool opt_past_changed = opt->params.past != params.common.opt_past;
996
-
997
- if (opt_past_changed) {
998
- die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value train from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
999
- // need to discard previous optimizer past function value statistics and opt_init with new shapes
1000
- // TODO
1001
- }
1002
- } else {
1003
- init_model(&model);
1004
- randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
1005
- if (!params.only_write_model) {
1006
- ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&model));
1007
- }
1008
- }
1009
- opt->iter = train->train_its;
1010
-
1011
- print_params(&model.hparams);
1012
- printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
1013
- printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
1014
- printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
1015
- printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
1016
- printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)), (float) (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)) / (1024.0f*1024.0f));
1017
-
1018
- if (params.only_write_model) {
1019
- save_train_files_data save_data;
1020
- save_data.fn_checkpoint_out = "";
1021
- save_data.fn_model_out = params.fn_model_out;
1022
- save_data.fn_vocab_model = params.fn_vocab_model;
1023
- save_data.pattern_fn_it = params.common.pattern_fn_it;
1024
- save_data.fn_latest = params.common.fn_latest;
1025
- save_data.model = &model;
1026
-
1027
- save_train_files(&save_data, train);
1028
-
1029
- free_train_state(train);
1030
- ggml_free(model.ctx);
1031
- llama_free(lctx);
1032
- llama_free_model(lmodel);
1033
- return 0;
1034
- }
1035
-
1036
- printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
1037
- printf("%s: opt iter %d\n", __func__, opt->iter);
1038
-
1039
- int n_tokens = model.hparams.n_ctx;
1040
- int n_vocab = model.hparams.n_vocab;
1041
- int n_batch = params.common.n_batch;
1042
-
1043
- // context for input tensors without their data
1044
- struct ggml_init_params ctx_input_params = {
1045
- ggml_tensor_overhead() * 2, // mem_size
1046
- NULL, // mem_buffer
1047
- true, // no_alloc
1048
- };
1049
- struct ggml_context * ctx_input = ggml_init(ctx_input_params);
1050
-
1051
- // the input tensors
1052
- struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
1053
- struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
1054
-
1055
- // measure required memory for input tensors
1056
- // allocate input tensors
1057
- ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
1058
- size_t max_input_size = ggml_backend_buffer_get_size(input_data);
1059
- printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
1060
-
1061
- // context for compute tensors without their data
1062
- const size_t estimated_compute_size_wo_data = (
1063
- 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
1064
- (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
1065
- );
1066
- struct ggml_init_params ctx_compute_params = {
1067
- estimated_compute_size_wo_data, // mem_size
1068
- NULL, // mem_buffer
1069
- true, // no_alloc
1070
- };
1071
- struct ggml_context * ctx_compute = NULL;
1072
-
1073
- struct ggml_tensor * loss = NULL;
1074
- struct ggml_tensor * logits = NULL;
1075
-
1076
- struct ggml_cgraph * gf = NULL;
1077
- struct ggml_cgraph * gb = NULL;
1078
- struct ggml_cgraph * gb_tmp = NULL;
1079
-
1080
- // measure required memory for compute tensors
1081
- size_t best_compute_size = SIZE_MAX;
1082
- enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
1083
- // find best evaluation order
1084
- for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
1085
- ctx_compute = ggml_init(ctx_compute_params);
1086
- ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
1087
- gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1088
- gf->order = (enum ggml_cgraph_eval_order) order;
1089
- gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1090
- gb_tmp = params.common.use_checkpointing
1091
- ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
1092
- : NULL;
1093
- loss = llama_build_train_graphs(
1094
- &model, alloc, ctx_compute,
1095
- gf, gb, gb_tmp,
1096
- &logits, tokens_input, target_probs,
1097
- n_tokens, n_batch,
1098
- params.common.use_flash,
1099
- params.common.use_checkpointing,
1100
- true
1101
- );
1102
- size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
1103
- if (max_compute_size < best_compute_size) {
1104
- best_compute_size = max_compute_size;
1105
- best_order = gf->order;
1106
- }
1107
- ggml_free(ctx_compute);
1108
- }
1109
- size_t max_compute_size = best_compute_size;
1110
- printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
1111
- printf("%s: evaluation order = %s\n", __func__,
1112
- (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
1113
- (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
1114
- "invalid");
1115
-
1116
- // allocate compute tensors
1117
- ctx_compute = ggml_init(ctx_compute_params);
1118
- ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
1119
- gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1120
- gf->order = best_order;
1121
- gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1122
- gb_tmp = params.common.use_checkpointing
1123
- ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
1124
- : NULL;
1125
- loss = llama_build_train_graphs(
1126
- &model, alloc, ctx_compute,
1127
- gf, gb, gb_tmp,
1128
- &logits, tokens_input, target_probs,
1129
- n_tokens, n_batch,
1130
- params.common.use_flash,
1131
- params.common.use_checkpointing,
1132
- false
1133
- );
1134
-
1135
- std::vector<llama_token> train_tokens;
1136
- std::vector<size_t> train_samples_begin;
1137
- std::vector<size_t> train_samples_size;
1138
- printf("%s: tokenize training data\n", __func__);
1139
- tokenize_file(lctx,
1140
- params.common.fn_train_data,
1141
- params.common.sample_start,
1142
- params.common.include_sample_start,
1143
- params.common.overlapping_samples,
1144
- n_tokens,
1145
- train_tokens,
1146
- train_samples_begin,
1147
- train_samples_size);
1148
- GGML_ASSERT(train_samples_begin.size() == train_samples_size.size());
1149
-
1150
- printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
1151
-
1152
- size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
1153
- const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
1154
- if (changed_train_data) {
1155
- printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
1156
- }
1157
- if (params.common.force_reshuffle) {
1158
- printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
1159
- }
1160
- if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
1161
- train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
1162
- train->shuffle_sample_count = train_samples_size.size();
1163
- train->shuffle_next_sample = 0;
1164
- train->shuffle_samples_hash = shuffle_samples_hash;
1165
- }
1166
- std::vector<size_t> train_shuffled_samples_offs;
1167
- std::vector<size_t> train_shuffled_samples_begin;
1168
- std::vector<size_t> train_shuffled_samples_size;
1169
- train_shuffled_samples_offs.resize(train_samples_begin.size());
1170
- train_shuffled_samples_begin.resize(train_samples_begin.size());
1171
- train_shuffled_samples_size.resize(train_samples_size.size());
1172
- train->shuffle_rng_state_next = shuffle_samples(
1173
- train->shuffle_rng_state_current,
1174
- train_shuffled_samples_offs.data(),
1175
- train_shuffled_samples_begin.data(),
1176
- train_shuffled_samples_size.data(),
1177
- train_samples_begin.data(),
1178
- train_samples_size.data(),
1179
- train_samples_size.size());
1180
- printf("%s: begin training\n", __func__);
1181
-
1182
- save_train_files_data save_data;
1183
- save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
1184
- save_data.fn_model_out = params.fn_model_out;
1185
- save_data.fn_vocab_model = params.fn_vocab_model;
1186
- save_data.pattern_fn_it = params.common.pattern_fn_it;
1187
- save_data.fn_latest = params.common.fn_latest;
1188
- save_data.model = &model;
1189
-
1190
- struct train_opt_callback_data opt_cb_data;
1191
- opt_cb_data.params = &params.common;
1192
- opt_cb_data.train = train;
1193
- opt_cb_data.save_cb = &save_train_files;
1194
- opt_cb_data.save_data = &save_data;
1195
- opt_cb_data.lctx = lctx;
1196
- opt_cb_data.last_save_iter = opt->iter;
1197
- opt_cb_data.tokens_data = train_tokens.data();
1198
- opt_cb_data.tokens_size = train_tokens.size();
1199
- opt_cb_data.samples_begin = train_samples_begin.data();
1200
- opt_cb_data.samples_size = train_samples_size.data();
1201
- opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
1202
- opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
1203
- opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
1204
- opt_cb_data.samples_count = train_samples_size.size();
1205
- opt_cb_data.tokens_input = tokens_input;
1206
- opt_cb_data.target_probs = target_probs;
1207
- opt_cb_data.first_iter = opt->iter;
1208
- opt_cb_data.first_epoch = train->train_epochs;
1209
- opt_cb_data.iter_at_last_epoch = -1;
1210
- opt_cb_data.last_time = ggml_time_ms();
1211
- opt_cb_data.millis_per_iter = 0.0;
1212
-
1213
- // measure required memory for work buffer
1214
- size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
1215
- printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
1216
-
1217
- // context for work buffer
1218
- struct ggml_init_params ctx_work_params = {
1219
- max_work_size, // mem_size
1220
- NULL, // mem_buffer
1221
- false, // no_alloc
1222
- };
1223
- struct ggml_context * ctx_work = ggml_init(ctx_work_params);
1224
-
1225
- int64_t t0 = ggml_time_ms();
1226
-
1227
- ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
1228
-
1229
- ggml_free(ctx_work);
1230
- ggml_free(ctx_compute);
1231
- ggml_free(ctx_input);
1232
-
1233
- int64_t t1 = ggml_time_ms();
1234
- printf("%s: total training time: ", __func__);
1235
- print_duration((double) (t1 - t0));
1236
- printf("\n");
1237
-
1238
- int new_iters = opt->iter - opt_cb_data.last_save_iter;
1239
- if (new_iters > 0) {
1240
- train->train_its += new_iters;
1241
- train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
1242
-
1243
- save_train_files(&save_data, train);
1244
- opt_cb_data.last_save_iter = opt->iter;
1245
- }
1246
-
1247
- ggml_free(opt->ctx);
1248
- free_train_state(train);
1249
- ggml_free(model.ctx);
1250
- llama_free(lctx);
1251
- llama_free_model(lmodel);
1252
- return 0;
1253
- }