@novastera-oss/llamarn 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/PureCppImpl.cpp +9 -27
  14. package/cpp/SystemUtils.h +2 -2
  15. package/cpp/build-info.cpp +2 -2
  16. package/cpp/llama.cpp/README.md +11 -3
  17. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  18. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  19. package/cpp/llama.cpp/common/arg.cpp +153 -113
  20. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  21. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  22. package/cpp/llama.cpp/common/chat.cpp +847 -699
  23. package/cpp/llama.cpp/common/chat.h +73 -6
  24. package/cpp/llama.cpp/common/common.cpp +50 -82
  25. package/cpp/llama.cpp/common/common.h +21 -17
  26. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  27. package/cpp/llama.cpp/common/json-partial.h +37 -0
  28. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  29. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  30. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  31. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  32. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  33. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  34. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  37. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  38. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  75. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  120. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  121. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  122. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  123. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  124. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  125. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  126. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  127. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  128. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  129. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  130. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  131. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  132. package/cpp/llama.cpp/include/llama.h +62 -125
  133. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  150. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  152. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  159. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  160. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  161. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  162. package/cpp/llama.cpp/models/templates/README.md +2 -0
  163. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  164. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  165. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  166. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  167. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  168. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  169. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  170. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  171. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  172. package/cpp/llama.cpp/src/llama-context.h +30 -0
  173. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  174. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  175. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  176. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  177. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  178. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  179. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  180. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  181. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  182. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  183. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  184. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  185. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  186. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  187. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  188. package/cpp/llama.cpp/src/llama-model.h +6 -1
  189. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  190. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  191. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  192. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  193. package/cpp/llama.cpp/src/llama.cpp +14 -0
  194. package/cpp/rn-completion.cpp +60 -5
  195. package/ios/include/chat.h +73 -6
  196. package/ios/include/common/minja/chat-template.hpp +9 -5
  197. package/ios/include/common/minja/minja.hpp +69 -36
  198. package/ios/include/common.h +21 -17
  199. package/ios/include/llama.h +62 -125
  200. package/ios/libs/llama.xcframework/Info.plist +19 -19
  201. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  205. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  206. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  212. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  213. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  227. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  228. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  233. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  234. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  240. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  241. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  242. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  246. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  247. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  253. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  254. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  255. package/package.json +1 -1
  256. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  257. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  267. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  268. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
28
28
  };
29
29
 
30
30
  struct ggml_opt_context {
31
- ggml_backend_sched_t backend_sched = nullptr;
32
- ggml_cgraph * allocated_graph = nullptr;
33
- ggml_cgraph * allocated_graph_copy = nullptr;
34
- struct ggml_context * ctx_static = nullptr;
35
- struct ggml_context * ctx_static_cpu = nullptr;
36
- struct ggml_context * ctx_compute = nullptr;
37
- struct ggml_context * ctx_copy = nullptr;
38
- ggml_backend_buffer_t buf_static = nullptr;
39
- ggml_backend_buffer_t buf_static_cpu = nullptr;
40
- std::mt19937 rng;
31
+ ggml_backend_sched_t backend_sched = nullptr;
32
+ ggml_cgraph * allocated_graph = nullptr;
33
+ ggml_cgraph * allocated_graph_copy = nullptr;
34
+ struct ggml_context * ctx_static = nullptr;
35
+ struct ggml_context * ctx_cpu = nullptr;
36
+ struct ggml_context * ctx_compute = nullptr;
37
+ struct ggml_context * ctx_copy = nullptr;
38
+ ggml_backend_buffer_t buf_static = nullptr;
39
+ ggml_backend_buffer_t buf_cpu = nullptr;
40
+ std::mt19937 rng;
41
+ enum ggml_opt_loss_type loss_type;
42
+ enum ggml_opt_build_type build_type;
43
+ enum ggml_opt_build_type build_type_alloc;
41
44
 
42
45
  struct ggml_tensor * inputs = nullptr;
43
46
  struct ggml_tensor * outputs = nullptr;
@@ -50,6 +53,11 @@ struct ggml_opt_context {
50
53
  struct ggml_cgraph * gf = nullptr;
51
54
  struct ggml_cgraph * gb_grad = nullptr;
52
55
  struct ggml_cgraph * gb_opt = nullptr;
56
+ bool static_graphs = false;
57
+ bool eval_ready = false;
58
+ std::vector<struct ggml_tensor *> grad_accs;
59
+ std::vector<struct ggml_tensor *> grad_m;
60
+ std::vector<struct ggml_tensor *> grad_v;
53
61
 
54
62
  int64_t iter = 1;
55
63
  int32_t opt_period = 1;
@@ -73,7 +81,13 @@ struct ggml_opt_result {
73
81
 
74
82
  // ====== Dataset ======
75
83
 
76
- ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
84
+ ggml_opt_dataset_t ggml_opt_dataset_init(
85
+ enum ggml_type type_data,
86
+ enum ggml_type type_label,
87
+ int64_t ne_datapoint,
88
+ int64_t ne_label,
89
+ int64_t ndata,
90
+ int64_t ndata_shard) {
77
91
  GGML_ASSERT(ne_datapoint > 0);
78
92
  GGML_ASSERT(ne_label >= 0);
79
93
  GGML_ASSERT(ndata > 0);
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
92
106
  result->ctx = ggml_init(params);
93
107
  }
94
108
 
95
- result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
109
+ result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
96
110
  result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
97
111
 
98
112
  if (ne_label > 0) {
99
- result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
113
+ result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
100
114
  result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
101
115
  } else {
102
116
  result->labels = nullptr;
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
119
133
  delete dataset;
120
134
  }
121
135
 
136
+ int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
137
+ return dataset->ndata;
138
+ }
139
+
122
140
  struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
123
141
  return dataset->data;
124
142
  }
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
144
162
  GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
145
163
  GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
146
164
  GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
165
+ GGML_ASSERT( data_batch->type == dataset->data->type);
166
+ GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
147
167
 
148
168
  const size_t nb_data_batch = ggml_nbytes(data_batch);
149
169
  GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
171
191
  }
172
192
  }
173
193
 
194
+ void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
195
+ GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
196
+ GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
197
+
198
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
199
+
200
+ GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
201
+
202
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
203
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
204
+
205
+ const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
206
+ char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
207
+ memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
208
+
209
+ if (!labels_batch) {
210
+ continue;
211
+ }
212
+
213
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
214
+ char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
215
+ memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
216
+ }
217
+ }
218
+
174
219
  // ====== Model / Context ======
175
220
 
176
221
  struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
187
232
  return result;
188
233
  }
189
234
 
235
+ struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236
+ return *((struct ggml_opt_optimizer_params *) userdata);
237
+ }
238
+
190
239
  struct ggml_opt_params ggml_opt_default_params(
191
240
  ggml_backend_sched_t backend_sched,
192
- struct ggml_context * ctx_compute,
193
- struct ggml_tensor * inputs,
194
- struct ggml_tensor * outputs,
195
241
  enum ggml_opt_loss_type loss_type) {
196
242
  return {
197
243
  /*backend_sched =*/ backend_sched,
198
- /*ctx_compute =*/ ctx_compute,
199
- /*inputs =*/ inputs,
200
- /*logits =*/ outputs,
244
+ /*ctx_compute =*/ nullptr,
245
+ /*inputs =*/ nullptr,
246
+ /*logits =*/ nullptr,
201
247
  /*loss_type =*/ loss_type,
202
248
  /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
203
249
  /*opt_period =*/ 1,
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
266
312
  return dst;
267
313
  }
268
314
 
269
- static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
270
- GGML_ASSERT(graph);
271
- if (opt_ctx->allocated_graph == graph) {
272
- return;
273
- }
315
+ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
316
+ GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
317
+ GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
274
318
 
275
- ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
319
+ const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
320
+ !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
276
321
 
277
- {
278
- ggml_init_params params = {
279
- /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
280
- /*.mem_buffer =*/ nullptr,
281
- /*.no_alloc =*/ true,
282
- };
283
- ggml_free(opt_ctx->ctx_copy);
284
- opt_ctx->ctx_copy = ggml_init(params);
285
- }
286
-
287
- opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
288
-
289
- ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
290
- opt_ctx->allocated_graph = graph;
291
- }
292
-
293
- ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
294
- ggml_opt_context_t result = new struct ggml_opt_context;
295
- result->backend_sched = params.backend_sched;
296
- result->ctx_compute = params.ctx_compute;
297
- result->inputs = params.inputs;
298
- result->outputs = params.outputs;
299
- result->opt_period = params.opt_period;
300
- result->get_opt_pars = params.get_opt_pars;
301
- result->get_opt_pars_ud = params.get_opt_pars_ud;
302
-
303
- GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
304
- GGML_ASSERT(result->opt_period >= 1);
305
-
306
- const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
307
- (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
308
-
309
- ggml_set_input(result->inputs);
310
- ggml_set_output(result->outputs);
311
-
312
- result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
313
- ggml_build_forward_expand(result->gf, result->outputs);
322
+ ggml_set_input(opt_ctx->inputs);
323
+ ggml_set_output(opt_ctx->outputs);
314
324
 
315
325
  int n_param = 0;
316
- for (int i = 0; i < result->gf->n_nodes; ++i) {
317
- if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
326
+ for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
327
+ const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
328
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
318
329
  n_param++;
319
330
  }
331
+ GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
320
332
  }
321
333
 
322
- {
334
+ if (!opt_ctx->ctx_static) {
323
335
  // The static context is used for:
324
- // - gradients (1 tensor per param if using gradient accumulation)
336
+ // - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
325
337
  // - optimizer momenta (2 tensors per param)
326
- // - labels
327
- // - loss + its gradient (up to 5 tensors)
328
- // - pred
329
- // - ncorrect (2 tensors).
330
- const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
331
- const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
338
+ // - labels (if using static graphs)
339
+ // - loss (if using static graphs, up to 5 tensors)
340
+ // - pred (if using static graphs)
341
+ // - ncorrect (if using static graphs, 2 tensors).
342
+ constexpr size_t n_loss = 1;
343
+ const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
+ (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
345
+ const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
+ const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
332
347
  struct ggml_init_params params = {
333
348
  /*.mem_size =*/ size_meta,
334
349
  /*.mem_buffer =*/ nullptr,
335
350
  /*.no_alloc =*/ true,
336
351
  };
337
- result->ctx_static = ggml_init(params);
352
+ opt_ctx->ctx_static = ggml_init(params);
338
353
  }
354
+ GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
355
+
339
356
  {
340
- // The static cpu context is used for:
341
- // - optimizer parameters (1 for the entire context)
357
+ // The cpu context is allocated statically if using static graphs, dynamically otherwise.
358
+ // It is used for:
359
+ // - optimizer parameters (1 shared for all optimizer invocations)
342
360
  const size_t size_meta = 1 * ggml_tensor_overhead();
343
361
  struct ggml_init_params params = {
344
362
  /*.mem_size =*/ size_meta,
345
363
  /*.mem_buffer =*/ nullptr,
346
364
  /*.no_alloc =*/ true,
347
365
  };
348
- result->ctx_static_cpu = ggml_init(params);
366
+ ggml_free(opt_ctx->ctx_cpu);
367
+ opt_ctx->ctx_cpu = ggml_init(params);
368
+
369
+ ggml_backend_buffer_free(opt_ctx->buf_cpu);
370
+ opt_ctx->buf_cpu = nullptr;
349
371
  }
350
372
 
373
+ struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
351
374
 
352
- switch (params.loss_type) {
375
+ switch (opt_ctx->loss_type) {
353
376
  case GGML_OPT_LOSS_TYPE_MEAN: {
354
- result->loss = ggml_sum(result->ctx_static, result->outputs);
355
- ggml_set_name(result->loss, "loss_sum");
356
- const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
357
- result->loss = ggml_scale(result->ctx_static, result->loss, scale);
358
- ggml_set_name(result->loss, "loss_mean");
359
- result->loss_per_datapoint = true;
377
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
378
+ ggml_set_name(opt_ctx->loss, "loss_sum");
379
+ const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
380
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
381
+ ggml_set_name(opt_ctx->loss, "loss_mean");
382
+ opt_ctx->loss_per_datapoint = true;
360
383
  break;
361
384
  }
362
385
  case GGML_OPT_LOSS_TYPE_SUM: {
363
- result->loss = ggml_sum(result->ctx_static, result->outputs);
364
- ggml_set_name(result->loss, "loss_sum");
365
- result->loss_per_datapoint = false;
386
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
387
+ ggml_set_name(opt_ctx->loss, "loss_sum");
388
+ opt_ctx->loss_per_datapoint = false;
366
389
  break;
367
390
  }
368
391
  case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
369
- result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
370
- ggml_set_input(result->labels);
371
- ggml_set_name(result->labels, "labels");
372
- result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
373
- ggml_set_name(result->loss, "loss_cross_entropy");
374
- if (result->opt_period > 1) {
375
- result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
376
- ggml_set_name(result->loss, "loss_cross_entropy_scaled");
392
+ opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
393
+ ggml_set_input(opt_ctx->labels);
394
+ ggml_set_name(opt_ctx->labels, "labels");
395
+ opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
396
+ ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
397
+ if (opt_ctx->opt_period > 1) {
398
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
399
+ ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
377
400
  }
378
- result->loss_per_datapoint = true;
401
+ opt_ctx->loss_per_datapoint = true;
379
402
  break;
380
403
  }
381
404
  case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
382
- result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
383
- ggml_set_input(result->labels);
384
- ggml_set_name(result->labels, "labels");
385
- result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
386
- ggml_set_name(result->loss, "loss_error");
387
- result->loss = ggml_sqr(result->ctx_static, result->loss);
388
- ggml_set_name(result->loss, "loss_squared_error");
389
- result->loss = ggml_sum(result->ctx_static, result->loss);
390
- ggml_set_name(result->loss, "loss_sum_squared_error");
391
- const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
392
- result->loss = ggml_scale(result->ctx_static, result->loss, scale);
393
- ggml_set_name(result->loss, "loss_mean_squared_error");
394
- result->loss_per_datapoint = true;
405
+ opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
406
+ ggml_set_input(opt_ctx->labels);
407
+ ggml_set_name(opt_ctx->labels, "labels");
408
+ opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
409
+ ggml_set_name(opt_ctx->loss, "loss_error");
410
+ opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
411
+ ggml_set_name(opt_ctx->loss, "loss_squared_error");
412
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
413
+ ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
414
+ const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
415
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
416
+ ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
417
+ opt_ctx->loss_per_datapoint = true;
395
418
  break;
396
419
  }
397
420
  }
398
- ggml_set_output(result->loss);
399
- ggml_set_loss(result->loss);
400
- ggml_build_forward_expand(result->gf, result->loss);
401
-
402
- result->pred = ggml_argmax(result->ctx_static, result->outputs);
403
- ggml_set_name(result->pred, "pred");
404
- ggml_set_output(result->pred);
405
- ggml_build_forward_expand(result->gf, result->pred);
421
+ ggml_set_output(opt_ctx->loss);
422
+ ggml_set_loss(opt_ctx->loss);
423
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
424
+
425
+ if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
426
+ opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
427
+ ggml_set_name(opt_ctx->pred, "pred");
428
+ ggml_set_output(opt_ctx->pred);
429
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
430
+
431
+ opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
432
+ ggml_set_name(opt_ctx->ncorrect, "ncorrect");
433
+ ggml_set_output(opt_ctx->ncorrect);
434
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
435
+ }
406
436
 
407
- if (result->labels) {
408
- result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
409
- ggml_set_name(result->ncorrect, "ncorrect");
410
- ggml_set_output(result->ncorrect);
411
- ggml_build_forward_expand(result->gf, result->ncorrect);
412
- } else {
413
- result->ncorrect = nullptr;
437
+ if (opt_ctx->buf_static) {
438
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
439
+ return;
440
+ }
441
+ } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
442
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
443
+ opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
444
+ return;
414
445
  }
415
446
 
416
- if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
417
- result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
418
- return result;
447
+ if (opt_ctx->grad_accs.empty()) {
448
+ GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
449
+
450
+ const int n_nodes = opt_ctx->gf->n_nodes;
451
+ opt_ctx->grad_accs.resize(n_nodes);
452
+ for (int i = 0; i < n_nodes; ++i) {
453
+ ggml_tensor * node = opt_ctx->gf->nodes[i];
454
+ if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
455
+ opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
456
+ } else {
457
+ opt_ctx->grad_accs[i] = nullptr;
458
+ }
459
+ }
460
+
461
+ if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
462
+ opt_ctx->grad_m.resize(n_nodes);
463
+ opt_ctx->grad_v.resize(n_nodes);
464
+ for (int i = 0; i < n_nodes; ++i) {
465
+ ggml_tensor * node = opt_ctx->gf->nodes[i];
466
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
467
+ opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
468
+ opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
469
+ } else {
470
+ opt_ctx->grad_m[i] = nullptr;
471
+ opt_ctx->grad_v[i] = nullptr;
472
+ }
473
+ }
474
+ }
419
475
  }
420
476
 
421
477
  // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
422
- result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
423
- ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
478
+ opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
479
+ ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
424
480
 
425
- if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
426
- result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
427
- ggml_graph_reset(result->gb_grad);
428
- return result;
481
+ if (opt_ctx->buf_static) {
482
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
483
+ return;
484
+ }
485
+ } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
486
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
487
+ ggml_graph_reset(opt_ctx->gb_grad);
429
488
  }
430
489
 
431
- GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
490
+ GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
432
491
 
433
492
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
434
- result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
493
+ opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
435
494
 
436
- result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
437
- ggml_set_input(result->adamw_params);
438
- ggml_set_name(result->adamw_params, "adamw_params");
495
+ opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
496
+ ggml_set_input(opt_ctx->adamw_params);
497
+ ggml_set_name(opt_ctx->adamw_params, "adamw_params");
439
498
 
440
- for (int i = result->gf->n_nodes-1; i >= 0; --i) {
441
- struct ggml_tensor * node = result->gb_opt->nodes[i];
442
- struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
499
+ for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
+ struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
+ struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
443
502
 
444
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
445
- struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
446
- struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
447
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
448
- ggml_build_forward_expand(result->gb_opt, opt_step);
503
+ if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
504
+ struct ggml_tensor * m = opt_ctx->grad_m[i];
505
+ struct ggml_tensor * v = opt_ctx->grad_v[i];
506
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
+
508
+ ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
+ ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
+ ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
+
512
+ ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
449
513
  }
450
514
  }
451
515
 
452
- result->buf_static = ggml_backend_alloc_ctx_tensors(
453
- result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
516
+ if (!opt_ctx->buf_static) {
517
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
518
+ opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
519
+ ggml_graph_reset(opt_ctx->gb_opt);
520
+ }
454
521
 
455
- result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
522
+ opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
523
+ }
456
524
 
457
- ggml_graph_reset(result->gb_opt);
525
+ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
526
+ ggml_opt_context_t result = new struct ggml_opt_context;
527
+ result->backend_sched = params.backend_sched;
528
+ result->ctx_compute = params.ctx_compute;
529
+ result->loss_type = params.loss_type;
530
+ result->build_type = params.build_type;
531
+ result->build_type_alloc = params.build_type;
532
+ result->inputs = params.inputs;
533
+ result->outputs = params.outputs;
534
+ result->opt_period = params.opt_period;
535
+ result->get_opt_pars = params.get_opt_pars;
536
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
537
+
538
+ GGML_ASSERT(result->opt_period >= 1);
539
+
540
+ result->static_graphs = result->ctx_compute;
541
+
542
+ if (!result->static_graphs) {
543
+ GGML_ASSERT(!result->inputs);
544
+ GGML_ASSERT(!result->outputs);
545
+ return result;
546
+ }
547
+
548
+ GGML_ASSERT(result->inputs);
549
+ GGML_ASSERT(result->outputs);
550
+
551
+ result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
552
+ ggml_build_forward_expand(result->gf, result->outputs);
553
+
554
+ ggml_opt_build(result);
458
555
 
459
556
  return result;
460
557
  }
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
464
561
  return;
465
562
  }
466
563
  ggml_backend_buffer_free(opt_ctx->buf_static);
467
- ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
564
+ ggml_backend_buffer_free(opt_ctx->buf_cpu);
468
565
  ggml_free(opt_ctx->ctx_static);
469
- ggml_free(opt_ctx->ctx_static_cpu);
566
+ ggml_free(opt_ctx->ctx_cpu);
470
567
  delete opt_ctx;
471
568
  }
472
569
 
@@ -479,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
479
576
  }
480
577
  }
481
578
 
579
+ bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
580
+ return opt_ctx->static_graphs;
581
+ }
582
+
482
583
  struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
483
584
  return opt_ctx->inputs;
484
585
  }
@@ -582,8 +683,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
582
683
 
583
684
  // ====== Computation ======
584
685
 
585
- static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
586
- if (graph != opt_ctx->gf) {
686
+ void ggml_opt_prepare_alloc(
687
+ ggml_opt_context_t opt_ctx,
688
+ struct ggml_context * ctx_compute,
689
+ struct ggml_cgraph * gf,
690
+ struct ggml_tensor * inputs,
691
+ struct ggml_tensor * outputs) {
692
+ GGML_ASSERT(!opt_ctx->static_graphs);
693
+ opt_ctx->ctx_compute = ctx_compute;
694
+ opt_ctx->gf = gf;
695
+ opt_ctx->inputs = inputs;
696
+ opt_ctx->outputs = outputs;
697
+ }
698
+
699
+ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
700
+ GGML_ASSERT(!opt_ctx->eval_ready);
701
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
702
+ ggml_graph_reset(opt_ctx->gb_grad);
703
+ }
704
+ if (backward) {
705
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
706
+ opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
707
+ } else {
708
+ opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
709
+ }
710
+
711
+ if (!opt_ctx->static_graphs) {
712
+ ggml_opt_build(opt_ctx);
713
+ }
714
+
715
+ struct ggml_cgraph * graph = nullptr;
716
+ switch (opt_ctx->build_type) {
717
+ case GGML_OPT_BUILD_TYPE_FORWARD: {
718
+ graph = opt_ctx->gf;
719
+ } break;
720
+ case GGML_OPT_BUILD_TYPE_GRAD: {
721
+ graph = opt_ctx->gb_grad;
722
+ } break;
723
+ case GGML_OPT_BUILD_TYPE_OPT: {
724
+ graph = opt_ctx->gb_opt;
725
+ } break;
726
+ }
727
+ GGML_ASSERT(graph);
728
+
729
+ if (opt_ctx->allocated_graph == graph) {
730
+ opt_ctx->eval_ready = true;
731
+ return;
732
+ }
733
+
734
+ ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
735
+
736
+ if (opt_ctx->static_graphs) {
737
+ ggml_init_params params = {
738
+ /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
739
+ /*.mem_buffer =*/ nullptr,
740
+ /*.no_alloc =*/ true,
741
+ };
742
+ ggml_free(opt_ctx->ctx_copy);
743
+ opt_ctx->ctx_copy = ggml_init(params);
744
+
745
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
746
+ } else {
747
+ opt_ctx->allocated_graph_copy = graph;
748
+ }
749
+
750
+ ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
751
+ opt_ctx->allocated_graph = graph;
752
+
753
+ opt_ctx->eval_ready = true;
754
+ }
755
+
756
+ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
757
+ GGML_ASSERT(opt_ctx->eval_ready);
758
+ if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
587
759
  struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
588
760
 
589
761
  GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
@@ -609,9 +781,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
609
781
  adamw_par_data[6] = beta2h;
610
782
  }
611
783
 
612
- ggml_opt_alloc_graph(opt_ctx, graph);
613
784
  ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
614
785
  opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
786
+ opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
787
+
788
+ if (!opt_ctx->static_graphs) {
789
+ opt_ctx->gf = nullptr;
790
+ opt_ctx->gb_grad = nullptr;
791
+ opt_ctx->gb_opt = nullptr;
792
+ opt_ctx->allocated_graph = nullptr;
793
+ opt_ctx->allocated_graph_copy = nullptr;
794
+ }
795
+
796
+ opt_ctx->eval_ready = false;
615
797
 
616
798
  if (!result) {
617
799
  return;
@@ -635,12 +817,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
635
817
  ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
636
818
  result->loss.push_back(loss);
637
819
 
638
- GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
639
- std::vector<int32_t> pred(ndata);
640
- ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
641
- result->pred.insert(result->pred.end(), pred.begin(), pred.end());
820
+ if (opt_ctx->pred) {
821
+ GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
822
+ std::vector<int32_t> pred(ndata);
823
+ ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
824
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
825
+ }
642
826
 
643
- if (!opt_ctx->labels || result->ncorrect < 0) {
827
+ if (!opt_ctx->ncorrect || result->ncorrect < 0) {
644
828
  result->ncorrect = -1;
645
829
  return;
646
830
  }
@@ -652,26 +836,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
652
836
  result->ncorrect += ncorrect;
653
837
  }
654
838
 
655
- void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
656
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
657
- }
658
-
659
- void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
660
- if (opt_ctx->opt_period == 1) {
661
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
662
- return;
663
- }
664
-
665
- const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
666
- if (opt_i_next == 0) {
667
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
668
- ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
669
- } else {
670
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
671
- }
672
- opt_ctx->opt_i = opt_i_next;
673
- }
674
-
675
839
  // ====== High-Level Functions ======
676
840
 
677
841
  void ggml_opt_epoch(
@@ -682,6 +846,7 @@ void ggml_opt_epoch(
682
846
  int64_t idata_split,
683
847
  ggml_opt_epoch_callback callback_train,
684
848
  ggml_opt_epoch_callback callback_eval) {
849
+ GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
685
850
  struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
686
851
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
687
852
  struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
@@ -700,16 +865,18 @@ void ggml_opt_epoch(
700
865
  int64_t ibatch = 0;
701
866
  int64_t t_loop_start = ggml_time_us();
702
867
  for (; ibatch < ibatch_split; ++ibatch) {
868
+ ggml_opt_alloc(opt_ctx, /*backward =*/ true);
703
869
  ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
704
- ggml_opt_forward_backward(opt_ctx, result_train);
870
+ ggml_opt_eval(opt_ctx, result_train);
705
871
  if (callback_train) {
706
872
  callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
707
873
  }
708
874
  }
709
875
  t_loop_start = ggml_time_us();
710
876
  for (; ibatch < nbatches; ++ibatch) {
877
+ ggml_opt_alloc(opt_ctx, /*backward =*/ false);
711
878
  ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
712
- ggml_opt_forward(opt_ctx, result_eval);
879
+ ggml_opt_eval(opt_ctx, result_eval);
713
880
  if (callback_eval) {
714
881
  callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
715
882
  }
@@ -726,13 +893,26 @@ void ggml_opt_epoch_callback_progress_bar(
726
893
  int64_t t_start_us) {
727
894
  fprintf(stderr, "%s[", train ? "train: " : "val: ");
728
895
 
729
- constexpr int64_t bar_length = 25;
896
+ // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
897
+ constexpr int64_t bar_length = 8;
898
+ const int64_t ibatch8 = 8 * ibatch;
730
899
  for (int64_t j = 0; j < bar_length; ++j) {
731
- const int64_t ibatch_j = ibatch_max * j/bar_length;
732
- if (ibatch_j < ibatch) {
733
- fprintf(stderr, "=");
734
- } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
735
- fprintf(stderr, ">");
900
+ if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
901
+ fprintf(stderr, "\u2588"); // full block
902
+ } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
903
+ fprintf(stderr, "\u2589"); // 7/8 filled
904
+ } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
905
+ fprintf(stderr, "\u258A"); // 6/8 filled
906
+ } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
907
+ fprintf(stderr, "\u258B"); // 5/8 filled
908
+ } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
909
+ fprintf(stderr, "\u258C"); // 4/8 filled
910
+ } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
911
+ fprintf(stderr, "\u258D"); // 3/8 filled
912
+ } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
913
+ fprintf(stderr, "\u258E"); // 2/8 filled
914
+ } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
915
+ fprintf(stderr, "\u258F"); // 1/8 filled
736
916
  } else {
737
917
  fprintf(stderr, " ");
738
918
  }
@@ -764,8 +944,8 @@ void ggml_opt_epoch_callback_progress_bar(
764
944
  const int64_t t_eta_m = t_eta_s / 60;
765
945
  t_eta_s -= t_eta_m * 60;
766
946
 
767
- fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
768
- "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
947
+ fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
948
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
769
949
  idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
770
950
  t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
771
951
  if (ibatch == ibatch_max) {
@@ -806,7 +986,10 @@ void ggml_opt_fit(
806
986
 
807
987
  int64_t epoch = 1;
808
988
 
809
- ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
989
+ ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
990
+ params.ctx_compute = ctx_compute;
991
+ params.inputs = inputs;
992
+ params.outputs = outputs;
810
993
  params.opt_period = opt_period;
811
994
  params.get_opt_pars = get_opt_pars;
812
995
  params.get_opt_pars_ud = &epoch;