whisper.rn 0.4.0-rc.1 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +21 -1
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -92
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +86 -40
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +85 -131
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +226 -109
  9. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  10. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  11. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  12. package/cpp/coreml/whisper-encoder.h +4 -0
  13. package/cpp/coreml/whisper-encoder.mm +5 -3
  14. package/cpp/ggml-alloc.c +797 -400
  15. package/cpp/ggml-alloc.h +60 -10
  16. package/cpp/ggml-backend-impl.h +255 -0
  17. package/cpp/ggml-backend-reg.cpp +582 -0
  18. package/cpp/ggml-backend.cpp +2002 -0
  19. package/cpp/ggml-backend.h +354 -0
  20. package/cpp/ggml-common.h +1851 -0
  21. package/cpp/ggml-cpp.h +39 -0
  22. package/cpp/ggml-cpu-aarch64.cpp +4247 -0
  23. package/cpp/ggml-cpu-aarch64.h +8 -0
  24. package/cpp/ggml-cpu-impl.h +531 -0
  25. package/cpp/ggml-cpu-quants.c +12245 -0
  26. package/cpp/ggml-cpu-quants.h +63 -0
  27. package/cpp/ggml-cpu-traits.cpp +36 -0
  28. package/cpp/ggml-cpu-traits.h +38 -0
  29. package/cpp/ggml-cpu.c +14792 -0
  30. package/cpp/ggml-cpu.cpp +653 -0
  31. package/cpp/ggml-cpu.h +137 -0
  32. package/cpp/ggml-impl.h +567 -0
  33. package/cpp/ggml-metal-impl.h +288 -0
  34. package/cpp/ggml-metal.h +24 -43
  35. package/cpp/ggml-metal.m +4867 -1080
  36. package/cpp/ggml-opt.cpp +854 -0
  37. package/cpp/ggml-opt.h +216 -0
  38. package/cpp/ggml-quants.c +5238 -0
  39. package/cpp/ggml-quants.h +100 -0
  40. package/cpp/ggml-threading.cpp +12 -0
  41. package/cpp/ggml-threading.h +14 -0
  42. package/cpp/ggml-whisper.metallib +0 -0
  43. package/cpp/ggml.c +5106 -19431
  44. package/cpp/ggml.h +847 -669
  45. package/cpp/gguf.cpp +1329 -0
  46. package/cpp/gguf.h +202 -0
  47. package/cpp/rn-audioutils.cpp +68 -0
  48. package/cpp/rn-audioutils.h +14 -0
  49. package/cpp/rn-whisper-log.h +11 -0
  50. package/cpp/rn-whisper.cpp +221 -52
  51. package/cpp/rn-whisper.h +50 -15
  52. package/cpp/whisper.cpp +3174 -1533
  53. package/cpp/whisper.h +176 -44
  54. package/ios/RNWhisper.mm +139 -46
  55. package/ios/RNWhisperAudioUtils.h +1 -2
  56. package/ios/RNWhisperAudioUtils.m +18 -67
  57. package/ios/RNWhisperContext.h +11 -8
  58. package/ios/RNWhisperContext.mm +195 -150
  59. package/jest/mock.js +15 -2
  60. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  61. package/lib/commonjs/index.js +76 -28
  62. package/lib/commonjs/index.js.map +1 -1
  63. package/lib/commonjs/version.json +1 -1
  64. package/lib/module/NativeRNWhisper.js.map +1 -1
  65. package/lib/module/index.js +76 -28
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/module/version.json +1 -1
  68. package/lib/typescript/NativeRNWhisper.d.ts +13 -4
  69. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +37 -5
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +9 -7
  73. package/src/NativeRNWhisper.ts +20 -4
  74. package/src/index.ts +98 -42
  75. package/src/version.json +1 -1
  76. package/whisper-rn.podspec +13 -20
  77. package/cpp/README.md +0 -4
  78. package/cpp/ggml-metal.metal +0 -2353
@@ -0,0 +1,854 @@
1
+ #include "ggml-opt.h"
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-alloc.h"
5
+ #include "ggml-backend.h"
6
+ #include "ggml-impl.h"
7
+
8
+ #include <algorithm>
9
+ #include <cmath>
10
+ #include <cstdint>
11
+ #include <cinttypes>
12
+ #include <map>
13
+ #include <random>
14
+ #include <vector>
15
+
16
+ struct wsp_ggml_opt_dataset {
17
+ struct wsp_ggml_context * ctx = nullptr;
18
+ wsp_ggml_backend_buffer_t buf = nullptr;
19
+ struct wsp_ggml_tensor * data = nullptr;
20
+ struct wsp_ggml_tensor * labels = nullptr;
21
+
22
+ int64_t ndata = -1;
23
+ int64_t ndata_shard = -1;
24
+ size_t nbs_data = -1;
25
+ size_t nbs_labels = -1;
26
+
27
+ std::vector<int64_t> permutation;
28
+ };
29
+
30
+ struct wsp_ggml_opt_context {
31
+ wsp_ggml_backend_sched_t backend_sched = nullptr;
32
+ wsp_ggml_cgraph * allocated_graph = nullptr;
33
+ wsp_ggml_cgraph * allocated_graph_copy = nullptr;
34
+ struct wsp_ggml_context * ctx_static = nullptr;
35
+ struct wsp_ggml_context * ctx_static_cpu = nullptr;
36
+ struct wsp_ggml_context * ctx_compute = nullptr;
37
+ struct wsp_ggml_context * ctx_copy = nullptr;
38
+ wsp_ggml_backend_buffer_t buf_static = nullptr;
39
+ wsp_ggml_backend_buffer_t buf_static_cpu = nullptr;
40
+ std::mt19937 rng;
41
+
42
+ struct wsp_ggml_tensor * inputs = nullptr;
43
+ struct wsp_ggml_tensor * outputs = nullptr;
44
+ struct wsp_ggml_tensor * labels = nullptr;
45
+
46
+ struct wsp_ggml_tensor * loss = nullptr;
47
+ struct wsp_ggml_tensor * pred = nullptr;
48
+ struct wsp_ggml_tensor * ncorrect = nullptr;
49
+
50
+ struct wsp_ggml_cgraph * gf = nullptr;
51
+ struct wsp_ggml_cgraph * gb_grad = nullptr;
52
+ struct wsp_ggml_cgraph * gb_opt = nullptr;
53
+
54
+ int64_t iter = 1;
55
+ int32_t opt_period = 1;
56
+ int32_t opt_i = 0;
57
+ bool loss_per_datapoint = false;
58
+
59
+ wsp_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
60
+ void * get_opt_pars_ud = nullptr;
61
+ struct wsp_ggml_tensor * adamw_params = nullptr;
62
+ };
63
+
64
+ struct wsp_ggml_opt_result {
65
+ int64_t ndata = 0;
66
+ std::vector<float> loss;
67
+ std::vector<int32_t> pred;
68
+ int64_t ncorrect = 0;
69
+
70
+ int64_t opt_period = -1;
71
+ bool loss_per_datapoint = false;
72
+ };
73
+
74
+ // ====== Dataset ======
75
+
76
+ wsp_ggml_opt_dataset_t wsp_ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
77
+ WSP_GGML_ASSERT(ne_datapoint > 0);
78
+ WSP_GGML_ASSERT(ne_label >= 0);
79
+ WSP_GGML_ASSERT(ndata > 0);
80
+ WSP_GGML_ASSERT(ndata_shard > 0);
81
+
82
+ wsp_ggml_opt_dataset_t result = new wsp_ggml_opt_dataset;
83
+ result->ndata = ndata;
84
+ result->ndata_shard = ndata_shard;
85
+
86
+ {
87
+ struct wsp_ggml_init_params params = {
88
+ /*.mem_size =*/ 2*wsp_ggml_tensor_overhead(),
89
+ /*.mem_buffer =*/ nullptr,
90
+ /*.no_alloc =*/ true,
91
+ };
92
+ result->ctx = wsp_ggml_init(params);
93
+ }
94
+
95
+ result->data = wsp_ggml_new_tensor_2d(result->ctx, WSP_GGML_TYPE_F32, ne_datapoint, ndata);
96
+ result->nbs_data = wsp_ggml_nbytes(result->data) * ndata_shard/ndata;
97
+
98
+ if (ne_label > 0) {
99
+ result->labels = wsp_ggml_new_tensor_2d(result->ctx, WSP_GGML_TYPE_F32, ne_label, ndata);
100
+ result->nbs_labels = wsp_ggml_nbytes(result->labels) * ndata_shard/ndata;
101
+ } else {
102
+ result->labels = nullptr;
103
+ result->nbs_labels = 0;
104
+ }
105
+
106
+ result->buf = wsp_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, wsp_ggml_backend_cpu_buffer_type());
107
+
108
+ const int64_t nshards = ndata/ndata_shard;
109
+ result->permutation.resize(nshards);
110
+ for (int64_t i = 0; i < nshards; ++i) {
111
+ result->permutation[i] = i;
112
+ }
113
+ return result;
114
+ }
115
+
116
+ void wsp_ggml_opt_dataset_free(wsp_ggml_opt_dataset_t dataset) {
117
+ wsp_ggml_backend_buffer_free(dataset->buf);
118
+ wsp_ggml_free(dataset->ctx);
119
+ delete dataset;
120
+ }
121
+
122
+ struct wsp_ggml_tensor * wsp_ggml_opt_dataset_data(wsp_ggml_opt_dataset_t dataset) {
123
+ return dataset->data;
124
+ }
125
+
126
+ struct wsp_ggml_tensor * wsp_ggml_opt_dataset_labels(wsp_ggml_opt_dataset_t dataset) {
127
+ return dataset->labels;
128
+ }
129
+
130
+ void wsp_ggml_opt_dataset_shuffle(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_dataset_t dataset, int64_t idata) {
131
+ WSP_GGML_ASSERT(idata <= dataset->ndata);
132
+
133
+ if (idata < 0) {
134
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
135
+ return;
136
+ }
137
+
138
+ WSP_GGML_ASSERT(idata % dataset->ndata_shard == 0);
139
+ const int64_t ishard_max = idata / dataset->ndata_shard;
140
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
141
+ }
142
+
143
+ void wsp_ggml_opt_dataset_get_batch(wsp_ggml_opt_dataset_t dataset, struct wsp_ggml_tensor * data_batch, struct wsp_ggml_tensor * labels_batch, int64_t ibatch) {
144
+ WSP_GGML_ASSERT( data_batch && wsp_ggml_is_contiguous(data_batch));
145
+ WSP_GGML_ASSERT(!labels_batch || wsp_ggml_is_contiguous(labels_batch));
146
+ WSP_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
147
+
148
+ const size_t nb_data_batch = wsp_ggml_nbytes(data_batch);
149
+ WSP_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
150
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
151
+
152
+ if (labels_batch) {
153
+ const size_t nb_labels_batch = wsp_ggml_nbytes(labels_batch);
154
+ WSP_GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
155
+ }
156
+
157
+ WSP_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
158
+
159
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
160
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
161
+
162
+ const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
163
+ wsp_ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
164
+
165
+ if (!labels_batch) {
166
+ continue;
167
+ }
168
+
169
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
170
+ wsp_ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
171
+ }
172
+ }
173
+
174
+ // ====== Model / Context ======
175
+
176
+ struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(void * userdata) {
177
+ WSP_GGML_UNUSED(userdata);
178
+
179
+ wsp_ggml_opt_optimizer_params result;
180
+
181
+ result.adamw.alpha = 0.001f;
182
+ result.adamw.beta1 = 0.9f;
183
+ result.adamw.beta2 = 0.999f;
184
+ result.adamw.eps = 1e-8f;
185
+ result.adamw.wd = 0.0f;
186
+
187
+ return result;
188
+ }
189
+
190
+ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
191
+ wsp_ggml_backend_sched_t backend_sched,
192
+ struct wsp_ggml_context * ctx_compute,
193
+ struct wsp_ggml_tensor * inputs,
194
+ struct wsp_ggml_tensor * outputs,
195
+ enum wsp_ggml_opt_loss_type loss_type) {
196
+ return {
197
+ /*backend_sched =*/ backend_sched,
198
+ /*ctx_compute =*/ ctx_compute,
199
+ /*inputs =*/ inputs,
200
+ /*logits =*/ outputs,
201
+ /*loss_type =*/ loss_type,
202
+ /*build_type =*/ WSP_GGML_OPT_BUILD_TYPE_OPT,
203
+ /*opt_period =*/ 1,
204
+ /*get_opt_pars =*/ wsp_ggml_opt_get_default_optimizer_params,
205
+ /*get_opt_pars_ud =*/ nullptr,
206
+ };
207
+ }
208
+
209
+ static wsp_ggml_tensor * map_tensor(std::map<wsp_ggml_tensor *, wsp_ggml_tensor *> & tensor_map, wsp_ggml_context * ctx, wsp_ggml_tensor * tensor) {
210
+ if (!tensor) {
211
+ return nullptr;
212
+ }
213
+
214
+ if (tensor_map.find(tensor) != tensor_map.end()) {
215
+ return tensor_map[tensor];
216
+ }
217
+
218
+ wsp_ggml_tensor * new_tensor = wsp_ggml_dup_tensor(ctx, tensor);
219
+ tensor_map[tensor] = new_tensor;
220
+
221
+ new_tensor->op = tensor->op;
222
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
223
+ new_tensor->nb[i] = tensor->nb[i];
224
+ }
225
+ new_tensor->flags = tensor->flags;
226
+ memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
227
+ strcpy(new_tensor->name, tensor->name);
228
+ new_tensor->data = tensor->data;
229
+ new_tensor->buffer = tensor->buffer;
230
+ new_tensor->extra = tensor->extra;
231
+ new_tensor->view_offs = tensor->view_offs;
232
+ new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
233
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
234
+ new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
235
+ }
236
+
237
+ return new_tensor;
238
+ }
239
+
240
+ static wsp_ggml_cgraph * dup_graph(wsp_ggml_context * ctx, wsp_ggml_cgraph * src) {
241
+ std::map<wsp_ggml_tensor *, wsp_ggml_tensor *> tensor_map;
242
+
243
+ wsp_ggml_cgraph * dst = wsp_ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
244
+
245
+ for (int i = 0; i < src->n_leafs; i++) {
246
+ wsp_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
247
+ }
248
+ WSP_GGML_ASSERT(dst->n_leafs == src->n_leafs);
249
+ for (int i = 0; i < src->n_nodes; i++) {
250
+ wsp_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
251
+ }
252
+ WSP_GGML_ASSERT(dst->n_nodes == src->n_nodes);
253
+ for (int i = 0; i < src->n_nodes; ++i) {
254
+ const size_t igrad_src = wsp_ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
255
+ const size_t igrad_dst = wsp_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
256
+
257
+ WSP_GGML_ASSERT(igrad_src != WSP_GGML_HASHSET_FULL);
258
+ WSP_GGML_ASSERT(wsp_ggml_bitset_get(src->visited_hash_set.used, igrad_src));
259
+ WSP_GGML_ASSERT(igrad_dst != WSP_GGML_HASHSET_FULL);
260
+ WSP_GGML_ASSERT(wsp_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
261
+
262
+ dst->grads[igrad_dst] = src->grads[igrad_src];
263
+ dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
264
+ }
265
+
266
+ return dst;
267
+ }
268
+
269
+ static void wsp_ggml_opt_alloc_graph(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_cgraph * graph) {
270
+ WSP_GGML_ASSERT(graph);
271
+ if (opt_ctx->allocated_graph == graph) {
272
+ return;
273
+ }
274
+
275
+ wsp_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
276
+
277
+ {
278
+ wsp_ggml_init_params params = {
279
+ /*.mem_size =*/ wsp_ggml_tensor_overhead() * WSP_GGML_DEFAULT_GRAPH_SIZE,
280
+ /*.mem_buffer =*/ nullptr,
281
+ /*.no_alloc =*/ true,
282
+ };
283
+ wsp_ggml_free(opt_ctx->ctx_copy);
284
+ opt_ctx->ctx_copy = wsp_ggml_init(params);
285
+ }
286
+
287
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
288
+
289
+ wsp_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
290
+ opt_ctx->allocated_graph = graph;
291
+ }
292
+
293
+ wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params) {
294
+ wsp_ggml_opt_context_t result = new struct wsp_ggml_opt_context;
295
+ result->backend_sched = params.backend_sched;
296
+ result->ctx_compute = params.ctx_compute;
297
+ result->inputs = params.inputs;
298
+ result->outputs = params.outputs;
299
+ result->opt_period = params.opt_period;
300
+ result->get_opt_pars = params.get_opt_pars;
301
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
302
+
303
+ WSP_GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
304
+ WSP_GGML_ASSERT(result->opt_period >= 1);
305
+
306
+ const bool accumulate = params.build_type == WSP_GGML_OPT_BUILD_TYPE_GRAD ||
307
+ (params.build_type == WSP_GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
308
+
309
+ wsp_ggml_set_input(result->inputs);
310
+ wsp_ggml_set_output(result->outputs);
311
+
312
+ result->gf = wsp_ggml_new_graph_custom(result->ctx_compute, WSP_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
313
+ wsp_ggml_build_forward_expand(result->gf, result->outputs);
314
+
315
+ int n_param = 0;
316
+ for (int i = 0; i < result->gf->n_nodes; ++i) {
317
+ if (result->gf->nodes[i]->flags & WSP_GGML_TENSOR_FLAG_PARAM) {
318
+ n_param++;
319
+ }
320
+ }
321
+
322
+ {
323
+ // The static context is used for:
324
+ // - gradients (1 tensor per param if using gradient accumulation)
325
+ // - optimizer momenta (2 tensors per param)
326
+ // - labels
327
+ // - loss + its gradient (up to 5 tensors)
328
+ // - pred
329
+ // - ncorrect (2 tensors).
330
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == WSP_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
331
+ const size_t size_meta = (tensors_per_param*n_param + 9) * wsp_ggml_tensor_overhead();
332
+ struct wsp_ggml_init_params params = {
333
+ /*.mem_size =*/ size_meta,
334
+ /*.mem_buffer =*/ nullptr,
335
+ /*.no_alloc =*/ true,
336
+ };
337
+ result->ctx_static = wsp_ggml_init(params);
338
+ }
339
+ {
340
+ // The static cpu context is used for:
341
+ // - optimizer parameters (1 for the entire context)
342
+ const size_t size_meta = 1 * wsp_ggml_tensor_overhead();
343
+ struct wsp_ggml_init_params params = {
344
+ /*.mem_size =*/ size_meta,
345
+ /*.mem_buffer =*/ nullptr,
346
+ /*.no_alloc =*/ true,
347
+ };
348
+ result->ctx_static_cpu = wsp_ggml_init(params);
349
+ }
350
+
351
+
352
+ switch (params.loss_type) {
353
+ case WSP_GGML_OPT_LOSS_TYPE_MEAN: {
354
+ result->loss = wsp_ggml_sum(result->ctx_static, result->outputs);
355
+ wsp_ggml_set_name(result->loss, "loss_sum");
356
+ const float scale = 1.0f / (result->opt_period * wsp_ggml_nelements(result->outputs));
357
+ result->loss = wsp_ggml_scale(result->ctx_static, result->loss, scale);
358
+ wsp_ggml_set_name(result->loss, "loss_mean");
359
+ result->loss_per_datapoint = true;
360
+ break;
361
+ }
362
+ case WSP_GGML_OPT_LOSS_TYPE_SUM: {
363
+ result->loss = wsp_ggml_sum(result->ctx_static, result->outputs);
364
+ wsp_ggml_set_name(result->loss, "loss_sum");
365
+ result->loss_per_datapoint = false;
366
+ break;
367
+ }
368
+ case WSP_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
369
+ result->labels = wsp_ggml_dup_tensor(result->ctx_static, result->outputs);
370
+ wsp_ggml_set_input(result->labels);
371
+ wsp_ggml_set_name(result->labels, "labels");
372
+ result->loss = wsp_ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
373
+ wsp_ggml_set_name(result->loss, "loss_cross_entropy");
374
+ if (result->opt_period > 1) {
375
+ result->loss = wsp_ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
376
+ wsp_ggml_set_name(result->loss, "loss_cross_entropy_scaled");
377
+ }
378
+ result->loss_per_datapoint = true;
379
+ break;
380
+ }
381
+ case WSP_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
382
+ result->labels = wsp_ggml_dup_tensor(result->ctx_static, result->outputs);
383
+ wsp_ggml_set_input(result->labels);
384
+ wsp_ggml_set_name(result->labels, "labels");
385
+ result->loss = wsp_ggml_sub(result->ctx_static, result->outputs, result->labels);
386
+ wsp_ggml_set_name(result->loss, "loss_error");
387
+ result->loss = wsp_ggml_sqr(result->ctx_static, result->loss);
388
+ wsp_ggml_set_name(result->loss, "loss_squared_error");
389
+ result->loss = wsp_ggml_sum(result->ctx_static, result->loss);
390
+ wsp_ggml_set_name(result->loss, "loss_sum_squared_error");
391
+ const float scale = 1.0f / (result->opt_period * wsp_ggml_nelements(result->outputs));
392
+ result->loss = wsp_ggml_scale(result->ctx_static, result->loss, scale);
393
+ wsp_ggml_set_name(result->loss, "loss_mean_squared_error");
394
+ result->loss_per_datapoint = true;
395
+ break;
396
+ }
397
+ }
398
+ wsp_ggml_set_output(result->loss);
399
+ wsp_ggml_set_loss(result->loss);
400
+ wsp_ggml_build_forward_expand(result->gf, result->loss);
401
+
402
+ result->pred = wsp_ggml_argmax(result->ctx_static, result->outputs);
403
+ wsp_ggml_set_name(result->pred, "pred");
404
+ wsp_ggml_set_output(result->pred);
405
+ wsp_ggml_build_forward_expand(result->gf, result->pred);
406
+
407
+ if (result->labels) {
408
+ result->ncorrect = wsp_ggml_count_equal(result->ctx_static, result->pred, wsp_ggml_argmax(result->ctx_static, result->labels));
409
+ wsp_ggml_set_name(result->ncorrect, "ncorrect");
410
+ wsp_ggml_set_output(result->ncorrect);
411
+ wsp_ggml_build_forward_expand(result->gf, result->ncorrect);
412
+ } else {
413
+ result->ncorrect = nullptr;
414
+ }
415
+
416
+ if (params.build_type == WSP_GGML_OPT_BUILD_TYPE_FORWARD) {
417
+ result->buf_static = wsp_ggml_backend_alloc_ctx_tensors(result->ctx_static, wsp_ggml_backend_sched_get_backend(result->backend_sched, 0));
418
+ return result;
419
+ }
420
+
421
+ // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
422
+ result->gb_grad = wsp_ggml_graph_dup(result->ctx_compute, result->gf);
423
+ wsp_ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
424
+
425
+ if (params.build_type == WSP_GGML_OPT_BUILD_TYPE_GRAD) {
426
+ result->buf_static = wsp_ggml_backend_alloc_ctx_tensors(result->ctx_static, wsp_ggml_backend_sched_get_backend(result->backend_sched, 0));
427
+ wsp_ggml_graph_reset(result->gb_grad);
428
+ return result;
429
+ }
430
+
431
+ WSP_GGML_ASSERT(params.build_type == WSP_GGML_OPT_BUILD_TYPE_OPT);
432
+
433
+ // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
434
+ result->gb_opt = wsp_ggml_graph_dup(result->ctx_compute, result->gb_grad);
435
+
436
+ result->adamw_params = wsp_ggml_new_tensor_1d(result->ctx_static_cpu, WSP_GGML_TYPE_F32, 7);
437
+ wsp_ggml_set_input(result->adamw_params);
438
+ wsp_ggml_set_name(result->adamw_params, "adamw_params");
439
+
440
+ for (int i = result->gf->n_nodes-1; i >= 0; --i) {
441
+ struct wsp_ggml_tensor * node = result->gb_opt->nodes[i];
442
+ struct wsp_ggml_tensor * grad = wsp_ggml_graph_get_grad(result->gb_opt, node);
443
+
444
+ if (node->flags & WSP_GGML_TENSOR_FLAG_PARAM) {
445
+ struct wsp_ggml_tensor * m = wsp_ggml_dup_tensor(result->ctx_static, node);
446
+ struct wsp_ggml_tensor * v = wsp_ggml_dup_tensor(result->ctx_static, node);
447
+ struct wsp_ggml_tensor * opt_step = wsp_ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
448
+ wsp_ggml_build_forward_expand(result->gb_opt, opt_step);
449
+ }
450
+ }
451
+
452
+ result->buf_static = wsp_ggml_backend_alloc_ctx_tensors(
453
+ result->ctx_static, wsp_ggml_backend_sched_get_backend(result->backend_sched, 0));
454
+
455
+ result->buf_static_cpu = wsp_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, wsp_ggml_backend_cpu_buffer_type());
456
+
457
+ wsp_ggml_graph_reset(result->gb_opt);
458
+
459
+ return result;
460
+ }
461
+
462
+ void wsp_ggml_opt_free(wsp_ggml_opt_context_t opt_ctx) {
463
+ if (opt_ctx == nullptr) {
464
+ return;
465
+ }
466
+ wsp_ggml_backend_buffer_free(opt_ctx->buf_static);
467
+ wsp_ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
468
+ wsp_ggml_free(opt_ctx->ctx_static);
469
+ wsp_ggml_free(opt_ctx->ctx_static_cpu);
470
+ delete opt_ctx;
471
+ }
472
+
473
+ void wsp_ggml_opt_reset(wsp_ggml_opt_context_t opt_ctx, bool optimizer) {
474
+ if (optimizer) {
475
+ wsp_ggml_graph_reset(opt_ctx->gb_opt);
476
+ opt_ctx->iter = 1;
477
+ } else {
478
+ wsp_ggml_graph_reset(opt_ctx->gb_grad);
479
+ }
480
+ }
481
+
482
+ struct wsp_ggml_tensor * wsp_ggml_opt_inputs(wsp_ggml_opt_context_t opt_ctx) {
483
+ return opt_ctx->inputs;
484
+ }
485
+
486
+ struct wsp_ggml_tensor * wsp_ggml_opt_outputs(wsp_ggml_opt_context_t opt_ctx) {
487
+ return opt_ctx->outputs;
488
+ }
489
+
490
+ struct wsp_ggml_tensor * wsp_ggml_opt_labels(wsp_ggml_opt_context_t opt_ctx) {
491
+ return opt_ctx->labels;
492
+ }
493
+
494
+ struct wsp_ggml_tensor * wsp_ggml_opt_loss(wsp_ggml_opt_context_t opt_ctx) {
495
+ return opt_ctx->loss;
496
+ }
497
+
498
+ struct wsp_ggml_tensor * wsp_ggml_opt_pred(wsp_ggml_opt_context_t opt_ctx) {
499
+ return opt_ctx->pred;
500
+ }
501
+
502
+ struct wsp_ggml_tensor * wsp_ggml_opt_ncorrect(wsp_ggml_opt_context_t opt_ctx) {
503
+ return opt_ctx->ncorrect;
504
+ }
505
+
506
+ struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node) {
507
+ return wsp_ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
508
+ }
509
+
510
+ // ====== Optimization Result ======
511
+
512
+ wsp_ggml_opt_result_t wsp_ggml_opt_result_init() {
513
+ return new wsp_ggml_opt_result;
514
+ }
515
+
516
+ void wsp_ggml_opt_result_free(wsp_ggml_opt_result_t result) {
517
+ delete result;
518
+ }
519
+
520
+ void wsp_ggml_opt_result_reset(wsp_ggml_opt_result_t result) {
521
+ result->ndata = 0;
522
+ result->loss.clear();
523
+ result->pred.clear();
524
+ result->ncorrect = 0;
525
+ }
526
+
527
+ void wsp_ggml_opt_result_ndata(wsp_ggml_opt_result_t result, int64_t * ndata) {
528
+ *ndata = result->ndata;
529
+ }
530
+
531
+ void wsp_ggml_opt_result_loss(wsp_ggml_opt_result_t result, double * loss, double * unc) {
532
+ const int64_t nbatches = result->loss.size(); // Number of physical batches.
533
+
534
+ if (nbatches == 0) {
535
+ *loss = 0.0;
536
+ *unc = NAN;
537
+ return;
538
+ }
539
+
540
+ double sum = 0.0;
541
+ double sum_squared = 0.0;
542
+
543
+ for (const float & loss : result->loss) {
544
+ // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
545
+ const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
546
+ sum += loss_scaled;
547
+ sum_squared += loss_scaled*loss_scaled;
548
+ }
549
+
550
+ const double mean = sum/nbatches;
551
+ *loss = result->loss_per_datapoint ? mean : sum;
552
+
553
+ if (!unc) {
554
+ return;
555
+ }
556
+
557
+ if (nbatches < 2) {
558
+ *unc = NAN;
559
+ return;
560
+ }
561
+
562
+ const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
563
+ *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
564
+ }
565
+
566
+ void wsp_ggml_opt_result_pred(wsp_ggml_opt_result_t result, int32_t * pred) {
567
+ for (size_t i = 0; i < result->pred.size(); ++i) {
568
+ pred[i] = result->pred[i];
569
+ }
570
+ }
571
+
572
+ void wsp_ggml_opt_result_accuracy(wsp_ggml_opt_result_t result, double * accuracy, double * unc) {
573
+ *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
574
+
575
+ if (!unc) {
576
+ return;
577
+ }
578
+
579
+ *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
580
+ sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
581
+ }
582
+
583
+ // ====== Computation ======
584
+
585
+ static void wsp_ggml_opt_eval_graph(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_cgraph * graph, wsp_ggml_opt_result * result) {
586
+ if (graph != opt_ctx->gf) {
587
+ struct wsp_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
588
+
589
+ WSP_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
590
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
591
+ WSP_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
592
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
593
+ WSP_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
594
+ WSP_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
595
+ WSP_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
596
+ WSP_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
597
+
598
+ // beta1, beta2 after applying warmup
599
+ const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
600
+ const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
601
+
602
+ float * adamw_par_data = wsp_ggml_get_data_f32(opt_ctx->adamw_params);
603
+ adamw_par_data[0] = opt_pars.adamw.alpha;
604
+ adamw_par_data[1] = opt_pars.adamw.beta1;
605
+ adamw_par_data[2] = opt_pars.adamw.beta2;
606
+ adamw_par_data[3] = opt_pars.adamw.eps;
607
+ adamw_par_data[4] = opt_pars.adamw.wd;
608
+ adamw_par_data[5] = beta1h;
609
+ adamw_par_data[6] = beta2h;
610
+ }
611
+
612
+ wsp_ggml_opt_alloc_graph(opt_ctx, graph);
613
+ wsp_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
614
+ opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
615
+
616
+ if (!result) {
617
+ return;
618
+ }
619
+
620
+ if (result->ndata == 0) {
621
+ result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
622
+ result->opt_period = opt_ctx->opt_period;
623
+ } else {
624
+ WSP_GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
625
+ WSP_GGML_ASSERT(result->opt_period == opt_ctx->opt_period);
626
+ }
627
+
628
+ const int64_t ndata = opt_ctx->outputs->ne[1];
629
+ WSP_GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
630
+ result->ndata += ndata;
631
+
632
+ WSP_GGML_ASSERT(wsp_ggml_is_scalar(opt_ctx->loss));
633
+ WSP_GGML_ASSERT(opt_ctx->loss->type == WSP_GGML_TYPE_F32);
634
+ float loss;
635
+ wsp_ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, wsp_ggml_nbytes(opt_ctx->loss));
636
+ result->loss.push_back(loss);
637
+
638
+ WSP_GGML_ASSERT(opt_ctx->pred->type == WSP_GGML_TYPE_I32);
639
+ std::vector<int32_t> pred(ndata);
640
+ wsp_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, wsp_ggml_nbytes(opt_ctx->pred));
641
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
642
+
643
+ if (!opt_ctx->labels || result->ncorrect < 0) {
644
+ result->ncorrect = -1;
645
+ return;
646
+ }
647
+
648
+ WSP_GGML_ASSERT(wsp_ggml_is_scalar(opt_ctx->ncorrect));
649
+ WSP_GGML_ASSERT(opt_ctx->ncorrect->type == WSP_GGML_TYPE_I64);
650
+ int64_t ncorrect;
651
+ wsp_ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, wsp_ggml_nbytes(opt_ctx->ncorrect));
652
+ result->ncorrect += ncorrect;
653
+ }
654
+
655
+ void wsp_ggml_opt_forward(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result * result) {
656
+ wsp_ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
657
+ }
658
+
659
+ void wsp_ggml_opt_forward_backward(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result * result) {
660
+ if (opt_ctx->opt_period == 1) {
661
+ wsp_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
662
+ return;
663
+ }
664
+
665
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
666
+ if (opt_i_next == 0) {
667
+ wsp_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
668
+ wsp_ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
669
+ } else {
670
+ wsp_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
671
+ }
672
+ opt_ctx->opt_i = opt_i_next;
673
+ }
674
+
675
+ // ====== High-Level Functions ======
676
+
677
+ void wsp_ggml_opt_epoch(
678
+ wsp_ggml_opt_context_t opt_ctx,
679
+ wsp_ggml_opt_dataset_t dataset,
680
+ wsp_ggml_opt_result_t result_train,
681
+ wsp_ggml_opt_result_t result_eval,
682
+ int64_t idata_split,
683
+ wsp_ggml_opt_epoch_callback callback_train,
684
+ wsp_ggml_opt_epoch_callback callback_eval) {
685
+ struct wsp_ggml_tensor * inputs = wsp_ggml_opt_inputs(opt_ctx);
686
+ struct wsp_ggml_tensor * labels = wsp_ggml_opt_labels(opt_ctx);
687
+ struct wsp_ggml_tensor * data = wsp_ggml_opt_dataset_data(dataset);
688
+ WSP_GGML_ASSERT(data->ne[0] == inputs->ne[0]);
689
+
690
+ const int64_t ndata = data->ne[1];
691
+ const int64_t ndata_batch = inputs->ne[1];
692
+
693
+ WSP_GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
694
+ const int64_t nbatches = ndata/ndata_batch;
695
+
696
+ idata_split = idata_split < 0 ? ndata : idata_split;
697
+ WSP_GGML_ASSERT(idata_split % ndata_batch == 0);
698
+ const int64_t ibatch_split = idata_split / ndata_batch;
699
+
700
+ int64_t ibatch = 0;
701
+ int64_t t_loop_start = wsp_ggml_time_us();
702
+ for (; ibatch < ibatch_split; ++ibatch) {
703
+ wsp_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
704
+ wsp_ggml_opt_forward_backward(opt_ctx, result_train);
705
+ if (callback_train) {
706
+ callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
707
+ }
708
+ }
709
+ t_loop_start = wsp_ggml_time_us();
710
+ for (; ibatch < nbatches; ++ibatch) {
711
+ wsp_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
712
+ wsp_ggml_opt_forward(opt_ctx, result_eval);
713
+ if (callback_eval) {
714
+ callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
715
+ }
716
+ }
717
+ }
718
+
719
+ void wsp_ggml_opt_epoch_callback_progress_bar(
720
+ bool train,
721
+ wsp_ggml_opt_context_t opt_ctx,
722
+ wsp_ggml_opt_dataset_t dataset,
723
+ wsp_ggml_opt_result_t result,
724
+ int64_t ibatch,
725
+ int64_t ibatch_max,
726
+ int64_t t_start_us) {
727
+ fprintf(stderr, "%s[", train ? "train: " : "val: ");
728
+
729
+ constexpr int64_t bar_length = 25;
730
+ for (int64_t j = 0; j < bar_length; ++j) {
731
+ const int64_t ibatch_j = ibatch_max * j/bar_length;
732
+ if (ibatch_j < ibatch) {
733
+ fprintf(stderr, "=");
734
+ } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
735
+ fprintf(stderr, ">");
736
+ } else {
737
+ fprintf(stderr, " ");
738
+ }
739
+ }
740
+
741
+ const int64_t batch_size = wsp_ggml_opt_inputs(opt_ctx)->ne[1];
742
+ const int64_t idata = ibatch*batch_size;
743
+ const int64_t idata_max = ibatch_max*batch_size;
744
+
745
+ double loss;
746
+ double loss_unc;
747
+ wsp_ggml_opt_result_loss(result, &loss, &loss_unc);
748
+
749
+ double accuracy;
750
+ double accuracy_unc;
751
+ wsp_ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
752
+
753
+ const int64_t t_ibatch_us = wsp_ggml_time_us() - t_start_us;
754
+ int64_t t_ibatch_s = t_ibatch_us / 1000000;
755
+ const int64_t t_ibatch_h = t_ibatch_s / 3600;
756
+ t_ibatch_s -= t_ibatch_h * 3600;
757
+ const int64_t t_ibatch_m = t_ibatch_s / 60;
758
+ t_ibatch_s -= t_ibatch_m * 60;
759
+
760
+ const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
761
+ int64_t t_eta_s = t_eta_us / 1000000;
762
+ const int64_t t_eta_h = t_eta_s / 3600;
763
+ t_eta_s -= t_eta_h * 3600;
764
+ const int64_t t_eta_m = t_eta_s / 60;
765
+ t_eta_s -= t_eta_m * 60;
766
+
767
+ fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
768
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
769
+ idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
770
+ t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
771
+ if (ibatch == ibatch_max) {
772
+ fprintf(stderr, "\n");
773
+ }
774
+ fflush(stderr);
775
+
776
+ WSP_GGML_UNUSED(dataset);
777
+ }
778
+
779
+ void wsp_ggml_opt_fit(
780
+ wsp_ggml_backend_sched_t backend_sched,
781
+ wsp_ggml_context * ctx_compute,
782
+ wsp_ggml_tensor * inputs,
783
+ wsp_ggml_tensor * outputs,
784
+ wsp_ggml_opt_dataset_t dataset,
785
+ enum wsp_ggml_opt_loss_type loss_type,
786
+ wsp_ggml_opt_get_optimizer_params get_opt_pars,
787
+ int64_t nepoch,
788
+ int64_t nbatch_logical,
789
+ float val_split,
790
+ bool silent) {
791
+ wsp_ggml_time_init();
792
+ const int64_t t_start_us = wsp_ggml_time_us();
793
+
794
+ const int64_t ndata = wsp_ggml_opt_dataset_data(dataset)->ne[1];
795
+ const int64_t nbatch_physical = inputs->ne[1];
796
+ WSP_GGML_ASSERT(ndata % nbatch_logical == 0);
797
+ WSP_GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
798
+
799
+ const int64_t opt_period = nbatch_logical / nbatch_physical;
800
+ const int64_t nbatches_logical = ndata / nbatch_logical;
801
+
802
+ WSP_GGML_ASSERT(val_split >= 0.0f);
803
+ WSP_GGML_ASSERT(val_split < 1.0f);
804
+ const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
805
+ const int64_t idata_split = ibatch_split * nbatch_physical;
806
+
807
+ int64_t epoch = 1;
808
+
809
+ wsp_ggml_opt_params params = wsp_ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
810
+ params.opt_period = opt_period;
811
+ params.get_opt_pars = get_opt_pars;
812
+ params.get_opt_pars_ud = &epoch;
813
+ wsp_ggml_opt_context_t opt_ctx = wsp_ggml_opt_init(params);
814
+
815
+ // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
816
+ if (nbatch_logical < ndata) {
817
+ wsp_ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
818
+ }
819
+
820
+ wsp_ggml_opt_result_t result_train = wsp_ggml_opt_result_init();
821
+ wsp_ggml_opt_result_t result_val = wsp_ggml_opt_result_init();
822
+
823
+ wsp_ggml_opt_epoch_callback epoch_callback = silent ? nullptr : wsp_ggml_opt_epoch_callback_progress_bar;
824
+
825
+ for (; epoch <= nepoch; ++epoch) {
826
+ if (nbatch_logical < idata_split) {
827
+ wsp_ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
828
+ }
829
+
830
+ wsp_ggml_opt_result_reset(result_train);
831
+ wsp_ggml_opt_result_reset(result_val);
832
+
833
+ if (!silent) {
834
+ fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
835
+ }
836
+ wsp_ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
837
+ if (!silent) {
838
+ fprintf(stderr, "\n");
839
+ }
840
+ }
841
+
842
+ if (!silent) {
843
+ int64_t t_total_s = (wsp_ggml_time_us() - t_start_us) / 1000000;
844
+ const int64_t t_total_h = t_total_s / 3600;
845
+ t_total_s -= t_total_h * 3600;
846
+ const int64_t t_total_m = t_total_s / 60;
847
+ t_total_s -= t_total_m * 60;
848
+ fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
849
+ }
850
+
851
+ wsp_ggml_opt_free(opt_ctx);
852
+ wsp_ggml_opt_result_free(result_train);
853
+ wsp_ggml_opt_result_free(result_val);
854
+ }