@reicek/neataptic-ts 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. package/.github/ISSUE_TEMPLATE/bug_report.md +33 -0
  2. package/.github/ISSUE_TEMPLATE/feature_request.md +27 -0
  3. package/.github/PULL_REQUEST_TEMPLATE.md +28 -0
  4. package/.github/workflows/ci.yml +41 -0
  5. package/.github/workflows/deploy-pages.yml +29 -0
  6. package/.github/workflows/manual_release_pipeline.yml +62 -0
  7. package/.github/workflows/publish.yml +85 -0
  8. package/.github/workflows/release_dispatch.yml +38 -0
  9. package/.travis.yml +5 -0
  10. package/CONTRIBUTING.md +92 -0
  11. package/LICENSE +24 -0
  12. package/ONNX_EXPORT.md +87 -0
  13. package/README.md +1173 -0
  14. package/RELEASE.md +54 -0
  15. package/dist-docs/package.json +1 -0
  16. package/dist-docs/scripts/generate-docs.d.ts +2 -0
  17. package/dist-docs/scripts/generate-docs.d.ts.map +1 -0
  18. package/dist-docs/scripts/generate-docs.js +536 -0
  19. package/dist-docs/scripts/generate-docs.js.map +1 -0
  20. package/dist-docs/scripts/render-docs-html.d.ts +2 -0
  21. package/dist-docs/scripts/render-docs-html.d.ts.map +1 -0
  22. package/dist-docs/scripts/render-docs-html.js +148 -0
  23. package/dist-docs/scripts/render-docs-html.js.map +1 -0
  24. package/docs/FOLDERS.md +14 -0
  25. package/docs/README.md +1173 -0
  26. package/docs/architecture/README.md +1391 -0
  27. package/docs/architecture/index.html +938 -0
  28. package/docs/architecture/network/README.md +1210 -0
  29. package/docs/architecture/network/index.html +908 -0
  30. package/docs/assets/ascii-maze.bundle.js +16542 -0
  31. package/docs/assets/ascii-maze.bundle.js.map +7 -0
  32. package/docs/index.html +1419 -0
  33. package/docs/methods/README.md +670 -0
  34. package/docs/methods/index.html +477 -0
  35. package/docs/multithreading/README.md +274 -0
  36. package/docs/multithreading/index.html +215 -0
  37. package/docs/multithreading/workers/README.md +23 -0
  38. package/docs/multithreading/workers/browser/README.md +39 -0
  39. package/docs/multithreading/workers/browser/index.html +70 -0
  40. package/docs/multithreading/workers/index.html +57 -0
  41. package/docs/multithreading/workers/node/README.md +33 -0
  42. package/docs/multithreading/workers/node/index.html +66 -0
  43. package/docs/neat/README.md +1284 -0
  44. package/docs/neat/index.html +906 -0
  45. package/docs/src/README.md +2659 -0
  46. package/docs/src/index.html +1579 -0
  47. package/jest.config.ts +32 -0
  48. package/package.json +99 -0
  49. package/plans/HyperMorphoNEAT.md +293 -0
  50. package/plans/ONNX_EXPORT_PLAN.md +46 -0
  51. package/scripts/generate-docs.ts +486 -0
  52. package/scripts/render-docs-html.ts +138 -0
  53. package/scripts/types.d.ts +2 -0
  54. package/src/README.md +2659 -0
  55. package/src/architecture/README.md +1391 -0
  56. package/src/architecture/activationArrayPool.ts +135 -0
  57. package/src/architecture/architect.ts +635 -0
  58. package/src/architecture/connection.ts +148 -0
  59. package/src/architecture/group.ts +406 -0
  60. package/src/architecture/layer.ts +804 -0
  61. package/src/architecture/network/README.md +1210 -0
  62. package/src/architecture/network/network.activate.ts +223 -0
  63. package/src/architecture/network/network.connect.ts +157 -0
  64. package/src/architecture/network/network.deterministic.ts +167 -0
  65. package/src/architecture/network/network.evolve.ts +426 -0
  66. package/src/architecture/network/network.gating.ts +186 -0
  67. package/src/architecture/network/network.genetic.ts +247 -0
  68. package/src/architecture/network/network.mutate.ts +624 -0
  69. package/src/architecture/network/network.onnx.ts +463 -0
  70. package/src/architecture/network/network.prune.ts +216 -0
  71. package/src/architecture/network/network.remove.ts +96 -0
  72. package/src/architecture/network/network.serialize.ts +309 -0
  73. package/src/architecture/network/network.slab.ts +262 -0
  74. package/src/architecture/network/network.standalone.ts +246 -0
  75. package/src/architecture/network/network.stats.ts +59 -0
  76. package/src/architecture/network/network.topology.ts +86 -0
  77. package/src/architecture/network/network.training.ts +1278 -0
  78. package/src/architecture/network.ts +1302 -0
  79. package/src/architecture/node.ts +1288 -0
  80. package/src/architecture/onnx.ts +3 -0
  81. package/src/config.ts +83 -0
  82. package/src/methods/README.md +670 -0
  83. package/src/methods/activation.ts +372 -0
  84. package/src/methods/connection.ts +31 -0
  85. package/src/methods/cost.ts +347 -0
  86. package/src/methods/crossover.ts +63 -0
  87. package/src/methods/gating.ts +43 -0
  88. package/src/methods/methods.ts +8 -0
  89. package/src/methods/mutation.ts +300 -0
  90. package/src/methods/rate.ts +257 -0
  91. package/src/methods/selection.ts +65 -0
  92. package/src/multithreading/README.md +274 -0
  93. package/src/multithreading/multi.ts +339 -0
  94. package/src/multithreading/workers/README.md +23 -0
  95. package/src/multithreading/workers/browser/README.md +39 -0
  96. package/src/multithreading/workers/browser/testworker.ts +99 -0
  97. package/src/multithreading/workers/node/README.md +33 -0
  98. package/src/multithreading/workers/node/testworker.ts +72 -0
  99. package/src/multithreading/workers/node/worker.ts +70 -0
  100. package/src/multithreading/workers/workers.ts +22 -0
  101. package/src/neat/README.md +1284 -0
  102. package/src/neat/neat.adaptive.ts +544 -0
  103. package/src/neat/neat.compat.ts +164 -0
  104. package/src/neat/neat.constants.ts +20 -0
  105. package/src/neat/neat.diversity.ts +217 -0
  106. package/src/neat/neat.evaluate.ts +328 -0
  107. package/src/neat/neat.evolve.ts +1026 -0
  108. package/src/neat/neat.export.ts +249 -0
  109. package/src/neat/neat.helpers.ts +235 -0
  110. package/src/neat/neat.lineage.ts +220 -0
  111. package/src/neat/neat.multiobjective.ts +260 -0
  112. package/src/neat/neat.mutation.ts +718 -0
  113. package/src/neat/neat.objectives.ts +157 -0
  114. package/src/neat/neat.pruning.ts +190 -0
  115. package/src/neat/neat.selection.ts +269 -0
  116. package/src/neat/neat.speciation.ts +460 -0
  117. package/src/neat/neat.species.ts +151 -0
  118. package/src/neat/neat.telemetry.exports.ts +469 -0
  119. package/src/neat/neat.telemetry.ts +933 -0
  120. package/src/neat/neat.types.ts +275 -0
  121. package/src/neat.ts +1042 -0
  122. package/src/neataptic.ts +10 -0
  123. package/test/architecture/activationArrayPool.capacity.test.ts +19 -0
  124. package/test/architecture/activationArrayPool.test.ts +46 -0
  125. package/test/architecture/connection.test.ts +290 -0
  126. package/test/architecture/group.test.ts +950 -0
  127. package/test/architecture/layer.test.ts +1535 -0
  128. package/test/architecture/network.pruning.test.ts +65 -0
  129. package/test/architecture/node.test.ts +1602 -0
  130. package/test/examples/asciiMaze/asciiMaze.e2e.test.ts +499 -0
  131. package/test/examples/asciiMaze/asciiMaze.ts +41 -0
  132. package/test/examples/asciiMaze/browser-entry.ts +164 -0
  133. package/test/examples/asciiMaze/browserLogger.ts +221 -0
  134. package/test/examples/asciiMaze/browserTerminalUtility.ts +48 -0
  135. package/test/examples/asciiMaze/colors.ts +119 -0
  136. package/test/examples/asciiMaze/dashboardManager.ts +968 -0
  137. package/test/examples/asciiMaze/evolutionEngine.ts +1248 -0
  138. package/test/examples/asciiMaze/fitness.ts +136 -0
  139. package/test/examples/asciiMaze/index.html +128 -0
  140. package/test/examples/asciiMaze/index.ts +26 -0
  141. package/test/examples/asciiMaze/interfaces.ts +235 -0
  142. package/test/examples/asciiMaze/mazeMovement.ts +996 -0
  143. package/test/examples/asciiMaze/mazeUtils.ts +278 -0
  144. package/test/examples/asciiMaze/mazeVision.ts +402 -0
  145. package/test/examples/asciiMaze/mazeVisualization.ts +585 -0
  146. package/test/examples/asciiMaze/mazes.ts +245 -0
  147. package/test/examples/asciiMaze/networkRefinement.ts +76 -0
  148. package/test/examples/asciiMaze/networkVisualization.ts +901 -0
  149. package/test/examples/asciiMaze/terminalUtility.ts +73 -0
  150. package/test/methods/activation.test.ts +1142 -0
  151. package/test/methods/connection.test.ts +146 -0
  152. package/test/methods/cost.test.ts +1123 -0
  153. package/test/methods/crossover.test.ts +202 -0
  154. package/test/methods/gating.test.ts +144 -0
  155. package/test/methods/mutation.test.ts +451 -0
  156. package/test/methods/optimizers.advanced.test.ts +80 -0
  157. package/test/methods/optimizers.behavior.test.ts +105 -0
  158. package/test/methods/optimizers.formula.test.ts +89 -0
  159. package/test/methods/rate.cosineWarmRestarts.test.ts +44 -0
  160. package/test/methods/rate.linearWarmupDecay.test.ts +41 -0
  161. package/test/methods/rate.reduceOnPlateau.test.ts +45 -0
  162. package/test/methods/rate.test.ts +684 -0
  163. package/test/methods/selection.test.ts +245 -0
  164. package/test/multithreading/activations.functions.test.ts +54 -0
  165. package/test/multithreading/multi.test.ts +290 -0
  166. package/test/multithreading/worker.node.process.test.ts +39 -0
  167. package/test/multithreading/workers.coverage.test.ts +36 -0
  168. package/test/multithreading/workers.dynamic.import.test.ts +8 -0
  169. package/test/neat/neat.adaptive.complexityBudget.test.ts +34 -0
  170. package/test/neat/neat.adaptive.criterion.complexity.test.ts +50 -0
  171. package/test/neat/neat.adaptive.mutation.strategy.test.ts +37 -0
  172. package/test/neat/neat.adaptive.operator.decay.test.ts +31 -0
  173. package/test/neat/neat.adaptive.phasedComplexity.test.ts +25 -0
  174. package/test/neat/neat.adaptive.pruning.test.ts +25 -0
  175. package/test/neat/neat.adaptive.targetSpecies.test.ts +43 -0
  176. package/test/neat/neat.additional.coverage.test.ts +126 -0
  177. package/test/neat/neat.advanced.enhancements.test.ts +85 -0
  178. package/test/neat/neat.advanced.test.ts +589 -0
  179. package/test/neat/neat.diversity.autocompat.test.ts +47 -0
  180. package/test/neat/neat.diversity.metrics.test.ts +21 -0
  181. package/test/neat/neat.diversity.stats.test.ts +44 -0
  182. package/test/neat/neat.enhancements.test.ts +79 -0
  183. package/test/neat/neat.entropy.ancestorAdaptive.test.ts +133 -0
  184. package/test/neat/neat.entropy.compat.csv.test.ts +108 -0
  185. package/test/neat/neat.evolution.pruning.test.ts +39 -0
  186. package/test/neat/neat.fastmode.autotune.test.ts +42 -0
  187. package/test/neat/neat.innovation.test.ts +134 -0
  188. package/test/neat/neat.lineage.antibreeding.test.ts +35 -0
  189. package/test/neat/neat.lineage.entropy.test.ts +56 -0
  190. package/test/neat/neat.lineage.inbreeding.test.ts +49 -0
  191. package/test/neat/neat.lineage.pressure.test.ts +29 -0
  192. package/test/neat/neat.multiobjective.adaptive.test.ts +57 -0
  193. package/test/neat/neat.multiobjective.dynamic.schedule.test.ts +46 -0
  194. package/test/neat/neat.multiobjective.dynamic.test.ts +31 -0
  195. package/test/neat/neat.multiobjective.fastsort.delegation.test.ts +51 -0
  196. package/test/neat/neat.multiobjective.prune.test.ts +39 -0
  197. package/test/neat/neat.multiobjective.test.ts +21 -0
  198. package/test/neat/neat.mutation.undefined.pool.test.ts +24 -0
  199. package/test/neat/neat.objective.events.test.ts +26 -0
  200. package/test/neat/neat.objective.importance.test.ts +21 -0
  201. package/test/neat/neat.objective.lifetimes.test.ts +33 -0
  202. package/test/neat/neat.offspring.allocation.test.ts +22 -0
  203. package/test/neat/neat.operator.bandit.test.ts +17 -0
  204. package/test/neat/neat.operator.phases.test.ts +38 -0
  205. package/test/neat/neat.pruneInactive.behavior.test.ts +54 -0
  206. package/test/neat/neat.reenable.adaptation.test.ts +18 -0
  207. package/test/neat/neat.rng.state.test.ts +22 -0
  208. package/test/neat/neat.spawn.add.test.ts +123 -0
  209. package/test/neat/neat.speciation.test.ts +96 -0
  210. package/test/neat/neat.species.allocation.telemetry.test.ts +26 -0
  211. package/test/neat/neat.species.history.csv.test.ts +24 -0
  212. package/test/neat/neat.telemetry.advanced.test.ts +226 -0
  213. package/test/neat/neat.telemetry.csv.lineage.test.ts +19 -0
  214. package/test/neat/neat.telemetry.parity.test.ts +42 -0
  215. package/test/neat/neat.telemetry.stream.test.ts +19 -0
  216. package/test/neat/neat.telemetry.test.ts +16 -0
  217. package/test/neat/neat.test.ts +422 -0
  218. package/test/neat/neat.utilities.test.ts +44 -0
  219. package/test/network/__suppress_console.ts +9 -0
  220. package/test/network/acyclic.topoorder.test.ts +17 -0
  221. package/test/network/checkpoint.metricshook.test.ts +36 -0
  222. package/test/network/error.handling.test.ts +581 -0
  223. package/test/network/evolution.test.ts +285 -0
  224. package/test/network/genetic.test.ts +208 -0
  225. package/test/network/learning.capability.test.ts +244 -0
  226. package/test/network/mutation.effects.test.ts +492 -0
  227. package/test/network/network.activate.test.ts +115 -0
  228. package/test/network/network.activateBatch.test.ts +30 -0
  229. package/test/network/network.deterministic.test.ts +64 -0
  230. package/test/network/network.evolve.branches.test.ts +75 -0
  231. package/test/network/network.evolve.multithread.branches.test.ts +83 -0
  232. package/test/network/network.evolve.test.ts +100 -0
  233. package/test/network/network.gating.removal.test.ts +93 -0
  234. package/test/network/network.mutate.additional.test.ts +145 -0
  235. package/test/network/network.mutate.edgecases.test.ts +101 -0
  236. package/test/network/network.mutate.test.ts +101 -0
  237. package/test/network/network.prune.earlyexit.test.ts +38 -0
  238. package/test/network/network.remove.errors.test.ts +45 -0
  239. package/test/network/network.slab.fallbacks.test.ts +22 -0
  240. package/test/network/network.stats.test.ts +45 -0
  241. package/test/network/network.training.advanced.test.ts +149 -0
  242. package/test/network/network.training.basic.test.ts +228 -0
  243. package/test/network/network.training.helpers.test.ts +183 -0
  244. package/test/network/onnx.export.test.ts +310 -0
  245. package/test/network/onnx.import.test.ts +129 -0
  246. package/test/network/pruning.topology.test.ts +282 -0
  247. package/test/network/regularization.determinism.test.ts +83 -0
  248. package/test/network/regularization.dropconnect.test.ts +17 -0
  249. package/test/network/regularization.dropconnect.validation.test.ts +18 -0
  250. package/test/network/regularization.stochasticdepth.test.ts +27 -0
  251. package/test/network/regularization.test.ts +843 -0
  252. package/test/network/regularization.weightnoise.test.ts +30 -0
  253. package/test/network/setupTests.ts +2 -0
  254. package/test/network/standalone.test.ts +332 -0
  255. package/test/network/structure.serialization.test.ts +660 -0
  256. package/test/training/training.determinism.mixed-precision.test.ts +134 -0
  257. package/test/training/training.earlystopping.test.ts +91 -0
  258. package/test/training/training.edge-cases.test.ts +91 -0
  259. package/test/training/training.extensions.test.ts +47 -0
  260. package/test/training/training.gradient.features.test.ts +110 -0
  261. package/test/training/training.gradient.refinements.test.ts +170 -0
  262. package/test/training/training.gradient.separate-bias.test.ts +41 -0
  263. package/test/training/training.optimizer.test.ts +48 -0
  264. package/test/training/training.plateau.smoothing.test.ts +58 -0
  265. package/test/training/training.smoothing.types.test.ts +174 -0
  266. package/test/training/training.train.options.coverage.test.ts +52 -0
  267. package/test/utils/console-helper.ts +76 -0
  268. package/test/utils/jest-setup.ts +60 -0
  269. package/test/utils/test-helpers.ts +175 -0
  270. package/tsconfig.docs.json +12 -0
  271. package/tsconfig.json +21 -0
  272. package/webpack.config.js +49 -0
@@ -0,0 +1,1278 @@
1
+ /**
2
+ * Training pipeline utilities (migrated from legacy architecture/network.train.ts).
3
+ *
4
+ * Provides:
5
+ * - Gradient clipping (global / layerwise; norm / percentile variants).
6
+ * - Mini & micro-batch gradient accumulation.
7
+ * - Optimizer step dispatch (SGD + adaptive optimizers + lookahead wrapper).
8
+ * - Simple mixed precision dynamic loss scaling (overflow detection heuristic).
9
+ * - Multiple moving-average smoothing strategies for error monitoring (SMA, EMA, adaptive EMA,
10
+ * median, gaussian, trimmed mean, WMA) plus separate plateau averaging.
11
+ * - Early stopping, schedule hooks, pruning hooks, and checkpoint callbacks.
12
+ *
13
+ * Notes:
14
+ * - This module intentionally keeps imperative style for clarity/perf (avoids heap churn in hot loops).
15
+ * - Refactor changes here are documentation & naming only; numerical behavior preserved.
16
+ */
17
+ import * as methods from '../../methods/methods';
18
+ import { config } from '../../config';
19
+ import type Network from '../network';
20
+
21
+ /**
22
+ * -----------------------------------------------------------------------------
23
+ * Internal Type Definitions (documentation only; optional for callers)
24
+ * -----------------------------------------------------------------------------
25
+ */
26
+ /** Cost function signature used by training. */
27
+ export type CostFunction = (target: number[], output: number[]) => number;
28
+
29
+ /** Gradient clipping configuration accepted by options.gradientClip. */
30
+ export interface GradientClipConfig {
31
+ mode?: 'norm' | 'percentile' | 'layerwiseNorm' | 'layerwisePercentile';
32
+ /** Max L2 norm (for *Norm modes). */
33
+ maxNorm?: number;
34
+ /** Percentile threshold (0-100) for *Percentile modes (clamps absolute values). */
35
+ percentile?: number;
36
+ /** Whether to treat bias separately (currently informational flag – behavior parity preserved). */
37
+ separateBias?: boolean;
38
+ }
39
+
40
+ /** Mixed precision configuration. */
41
+ export interface MixedPrecisionDynamicConfig {
42
+ /** Minimum loss scale when scaling down after overflows. */
43
+ minScale?: number;
44
+ /** Maximum allowed loss scale for automatic increases. */
45
+ maxScale?: number;
46
+ /** Steps of stable (non-overflow) updates before doubling loss scale. */
47
+ increaseEvery?: number; // alias stableStepsForIncrease
48
+ /** Legacy alias: stable steps threshold for increase. */
49
+ stableStepsForIncrease?: number;
50
+ }
51
+ export interface MixedPrecisionConfig {
52
+ /** Initial loss scale (larger -> more mantissa preservation but higher overflow risk). */
53
+ lossScale?: number;
54
+ /** Enable dynamic (auto increase/decrease) logic. */
55
+ dynamic?: MixedPrecisionDynamicConfig;
56
+ }
57
+
58
+ /** Optimizer configuration (subset – delegated to node.applyBatchUpdatesWithOptimizer). */
59
+ export interface OptimizerConfigBase {
60
+ type: string; // normalized to lowercase
61
+ baseType?: string; // for lookahead
62
+ beta1?: number;
63
+ beta2?: number;
64
+ eps?: number;
65
+ weightDecay?: number;
66
+ momentum?: number;
67
+ la_k?: number; // lookahead sync interval
68
+ la_alpha?: number; // lookahead interpolation factor
69
+ }
70
+
71
+ /** Checkpoint callback spec. */
72
+ export interface CheckpointConfig {
73
+ /** Save final state each iteration. */
74
+ last?: boolean;
75
+ /** Save best (lowest error) state. */
76
+ best?: boolean;
77
+ /** Persist function invoked with metadata + serialized network. */
78
+ save: (payload: {
79
+ type: 'last' | 'best';
80
+ iteration: number;
81
+ error: number;
82
+ network: any;
83
+ }) => void;
84
+ }
85
+
86
+ /** Schedule hook executed every N iterations. */
87
+ export interface ScheduleConfig {
88
+ iterations: number; // frequency
89
+ function: (info: { error: number; iteration: number }) => void;
90
+ }
91
+
92
+ /** Metrics hook signature. */
93
+ export type MetricsHook = (m: {
94
+ iteration: number;
95
+ error: number;
96
+ plateauError?: number;
97
+ gradNorm: number;
98
+ }) => void;
99
+
100
+ /** Moving average strategy identifiers. */
101
+ export type MovingAverageType =
102
+ | 'sma'
103
+ | 'ema'
104
+ | 'adaptive-ema'
105
+ | 'median'
106
+ | 'gaussian'
107
+ | 'trimmed'
108
+ | 'wma';
109
+
110
+ /** Primary training options object (public shape). */
111
+ export interface TrainingOptions {
112
+ iterations?: number; // stopping condition: max passes
113
+ error?: number; // stopping condition: target monitored (smoothed) error
114
+ rate?: number; // base learning rate
115
+ momentum?: number; // momentum for SGD / sometimes consumed by wrappers
116
+ optimizer?: string | OptimizerConfigBase; // adaptive optimizer choice
117
+ dropout?: number; // dropout probability applied per forward (mutable net.dropout)
118
+ batchSize?: number; // mini-batch size; if > dataset length => error
119
+ accumulationSteps?: number; // gradient accumulation factor (micro-batches per optimizer step)
120
+ accumulationReduction?: 'average' | 'sum'; // scaling mode for accumulated gradients
121
+ gradientClip?: GradientClipConfig; // gradient clipping configuration
122
+ mixedPrecision?: boolean | MixedPrecisionConfig; // enable FP16-like scaling logic
123
+ cost?: CostFunction | { fn?: CostFunction; calculate?: CostFunction }; // cost interface variants
124
+ movingAverageWindow?: number; // smoothing window size
125
+ movingAverageType?: MovingAverageType; // smoothing algorithm
126
+ emaAlpha?: number; // override alpha for EMA
127
+ adaptiveEmaBaseAlpha?: number; // (not currently used – placeholder)
128
+ trimmedRatio?: number; // fraction dropped from each tail for trimmed mean (0..0.49)
129
+ plateauMovingAverageWindow?: number; // independent plateau window
130
+ plateauMovingAverageType?: MovingAverageType; // independent plateau strategy
131
+ plateauEmaAlpha?: number; // plateau EMA alpha override
132
+ earlyStopPatience?: number; // iterations with no improvement before stop
133
+ earlyStopMinDelta?: number; // required improvement beyond previous best
134
+ checkpoint?: CheckpointConfig; // persistence callbacks
135
+ schedule?: ScheduleConfig; // periodic hook
136
+ metricsHook?: MetricsHook; // telemetry per iteration
137
+ }
138
+
139
+ /** ---------------------------------------------------------------------------
140
+ * Internal Helper Utilities (non-exported)
141
+ * ---------------------------------------------------------------------------
142
+ * These functions encapsulate cohesive sub-steps of the training pipeline so the
143
+ * main exported functions remain readable while preserving original behavior.
144
+ * Each helper is intentionally pure where reasonable or documents its side-effects.
145
+ */
146
+
147
+ /** State container for EMA / Adaptive EMA smoothing values. */
148
+ interface PrimarySmoothingState {
149
+ /** Classic EMA value (when movingAverageType === 'ema'). */
150
+ emaValue?: number;
151
+ /** Baseline EMA part of adaptive EMA (slower). */
152
+ adaptiveBaseEmaValue?: number;
153
+ /** Fast adaptive EMA (higher alpha under variance). */
154
+ adaptiveEmaValue?: number;
155
+ }
156
+
157
+ /** State container for plateau EMA smoothing. */
158
+ interface PlateauSmoothingState {
159
+ plateauEmaValue?: number;
160
+ }
161
+
162
+ /** Configuration passed to monitored (primary) smoothing computation. */
163
+ interface MonitoredSmoothingConfig {
164
+ type: MovingAverageType;
165
+ window: number;
166
+ emaAlpha?: number; // optional override (only for EMA types)
167
+ trimmedRatio?: number; // for trimmed mean strategy
168
+ }
169
+
170
+ /** Configuration for plateau smoothing computation. */
171
+ interface PlateauSmoothingConfig {
172
+ type: MovingAverageType;
173
+ window: number;
174
+ emaAlpha?: number;
175
+ }
176
+
177
+ /**
178
+ * Compute the monitored (primary) smoothed error given recent raw errors.
179
+ *
180
+ * Behavior:
181
+ * - For SMA-like strategies uses the supplied window slice directly.
182
+ * - For EMA it mutates state.emaValue.
183
+ * - For adaptive-ema maintains dual EMA tracks inside state and returns the min for stability.
184
+ * - For median / gaussian / trimmed / wma applies algorithmic weighting as documented inline.
185
+ *
186
+ * Inputs:
187
+ * - trainError: Current raw mean error for this iteration.
188
+ * - recentErrors: Chronological array (oldest->newest) of last N raw errors.
189
+ * - cfg: Algorithm selection + parameters.
190
+ * - state: Mutable smoothing state (ema / adaptive fields updated in-place).
191
+ *
192
+ * Returns: Smoothed/monitored error metric (may equal trainError if no smoothing active).
193
+ */
194
+ function computeMonitoredError(
195
+ trainError: number,
196
+ recentErrors: number[],
197
+ cfg: MonitoredSmoothingConfig,
198
+ state: PrimarySmoothingState
199
+ ): number {
200
+ // Fast path: no smoothing window / algorithm requiring history.
201
+ if (cfg.window <= 1 && cfg.type !== 'ema' && cfg.type !== 'adaptive-ema') {
202
+ return trainError;
203
+ }
204
+ const type = cfg.type;
205
+ if (type === 'median') {
206
+ const sorted = [...recentErrors].sort((a, b) => a - b);
207
+ const midIndex = Math.floor(sorted.length / 2);
208
+ return sorted.length % 2
209
+ ? sorted[midIndex]
210
+ : (sorted[midIndex - 1] + sorted[midIndex]) / 2;
211
+ }
212
+ if (type === 'ema') {
213
+ // Standard exponential moving average.
214
+ if (state.emaValue == null) state.emaValue = trainError;
215
+ else
216
+ state.emaValue =
217
+ state.emaValue + cfg.emaAlpha! * (trainError - state.emaValue);
218
+ return state.emaValue;
219
+ }
220
+ if (type === 'adaptive-ema') {
221
+ // Adaptive EMA: baseline alpha + volatility-inflated alpha, final metric is more conservative (min).
222
+ const mean = recentErrors.reduce((a, b) => a + b, 0) / recentErrors.length;
223
+ const variance =
224
+ recentErrors.reduce((a, b) => a + (b - mean) * (b - mean), 0) /
225
+ recentErrors.length;
226
+ const baseAlpha = cfg.emaAlpha || 2 / (cfg.window + 1);
227
+ const varianceScaled = variance / Math.max(mean * mean, 1e-8);
228
+ const adaptiveAlpha = Math.min(
229
+ 0.95,
230
+ Math.max(baseAlpha, baseAlpha * (1 + 2 * varianceScaled))
231
+ );
232
+ if (state.adaptiveBaseEmaValue == null) {
233
+ state.adaptiveBaseEmaValue = trainError;
234
+ state.adaptiveEmaValue = trainError;
235
+ } else {
236
+ state.adaptiveBaseEmaValue =
237
+ state.adaptiveBaseEmaValue +
238
+ baseAlpha * (trainError - state.adaptiveBaseEmaValue);
239
+ state.adaptiveEmaValue =
240
+ state.adaptiveEmaValue! +
241
+ adaptiveAlpha * (trainError - state.adaptiveEmaValue!);
242
+ }
243
+ return Math.min(state.adaptiveEmaValue!, state.adaptiveBaseEmaValue!);
244
+ }
245
+ if (type === 'gaussian') {
246
+ // Gaussian kernel weights centered at newest element (index length-1).
247
+ const sigma = cfg.window / 3 || 1; // heuristic: cover window ~3 sigma
248
+ let weightSum = 0;
249
+ let weightedAccumulator = 0;
250
+ const length = recentErrors.length;
251
+ for (let i = 0; i < length; i++) {
252
+ const weight = Math.exp(-0.5 * Math.pow((i - (length - 1)) / sigma, 2));
253
+ weightSum += weight;
254
+ weightedAccumulator += weight * recentErrors[i];
255
+ }
256
+ return weightedAccumulator / (weightSum || 1);
257
+ }
258
+ if (type === 'trimmed') {
259
+ // Trim symmetric tails before averaging to reduce outlier influence.
260
+ const ratio = Math.min(0.49, Math.max(0, cfg.trimmedRatio || 0.1));
261
+ const sorted = [...recentErrors].sort((a, b) => a - b);
262
+ const drop = Math.floor(sorted.length * ratio);
263
+ const trimmed = sorted.slice(drop, sorted.length - drop);
264
+ return trimmed.reduce((a, b) => a + b, 0) / (trimmed.length || 1);
265
+ }
266
+ if (type === 'wma') {
267
+ // Linear weighting (oldest weight=1 ... newest weight=n).
268
+ let weightSum = 0;
269
+ let weightedAccumulator = 0;
270
+ for (let i = 0; i < recentErrors.length; i++) {
271
+ const weight = i + 1;
272
+ weightSum += weight;
273
+ weightedAccumulator += weight * recentErrors[i];
274
+ }
275
+ return weightedAccumulator / (weightSum || 1);
276
+ }
277
+ // Default: arithmetic mean (SMA).
278
+ return recentErrors.reduce((a, b) => a + b, 0) / recentErrors.length;
279
+ }
280
+
281
+ /**
282
+ * Compute plateau metric (may differ in strategy from primary monitored error).
283
+ * Only algorithms actually supported for plateau in current pipeline are SMA, median and EMA.
284
+ * Provided flexibility keeps room for extension; unsupported types silently fallback to mean.
285
+ */
286
+ function computePlateauMetric(
287
+ trainError: number,
288
+ plateauErrors: number[],
289
+ cfg: PlateauSmoothingConfig,
290
+ state: PlateauSmoothingState
291
+ ): number {
292
+ if (cfg.window <= 1 && cfg.type !== 'ema') return trainError;
293
+ if (cfg.type === 'median') {
294
+ const sorted = [...plateauErrors].sort((a, b) => a - b);
295
+ const mid = Math.floor(sorted.length / 2);
296
+ return sorted.length % 2
297
+ ? sorted[mid]
298
+ : (sorted[mid - 1] + sorted[mid]) / 2;
299
+ }
300
+ if (cfg.type === 'ema') {
301
+ if (state.plateauEmaValue == null) state.plateauEmaValue = trainError;
302
+ else
303
+ state.plateauEmaValue =
304
+ state.plateauEmaValue +
305
+ cfg.emaAlpha! * (trainError - state.plateauEmaValue);
306
+ return state.plateauEmaValue;
307
+ }
308
+ // Fallback default mean.
309
+ return plateauErrors.reduce((a, b) => a + b, 0) / plateauErrors.length;
310
+ }
311
+
312
+ // Internal export bundle (test-only usage) to enable direct branch coverage of smoothing helpers.
313
+ // Marked with double underscore to discourage production use.
314
+ export const __trainingInternals = {
315
+ computeMonitoredError,
316
+ computePlateauMetric,
317
+ };
318
+
319
+ /**
320
+ * Detect mixed precision overflow (NaN / Inf) in bias values if mixed precision enabled.
321
+ * Side-effect: may clear internal trigger _forceNextOverflow.
322
+ */
323
+ function detectMixedPrecisionOverflow(net: Network, internalNet: any): boolean {
324
+ if (!internalNet._mixedPrecision.enabled) return false;
325
+ if (internalNet._forceNextOverflow) {
326
+ internalNet._forceNextOverflow = false;
327
+ return true;
328
+ }
329
+ let overflow = false;
330
+ net.nodes.forEach((node) => {
331
+ if ((node as any)._fp32Bias !== undefined) {
332
+ if (!Number.isFinite((node as any).bias)) overflow = true;
333
+ }
334
+ });
335
+ return overflow;
336
+ }
337
+
338
+ /** Zero-out accumulated gradient buffers after an overflow to discard invalid updates. */
339
+ function zeroAccumulatedGradients(net: Network) {
340
+ net.nodes.forEach((node) => {
341
+ (node as any).connections.in.forEach((c: any) => {
342
+ c.totalDeltaWeight = 0;
343
+ });
344
+ (node as any).connections.self.forEach((c: any) => {
345
+ c.totalDeltaWeight = 0;
346
+ });
347
+ if (typeof (node as any).totalDeltaBias === 'number')
348
+ (node as any).totalDeltaBias = 0;
349
+ (node as any).previousDeltaBias = 0;
350
+ });
351
+ }
352
+
353
+ /** Divide accumulated gradients by accumulationSteps (average reduction mode). */
354
+ function averageAccumulatedGradients(net: Network, accumulationSteps: number) {
355
+ if (accumulationSteps <= 1) return;
356
+ net.nodes.forEach((node) => {
357
+ (node as any).connections.in.forEach((c: any) => {
358
+ if (typeof c.totalDeltaWeight === 'number')
359
+ c.totalDeltaWeight /= accumulationSteps;
360
+ });
361
+ (node as any).connections.self.forEach((c: any) => {
362
+ if (typeof c.totalDeltaWeight === 'number')
363
+ c.totalDeltaWeight /= accumulationSteps;
364
+ });
365
+ if (typeof (node as any).totalDeltaBias === 'number')
366
+ (node as any).totalDeltaBias /= accumulationSteps;
367
+ });
368
+ }
369
+
370
+ /** Apply optimizer update step across all nodes; returns gradient L2 norm (approx). */
371
+ function applyOptimizerStep(
372
+ net: Network,
373
+ optimizer: any,
374
+ currentRate: number,
375
+ momentum: number,
376
+ internalNet: any
377
+ ): number {
378
+ let sumSq = 0;
379
+ net.nodes.forEach((node) => {
380
+ if (node.type === 'input') return;
381
+ (node as any).applyBatchUpdatesWithOptimizer({
382
+ type: optimizer.type,
383
+ baseType: optimizer.baseType,
384
+ beta1: optimizer.beta1,
385
+ beta2: optimizer.beta2,
386
+ eps: optimizer.eps,
387
+ weightDecay: optimizer.weightDecay,
388
+ momentum: optimizer.momentum ?? momentum,
389
+ lrScale: currentRate,
390
+ t: internalNet._optimizerStep,
391
+ la_k: optimizer.la_k,
392
+ la_alpha: optimizer.la_alpha,
393
+ });
394
+ (node as any).connections.in.forEach((c: any) => {
395
+ if (typeof c.previousDeltaWeight === 'number')
396
+ sumSq += c.previousDeltaWeight * c.previousDeltaWeight;
397
+ });
398
+ (node as any).connections.self.forEach((c: any) => {
399
+ if (typeof c.previousDeltaWeight === 'number')
400
+ sumSq += c.previousDeltaWeight * c.previousDeltaWeight;
401
+ });
402
+ });
403
+ return Math.sqrt(sumSq);
404
+ }
405
+
406
+ /** Update dynamic loss scaling after a successful (non-overflow) optimizer step. */
407
+ function maybeIncreaseLossScale(internalNet: any) {
408
+ internalNet._mixedPrecisionState.goodSteps++;
409
+ const incEvery = internalNet._mpIncreaseEvery || 200;
410
+ if (
411
+ internalNet._mixedPrecisionState.goodSteps >= incEvery &&
412
+ internalNet._mixedPrecision.lossScale <
413
+ internalNet._mixedPrecisionState.maxLossScale
414
+ ) {
415
+ internalNet._mixedPrecision.lossScale *= 2;
416
+ internalNet._mixedPrecisionState.goodSteps = 0;
417
+ internalNet._mixedPrecisionState.scaleUpEvents =
418
+ (internalNet._mixedPrecisionState.scaleUpEvents || 0) + 1;
419
+ }
420
+ }
421
+
422
+ /** Respond to a mixed precision overflow by shrinking loss scale & bookkeeping. */
423
+ function handleOverflow(internalNet: any) {
424
+ internalNet._mixedPrecisionState.badSteps++;
425
+ internalNet._mixedPrecisionState.goodSteps = 0;
426
+ internalNet._mixedPrecision.lossScale = Math.max(
427
+ internalNet._mixedPrecisionState.minLossScale,
428
+ Math.floor(internalNet._mixedPrecision.lossScale / 2) || 1
429
+ );
430
+ internalNet._mixedPrecisionState.overflowCount =
431
+ (internalNet._mixedPrecisionState.overflowCount || 0) + 1;
432
+ internalNet._mixedPrecisionState.scaleDownEvents =
433
+ (internalNet._mixedPrecisionState.scaleDownEvents || 0) + 1;
434
+ internalNet._lastOverflowStep = internalNet._optimizerStep;
435
+ }
436
+
437
+ /**
438
+ * Apply gradient clipping to accumulated connection deltas / bias deltas.
439
+ *
440
+ * Modes:
441
+ * - norm / layerwiseNorm: L2 norm scaling (global vs per group).
442
+ * - percentile / layerwisePercentile: element-wise clamp at absolute percentile threshold.
443
+ *
444
+ * Grouping:
445
+ * - If layerwise* and net.layers exists -> each defined layer is a group.
446
+ * - Else if layerwise* -> each non-input node becomes its own group.
447
+ * - Otherwise a single global group containing all learnable params.
448
+ */
449
+ export function applyGradientClippingImpl(
450
+ net: Network,
451
+ cfg: {
452
+ mode: 'norm' | 'percentile' | 'layerwiseNorm' | 'layerwisePercentile';
453
+ maxNorm?: number;
454
+ percentile?: number;
455
+ }
456
+ ) {
457
+ const internalNet = net as any;
458
+ /**
459
+ * Build arrays of gradient values grouped according to chosen clipping mode.
460
+ * Each group is later processed independently (layerwise modes) or as a single global set.
461
+ */
462
+ const collectGroups = () => {
463
+ const collected: number[][] = [];
464
+ if (cfg.mode.startsWith('layerwise')) {
465
+ if ((net as any).layers && (net as any).layers.length > 0) {
466
+ for (let li = 0; li < (net as any).layers.length; li++) {
467
+ const layer = (net as any).layers[li];
468
+ if (!layer || !layer.nodes) continue;
469
+ const groupVals: number[] = [];
470
+ layer.nodes.forEach((node: any) => {
471
+ if (!node || node.type === 'input') return;
472
+ node.connections.in.forEach((c: any) => {
473
+ if (typeof c.totalDeltaWeight === 'number')
474
+ groupVals.push(c.totalDeltaWeight);
475
+ });
476
+ node.connections.self.forEach((c: any) => {
477
+ if (typeof c.totalDeltaWeight === 'number')
478
+ groupVals.push(c.totalDeltaWeight);
479
+ });
480
+ if (typeof node.totalDeltaBias === 'number')
481
+ groupVals.push(node.totalDeltaBias);
482
+ });
483
+ if (groupVals.length) collected.push(groupVals);
484
+ }
485
+ } else {
486
+ net.nodes.forEach((node) => {
487
+ if (node.type === 'input') return;
488
+ const groupVals: number[] = [];
489
+ (node as any).connections.in.forEach((c: any) => {
490
+ if (typeof c.totalDeltaWeight === 'number')
491
+ groupVals.push(c.totalDeltaWeight);
492
+ });
493
+ (node as any).connections.self.forEach((c: any) => {
494
+ if (typeof c.totalDeltaWeight === 'number')
495
+ groupVals.push(c.totalDeltaWeight);
496
+ });
497
+ if (typeof (node as any).totalDeltaBias === 'number')
498
+ groupVals.push((node as any).totalDeltaBias);
499
+ if (groupVals.length) collected.push(groupVals);
500
+ });
501
+ }
502
+ } else {
503
+ const globalVals: number[] = [];
504
+ net.nodes.forEach((node) => {
505
+ (node as any).connections.in.forEach((c: any) => {
506
+ if (typeof c.totalDeltaWeight === 'number')
507
+ globalVals.push(c.totalDeltaWeight);
508
+ });
509
+ (node as any).connections.self.forEach((c: any) => {
510
+ if (typeof c.totalDeltaWeight === 'number')
511
+ globalVals.push(c.totalDeltaWeight);
512
+ });
513
+ if (typeof (node as any).totalDeltaBias === 'number')
514
+ globalVals.push((node as any).totalDeltaBias);
515
+ });
516
+ if (globalVals.length) collected.push(globalVals);
517
+ }
518
+ return collected;
519
+ };
520
+ /**
521
+ * Gradient groups discovered for clipping (size: 1 for global modes).
522
+ * Each entry is an array of parameter delta values belonging to a logical group (layer or node level).
523
+ */
524
+ const groups = collectGroups();
525
+ /** Tracking for diagnostics / potential external tooling. */
526
+ internalNet._lastGradClipGroupCount = groups.length;
527
+ /**
528
+ * Compute absolute percentile threshold (e.g. percentile=99 => value whose |value| is at the 99th percentile).
529
+ * Sorting by absolute value guarantees consistent clipping for symmetric distributions.
530
+ */
531
+ const computeAbsolutePercentileThreshold = (
532
+ values: number[],
533
+ percentile: number
534
+ ) => {
535
+ if (!values.length) return 0;
536
+ const sortedByAbs = [...values].sort((a, b) => Math.abs(a) - Math.abs(b));
537
+ const rank = Math.min(
538
+ sortedByAbs.length - 1,
539
+ Math.max(0, Math.floor((percentile / 100) * sortedByAbs.length - 1))
540
+ );
541
+ return Math.abs(sortedByAbs[rank]);
542
+ };
543
+ /**
544
+ * Iterate all learnable parameters applying a transform function.
545
+ * The transform receives the current value and the owning group so it can selectively scale only
546
+ * the active group (when computing per-group scaling factor yet iterating entire model).
547
+ */
548
+ const applyScale = (
549
+ scaleFn: (currentValue: number, owningGroup: number[]) => number
550
+ ) => {
551
+ let groupIndex = 0; // advances only for layerwise modes
552
+ net.nodes.forEach((node) => {
553
+ if (cfg.mode.startsWith('layerwise') && node.type === 'input') return; // skip input nodes in layerwise grouping
554
+ const activeGroup = cfg.mode.startsWith('layerwise')
555
+ ? groups[groupIndex++]
556
+ : groups[0];
557
+ (node as any).connections.in.forEach((c: any) => {
558
+ if (typeof c.totalDeltaWeight === 'number')
559
+ c.totalDeltaWeight = scaleFn(c.totalDeltaWeight, activeGroup);
560
+ });
561
+ (node as any).connections.self.forEach((c: any) => {
562
+ if (typeof c.totalDeltaWeight === 'number')
563
+ c.totalDeltaWeight = scaleFn(c.totalDeltaWeight, activeGroup);
564
+ });
565
+ if (typeof (node as any).totalDeltaBias === 'number')
566
+ (node as any).totalDeltaBias = scaleFn(
567
+ (node as any).totalDeltaBias,
568
+ activeGroup
569
+ );
570
+ });
571
+ };
572
+ if (cfg.mode === 'norm' || cfg.mode === 'layerwiseNorm') {
573
+ /** Maximum allowed L2 norm per group (or global). */
574
+ const maxAllowedNorm = cfg.maxNorm || 1;
575
+ groups.forEach((groupValues) => {
576
+ /** Current group L2 norm. */
577
+ const groupL2Norm = Math.sqrt(
578
+ groupValues.reduce((sum, v) => sum + v * v, 0)
579
+ );
580
+ if (groupL2Norm > maxAllowedNorm && groupL2Norm > 0) {
581
+ /** Scaling factor applied uniformly to bring norm to boundary. */
582
+ const normScaleFactor = maxAllowedNorm / groupL2Norm;
583
+ applyScale((currentValue, owningGroup) =>
584
+ owningGroup === groupValues
585
+ ? currentValue * normScaleFactor
586
+ : currentValue
587
+ );
588
+ }
589
+ });
590
+ } else if (cfg.mode === 'percentile' || cfg.mode === 'layerwisePercentile') {
591
+ /** Percentile specifying absolute magnitude cutoff (values above are clamped). */
592
+ const percentileSetting = cfg.percentile || 99;
593
+ groups.forEach((groupValues) => {
594
+ const percentileThreshold = computeAbsolutePercentileThreshold(
595
+ groupValues,
596
+ percentileSetting
597
+ );
598
+ if (percentileThreshold <= 0) return;
599
+ applyScale((currentValue, owningGroup) =>
600
+ owningGroup === groupValues &&
601
+ Math.abs(currentValue) > percentileThreshold
602
+ ? percentileThreshold * Math.sign(currentValue)
603
+ : currentValue
604
+ );
605
+ });
606
+ }
607
+ }
608
+
609
+ /**
610
+ * Execute one full pass over dataset (epoch) with optional accumulation & adaptive optimizer.
611
+ * Returns mean cost across processed samples.
612
+ */
613
+ export function trainSetImpl(
614
+ net: Network,
615
+ set: { input: number[]; output: number[] }[],
616
+ batchSize: number,
617
+ accumulationSteps: number,
618
+ currentRate: number,
619
+ momentum: number,
620
+ regularization: any,
621
+ costFunction: (target: number[], output: number[]) => number,
622
+ optimizer?: any
623
+ ): number {
624
+ const internalNet = net as any;
625
+ /** Sum of raw (unsmoothed) cost values across valid samples. */
626
+ let cumulativeError = 0;
627
+ /** Number of samples processed in current mini-batch (resets after potential optimizer step). */
628
+ let batchSampleCount = 0;
629
+ /** Counter of micro-batches contributing to current accumulated gradient set. */
630
+ internalNet._gradAccumMicroBatches = 0;
631
+ /** Total number of dataset samples actually processed (dimension-valid). */
632
+ let totalProcessedSamples = 0;
633
+ /** Cached list of output layer nodes (backprop order requires targets). */
634
+ const outputNodes = net.nodes.filter((n) => n.type === 'output');
635
+ /** Unified cost evaluation function resolved from provided cost variant. */
636
+ let computeError: (t: number[], o: number[]) => number;
637
+ if (typeof costFunction === 'function') computeError = costFunction as any;
638
+ else if (
639
+ (costFunction as any) &&
640
+ typeof (costFunction as any).fn === 'function'
641
+ )
642
+ computeError = (costFunction as any).fn;
643
+ else if (
644
+ (costFunction as any) &&
645
+ typeof (costFunction as any).calculate === 'function'
646
+ )
647
+ computeError = (costFunction as any).calculate;
648
+ else computeError = () => 0;
649
+
650
+ for (let sampleIndex = 0; sampleIndex < set.length; sampleIndex++) {
651
+ /** Current training sample record (input + target). */
652
+ const dataPoint = set[sampleIndex];
653
+ /** Input feature vector (validated for dimension). */
654
+ const input = dataPoint.input;
655
+ /** Target output vector (validated for dimension). */
656
+ const target = dataPoint.output;
657
+ if (input.length !== net.input || target.length !== net.output) {
658
+ if (config.warnings)
659
+ console.warn(
660
+ `Data point ${sampleIndex} has incorrect dimensions (input: ${input.length}/${net.input}, output: ${target.length}/${net.output}), skipping.`
661
+ );
662
+ continue;
663
+ }
664
+ try {
665
+ // Forward pass with training flag (enables dropout / any stochastic layers).
666
+ const output = (net as any).activate(input, true);
667
+ if (optimizer && optimizer.type && optimizer.type !== 'sgd') {
668
+ // Accumulate gradients for adaptive optimizers (no immediate weight update inside propagate).
669
+ for (let outIndex = 0; outIndex < outputNodes.length; outIndex++)
670
+ (outputNodes[outIndex] as any).propagate(
671
+ currentRate,
672
+ momentum,
673
+ false,
674
+ regularization,
675
+ target[outIndex]
676
+ );
677
+ for (
678
+ let reverseIndex = net.nodes.length - 1;
679
+ reverseIndex >= 0;
680
+ reverseIndex--
681
+ ) {
682
+ const node = net.nodes[reverseIndex];
683
+ if (node.type === 'output' || node.type === 'input') continue;
684
+ (node as any).propagate(currentRate, momentum, false, regularization);
685
+ }
686
+ } else {
687
+ // SGD mode: propagate performs immediate parameter updates using deltas.
688
+ for (let outIndex = 0; outIndex < outputNodes.length; outIndex++)
689
+ (outputNodes[outIndex] as any).propagate(
690
+ currentRate,
691
+ momentum,
692
+ true,
693
+ regularization,
694
+ target[outIndex]
695
+ );
696
+ for (
697
+ let reverseIndex = net.nodes.length - 1;
698
+ reverseIndex >= 0;
699
+ reverseIndex--
700
+ ) {
701
+ const node = net.nodes[reverseIndex];
702
+ if (node.type === 'output' || node.type === 'input') continue;
703
+ (node as any).propagate(currentRate, momentum, true, regularization);
704
+ }
705
+ }
706
+ cumulativeError += computeError(target, output);
707
+ batchSampleCount++;
708
+ totalProcessedSamples++;
709
+ } catch (e: any) {
710
+ if (config.warnings)
711
+ console.warn(
712
+ `Error processing data point ${sampleIndex} (input: ${JSON.stringify(
713
+ input
714
+ )}): ${e.message}. Skipping.`
715
+ );
716
+ }
717
+ // Mini-batch / end-of-dataset flush condition.
718
+ if (
719
+ batchSampleCount > 0 &&
720
+ ((sampleIndex + 1) % batchSize === 0 || sampleIndex === set.length - 1)
721
+ ) {
722
+ if (optimizer && optimizer.type && optimizer.type !== 'sgd') {
723
+ // Only adaptive optimizers delay the step; vanilla SGD already updated weights per sample.
724
+ internalNet._gradAccumMicroBatches++;
725
+ /** True when we have accumulated sufficient micro-batches or reached dataset end. */
726
+ const readyForStep =
727
+ internalNet._gradAccumMicroBatches % accumulationSteps === 0 ||
728
+ sampleIndex === set.length - 1;
729
+ if (readyForStep) {
730
+ /** 1-based optimizer step counter (used for bias-correction terms by adaptive methods). */
731
+ internalNet._optimizerStep = (internalNet._optimizerStep || 0) + 1;
732
+ /** Detect overflow under mixed precision (NaN/Inf). */
733
+ const overflowDetected = detectMixedPrecisionOverflow(
734
+ net,
735
+ internalNet
736
+ );
737
+ if (overflowDetected) {
738
+ // Discard invalid gradients & shrink loss scale.
739
+ zeroAccumulatedGradients(net);
740
+ if (internalNet._mixedPrecision.enabled)
741
+ handleOverflow(internalNet);
742
+ internalNet._lastGradNorm = 0;
743
+ } else {
744
+ // Optional gradient clipping before optimizer math.
745
+ if (internalNet._currentGradClip)
746
+ applyGradientClippingImpl(net, internalNet._currentGradClip);
747
+ // Average accumulated micro-batch gradients if configured.
748
+ if (
749
+ accumulationSteps > 1 &&
750
+ internalNet._accumulationReduction === 'average'
751
+ ) {
752
+ averageAccumulatedGradients(net, accumulationSteps);
753
+ }
754
+ // Apply optimizer updates and compute gradient norm.
755
+ internalNet._lastGradNorm = applyOptimizerStep(
756
+ net,
757
+ optimizer,
758
+ currentRate,
759
+ momentum,
760
+ internalNet
761
+ );
762
+ // Dynamic loss scaling increase if conditions satisfied.
763
+ if (internalNet._mixedPrecision.enabled)
764
+ maybeIncreaseLossScale(internalNet);
765
+ }
766
+ }
767
+ batchSampleCount = 0; // reset mini-batch sample counter
768
+ }
769
+ }
770
+ }
771
+ if (internalNet._lastGradNorm == null) internalNet._lastGradNorm = 0;
772
+ return totalProcessedSamples > 0
773
+ ? cumulativeError / totalProcessedSamples
774
+ : 0;
775
+ }
776
+
777
+ /**
778
+ * High-level training orchestration with early stopping, smoothing & callbacks.
779
+ */
780
+ export function trainImpl(
781
+ net: Network,
782
+ set: { input: number[]; output: number[] }[],
783
+ options: TrainingOptions
784
+ ): { error: number; iterations: number; time: number } {
785
+ const internalNet = net as any;
786
+ if (
787
+ !set ||
788
+ set.length === 0 ||
789
+ set[0].input.length !== net.input ||
790
+ set[0].output.length !== net.output
791
+ ) {
792
+ throw new Error(
793
+ 'Dataset is invalid or dimensions do not match network input/output size!'
794
+ );
795
+ }
796
+ options = options || {};
797
+ if (
798
+ typeof options.iterations === 'undefined' &&
799
+ typeof options.error === 'undefined'
800
+ ) {
801
+ if (config.warnings)
802
+ console.warn('Missing `iterations` or `error` option.');
803
+ throw new Error(
804
+ 'Missing `iterations` or `error` option. Training requires a stopping condition.'
805
+ );
806
+ }
807
+ if (config.warnings) {
808
+ if (typeof options.rate === 'undefined') {
809
+ console.warn('Missing `rate` option');
810
+ console.warn('Missing `rate` option, using default learning rate 0.3.');
811
+ }
812
+ if (typeof options.iterations === 'undefined')
813
+ console.warn(
814
+ 'Missing `iterations` option. Training will run potentially indefinitely until `error` threshold is met.'
815
+ );
816
+ }
817
+ /** Target monitored (smoothed) error threshold for early termination. */
818
+ let targetError = options.error ?? -Infinity;
819
+ /** Cost function (defaults to MSE) resolved from provided variant. */
820
+ const cost = options.cost || methods.Cost.mse;
821
+ if (
822
+ typeof cost !== 'function' &&
823
+ !(
824
+ typeof cost === 'object' &&
825
+ (typeof (cost as any).fn === 'function' ||
826
+ typeof (cost as any).calculate === 'function')
827
+ )
828
+ ) {
829
+ throw new Error('Invalid cost function provided to Network.train.');
830
+ }
831
+ /** Base learning rate used as scaling factor for optimizer weight updates. */
832
+ const baseRate = options.rate ?? 0.3;
833
+ /** Dropout probability applied each forward pass (0 disables). */
834
+ const dropout = options.dropout || 0;
835
+ if (dropout < 0 || dropout >= 1) throw new Error('dropout must be in [0,1)');
836
+ /** Momentum factor for SGD or reused by optimizers expecting momentum param. */
837
+ const momentum = options.momentum || 0;
838
+ /** Mini-batch size (#samples per gradient accumulation flush). */
839
+ const batchSize = options.batchSize || 1;
840
+ if (batchSize > set.length)
841
+ throw new Error('Batch size cannot be larger than the dataset length.');
842
+ /** Gradient accumulation factor (micro-batches per optimizer step). */
843
+ const accumulationSteps = options.accumulationSteps || 1;
844
+ internalNet._accumulationReduction =
845
+ options.accumulationReduction === 'sum' ? 'sum' : 'average';
846
+ if (accumulationSteps < 1 || !Number.isFinite(accumulationSteps))
847
+ throw new Error('accumulationSteps must be >=1');
848
+ if (options.gradientClip) {
849
+ const gc = options.gradientClip;
850
+ if (gc.mode)
851
+ internalNet._currentGradClip = {
852
+ mode: gc.mode,
853
+ maxNorm: gc.maxNorm,
854
+ percentile: gc.percentile,
855
+ } as any;
856
+ else if (typeof gc.maxNorm === 'number')
857
+ internalNet._currentGradClip = { mode: 'norm', maxNorm: gc.maxNorm };
858
+ else if (typeof gc.percentile === 'number')
859
+ internalNet._currentGradClip = {
860
+ mode: 'percentile',
861
+ percentile: gc.percentile,
862
+ } as any;
863
+ internalNet._gradClipSeparateBias = !!gc.separateBias;
864
+ } else {
865
+ internalNet._currentGradClip = undefined;
866
+ internalNet._gradClipSeparateBias = false;
867
+ }
868
+ if (options.mixedPrecision) {
869
+ const mp =
870
+ options.mixedPrecision === true
871
+ ? { lossScale: 1024 }
872
+ : options.mixedPrecision;
873
+ internalNet._mixedPrecision.enabled = true;
874
+ internalNet._mixedPrecision.lossScale = mp.lossScale || 1024;
875
+ const dyn = mp.dynamic || {};
876
+ internalNet._mixedPrecisionState.minLossScale = dyn.minScale || 1;
877
+ internalNet._mixedPrecisionState.maxLossScale = dyn.maxScale || 65536;
878
+ internalNet._mpIncreaseEvery =
879
+ dyn.increaseEvery || dyn.stableStepsForIncrease || 200;
880
+ net.connections.forEach((c) => {
881
+ (c as any)._fp32Weight = c.weight;
882
+ });
883
+ net.nodes.forEach((n) => {
884
+ if (n.type !== 'input') (n as any)._fp32Bias = n.bias;
885
+ });
886
+ } else {
887
+ internalNet._mixedPrecision.enabled = false;
888
+ internalNet._mixedPrecision.lossScale = 1;
889
+ internalNet._mpIncreaseEvery = 200;
890
+ }
891
+ /** Supported optimizer algorithm identifiers (lowercased). */
892
+ const allowedOptimizers = new Set([
893
+ 'sgd',
894
+ 'rmsprop',
895
+ 'adagrad',
896
+ 'adam',
897
+ 'adamw',
898
+ 'amsgrad',
899
+ 'adamax',
900
+ 'nadam',
901
+ 'radam',
902
+ 'lion',
903
+ 'adabelief',
904
+ 'lookahead',
905
+ ]);
906
+ /** Normalized optimizer configuration or undefined for pure SGD mode. */
907
+ let optimizerConfig: any = undefined;
908
+ if (typeof options.optimizer !== 'undefined') {
909
+ if (typeof options.optimizer === 'string')
910
+ optimizerConfig = { type: options.optimizer.toLowerCase() };
911
+ else if (
912
+ typeof options.optimizer === 'object' &&
913
+ options.optimizer !== null
914
+ ) {
915
+ optimizerConfig = { ...options.optimizer };
916
+ if (typeof optimizerConfig.type === 'string')
917
+ optimizerConfig.type = optimizerConfig.type.toLowerCase();
918
+ } else
919
+ throw new Error('Invalid optimizer option; must be string or object');
920
+ if (!allowedOptimizers.has(optimizerConfig.type))
921
+ throw new Error(`Unknown optimizer type: ${optimizerConfig.type}`);
922
+ if (optimizerConfig.type === 'lookahead') {
923
+ if (!optimizerConfig.baseType) optimizerConfig.baseType = 'adam';
924
+ if (optimizerConfig.baseType === 'lookahead')
925
+ throw new Error(
926
+ 'Nested lookahead (baseType lookahead) is not supported'
927
+ );
928
+ if (!allowedOptimizers.has(optimizerConfig.baseType))
929
+ throw new Error(
930
+ `Unknown baseType for lookahead: ${optimizerConfig.baseType}`
931
+ );
932
+ optimizerConfig.la_k = optimizerConfig.la_k || 5;
933
+ optimizerConfig.la_alpha = optimizerConfig.la_alpha ?? 0.5;
934
+ }
935
+ }
936
+ /** Maximum training iterations permitted (guard against infinite loops w/ only error criterion). */
937
+ const iterations = options.iterations ?? Number.MAX_SAFE_INTEGER;
938
+ /** Wall-clock start time for duration metric. */
939
+ const start = Date.now();
940
+ /** Most recent monitored (smoothed) error value. */
941
+ let finalError = Infinity;
942
+ /** Window length for primary moving average smoothing. */
943
+ const movingAverageWindow = Math.max(1, options.movingAverageWindow || 1);
944
+ /** Selected smoothing algorithm kind. */
945
+ const movingAverageType = options.movingAverageType || 'sma';
946
+ /** EMA alpha (if EMA selected) computed via CMA formula unless explicitly overridden. */
947
+ const emaAlpha = (() => {
948
+ if (movingAverageType !== 'ema') return undefined;
949
+ if (options.emaAlpha && options.emaAlpha > 0 && options.emaAlpha <= 1)
950
+ return options.emaAlpha;
951
+ return 2 / (movingAverageWindow + 1);
952
+ })();
953
+ /** Separate window for plateau detection (defaults to primary window). */
954
+ const plateauWindow = Math.max(
955
+ 1,
956
+ options.plateauMovingAverageWindow || movingAverageWindow
957
+ );
958
+ /** Smoothing algorithm used specifically for plateau (scheduler / early-stop) metrics. */
959
+ const plateauType = options.plateauMovingAverageType || movingAverageType;
960
+ /** EMA alpha for plateau smoothing if needed. */
961
+ const plateauEmaAlpha = (() => {
962
+ if (plateauType !== 'ema') return undefined;
963
+ if (
964
+ options.plateauEmaAlpha &&
965
+ options.plateauEmaAlpha > 0 &&
966
+ options.plateauEmaAlpha <= 1
967
+ )
968
+ return options.plateauEmaAlpha;
969
+ return 2 / (plateauWindow + 1);
970
+ })();
971
+ /** Max consecutive non-improving iterations tolerated before early stop (undefined => disabled). */
972
+ const earlyStopPatience = options.earlyStopPatience;
973
+ /** Minimal decrease required to qualify as improvement. */
974
+ const earlyStopMinDelta = options.earlyStopMinDelta || 0;
975
+ /** Best (lowest) monitored error observed so far. */
976
+ let bestError = Infinity;
977
+ /** Count of successive iterations without sufficient improvement. */
978
+ let noImproveCount = 0;
979
+ /** Capacity of circular buffer for recent errors. */
980
+ const recentErrorsCapacity = movingAverageWindow;
981
+ /** Circular buffer holding recent raw training errors (for smoothing). */
982
+ const recentErrorsBuf: number[] = new Array(recentErrorsCapacity);
983
+ /** Current number of valid entries in buffer (grows until capacity). */
984
+ let recentErrorsCount = 0;
985
+ /** Next write index within circular buffer. */
986
+ let recentErrorsWriteIdx = 0;
987
+ /** Push a new error value into circular buffer (overwriting oldest when full). */
988
+ const recentErrorsPush = (value: number) => {
989
+ if (recentErrorsCapacity === 1) {
990
+ recentErrorsBuf[0] = value;
991
+ recentErrorsCount = 1;
992
+ recentErrorsWriteIdx = 0;
993
+ return;
994
+ }
995
+ recentErrorsBuf[recentErrorsWriteIdx] = value;
996
+ recentErrorsWriteIdx = (recentErrorsWriteIdx + 1) % recentErrorsCapacity;
997
+ if (recentErrorsCount < recentErrorsCapacity) recentErrorsCount++;
998
+ };
999
+ /** Produce chronologically ordered snapshot of buffered errors. */
1000
+ const recentErrorsChrono = (): number[] => {
1001
+ if (recentErrorsCount === 0) return [];
1002
+ if (recentErrorsCount < recentErrorsCapacity)
1003
+ return recentErrorsBuf.slice(0, recentErrorsCount);
1004
+ const out = new Array(recentErrorsCount);
1005
+ const start = recentErrorsWriteIdx;
1006
+ for (let i = 0; i < recentErrorsCount; i++)
1007
+ out[i] = recentErrorsBuf[(start + i) % recentErrorsCapacity];
1008
+ return out;
1009
+ };
1010
+ /** Exponential moving average state for classic EMA smoothing. */
1011
+ let emaValue: number | undefined = undefined;
1012
+ /** Base EMA state for adaptive EMA (lower variance baseline). */
1013
+ let adaptiveBaseEmaValue: number | undefined = undefined;
1014
+ /** Adaptive EMA state (higher alpha when volatility detected). */
1015
+ let adaptiveEmaValue: number | undefined = undefined;
1016
+ /** Capacity of plateau circular buffer. */
1017
+ const plateauCapacity = plateauWindow;
1018
+ /** Raw errors buffer for plateau smoothing. */
1019
+ const plateauBuf: number[] = new Array(plateauCapacity);
1020
+ /** Current number of plateau entries filled. */
1021
+ let plateauCount = 0;
1022
+ /** Next write index for plateau buffer. */
1023
+ let plateauWriteIdx = 0;
1024
+ /** Insert new training error into plateau buffer. */
1025
+ const plateauPush = (value: number) => {
1026
+ if (plateauCapacity === 1) {
1027
+ plateauBuf[0] = value;
1028
+ plateauCount = 1;
1029
+ plateauWriteIdx = 0;
1030
+ return;
1031
+ }
1032
+ plateauBuf[plateauWriteIdx] = value;
1033
+ plateauWriteIdx = (plateauWriteIdx + 1) % plateauCapacity;
1034
+ if (plateauCount < plateauCapacity) plateauCount++;
1035
+ };
1036
+ /** Chronologically ordered plateau buffer snapshot. */
1037
+ const plateauChrono = (): number[] => {
1038
+ if (plateauCount === 0) return [];
1039
+ if (plateauCount < plateauCapacity)
1040
+ return plateauBuf.slice(0, plateauCount);
1041
+ const out = new Array(plateauCount);
1042
+ const start = plateauWriteIdx;
1043
+ for (let i = 0; i < plateauCount; i++)
1044
+ out[i] = plateauBuf[(start + i) % plateauCapacity];
1045
+ return out;
1046
+ };
1047
+ /** Plateau-specific EMA state (if plateauType === 'ema'). */
1048
+ let plateauEmaValue: number | undefined = undefined;
1049
+ /** Mutate network dropout probability for upcoming epoch iterations. */
1050
+ net.dropout = dropout;
1051
+ /** Number of iterations actually executed (in case of early stopping). */
1052
+ let performedIterations = 0;
1053
+ for (let iter = 1; iter <= iterations; iter++) {
1054
+ // -----------------------------
1055
+ // Iteration prologue
1056
+ // -----------------------------
1057
+ // 'iter' is 1-based to align with common optimizer bias-correction formulae (Adam etc.).
1058
+ if ((net as any)._maybePrune) {
1059
+ (net as any)._maybePrune((internalNet._globalEpoch || 0) + iter);
1060
+ }
1061
+ // Run one epoch pass over dataset (mini-batching handled internally) and obtain raw mean error.
1062
+ const trainError = trainSetImpl(
1063
+ net,
1064
+ set,
1065
+ batchSize,
1066
+ accumulationSteps,
1067
+ baseRate,
1068
+ momentum,
1069
+ {},
1070
+ cost as any,
1071
+ optimizerConfig
1072
+ );
1073
+ // Record that this iteration was fully executed (used if we early break afterwards).
1074
+ performedIterations = iter;
1075
+ // Push raw error into smoothing buffer(s) for subsequent moving-average computation.
1076
+ recentErrorsPush(trainError);
1077
+ /** Monitored error value after smoothing strategy is applied (initially raw). */
1078
+ let monitored = trainError;
1079
+ // -----------------------------
1080
+ // Primary moving-average smoothing block
1081
+ // -----------------------------
1082
+ // Conditions: apply if window > 1 or a strategy that inherently disregards window size (ema/adaptive).
1083
+ if (
1084
+ movingAverageWindow > 1 ||
1085
+ movingAverageType === 'ema' ||
1086
+ movingAverageType === 'adaptive-ema'
1087
+ ) {
1088
+ const recentArr = recentErrorsChrono();
1089
+ if (movingAverageType === 'median') {
1090
+ // Robust central tendency; reduces influence of transient spikes.
1091
+ const sorted = [...recentArr].sort((a, b) => a - b);
1092
+ const mid = Math.floor(sorted.length / 2); // middle index
1093
+ monitored =
1094
+ sorted.length % 2 ? sorted[mid] : (sorted[mid - 1] + sorted[mid]) / 2;
1095
+ } else if (movingAverageType === 'ema') {
1096
+ // Classic exponentially weighted moving average (constant alpha).
1097
+ if (emaValue == null) emaValue = trainError;
1098
+ else emaValue = emaValue + emaAlpha! * (trainError - emaValue);
1099
+ monitored = emaValue;
1100
+ } else if (movingAverageType === 'adaptive-ema') {
1101
+ // Dual EMA: baseline + adaptive alpha that expands under variance to speed reaction, then we keep min.
1102
+ const mean = recentArr.reduce((a, b) => a + b, 0) / recentArr.length;
1103
+ const variance =
1104
+ recentArr.reduce((a, b) => a + (b - mean) * (b - mean), 0) /
1105
+ recentArr.length;
1106
+ const baseAlpha = emaAlpha || 2 / (movingAverageWindow + 1);
1107
+ const varScaled = variance / Math.max(mean * mean, 1e-8);
1108
+ const adaptAlpha = Math.min(
1109
+ 0.95,
1110
+ Math.max(baseAlpha, baseAlpha * (1 + 2 * varScaled))
1111
+ );
1112
+ if (adaptiveBaseEmaValue == null) {
1113
+ adaptiveBaseEmaValue = trainError;
1114
+ adaptiveEmaValue = trainError;
1115
+ } else {
1116
+ adaptiveBaseEmaValue =
1117
+ adaptiveBaseEmaValue +
1118
+ baseAlpha * (trainError - adaptiveBaseEmaValue);
1119
+ adaptiveEmaValue =
1120
+ adaptiveEmaValue! + adaptAlpha * (trainError - adaptiveEmaValue!);
1121
+ }
1122
+ monitored = Math.min(adaptiveEmaValue!, adaptiveBaseEmaValue!);
1123
+ } else if (movingAverageType === 'gaussian') {
1124
+ // Weighted by Gaussian kernel centered at newest point; older (earlier) points get progressively less weight.
1125
+ const gaussianWindow = recentArr;
1126
+ const windowLength = gaussianWindow.length;
1127
+ const sigma = movingAverageWindow / 3 || 1; // heuristic: cover window with ~3 sigma
1128
+ let gaussianWeightSum = 0;
1129
+ let gaussianWeightedAccumulator = 0;
1130
+ for (let gi = 0; gi < windowLength; gi++) {
1131
+ const weight = Math.exp(
1132
+ -0.5 * Math.pow((gi - (windowLength - 1)) / sigma, 2)
1133
+ );
1134
+ gaussianWeightSum += weight;
1135
+ gaussianWeightedAccumulator += weight * gaussianWindow[gi];
1136
+ }
1137
+ monitored = gaussianWeightedAccumulator / (gaussianWeightSum || 1);
1138
+ } else if (movingAverageType === 'trimmed') {
1139
+ // Trim symmetrical tails to damp outliers before averaging.
1140
+ const tailTrimRatio = Math.min(
1141
+ 0.49,
1142
+ Math.max(0, options.trimmedRatio || 0.1)
1143
+ );
1144
+ const sorted = [...recentArr].sort((a, b) => a - b);
1145
+ const elementsToDropEachSide = Math.floor(
1146
+ sorted.length * tailTrimRatio
1147
+ );
1148
+ const trimmedSegment = sorted.slice(
1149
+ elementsToDropEachSide,
1150
+ sorted.length - elementsToDropEachSide
1151
+ );
1152
+ monitored =
1153
+ trimmedSegment.reduce((a, b) => a + b, 0) /
1154
+ (trimmedSegment.length || 1);
1155
+ } else if (movingAverageType === 'wma') {
1156
+ // Linear weights: newer samples more influential.
1157
+ let linearWeightSum = 0;
1158
+ let linearWeightedAccumulator = 0;
1159
+ for (let li = 0; li < recentArr.length; li++) {
1160
+ const weight = li + 1; // oldest gets 1, newest gets N
1161
+ linearWeightSum += weight;
1162
+ linearWeightedAccumulator += weight * recentArr[li];
1163
+ }
1164
+ monitored = linearWeightedAccumulator / (linearWeightSum || 1);
1165
+ } else {
1166
+ // Simple arithmetic mean (SMA).
1167
+ monitored = recentArr.reduce((a, b) => a + b, 0) / recentArr.length;
1168
+ }
1169
+ }
1170
+ // Update finalError with the smoothed/selected monitored metric.
1171
+ finalError = monitored;
1172
+ // Store raw trainError (not smoothed) for plateau evaluation buffer.
1173
+ plateauPush(trainError);
1174
+ /** Plateau-smoothed error (could use different smoothing strategy than monitored). */
1175
+ let plateauError: number | undefined = trainError;
1176
+ if (plateauWindow > 1 || plateauType === 'ema') {
1177
+ if (plateauType === 'median') {
1178
+ // Median for plateau stability over variable noise.
1179
+ const sorted = [...plateauChrono()].sort((a, b) => a - b);
1180
+ const mid = Math.floor(sorted.length / 2);
1181
+ plateauError =
1182
+ sorted.length % 2 ? sorted[mid] : (sorted[mid - 1] + sorted[mid]) / 2;
1183
+ } else if (plateauType === 'ema') {
1184
+ // EMA variant for plateau detection (faster adaptation with controlled lag).
1185
+ if (plateauEmaValue == null) plateauEmaValue = trainError;
1186
+ else
1187
+ plateauEmaValue =
1188
+ plateauEmaValue + plateauEmaAlpha! * (trainError - plateauEmaValue);
1189
+ plateauError = plateauEmaValue;
1190
+ } else {
1191
+ // Default plateau = arithmetic mean over plateau window.
1192
+ const arr = plateauChrono();
1193
+ plateauError = arr.reduce((a, b) => a + b, 0) / arr.length;
1194
+ }
1195
+ }
1196
+ if (typeof options.metricsHook === 'function') {
1197
+ try {
1198
+ // User hook for live metrics logging / dashboards / adaptive schedulers.
1199
+ options.metricsHook({
1200
+ iteration: iter,
1201
+ error: finalError,
1202
+ plateauError,
1203
+ gradNorm: internalNet._lastGradNorm ?? 0,
1204
+ });
1205
+ } catch {}
1206
+ }
1207
+ if (options.checkpoint && typeof options.checkpoint.save === 'function') {
1208
+ if (options.checkpoint.last) {
1209
+ try {
1210
+ // Always save most recent network state.
1211
+ options.checkpoint.save({
1212
+ type: 'last',
1213
+ iteration: iter,
1214
+ error: finalError,
1215
+ network: net.toJSON(),
1216
+ });
1217
+ } catch {}
1218
+ }
1219
+ if (options.checkpoint.best) {
1220
+ if (
1221
+ finalError < (net as any)._checkpointBestError ||
1222
+ (net as any)._checkpointBestError == null
1223
+ ) {
1224
+ // New best model discovered under monitored error metric.
1225
+ (net as any)._checkpointBestError = finalError;
1226
+ try {
1227
+ options.checkpoint.save({
1228
+ type: 'best',
1229
+ iteration: iter,
1230
+ error: finalError,
1231
+ network: net.toJSON(),
1232
+ });
1233
+ } catch {}
1234
+ }
1235
+ }
1236
+ }
1237
+ if (
1238
+ options.schedule &&
1239
+ options.schedule.iterations &&
1240
+ iter % options.schedule.iterations === 0
1241
+ ) {
1242
+ try {
1243
+ // Periodic user-defined callback (e.g., adjust LR, print status, inject curriculum changes).
1244
+ options.schedule.function({ error: finalError, iteration: iter });
1245
+ } catch {}
1246
+ }
1247
+ // -----------------------------
1248
+ // Early stopping logic
1249
+ // -----------------------------
1250
+ if (finalError < bestError - earlyStopMinDelta) {
1251
+ // Sufficient improvement: update best and reset stagnation counter.
1252
+ bestError = finalError;
1253
+ noImproveCount = 0;
1254
+ } else if (earlyStopPatience) {
1255
+ // Track consecutive non-improving iterations.
1256
+ noImproveCount++;
1257
+ }
1258
+ // Patience exhaustion: terminate.
1259
+ if (earlyStopPatience && noImproveCount >= earlyStopPatience) break;
1260
+ // Target error reached: terminate.
1261
+ if (finalError <= targetError) break;
1262
+ }
1263
+ net.nodes.forEach((n) => {
1264
+ if (n.type === 'hidden') n.mask = 1;
1265
+ });
1266
+ // Clear dropout for inference after training completes.
1267
+ net.dropout = 0;
1268
+ internalNet._globalEpoch =
1269
+ (internalNet._globalEpoch || 0) + performedIterations;
1270
+ return {
1271
+ /** Final monitored (possibly smoothed) error achieved at termination. */
1272
+ error: finalError,
1273
+ /** Number of iterations actually executed (could be < requested iterations due to early stop). */
1274
+ iterations: performedIterations,
1275
+ /** Wall-clock training duration in milliseconds. */
1276
+ time: Date.now() - start,
1277
+ };
1278
+ }