@reicek/neataptic-ts 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/ISSUE_TEMPLATE/bug_report.md +33 -0
- package/.github/ISSUE_TEMPLATE/feature_request.md +27 -0
- package/.github/PULL_REQUEST_TEMPLATE.md +28 -0
- package/.github/workflows/ci.yml +41 -0
- package/.github/workflows/deploy-pages.yml +29 -0
- package/.github/workflows/manual_release_pipeline.yml +62 -0
- package/.github/workflows/publish.yml +85 -0
- package/.github/workflows/release_dispatch.yml +38 -0
- package/.travis.yml +5 -0
- package/CONTRIBUTING.md +92 -0
- package/LICENSE +24 -0
- package/ONNX_EXPORT.md +87 -0
- package/README.md +1173 -0
- package/RELEASE.md +54 -0
- package/dist-docs/package.json +1 -0
- package/dist-docs/scripts/generate-docs.d.ts +2 -0
- package/dist-docs/scripts/generate-docs.d.ts.map +1 -0
- package/dist-docs/scripts/generate-docs.js +536 -0
- package/dist-docs/scripts/generate-docs.js.map +1 -0
- package/dist-docs/scripts/render-docs-html.d.ts +2 -0
- package/dist-docs/scripts/render-docs-html.d.ts.map +1 -0
- package/dist-docs/scripts/render-docs-html.js +148 -0
- package/dist-docs/scripts/render-docs-html.js.map +1 -0
- package/docs/FOLDERS.md +14 -0
- package/docs/README.md +1173 -0
- package/docs/architecture/README.md +1391 -0
- package/docs/architecture/index.html +938 -0
- package/docs/architecture/network/README.md +1210 -0
- package/docs/architecture/network/index.html +908 -0
- package/docs/assets/ascii-maze.bundle.js +16542 -0
- package/docs/assets/ascii-maze.bundle.js.map +7 -0
- package/docs/index.html +1419 -0
- package/docs/methods/README.md +670 -0
- package/docs/methods/index.html +477 -0
- package/docs/multithreading/README.md +274 -0
- package/docs/multithreading/index.html +215 -0
- package/docs/multithreading/workers/README.md +23 -0
- package/docs/multithreading/workers/browser/README.md +39 -0
- package/docs/multithreading/workers/browser/index.html +70 -0
- package/docs/multithreading/workers/index.html +57 -0
- package/docs/multithreading/workers/node/README.md +33 -0
- package/docs/multithreading/workers/node/index.html +66 -0
- package/docs/neat/README.md +1284 -0
- package/docs/neat/index.html +906 -0
- package/docs/src/README.md +2659 -0
- package/docs/src/index.html +1579 -0
- package/jest.config.ts +32 -0
- package/package.json +99 -0
- package/plans/HyperMorphoNEAT.md +293 -0
- package/plans/ONNX_EXPORT_PLAN.md +46 -0
- package/scripts/generate-docs.ts +486 -0
- package/scripts/render-docs-html.ts +138 -0
- package/scripts/types.d.ts +2 -0
- package/src/README.md +2659 -0
- package/src/architecture/README.md +1391 -0
- package/src/architecture/activationArrayPool.ts +135 -0
- package/src/architecture/architect.ts +635 -0
- package/src/architecture/connection.ts +148 -0
- package/src/architecture/group.ts +406 -0
- package/src/architecture/layer.ts +804 -0
- package/src/architecture/network/README.md +1210 -0
- package/src/architecture/network/network.activate.ts +223 -0
- package/src/architecture/network/network.connect.ts +157 -0
- package/src/architecture/network/network.deterministic.ts +167 -0
- package/src/architecture/network/network.evolve.ts +426 -0
- package/src/architecture/network/network.gating.ts +186 -0
- package/src/architecture/network/network.genetic.ts +247 -0
- package/src/architecture/network/network.mutate.ts +624 -0
- package/src/architecture/network/network.onnx.ts +463 -0
- package/src/architecture/network/network.prune.ts +216 -0
- package/src/architecture/network/network.remove.ts +96 -0
- package/src/architecture/network/network.serialize.ts +309 -0
- package/src/architecture/network/network.slab.ts +262 -0
- package/src/architecture/network/network.standalone.ts +246 -0
- package/src/architecture/network/network.stats.ts +59 -0
- package/src/architecture/network/network.topology.ts +86 -0
- package/src/architecture/network/network.training.ts +1278 -0
- package/src/architecture/network.ts +1302 -0
- package/src/architecture/node.ts +1288 -0
- package/src/architecture/onnx.ts +3 -0
- package/src/config.ts +83 -0
- package/src/methods/README.md +670 -0
- package/src/methods/activation.ts +372 -0
- package/src/methods/connection.ts +31 -0
- package/src/methods/cost.ts +347 -0
- package/src/methods/crossover.ts +63 -0
- package/src/methods/gating.ts +43 -0
- package/src/methods/methods.ts +8 -0
- package/src/methods/mutation.ts +300 -0
- package/src/methods/rate.ts +257 -0
- package/src/methods/selection.ts +65 -0
- package/src/multithreading/README.md +274 -0
- package/src/multithreading/multi.ts +339 -0
- package/src/multithreading/workers/README.md +23 -0
- package/src/multithreading/workers/browser/README.md +39 -0
- package/src/multithreading/workers/browser/testworker.ts +99 -0
- package/src/multithreading/workers/node/README.md +33 -0
- package/src/multithreading/workers/node/testworker.ts +72 -0
- package/src/multithreading/workers/node/worker.ts +70 -0
- package/src/multithreading/workers/workers.ts +22 -0
- package/src/neat/README.md +1284 -0
- package/src/neat/neat.adaptive.ts +544 -0
- package/src/neat/neat.compat.ts +164 -0
- package/src/neat/neat.constants.ts +20 -0
- package/src/neat/neat.diversity.ts +217 -0
- package/src/neat/neat.evaluate.ts +328 -0
- package/src/neat/neat.evolve.ts +1026 -0
- package/src/neat/neat.export.ts +249 -0
- package/src/neat/neat.helpers.ts +235 -0
- package/src/neat/neat.lineage.ts +220 -0
- package/src/neat/neat.multiobjective.ts +260 -0
- package/src/neat/neat.mutation.ts +718 -0
- package/src/neat/neat.objectives.ts +157 -0
- package/src/neat/neat.pruning.ts +190 -0
- package/src/neat/neat.selection.ts +269 -0
- package/src/neat/neat.speciation.ts +460 -0
- package/src/neat/neat.species.ts +151 -0
- package/src/neat/neat.telemetry.exports.ts +469 -0
- package/src/neat/neat.telemetry.ts +933 -0
- package/src/neat/neat.types.ts +275 -0
- package/src/neat.ts +1042 -0
- package/src/neataptic.ts +10 -0
- package/test/architecture/activationArrayPool.capacity.test.ts +19 -0
- package/test/architecture/activationArrayPool.test.ts +46 -0
- package/test/architecture/connection.test.ts +290 -0
- package/test/architecture/group.test.ts +950 -0
- package/test/architecture/layer.test.ts +1535 -0
- package/test/architecture/network.pruning.test.ts +65 -0
- package/test/architecture/node.test.ts +1602 -0
- package/test/examples/asciiMaze/asciiMaze.e2e.test.ts +499 -0
- package/test/examples/asciiMaze/asciiMaze.ts +41 -0
- package/test/examples/asciiMaze/browser-entry.ts +164 -0
- package/test/examples/asciiMaze/browserLogger.ts +221 -0
- package/test/examples/asciiMaze/browserTerminalUtility.ts +48 -0
- package/test/examples/asciiMaze/colors.ts +119 -0
- package/test/examples/asciiMaze/dashboardManager.ts +968 -0
- package/test/examples/asciiMaze/evolutionEngine.ts +1248 -0
- package/test/examples/asciiMaze/fitness.ts +136 -0
- package/test/examples/asciiMaze/index.html +128 -0
- package/test/examples/asciiMaze/index.ts +26 -0
- package/test/examples/asciiMaze/interfaces.ts +235 -0
- package/test/examples/asciiMaze/mazeMovement.ts +996 -0
- package/test/examples/asciiMaze/mazeUtils.ts +278 -0
- package/test/examples/asciiMaze/mazeVision.ts +402 -0
- package/test/examples/asciiMaze/mazeVisualization.ts +585 -0
- package/test/examples/asciiMaze/mazes.ts +245 -0
- package/test/examples/asciiMaze/networkRefinement.ts +76 -0
- package/test/examples/asciiMaze/networkVisualization.ts +901 -0
- package/test/examples/asciiMaze/terminalUtility.ts +73 -0
- package/test/methods/activation.test.ts +1142 -0
- package/test/methods/connection.test.ts +146 -0
- package/test/methods/cost.test.ts +1123 -0
- package/test/methods/crossover.test.ts +202 -0
- package/test/methods/gating.test.ts +144 -0
- package/test/methods/mutation.test.ts +451 -0
- package/test/methods/optimizers.advanced.test.ts +80 -0
- package/test/methods/optimizers.behavior.test.ts +105 -0
- package/test/methods/optimizers.formula.test.ts +89 -0
- package/test/methods/rate.cosineWarmRestarts.test.ts +44 -0
- package/test/methods/rate.linearWarmupDecay.test.ts +41 -0
- package/test/methods/rate.reduceOnPlateau.test.ts +45 -0
- package/test/methods/rate.test.ts +684 -0
- package/test/methods/selection.test.ts +245 -0
- package/test/multithreading/activations.functions.test.ts +54 -0
- package/test/multithreading/multi.test.ts +290 -0
- package/test/multithreading/worker.node.process.test.ts +39 -0
- package/test/multithreading/workers.coverage.test.ts +36 -0
- package/test/multithreading/workers.dynamic.import.test.ts +8 -0
- package/test/neat/neat.adaptive.complexityBudget.test.ts +34 -0
- package/test/neat/neat.adaptive.criterion.complexity.test.ts +50 -0
- package/test/neat/neat.adaptive.mutation.strategy.test.ts +37 -0
- package/test/neat/neat.adaptive.operator.decay.test.ts +31 -0
- package/test/neat/neat.adaptive.phasedComplexity.test.ts +25 -0
- package/test/neat/neat.adaptive.pruning.test.ts +25 -0
- package/test/neat/neat.adaptive.targetSpecies.test.ts +43 -0
- package/test/neat/neat.additional.coverage.test.ts +126 -0
- package/test/neat/neat.advanced.enhancements.test.ts +85 -0
- package/test/neat/neat.advanced.test.ts +589 -0
- package/test/neat/neat.diversity.autocompat.test.ts +47 -0
- package/test/neat/neat.diversity.metrics.test.ts +21 -0
- package/test/neat/neat.diversity.stats.test.ts +44 -0
- package/test/neat/neat.enhancements.test.ts +79 -0
- package/test/neat/neat.entropy.ancestorAdaptive.test.ts +133 -0
- package/test/neat/neat.entropy.compat.csv.test.ts +108 -0
- package/test/neat/neat.evolution.pruning.test.ts +39 -0
- package/test/neat/neat.fastmode.autotune.test.ts +42 -0
- package/test/neat/neat.innovation.test.ts +134 -0
- package/test/neat/neat.lineage.antibreeding.test.ts +35 -0
- package/test/neat/neat.lineage.entropy.test.ts +56 -0
- package/test/neat/neat.lineage.inbreeding.test.ts +49 -0
- package/test/neat/neat.lineage.pressure.test.ts +29 -0
- package/test/neat/neat.multiobjective.adaptive.test.ts +57 -0
- package/test/neat/neat.multiobjective.dynamic.schedule.test.ts +46 -0
- package/test/neat/neat.multiobjective.dynamic.test.ts +31 -0
- package/test/neat/neat.multiobjective.fastsort.delegation.test.ts +51 -0
- package/test/neat/neat.multiobjective.prune.test.ts +39 -0
- package/test/neat/neat.multiobjective.test.ts +21 -0
- package/test/neat/neat.mutation.undefined.pool.test.ts +24 -0
- package/test/neat/neat.objective.events.test.ts +26 -0
- package/test/neat/neat.objective.importance.test.ts +21 -0
- package/test/neat/neat.objective.lifetimes.test.ts +33 -0
- package/test/neat/neat.offspring.allocation.test.ts +22 -0
- package/test/neat/neat.operator.bandit.test.ts +17 -0
- package/test/neat/neat.operator.phases.test.ts +38 -0
- package/test/neat/neat.pruneInactive.behavior.test.ts +54 -0
- package/test/neat/neat.reenable.adaptation.test.ts +18 -0
- package/test/neat/neat.rng.state.test.ts +22 -0
- package/test/neat/neat.spawn.add.test.ts +123 -0
- package/test/neat/neat.speciation.test.ts +96 -0
- package/test/neat/neat.species.allocation.telemetry.test.ts +26 -0
- package/test/neat/neat.species.history.csv.test.ts +24 -0
- package/test/neat/neat.telemetry.advanced.test.ts +226 -0
- package/test/neat/neat.telemetry.csv.lineage.test.ts +19 -0
- package/test/neat/neat.telemetry.parity.test.ts +42 -0
- package/test/neat/neat.telemetry.stream.test.ts +19 -0
- package/test/neat/neat.telemetry.test.ts +16 -0
- package/test/neat/neat.test.ts +422 -0
- package/test/neat/neat.utilities.test.ts +44 -0
- package/test/network/__suppress_console.ts +9 -0
- package/test/network/acyclic.topoorder.test.ts +17 -0
- package/test/network/checkpoint.metricshook.test.ts +36 -0
- package/test/network/error.handling.test.ts +581 -0
- package/test/network/evolution.test.ts +285 -0
- package/test/network/genetic.test.ts +208 -0
- package/test/network/learning.capability.test.ts +244 -0
- package/test/network/mutation.effects.test.ts +492 -0
- package/test/network/network.activate.test.ts +115 -0
- package/test/network/network.activateBatch.test.ts +30 -0
- package/test/network/network.deterministic.test.ts +64 -0
- package/test/network/network.evolve.branches.test.ts +75 -0
- package/test/network/network.evolve.multithread.branches.test.ts +83 -0
- package/test/network/network.evolve.test.ts +100 -0
- package/test/network/network.gating.removal.test.ts +93 -0
- package/test/network/network.mutate.additional.test.ts +145 -0
- package/test/network/network.mutate.edgecases.test.ts +101 -0
- package/test/network/network.mutate.test.ts +101 -0
- package/test/network/network.prune.earlyexit.test.ts +38 -0
- package/test/network/network.remove.errors.test.ts +45 -0
- package/test/network/network.slab.fallbacks.test.ts +22 -0
- package/test/network/network.stats.test.ts +45 -0
- package/test/network/network.training.advanced.test.ts +149 -0
- package/test/network/network.training.basic.test.ts +228 -0
- package/test/network/network.training.helpers.test.ts +183 -0
- package/test/network/onnx.export.test.ts +310 -0
- package/test/network/onnx.import.test.ts +129 -0
- package/test/network/pruning.topology.test.ts +282 -0
- package/test/network/regularization.determinism.test.ts +83 -0
- package/test/network/regularization.dropconnect.test.ts +17 -0
- package/test/network/regularization.dropconnect.validation.test.ts +18 -0
- package/test/network/regularization.stochasticdepth.test.ts +27 -0
- package/test/network/regularization.test.ts +843 -0
- package/test/network/regularization.weightnoise.test.ts +30 -0
- package/test/network/setupTests.ts +2 -0
- package/test/network/standalone.test.ts +332 -0
- package/test/network/structure.serialization.test.ts +660 -0
- package/test/training/training.determinism.mixed-precision.test.ts +134 -0
- package/test/training/training.earlystopping.test.ts +91 -0
- package/test/training/training.edge-cases.test.ts +91 -0
- package/test/training/training.extensions.test.ts +47 -0
- package/test/training/training.gradient.features.test.ts +110 -0
- package/test/training/training.gradient.refinements.test.ts +170 -0
- package/test/training/training.gradient.separate-bias.test.ts +41 -0
- package/test/training/training.optimizer.test.ts +48 -0
- package/test/training/training.plateau.smoothing.test.ts +58 -0
- package/test/training/training.smoothing.types.test.ts +174 -0
- package/test/training/training.train.options.coverage.test.ts +52 -0
- package/test/utils/console-helper.ts +76 -0
- package/test/utils/jest-setup.ts +60 -0
- package/test/utils/test-helpers.ts +175 -0
- package/tsconfig.docs.json +12 -0
- package/tsconfig.json +21 -0
- package/webpack.config.js +49 -0
|
@@ -0,0 +1,1278 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Training pipeline utilities (migrated from legacy architecture/network.train.ts).
|
|
3
|
+
*
|
|
4
|
+
* Provides:
|
|
5
|
+
* - Gradient clipping (global / layerwise; norm / percentile variants).
|
|
6
|
+
* - Mini & micro-batch gradient accumulation.
|
|
7
|
+
* - Optimizer step dispatch (SGD + adaptive optimizers + lookahead wrapper).
|
|
8
|
+
* - Simple mixed precision dynamic loss scaling (overflow detection heuristic).
|
|
9
|
+
* - Multiple moving-average smoothing strategies for error monitoring (SMA, EMA, adaptive EMA,
|
|
10
|
+
* median, gaussian, trimmed mean, WMA) plus separate plateau averaging.
|
|
11
|
+
* - Early stopping, schedule hooks, pruning hooks, and checkpoint callbacks.
|
|
12
|
+
*
|
|
13
|
+
* Notes:
|
|
14
|
+
* - This module intentionally keeps imperative style for clarity/perf (avoids heap churn in hot loops).
|
|
15
|
+
* - Refactor changes here are documentation & naming only; numerical behavior preserved.
|
|
16
|
+
*/
|
|
17
|
+
import * as methods from '../../methods/methods';
|
|
18
|
+
import { config } from '../../config';
|
|
19
|
+
import type Network from '../network';
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* -----------------------------------------------------------------------------
|
|
23
|
+
* Internal Type Definitions (documentation only; optional for callers)
|
|
24
|
+
* -----------------------------------------------------------------------------
|
|
25
|
+
*/
|
|
26
|
+
/** Cost function signature used by training. */
|
|
27
|
+
export type CostFunction = (target: number[], output: number[]) => number;
|
|
28
|
+
|
|
29
|
+
/** Gradient clipping configuration accepted by options.gradientClip. */
|
|
30
|
+
export interface GradientClipConfig {
|
|
31
|
+
mode?: 'norm' | 'percentile' | 'layerwiseNorm' | 'layerwisePercentile';
|
|
32
|
+
/** Max L2 norm (for *Norm modes). */
|
|
33
|
+
maxNorm?: number;
|
|
34
|
+
/** Percentile threshold (0-100) for *Percentile modes (clamps absolute values). */
|
|
35
|
+
percentile?: number;
|
|
36
|
+
/** Whether to treat bias separately (currently informational flag – behavior parity preserved). */
|
|
37
|
+
separateBias?: boolean;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/** Mixed precision configuration. */
|
|
41
|
+
export interface MixedPrecisionDynamicConfig {
|
|
42
|
+
/** Minimum loss scale when scaling down after overflows. */
|
|
43
|
+
minScale?: number;
|
|
44
|
+
/** Maximum allowed loss scale for automatic increases. */
|
|
45
|
+
maxScale?: number;
|
|
46
|
+
/** Steps of stable (non-overflow) updates before doubling loss scale. */
|
|
47
|
+
increaseEvery?: number; // alias stableStepsForIncrease
|
|
48
|
+
/** Legacy alias: stable steps threshold for increase. */
|
|
49
|
+
stableStepsForIncrease?: number;
|
|
50
|
+
}
|
|
51
|
+
export interface MixedPrecisionConfig {
|
|
52
|
+
/** Initial loss scale (larger -> more mantissa preservation but higher overflow risk). */
|
|
53
|
+
lossScale?: number;
|
|
54
|
+
/** Enable dynamic (auto increase/decrease) logic. */
|
|
55
|
+
dynamic?: MixedPrecisionDynamicConfig;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/** Optimizer configuration (subset – delegated to node.applyBatchUpdatesWithOptimizer). */
|
|
59
|
+
export interface OptimizerConfigBase {
|
|
60
|
+
type: string; // normalized to lowercase
|
|
61
|
+
baseType?: string; // for lookahead
|
|
62
|
+
beta1?: number;
|
|
63
|
+
beta2?: number;
|
|
64
|
+
eps?: number;
|
|
65
|
+
weightDecay?: number;
|
|
66
|
+
momentum?: number;
|
|
67
|
+
la_k?: number; // lookahead sync interval
|
|
68
|
+
la_alpha?: number; // lookahead interpolation factor
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/** Checkpoint callback spec. */
|
|
72
|
+
export interface CheckpointConfig {
|
|
73
|
+
/** Save final state each iteration. */
|
|
74
|
+
last?: boolean;
|
|
75
|
+
/** Save best (lowest error) state. */
|
|
76
|
+
best?: boolean;
|
|
77
|
+
/** Persist function invoked with metadata + serialized network. */
|
|
78
|
+
save: (payload: {
|
|
79
|
+
type: 'last' | 'best';
|
|
80
|
+
iteration: number;
|
|
81
|
+
error: number;
|
|
82
|
+
network: any;
|
|
83
|
+
}) => void;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/** Schedule hook executed every N iterations. */
|
|
87
|
+
export interface ScheduleConfig {
|
|
88
|
+
iterations: number; // frequency
|
|
89
|
+
function: (info: { error: number; iteration: number }) => void;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/** Metrics hook signature. */
|
|
93
|
+
export type MetricsHook = (m: {
|
|
94
|
+
iteration: number;
|
|
95
|
+
error: number;
|
|
96
|
+
plateauError?: number;
|
|
97
|
+
gradNorm: number;
|
|
98
|
+
}) => void;
|
|
99
|
+
|
|
100
|
+
/** Moving average strategy identifiers. */
|
|
101
|
+
export type MovingAverageType =
|
|
102
|
+
| 'sma'
|
|
103
|
+
| 'ema'
|
|
104
|
+
| 'adaptive-ema'
|
|
105
|
+
| 'median'
|
|
106
|
+
| 'gaussian'
|
|
107
|
+
| 'trimmed'
|
|
108
|
+
| 'wma';
|
|
109
|
+
|
|
110
|
+
/** Primary training options object (public shape). */
|
|
111
|
+
export interface TrainingOptions {
|
|
112
|
+
iterations?: number; // stopping condition: max passes
|
|
113
|
+
error?: number; // stopping condition: target monitored (smoothed) error
|
|
114
|
+
rate?: number; // base learning rate
|
|
115
|
+
momentum?: number; // momentum for SGD / sometimes consumed by wrappers
|
|
116
|
+
optimizer?: string | OptimizerConfigBase; // adaptive optimizer choice
|
|
117
|
+
dropout?: number; // dropout probability applied per forward (mutable net.dropout)
|
|
118
|
+
batchSize?: number; // mini-batch size; if > dataset length => error
|
|
119
|
+
accumulationSteps?: number; // gradient accumulation factor (micro-batches per optimizer step)
|
|
120
|
+
accumulationReduction?: 'average' | 'sum'; // scaling mode for accumulated gradients
|
|
121
|
+
gradientClip?: GradientClipConfig; // gradient clipping configuration
|
|
122
|
+
mixedPrecision?: boolean | MixedPrecisionConfig; // enable FP16-like scaling logic
|
|
123
|
+
cost?: CostFunction | { fn?: CostFunction; calculate?: CostFunction }; // cost interface variants
|
|
124
|
+
movingAverageWindow?: number; // smoothing window size
|
|
125
|
+
movingAverageType?: MovingAverageType; // smoothing algorithm
|
|
126
|
+
emaAlpha?: number; // override alpha for EMA
|
|
127
|
+
adaptiveEmaBaseAlpha?: number; // (not currently used – placeholder)
|
|
128
|
+
trimmedRatio?: number; // fraction dropped from each tail for trimmed mean (0..0.49)
|
|
129
|
+
plateauMovingAverageWindow?: number; // independent plateau window
|
|
130
|
+
plateauMovingAverageType?: MovingAverageType; // independent plateau strategy
|
|
131
|
+
plateauEmaAlpha?: number; // plateau EMA alpha override
|
|
132
|
+
earlyStopPatience?: number; // iterations with no improvement before stop
|
|
133
|
+
earlyStopMinDelta?: number; // required improvement beyond previous best
|
|
134
|
+
checkpoint?: CheckpointConfig; // persistence callbacks
|
|
135
|
+
schedule?: ScheduleConfig; // periodic hook
|
|
136
|
+
metricsHook?: MetricsHook; // telemetry per iteration
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/** ---------------------------------------------------------------------------
|
|
140
|
+
* Internal Helper Utilities (non-exported)
|
|
141
|
+
* ---------------------------------------------------------------------------
|
|
142
|
+
* These functions encapsulate cohesive sub-steps of the training pipeline so the
|
|
143
|
+
* main exported functions remain readable while preserving original behavior.
|
|
144
|
+
* Each helper is intentionally pure where reasonable or documents its side-effects.
|
|
145
|
+
*/
|
|
146
|
+
|
|
147
|
+
/** State container for EMA / Adaptive EMA smoothing values. */
|
|
148
|
+
interface PrimarySmoothingState {
|
|
149
|
+
/** Classic EMA value (when movingAverageType === 'ema'). */
|
|
150
|
+
emaValue?: number;
|
|
151
|
+
/** Baseline EMA part of adaptive EMA (slower). */
|
|
152
|
+
adaptiveBaseEmaValue?: number;
|
|
153
|
+
/** Fast adaptive EMA (higher alpha under variance). */
|
|
154
|
+
adaptiveEmaValue?: number;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/** State container for plateau EMA smoothing. */
|
|
158
|
+
interface PlateauSmoothingState {
|
|
159
|
+
plateauEmaValue?: number;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
/** Configuration passed to monitored (primary) smoothing computation. */
|
|
163
|
+
interface MonitoredSmoothingConfig {
|
|
164
|
+
type: MovingAverageType;
|
|
165
|
+
window: number;
|
|
166
|
+
emaAlpha?: number; // optional override (only for EMA types)
|
|
167
|
+
trimmedRatio?: number; // for trimmed mean strategy
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
/** Configuration for plateau smoothing computation. */
|
|
171
|
+
interface PlateauSmoothingConfig {
|
|
172
|
+
type: MovingAverageType;
|
|
173
|
+
window: number;
|
|
174
|
+
emaAlpha?: number;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
/**
|
|
178
|
+
* Compute the monitored (primary) smoothed error given recent raw errors.
|
|
179
|
+
*
|
|
180
|
+
* Behavior:
|
|
181
|
+
* - For SMA-like strategies uses the supplied window slice directly.
|
|
182
|
+
* - For EMA it mutates state.emaValue.
|
|
183
|
+
* - For adaptive-ema maintains dual EMA tracks inside state and returns the min for stability.
|
|
184
|
+
* - For median / gaussian / trimmed / wma applies algorithmic weighting as documented inline.
|
|
185
|
+
*
|
|
186
|
+
* Inputs:
|
|
187
|
+
* - trainError: Current raw mean error for this iteration.
|
|
188
|
+
* - recentErrors: Chronological array (oldest->newest) of last N raw errors.
|
|
189
|
+
* - cfg: Algorithm selection + parameters.
|
|
190
|
+
* - state: Mutable smoothing state (ema / adaptive fields updated in-place).
|
|
191
|
+
*
|
|
192
|
+
* Returns: Smoothed/monitored error metric (may equal trainError if no smoothing active).
|
|
193
|
+
*/
|
|
194
|
+
function computeMonitoredError(
|
|
195
|
+
trainError: number,
|
|
196
|
+
recentErrors: number[],
|
|
197
|
+
cfg: MonitoredSmoothingConfig,
|
|
198
|
+
state: PrimarySmoothingState
|
|
199
|
+
): number {
|
|
200
|
+
// Fast path: no smoothing window / algorithm requiring history.
|
|
201
|
+
if (cfg.window <= 1 && cfg.type !== 'ema' && cfg.type !== 'adaptive-ema') {
|
|
202
|
+
return trainError;
|
|
203
|
+
}
|
|
204
|
+
const type = cfg.type;
|
|
205
|
+
if (type === 'median') {
|
|
206
|
+
const sorted = [...recentErrors].sort((a, b) => a - b);
|
|
207
|
+
const midIndex = Math.floor(sorted.length / 2);
|
|
208
|
+
return sorted.length % 2
|
|
209
|
+
? sorted[midIndex]
|
|
210
|
+
: (sorted[midIndex - 1] + sorted[midIndex]) / 2;
|
|
211
|
+
}
|
|
212
|
+
if (type === 'ema') {
|
|
213
|
+
// Standard exponential moving average.
|
|
214
|
+
if (state.emaValue == null) state.emaValue = trainError;
|
|
215
|
+
else
|
|
216
|
+
state.emaValue =
|
|
217
|
+
state.emaValue + cfg.emaAlpha! * (trainError - state.emaValue);
|
|
218
|
+
return state.emaValue;
|
|
219
|
+
}
|
|
220
|
+
if (type === 'adaptive-ema') {
|
|
221
|
+
// Adaptive EMA: baseline alpha + volatility-inflated alpha, final metric is more conservative (min).
|
|
222
|
+
const mean = recentErrors.reduce((a, b) => a + b, 0) / recentErrors.length;
|
|
223
|
+
const variance =
|
|
224
|
+
recentErrors.reduce((a, b) => a + (b - mean) * (b - mean), 0) /
|
|
225
|
+
recentErrors.length;
|
|
226
|
+
const baseAlpha = cfg.emaAlpha || 2 / (cfg.window + 1);
|
|
227
|
+
const varianceScaled = variance / Math.max(mean * mean, 1e-8);
|
|
228
|
+
const adaptiveAlpha = Math.min(
|
|
229
|
+
0.95,
|
|
230
|
+
Math.max(baseAlpha, baseAlpha * (1 + 2 * varianceScaled))
|
|
231
|
+
);
|
|
232
|
+
if (state.adaptiveBaseEmaValue == null) {
|
|
233
|
+
state.adaptiveBaseEmaValue = trainError;
|
|
234
|
+
state.adaptiveEmaValue = trainError;
|
|
235
|
+
} else {
|
|
236
|
+
state.adaptiveBaseEmaValue =
|
|
237
|
+
state.adaptiveBaseEmaValue +
|
|
238
|
+
baseAlpha * (trainError - state.adaptiveBaseEmaValue);
|
|
239
|
+
state.adaptiveEmaValue =
|
|
240
|
+
state.adaptiveEmaValue! +
|
|
241
|
+
adaptiveAlpha * (trainError - state.adaptiveEmaValue!);
|
|
242
|
+
}
|
|
243
|
+
return Math.min(state.adaptiveEmaValue!, state.adaptiveBaseEmaValue!);
|
|
244
|
+
}
|
|
245
|
+
if (type === 'gaussian') {
|
|
246
|
+
// Gaussian kernel weights centered at newest element (index length-1).
|
|
247
|
+
const sigma = cfg.window / 3 || 1; // heuristic: cover window ~3 sigma
|
|
248
|
+
let weightSum = 0;
|
|
249
|
+
let weightedAccumulator = 0;
|
|
250
|
+
const length = recentErrors.length;
|
|
251
|
+
for (let i = 0; i < length; i++) {
|
|
252
|
+
const weight = Math.exp(-0.5 * Math.pow((i - (length - 1)) / sigma, 2));
|
|
253
|
+
weightSum += weight;
|
|
254
|
+
weightedAccumulator += weight * recentErrors[i];
|
|
255
|
+
}
|
|
256
|
+
return weightedAccumulator / (weightSum || 1);
|
|
257
|
+
}
|
|
258
|
+
if (type === 'trimmed') {
|
|
259
|
+
// Trim symmetric tails before averaging to reduce outlier influence.
|
|
260
|
+
const ratio = Math.min(0.49, Math.max(0, cfg.trimmedRatio || 0.1));
|
|
261
|
+
const sorted = [...recentErrors].sort((a, b) => a - b);
|
|
262
|
+
const drop = Math.floor(sorted.length * ratio);
|
|
263
|
+
const trimmed = sorted.slice(drop, sorted.length - drop);
|
|
264
|
+
return trimmed.reduce((a, b) => a + b, 0) / (trimmed.length || 1);
|
|
265
|
+
}
|
|
266
|
+
if (type === 'wma') {
|
|
267
|
+
// Linear weighting (oldest weight=1 ... newest weight=n).
|
|
268
|
+
let weightSum = 0;
|
|
269
|
+
let weightedAccumulator = 0;
|
|
270
|
+
for (let i = 0; i < recentErrors.length; i++) {
|
|
271
|
+
const weight = i + 1;
|
|
272
|
+
weightSum += weight;
|
|
273
|
+
weightedAccumulator += weight * recentErrors[i];
|
|
274
|
+
}
|
|
275
|
+
return weightedAccumulator / (weightSum || 1);
|
|
276
|
+
}
|
|
277
|
+
// Default: arithmetic mean (SMA).
|
|
278
|
+
return recentErrors.reduce((a, b) => a + b, 0) / recentErrors.length;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
/**
|
|
282
|
+
* Compute plateau metric (may differ in strategy from primary monitored error).
|
|
283
|
+
* Only algorithms actually supported for plateau in current pipeline are SMA, median and EMA.
|
|
284
|
+
* Provided flexibility keeps room for extension; unsupported types silently fallback to mean.
|
|
285
|
+
*/
|
|
286
|
+
function computePlateauMetric(
|
|
287
|
+
trainError: number,
|
|
288
|
+
plateauErrors: number[],
|
|
289
|
+
cfg: PlateauSmoothingConfig,
|
|
290
|
+
state: PlateauSmoothingState
|
|
291
|
+
): number {
|
|
292
|
+
if (cfg.window <= 1 && cfg.type !== 'ema') return trainError;
|
|
293
|
+
if (cfg.type === 'median') {
|
|
294
|
+
const sorted = [...plateauErrors].sort((a, b) => a - b);
|
|
295
|
+
const mid = Math.floor(sorted.length / 2);
|
|
296
|
+
return sorted.length % 2
|
|
297
|
+
? sorted[mid]
|
|
298
|
+
: (sorted[mid - 1] + sorted[mid]) / 2;
|
|
299
|
+
}
|
|
300
|
+
if (cfg.type === 'ema') {
|
|
301
|
+
if (state.plateauEmaValue == null) state.plateauEmaValue = trainError;
|
|
302
|
+
else
|
|
303
|
+
state.plateauEmaValue =
|
|
304
|
+
state.plateauEmaValue +
|
|
305
|
+
cfg.emaAlpha! * (trainError - state.plateauEmaValue);
|
|
306
|
+
return state.plateauEmaValue;
|
|
307
|
+
}
|
|
308
|
+
// Fallback default mean.
|
|
309
|
+
return plateauErrors.reduce((a, b) => a + b, 0) / plateauErrors.length;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
// Internal export bundle (test-only usage) to enable direct branch coverage of smoothing helpers.
|
|
313
|
+
// Marked with double underscore to discourage production use.
|
|
314
|
+
export const __trainingInternals = {
|
|
315
|
+
computeMonitoredError,
|
|
316
|
+
computePlateauMetric,
|
|
317
|
+
};
|
|
318
|
+
|
|
319
|
+
/**
|
|
320
|
+
* Detect mixed precision overflow (NaN / Inf) in bias values if mixed precision enabled.
|
|
321
|
+
* Side-effect: may clear internal trigger _forceNextOverflow.
|
|
322
|
+
*/
|
|
323
|
+
function detectMixedPrecisionOverflow(net: Network, internalNet: any): boolean {
|
|
324
|
+
if (!internalNet._mixedPrecision.enabled) return false;
|
|
325
|
+
if (internalNet._forceNextOverflow) {
|
|
326
|
+
internalNet._forceNextOverflow = false;
|
|
327
|
+
return true;
|
|
328
|
+
}
|
|
329
|
+
let overflow = false;
|
|
330
|
+
net.nodes.forEach((node) => {
|
|
331
|
+
if ((node as any)._fp32Bias !== undefined) {
|
|
332
|
+
if (!Number.isFinite((node as any).bias)) overflow = true;
|
|
333
|
+
}
|
|
334
|
+
});
|
|
335
|
+
return overflow;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
/** Zero-out accumulated gradient buffers after an overflow to discard invalid updates. */
|
|
339
|
+
function zeroAccumulatedGradients(net: Network) {
|
|
340
|
+
net.nodes.forEach((node) => {
|
|
341
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
342
|
+
c.totalDeltaWeight = 0;
|
|
343
|
+
});
|
|
344
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
345
|
+
c.totalDeltaWeight = 0;
|
|
346
|
+
});
|
|
347
|
+
if (typeof (node as any).totalDeltaBias === 'number')
|
|
348
|
+
(node as any).totalDeltaBias = 0;
|
|
349
|
+
(node as any).previousDeltaBias = 0;
|
|
350
|
+
});
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
/** Divide accumulated gradients by accumulationSteps (average reduction mode). */
|
|
354
|
+
function averageAccumulatedGradients(net: Network, accumulationSteps: number) {
|
|
355
|
+
if (accumulationSteps <= 1) return;
|
|
356
|
+
net.nodes.forEach((node) => {
|
|
357
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
358
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
359
|
+
c.totalDeltaWeight /= accumulationSteps;
|
|
360
|
+
});
|
|
361
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
362
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
363
|
+
c.totalDeltaWeight /= accumulationSteps;
|
|
364
|
+
});
|
|
365
|
+
if (typeof (node as any).totalDeltaBias === 'number')
|
|
366
|
+
(node as any).totalDeltaBias /= accumulationSteps;
|
|
367
|
+
});
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
/** Apply optimizer update step across all nodes; returns gradient L2 norm (approx). */
|
|
371
|
+
function applyOptimizerStep(
|
|
372
|
+
net: Network,
|
|
373
|
+
optimizer: any,
|
|
374
|
+
currentRate: number,
|
|
375
|
+
momentum: number,
|
|
376
|
+
internalNet: any
|
|
377
|
+
): number {
|
|
378
|
+
let sumSq = 0;
|
|
379
|
+
net.nodes.forEach((node) => {
|
|
380
|
+
if (node.type === 'input') return;
|
|
381
|
+
(node as any).applyBatchUpdatesWithOptimizer({
|
|
382
|
+
type: optimizer.type,
|
|
383
|
+
baseType: optimizer.baseType,
|
|
384
|
+
beta1: optimizer.beta1,
|
|
385
|
+
beta2: optimizer.beta2,
|
|
386
|
+
eps: optimizer.eps,
|
|
387
|
+
weightDecay: optimizer.weightDecay,
|
|
388
|
+
momentum: optimizer.momentum ?? momentum,
|
|
389
|
+
lrScale: currentRate,
|
|
390
|
+
t: internalNet._optimizerStep,
|
|
391
|
+
la_k: optimizer.la_k,
|
|
392
|
+
la_alpha: optimizer.la_alpha,
|
|
393
|
+
});
|
|
394
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
395
|
+
if (typeof c.previousDeltaWeight === 'number')
|
|
396
|
+
sumSq += c.previousDeltaWeight * c.previousDeltaWeight;
|
|
397
|
+
});
|
|
398
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
399
|
+
if (typeof c.previousDeltaWeight === 'number')
|
|
400
|
+
sumSq += c.previousDeltaWeight * c.previousDeltaWeight;
|
|
401
|
+
});
|
|
402
|
+
});
|
|
403
|
+
return Math.sqrt(sumSq);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
/** Update dynamic loss scaling after a successful (non-overflow) optimizer step. */
|
|
407
|
+
function maybeIncreaseLossScale(internalNet: any) {
|
|
408
|
+
internalNet._mixedPrecisionState.goodSteps++;
|
|
409
|
+
const incEvery = internalNet._mpIncreaseEvery || 200;
|
|
410
|
+
if (
|
|
411
|
+
internalNet._mixedPrecisionState.goodSteps >= incEvery &&
|
|
412
|
+
internalNet._mixedPrecision.lossScale <
|
|
413
|
+
internalNet._mixedPrecisionState.maxLossScale
|
|
414
|
+
) {
|
|
415
|
+
internalNet._mixedPrecision.lossScale *= 2;
|
|
416
|
+
internalNet._mixedPrecisionState.goodSteps = 0;
|
|
417
|
+
internalNet._mixedPrecisionState.scaleUpEvents =
|
|
418
|
+
(internalNet._mixedPrecisionState.scaleUpEvents || 0) + 1;
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
/** Respond to a mixed precision overflow by shrinking loss scale & bookkeeping. */
|
|
423
|
+
function handleOverflow(internalNet: any) {
|
|
424
|
+
internalNet._mixedPrecisionState.badSteps++;
|
|
425
|
+
internalNet._mixedPrecisionState.goodSteps = 0;
|
|
426
|
+
internalNet._mixedPrecision.lossScale = Math.max(
|
|
427
|
+
internalNet._mixedPrecisionState.minLossScale,
|
|
428
|
+
Math.floor(internalNet._mixedPrecision.lossScale / 2) || 1
|
|
429
|
+
);
|
|
430
|
+
internalNet._mixedPrecisionState.overflowCount =
|
|
431
|
+
(internalNet._mixedPrecisionState.overflowCount || 0) + 1;
|
|
432
|
+
internalNet._mixedPrecisionState.scaleDownEvents =
|
|
433
|
+
(internalNet._mixedPrecisionState.scaleDownEvents || 0) + 1;
|
|
434
|
+
internalNet._lastOverflowStep = internalNet._optimizerStep;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
/**
|
|
438
|
+
* Apply gradient clipping to accumulated connection deltas / bias deltas.
|
|
439
|
+
*
|
|
440
|
+
* Modes:
|
|
441
|
+
* - norm / layerwiseNorm: L2 norm scaling (global vs per group).
|
|
442
|
+
* - percentile / layerwisePercentile: element-wise clamp at absolute percentile threshold.
|
|
443
|
+
*
|
|
444
|
+
* Grouping:
|
|
445
|
+
* - If layerwise* and net.layers exists -> each defined layer is a group.
|
|
446
|
+
* - Else if layerwise* -> each non-input node becomes its own group.
|
|
447
|
+
* - Otherwise a single global group containing all learnable params.
|
|
448
|
+
*/
|
|
449
|
+
export function applyGradientClippingImpl(
|
|
450
|
+
net: Network,
|
|
451
|
+
cfg: {
|
|
452
|
+
mode: 'norm' | 'percentile' | 'layerwiseNorm' | 'layerwisePercentile';
|
|
453
|
+
maxNorm?: number;
|
|
454
|
+
percentile?: number;
|
|
455
|
+
}
|
|
456
|
+
) {
|
|
457
|
+
const internalNet = net as any;
|
|
458
|
+
/**
|
|
459
|
+
* Build arrays of gradient values grouped according to chosen clipping mode.
|
|
460
|
+
* Each group is later processed independently (layerwise modes) or as a single global set.
|
|
461
|
+
*/
|
|
462
|
+
const collectGroups = () => {
|
|
463
|
+
const collected: number[][] = [];
|
|
464
|
+
if (cfg.mode.startsWith('layerwise')) {
|
|
465
|
+
if ((net as any).layers && (net as any).layers.length > 0) {
|
|
466
|
+
for (let li = 0; li < (net as any).layers.length; li++) {
|
|
467
|
+
const layer = (net as any).layers[li];
|
|
468
|
+
if (!layer || !layer.nodes) continue;
|
|
469
|
+
const groupVals: number[] = [];
|
|
470
|
+
layer.nodes.forEach((node: any) => {
|
|
471
|
+
if (!node || node.type === 'input') return;
|
|
472
|
+
node.connections.in.forEach((c: any) => {
|
|
473
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
474
|
+
groupVals.push(c.totalDeltaWeight);
|
|
475
|
+
});
|
|
476
|
+
node.connections.self.forEach((c: any) => {
|
|
477
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
478
|
+
groupVals.push(c.totalDeltaWeight);
|
|
479
|
+
});
|
|
480
|
+
if (typeof node.totalDeltaBias === 'number')
|
|
481
|
+
groupVals.push(node.totalDeltaBias);
|
|
482
|
+
});
|
|
483
|
+
if (groupVals.length) collected.push(groupVals);
|
|
484
|
+
}
|
|
485
|
+
} else {
|
|
486
|
+
net.nodes.forEach((node) => {
|
|
487
|
+
if (node.type === 'input') return;
|
|
488
|
+
const groupVals: number[] = [];
|
|
489
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
490
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
491
|
+
groupVals.push(c.totalDeltaWeight);
|
|
492
|
+
});
|
|
493
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
494
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
495
|
+
groupVals.push(c.totalDeltaWeight);
|
|
496
|
+
});
|
|
497
|
+
if (typeof (node as any).totalDeltaBias === 'number')
|
|
498
|
+
groupVals.push((node as any).totalDeltaBias);
|
|
499
|
+
if (groupVals.length) collected.push(groupVals);
|
|
500
|
+
});
|
|
501
|
+
}
|
|
502
|
+
} else {
|
|
503
|
+
const globalVals: number[] = [];
|
|
504
|
+
net.nodes.forEach((node) => {
|
|
505
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
506
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
507
|
+
globalVals.push(c.totalDeltaWeight);
|
|
508
|
+
});
|
|
509
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
510
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
511
|
+
globalVals.push(c.totalDeltaWeight);
|
|
512
|
+
});
|
|
513
|
+
if (typeof (node as any).totalDeltaBias === 'number')
|
|
514
|
+
globalVals.push((node as any).totalDeltaBias);
|
|
515
|
+
});
|
|
516
|
+
if (globalVals.length) collected.push(globalVals);
|
|
517
|
+
}
|
|
518
|
+
return collected;
|
|
519
|
+
};
|
|
520
|
+
/**
|
|
521
|
+
* Gradient groups discovered for clipping (size: 1 for global modes).
|
|
522
|
+
* Each entry is an array of parameter delta values belonging to a logical group (layer or node level).
|
|
523
|
+
*/
|
|
524
|
+
const groups = collectGroups();
|
|
525
|
+
/** Tracking for diagnostics / potential external tooling. */
|
|
526
|
+
internalNet._lastGradClipGroupCount = groups.length;
|
|
527
|
+
/**
|
|
528
|
+
* Compute absolute percentile threshold (e.g. percentile=99 => value whose |value| is at the 99th percentile).
|
|
529
|
+
* Sorting by absolute value guarantees consistent clipping for symmetric distributions.
|
|
530
|
+
*/
|
|
531
|
+
const computeAbsolutePercentileThreshold = (
|
|
532
|
+
values: number[],
|
|
533
|
+
percentile: number
|
|
534
|
+
) => {
|
|
535
|
+
if (!values.length) return 0;
|
|
536
|
+
const sortedByAbs = [...values].sort((a, b) => Math.abs(a) - Math.abs(b));
|
|
537
|
+
const rank = Math.min(
|
|
538
|
+
sortedByAbs.length - 1,
|
|
539
|
+
Math.max(0, Math.floor((percentile / 100) * sortedByAbs.length - 1))
|
|
540
|
+
);
|
|
541
|
+
return Math.abs(sortedByAbs[rank]);
|
|
542
|
+
};
|
|
543
|
+
/**
|
|
544
|
+
* Iterate all learnable parameters applying a transform function.
|
|
545
|
+
* The transform receives the current value and the owning group so it can selectively scale only
|
|
546
|
+
* the active group (when computing per-group scaling factor yet iterating entire model).
|
|
547
|
+
*/
|
|
548
|
+
const applyScale = (
|
|
549
|
+
scaleFn: (currentValue: number, owningGroup: number[]) => number
|
|
550
|
+
) => {
|
|
551
|
+
let groupIndex = 0; // advances only for layerwise modes
|
|
552
|
+
net.nodes.forEach((node) => {
|
|
553
|
+
if (cfg.mode.startsWith('layerwise') && node.type === 'input') return; // skip input nodes in layerwise grouping
|
|
554
|
+
const activeGroup = cfg.mode.startsWith('layerwise')
|
|
555
|
+
? groups[groupIndex++]
|
|
556
|
+
: groups[0];
|
|
557
|
+
(node as any).connections.in.forEach((c: any) => {
|
|
558
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
559
|
+
c.totalDeltaWeight = scaleFn(c.totalDeltaWeight, activeGroup);
|
|
560
|
+
});
|
|
561
|
+
(node as any).connections.self.forEach((c: any) => {
|
|
562
|
+
if (typeof c.totalDeltaWeight === 'number')
|
|
563
|
+
c.totalDeltaWeight = scaleFn(c.totalDeltaWeight, activeGroup);
|
|
564
|
+
});
|
|
565
|
+
if (typeof (node as any).totalDeltaBias === 'number')
|
|
566
|
+
(node as any).totalDeltaBias = scaleFn(
|
|
567
|
+
(node as any).totalDeltaBias,
|
|
568
|
+
activeGroup
|
|
569
|
+
);
|
|
570
|
+
});
|
|
571
|
+
};
|
|
572
|
+
if (cfg.mode === 'norm' || cfg.mode === 'layerwiseNorm') {
|
|
573
|
+
/** Maximum allowed L2 norm per group (or global). */
|
|
574
|
+
const maxAllowedNorm = cfg.maxNorm || 1;
|
|
575
|
+
groups.forEach((groupValues) => {
|
|
576
|
+
/** Current group L2 norm. */
|
|
577
|
+
const groupL2Norm = Math.sqrt(
|
|
578
|
+
groupValues.reduce((sum, v) => sum + v * v, 0)
|
|
579
|
+
);
|
|
580
|
+
if (groupL2Norm > maxAllowedNorm && groupL2Norm > 0) {
|
|
581
|
+
/** Scaling factor applied uniformly to bring norm to boundary. */
|
|
582
|
+
const normScaleFactor = maxAllowedNorm / groupL2Norm;
|
|
583
|
+
applyScale((currentValue, owningGroup) =>
|
|
584
|
+
owningGroup === groupValues
|
|
585
|
+
? currentValue * normScaleFactor
|
|
586
|
+
: currentValue
|
|
587
|
+
);
|
|
588
|
+
}
|
|
589
|
+
});
|
|
590
|
+
} else if (cfg.mode === 'percentile' || cfg.mode === 'layerwisePercentile') {
|
|
591
|
+
/** Percentile specifying absolute magnitude cutoff (values above are clamped). */
|
|
592
|
+
const percentileSetting = cfg.percentile || 99;
|
|
593
|
+
groups.forEach((groupValues) => {
|
|
594
|
+
const percentileThreshold = computeAbsolutePercentileThreshold(
|
|
595
|
+
groupValues,
|
|
596
|
+
percentileSetting
|
|
597
|
+
);
|
|
598
|
+
if (percentileThreshold <= 0) return;
|
|
599
|
+
applyScale((currentValue, owningGroup) =>
|
|
600
|
+
owningGroup === groupValues &&
|
|
601
|
+
Math.abs(currentValue) > percentileThreshold
|
|
602
|
+
? percentileThreshold * Math.sign(currentValue)
|
|
603
|
+
: currentValue
|
|
604
|
+
);
|
|
605
|
+
});
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
/**
|
|
610
|
+
* Execute one full pass over dataset (epoch) with optional accumulation & adaptive optimizer.
|
|
611
|
+
* Returns mean cost across processed samples.
|
|
612
|
+
*/
|
|
613
|
+
export function trainSetImpl(
|
|
614
|
+
net: Network,
|
|
615
|
+
set: { input: number[]; output: number[] }[],
|
|
616
|
+
batchSize: number,
|
|
617
|
+
accumulationSteps: number,
|
|
618
|
+
currentRate: number,
|
|
619
|
+
momentum: number,
|
|
620
|
+
regularization: any,
|
|
621
|
+
costFunction: (target: number[], output: number[]) => number,
|
|
622
|
+
optimizer?: any
|
|
623
|
+
): number {
|
|
624
|
+
const internalNet = net as any;
|
|
625
|
+
/** Sum of raw (unsmoothed) cost values across valid samples. */
|
|
626
|
+
let cumulativeError = 0;
|
|
627
|
+
/** Number of samples processed in current mini-batch (resets after potential optimizer step). */
|
|
628
|
+
let batchSampleCount = 0;
|
|
629
|
+
/** Counter of micro-batches contributing to current accumulated gradient set. */
|
|
630
|
+
internalNet._gradAccumMicroBatches = 0;
|
|
631
|
+
/** Total number of dataset samples actually processed (dimension-valid). */
|
|
632
|
+
let totalProcessedSamples = 0;
|
|
633
|
+
/** Cached list of output layer nodes (backprop order requires targets). */
|
|
634
|
+
const outputNodes = net.nodes.filter((n) => n.type === 'output');
|
|
635
|
+
/** Unified cost evaluation function resolved from provided cost variant. */
|
|
636
|
+
let computeError: (t: number[], o: number[]) => number;
|
|
637
|
+
if (typeof costFunction === 'function') computeError = costFunction as any;
|
|
638
|
+
else if (
|
|
639
|
+
(costFunction as any) &&
|
|
640
|
+
typeof (costFunction as any).fn === 'function'
|
|
641
|
+
)
|
|
642
|
+
computeError = (costFunction as any).fn;
|
|
643
|
+
else if (
|
|
644
|
+
(costFunction as any) &&
|
|
645
|
+
typeof (costFunction as any).calculate === 'function'
|
|
646
|
+
)
|
|
647
|
+
computeError = (costFunction as any).calculate;
|
|
648
|
+
else computeError = () => 0;
|
|
649
|
+
|
|
650
|
+
for (let sampleIndex = 0; sampleIndex < set.length; sampleIndex++) {
|
|
651
|
+
/** Current training sample record (input + target). */
|
|
652
|
+
const dataPoint = set[sampleIndex];
|
|
653
|
+
/** Input feature vector (validated for dimension). */
|
|
654
|
+
const input = dataPoint.input;
|
|
655
|
+
/** Target output vector (validated for dimension). */
|
|
656
|
+
const target = dataPoint.output;
|
|
657
|
+
if (input.length !== net.input || target.length !== net.output) {
|
|
658
|
+
if (config.warnings)
|
|
659
|
+
console.warn(
|
|
660
|
+
`Data point ${sampleIndex} has incorrect dimensions (input: ${input.length}/${net.input}, output: ${target.length}/${net.output}), skipping.`
|
|
661
|
+
);
|
|
662
|
+
continue;
|
|
663
|
+
}
|
|
664
|
+
try {
|
|
665
|
+
// Forward pass with training flag (enables dropout / any stochastic layers).
|
|
666
|
+
const output = (net as any).activate(input, true);
|
|
667
|
+
if (optimizer && optimizer.type && optimizer.type !== 'sgd') {
|
|
668
|
+
// Accumulate gradients for adaptive optimizers (no immediate weight update inside propagate).
|
|
669
|
+
for (let outIndex = 0; outIndex < outputNodes.length; outIndex++)
|
|
670
|
+
(outputNodes[outIndex] as any).propagate(
|
|
671
|
+
currentRate,
|
|
672
|
+
momentum,
|
|
673
|
+
false,
|
|
674
|
+
regularization,
|
|
675
|
+
target[outIndex]
|
|
676
|
+
);
|
|
677
|
+
for (
|
|
678
|
+
let reverseIndex = net.nodes.length - 1;
|
|
679
|
+
reverseIndex >= 0;
|
|
680
|
+
reverseIndex--
|
|
681
|
+
) {
|
|
682
|
+
const node = net.nodes[reverseIndex];
|
|
683
|
+
if (node.type === 'output' || node.type === 'input') continue;
|
|
684
|
+
(node as any).propagate(currentRate, momentum, false, regularization);
|
|
685
|
+
}
|
|
686
|
+
} else {
|
|
687
|
+
// SGD mode: propagate performs immediate parameter updates using deltas.
|
|
688
|
+
for (let outIndex = 0; outIndex < outputNodes.length; outIndex++)
|
|
689
|
+
(outputNodes[outIndex] as any).propagate(
|
|
690
|
+
currentRate,
|
|
691
|
+
momentum,
|
|
692
|
+
true,
|
|
693
|
+
regularization,
|
|
694
|
+
target[outIndex]
|
|
695
|
+
);
|
|
696
|
+
for (
|
|
697
|
+
let reverseIndex = net.nodes.length - 1;
|
|
698
|
+
reverseIndex >= 0;
|
|
699
|
+
reverseIndex--
|
|
700
|
+
) {
|
|
701
|
+
const node = net.nodes[reverseIndex];
|
|
702
|
+
if (node.type === 'output' || node.type === 'input') continue;
|
|
703
|
+
(node as any).propagate(currentRate, momentum, true, regularization);
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
cumulativeError += computeError(target, output);
|
|
707
|
+
batchSampleCount++;
|
|
708
|
+
totalProcessedSamples++;
|
|
709
|
+
} catch (e: any) {
|
|
710
|
+
if (config.warnings)
|
|
711
|
+
console.warn(
|
|
712
|
+
`Error processing data point ${sampleIndex} (input: ${JSON.stringify(
|
|
713
|
+
input
|
|
714
|
+
)}): ${e.message}. Skipping.`
|
|
715
|
+
);
|
|
716
|
+
}
|
|
717
|
+
// Mini-batch / end-of-dataset flush condition.
|
|
718
|
+
if (
|
|
719
|
+
batchSampleCount > 0 &&
|
|
720
|
+
((sampleIndex + 1) % batchSize === 0 || sampleIndex === set.length - 1)
|
|
721
|
+
) {
|
|
722
|
+
if (optimizer && optimizer.type && optimizer.type !== 'sgd') {
|
|
723
|
+
// Only adaptive optimizers delay the step; vanilla SGD already updated weights per sample.
|
|
724
|
+
internalNet._gradAccumMicroBatches++;
|
|
725
|
+
/** True when we have accumulated sufficient micro-batches or reached dataset end. */
|
|
726
|
+
const readyForStep =
|
|
727
|
+
internalNet._gradAccumMicroBatches % accumulationSteps === 0 ||
|
|
728
|
+
sampleIndex === set.length - 1;
|
|
729
|
+
if (readyForStep) {
|
|
730
|
+
/** 1-based optimizer step counter (used for bias-correction terms by adaptive methods). */
|
|
731
|
+
internalNet._optimizerStep = (internalNet._optimizerStep || 0) + 1;
|
|
732
|
+
/** Detect overflow under mixed precision (NaN/Inf). */
|
|
733
|
+
const overflowDetected = detectMixedPrecisionOverflow(
|
|
734
|
+
net,
|
|
735
|
+
internalNet
|
|
736
|
+
);
|
|
737
|
+
if (overflowDetected) {
|
|
738
|
+
// Discard invalid gradients & shrink loss scale.
|
|
739
|
+
zeroAccumulatedGradients(net);
|
|
740
|
+
if (internalNet._mixedPrecision.enabled)
|
|
741
|
+
handleOverflow(internalNet);
|
|
742
|
+
internalNet._lastGradNorm = 0;
|
|
743
|
+
} else {
|
|
744
|
+
// Optional gradient clipping before optimizer math.
|
|
745
|
+
if (internalNet._currentGradClip)
|
|
746
|
+
applyGradientClippingImpl(net, internalNet._currentGradClip);
|
|
747
|
+
// Average accumulated micro-batch gradients if configured.
|
|
748
|
+
if (
|
|
749
|
+
accumulationSteps > 1 &&
|
|
750
|
+
internalNet._accumulationReduction === 'average'
|
|
751
|
+
) {
|
|
752
|
+
averageAccumulatedGradients(net, accumulationSteps);
|
|
753
|
+
}
|
|
754
|
+
// Apply optimizer updates and compute gradient norm.
|
|
755
|
+
internalNet._lastGradNorm = applyOptimizerStep(
|
|
756
|
+
net,
|
|
757
|
+
optimizer,
|
|
758
|
+
currentRate,
|
|
759
|
+
momentum,
|
|
760
|
+
internalNet
|
|
761
|
+
);
|
|
762
|
+
// Dynamic loss scaling increase if conditions satisfied.
|
|
763
|
+
if (internalNet._mixedPrecision.enabled)
|
|
764
|
+
maybeIncreaseLossScale(internalNet);
|
|
765
|
+
}
|
|
766
|
+
}
|
|
767
|
+
batchSampleCount = 0; // reset mini-batch sample counter
|
|
768
|
+
}
|
|
769
|
+
}
|
|
770
|
+
}
|
|
771
|
+
if (internalNet._lastGradNorm == null) internalNet._lastGradNorm = 0;
|
|
772
|
+
return totalProcessedSamples > 0
|
|
773
|
+
? cumulativeError / totalProcessedSamples
|
|
774
|
+
: 0;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
/**
|
|
778
|
+
* High-level training orchestration with early stopping, smoothing & callbacks.
|
|
779
|
+
*/
|
|
780
|
+
export function trainImpl(
|
|
781
|
+
net: Network,
|
|
782
|
+
set: { input: number[]; output: number[] }[],
|
|
783
|
+
options: TrainingOptions
|
|
784
|
+
): { error: number; iterations: number; time: number } {
|
|
785
|
+
const internalNet = net as any;
|
|
786
|
+
if (
|
|
787
|
+
!set ||
|
|
788
|
+
set.length === 0 ||
|
|
789
|
+
set[0].input.length !== net.input ||
|
|
790
|
+
set[0].output.length !== net.output
|
|
791
|
+
) {
|
|
792
|
+
throw new Error(
|
|
793
|
+
'Dataset is invalid or dimensions do not match network input/output size!'
|
|
794
|
+
);
|
|
795
|
+
}
|
|
796
|
+
options = options || {};
|
|
797
|
+
if (
|
|
798
|
+
typeof options.iterations === 'undefined' &&
|
|
799
|
+
typeof options.error === 'undefined'
|
|
800
|
+
) {
|
|
801
|
+
if (config.warnings)
|
|
802
|
+
console.warn('Missing `iterations` or `error` option.');
|
|
803
|
+
throw new Error(
|
|
804
|
+
'Missing `iterations` or `error` option. Training requires a stopping condition.'
|
|
805
|
+
);
|
|
806
|
+
}
|
|
807
|
+
if (config.warnings) {
|
|
808
|
+
if (typeof options.rate === 'undefined') {
|
|
809
|
+
console.warn('Missing `rate` option');
|
|
810
|
+
console.warn('Missing `rate` option, using default learning rate 0.3.');
|
|
811
|
+
}
|
|
812
|
+
if (typeof options.iterations === 'undefined')
|
|
813
|
+
console.warn(
|
|
814
|
+
'Missing `iterations` option. Training will run potentially indefinitely until `error` threshold is met.'
|
|
815
|
+
);
|
|
816
|
+
}
|
|
817
|
+
/** Target monitored (smoothed) error threshold for early termination. */
|
|
818
|
+
let targetError = options.error ?? -Infinity;
|
|
819
|
+
/** Cost function (defaults to MSE) resolved from provided variant. */
|
|
820
|
+
const cost = options.cost || methods.Cost.mse;
|
|
821
|
+
if (
|
|
822
|
+
typeof cost !== 'function' &&
|
|
823
|
+
!(
|
|
824
|
+
typeof cost === 'object' &&
|
|
825
|
+
(typeof (cost as any).fn === 'function' ||
|
|
826
|
+
typeof (cost as any).calculate === 'function')
|
|
827
|
+
)
|
|
828
|
+
) {
|
|
829
|
+
throw new Error('Invalid cost function provided to Network.train.');
|
|
830
|
+
}
|
|
831
|
+
/** Base learning rate used as scaling factor for optimizer weight updates. */
|
|
832
|
+
const baseRate = options.rate ?? 0.3;
|
|
833
|
+
/** Dropout probability applied each forward pass (0 disables). */
|
|
834
|
+
const dropout = options.dropout || 0;
|
|
835
|
+
if (dropout < 0 || dropout >= 1) throw new Error('dropout must be in [0,1)');
|
|
836
|
+
/** Momentum factor for SGD or reused by optimizers expecting momentum param. */
|
|
837
|
+
const momentum = options.momentum || 0;
|
|
838
|
+
/** Mini-batch size (#samples per gradient accumulation flush). */
|
|
839
|
+
const batchSize = options.batchSize || 1;
|
|
840
|
+
if (batchSize > set.length)
|
|
841
|
+
throw new Error('Batch size cannot be larger than the dataset length.');
|
|
842
|
+
/** Gradient accumulation factor (micro-batches per optimizer step). */
|
|
843
|
+
const accumulationSteps = options.accumulationSteps || 1;
|
|
844
|
+
internalNet._accumulationReduction =
|
|
845
|
+
options.accumulationReduction === 'sum' ? 'sum' : 'average';
|
|
846
|
+
if (accumulationSteps < 1 || !Number.isFinite(accumulationSteps))
|
|
847
|
+
throw new Error('accumulationSteps must be >=1');
|
|
848
|
+
if (options.gradientClip) {
|
|
849
|
+
const gc = options.gradientClip;
|
|
850
|
+
if (gc.mode)
|
|
851
|
+
internalNet._currentGradClip = {
|
|
852
|
+
mode: gc.mode,
|
|
853
|
+
maxNorm: gc.maxNorm,
|
|
854
|
+
percentile: gc.percentile,
|
|
855
|
+
} as any;
|
|
856
|
+
else if (typeof gc.maxNorm === 'number')
|
|
857
|
+
internalNet._currentGradClip = { mode: 'norm', maxNorm: gc.maxNorm };
|
|
858
|
+
else if (typeof gc.percentile === 'number')
|
|
859
|
+
internalNet._currentGradClip = {
|
|
860
|
+
mode: 'percentile',
|
|
861
|
+
percentile: gc.percentile,
|
|
862
|
+
} as any;
|
|
863
|
+
internalNet._gradClipSeparateBias = !!gc.separateBias;
|
|
864
|
+
} else {
|
|
865
|
+
internalNet._currentGradClip = undefined;
|
|
866
|
+
internalNet._gradClipSeparateBias = false;
|
|
867
|
+
}
|
|
868
|
+
if (options.mixedPrecision) {
|
|
869
|
+
const mp =
|
|
870
|
+
options.mixedPrecision === true
|
|
871
|
+
? { lossScale: 1024 }
|
|
872
|
+
: options.mixedPrecision;
|
|
873
|
+
internalNet._mixedPrecision.enabled = true;
|
|
874
|
+
internalNet._mixedPrecision.lossScale = mp.lossScale || 1024;
|
|
875
|
+
const dyn = mp.dynamic || {};
|
|
876
|
+
internalNet._mixedPrecisionState.minLossScale = dyn.minScale || 1;
|
|
877
|
+
internalNet._mixedPrecisionState.maxLossScale = dyn.maxScale || 65536;
|
|
878
|
+
internalNet._mpIncreaseEvery =
|
|
879
|
+
dyn.increaseEvery || dyn.stableStepsForIncrease || 200;
|
|
880
|
+
net.connections.forEach((c) => {
|
|
881
|
+
(c as any)._fp32Weight = c.weight;
|
|
882
|
+
});
|
|
883
|
+
net.nodes.forEach((n) => {
|
|
884
|
+
if (n.type !== 'input') (n as any)._fp32Bias = n.bias;
|
|
885
|
+
});
|
|
886
|
+
} else {
|
|
887
|
+
internalNet._mixedPrecision.enabled = false;
|
|
888
|
+
internalNet._mixedPrecision.lossScale = 1;
|
|
889
|
+
internalNet._mpIncreaseEvery = 200;
|
|
890
|
+
}
|
|
891
|
+
/** Supported optimizer algorithm identifiers (lowercased). */
|
|
892
|
+
const allowedOptimizers = new Set([
|
|
893
|
+
'sgd',
|
|
894
|
+
'rmsprop',
|
|
895
|
+
'adagrad',
|
|
896
|
+
'adam',
|
|
897
|
+
'adamw',
|
|
898
|
+
'amsgrad',
|
|
899
|
+
'adamax',
|
|
900
|
+
'nadam',
|
|
901
|
+
'radam',
|
|
902
|
+
'lion',
|
|
903
|
+
'adabelief',
|
|
904
|
+
'lookahead',
|
|
905
|
+
]);
|
|
906
|
+
/** Normalized optimizer configuration or undefined for pure SGD mode. */
|
|
907
|
+
let optimizerConfig: any = undefined;
|
|
908
|
+
if (typeof options.optimizer !== 'undefined') {
|
|
909
|
+
if (typeof options.optimizer === 'string')
|
|
910
|
+
optimizerConfig = { type: options.optimizer.toLowerCase() };
|
|
911
|
+
else if (
|
|
912
|
+
typeof options.optimizer === 'object' &&
|
|
913
|
+
options.optimizer !== null
|
|
914
|
+
) {
|
|
915
|
+
optimizerConfig = { ...options.optimizer };
|
|
916
|
+
if (typeof optimizerConfig.type === 'string')
|
|
917
|
+
optimizerConfig.type = optimizerConfig.type.toLowerCase();
|
|
918
|
+
} else
|
|
919
|
+
throw new Error('Invalid optimizer option; must be string or object');
|
|
920
|
+
if (!allowedOptimizers.has(optimizerConfig.type))
|
|
921
|
+
throw new Error(`Unknown optimizer type: ${optimizerConfig.type}`);
|
|
922
|
+
if (optimizerConfig.type === 'lookahead') {
|
|
923
|
+
if (!optimizerConfig.baseType) optimizerConfig.baseType = 'adam';
|
|
924
|
+
if (optimizerConfig.baseType === 'lookahead')
|
|
925
|
+
throw new Error(
|
|
926
|
+
'Nested lookahead (baseType lookahead) is not supported'
|
|
927
|
+
);
|
|
928
|
+
if (!allowedOptimizers.has(optimizerConfig.baseType))
|
|
929
|
+
throw new Error(
|
|
930
|
+
`Unknown baseType for lookahead: ${optimizerConfig.baseType}`
|
|
931
|
+
);
|
|
932
|
+
optimizerConfig.la_k = optimizerConfig.la_k || 5;
|
|
933
|
+
optimizerConfig.la_alpha = optimizerConfig.la_alpha ?? 0.5;
|
|
934
|
+
}
|
|
935
|
+
}
|
|
936
|
+
/** Maximum training iterations permitted (guard against infinite loops w/ only error criterion). */
|
|
937
|
+
const iterations = options.iterations ?? Number.MAX_SAFE_INTEGER;
|
|
938
|
+
/** Wall-clock start time for duration metric. */
|
|
939
|
+
const start = Date.now();
|
|
940
|
+
/** Most recent monitored (smoothed) error value. */
|
|
941
|
+
let finalError = Infinity;
|
|
942
|
+
/** Window length for primary moving average smoothing. */
|
|
943
|
+
const movingAverageWindow = Math.max(1, options.movingAverageWindow || 1);
|
|
944
|
+
/** Selected smoothing algorithm kind. */
|
|
945
|
+
const movingAverageType = options.movingAverageType || 'sma';
|
|
946
|
+
/** EMA alpha (if EMA selected) computed via CMA formula unless explicitly overridden. */
|
|
947
|
+
const emaAlpha = (() => {
|
|
948
|
+
if (movingAverageType !== 'ema') return undefined;
|
|
949
|
+
if (options.emaAlpha && options.emaAlpha > 0 && options.emaAlpha <= 1)
|
|
950
|
+
return options.emaAlpha;
|
|
951
|
+
return 2 / (movingAverageWindow + 1);
|
|
952
|
+
})();
|
|
953
|
+
/** Separate window for plateau detection (defaults to primary window). */
|
|
954
|
+
const plateauWindow = Math.max(
|
|
955
|
+
1,
|
|
956
|
+
options.plateauMovingAverageWindow || movingAverageWindow
|
|
957
|
+
);
|
|
958
|
+
/** Smoothing algorithm used specifically for plateau (scheduler / early-stop) metrics. */
|
|
959
|
+
const plateauType = options.plateauMovingAverageType || movingAverageType;
|
|
960
|
+
/** EMA alpha for plateau smoothing if needed. */
|
|
961
|
+
const plateauEmaAlpha = (() => {
|
|
962
|
+
if (plateauType !== 'ema') return undefined;
|
|
963
|
+
if (
|
|
964
|
+
options.plateauEmaAlpha &&
|
|
965
|
+
options.plateauEmaAlpha > 0 &&
|
|
966
|
+
options.plateauEmaAlpha <= 1
|
|
967
|
+
)
|
|
968
|
+
return options.plateauEmaAlpha;
|
|
969
|
+
return 2 / (plateauWindow + 1);
|
|
970
|
+
})();
|
|
971
|
+
/** Max consecutive non-improving iterations tolerated before early stop (undefined => disabled). */
|
|
972
|
+
const earlyStopPatience = options.earlyStopPatience;
|
|
973
|
+
/** Minimal decrease required to qualify as improvement. */
|
|
974
|
+
const earlyStopMinDelta = options.earlyStopMinDelta || 0;
|
|
975
|
+
/** Best (lowest) monitored error observed so far. */
|
|
976
|
+
let bestError = Infinity;
|
|
977
|
+
/** Count of successive iterations without sufficient improvement. */
|
|
978
|
+
let noImproveCount = 0;
|
|
979
|
+
/** Capacity of circular buffer for recent errors. */
|
|
980
|
+
const recentErrorsCapacity = movingAverageWindow;
|
|
981
|
+
/** Circular buffer holding recent raw training errors (for smoothing). */
|
|
982
|
+
const recentErrorsBuf: number[] = new Array(recentErrorsCapacity);
|
|
983
|
+
/** Current number of valid entries in buffer (grows until capacity). */
|
|
984
|
+
let recentErrorsCount = 0;
|
|
985
|
+
/** Next write index within circular buffer. */
|
|
986
|
+
let recentErrorsWriteIdx = 0;
|
|
987
|
+
/** Push a new error value into circular buffer (overwriting oldest when full). */
|
|
988
|
+
const recentErrorsPush = (value: number) => {
|
|
989
|
+
if (recentErrorsCapacity === 1) {
|
|
990
|
+
recentErrorsBuf[0] = value;
|
|
991
|
+
recentErrorsCount = 1;
|
|
992
|
+
recentErrorsWriteIdx = 0;
|
|
993
|
+
return;
|
|
994
|
+
}
|
|
995
|
+
recentErrorsBuf[recentErrorsWriteIdx] = value;
|
|
996
|
+
recentErrorsWriteIdx = (recentErrorsWriteIdx + 1) % recentErrorsCapacity;
|
|
997
|
+
if (recentErrorsCount < recentErrorsCapacity) recentErrorsCount++;
|
|
998
|
+
};
|
|
999
|
+
/** Produce chronologically ordered snapshot of buffered errors. */
|
|
1000
|
+
const recentErrorsChrono = (): number[] => {
|
|
1001
|
+
if (recentErrorsCount === 0) return [];
|
|
1002
|
+
if (recentErrorsCount < recentErrorsCapacity)
|
|
1003
|
+
return recentErrorsBuf.slice(0, recentErrorsCount);
|
|
1004
|
+
const out = new Array(recentErrorsCount);
|
|
1005
|
+
const start = recentErrorsWriteIdx;
|
|
1006
|
+
for (let i = 0; i < recentErrorsCount; i++)
|
|
1007
|
+
out[i] = recentErrorsBuf[(start + i) % recentErrorsCapacity];
|
|
1008
|
+
return out;
|
|
1009
|
+
};
|
|
1010
|
+
/** Exponential moving average state for classic EMA smoothing. */
|
|
1011
|
+
let emaValue: number | undefined = undefined;
|
|
1012
|
+
/** Base EMA state for adaptive EMA (lower variance baseline). */
|
|
1013
|
+
let adaptiveBaseEmaValue: number | undefined = undefined;
|
|
1014
|
+
/** Adaptive EMA state (higher alpha when volatility detected). */
|
|
1015
|
+
let adaptiveEmaValue: number | undefined = undefined;
|
|
1016
|
+
/** Capacity of plateau circular buffer. */
|
|
1017
|
+
const plateauCapacity = plateauWindow;
|
|
1018
|
+
/** Raw errors buffer for plateau smoothing. */
|
|
1019
|
+
const plateauBuf: number[] = new Array(plateauCapacity);
|
|
1020
|
+
/** Current number of plateau entries filled. */
|
|
1021
|
+
let plateauCount = 0;
|
|
1022
|
+
/** Next write index for plateau buffer. */
|
|
1023
|
+
let plateauWriteIdx = 0;
|
|
1024
|
+
/** Insert new training error into plateau buffer. */
|
|
1025
|
+
const plateauPush = (value: number) => {
|
|
1026
|
+
if (plateauCapacity === 1) {
|
|
1027
|
+
plateauBuf[0] = value;
|
|
1028
|
+
plateauCount = 1;
|
|
1029
|
+
plateauWriteIdx = 0;
|
|
1030
|
+
return;
|
|
1031
|
+
}
|
|
1032
|
+
plateauBuf[plateauWriteIdx] = value;
|
|
1033
|
+
plateauWriteIdx = (plateauWriteIdx + 1) % plateauCapacity;
|
|
1034
|
+
if (plateauCount < plateauCapacity) plateauCount++;
|
|
1035
|
+
};
|
|
1036
|
+
/** Chronologically ordered plateau buffer snapshot. */
|
|
1037
|
+
const plateauChrono = (): number[] => {
|
|
1038
|
+
if (plateauCount === 0) return [];
|
|
1039
|
+
if (plateauCount < plateauCapacity)
|
|
1040
|
+
return plateauBuf.slice(0, plateauCount);
|
|
1041
|
+
const out = new Array(plateauCount);
|
|
1042
|
+
const start = plateauWriteIdx;
|
|
1043
|
+
for (let i = 0; i < plateauCount; i++)
|
|
1044
|
+
out[i] = plateauBuf[(start + i) % plateauCapacity];
|
|
1045
|
+
return out;
|
|
1046
|
+
};
|
|
1047
|
+
/** Plateau-specific EMA state (if plateauType === 'ema'). */
|
|
1048
|
+
let plateauEmaValue: number | undefined = undefined;
|
|
1049
|
+
/** Mutate network dropout probability for upcoming epoch iterations. */
|
|
1050
|
+
net.dropout = dropout;
|
|
1051
|
+
/** Number of iterations actually executed (in case of early stopping). */
|
|
1052
|
+
let performedIterations = 0;
|
|
1053
|
+
for (let iter = 1; iter <= iterations; iter++) {
|
|
1054
|
+
// -----------------------------
|
|
1055
|
+
// Iteration prologue
|
|
1056
|
+
// -----------------------------
|
|
1057
|
+
// 'iter' is 1-based to align with common optimizer bias-correction formulae (Adam etc.).
|
|
1058
|
+
if ((net as any)._maybePrune) {
|
|
1059
|
+
(net as any)._maybePrune((internalNet._globalEpoch || 0) + iter);
|
|
1060
|
+
}
|
|
1061
|
+
// Run one epoch pass over dataset (mini-batching handled internally) and obtain raw mean error.
|
|
1062
|
+
const trainError = trainSetImpl(
|
|
1063
|
+
net,
|
|
1064
|
+
set,
|
|
1065
|
+
batchSize,
|
|
1066
|
+
accumulationSteps,
|
|
1067
|
+
baseRate,
|
|
1068
|
+
momentum,
|
|
1069
|
+
{},
|
|
1070
|
+
cost as any,
|
|
1071
|
+
optimizerConfig
|
|
1072
|
+
);
|
|
1073
|
+
// Record that this iteration was fully executed (used if we early break afterwards).
|
|
1074
|
+
performedIterations = iter;
|
|
1075
|
+
// Push raw error into smoothing buffer(s) for subsequent moving-average computation.
|
|
1076
|
+
recentErrorsPush(trainError);
|
|
1077
|
+
/** Monitored error value after smoothing strategy is applied (initially raw). */
|
|
1078
|
+
let monitored = trainError;
|
|
1079
|
+
// -----------------------------
|
|
1080
|
+
// Primary moving-average smoothing block
|
|
1081
|
+
// -----------------------------
|
|
1082
|
+
// Conditions: apply if window > 1 or a strategy that inherently disregards window size (ema/adaptive).
|
|
1083
|
+
if (
|
|
1084
|
+
movingAverageWindow > 1 ||
|
|
1085
|
+
movingAverageType === 'ema' ||
|
|
1086
|
+
movingAverageType === 'adaptive-ema'
|
|
1087
|
+
) {
|
|
1088
|
+
const recentArr = recentErrorsChrono();
|
|
1089
|
+
if (movingAverageType === 'median') {
|
|
1090
|
+
// Robust central tendency; reduces influence of transient spikes.
|
|
1091
|
+
const sorted = [...recentArr].sort((a, b) => a - b);
|
|
1092
|
+
const mid = Math.floor(sorted.length / 2); // middle index
|
|
1093
|
+
monitored =
|
|
1094
|
+
sorted.length % 2 ? sorted[mid] : (sorted[mid - 1] + sorted[mid]) / 2;
|
|
1095
|
+
} else if (movingAverageType === 'ema') {
|
|
1096
|
+
// Classic exponentially weighted moving average (constant alpha).
|
|
1097
|
+
if (emaValue == null) emaValue = trainError;
|
|
1098
|
+
else emaValue = emaValue + emaAlpha! * (trainError - emaValue);
|
|
1099
|
+
monitored = emaValue;
|
|
1100
|
+
} else if (movingAverageType === 'adaptive-ema') {
|
|
1101
|
+
// Dual EMA: baseline + adaptive alpha that expands under variance to speed reaction, then we keep min.
|
|
1102
|
+
const mean = recentArr.reduce((a, b) => a + b, 0) / recentArr.length;
|
|
1103
|
+
const variance =
|
|
1104
|
+
recentArr.reduce((a, b) => a + (b - mean) * (b - mean), 0) /
|
|
1105
|
+
recentArr.length;
|
|
1106
|
+
const baseAlpha = emaAlpha || 2 / (movingAverageWindow + 1);
|
|
1107
|
+
const varScaled = variance / Math.max(mean * mean, 1e-8);
|
|
1108
|
+
const adaptAlpha = Math.min(
|
|
1109
|
+
0.95,
|
|
1110
|
+
Math.max(baseAlpha, baseAlpha * (1 + 2 * varScaled))
|
|
1111
|
+
);
|
|
1112
|
+
if (adaptiveBaseEmaValue == null) {
|
|
1113
|
+
adaptiveBaseEmaValue = trainError;
|
|
1114
|
+
adaptiveEmaValue = trainError;
|
|
1115
|
+
} else {
|
|
1116
|
+
adaptiveBaseEmaValue =
|
|
1117
|
+
adaptiveBaseEmaValue +
|
|
1118
|
+
baseAlpha * (trainError - adaptiveBaseEmaValue);
|
|
1119
|
+
adaptiveEmaValue =
|
|
1120
|
+
adaptiveEmaValue! + adaptAlpha * (trainError - adaptiveEmaValue!);
|
|
1121
|
+
}
|
|
1122
|
+
monitored = Math.min(adaptiveEmaValue!, adaptiveBaseEmaValue!);
|
|
1123
|
+
} else if (movingAverageType === 'gaussian') {
|
|
1124
|
+
// Weighted by Gaussian kernel centered at newest point; older (earlier) points get progressively less weight.
|
|
1125
|
+
const gaussianWindow = recentArr;
|
|
1126
|
+
const windowLength = gaussianWindow.length;
|
|
1127
|
+
const sigma = movingAverageWindow / 3 || 1; // heuristic: cover window with ~3 sigma
|
|
1128
|
+
let gaussianWeightSum = 0;
|
|
1129
|
+
let gaussianWeightedAccumulator = 0;
|
|
1130
|
+
for (let gi = 0; gi < windowLength; gi++) {
|
|
1131
|
+
const weight = Math.exp(
|
|
1132
|
+
-0.5 * Math.pow((gi - (windowLength - 1)) / sigma, 2)
|
|
1133
|
+
);
|
|
1134
|
+
gaussianWeightSum += weight;
|
|
1135
|
+
gaussianWeightedAccumulator += weight * gaussianWindow[gi];
|
|
1136
|
+
}
|
|
1137
|
+
monitored = gaussianWeightedAccumulator / (gaussianWeightSum || 1);
|
|
1138
|
+
} else if (movingAverageType === 'trimmed') {
|
|
1139
|
+
// Trim symmetrical tails to damp outliers before averaging.
|
|
1140
|
+
const tailTrimRatio = Math.min(
|
|
1141
|
+
0.49,
|
|
1142
|
+
Math.max(0, options.trimmedRatio || 0.1)
|
|
1143
|
+
);
|
|
1144
|
+
const sorted = [...recentArr].sort((a, b) => a - b);
|
|
1145
|
+
const elementsToDropEachSide = Math.floor(
|
|
1146
|
+
sorted.length * tailTrimRatio
|
|
1147
|
+
);
|
|
1148
|
+
const trimmedSegment = sorted.slice(
|
|
1149
|
+
elementsToDropEachSide,
|
|
1150
|
+
sorted.length - elementsToDropEachSide
|
|
1151
|
+
);
|
|
1152
|
+
monitored =
|
|
1153
|
+
trimmedSegment.reduce((a, b) => a + b, 0) /
|
|
1154
|
+
(trimmedSegment.length || 1);
|
|
1155
|
+
} else if (movingAverageType === 'wma') {
|
|
1156
|
+
// Linear weights: newer samples more influential.
|
|
1157
|
+
let linearWeightSum = 0;
|
|
1158
|
+
let linearWeightedAccumulator = 0;
|
|
1159
|
+
for (let li = 0; li < recentArr.length; li++) {
|
|
1160
|
+
const weight = li + 1; // oldest gets 1, newest gets N
|
|
1161
|
+
linearWeightSum += weight;
|
|
1162
|
+
linearWeightedAccumulator += weight * recentArr[li];
|
|
1163
|
+
}
|
|
1164
|
+
monitored = linearWeightedAccumulator / (linearWeightSum || 1);
|
|
1165
|
+
} else {
|
|
1166
|
+
// Simple arithmetic mean (SMA).
|
|
1167
|
+
monitored = recentArr.reduce((a, b) => a + b, 0) / recentArr.length;
|
|
1168
|
+
}
|
|
1169
|
+
}
|
|
1170
|
+
// Update finalError with the smoothed/selected monitored metric.
|
|
1171
|
+
finalError = monitored;
|
|
1172
|
+
// Store raw trainError (not smoothed) for plateau evaluation buffer.
|
|
1173
|
+
plateauPush(trainError);
|
|
1174
|
+
/** Plateau-smoothed error (could use different smoothing strategy than monitored). */
|
|
1175
|
+
let plateauError: number | undefined = trainError;
|
|
1176
|
+
if (plateauWindow > 1 || plateauType === 'ema') {
|
|
1177
|
+
if (plateauType === 'median') {
|
|
1178
|
+
// Median for plateau stability over variable noise.
|
|
1179
|
+
const sorted = [...plateauChrono()].sort((a, b) => a - b);
|
|
1180
|
+
const mid = Math.floor(sorted.length / 2);
|
|
1181
|
+
plateauError =
|
|
1182
|
+
sorted.length % 2 ? sorted[mid] : (sorted[mid - 1] + sorted[mid]) / 2;
|
|
1183
|
+
} else if (plateauType === 'ema') {
|
|
1184
|
+
// EMA variant for plateau detection (faster adaptation with controlled lag).
|
|
1185
|
+
if (plateauEmaValue == null) plateauEmaValue = trainError;
|
|
1186
|
+
else
|
|
1187
|
+
plateauEmaValue =
|
|
1188
|
+
plateauEmaValue + plateauEmaAlpha! * (trainError - plateauEmaValue);
|
|
1189
|
+
plateauError = plateauEmaValue;
|
|
1190
|
+
} else {
|
|
1191
|
+
// Default plateau = arithmetic mean over plateau window.
|
|
1192
|
+
const arr = plateauChrono();
|
|
1193
|
+
plateauError = arr.reduce((a, b) => a + b, 0) / arr.length;
|
|
1194
|
+
}
|
|
1195
|
+
}
|
|
1196
|
+
if (typeof options.metricsHook === 'function') {
|
|
1197
|
+
try {
|
|
1198
|
+
// User hook for live metrics logging / dashboards / adaptive schedulers.
|
|
1199
|
+
options.metricsHook({
|
|
1200
|
+
iteration: iter,
|
|
1201
|
+
error: finalError,
|
|
1202
|
+
plateauError,
|
|
1203
|
+
gradNorm: internalNet._lastGradNorm ?? 0,
|
|
1204
|
+
});
|
|
1205
|
+
} catch {}
|
|
1206
|
+
}
|
|
1207
|
+
if (options.checkpoint && typeof options.checkpoint.save === 'function') {
|
|
1208
|
+
if (options.checkpoint.last) {
|
|
1209
|
+
try {
|
|
1210
|
+
// Always save most recent network state.
|
|
1211
|
+
options.checkpoint.save({
|
|
1212
|
+
type: 'last',
|
|
1213
|
+
iteration: iter,
|
|
1214
|
+
error: finalError,
|
|
1215
|
+
network: net.toJSON(),
|
|
1216
|
+
});
|
|
1217
|
+
} catch {}
|
|
1218
|
+
}
|
|
1219
|
+
if (options.checkpoint.best) {
|
|
1220
|
+
if (
|
|
1221
|
+
finalError < (net as any)._checkpointBestError ||
|
|
1222
|
+
(net as any)._checkpointBestError == null
|
|
1223
|
+
) {
|
|
1224
|
+
// New best model discovered under monitored error metric.
|
|
1225
|
+
(net as any)._checkpointBestError = finalError;
|
|
1226
|
+
try {
|
|
1227
|
+
options.checkpoint.save({
|
|
1228
|
+
type: 'best',
|
|
1229
|
+
iteration: iter,
|
|
1230
|
+
error: finalError,
|
|
1231
|
+
network: net.toJSON(),
|
|
1232
|
+
});
|
|
1233
|
+
} catch {}
|
|
1234
|
+
}
|
|
1235
|
+
}
|
|
1236
|
+
}
|
|
1237
|
+
if (
|
|
1238
|
+
options.schedule &&
|
|
1239
|
+
options.schedule.iterations &&
|
|
1240
|
+
iter % options.schedule.iterations === 0
|
|
1241
|
+
) {
|
|
1242
|
+
try {
|
|
1243
|
+
// Periodic user-defined callback (e.g., adjust LR, print status, inject curriculum changes).
|
|
1244
|
+
options.schedule.function({ error: finalError, iteration: iter });
|
|
1245
|
+
} catch {}
|
|
1246
|
+
}
|
|
1247
|
+
// -----------------------------
|
|
1248
|
+
// Early stopping logic
|
|
1249
|
+
// -----------------------------
|
|
1250
|
+
if (finalError < bestError - earlyStopMinDelta) {
|
|
1251
|
+
// Sufficient improvement: update best and reset stagnation counter.
|
|
1252
|
+
bestError = finalError;
|
|
1253
|
+
noImproveCount = 0;
|
|
1254
|
+
} else if (earlyStopPatience) {
|
|
1255
|
+
// Track consecutive non-improving iterations.
|
|
1256
|
+
noImproveCount++;
|
|
1257
|
+
}
|
|
1258
|
+
// Patience exhaustion: terminate.
|
|
1259
|
+
if (earlyStopPatience && noImproveCount >= earlyStopPatience) break;
|
|
1260
|
+
// Target error reached: terminate.
|
|
1261
|
+
if (finalError <= targetError) break;
|
|
1262
|
+
}
|
|
1263
|
+
net.nodes.forEach((n) => {
|
|
1264
|
+
if (n.type === 'hidden') n.mask = 1;
|
|
1265
|
+
});
|
|
1266
|
+
// Clear dropout for inference after training completes.
|
|
1267
|
+
net.dropout = 0;
|
|
1268
|
+
internalNet._globalEpoch =
|
|
1269
|
+
(internalNet._globalEpoch || 0) + performedIterations;
|
|
1270
|
+
return {
|
|
1271
|
+
/** Final monitored (possibly smoothed) error achieved at termination. */
|
|
1272
|
+
error: finalError,
|
|
1273
|
+
/** Number of iterations actually executed (could be < requested iterations due to early stop). */
|
|
1274
|
+
iterations: performedIterations,
|
|
1275
|
+
/** Wall-clock training duration in milliseconds. */
|
|
1276
|
+
time: Date.now() - start,
|
|
1277
|
+
};
|
|
1278
|
+
}
|