cui-llama.rn 1.6.1 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
  3. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +41 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
package/cpp/ggml-opt.cpp CHANGED
@@ -28,16 +28,19 @@ struct lm_ggml_opt_dataset {
28
28
  };
29
29
 
30
30
  struct lm_ggml_opt_context {
31
- lm_ggml_backend_sched_t backend_sched = nullptr;
32
- lm_ggml_cgraph * allocated_graph = nullptr;
33
- lm_ggml_cgraph * allocated_graph_copy = nullptr;
34
- struct lm_ggml_context * ctx_static = nullptr;
35
- struct lm_ggml_context * ctx_static_cpu = nullptr;
36
- struct lm_ggml_context * ctx_compute = nullptr;
37
- struct lm_ggml_context * ctx_copy = nullptr;
38
- lm_ggml_backend_buffer_t buf_static = nullptr;
39
- lm_ggml_backend_buffer_t buf_static_cpu = nullptr;
40
- std::mt19937 rng;
31
+ lm_ggml_backend_sched_t backend_sched = nullptr;
32
+ lm_ggml_cgraph * allocated_graph = nullptr;
33
+ lm_ggml_cgraph * allocated_graph_copy = nullptr;
34
+ struct lm_ggml_context * ctx_static = nullptr;
35
+ struct lm_ggml_context * ctx_cpu = nullptr;
36
+ struct lm_ggml_context * ctx_compute = nullptr;
37
+ struct lm_ggml_context * ctx_copy = nullptr;
38
+ lm_ggml_backend_buffer_t buf_static = nullptr;
39
+ lm_ggml_backend_buffer_t buf_cpu = nullptr;
40
+ std::mt19937 rng;
41
+ enum lm_ggml_opt_loss_type loss_type;
42
+ enum lm_ggml_opt_build_type build_type;
43
+ enum lm_ggml_opt_build_type build_type_alloc;
41
44
 
42
45
  struct lm_ggml_tensor * inputs = nullptr;
43
46
  struct lm_ggml_tensor * outputs = nullptr;
@@ -50,6 +53,11 @@ struct lm_ggml_opt_context {
50
53
  struct lm_ggml_cgraph * gf = nullptr;
51
54
  struct lm_ggml_cgraph * gb_grad = nullptr;
52
55
  struct lm_ggml_cgraph * gb_opt = nullptr;
56
+ bool static_graphs = false;
57
+ bool eval_ready = false;
58
+ std::vector<struct lm_ggml_tensor *> grad_accs;
59
+ std::vector<struct lm_ggml_tensor *> grad_m;
60
+ std::vector<struct lm_ggml_tensor *> grad_v;
53
61
 
54
62
  int64_t iter = 1;
55
63
  int32_t opt_period = 1;
@@ -73,7 +81,13 @@ struct lm_ggml_opt_result {
73
81
 
74
82
  // ====== Dataset ======
75
83
 
76
- lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
84
+ lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
85
+ enum lm_ggml_type type_data,
86
+ enum lm_ggml_type type_label,
87
+ int64_t ne_datapoint,
88
+ int64_t ne_label,
89
+ int64_t ndata,
90
+ int64_t ndata_shard) {
77
91
  LM_GGML_ASSERT(ne_datapoint > 0);
78
92
  LM_GGML_ASSERT(ne_label >= 0);
79
93
  LM_GGML_ASSERT(ndata > 0);
@@ -92,11 +106,11 @@ lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_
92
106
  result->ctx = lm_ggml_init(params);
93
107
  }
94
108
 
95
- result->data = lm_ggml_new_tensor_2d(result->ctx, LM_GGML_TYPE_F32, ne_datapoint, ndata);
109
+ result->data = lm_ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
96
110
  result->nbs_data = lm_ggml_nbytes(result->data) * ndata_shard/ndata;
97
111
 
98
112
  if (ne_label > 0) {
99
- result->labels = lm_ggml_new_tensor_2d(result->ctx, LM_GGML_TYPE_F32, ne_label, ndata);
113
+ result->labels = lm_ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
100
114
  result->nbs_labels = lm_ggml_nbytes(result->labels) * ndata_shard/ndata;
101
115
  } else {
102
116
  result->labels = nullptr;
@@ -119,6 +133,10 @@ void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset) {
119
133
  delete dataset;
120
134
  }
121
135
 
136
+ int64_t lm_ggml_opt_dataset_ndata(lm_ggml_opt_dataset_t dataset) {
137
+ return dataset->ndata;
138
+ }
139
+
122
140
  struct lm_ggml_tensor * lm_ggml_opt_dataset_data(lm_ggml_opt_dataset_t dataset) {
123
141
  return dataset->data;
124
142
  }
@@ -144,6 +162,8 @@ void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml
144
162
  LM_GGML_ASSERT( data_batch && lm_ggml_is_contiguous(data_batch));
145
163
  LM_GGML_ASSERT(!labels_batch || lm_ggml_is_contiguous(labels_batch));
146
164
  LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
165
+ LM_GGML_ASSERT( data_batch->type == dataset->data->type);
166
+ LM_GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
147
167
 
148
168
  const size_t nb_data_batch = lm_ggml_nbytes(data_batch);
149
169
  LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@@ -171,6 +191,31 @@ void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml
171
191
  }
172
192
  }
173
193
 
194
+ void lm_ggml_opt_dataset_get_batch_host(lm_ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
195
+ LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
196
+ LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
197
+
198
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
199
+
200
+ LM_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
201
+
202
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
203
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
204
+
205
+ const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
206
+ char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
207
+ memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
208
+
209
+ if (!labels_batch) {
210
+ continue;
211
+ }
212
+
213
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
214
+ char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
215
+ memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
216
+ }
217
+ }
218
+
174
219
  // ====== Model / Context ======
175
220
 
176
221
  struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata) {
@@ -187,17 +232,18 @@ struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(voi
187
232
  return result;
188
233
  }
189
234
 
235
+ struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_constant_optimizer_params(void * userdata) {
236
+ return *((struct lm_ggml_opt_optimizer_params *) userdata);
237
+ }
238
+
190
239
  struct lm_ggml_opt_params lm_ggml_opt_default_params(
191
240
  lm_ggml_backend_sched_t backend_sched,
192
- struct lm_ggml_context * ctx_compute,
193
- struct lm_ggml_tensor * inputs,
194
- struct lm_ggml_tensor * outputs,
195
241
  enum lm_ggml_opt_loss_type loss_type) {
196
242
  return {
197
243
  /*backend_sched =*/ backend_sched,
198
- /*ctx_compute =*/ ctx_compute,
199
- /*inputs =*/ inputs,
200
- /*logits =*/ outputs,
244
+ /*ctx_compute =*/ nullptr,
245
+ /*inputs =*/ nullptr,
246
+ /*logits =*/ nullptr,
201
247
  /*loss_type =*/ loss_type,
202
248
  /*build_type =*/ LM_GGML_OPT_BUILD_TYPE_OPT,
203
249
  /*opt_period =*/ 1,
@@ -266,195 +312,246 @@ static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * src) {
266
312
  return dst;
267
313
  }
268
314
 
269
- static void lm_ggml_opt_alloc_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph * graph) {
270
- LM_GGML_ASSERT(graph);
271
- if (opt_ctx->allocated_graph == graph) {
272
- return;
273
- }
315
+ static void lm_ggml_opt_build(lm_ggml_opt_context_t opt_ctx) {
316
+ LM_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with lm_ggml_opt_prepare_alloc");
317
+ LM_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
274
318
 
275
- lm_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
319
+ const bool accumulate = opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD &&
320
+ !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
276
321
 
277
- {
278
- lm_ggml_init_params params = {
279
- /*.mem_size =*/ lm_ggml_tensor_overhead() * LM_GGML_DEFAULT_GRAPH_SIZE,
280
- /*.mem_buffer =*/ nullptr,
281
- /*.no_alloc =*/ true,
282
- };
283
- lm_ggml_free(opt_ctx->ctx_copy);
284
- opt_ctx->ctx_copy = lm_ggml_init(params);
285
- }
286
-
287
- opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
288
-
289
- lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
290
- opt_ctx->allocated_graph = graph;
291
- }
292
-
293
- lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) {
294
- lm_ggml_opt_context_t result = new struct lm_ggml_opt_context;
295
- result->backend_sched = params.backend_sched;
296
- result->ctx_compute = params.ctx_compute;
297
- result->inputs = params.inputs;
298
- result->outputs = params.outputs;
299
- result->opt_period = params.opt_period;
300
- result->get_opt_pars = params.get_opt_pars;
301
- result->get_opt_pars_ud = params.get_opt_pars_ud;
302
-
303
- LM_GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
304
- LM_GGML_ASSERT(result->opt_period >= 1);
305
-
306
- const bool accumulate = params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD ||
307
- (params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
308
-
309
- lm_ggml_set_input(result->inputs);
310
- lm_ggml_set_output(result->outputs);
311
-
312
- result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
313
- lm_ggml_build_forward_expand(result->gf, result->outputs);
322
+ lm_ggml_set_input(opt_ctx->inputs);
323
+ lm_ggml_set_output(opt_ctx->outputs);
314
324
 
315
325
  int n_param = 0;
316
- for (int i = 0; i < result->gf->n_nodes; ++i) {
317
- if (result->gf->nodes[i]->flags & LM_GGML_TENSOR_FLAG_PARAM) {
326
+ for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
327
+ const struct lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
328
+ if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
318
329
  n_param++;
319
330
  }
331
+ LM_GGML_ASSERT(!(node->flags & LM_GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
320
332
  }
321
333
 
322
- {
334
+ if (!opt_ctx->ctx_static) {
323
335
  // The static context is used for:
324
- // - gradients (1 tensor per param if using gradient accumulation)
336
+ // - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
325
337
  // - optimizer momenta (2 tensors per param)
326
- // - labels
327
- // - loss + its gradient (up to 5 tensors)
328
- // - pred
329
- // - ncorrect (2 tensors).
330
- const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
331
- const size_t size_meta = (tensors_per_param*n_param + 9) * lm_ggml_tensor_overhead();
338
+ // - labels (if using static graphs)
339
+ // - loss (if using static graphs, up to 5 tensors)
340
+ // - pred (if using static graphs)
341
+ // - ncorrect (if using static graphs, 2 tensors).
342
+ constexpr size_t n_loss = 1;
343
+ const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
+ (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
345
+ const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
+ const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * lm_ggml_tensor_overhead();
332
347
  struct lm_ggml_init_params params = {
333
348
  /*.mem_size =*/ size_meta,
334
349
  /*.mem_buffer =*/ nullptr,
335
350
  /*.no_alloc =*/ true,
336
351
  };
337
- result->ctx_static = lm_ggml_init(params);
352
+ opt_ctx->ctx_static = lm_ggml_init(params);
338
353
  }
354
+ LM_GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
355
+
339
356
  {
340
- // The static cpu context is used for:
341
- // - optimizer parameters (1 for the entire context)
357
+ // The cpu context is allocated statically if using static graphs, dynamically otherwise.
358
+ // It is used for:
359
+ // - optimizer parameters (1 shared for all optimizer invocations)
342
360
  const size_t size_meta = 1 * lm_ggml_tensor_overhead();
343
361
  struct lm_ggml_init_params params = {
344
362
  /*.mem_size =*/ size_meta,
345
363
  /*.mem_buffer =*/ nullptr,
346
364
  /*.no_alloc =*/ true,
347
365
  };
348
- result->ctx_static_cpu = lm_ggml_init(params);
366
+ lm_ggml_free(opt_ctx->ctx_cpu);
367
+ opt_ctx->ctx_cpu = lm_ggml_init(params);
368
+
369
+ lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
370
+ opt_ctx->buf_cpu = nullptr;
349
371
  }
350
372
 
373
+ struct lm_ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
351
374
 
352
- switch (params.loss_type) {
375
+ switch (opt_ctx->loss_type) {
353
376
  case LM_GGML_OPT_LOSS_TYPE_MEAN: {
354
- result->loss = lm_ggml_sum(result->ctx_static, result->outputs);
355
- lm_ggml_set_name(result->loss, "loss_sum");
356
- const float scale = 1.0f / (result->opt_period * lm_ggml_nelements(result->outputs));
357
- result->loss = lm_ggml_scale(result->ctx_static, result->loss, scale);
358
- lm_ggml_set_name(result->loss, "loss_mean");
359
- result->loss_per_datapoint = true;
377
+ opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
378
+ lm_ggml_set_name(opt_ctx->loss, "loss_sum");
379
+ const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
380
+ opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
381
+ lm_ggml_set_name(opt_ctx->loss, "loss_mean");
382
+ opt_ctx->loss_per_datapoint = true;
360
383
  break;
361
384
  }
362
385
  case LM_GGML_OPT_LOSS_TYPE_SUM: {
363
- result->loss = lm_ggml_sum(result->ctx_static, result->outputs);
364
- lm_ggml_set_name(result->loss, "loss_sum");
365
- result->loss_per_datapoint = false;
386
+ opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
387
+ lm_ggml_set_name(opt_ctx->loss, "loss_sum");
388
+ opt_ctx->loss_per_datapoint = false;
366
389
  break;
367
390
  }
368
391
  case LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
369
- result->labels = lm_ggml_dup_tensor(result->ctx_static, result->outputs);
370
- lm_ggml_set_input(result->labels);
371
- lm_ggml_set_name(result->labels, "labels");
372
- result->loss = lm_ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
373
- lm_ggml_set_name(result->loss, "loss_cross_entropy");
374
- if (result->opt_period > 1) {
375
- result->loss = lm_ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
376
- lm_ggml_set_name(result->loss, "loss_cross_entropy_scaled");
392
+ opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
393
+ lm_ggml_set_input(opt_ctx->labels);
394
+ lm_ggml_set_name(opt_ctx->labels, "labels");
395
+ opt_ctx->loss = lm_ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
396
+ lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
397
+ if (opt_ctx->opt_period > 1) {
398
+ opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
399
+ lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
377
400
  }
378
- result->loss_per_datapoint = true;
401
+ opt_ctx->loss_per_datapoint = true;
379
402
  break;
380
403
  }
381
404
  case LM_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
382
- result->labels = lm_ggml_dup_tensor(result->ctx_static, result->outputs);
383
- lm_ggml_set_input(result->labels);
384
- lm_ggml_set_name(result->labels, "labels");
385
- result->loss = lm_ggml_sub(result->ctx_static, result->outputs, result->labels);
386
- lm_ggml_set_name(result->loss, "loss_error");
387
- result->loss = lm_ggml_sqr(result->ctx_static, result->loss);
388
- lm_ggml_set_name(result->loss, "loss_squared_error");
389
- result->loss = lm_ggml_sum(result->ctx_static, result->loss);
390
- lm_ggml_set_name(result->loss, "loss_sum_squared_error");
391
- const float scale = 1.0f / (result->opt_period * lm_ggml_nelements(result->outputs));
392
- result->loss = lm_ggml_scale(result->ctx_static, result->loss, scale);
393
- lm_ggml_set_name(result->loss, "loss_mean_squared_error");
394
- result->loss_per_datapoint = true;
405
+ opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
406
+ lm_ggml_set_input(opt_ctx->labels);
407
+ lm_ggml_set_name(opt_ctx->labels, "labels");
408
+ opt_ctx->loss = lm_ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
409
+ lm_ggml_set_name(opt_ctx->loss, "loss_error");
410
+ opt_ctx->loss = lm_ggml_sqr(ctx_results, opt_ctx->loss);
411
+ lm_ggml_set_name(opt_ctx->loss, "loss_squared_error");
412
+ opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->loss);
413
+ lm_ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
414
+ const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
415
+ opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
416
+ lm_ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
417
+ opt_ctx->loss_per_datapoint = true;
395
418
  break;
396
419
  }
397
420
  }
398
- lm_ggml_set_output(result->loss);
399
- lm_ggml_set_loss(result->loss);
400
- lm_ggml_build_forward_expand(result->gf, result->loss);
401
-
402
- result->pred = lm_ggml_argmax(result->ctx_static, result->outputs);
403
- lm_ggml_set_name(result->pred, "pred");
404
- lm_ggml_set_output(result->pred);
405
- lm_ggml_build_forward_expand(result->gf, result->pred);
421
+ lm_ggml_set_output(opt_ctx->loss);
422
+ lm_ggml_set_loss(opt_ctx->loss);
423
+ lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
424
+
425
+ if (opt_ctx->loss_type == LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
426
+ opt_ctx->pred = lm_ggml_argmax(ctx_results, opt_ctx->outputs);
427
+ lm_ggml_set_name(opt_ctx->pred, "pred");
428
+ lm_ggml_set_output(opt_ctx->pred);
429
+ lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
430
+
431
+ opt_ctx->ncorrect = lm_ggml_count_equal(ctx_results, opt_ctx->pred, lm_ggml_argmax(ctx_results, opt_ctx->labels));
432
+ lm_ggml_set_name(opt_ctx->ncorrect, "ncorrect");
433
+ lm_ggml_set_output(opt_ctx->ncorrect);
434
+ lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
435
+ }
406
436
 
407
- if (result->labels) {
408
- result->ncorrect = lm_ggml_count_equal(result->ctx_static, result->pred, lm_ggml_argmax(result->ctx_static, result->labels));
409
- lm_ggml_set_name(result->ncorrect, "ncorrect");
410
- lm_ggml_set_output(result->ncorrect);
411
- lm_ggml_build_forward_expand(result->gf, result->ncorrect);
412
- } else {
413
- result->ncorrect = nullptr;
437
+ if (opt_ctx->buf_static) {
438
+ if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
439
+ return;
440
+ }
441
+ } else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
442
+ opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
443
+ opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
444
+ return;
414
445
  }
415
446
 
416
- if (params.build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
417
- result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0));
418
- return result;
447
+ if (opt_ctx->grad_accs.empty()) {
448
+ LM_GGML_ASSERT(opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD);
449
+
450
+ const int n_nodes = opt_ctx->gf->n_nodes;
451
+ opt_ctx->grad_accs.resize(n_nodes);
452
+ for (int i = 0; i < n_nodes; ++i) {
453
+ lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
454
+ if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) {
455
+ opt_ctx->grad_accs[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
456
+ } else {
457
+ opt_ctx->grad_accs[i] = nullptr;
458
+ }
459
+ }
460
+
461
+ if (opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_OPT) {
462
+ opt_ctx->grad_m.resize(n_nodes);
463
+ opt_ctx->grad_v.resize(n_nodes);
464
+ for (int i = 0; i < n_nodes; ++i) {
465
+ lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
466
+ if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
467
+ opt_ctx->grad_m[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
468
+ opt_ctx->grad_v[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
469
+ } else {
470
+ opt_ctx->grad_m[i] = nullptr;
471
+ opt_ctx->grad_v[i] = nullptr;
472
+ }
473
+ }
474
+ }
419
475
  }
420
476
 
421
477
  // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
422
- result->gb_grad = lm_ggml_graph_dup(result->ctx_compute, result->gf);
423
- lm_ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
478
+ opt_ctx->gb_grad = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
479
+ lm_ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
424
480
 
425
- if (params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) {
426
- result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0));
427
- lm_ggml_graph_reset(result->gb_grad);
428
- return result;
481
+ if (opt_ctx->buf_static) {
482
+ if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) {
483
+ return;
484
+ }
485
+ } else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_GRAD) {
486
+ opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
487
+ lm_ggml_graph_reset(opt_ctx->gb_grad);
429
488
  }
430
489
 
431
- LM_GGML_ASSERT(params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT);
490
+ LM_GGML_ASSERT(opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT);
432
491
 
433
492
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
434
- result->gb_opt = lm_ggml_graph_dup(result->ctx_compute, result->gb_grad);
493
+ opt_ctx->gb_opt = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
435
494
 
436
- result->adamw_params = lm_ggml_new_tensor_1d(result->ctx_static_cpu, LM_GGML_TYPE_F32, 7);
437
- lm_ggml_set_input(result->adamw_params);
438
- lm_ggml_set_name(result->adamw_params, "adamw_params");
495
+ opt_ctx->adamw_params = lm_ggml_new_tensor_1d(opt_ctx->ctx_cpu, LM_GGML_TYPE_F32, 7);
496
+ lm_ggml_set_input(opt_ctx->adamw_params);
497
+ lm_ggml_set_name(opt_ctx->adamw_params, "adamw_params");
439
498
 
440
- for (int i = result->gf->n_nodes-1; i >= 0; --i) {
441
- struct lm_ggml_tensor * node = result->gb_opt->nodes[i];
442
- struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(result->gb_opt, node);
499
+ for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
+ struct lm_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
+ struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(opt_ctx->gb_opt, node);
443
502
 
444
- if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
445
- struct lm_ggml_tensor * m = lm_ggml_dup_tensor(result->ctx_static, node);
446
- struct lm_ggml_tensor * v = lm_ggml_dup_tensor(result->ctx_static, node);
447
- struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
448
- lm_ggml_build_forward_expand(result->gb_opt, opt_step);
503
+ if (grad && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
504
+ struct lm_ggml_tensor * m = opt_ctx->grad_m[i];
505
+ struct lm_ggml_tensor * v = opt_ctx->grad_v[i];
506
+ struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
+
508
+ lm_ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
+ lm_ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
+ lm_ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
+
512
+ lm_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
449
513
  }
450
514
  }
451
515
 
452
- result->buf_static = lm_ggml_backend_alloc_ctx_tensors(
453
- result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0));
516
+ if (!opt_ctx->buf_static) {
517
+ opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
518
+ opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
519
+ lm_ggml_graph_reset(opt_ctx->gb_opt);
520
+ }
454
521
 
455
- result->buf_static_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, lm_ggml_backend_cpu_buffer_type());
522
+ opt_ctx->buf_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, lm_ggml_backend_cpu_buffer_type());
523
+ }
456
524
 
457
- lm_ggml_graph_reset(result->gb_opt);
525
+ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) {
526
+ lm_ggml_opt_context_t result = new struct lm_ggml_opt_context;
527
+ result->backend_sched = params.backend_sched;
528
+ result->ctx_compute = params.ctx_compute;
529
+ result->loss_type = params.loss_type;
530
+ result->build_type = params.build_type;
531
+ result->build_type_alloc = params.build_type;
532
+ result->inputs = params.inputs;
533
+ result->outputs = params.outputs;
534
+ result->opt_period = params.opt_period;
535
+ result->get_opt_pars = params.get_opt_pars;
536
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
537
+
538
+ LM_GGML_ASSERT(result->opt_period >= 1);
539
+
540
+ result->static_graphs = result->ctx_compute;
541
+
542
+ if (!result->static_graphs) {
543
+ LM_GGML_ASSERT(!result->inputs);
544
+ LM_GGML_ASSERT(!result->outputs);
545
+ return result;
546
+ }
547
+
548
+ LM_GGML_ASSERT(result->inputs);
549
+ LM_GGML_ASSERT(result->outputs);
550
+
551
+ result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
552
+ lm_ggml_build_forward_expand(result->gf, result->outputs);
553
+
554
+ lm_ggml_opt_build(result);
458
555
 
459
556
  return result;
460
557
  }
@@ -464,9 +561,9 @@ void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx) {
464
561
  return;
465
562
  }
466
563
  lm_ggml_backend_buffer_free(opt_ctx->buf_static);
467
- lm_ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
564
+ lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
468
565
  lm_ggml_free(opt_ctx->ctx_static);
469
- lm_ggml_free(opt_ctx->ctx_static_cpu);
566
+ lm_ggml_free(opt_ctx->ctx_cpu);
470
567
  delete opt_ctx;
471
568
  }
472
569
 
@@ -479,6 +576,10 @@ void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer) {
479
576
  }
480
577
  }
481
578
 
579
+ bool lm_ggml_opt_static_graphs(lm_ggml_opt_context_t opt_ctx) {
580
+ return opt_ctx->static_graphs;
581
+ }
582
+
482
583
  struct lm_ggml_tensor * lm_ggml_opt_inputs(lm_ggml_opt_context_t opt_ctx) {
483
584
  return opt_ctx->inputs;
484
585
  }
@@ -582,8 +683,79 @@ void lm_ggml_opt_result_accuracy(lm_ggml_opt_result_t result, double * accuracy,
582
683
 
583
684
  // ====== Computation ======
584
685
 
585
- static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph * graph, lm_ggml_opt_result * result) {
586
- if (graph != opt_ctx->gf) {
686
+ void lm_ggml_opt_prepare_alloc(
687
+ lm_ggml_opt_context_t opt_ctx,
688
+ struct lm_ggml_context * ctx_compute,
689
+ struct lm_ggml_cgraph * gf,
690
+ struct lm_ggml_tensor * inputs,
691
+ struct lm_ggml_tensor * outputs) {
692
+ LM_GGML_ASSERT(!opt_ctx->static_graphs);
693
+ opt_ctx->ctx_compute = ctx_compute;
694
+ opt_ctx->gf = gf;
695
+ opt_ctx->inputs = inputs;
696
+ opt_ctx->outputs = outputs;
697
+ }
698
+
699
+ void lm_ggml_opt_alloc(lm_ggml_opt_context_t opt_ctx, bool backward) {
700
+ LM_GGML_ASSERT(!opt_ctx->eval_ready);
701
+ if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
702
+ lm_ggml_graph_reset(opt_ctx->gb_grad);
703
+ }
704
+ if (backward) {
705
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
706
+ opt_ctx->build_type = opt_i_next == 0 ? LM_GGML_OPT_BUILD_TYPE_OPT : LM_GGML_OPT_BUILD_TYPE_GRAD;
707
+ } else {
708
+ opt_ctx->build_type = LM_GGML_OPT_BUILD_TYPE_FORWARD;
709
+ }
710
+
711
+ if (!opt_ctx->static_graphs) {
712
+ lm_ggml_opt_build(opt_ctx);
713
+ }
714
+
715
+ struct lm_ggml_cgraph * graph = nullptr;
716
+ switch (opt_ctx->build_type) {
717
+ case LM_GGML_OPT_BUILD_TYPE_FORWARD: {
718
+ graph = opt_ctx->gf;
719
+ } break;
720
+ case LM_GGML_OPT_BUILD_TYPE_GRAD: {
721
+ graph = opt_ctx->gb_grad;
722
+ } break;
723
+ case LM_GGML_OPT_BUILD_TYPE_OPT: {
724
+ graph = opt_ctx->gb_opt;
725
+ } break;
726
+ }
727
+ LM_GGML_ASSERT(graph);
728
+
729
+ if (opt_ctx->allocated_graph == graph) {
730
+ opt_ctx->eval_ready = true;
731
+ return;
732
+ }
733
+
734
+ lm_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
735
+
736
+ if (opt_ctx->static_graphs) {
737
+ lm_ggml_init_params params = {
738
+ /*.mem_size =*/ graph->size*lm_ggml_tensor_overhead() + lm_ggml_graph_overhead_custom(graph->size, graph->grads),
739
+ /*.mem_buffer =*/ nullptr,
740
+ /*.no_alloc =*/ true,
741
+ };
742
+ lm_ggml_free(opt_ctx->ctx_copy);
743
+ opt_ctx->ctx_copy = lm_ggml_init(params);
744
+
745
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
746
+ } else {
747
+ opt_ctx->allocated_graph_copy = graph;
748
+ }
749
+
750
+ lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
751
+ opt_ctx->allocated_graph = graph;
752
+
753
+ opt_ctx->eval_ready = true;
754
+ }
755
+
756
+ void lm_ggml_opt_eval(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result) {
757
+ LM_GGML_ASSERT(opt_ctx->eval_ready);
758
+ if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
587
759
  struct lm_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
588
760
 
589
761
  LM_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
@@ -609,9 +781,19 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
609
781
  adamw_par_data[6] = beta2h;
610
782
  }
611
783
 
612
- lm_ggml_opt_alloc_graph(opt_ctx, graph);
613
784
  lm_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
614
785
  opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
786
+ opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
787
+
788
+ if (!opt_ctx->static_graphs) {
789
+ opt_ctx->gf = nullptr;
790
+ opt_ctx->gb_grad = nullptr;
791
+ opt_ctx->gb_opt = nullptr;
792
+ opt_ctx->allocated_graph = nullptr;
793
+ opt_ctx->allocated_graph_copy = nullptr;
794
+ }
795
+
796
+ opt_ctx->eval_ready = false;
615
797
 
616
798
  if (!result) {
617
799
  return;
@@ -635,12 +817,14 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
635
817
  lm_ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, lm_ggml_nbytes(opt_ctx->loss));
636
818
  result->loss.push_back(loss);
637
819
 
638
- LM_GGML_ASSERT(opt_ctx->pred->type == LM_GGML_TYPE_I32);
639
- std::vector<int32_t> pred(ndata);
640
- lm_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, lm_ggml_nbytes(opt_ctx->pred));
641
- result->pred.insert(result->pred.end(), pred.begin(), pred.end());
820
+ if (opt_ctx->pred) {
821
+ LM_GGML_ASSERT(opt_ctx->pred->type == LM_GGML_TYPE_I32);
822
+ std::vector<int32_t> pred(ndata);
823
+ lm_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, lm_ggml_nbytes(opt_ctx->pred));
824
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
825
+ }
642
826
 
643
- if (!opt_ctx->labels || result->ncorrect < 0) {
827
+ if (!opt_ctx->ncorrect || result->ncorrect < 0) {
644
828
  result->ncorrect = -1;
645
829
  return;
646
830
  }
@@ -652,26 +836,6 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
652
836
  result->ncorrect += ncorrect;
653
837
  }
654
838
 
655
- void lm_ggml_opt_forward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) {
656
- lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
657
- }
658
-
659
- void lm_ggml_opt_forward_backward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) {
660
- if (opt_ctx->opt_period == 1) {
661
- lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
662
- return;
663
- }
664
-
665
- const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
666
- if (opt_i_next == 0) {
667
- lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
668
- lm_ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
669
- } else {
670
- lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
671
- }
672
- opt_ctx->opt_i = opt_i_next;
673
- }
674
-
675
839
  // ====== High-Level Functions ======
676
840
 
677
841
  void lm_ggml_opt_epoch(
@@ -682,6 +846,7 @@ void lm_ggml_opt_epoch(
682
846
  int64_t idata_split,
683
847
  lm_ggml_opt_epoch_callback callback_train,
684
848
  lm_ggml_opt_epoch_callback callback_eval) {
849
+ LM_GGML_ASSERT(lm_ggml_opt_static_graphs(opt_ctx) && "lm_ggml_opt_epoch requires static graphs");
685
850
  struct lm_ggml_tensor * inputs = lm_ggml_opt_inputs(opt_ctx);
686
851
  struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
687
852
  struct lm_ggml_tensor * data = lm_ggml_opt_dataset_data(dataset);
@@ -700,16 +865,18 @@ void lm_ggml_opt_epoch(
700
865
  int64_t ibatch = 0;
701
866
  int64_t t_loop_start = lm_ggml_time_us();
702
867
  for (; ibatch < ibatch_split; ++ibatch) {
868
+ lm_ggml_opt_alloc(opt_ctx, /*backward =*/ true);
703
869
  lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
704
- lm_ggml_opt_forward_backward(opt_ctx, result_train);
870
+ lm_ggml_opt_eval(opt_ctx, result_train);
705
871
  if (callback_train) {
706
872
  callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
707
873
  }
708
874
  }
709
875
  t_loop_start = lm_ggml_time_us();
710
876
  for (; ibatch < nbatches; ++ibatch) {
877
+ lm_ggml_opt_alloc(opt_ctx, /*backward =*/ false);
711
878
  lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
712
- lm_ggml_opt_forward(opt_ctx, result_eval);
879
+ lm_ggml_opt_eval(opt_ctx, result_eval);
713
880
  if (callback_eval) {
714
881
  callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
715
882
  }
@@ -726,13 +893,26 @@ void lm_ggml_opt_epoch_callback_progress_bar(
726
893
  int64_t t_start_us) {
727
894
  fprintf(stderr, "%s[", train ? "train: " : "val: ");
728
895
 
729
- constexpr int64_t bar_length = 25;
896
+ // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
897
+ constexpr int64_t bar_length = 8;
898
+ const int64_t ibatch8 = 8 * ibatch;
730
899
  for (int64_t j = 0; j < bar_length; ++j) {
731
- const int64_t ibatch_j = ibatch_max * j/bar_length;
732
- if (ibatch_j < ibatch) {
733
- fprintf(stderr, "=");
734
- } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
735
- fprintf(stderr, ">");
900
+ if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
901
+ fprintf(stderr, "\u2588"); // full block
902
+ } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
903
+ fprintf(stderr, "\u2589"); // 7/8 filled
904
+ } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
905
+ fprintf(stderr, "\u258A"); // 6/8 filled
906
+ } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
907
+ fprintf(stderr, "\u258B"); // 5/8 filled
908
+ } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
909
+ fprintf(stderr, "\u258C"); // 4/8 filled
910
+ } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
911
+ fprintf(stderr, "\u258D"); // 3/8 filled
912
+ } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
913
+ fprintf(stderr, "\u258E"); // 2/8 filled
914
+ } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
915
+ fprintf(stderr, "\u258F"); // 1/8 filled
736
916
  } else {
737
917
  fprintf(stderr, " ");
738
918
  }
@@ -764,8 +944,8 @@ void lm_ggml_opt_epoch_callback_progress_bar(
764
944
  const int64_t t_eta_m = t_eta_s / 60;
765
945
  t_eta_s -= t_eta_m * 60;
766
946
 
767
- fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
768
- "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
947
+ fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
948
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
769
949
  idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
770
950
  t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
771
951
  if (ibatch == ibatch_max) {
@@ -806,7 +986,10 @@ void lm_ggml_opt_fit(
806
986
 
807
987
  int64_t epoch = 1;
808
988
 
809
- lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
989
+ lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched, loss_type);
990
+ params.ctx_compute = ctx_compute;
991
+ params.inputs = inputs;
992
+ params.outputs = outputs;
810
993
  params.opt_period = opt_period;
811
994
  params.get_opt_pars = get_opt_pars;
812
995
  params.get_opt_pars_ud = &epoch;