@fugood/llama.node 0.3.17 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
28
28
  };
29
29
 
30
30
  struct ggml_opt_context {
31
- ggml_backend_sched_t backend_sched = nullptr;
32
- ggml_cgraph * allocated_graph = nullptr;
33
- ggml_cgraph * allocated_graph_copy = nullptr;
34
- struct ggml_context * ctx_static = nullptr;
35
- struct ggml_context * ctx_static_cpu = nullptr;
36
- struct ggml_context * ctx_compute = nullptr;
37
- struct ggml_context * ctx_copy = nullptr;
38
- ggml_backend_buffer_t buf_static = nullptr;
39
- ggml_backend_buffer_t buf_static_cpu = nullptr;
40
- std::mt19937 rng;
31
+ ggml_backend_sched_t backend_sched = nullptr;
32
+ ggml_cgraph * allocated_graph = nullptr;
33
+ ggml_cgraph * allocated_graph_copy = nullptr;
34
+ struct ggml_context * ctx_static = nullptr;
35
+ struct ggml_context * ctx_cpu = nullptr;
36
+ struct ggml_context * ctx_compute = nullptr;
37
+ struct ggml_context * ctx_copy = nullptr;
38
+ ggml_backend_buffer_t buf_static = nullptr;
39
+ ggml_backend_buffer_t buf_cpu = nullptr;
40
+ std::mt19937 rng;
41
+ enum ggml_opt_loss_type loss_type;
42
+ enum ggml_opt_build_type build_type;
43
+ enum ggml_opt_build_type build_type_alloc;
41
44
 
42
45
  struct ggml_tensor * inputs = nullptr;
43
46
  struct ggml_tensor * outputs = nullptr;
@@ -50,6 +53,11 @@ struct ggml_opt_context {
50
53
  struct ggml_cgraph * gf = nullptr;
51
54
  struct ggml_cgraph * gb_grad = nullptr;
52
55
  struct ggml_cgraph * gb_opt = nullptr;
56
+ bool static_graphs = false;
57
+ bool eval_ready = false;
58
+ std::vector<struct ggml_tensor *> grad_accs;
59
+ std::vector<struct ggml_tensor *> grad_m;
60
+ std::vector<struct ggml_tensor *> grad_v;
53
61
 
54
62
  int64_t iter = 1;
55
63
  int32_t opt_period = 1;
@@ -73,7 +81,13 @@ struct ggml_opt_result {
73
81
 
74
82
  // ====== Dataset ======
75
83
 
76
- ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
84
+ ggml_opt_dataset_t ggml_opt_dataset_init(
85
+ enum ggml_type type_data,
86
+ enum ggml_type type_label,
87
+ int64_t ne_datapoint,
88
+ int64_t ne_label,
89
+ int64_t ndata,
90
+ int64_t ndata_shard) {
77
91
  GGML_ASSERT(ne_datapoint > 0);
78
92
  GGML_ASSERT(ne_label >= 0);
79
93
  GGML_ASSERT(ndata > 0);
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
92
106
  result->ctx = ggml_init(params);
93
107
  }
94
108
 
95
- result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
109
+ result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
96
110
  result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
97
111
 
98
112
  if (ne_label > 0) {
99
- result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
113
+ result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
100
114
  result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
101
115
  } else {
102
116
  result->labels = nullptr;
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
119
133
  delete dataset;
120
134
  }
121
135
 
136
+ int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
137
+ return dataset->ndata;
138
+ }
139
+
122
140
  struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
123
141
  return dataset->data;
124
142
  }
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
144
162
  GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
145
163
  GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
146
164
  GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
165
+ GGML_ASSERT( data_batch->type == dataset->data->type);
166
+ GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
147
167
 
148
168
  const size_t nb_data_batch = ggml_nbytes(data_batch);
149
169
  GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
171
191
  }
172
192
  }
173
193
 
194
+ void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
195
+ GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
196
+ GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
197
+
198
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
199
+
200
+ GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
201
+
202
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
203
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
204
+
205
+ const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
206
+ char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
207
+ memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
208
+
209
+ if (!labels_batch) {
210
+ continue;
211
+ }
212
+
213
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
214
+ char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
215
+ memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
216
+ }
217
+ }
218
+
174
219
  // ====== Model / Context ======
175
220
 
176
221
  struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
187
232
  return result;
188
233
  }
189
234
 
235
+ struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236
+ return *((struct ggml_opt_optimizer_params *) userdata);
237
+ }
238
+
190
239
  struct ggml_opt_params ggml_opt_default_params(
191
240
  ggml_backend_sched_t backend_sched,
192
- struct ggml_context * ctx_compute,
193
- struct ggml_tensor * inputs,
194
- struct ggml_tensor * outputs,
195
241
  enum ggml_opt_loss_type loss_type) {
196
242
  return {
197
243
  /*backend_sched =*/ backend_sched,
198
- /*ctx_compute =*/ ctx_compute,
199
- /*inputs =*/ inputs,
200
- /*logits =*/ outputs,
244
+ /*ctx_compute =*/ nullptr,
245
+ /*inputs =*/ nullptr,
246
+ /*logits =*/ nullptr,
201
247
  /*loss_type =*/ loss_type,
202
248
  /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
203
249
  /*opt_period =*/ 1,
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
266
312
  return dst;
267
313
  }
268
314
 
269
- static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
270
- GGML_ASSERT(graph);
271
- if (opt_ctx->allocated_graph == graph) {
272
- return;
273
- }
274
-
275
- ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
276
-
277
- {
278
- ggml_init_params params = {
279
- /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
280
- /*.mem_buffer =*/ nullptr,
281
- /*.no_alloc =*/ true,
282
- };
283
- ggml_free(opt_ctx->ctx_copy);
284
- opt_ctx->ctx_copy = ggml_init(params);
285
- }
286
-
287
- opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
288
-
289
- ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
290
- opt_ctx->allocated_graph = graph;
291
- }
292
-
293
- ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
294
- ggml_opt_context_t result = new struct ggml_opt_context;
295
- result->backend_sched = params.backend_sched;
296
- result->ctx_compute = params.ctx_compute;
297
- result->inputs = params.inputs;
298
- result->outputs = params.outputs;
299
- result->opt_period = params.opt_period;
300
- result->get_opt_pars = params.get_opt_pars;
301
- result->get_opt_pars_ud = params.get_opt_pars_ud;
302
-
303
- GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
304
- GGML_ASSERT(result->opt_period >= 1);
305
-
306
- const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
307
- (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
315
+ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
316
+ GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
317
+ GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
308
318
 
309
- ggml_set_input(result->inputs);
310
- ggml_set_output(result->outputs);
319
+ const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
320
+ !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
311
321
 
312
- result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
313
- ggml_build_forward_expand(result->gf, result->outputs);
322
+ ggml_set_input(opt_ctx->inputs);
323
+ ggml_set_output(opt_ctx->outputs);
314
324
 
315
325
  int n_param = 0;
316
- for (int i = 0; i < result->gf->n_nodes; ++i) {
317
- if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
326
+ for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
327
+ const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
328
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
318
329
  n_param++;
319
330
  }
331
+ GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
320
332
  }
321
333
 
322
- {
334
+ if (!opt_ctx->ctx_static) {
323
335
  // The static context is used for:
324
- // - gradients (1 tensor per param if using gradient accumulation)
336
+ // - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
325
337
  // - optimizer momenta (2 tensors per param)
326
- // - labels
327
- // - loss + its gradient (up to 5 tensors)
328
- // - pred
329
- // - ncorrect (2 tensors).
330
- const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
331
- const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
338
+ // - labels (if using static graphs)
339
+ // - loss (if using static graphs, up to 5 tensors)
340
+ // - pred (if using static graphs)
341
+ // - ncorrect (if using static graphs, 2 tensors).
342
+ constexpr size_t n_loss = 1;
343
+ const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
+ (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
345
+ const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
+ const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
332
347
  struct ggml_init_params params = {
333
348
  /*.mem_size =*/ size_meta,
334
349
  /*.mem_buffer =*/ nullptr,
335
350
  /*.no_alloc =*/ true,
336
351
  };
337
- result->ctx_static = ggml_init(params);
352
+ opt_ctx->ctx_static = ggml_init(params);
338
353
  }
354
+ GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
355
+
339
356
  {
340
- // The static cpu context is used for:
341
- // - optimizer parameters (1 for the entire context)
357
+ // The cpu context is allocated statically if using static graphs, dynamically otherwise.
358
+ // It is used for:
359
+ // - optimizer parameters (1 shared for all optimizer invocations)
342
360
  const size_t size_meta = 1 * ggml_tensor_overhead();
343
361
  struct ggml_init_params params = {
344
362
  /*.mem_size =*/ size_meta,
345
363
  /*.mem_buffer =*/ nullptr,
346
364
  /*.no_alloc =*/ true,
347
365
  };
348
- result->ctx_static_cpu = ggml_init(params);
366
+ ggml_free(opt_ctx->ctx_cpu);
367
+ opt_ctx->ctx_cpu = ggml_init(params);
368
+
369
+ ggml_backend_buffer_free(opt_ctx->buf_cpu);
370
+ opt_ctx->buf_cpu = nullptr;
349
371
  }
350
372
 
373
+ struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
351
374
 
352
- switch (params.loss_type) {
375
+ switch (opt_ctx->loss_type) {
353
376
  case GGML_OPT_LOSS_TYPE_MEAN: {
354
- result->loss = ggml_sum(result->ctx_static, result->outputs);
355
- ggml_set_name(result->loss, "loss_sum");
356
- const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
357
- result->loss = ggml_scale(result->ctx_static, result->loss, scale);
358
- ggml_set_name(result->loss, "loss_mean");
359
- result->loss_per_datapoint = true;
377
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
378
+ ggml_set_name(opt_ctx->loss, "loss_sum");
379
+ const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
380
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
381
+ ggml_set_name(opt_ctx->loss, "loss_mean");
382
+ opt_ctx->loss_per_datapoint = true;
360
383
  break;
361
384
  }
362
385
  case GGML_OPT_LOSS_TYPE_SUM: {
363
- result->loss = ggml_sum(result->ctx_static, result->outputs);
364
- ggml_set_name(result->loss, "loss_sum");
365
- result->loss_per_datapoint = false;
386
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
387
+ ggml_set_name(opt_ctx->loss, "loss_sum");
388
+ opt_ctx->loss_per_datapoint = false;
366
389
  break;
367
390
  }
368
391
  case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
369
- result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
370
- ggml_set_input(result->labels);
371
- ggml_set_name(result->labels, "labels");
372
- result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
373
- ggml_set_name(result->loss, "loss_cross_entropy");
374
- if (result->opt_period > 1) {
375
- result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
376
- ggml_set_name(result->loss, "loss_cross_entropy_scaled");
392
+ opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
393
+ ggml_set_input(opt_ctx->labels);
394
+ ggml_set_name(opt_ctx->labels, "labels");
395
+ opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
396
+ ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
397
+ if (opt_ctx->opt_period > 1) {
398
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
399
+ ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
377
400
  }
378
- result->loss_per_datapoint = true;
401
+ opt_ctx->loss_per_datapoint = true;
379
402
  break;
380
403
  }
381
404
  case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
382
- result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
383
- ggml_set_input(result->labels);
384
- ggml_set_name(result->labels, "labels");
385
- result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
386
- ggml_set_name(result->loss, "loss_error");
387
- result->loss = ggml_sqr(result->ctx_static, result->loss);
388
- ggml_set_name(result->loss, "loss_squared_error");
389
- result->loss = ggml_sum(result->ctx_static, result->loss);
390
- ggml_set_name(result->loss, "loss_sum_squared_error");
391
- const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
392
- result->loss = ggml_scale(result->ctx_static, result->loss, scale);
393
- ggml_set_name(result->loss, "loss_mean_squared_error");
394
- result->loss_per_datapoint = true;
405
+ opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
406
+ ggml_set_input(opt_ctx->labels);
407
+ ggml_set_name(opt_ctx->labels, "labels");
408
+ opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
409
+ ggml_set_name(opt_ctx->loss, "loss_error");
410
+ opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
411
+ ggml_set_name(opt_ctx->loss, "loss_squared_error");
412
+ opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
413
+ ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
414
+ const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
415
+ opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
416
+ ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
417
+ opt_ctx->loss_per_datapoint = true;
395
418
  break;
396
419
  }
397
420
  }
398
- ggml_set_output(result->loss);
399
- ggml_set_loss(result->loss);
400
- ggml_build_forward_expand(result->gf, result->loss);
401
-
402
- result->pred = ggml_argmax(result->ctx_static, result->outputs);
403
- ggml_set_name(result->pred, "pred");
404
- ggml_set_output(result->pred);
405
- ggml_build_forward_expand(result->gf, result->pred);
421
+ ggml_set_output(opt_ctx->loss);
422
+ ggml_set_loss(opt_ctx->loss);
423
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
424
+
425
+ if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
426
+ opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
427
+ ggml_set_name(opt_ctx->pred, "pred");
428
+ ggml_set_output(opt_ctx->pred);
429
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
430
+
431
+ opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
432
+ ggml_set_name(opt_ctx->ncorrect, "ncorrect");
433
+ ggml_set_output(opt_ctx->ncorrect);
434
+ ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
435
+ }
406
436
 
407
- if (result->labels) {
408
- result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
409
- ggml_set_name(result->ncorrect, "ncorrect");
410
- ggml_set_output(result->ncorrect);
411
- ggml_build_forward_expand(result->gf, result->ncorrect);
412
- } else {
413
- result->ncorrect = nullptr;
437
+ if (opt_ctx->buf_static) {
438
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
439
+ return;
440
+ }
441
+ } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
442
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
443
+ opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
444
+ return;
414
445
  }
415
446
 
416
- if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
417
- result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
418
- return result;
447
+ if (opt_ctx->grad_accs.empty()) {
448
+ GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
449
+
450
+ const int n_nodes = opt_ctx->gf->n_nodes;
451
+ opt_ctx->grad_accs.resize(n_nodes);
452
+ for (int i = 0; i < n_nodes; ++i) {
453
+ ggml_tensor * node = opt_ctx->gf->nodes[i];
454
+ if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
455
+ opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
456
+ } else {
457
+ opt_ctx->grad_accs[i] = nullptr;
458
+ }
459
+ }
460
+
461
+ if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
462
+ opt_ctx->grad_m.resize(n_nodes);
463
+ opt_ctx->grad_v.resize(n_nodes);
464
+ for (int i = 0; i < n_nodes; ++i) {
465
+ ggml_tensor * node = opt_ctx->gf->nodes[i];
466
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
467
+ opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
468
+ opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
469
+ } else {
470
+ opt_ctx->grad_m[i] = nullptr;
471
+ opt_ctx->grad_v[i] = nullptr;
472
+ }
473
+ }
474
+ }
419
475
  }
420
476
 
421
477
  // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
422
- result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
423
- ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
478
+ opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
479
+ ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
424
480
 
425
- if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
426
- result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
427
- ggml_graph_reset(result->gb_grad);
428
- return result;
481
+ if (opt_ctx->buf_static) {
482
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
483
+ return;
484
+ }
485
+ } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
486
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
487
+ ggml_graph_reset(opt_ctx->gb_grad);
429
488
  }
430
489
 
431
- GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
490
+ GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
432
491
 
433
492
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
434
- result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
493
+ opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
435
494
 
436
- result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
437
- ggml_set_input(result->adamw_params);
438
- ggml_set_name(result->adamw_params, "adamw_params");
495
+ opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
496
+ ggml_set_input(opt_ctx->adamw_params);
497
+ ggml_set_name(opt_ctx->adamw_params, "adamw_params");
439
498
 
440
- for (int i = result->gf->n_nodes-1; i >= 0; --i) {
441
- struct ggml_tensor * node = result->gb_opt->nodes[i];
442
- struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
499
+ for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
+ struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
+ struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
443
502
 
444
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
445
- struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
446
- struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
447
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
448
- ggml_build_forward_expand(result->gb_opt, opt_step);
503
+ if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
504
+ struct ggml_tensor * m = opt_ctx->grad_m[i];
505
+ struct ggml_tensor * v = opt_ctx->grad_v[i];
506
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
+
508
+ ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
+ ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
+ ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
+
512
+ ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
449
513
  }
450
514
  }
451
515
 
452
- result->buf_static = ggml_backend_alloc_ctx_tensors(
453
- result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
516
+ if (!opt_ctx->buf_static) {
517
+ opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
518
+ opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
519
+ ggml_graph_reset(opt_ctx->gb_opt);
520
+ }
454
521
 
455
- result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
522
+ opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
523
+ }
456
524
 
457
- ggml_graph_reset(result->gb_opt);
525
+ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
526
+ ggml_opt_context_t result = new struct ggml_opt_context;
527
+ result->backend_sched = params.backend_sched;
528
+ result->ctx_compute = params.ctx_compute;
529
+ result->loss_type = params.loss_type;
530
+ result->build_type = params.build_type;
531
+ result->build_type_alloc = params.build_type;
532
+ result->inputs = params.inputs;
533
+ result->outputs = params.outputs;
534
+ result->opt_period = params.opt_period;
535
+ result->get_opt_pars = params.get_opt_pars;
536
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
537
+
538
+ GGML_ASSERT(result->opt_period >= 1);
539
+
540
+ result->static_graphs = result->ctx_compute;
541
+
542
+ if (!result->static_graphs) {
543
+ GGML_ASSERT(!result->inputs);
544
+ GGML_ASSERT(!result->outputs);
545
+ return result;
546
+ }
547
+
548
+ GGML_ASSERT(result->inputs);
549
+ GGML_ASSERT(result->outputs);
550
+
551
+ result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
552
+ ggml_build_forward_expand(result->gf, result->outputs);
553
+
554
+ ggml_opt_build(result);
458
555
 
459
556
  return result;
460
557
  }
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
464
561
  return;
465
562
  }
466
563
  ggml_backend_buffer_free(opt_ctx->buf_static);
467
- ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
564
+ ggml_backend_buffer_free(opt_ctx->buf_cpu);
468
565
  ggml_free(opt_ctx->ctx_static);
469
- ggml_free(opt_ctx->ctx_static_cpu);
566
+ ggml_free(opt_ctx->ctx_cpu);
470
567
  delete opt_ctx;
471
568
  }
472
569
 
@@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
582
679
 
583
680
  // ====== Computation ======
584
681
 
585
- static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
586
- if (graph != opt_ctx->gf) {
682
+ void ggml_opt_prepare_alloc(
683
+ ggml_opt_context_t opt_ctx,
684
+ struct ggml_context * ctx_compute,
685
+ struct ggml_cgraph * gf,
686
+ struct ggml_tensor * inputs,
687
+ struct ggml_tensor * outputs) {
688
+ GGML_ASSERT(!opt_ctx->static_graphs);
689
+ opt_ctx->ctx_compute = ctx_compute;
690
+ opt_ctx->gf = gf;
691
+ opt_ctx->inputs = inputs;
692
+ opt_ctx->outputs = outputs;
693
+ }
694
+
695
+ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
696
+ GGML_ASSERT(!opt_ctx->eval_ready);
697
+ if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
698
+ ggml_graph_reset(opt_ctx->gb_grad);
699
+ }
700
+ if (backward) {
701
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
702
+ opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
703
+ } else {
704
+ opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
705
+ }
706
+
707
+ if (!opt_ctx->static_graphs) {
708
+ ggml_opt_build(opt_ctx);
709
+ }
710
+
711
+ struct ggml_cgraph * graph = nullptr;
712
+ switch (opt_ctx->build_type) {
713
+ case GGML_OPT_BUILD_TYPE_FORWARD: {
714
+ graph = opt_ctx->gf;
715
+ } break;
716
+ case GGML_OPT_BUILD_TYPE_GRAD: {
717
+ graph = opt_ctx->gb_grad;
718
+ } break;
719
+ case GGML_OPT_BUILD_TYPE_OPT: {
720
+ graph = opt_ctx->gb_opt;
721
+ } break;
722
+ }
723
+ GGML_ASSERT(graph);
724
+
725
+ if (opt_ctx->allocated_graph == graph) {
726
+ opt_ctx->eval_ready = true;
727
+ return;
728
+ }
729
+
730
+ ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
731
+
732
+ if (opt_ctx->static_graphs) {
733
+ ggml_init_params params = {
734
+ /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
735
+ /*.mem_buffer =*/ nullptr,
736
+ /*.no_alloc =*/ true,
737
+ };
738
+ ggml_free(opt_ctx->ctx_copy);
739
+ opt_ctx->ctx_copy = ggml_init(params);
740
+
741
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
742
+ } else {
743
+ opt_ctx->allocated_graph_copy = graph;
744
+ }
745
+
746
+ ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
747
+ opt_ctx->allocated_graph = graph;
748
+
749
+ opt_ctx->eval_ready = true;
750
+ }
751
+
752
+ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
753
+ GGML_ASSERT(opt_ctx->eval_ready);
754
+ if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
587
755
  struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
588
756
 
589
757
  GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
@@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
609
777
  adamw_par_data[6] = beta2h;
610
778
  }
611
779
 
612
- ggml_opt_alloc_graph(opt_ctx, graph);
613
780
  ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
614
781
  opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
782
+ opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
783
+
784
+ if (!opt_ctx->static_graphs) {
785
+ opt_ctx->gf = nullptr;
786
+ opt_ctx->gb_grad = nullptr;
787
+ opt_ctx->gb_opt = nullptr;
788
+ opt_ctx->allocated_graph = nullptr;
789
+ opt_ctx->allocated_graph_copy = nullptr;
790
+ }
791
+
792
+ opt_ctx->eval_ready = false;
615
793
 
616
794
  if (!result) {
617
795
  return;
@@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
635
813
  ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
636
814
  result->loss.push_back(loss);
637
815
 
638
- GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
639
- std::vector<int32_t> pred(ndata);
640
- ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
641
- result->pred.insert(result->pred.end(), pred.begin(), pred.end());
816
+ if (opt_ctx->pred) {
817
+ GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
818
+ std::vector<int32_t> pred(ndata);
819
+ ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
820
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
821
+ }
642
822
 
643
- if (!opt_ctx->labels || result->ncorrect < 0) {
823
+ if (!opt_ctx->ncorrect || result->ncorrect < 0) {
644
824
  result->ncorrect = -1;
645
825
  return;
646
826
  }
@@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
652
832
  result->ncorrect += ncorrect;
653
833
  }
654
834
 
655
- void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
656
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
657
- }
658
-
659
- void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
660
- if (opt_ctx->opt_period == 1) {
661
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
662
- return;
663
- }
664
-
665
- const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
666
- if (opt_i_next == 0) {
667
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
668
- ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
669
- } else {
670
- ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
671
- }
672
- opt_ctx->opt_i = opt_i_next;
673
- }
674
-
675
835
  // ====== High-Level Functions ======
676
836
 
677
837
  void ggml_opt_epoch(
@@ -700,16 +860,18 @@ void ggml_opt_epoch(
700
860
  int64_t ibatch = 0;
701
861
  int64_t t_loop_start = ggml_time_us();
702
862
  for (; ibatch < ibatch_split; ++ibatch) {
863
+ ggml_opt_alloc(opt_ctx, /*backward =*/ true);
703
864
  ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
704
- ggml_opt_forward_backward(opt_ctx, result_train);
865
+ ggml_opt_eval(opt_ctx, result_train);
705
866
  if (callback_train) {
706
867
  callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
707
868
  }
708
869
  }
709
870
  t_loop_start = ggml_time_us();
710
871
  for (; ibatch < nbatches; ++ibatch) {
872
+ ggml_opt_alloc(opt_ctx, /*backward =*/ false);
711
873
  ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
712
- ggml_opt_forward(opt_ctx, result_eval);
874
+ ggml_opt_eval(opt_ctx, result_eval);
713
875
  if (callback_eval) {
714
876
  callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
715
877
  }
@@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
726
888
  int64_t t_start_us) {
727
889
  fprintf(stderr, "%s[", train ? "train: " : "val: ");
728
890
 
729
- constexpr int64_t bar_length = 25;
891
+ // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
892
+ constexpr int64_t bar_length = 8;
893
+ const int64_t ibatch8 = 8 * ibatch;
730
894
  for (int64_t j = 0; j < bar_length; ++j) {
731
- const int64_t ibatch_j = ibatch_max * j/bar_length;
732
- if (ibatch_j < ibatch) {
733
- fprintf(stderr, "=");
734
- } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
735
- fprintf(stderr, ">");
895
+ if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
896
+ fprintf(stderr, "\u2588"); // full block
897
+ } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
898
+ fprintf(stderr, "\u2589"); // 7/8 filled
899
+ } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
900
+ fprintf(stderr, "\u258A"); // 6/8 filled
901
+ } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
902
+ fprintf(stderr, "\u258B"); // 5/8 filled
903
+ } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
904
+ fprintf(stderr, "\u258C"); // 4/8 filled
905
+ } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
906
+ fprintf(stderr, "\u258D"); // 3/8 filled
907
+ } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
908
+ fprintf(stderr, "\u258E"); // 2/8 filled
909
+ } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
910
+ fprintf(stderr, "\u258F"); // 1/8 filled
736
911
  } else {
737
912
  fprintf(stderr, " ");
738
913
  }
@@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
764
939
  const int64_t t_eta_m = t_eta_s / 60;
765
940
  t_eta_s -= t_eta_m * 60;
766
941
 
767
- fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
768
- "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
942
+ fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
943
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
769
944
  idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
770
945
  t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
771
946
  if (ibatch == ibatch_max) {
@@ -806,7 +981,10 @@ void ggml_opt_fit(
806
981
 
807
982
  int64_t epoch = 1;
808
983
 
809
- ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
984
+ ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
985
+ params.ctx_compute = ctx_compute;
986
+ params.inputs = inputs;
987
+ params.outputs = outputs;
810
988
  params.opt_period = opt_period;
811
989
  params.get_opt_pars = get_opt_pars;
812
990
  params.get_opt_pars_ud = &epoch;