@reicek/neataptic-ts 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/ISSUE_TEMPLATE/bug_report.md +33 -0
- package/.github/ISSUE_TEMPLATE/feature_request.md +27 -0
- package/.github/PULL_REQUEST_TEMPLATE.md +28 -0
- package/.github/workflows/ci.yml +41 -0
- package/.github/workflows/deploy-pages.yml +29 -0
- package/.github/workflows/manual_release_pipeline.yml +62 -0
- package/.github/workflows/publish.yml +85 -0
- package/.github/workflows/release_dispatch.yml +38 -0
- package/.travis.yml +5 -0
- package/CONTRIBUTING.md +92 -0
- package/LICENSE +24 -0
- package/ONNX_EXPORT.md +87 -0
- package/README.md +1173 -0
- package/RELEASE.md +54 -0
- package/dist-docs/package.json +1 -0
- package/dist-docs/scripts/generate-docs.d.ts +2 -0
- package/dist-docs/scripts/generate-docs.d.ts.map +1 -0
- package/dist-docs/scripts/generate-docs.js +536 -0
- package/dist-docs/scripts/generate-docs.js.map +1 -0
- package/dist-docs/scripts/render-docs-html.d.ts +2 -0
- package/dist-docs/scripts/render-docs-html.d.ts.map +1 -0
- package/dist-docs/scripts/render-docs-html.js +148 -0
- package/dist-docs/scripts/render-docs-html.js.map +1 -0
- package/docs/FOLDERS.md +14 -0
- package/docs/README.md +1173 -0
- package/docs/architecture/README.md +1391 -0
- package/docs/architecture/index.html +938 -0
- package/docs/architecture/network/README.md +1210 -0
- package/docs/architecture/network/index.html +908 -0
- package/docs/assets/ascii-maze.bundle.js +16542 -0
- package/docs/assets/ascii-maze.bundle.js.map +7 -0
- package/docs/index.html +1419 -0
- package/docs/methods/README.md +670 -0
- package/docs/methods/index.html +477 -0
- package/docs/multithreading/README.md +274 -0
- package/docs/multithreading/index.html +215 -0
- package/docs/multithreading/workers/README.md +23 -0
- package/docs/multithreading/workers/browser/README.md +39 -0
- package/docs/multithreading/workers/browser/index.html +70 -0
- package/docs/multithreading/workers/index.html +57 -0
- package/docs/multithreading/workers/node/README.md +33 -0
- package/docs/multithreading/workers/node/index.html +66 -0
- package/docs/neat/README.md +1284 -0
- package/docs/neat/index.html +906 -0
- package/docs/src/README.md +2659 -0
- package/docs/src/index.html +1579 -0
- package/jest.config.ts +32 -0
- package/package.json +99 -0
- package/plans/HyperMorphoNEAT.md +293 -0
- package/plans/ONNX_EXPORT_PLAN.md +46 -0
- package/scripts/generate-docs.ts +486 -0
- package/scripts/render-docs-html.ts +138 -0
- package/scripts/types.d.ts +2 -0
- package/src/README.md +2659 -0
- package/src/architecture/README.md +1391 -0
- package/src/architecture/activationArrayPool.ts +135 -0
- package/src/architecture/architect.ts +635 -0
- package/src/architecture/connection.ts +148 -0
- package/src/architecture/group.ts +406 -0
- package/src/architecture/layer.ts +804 -0
- package/src/architecture/network/README.md +1210 -0
- package/src/architecture/network/network.activate.ts +223 -0
- package/src/architecture/network/network.connect.ts +157 -0
- package/src/architecture/network/network.deterministic.ts +167 -0
- package/src/architecture/network/network.evolve.ts +426 -0
- package/src/architecture/network/network.gating.ts +186 -0
- package/src/architecture/network/network.genetic.ts +247 -0
- package/src/architecture/network/network.mutate.ts +624 -0
- package/src/architecture/network/network.onnx.ts +463 -0
- package/src/architecture/network/network.prune.ts +216 -0
- package/src/architecture/network/network.remove.ts +96 -0
- package/src/architecture/network/network.serialize.ts +309 -0
- package/src/architecture/network/network.slab.ts +262 -0
- package/src/architecture/network/network.standalone.ts +246 -0
- package/src/architecture/network/network.stats.ts +59 -0
- package/src/architecture/network/network.topology.ts +86 -0
- package/src/architecture/network/network.training.ts +1278 -0
- package/src/architecture/network.ts +1302 -0
- package/src/architecture/node.ts +1288 -0
- package/src/architecture/onnx.ts +3 -0
- package/src/config.ts +83 -0
- package/src/methods/README.md +670 -0
- package/src/methods/activation.ts +372 -0
- package/src/methods/connection.ts +31 -0
- package/src/methods/cost.ts +347 -0
- package/src/methods/crossover.ts +63 -0
- package/src/methods/gating.ts +43 -0
- package/src/methods/methods.ts +8 -0
- package/src/methods/mutation.ts +300 -0
- package/src/methods/rate.ts +257 -0
- package/src/methods/selection.ts +65 -0
- package/src/multithreading/README.md +274 -0
- package/src/multithreading/multi.ts +339 -0
- package/src/multithreading/workers/README.md +23 -0
- package/src/multithreading/workers/browser/README.md +39 -0
- package/src/multithreading/workers/browser/testworker.ts +99 -0
- package/src/multithreading/workers/node/README.md +33 -0
- package/src/multithreading/workers/node/testworker.ts +72 -0
- package/src/multithreading/workers/node/worker.ts +70 -0
- package/src/multithreading/workers/workers.ts +22 -0
- package/src/neat/README.md +1284 -0
- package/src/neat/neat.adaptive.ts +544 -0
- package/src/neat/neat.compat.ts +164 -0
- package/src/neat/neat.constants.ts +20 -0
- package/src/neat/neat.diversity.ts +217 -0
- package/src/neat/neat.evaluate.ts +328 -0
- package/src/neat/neat.evolve.ts +1026 -0
- package/src/neat/neat.export.ts +249 -0
- package/src/neat/neat.helpers.ts +235 -0
- package/src/neat/neat.lineage.ts +220 -0
- package/src/neat/neat.multiobjective.ts +260 -0
- package/src/neat/neat.mutation.ts +718 -0
- package/src/neat/neat.objectives.ts +157 -0
- package/src/neat/neat.pruning.ts +190 -0
- package/src/neat/neat.selection.ts +269 -0
- package/src/neat/neat.speciation.ts +460 -0
- package/src/neat/neat.species.ts +151 -0
- package/src/neat/neat.telemetry.exports.ts +469 -0
- package/src/neat/neat.telemetry.ts +933 -0
- package/src/neat/neat.types.ts +275 -0
- package/src/neat.ts +1042 -0
- package/src/neataptic.ts +10 -0
- package/test/architecture/activationArrayPool.capacity.test.ts +19 -0
- package/test/architecture/activationArrayPool.test.ts +46 -0
- package/test/architecture/connection.test.ts +290 -0
- package/test/architecture/group.test.ts +950 -0
- package/test/architecture/layer.test.ts +1535 -0
- package/test/architecture/network.pruning.test.ts +65 -0
- package/test/architecture/node.test.ts +1602 -0
- package/test/examples/asciiMaze/asciiMaze.e2e.test.ts +499 -0
- package/test/examples/asciiMaze/asciiMaze.ts +41 -0
- package/test/examples/asciiMaze/browser-entry.ts +164 -0
- package/test/examples/asciiMaze/browserLogger.ts +221 -0
- package/test/examples/asciiMaze/browserTerminalUtility.ts +48 -0
- package/test/examples/asciiMaze/colors.ts +119 -0
- package/test/examples/asciiMaze/dashboardManager.ts +968 -0
- package/test/examples/asciiMaze/evolutionEngine.ts +1248 -0
- package/test/examples/asciiMaze/fitness.ts +136 -0
- package/test/examples/asciiMaze/index.html +128 -0
- package/test/examples/asciiMaze/index.ts +26 -0
- package/test/examples/asciiMaze/interfaces.ts +235 -0
- package/test/examples/asciiMaze/mazeMovement.ts +996 -0
- package/test/examples/asciiMaze/mazeUtils.ts +278 -0
- package/test/examples/asciiMaze/mazeVision.ts +402 -0
- package/test/examples/asciiMaze/mazeVisualization.ts +585 -0
- package/test/examples/asciiMaze/mazes.ts +245 -0
- package/test/examples/asciiMaze/networkRefinement.ts +76 -0
- package/test/examples/asciiMaze/networkVisualization.ts +901 -0
- package/test/examples/asciiMaze/terminalUtility.ts +73 -0
- package/test/methods/activation.test.ts +1142 -0
- package/test/methods/connection.test.ts +146 -0
- package/test/methods/cost.test.ts +1123 -0
- package/test/methods/crossover.test.ts +202 -0
- package/test/methods/gating.test.ts +144 -0
- package/test/methods/mutation.test.ts +451 -0
- package/test/methods/optimizers.advanced.test.ts +80 -0
- package/test/methods/optimizers.behavior.test.ts +105 -0
- package/test/methods/optimizers.formula.test.ts +89 -0
- package/test/methods/rate.cosineWarmRestarts.test.ts +44 -0
- package/test/methods/rate.linearWarmupDecay.test.ts +41 -0
- package/test/methods/rate.reduceOnPlateau.test.ts +45 -0
- package/test/methods/rate.test.ts +684 -0
- package/test/methods/selection.test.ts +245 -0
- package/test/multithreading/activations.functions.test.ts +54 -0
- package/test/multithreading/multi.test.ts +290 -0
- package/test/multithreading/worker.node.process.test.ts +39 -0
- package/test/multithreading/workers.coverage.test.ts +36 -0
- package/test/multithreading/workers.dynamic.import.test.ts +8 -0
- package/test/neat/neat.adaptive.complexityBudget.test.ts +34 -0
- package/test/neat/neat.adaptive.criterion.complexity.test.ts +50 -0
- package/test/neat/neat.adaptive.mutation.strategy.test.ts +37 -0
- package/test/neat/neat.adaptive.operator.decay.test.ts +31 -0
- package/test/neat/neat.adaptive.phasedComplexity.test.ts +25 -0
- package/test/neat/neat.adaptive.pruning.test.ts +25 -0
- package/test/neat/neat.adaptive.targetSpecies.test.ts +43 -0
- package/test/neat/neat.additional.coverage.test.ts +126 -0
- package/test/neat/neat.advanced.enhancements.test.ts +85 -0
- package/test/neat/neat.advanced.test.ts +589 -0
- package/test/neat/neat.diversity.autocompat.test.ts +47 -0
- package/test/neat/neat.diversity.metrics.test.ts +21 -0
- package/test/neat/neat.diversity.stats.test.ts +44 -0
- package/test/neat/neat.enhancements.test.ts +79 -0
- package/test/neat/neat.entropy.ancestorAdaptive.test.ts +133 -0
- package/test/neat/neat.entropy.compat.csv.test.ts +108 -0
- package/test/neat/neat.evolution.pruning.test.ts +39 -0
- package/test/neat/neat.fastmode.autotune.test.ts +42 -0
- package/test/neat/neat.innovation.test.ts +134 -0
- package/test/neat/neat.lineage.antibreeding.test.ts +35 -0
- package/test/neat/neat.lineage.entropy.test.ts +56 -0
- package/test/neat/neat.lineage.inbreeding.test.ts +49 -0
- package/test/neat/neat.lineage.pressure.test.ts +29 -0
- package/test/neat/neat.multiobjective.adaptive.test.ts +57 -0
- package/test/neat/neat.multiobjective.dynamic.schedule.test.ts +46 -0
- package/test/neat/neat.multiobjective.dynamic.test.ts +31 -0
- package/test/neat/neat.multiobjective.fastsort.delegation.test.ts +51 -0
- package/test/neat/neat.multiobjective.prune.test.ts +39 -0
- package/test/neat/neat.multiobjective.test.ts +21 -0
- package/test/neat/neat.mutation.undefined.pool.test.ts +24 -0
- package/test/neat/neat.objective.events.test.ts +26 -0
- package/test/neat/neat.objective.importance.test.ts +21 -0
- package/test/neat/neat.objective.lifetimes.test.ts +33 -0
- package/test/neat/neat.offspring.allocation.test.ts +22 -0
- package/test/neat/neat.operator.bandit.test.ts +17 -0
- package/test/neat/neat.operator.phases.test.ts +38 -0
- package/test/neat/neat.pruneInactive.behavior.test.ts +54 -0
- package/test/neat/neat.reenable.adaptation.test.ts +18 -0
- package/test/neat/neat.rng.state.test.ts +22 -0
- package/test/neat/neat.spawn.add.test.ts +123 -0
- package/test/neat/neat.speciation.test.ts +96 -0
- package/test/neat/neat.species.allocation.telemetry.test.ts +26 -0
- package/test/neat/neat.species.history.csv.test.ts +24 -0
- package/test/neat/neat.telemetry.advanced.test.ts +226 -0
- package/test/neat/neat.telemetry.csv.lineage.test.ts +19 -0
- package/test/neat/neat.telemetry.parity.test.ts +42 -0
- package/test/neat/neat.telemetry.stream.test.ts +19 -0
- package/test/neat/neat.telemetry.test.ts +16 -0
- package/test/neat/neat.test.ts +422 -0
- package/test/neat/neat.utilities.test.ts +44 -0
- package/test/network/__suppress_console.ts +9 -0
- package/test/network/acyclic.topoorder.test.ts +17 -0
- package/test/network/checkpoint.metricshook.test.ts +36 -0
- package/test/network/error.handling.test.ts +581 -0
- package/test/network/evolution.test.ts +285 -0
- package/test/network/genetic.test.ts +208 -0
- package/test/network/learning.capability.test.ts +244 -0
- package/test/network/mutation.effects.test.ts +492 -0
- package/test/network/network.activate.test.ts +115 -0
- package/test/network/network.activateBatch.test.ts +30 -0
- package/test/network/network.deterministic.test.ts +64 -0
- package/test/network/network.evolve.branches.test.ts +75 -0
- package/test/network/network.evolve.multithread.branches.test.ts +83 -0
- package/test/network/network.evolve.test.ts +100 -0
- package/test/network/network.gating.removal.test.ts +93 -0
- package/test/network/network.mutate.additional.test.ts +145 -0
- package/test/network/network.mutate.edgecases.test.ts +101 -0
- package/test/network/network.mutate.test.ts +101 -0
- package/test/network/network.prune.earlyexit.test.ts +38 -0
- package/test/network/network.remove.errors.test.ts +45 -0
- package/test/network/network.slab.fallbacks.test.ts +22 -0
- package/test/network/network.stats.test.ts +45 -0
- package/test/network/network.training.advanced.test.ts +149 -0
- package/test/network/network.training.basic.test.ts +228 -0
- package/test/network/network.training.helpers.test.ts +183 -0
- package/test/network/onnx.export.test.ts +310 -0
- package/test/network/onnx.import.test.ts +129 -0
- package/test/network/pruning.topology.test.ts +282 -0
- package/test/network/regularization.determinism.test.ts +83 -0
- package/test/network/regularization.dropconnect.test.ts +17 -0
- package/test/network/regularization.dropconnect.validation.test.ts +18 -0
- package/test/network/regularization.stochasticdepth.test.ts +27 -0
- package/test/network/regularization.test.ts +843 -0
- package/test/network/regularization.weightnoise.test.ts +30 -0
- package/test/network/setupTests.ts +2 -0
- package/test/network/standalone.test.ts +332 -0
- package/test/network/structure.serialization.test.ts +660 -0
- package/test/training/training.determinism.mixed-precision.test.ts +134 -0
- package/test/training/training.earlystopping.test.ts +91 -0
- package/test/training/training.edge-cases.test.ts +91 -0
- package/test/training/training.extensions.test.ts +47 -0
- package/test/training/training.gradient.features.test.ts +110 -0
- package/test/training/training.gradient.refinements.test.ts +170 -0
- package/test/training/training.gradient.separate-bias.test.ts +41 -0
- package/test/training/training.optimizer.test.ts +48 -0
- package/test/training/training.plateau.smoothing.test.ts +58 -0
- package/test/training/training.smoothing.types.test.ts +174 -0
- package/test/training/training.train.options.coverage.test.ts +52 -0
- package/test/utils/console-helper.ts +76 -0
- package/test/utils/jest-setup.ts +60 -0
- package/test/utils/test-helpers.ts +175 -0
- package/tsconfig.docs.json +12 -0
- package/tsconfig.json +21 -0
- package/webpack.config.js +49 -0
|
@@ -0,0 +1,1535 @@
|
|
|
1
|
+
import Layer from '../../src/architecture/layer';
|
|
2
|
+
import Node from '../../src/architecture/node';
|
|
3
|
+
import Group from '../../src/architecture/group';
|
|
4
|
+
import Connection from '../../src/architecture/connection';
|
|
5
|
+
import * as methods from '../../src/methods/methods';
|
|
6
|
+
import { config } from '../../src/config';
|
|
7
|
+
|
|
8
|
+
// Retry failed tests
|
|
9
|
+
jest.retryTimes(2, { logErrorsBeforeRetry: true });
|
|
10
|
+
|
|
11
|
+
// Helper function to check group connectivity
|
|
12
|
+
const isGroupConnectedTo = (
|
|
13
|
+
groupA: Group,
|
|
14
|
+
groupB: Group,
|
|
15
|
+
method?: any
|
|
16
|
+
): boolean => {
|
|
17
|
+
if (!groupA || !groupB || !groupA.nodes || !groupB.nodes) return false; // Basic validation
|
|
18
|
+
|
|
19
|
+
if (method === methods.groupConnection.ONE_TO_ONE) {
|
|
20
|
+
if (groupA.nodes.length !== groupB.nodes.length) return false;
|
|
21
|
+
return groupA.nodes.every((nodeA, i) => {
|
|
22
|
+
const nodeB = groupB.nodes[i];
|
|
23
|
+
if (!nodeA || !nodeB) return false; // Node validation
|
|
24
|
+
if (nodeA.connections.out.some((conn) => conn.to === nodeB)) return true;
|
|
25
|
+
if (
|
|
26
|
+
nodeA === nodeB &&
|
|
27
|
+
nodeA.connections.self.some((conn: Connection) => conn.to === nodeB)
|
|
28
|
+
)
|
|
29
|
+
return true;
|
|
30
|
+
return false;
|
|
31
|
+
});
|
|
32
|
+
} else {
|
|
33
|
+
return groupA.nodes.some(
|
|
34
|
+
(nodeA) =>
|
|
35
|
+
nodeA &&
|
|
36
|
+
groupB.nodes.some((nodeB) => {
|
|
37
|
+
if (!nodeB) return false; // Node validation
|
|
38
|
+
if (nodeA.connections.out.some((conn) => conn.to === nodeB))
|
|
39
|
+
return true;
|
|
40
|
+
if (
|
|
41
|
+
nodeA === nodeB &&
|
|
42
|
+
nodeA.connections.self.some((conn: Connection) => conn.to === nodeB)
|
|
43
|
+
)
|
|
44
|
+
return true;
|
|
45
|
+
return false;
|
|
46
|
+
})
|
|
47
|
+
);
|
|
48
|
+
}
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
describe('Layer', () => {
|
|
52
|
+
const epsilon = 1e-9; // Tolerance for float comparisons
|
|
53
|
+
|
|
54
|
+
describe('Constructor', () => {
|
|
55
|
+
let layer: Layer;
|
|
56
|
+
|
|
57
|
+
beforeEach(() => {
|
|
58
|
+
layer = new Layer();
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
it('should initialize output to null', () => {
|
|
62
|
+
// Arrange, Act & Assert
|
|
63
|
+
expect(layer.output).toBeNull();
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
it('should initialize nodes as an empty array', () => {
|
|
67
|
+
// Arrange, Act & Assert
|
|
68
|
+
expect(layer.nodes).toEqual([]);
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
it('should initialize connections.in as an empty array', () => {
|
|
72
|
+
// Arrange, Act & Assert
|
|
73
|
+
expect(layer.connections.in).toEqual([]);
|
|
74
|
+
});
|
|
75
|
+
|
|
76
|
+
it('should initialize connections.out as an empty array', () => {
|
|
77
|
+
// Arrange, Act & Assert
|
|
78
|
+
expect(layer.connections.out).toEqual([]);
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
it('should initialize connections.self as an empty array', () => {
|
|
82
|
+
// Arrange, Act & Assert
|
|
83
|
+
expect(layer.connections.self).toEqual([]);
|
|
84
|
+
});
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
const createTestLayer = (size: number): Layer => {
|
|
88
|
+
const layer = new Layer();
|
|
89
|
+
for (let i = 0; i < size; i++) {
|
|
90
|
+
const node = new Node();
|
|
91
|
+
layer.nodes.push(node);
|
|
92
|
+
}
|
|
93
|
+
const group = new Group(size);
|
|
94
|
+
layer.output = group;
|
|
95
|
+
return layer;
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
describe('Instance Methods', () => {
|
|
99
|
+
describe('activate()', () => {
|
|
100
|
+
describe('Scenario: input layer', () => {
|
|
101
|
+
it('should return activation values from input nodes with input', () => {
|
|
102
|
+
// Arrange
|
|
103
|
+
const size = 3;
|
|
104
|
+
const inputValues = [0.5, -0.2, 0.9];
|
|
105
|
+
const layer = new Layer();
|
|
106
|
+
for (let i = 0; i < size; i++) {
|
|
107
|
+
const node = new Node('input');
|
|
108
|
+
layer.nodes.push(node);
|
|
109
|
+
}
|
|
110
|
+
// Act
|
|
111
|
+
const activations = layer.activate(inputValues);
|
|
112
|
+
// Assert
|
|
113
|
+
expect(activations).toHaveLength(size);
|
|
114
|
+
activations.forEach((act, i) => {
|
|
115
|
+
expect(act).toBe(inputValues[i]);
|
|
116
|
+
});
|
|
117
|
+
});
|
|
118
|
+
});
|
|
119
|
+
describe('Scenario: hidden layer', () => {
|
|
120
|
+
it('should return activation values after applying activation function', () => {
|
|
121
|
+
// Arrange
|
|
122
|
+
const size = 3;
|
|
123
|
+
const inputValues = [0.5, -0.2, 0.9];
|
|
124
|
+
const layer = new Layer();
|
|
125
|
+
for (let i = 0; i < size; i++) {
|
|
126
|
+
const node = new Node('hidden');
|
|
127
|
+
layer.nodes.push(node);
|
|
128
|
+
}
|
|
129
|
+
// Act
|
|
130
|
+
const activations = layer.activate(inputValues);
|
|
131
|
+
// Assert
|
|
132
|
+
expect(activations).toHaveLength(size);
|
|
133
|
+
activations.forEach((act, i) => {
|
|
134
|
+
// Default activation is sigmoid
|
|
135
|
+
expect(act).toBeCloseTo(1 / (1 + Math.exp(-inputValues[i])), 10);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
describe('propagate()', () => {
|
|
142
|
+
const size = 3;
|
|
143
|
+
let layer: Layer;
|
|
144
|
+
let nodeSpies: jest.SpyInstance[];
|
|
145
|
+
const rate = 0.1;
|
|
146
|
+
const momentum = 0.9;
|
|
147
|
+
|
|
148
|
+
beforeEach(() => {
|
|
149
|
+
layer = createTestLayer(size);
|
|
150
|
+
nodeSpies = layer.nodes.map((node) => jest.spyOn(node, 'propagate'));
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
afterEach(() => {
|
|
154
|
+
nodeSpies.forEach((spy) => spy.mockRestore());
|
|
155
|
+
});
|
|
156
|
+
|
|
157
|
+
it('should call propagate on all nodes without target values', () => {
|
|
158
|
+
// Arrange
|
|
159
|
+
// Act
|
|
160
|
+
layer.propagate(rate, momentum);
|
|
161
|
+
// Assert
|
|
162
|
+
expect(nodeSpies).toHaveLength(size);
|
|
163
|
+
nodeSpies.forEach((spy) => {
|
|
164
|
+
expect(spy).toHaveBeenCalledTimes(1);
|
|
165
|
+
expect(spy).toHaveBeenCalledWith(rate, momentum, true, 0);
|
|
166
|
+
});
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
it('should call propagate on nodes in reverse order', () => {
|
|
170
|
+
// Arrange
|
|
171
|
+
const callOrder: number[] = [];
|
|
172
|
+
nodeSpies.forEach((spy, i) =>
|
|
173
|
+
spy.mockImplementation(() => callOrder.push(i))
|
|
174
|
+
);
|
|
175
|
+
// Act
|
|
176
|
+
layer.propagate(rate, momentum);
|
|
177
|
+
// Assert
|
|
178
|
+
expect(callOrder).toEqual([2, 1, 0]);
|
|
179
|
+
});
|
|
180
|
+
|
|
181
|
+
it('should call propagate on all nodes with target values', () => {
|
|
182
|
+
// Arrange
|
|
183
|
+
const targetValues = [0.8, 0.1, 0.5];
|
|
184
|
+
// Act
|
|
185
|
+
layer.propagate(rate, momentum, targetValues);
|
|
186
|
+
// Assert
|
|
187
|
+
expect(nodeSpies).toHaveLength(size);
|
|
188
|
+
nodeSpies.reverse().forEach((spy, i) => {
|
|
189
|
+
const originalIndex = size - 1 - i;
|
|
190
|
+
expect(spy).toHaveBeenCalledTimes(1);
|
|
191
|
+
expect(spy).toHaveBeenCalledWith(
|
|
192
|
+
rate,
|
|
193
|
+
momentum,
|
|
194
|
+
true,
|
|
195
|
+
0,
|
|
196
|
+
targetValues[originalIndex]
|
|
197
|
+
);
|
|
198
|
+
});
|
|
199
|
+
});
|
|
200
|
+
|
|
201
|
+
it('should throw error if target value array length mismatches layer node count', () => {
|
|
202
|
+
// Arrange
|
|
203
|
+
const invalidTarget = [0.1, 0.2];
|
|
204
|
+
// Act & Assert
|
|
205
|
+
expect(() => layer.propagate(rate, momentum, invalidTarget)).toThrow(
|
|
206
|
+
'Array with values should be same as the amount of nodes!'
|
|
207
|
+
);
|
|
208
|
+
});
|
|
209
|
+
});
|
|
210
|
+
|
|
211
|
+
describe('connect()', () => {
|
|
212
|
+
let sourceLayer: Layer;
|
|
213
|
+
let targetLayer: Layer;
|
|
214
|
+
let targetGroup: Group;
|
|
215
|
+
let targetNode: Node;
|
|
216
|
+
let sourceOutputSpy: jest.SpyInstance | undefined;
|
|
217
|
+
let targetInputSpy: jest.SpyInstance;
|
|
218
|
+
|
|
219
|
+
beforeEach(() => {
|
|
220
|
+
sourceLayer = createTestLayer(3);
|
|
221
|
+
targetLayer = createTestLayer(3);
|
|
222
|
+
targetGroup = new Group(2);
|
|
223
|
+
targetNode = new Node();
|
|
224
|
+
|
|
225
|
+
if (sourceLayer.output) {
|
|
226
|
+
sourceOutputSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
227
|
+
}
|
|
228
|
+
targetInputSpy = jest.spyOn(targetLayer, 'input');
|
|
229
|
+
});
|
|
230
|
+
|
|
231
|
+
afterEach(() => {
|
|
232
|
+
sourceOutputSpy?.mockRestore();
|
|
233
|
+
targetInputSpy?.mockRestore();
|
|
234
|
+
});
|
|
235
|
+
|
|
236
|
+
it('should throw error if source layer output is not defined', () => {
|
|
237
|
+
// Arrange
|
|
238
|
+
const layerWithoutOutput = new Layer();
|
|
239
|
+
// Act & Assert
|
|
240
|
+
expect(() => layerWithoutOutput.connect(targetGroup)).toThrow(
|
|
241
|
+
'Layer output is not defined. Cannot connect from this layer.'
|
|
242
|
+
);
|
|
243
|
+
});
|
|
244
|
+
|
|
245
|
+
it('should call output.connect when connecting to a Group', () => {
|
|
246
|
+
// Arrange, Act
|
|
247
|
+
const method = methods.groupConnection.ALL_TO_ALL;
|
|
248
|
+
const weight = 0.5;
|
|
249
|
+
sourceLayer.connect(targetGroup, method, weight);
|
|
250
|
+
// Assert
|
|
251
|
+
expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
|
|
252
|
+
expect(sourceOutputSpy).toHaveBeenCalledWith(
|
|
253
|
+
targetGroup,
|
|
254
|
+
method,
|
|
255
|
+
weight
|
|
256
|
+
);
|
|
257
|
+
});
|
|
258
|
+
|
|
259
|
+
it('should call output.connect when connecting to a Node', () => {
|
|
260
|
+
// Arrange, Act
|
|
261
|
+
const weight = 0.6;
|
|
262
|
+
sourceLayer.connect(targetNode, undefined, weight);
|
|
263
|
+
// Assert
|
|
264
|
+
expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
|
|
265
|
+
expect(sourceOutputSpy).toHaveBeenCalledWith(
|
|
266
|
+
targetNode,
|
|
267
|
+
undefined,
|
|
268
|
+
weight
|
|
269
|
+
);
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
it('should call targetLayer.input when connecting to another Layer', () => {
|
|
273
|
+
// Arrange, Act
|
|
274
|
+
const method = methods.groupConnection.ONE_TO_ONE;
|
|
275
|
+
const weight = 0.7;
|
|
276
|
+
sourceLayer.connect(targetLayer, method, weight);
|
|
277
|
+
// Assert
|
|
278
|
+
expect(targetInputSpy).toHaveBeenCalledTimes(1);
|
|
279
|
+
expect(targetInputSpy).toHaveBeenCalledWith(
|
|
280
|
+
sourceLayer,
|
|
281
|
+
method,
|
|
282
|
+
weight
|
|
283
|
+
);
|
|
284
|
+
expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
|
|
285
|
+
expect(sourceOutputSpy).toHaveBeenCalledWith(
|
|
286
|
+
targetLayer.output,
|
|
287
|
+
method,
|
|
288
|
+
weight
|
|
289
|
+
);
|
|
290
|
+
});
|
|
291
|
+
|
|
292
|
+
it('should return the connections created by output.connect (Group target)', () => {
|
|
293
|
+
// Arrange
|
|
294
|
+
const mockConnections = [new Connection(new Node(), new Node())];
|
|
295
|
+
sourceOutputSpy?.mockReturnValue(mockConnections);
|
|
296
|
+
// Act
|
|
297
|
+
const result = sourceLayer.connect(targetGroup);
|
|
298
|
+
// Assert
|
|
299
|
+
expect(result).toBe(mockConnections);
|
|
300
|
+
});
|
|
301
|
+
|
|
302
|
+
it('should return the connections created by targetLayer.input (Layer target)', () => {
|
|
303
|
+
// Arrange
|
|
304
|
+
const mockConnections = [new Connection(new Node(), new Node())];
|
|
305
|
+
targetInputSpy.mockReturnValue(mockConnections);
|
|
306
|
+
// Act
|
|
307
|
+
const result = sourceLayer.connect(targetLayer);
|
|
308
|
+
// Assert
|
|
309
|
+
expect(result).toBe(mockConnections);
|
|
310
|
+
});
|
|
311
|
+
});
|
|
312
|
+
|
|
313
|
+
describe('gate()', () => {
|
|
314
|
+
let layer: Layer;
|
|
315
|
+
let connectionsToGate: Connection[];
|
|
316
|
+
let outputGroupSpy: jest.SpyInstance;
|
|
317
|
+
|
|
318
|
+
beforeEach(() => {
|
|
319
|
+
layer = createTestLayer(2);
|
|
320
|
+
const node1 = new Node();
|
|
321
|
+
const node2 = new Node();
|
|
322
|
+
connectionsToGate = [new Connection(node1, node2)];
|
|
323
|
+
|
|
324
|
+
if (layer.output) {
|
|
325
|
+
outputGroupSpy = jest.spyOn(layer.output, 'gate');
|
|
326
|
+
}
|
|
327
|
+
});
|
|
328
|
+
|
|
329
|
+
afterEach(() => {
|
|
330
|
+
outputGroupSpy?.mockRestore();
|
|
331
|
+
});
|
|
332
|
+
|
|
333
|
+
it('should throw error if layer output is not defined', () => {
|
|
334
|
+
// Arrange
|
|
335
|
+
const layerWithoutOutput = new Layer();
|
|
336
|
+
// Act & Assert
|
|
337
|
+
expect(() =>
|
|
338
|
+
layerWithoutOutput.gate(connectionsToGate, methods.gating.INPUT)
|
|
339
|
+
).toThrow('Layer output is not defined. Cannot gate from this layer.');
|
|
340
|
+
});
|
|
341
|
+
|
|
342
|
+
it('should call output.gate with the provided connections and method', () => {
|
|
343
|
+
// Arrange, Act
|
|
344
|
+
const method = methods.gating.OUTPUT;
|
|
345
|
+
layer.gate(connectionsToGate, method);
|
|
346
|
+
// Assert
|
|
347
|
+
expect(outputGroupSpy).toHaveBeenCalledTimes(1);
|
|
348
|
+
expect(outputGroupSpy).toHaveBeenCalledWith(connectionsToGate, method);
|
|
349
|
+
});
|
|
350
|
+
});
|
|
351
|
+
|
|
352
|
+
describe('set()', () => {
|
|
353
|
+
const size = 3;
|
|
354
|
+
let layer: Layer;
|
|
355
|
+
|
|
356
|
+
beforeEach(() => {
|
|
357
|
+
layer = createTestLayer(size);
|
|
358
|
+
layer.nodes.forEach((node, i) => {
|
|
359
|
+
node.bias = i * 0.1;
|
|
360
|
+
node.squash = methods.Activation.logistic;
|
|
361
|
+
node.type = 'hidden';
|
|
362
|
+
});
|
|
363
|
+
});
|
|
364
|
+
|
|
365
|
+
it('should set bias for all nodes', () => {
|
|
366
|
+
// Arrange, Act
|
|
367
|
+
const biasValue = 0.7;
|
|
368
|
+
layer.set({ bias: biasValue });
|
|
369
|
+
// Assert
|
|
370
|
+
layer.nodes.forEach((node) => {
|
|
371
|
+
expect(node.bias).toBe(biasValue);
|
|
372
|
+
});
|
|
373
|
+
});
|
|
374
|
+
|
|
375
|
+
it('should set squash function for all nodes', () => {
|
|
376
|
+
// Arrange, Act
|
|
377
|
+
const squashFn = methods.Activation.relu;
|
|
378
|
+
layer.set({ squash: squashFn });
|
|
379
|
+
// Assert
|
|
380
|
+
layer.nodes.forEach((node) => {
|
|
381
|
+
expect(node.squash).toBe(squashFn);
|
|
382
|
+
});
|
|
383
|
+
});
|
|
384
|
+
|
|
385
|
+
it('should set type for all nodes', () => {
|
|
386
|
+
// Arrange, Act
|
|
387
|
+
const typeValue = 'output';
|
|
388
|
+
layer.set({ type: typeValue });
|
|
389
|
+
// Assert
|
|
390
|
+
layer.nodes.forEach((node) => {
|
|
391
|
+
expect(node.type).toBe(typeValue);
|
|
392
|
+
});
|
|
393
|
+
});
|
|
394
|
+
|
|
395
|
+
it('should set multiple properties at once', () => {
|
|
396
|
+
// Arrange, Act
|
|
397
|
+
const biasValue = -0.2;
|
|
398
|
+
const squashFn = methods.Activation.tanh;
|
|
399
|
+
const typeValue = 'input';
|
|
400
|
+
layer.set({ bias: biasValue, squash: squashFn, type: typeValue });
|
|
401
|
+
// Assert
|
|
402
|
+
layer.nodes.forEach((node) => {
|
|
403
|
+
expect(node.bias).toBe(biasValue);
|
|
404
|
+
expect(node.squash).toBe(squashFn);
|
|
405
|
+
expect(node.type).toBe(typeValue);
|
|
406
|
+
});
|
|
407
|
+
});
|
|
408
|
+
|
|
409
|
+
it('should not change properties if not provided in values object', () => {
|
|
410
|
+
// Arrange
|
|
411
|
+
const initialBiases = layer.nodes.map((node) => node.bias);
|
|
412
|
+
const initialSquashes = layer.nodes.map((node) => node.squash);
|
|
413
|
+
const initialTypes = layer.nodes.map((node) => node.type);
|
|
414
|
+
// Act
|
|
415
|
+
layer.set({ bias: 0.99 });
|
|
416
|
+
// Assert
|
|
417
|
+
layer.nodes.forEach((node, i) => {
|
|
418
|
+
expect(node.bias).toBe(0.99);
|
|
419
|
+
expect(node.squash).toBe(initialSquashes[i]);
|
|
420
|
+
expect(node.type).toBe(initialTypes[i]);
|
|
421
|
+
});
|
|
422
|
+
});
|
|
423
|
+
|
|
424
|
+
it('should call set on Group instances within the layer nodes', () => {
|
|
425
|
+
const memoryLayer = Layer.memory(2, 2);
|
|
426
|
+
const groupSetSpies = memoryLayer.nodes.map((groupNode) =>
|
|
427
|
+
jest.spyOn((groupNode as unknown) as Group, 'set')
|
|
428
|
+
);
|
|
429
|
+
|
|
430
|
+
const settings = { bias: 0.1, squash: methods.Activation.relu };
|
|
431
|
+
memoryLayer.set(settings);
|
|
432
|
+
|
|
433
|
+
groupSetSpies.forEach((spy) => {
|
|
434
|
+
expect(spy).toHaveBeenCalledTimes(1);
|
|
435
|
+
expect(spy).toHaveBeenCalledWith(settings);
|
|
436
|
+
});
|
|
437
|
+
|
|
438
|
+
groupSetSpies.forEach((spy) => spy.mockRestore());
|
|
439
|
+
});
|
|
440
|
+
});
|
|
441
|
+
|
|
442
|
+
describe('disconnect()', () => {
|
|
443
|
+
let layer: Layer;
|
|
444
|
+
let targetGroup: Group;
|
|
445
|
+
let targetNode: Node;
|
|
446
|
+
let nodeDisconnectSpies: jest.SpyInstance[];
|
|
447
|
+
|
|
448
|
+
beforeEach(() => {
|
|
449
|
+
layer = createTestLayer(2);
|
|
450
|
+
targetGroup = new Group(2);
|
|
451
|
+
targetNode = new Node();
|
|
452
|
+
|
|
453
|
+
layer.nodes.forEach((node) => {
|
|
454
|
+
targetGroup.nodes.forEach((target) => node.connect(target));
|
|
455
|
+
node.connect(targetNode);
|
|
456
|
+
});
|
|
457
|
+
layer.nodes[0].connections.out.forEach((conn) =>
|
|
458
|
+
layer.connections.out.push(conn)
|
|
459
|
+
);
|
|
460
|
+
layer.nodes[1].connections.out.forEach((conn) =>
|
|
461
|
+
layer.connections.out.push(conn)
|
|
462
|
+
);
|
|
463
|
+
|
|
464
|
+
nodeDisconnectSpies = layer.nodes.map((node) =>
|
|
465
|
+
jest.spyOn(node, 'disconnect')
|
|
466
|
+
);
|
|
467
|
+
});
|
|
468
|
+
|
|
469
|
+
afterEach(() => {
|
|
470
|
+
nodeDisconnectSpies.forEach((spy) => spy.mockRestore());
|
|
471
|
+
});
|
|
472
|
+
|
|
473
|
+
describe('Disconnecting from Group', () => {
|
|
474
|
+
it('should call disconnect on each layer node for each target group node (one-sided)', () => {
|
|
475
|
+
layer.disconnect(targetGroup, false);
|
|
476
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(
|
|
477
|
+
targetGroup.nodes.length
|
|
478
|
+
);
|
|
479
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(
|
|
480
|
+
targetGroup.nodes.length
|
|
481
|
+
);
|
|
482
|
+
targetGroup.nodes.forEach((target) => {
|
|
483
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(target, false);
|
|
484
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(target, false);
|
|
485
|
+
});
|
|
486
|
+
});
|
|
487
|
+
|
|
488
|
+
it('should call disconnect on each layer node for each target group node (two-sided)', () => {
|
|
489
|
+
layer.disconnect(targetGroup, true);
|
|
490
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(
|
|
491
|
+
targetGroup.nodes.length
|
|
492
|
+
);
|
|
493
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(
|
|
494
|
+
targetGroup.nodes.length
|
|
495
|
+
);
|
|
496
|
+
targetGroup.nodes.forEach((target) => {
|
|
497
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(target, true);
|
|
498
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(target, true);
|
|
499
|
+
});
|
|
500
|
+
});
|
|
501
|
+
|
|
502
|
+
it('should remove outgoing connections from layer.connections.out (one-sided)', () => {
|
|
503
|
+
const initialOutCount = layer.connections.out.length;
|
|
504
|
+
expect(initialOutCount).toBe(6);
|
|
505
|
+
layer.disconnect(targetGroup, false);
|
|
506
|
+
expect(layer.connections.out).toHaveLength(initialOutCount - 4);
|
|
507
|
+
layer.connections.out.forEach((conn) => {
|
|
508
|
+
expect(conn.to).toBe(targetNode);
|
|
509
|
+
});
|
|
510
|
+
});
|
|
511
|
+
|
|
512
|
+
it('should remove incoming connections from layer.connections.in if two-sided', () => {
|
|
513
|
+
targetGroup.nodes[0].connect(layer.nodes[0]);
|
|
514
|
+
targetGroup.nodes[1].connect(layer.nodes[1]);
|
|
515
|
+
layer.connections.in.push(targetGroup.nodes[0].connections.out[0]);
|
|
516
|
+
layer.connections.in.push(targetGroup.nodes[1].connections.out[0]);
|
|
517
|
+
const initialInCount = layer.connections.in.length;
|
|
518
|
+
expect(initialInCount).toBe(2);
|
|
519
|
+
|
|
520
|
+
layer.disconnect(targetGroup, true);
|
|
521
|
+
|
|
522
|
+
expect(layer.connections.in).toHaveLength(0);
|
|
523
|
+
});
|
|
524
|
+
});
|
|
525
|
+
|
|
526
|
+
describe('Disconnecting from Node', () => {
|
|
527
|
+
it('should call disconnect on each layer node for the target node (one-sided)', () => {
|
|
528
|
+
layer.disconnect(targetNode, false);
|
|
529
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(1);
|
|
530
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(1);
|
|
531
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(
|
|
532
|
+
targetNode,
|
|
533
|
+
false
|
|
534
|
+
);
|
|
535
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(
|
|
536
|
+
targetNode,
|
|
537
|
+
false
|
|
538
|
+
);
|
|
539
|
+
});
|
|
540
|
+
|
|
541
|
+
it('should call disconnect on each layer node for the target node (two-sided)', () => {
|
|
542
|
+
layer.disconnect(targetNode, true);
|
|
543
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(1);
|
|
544
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(1);
|
|
545
|
+
expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(targetNode, true);
|
|
546
|
+
expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(targetNode, true);
|
|
547
|
+
});
|
|
548
|
+
|
|
549
|
+
it('should remove outgoing connections to the node from layer.connections.out (one-sided)', () => {
|
|
550
|
+
const initialOutCount = layer.connections.out.length;
|
|
551
|
+
layer.disconnect(targetNode, false);
|
|
552
|
+
expect(layer.connections.out).toHaveLength(initialOutCount - 2);
|
|
553
|
+
layer.connections.out.forEach((conn) => {
|
|
554
|
+
expect(targetGroup.nodes).toContain(conn.to);
|
|
555
|
+
});
|
|
556
|
+
});
|
|
557
|
+
|
|
558
|
+
it('should remove incoming connections from the node in layer.connections.in if two-sided', () => {
|
|
559
|
+
targetNode.connect(layer.nodes[0]);
|
|
560
|
+
layer.connections.in.push(targetNode.connections.out[0]);
|
|
561
|
+
const initialInCount = layer.connections.in.length;
|
|
562
|
+
expect(initialInCount).toBe(1);
|
|
563
|
+
|
|
564
|
+
layer.disconnect(targetNode, true);
|
|
565
|
+
|
|
566
|
+
expect(layer.connections.in).toHaveLength(0);
|
|
567
|
+
});
|
|
568
|
+
});
|
|
569
|
+
});
|
|
570
|
+
|
|
571
|
+
describe('clear()', () => {
|
|
572
|
+
const size = 3;
|
|
573
|
+
let layer: Layer;
|
|
574
|
+
let nodeSpies: jest.SpyInstance[];
|
|
575
|
+
|
|
576
|
+
beforeEach(() => {
|
|
577
|
+
layer = createTestLayer(size);
|
|
578
|
+
nodeSpies = layer.nodes.map((node) => jest.spyOn(node, 'clear'));
|
|
579
|
+
});
|
|
580
|
+
|
|
581
|
+
afterEach(() => {
|
|
582
|
+
nodeSpies.forEach((spy) => spy.mockRestore());
|
|
583
|
+
});
|
|
584
|
+
|
|
585
|
+
it('should call clear on all nodes in the layer', () => {
|
|
586
|
+
layer.clear();
|
|
587
|
+
expect(nodeSpies).toHaveLength(size);
|
|
588
|
+
nodeSpies.forEach((spy) => {
|
|
589
|
+
expect(spy).toHaveBeenCalledTimes(1);
|
|
590
|
+
});
|
|
591
|
+
});
|
|
592
|
+
|
|
593
|
+
it('should call clear on Group instances within the layer nodes (e.g., Memory layer)', () => {
|
|
594
|
+
const memoryLayer = Layer.memory(2, 2);
|
|
595
|
+
const groupClearSpies = memoryLayer.nodes.map((groupNode) =>
|
|
596
|
+
jest.spyOn((groupNode as unknown) as Group, 'clear')
|
|
597
|
+
);
|
|
598
|
+
|
|
599
|
+
memoryLayer.clear();
|
|
600
|
+
|
|
601
|
+
groupClearSpies.forEach((spy) => {
|
|
602
|
+
expect(spy).toHaveBeenCalledTimes(1);
|
|
603
|
+
});
|
|
604
|
+
|
|
605
|
+
groupClearSpies.forEach((spy) => spy.mockRestore());
|
|
606
|
+
});
|
|
607
|
+
});
|
|
608
|
+
|
|
609
|
+
describe('input()', () => {
|
|
610
|
+
let targetLayer: Layer;
|
|
611
|
+
let sourceLayer: Layer;
|
|
612
|
+
let sourceGroup: Group;
|
|
613
|
+
let sourceOutputConnectSpy: jest.SpyInstance | undefined;
|
|
614
|
+
let sourceGroupConnectSpy: jest.SpyInstance;
|
|
615
|
+
let targetOutputConnectSpy: jest.SpyInstance | undefined;
|
|
616
|
+
|
|
617
|
+
beforeEach(() => {
|
|
618
|
+
targetLayer = createTestLayer(2);
|
|
619
|
+
sourceLayer = createTestLayer(3);
|
|
620
|
+
sourceGroup = new Group(3);
|
|
621
|
+
|
|
622
|
+
if (sourceLayer.output) {
|
|
623
|
+
sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
624
|
+
}
|
|
625
|
+
sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
|
|
626
|
+
|
|
627
|
+
if (targetLayer.output) {
|
|
628
|
+
targetOutputConnectSpy = jest.spyOn(targetLayer.output, 'connect');
|
|
629
|
+
}
|
|
630
|
+
});
|
|
631
|
+
|
|
632
|
+
afterEach(() => {
|
|
633
|
+
sourceOutputConnectSpy?.mockRestore();
|
|
634
|
+
sourceGroupConnectSpy.mockRestore();
|
|
635
|
+
targetOutputConnectSpy?.mockRestore();
|
|
636
|
+
});
|
|
637
|
+
|
|
638
|
+
it('should throw error if target layer output (acting as input) is not defined', () => {
|
|
639
|
+
const layerWithoutOutput = new Layer();
|
|
640
|
+
expect(() => layerWithoutOutput.input(sourceGroup)).toThrow(
|
|
641
|
+
'Layer output (acting as input target) is not defined.'
|
|
642
|
+
);
|
|
643
|
+
});
|
|
644
|
+
|
|
645
|
+
it('should use source Layer output group when connecting from a Layer', () => {
|
|
646
|
+
targetLayer.input(sourceLayer);
|
|
647
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
|
|
648
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
649
|
+
targetLayer.output,
|
|
650
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
651
|
+
undefined
|
|
652
|
+
);
|
|
653
|
+
expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
|
|
654
|
+
});
|
|
655
|
+
|
|
656
|
+
it('should use source Group directly when connecting from a Group', () => {
|
|
657
|
+
targetLayer.input(sourceGroup);
|
|
658
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
|
|
659
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
660
|
+
targetLayer.output,
|
|
661
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
662
|
+
undefined
|
|
663
|
+
);
|
|
664
|
+
expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
|
|
665
|
+
});
|
|
666
|
+
|
|
667
|
+
it('should use default connection method ALL_TO_ALL if none provided', () => {
|
|
668
|
+
targetLayer.input(sourceGroup);
|
|
669
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
670
|
+
expect.anything(),
|
|
671
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
672
|
+
undefined
|
|
673
|
+
);
|
|
674
|
+
});
|
|
675
|
+
|
|
676
|
+
it('should use provided connection method', () => {
|
|
677
|
+
const method = methods.groupConnection.ONE_TO_ONE;
|
|
678
|
+
sourceGroup = new Group(2);
|
|
679
|
+
sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
|
|
680
|
+
|
|
681
|
+
targetLayer.input(sourceGroup, method);
|
|
682
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
683
|
+
targetLayer.output,
|
|
684
|
+
method,
|
|
685
|
+
undefined
|
|
686
|
+
);
|
|
687
|
+
});
|
|
688
|
+
|
|
689
|
+
it('should use provided weight', () => {
|
|
690
|
+
const weight = 0.88;
|
|
691
|
+
targetLayer.input(sourceGroup, undefined, weight);
|
|
692
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
693
|
+
targetLayer.output,
|
|
694
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
695
|
+
weight
|
|
696
|
+
);
|
|
697
|
+
});
|
|
698
|
+
|
|
699
|
+
it('should return the connections created by the source connect call', () => {
|
|
700
|
+
const mockConnections = [new Connection(new Node(), new Node())];
|
|
701
|
+
sourceGroupConnectSpy.mockReturnValue(mockConnections);
|
|
702
|
+
const result = targetLayer.input(sourceGroup);
|
|
703
|
+
expect(result).toBe(mockConnections);
|
|
704
|
+
});
|
|
705
|
+
});
|
|
706
|
+
});
|
|
707
|
+
|
|
708
|
+
describe('Static Factory Methods', () => {
|
|
709
|
+
describe('Layer.dense()', () => {
|
|
710
|
+
const size = 5;
|
|
711
|
+
let layer: Layer;
|
|
712
|
+
|
|
713
|
+
beforeEach(() => {
|
|
714
|
+
layer = Layer.dense(size);
|
|
715
|
+
});
|
|
716
|
+
|
|
717
|
+
it('should create a layer with the specified number of nodes', () => {
|
|
718
|
+
expect(layer.nodes).toHaveLength(size);
|
|
719
|
+
layer.nodes.forEach((node) => expect(node).toBeInstanceOf(Node));
|
|
720
|
+
});
|
|
721
|
+
|
|
722
|
+
it('should set the output to a Group containing all nodes', () => {
|
|
723
|
+
expect(layer.output).toBeInstanceOf(Group);
|
|
724
|
+
expect(layer.output?.nodes).toHaveLength(size);
|
|
725
|
+
expect(layer.output?.nodes).toEqual(layer.nodes);
|
|
726
|
+
});
|
|
727
|
+
|
|
728
|
+
it('should have a custom input method', () => {
|
|
729
|
+
expect(typeof layer.input).toBe('function');
|
|
730
|
+
expect(layer.input).not.toBe(Layer.prototype.input);
|
|
731
|
+
});
|
|
732
|
+
|
|
733
|
+
describe('Dense Layer input() method', () => {
|
|
734
|
+
let sourceLayer: Layer;
|
|
735
|
+
let sourceGroup: Group;
|
|
736
|
+
let sourceOutputConnectSpy: jest.SpyInstance | undefined;
|
|
737
|
+
let sourceGroupConnectSpy: jest.SpyInstance;
|
|
738
|
+
|
|
739
|
+
beforeEach(() => {
|
|
740
|
+
sourceLayer = Layer.dense(3);
|
|
741
|
+
sourceGroup = new Group(3);
|
|
742
|
+
|
|
743
|
+
if (sourceLayer.output) {
|
|
744
|
+
sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
745
|
+
}
|
|
746
|
+
sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
|
|
747
|
+
});
|
|
748
|
+
|
|
749
|
+
afterEach(() => {
|
|
750
|
+
sourceOutputConnectSpy?.mockRestore();
|
|
751
|
+
sourceGroupConnectSpy.mockRestore();
|
|
752
|
+
});
|
|
753
|
+
|
|
754
|
+
it('should connect source Layer output to the dense layers internal block', () => {
|
|
755
|
+
const method = methods.groupConnection.ALL_TO_ALL;
|
|
756
|
+
const weight = 0.1;
|
|
757
|
+
layer.input(sourceLayer, method, weight);
|
|
758
|
+
|
|
759
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
|
|
760
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
761
|
+
layer.output,
|
|
762
|
+
method,
|
|
763
|
+
weight
|
|
764
|
+
);
|
|
765
|
+
expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
|
|
766
|
+
});
|
|
767
|
+
|
|
768
|
+
it('should connect source Group to the dense layers internal block', () => {
|
|
769
|
+
const method = methods.groupConnection.ONE_TO_ONE;
|
|
770
|
+
sourceGroup = new Group(size);
|
|
771
|
+
sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
|
|
772
|
+
const weight = 0.2;
|
|
773
|
+
|
|
774
|
+
layer.input(sourceGroup, method, weight);
|
|
775
|
+
|
|
776
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
|
|
777
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
778
|
+
layer.output,
|
|
779
|
+
method,
|
|
780
|
+
weight
|
|
781
|
+
);
|
|
782
|
+
expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
|
|
783
|
+
});
|
|
784
|
+
|
|
785
|
+
it('should use ALL_TO_ALL by default if method not provided', () => {
|
|
786
|
+
layer.input(sourceGroup);
|
|
787
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
788
|
+
expect.anything(),
|
|
789
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
790
|
+
undefined
|
|
791
|
+
);
|
|
792
|
+
});
|
|
793
|
+
});
|
|
794
|
+
});
|
|
795
|
+
|
|
796
|
+
describe('Layer.lstm()', () => {
|
|
797
|
+
const size = 2;
|
|
798
|
+
let layer: Layer;
|
|
799
|
+
let inputGate: Group,
|
|
800
|
+
forgetGate: Group,
|
|
801
|
+
memoryCell: Group,
|
|
802
|
+
outputGate: Group,
|
|
803
|
+
outputBlock: Group;
|
|
804
|
+
|
|
805
|
+
beforeEach(() => {
|
|
806
|
+
layer = Layer.lstm(size);
|
|
807
|
+
inputGate = new Group(size);
|
|
808
|
+
inputGate.nodes = layer.nodes.slice(0, size);
|
|
809
|
+
|
|
810
|
+
forgetGate = new Group(size);
|
|
811
|
+
forgetGate.nodes = layer.nodes.slice(size, size * 2);
|
|
812
|
+
|
|
813
|
+
memoryCell = new Group(size);
|
|
814
|
+
memoryCell.nodes = layer.nodes.slice(size * 2, size * 3);
|
|
815
|
+
|
|
816
|
+
outputGate = new Group(size);
|
|
817
|
+
outputGate.nodes = layer.nodes.slice(size * 3, size * 4);
|
|
818
|
+
|
|
819
|
+
outputBlock = new Group(size);
|
|
820
|
+
outputBlock.nodes = layer.nodes.slice(size * 4, size * 5);
|
|
821
|
+
|
|
822
|
+
layer.output = outputBlock;
|
|
823
|
+
});
|
|
824
|
+
|
|
825
|
+
it('should create a layer with 5 * size nodes', () => {
|
|
826
|
+
expect(layer.nodes).toHaveLength(5 * size);
|
|
827
|
+
});
|
|
828
|
+
|
|
829
|
+
it('should set the output to the outputBlock group', () => {
|
|
830
|
+
expect(layer.output).toBe(outputBlock);
|
|
831
|
+
});
|
|
832
|
+
|
|
833
|
+
it('should set initial bias for gates', () => {
|
|
834
|
+
inputGate.nodes.forEach((node) => expect(node.bias).toBe(1));
|
|
835
|
+
forgetGate.nodes.forEach((node) => expect(node.bias).toBe(1));
|
|
836
|
+
outputGate.nodes.forEach((node) => expect(node.bias).toBe(1));
|
|
837
|
+
memoryCell.nodes.forEach((node) => expect(node.bias).toBe(0));
|
|
838
|
+
outputBlock.nodes.forEach((node) => expect(node.bias).toBe(0));
|
|
839
|
+
});
|
|
840
|
+
|
|
841
|
+
it('should establish internal connections (memoryCell to gates)', () => {
|
|
842
|
+
expect(isGroupConnectedTo(memoryCell, inputGate)).toBe(true);
|
|
843
|
+
expect(isGroupConnectedTo(memoryCell, forgetGate)).toBe(true);
|
|
844
|
+
expect(isGroupConnectedTo(memoryCell, outputGate)).toBe(true);
|
|
845
|
+
});
|
|
846
|
+
|
|
847
|
+
it('should establish internal connections (memoryCell self-connection)', () => {
|
|
848
|
+
expect(
|
|
849
|
+
isGroupConnectedTo(
|
|
850
|
+
memoryCell,
|
|
851
|
+
memoryCell,
|
|
852
|
+
methods.groupConnection.ONE_TO_ONE
|
|
853
|
+
)
|
|
854
|
+
).toBe(true);
|
|
855
|
+
});
|
|
856
|
+
|
|
857
|
+
it('should establish internal connections (memoryCell to outputBlock)', () => {
|
|
858
|
+
expect(isGroupConnectedTo(memoryCell, outputBlock)).toBe(true);
|
|
859
|
+
});
|
|
860
|
+
|
|
861
|
+
it('should gate the memoryCell self-connection with the forgetGate', () => {
|
|
862
|
+
memoryCell.nodes.forEach((node, i) => {
|
|
863
|
+
const selfConnection = node.connections.self.find(
|
|
864
|
+
(conn: Connection) => conn.to === node
|
|
865
|
+
);
|
|
866
|
+
expect(selfConnection).toBeDefined();
|
|
867
|
+
expect(selfConnection?.gater).toBe(forgetGate.nodes[i]);
|
|
868
|
+
});
|
|
869
|
+
});
|
|
870
|
+
|
|
871
|
+
it('should gate the memoryCell to outputBlock connection with the outputGate', () => {
|
|
872
|
+
memoryCell.nodes.forEach((node, i) => {
|
|
873
|
+
const outputConnection = node.connections.out.find(
|
|
874
|
+
(conn) => conn.to === outputBlock.nodes[0]
|
|
875
|
+
);
|
|
876
|
+
expect(outputConnection).toBeDefined();
|
|
877
|
+
expect(outputConnection?.gater).toBe(outputGate.nodes[i]);
|
|
878
|
+
|
|
879
|
+
if (size > 1) {
|
|
880
|
+
const outputConnection1 = node.connections.out.find(
|
|
881
|
+
(conn) => conn.to === outputBlock.nodes[1]
|
|
882
|
+
);
|
|
883
|
+
expect(outputConnection1).toBeDefined();
|
|
884
|
+
expect(outputConnection1?.gater).toBe(outputGate.nodes[i]);
|
|
885
|
+
}
|
|
886
|
+
});
|
|
887
|
+
});
|
|
888
|
+
|
|
889
|
+
it('should have a custom input method', () => {
|
|
890
|
+
expect(typeof layer.input).toBe('function');
|
|
891
|
+
expect(layer.input).not.toBe(Layer.prototype.input);
|
|
892
|
+
});
|
|
893
|
+
|
|
894
|
+
describe('LSTM Layer input() method', () => {
|
|
895
|
+
let sourceLayer: Layer;
|
|
896
|
+
let sourceOutputConnectSpy: jest.SpyInstance | undefined;
|
|
897
|
+
|
|
898
|
+
beforeEach(() => {
|
|
899
|
+
sourceLayer = Layer.dense(3);
|
|
900
|
+
if (sourceLayer.output) {
|
|
901
|
+
sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
902
|
+
}
|
|
903
|
+
});
|
|
904
|
+
|
|
905
|
+
afterEach(() => {
|
|
906
|
+
sourceOutputConnectSpy?.mockRestore();
|
|
907
|
+
});
|
|
908
|
+
|
|
909
|
+
it('should connect source to inputGate, forgetGate, memoryCell, and outputGate', () => {
|
|
910
|
+
layer.input(sourceLayer);
|
|
911
|
+
|
|
912
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(4);
|
|
913
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
914
|
+
expect.objectContaining({ nodes: inputGate.nodes }),
|
|
915
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
916
|
+
undefined
|
|
917
|
+
);
|
|
918
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
919
|
+
expect.objectContaining({ nodes: forgetGate.nodes }),
|
|
920
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
921
|
+
undefined
|
|
922
|
+
);
|
|
923
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
924
|
+
expect.objectContaining({ nodes: memoryCell.nodes }),
|
|
925
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
926
|
+
undefined
|
|
927
|
+
);
|
|
928
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
929
|
+
expect.objectContaining({ nodes: outputGate.nodes }),
|
|
930
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
931
|
+
undefined
|
|
932
|
+
);
|
|
933
|
+
});
|
|
934
|
+
|
|
935
|
+
it('should gate the source-to-memoryCell connection with the inputGate', () => {
|
|
936
|
+
const connections = layer.input(sourceLayer);
|
|
937
|
+
const inputToMemoryConnection = connections.find(
|
|
938
|
+
(conn) =>
|
|
939
|
+
sourceLayer.output!.nodes.includes(conn.from) &&
|
|
940
|
+
memoryCell.nodes.includes(conn.to)
|
|
941
|
+
);
|
|
942
|
+
expect(inputToMemoryConnection).toBeDefined();
|
|
943
|
+
const targetNodeIndex = memoryCell.nodes.indexOf(
|
|
944
|
+
inputToMemoryConnection.to
|
|
945
|
+
);
|
|
946
|
+
expect(inputToMemoryConnection.gater).toBe(
|
|
947
|
+
inputGate.nodes[targetNodeIndex]
|
|
948
|
+
);
|
|
949
|
+
});
|
|
950
|
+
});
|
|
951
|
+
});
|
|
952
|
+
|
|
953
|
+
describe('Layer.gru()', () => {
|
|
954
|
+
const size = 2;
|
|
955
|
+
let layer: Layer;
|
|
956
|
+
let updateGate: Group,
|
|
957
|
+
inverseUpdateGate: Group,
|
|
958
|
+
resetGate: Group,
|
|
959
|
+
memoryCell: Group,
|
|
960
|
+
output: Group,
|
|
961
|
+
previousOutput: Group;
|
|
962
|
+
|
|
963
|
+
beforeEach(() => {
|
|
964
|
+
layer = Layer.gru(size);
|
|
965
|
+
updateGate = new Group(size);
|
|
966
|
+
updateGate.nodes = layer.nodes.slice(0, size);
|
|
967
|
+
|
|
968
|
+
inverseUpdateGate = new Group(size);
|
|
969
|
+
inverseUpdateGate.nodes = layer.nodes.slice(size, size * 2);
|
|
970
|
+
|
|
971
|
+
resetGate = new Group(size);
|
|
972
|
+
resetGate.nodes = layer.nodes.slice(size * 2, size * 3);
|
|
973
|
+
|
|
974
|
+
memoryCell = new Group(size);
|
|
975
|
+
memoryCell.nodes = layer.nodes.slice(size * 3, size * 4);
|
|
976
|
+
|
|
977
|
+
output = new Group(size);
|
|
978
|
+
output.nodes = layer.nodes.slice(size * 4, size * 5);
|
|
979
|
+
|
|
980
|
+
previousOutput = new Group(size);
|
|
981
|
+
previousOutput.nodes = layer.nodes.slice(size * 5, size * 6);
|
|
982
|
+
|
|
983
|
+
layer.output = output;
|
|
984
|
+
});
|
|
985
|
+
|
|
986
|
+
it('should create a layer with 6 * size nodes', () => {
|
|
987
|
+
expect(layer.nodes).toHaveLength(6 * size);
|
|
988
|
+
});
|
|
989
|
+
|
|
990
|
+
it('should set the output to the output group', () => {
|
|
991
|
+
expect(layer.output).toBe(output);
|
|
992
|
+
});
|
|
993
|
+
|
|
994
|
+
it('should set specific node properties', () => {
|
|
995
|
+
previousOutput.nodes.forEach((node) => {
|
|
996
|
+
expect(node.bias).toBe(0);
|
|
997
|
+
expect(node.squash).toBe(methods.Activation.identity);
|
|
998
|
+
expect(node.type).toBe('variant');
|
|
999
|
+
});
|
|
1000
|
+
memoryCell.nodes.forEach((node) =>
|
|
1001
|
+
expect(node.squash).toBe(methods.Activation.tanh)
|
|
1002
|
+
);
|
|
1003
|
+
inverseUpdateGate.nodes.forEach((node) => {
|
|
1004
|
+
expect(node.bias).toBe(0);
|
|
1005
|
+
expect(node.squash).toBe(methods.Activation.inverse);
|
|
1006
|
+
expect(node.type).toBe('variant');
|
|
1007
|
+
});
|
|
1008
|
+
updateGate.nodes.forEach((node) => expect(node.bias).toBe(1));
|
|
1009
|
+
resetGate.nodes.forEach((node) => expect(node.bias).toBe(0));
|
|
1010
|
+
});
|
|
1011
|
+
|
|
1012
|
+
it('should establish internal connections (previousOutput to gates)', () => {
|
|
1013
|
+
expect(isGroupConnectedTo(previousOutput, updateGate)).toBe(true);
|
|
1014
|
+
expect(isGroupConnectedTo(previousOutput, resetGate)).toBe(true);
|
|
1015
|
+
});
|
|
1016
|
+
|
|
1017
|
+
it('should establish internal connections (updateGate to inverseUpdateGate)', () => {
|
|
1018
|
+
expect(
|
|
1019
|
+
isGroupConnectedTo(
|
|
1020
|
+
updateGate,
|
|
1021
|
+
inverseUpdateGate,
|
|
1022
|
+
methods.groupConnection.ONE_TO_ONE
|
|
1023
|
+
)
|
|
1024
|
+
).toBe(true);
|
|
1025
|
+
updateGate.nodes.forEach((node, i) => {
|
|
1026
|
+
const conn = node.connections.out.find(
|
|
1027
|
+
(c) => c.to === inverseUpdateGate.nodes[i]
|
|
1028
|
+
);
|
|
1029
|
+
expect(conn).toBeDefined();
|
|
1030
|
+
expect(conn?.weight).toBe(1);
|
|
1031
|
+
});
|
|
1032
|
+
});
|
|
1033
|
+
|
|
1034
|
+
it('should establish internal connections (previousOutput to memoryCell)', () => {
|
|
1035
|
+
expect(isGroupConnectedTo(previousOutput, memoryCell)).toBe(true);
|
|
1036
|
+
});
|
|
1037
|
+
|
|
1038
|
+
it('should establish internal connections (previousOutput and memoryCell to output)', () => {
|
|
1039
|
+
expect(isGroupConnectedTo(previousOutput, output)).toBe(true);
|
|
1040
|
+
expect(isGroupConnectedTo(memoryCell, output)).toBe(true);
|
|
1041
|
+
});
|
|
1042
|
+
|
|
1043
|
+
it('should establish internal connections (output to previousOutput)', () => {
|
|
1044
|
+
expect(
|
|
1045
|
+
isGroupConnectedTo(
|
|
1046
|
+
output,
|
|
1047
|
+
previousOutput,
|
|
1048
|
+
methods.groupConnection.ONE_TO_ONE
|
|
1049
|
+
)
|
|
1050
|
+
).toBe(true);
|
|
1051
|
+
output.nodes.forEach((node, i) => {
|
|
1052
|
+
const conn = node.connections.out.find(
|
|
1053
|
+
(c) => c.to === previousOutput.nodes[i]
|
|
1054
|
+
);
|
|
1055
|
+
expect(conn).toBeDefined();
|
|
1056
|
+
expect(conn?.weight).toBe(1);
|
|
1057
|
+
});
|
|
1058
|
+
});
|
|
1059
|
+
|
|
1060
|
+
it('should gate previousOutput->memoryCell connection with resetGate', () => {
|
|
1061
|
+
previousOutput.nodes.forEach((node, i) => {
|
|
1062
|
+
const conn = node.connections.out.find(
|
|
1063
|
+
(c) => c.to === memoryCell.nodes[0]
|
|
1064
|
+
);
|
|
1065
|
+
expect(conn).toBeDefined();
|
|
1066
|
+
expect(conn?.gater).toBe(resetGate.nodes[i]);
|
|
1067
|
+
});
|
|
1068
|
+
});
|
|
1069
|
+
|
|
1070
|
+
it('should gate previousOutput->output connection with updateGate', () => {
|
|
1071
|
+
previousOutput.nodes.forEach((node, i) => {
|
|
1072
|
+
const conn = node.connections.out.find(
|
|
1073
|
+
(c) => c.to === output.nodes[0]
|
|
1074
|
+
);
|
|
1075
|
+
expect(conn).toBeDefined();
|
|
1076
|
+
expect(conn?.gater).toBe(updateGate.nodes[i]);
|
|
1077
|
+
});
|
|
1078
|
+
});
|
|
1079
|
+
|
|
1080
|
+
it('should gate memoryCell->output connection with inverseUpdateGate', () => {
|
|
1081
|
+
memoryCell.nodes.forEach((node, i) => {
|
|
1082
|
+
const conn = node.connections.out.find(
|
|
1083
|
+
(c) => c.to === output.nodes[0]
|
|
1084
|
+
);
|
|
1085
|
+
expect(conn).toBeDefined();
|
|
1086
|
+
expect(conn?.gater).toBe(inverseUpdateGate.nodes[i]);
|
|
1087
|
+
});
|
|
1088
|
+
});
|
|
1089
|
+
|
|
1090
|
+
it('should have a custom input method', () => {
|
|
1091
|
+
expect(typeof layer.input).toBe('function');
|
|
1092
|
+
expect(layer.input).not.toBe(Layer.prototype.input);
|
|
1093
|
+
});
|
|
1094
|
+
|
|
1095
|
+
describe('GRU Layer input() method', () => {
|
|
1096
|
+
let sourceLayer: Layer;
|
|
1097
|
+
let sourceOutputConnectSpy: jest.SpyInstance | undefined;
|
|
1098
|
+
|
|
1099
|
+
beforeEach(() => {
|
|
1100
|
+
sourceLayer = Layer.dense(3);
|
|
1101
|
+
if (sourceLayer.output) {
|
|
1102
|
+
sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
1103
|
+
}
|
|
1104
|
+
});
|
|
1105
|
+
|
|
1106
|
+
afterEach(() => {
|
|
1107
|
+
sourceOutputConnectSpy?.mockRestore();
|
|
1108
|
+
});
|
|
1109
|
+
|
|
1110
|
+
it('should connect source to updateGate, resetGate, and memoryCell', () => {
|
|
1111
|
+
layer.input(sourceLayer);
|
|
1112
|
+
|
|
1113
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(3);
|
|
1114
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
1115
|
+
expect.objectContaining({ nodes: updateGate.nodes }),
|
|
1116
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
1117
|
+
undefined
|
|
1118
|
+
);
|
|
1119
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
1120
|
+
expect.objectContaining({ nodes: resetGate.nodes }),
|
|
1121
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
1122
|
+
undefined
|
|
1123
|
+
);
|
|
1124
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
1125
|
+
expect.objectContaining({ nodes: memoryCell.nodes }),
|
|
1126
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
1127
|
+
undefined
|
|
1128
|
+
);
|
|
1129
|
+
});
|
|
1130
|
+
});
|
|
1131
|
+
});
|
|
1132
|
+
|
|
1133
|
+
describe('Layer.memory()', () => {
|
|
1134
|
+
const size = 3;
|
|
1135
|
+
const memoryDepth = 2;
|
|
1136
|
+
let layer: Layer;
|
|
1137
|
+
|
|
1138
|
+
beforeEach(() => {
|
|
1139
|
+
layer = Layer.memory(size, memoryDepth);
|
|
1140
|
+
});
|
|
1141
|
+
|
|
1142
|
+
it('should create a layer with memoryDepth groups in nodes array', () => {
|
|
1143
|
+
expect(layer.nodes).toHaveLength(memoryDepth);
|
|
1144
|
+
layer.nodes.forEach((nodeOrGroup) => {
|
|
1145
|
+
expect((layer as any).isGroup(nodeOrGroup)).toBe(true);
|
|
1146
|
+
expect(((nodeOrGroup as unknown) as Group).nodes).toHaveLength(size);
|
|
1147
|
+
});
|
|
1148
|
+
});
|
|
1149
|
+
|
|
1150
|
+
it('should set specific properties for nodes within memory blocks', () => {
|
|
1151
|
+
layer.nodes.forEach((group) => {
|
|
1152
|
+
((group as unknown) as Group).nodes.forEach((node) => {
|
|
1153
|
+
expect(node.squash).toBe(methods.Activation.identity);
|
|
1154
|
+
expect(node.bias).toBe(0);
|
|
1155
|
+
expect(node.type).toBe('variant');
|
|
1156
|
+
});
|
|
1157
|
+
});
|
|
1158
|
+
});
|
|
1159
|
+
|
|
1160
|
+
it('should connect previous memory block to current one (ONE_TO_ONE, weight 1)', () => {
|
|
1161
|
+
// After reversal in factory: nodes[0] is newest, nodes[1] is second newest, etc.
|
|
1162
|
+
// Connection is made from older (previous) to newer (block).
|
|
1163
|
+
// So, connection should exist from nodes[1] (older) to nodes[0] (newer).
|
|
1164
|
+
const block1 = (layer.nodes[0] as unknown) as Group; // Newest block
|
|
1165
|
+
const block2 = (layer.nodes[1] as unknown) as Group; // Second newest (older) block
|
|
1166
|
+
|
|
1167
|
+
// Check connection from the older block (block2) to the newer block (block1)
|
|
1168
|
+
expect(
|
|
1169
|
+
isGroupConnectedTo(block2, block1, methods.groupConnection.ONE_TO_ONE)
|
|
1170
|
+
).toBe(true);
|
|
1171
|
+
|
|
1172
|
+
// Check weight on the node level (connection from block2 node to block1 node)
|
|
1173
|
+
block2.nodes.forEach((node, i) => {
|
|
1174
|
+
const conn = node.connections.out.find(
|
|
1175
|
+
(c) => c.to === block1.nodes[i]
|
|
1176
|
+
);
|
|
1177
|
+
expect(conn).toBeDefined();
|
|
1178
|
+
expect(conn?.weight).toBe(1);
|
|
1179
|
+
});
|
|
1180
|
+
});
|
|
1181
|
+
|
|
1182
|
+
it('should create a concatenated output group', () => {
|
|
1183
|
+
expect(layer.output).toBeInstanceOf(Group);
|
|
1184
|
+
expect(layer.output?.nodes).toHaveLength(size * memoryDepth);
|
|
1185
|
+
const block1Nodes = ((layer.nodes[0] as unknown) as Group).nodes;
|
|
1186
|
+
const block2Nodes = ((layer.nodes[1] as unknown) as Group).nodes;
|
|
1187
|
+
expect(layer.output?.nodes).toEqual([...block1Nodes, ...block2Nodes]);
|
|
1188
|
+
});
|
|
1189
|
+
|
|
1190
|
+
it('should have a custom input method', () => {
|
|
1191
|
+
expect(typeof layer.input).toBe('function');
|
|
1192
|
+
expect(layer.input).not.toBe(Layer.prototype.input);
|
|
1193
|
+
});
|
|
1194
|
+
|
|
1195
|
+
describe('Memory Layer input() method', () => {
|
|
1196
|
+
let sourceLayer: Layer;
|
|
1197
|
+
let sourceGroup: Group;
|
|
1198
|
+
let sourceOutputConnectSpy: jest.SpyInstance | undefined;
|
|
1199
|
+
let sourceGroupConnectSpy: jest.SpyInstance;
|
|
1200
|
+
let lastBlock: Group;
|
|
1201
|
+
|
|
1202
|
+
beforeEach(() => {
|
|
1203
|
+
sourceLayer = Layer.dense(size);
|
|
1204
|
+
sourceGroup = new Group(size);
|
|
1205
|
+
lastBlock = (layer.nodes[memoryDepth - 1] as unknown) as Group;
|
|
1206
|
+
|
|
1207
|
+
if (sourceLayer.output) {
|
|
1208
|
+
sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
|
|
1209
|
+
}
|
|
1210
|
+
sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
|
|
1211
|
+
});
|
|
1212
|
+
|
|
1213
|
+
afterEach(() => {
|
|
1214
|
+
sourceOutputConnectSpy?.mockRestore();
|
|
1215
|
+
sourceGroupConnectSpy.mockRestore();
|
|
1216
|
+
});
|
|
1217
|
+
|
|
1218
|
+
it('should connect source Layer output to the last memory block (ONE_TO_ONE, weight 1)', () => {
|
|
1219
|
+
layer.input(sourceLayer);
|
|
1220
|
+
|
|
1221
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
|
|
1222
|
+
expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
|
|
1223
|
+
expect.objectContaining({ nodes: lastBlock.nodes }),
|
|
1224
|
+
methods.groupConnection.ONE_TO_ONE,
|
|
1225
|
+
1
|
|
1226
|
+
);
|
|
1227
|
+
expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
|
|
1228
|
+
});
|
|
1229
|
+
|
|
1230
|
+
it('should connect source Group to the last memory block (ONE_TO_ONE, weight 1)', () => {
|
|
1231
|
+
layer.input(sourceGroup);
|
|
1232
|
+
|
|
1233
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
|
|
1234
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
1235
|
+
expect.objectContaining({ nodes: lastBlock.nodes }),
|
|
1236
|
+
methods.groupConnection.ONE_TO_ONE,
|
|
1237
|
+
1
|
|
1238
|
+
);
|
|
1239
|
+
expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
|
|
1240
|
+
});
|
|
1241
|
+
|
|
1242
|
+
it('should ignore provided method and weight, forcing ONE_TO_ONE and weight 1', () => {
|
|
1243
|
+
layer.input(sourceGroup, methods.groupConnection.ALL_TO_ALL, 0.5);
|
|
1244
|
+
|
|
1245
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
|
|
1246
|
+
expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
|
|
1247
|
+
expect.objectContaining({ nodes: lastBlock.nodes }),
|
|
1248
|
+
methods.groupConnection.ONE_TO_ONE,
|
|
1249
|
+
1
|
|
1250
|
+
);
|
|
1251
|
+
});
|
|
1252
|
+
|
|
1253
|
+
it('should throw error if source size does not match memory block size', () => {
|
|
1254
|
+
const wrongSizeSource = new Group(size + 1);
|
|
1255
|
+
expect(() => layer.input(wrongSizeSource)).toThrow(
|
|
1256
|
+
`Previous layer size (${wrongSizeSource.nodes.length}) must be same as memory size (${size})`
|
|
1257
|
+
);
|
|
1258
|
+
});
|
|
1259
|
+
|
|
1260
|
+
it('should throw error if the target input block is not a Group (edge case)', () => {
|
|
1261
|
+
layer.nodes[memoryDepth - 1] = new Node();
|
|
1262
|
+
expect(() => layer.input(sourceGroup)).toThrow(
|
|
1263
|
+
'Memory layer input block is not a Group.'
|
|
1264
|
+
);
|
|
1265
|
+
});
|
|
1266
|
+
});
|
|
1267
|
+
});
|
|
1268
|
+
|
|
1269
|
+
describe('Layer.batchNorm() and Layer.layerNorm()', () => {
|
|
1270
|
+
describe('Scenario: batchNorm normalizes activations', () => {
|
|
1271
|
+
it('mean is ~0', () => {
|
|
1272
|
+
// Arrange
|
|
1273
|
+
const size = 10;
|
|
1274
|
+
const layer = Layer.batchNorm(size);
|
|
1275
|
+
const input = Array.from({ length: size }, (_, i) => i + 1);
|
|
1276
|
+
// Act
|
|
1277
|
+
const activations = layer.activate(input);
|
|
1278
|
+
const mean = activations.reduce((a, b) => a + b, 0) / size;
|
|
1279
|
+
// Assert
|
|
1280
|
+
expect(mean).toBeCloseTo(0, 5);
|
|
1281
|
+
});
|
|
1282
|
+
it('variance is ~1', () => {
|
|
1283
|
+
// Arrange
|
|
1284
|
+
const size = 10;
|
|
1285
|
+
const layer = Layer.batchNorm(size);
|
|
1286
|
+
const input = Array.from({ length: size }, (_, i) => i + 1);
|
|
1287
|
+
// Act
|
|
1288
|
+
const activations = layer.activate(input);
|
|
1289
|
+
const mean = activations.reduce((a, b) => a + b, 0) / size;
|
|
1290
|
+
const variance =
|
|
1291
|
+
activations.reduce((a, b) => a + (b - mean) ** 2, 0) / size;
|
|
1292
|
+
// Assert
|
|
1293
|
+
expect(variance).toBeCloseTo(1, 2);
|
|
1294
|
+
});
|
|
1295
|
+
});
|
|
1296
|
+
describe('Scenario: layerNorm normalizes activations', () => {
|
|
1297
|
+
it('mean is ~0', () => {
|
|
1298
|
+
// Arrange
|
|
1299
|
+
const size = 10;
|
|
1300
|
+
const layer = Layer.layerNorm(size);
|
|
1301
|
+
const input = Array.from({ length: size }, (_, i) => i + 1);
|
|
1302
|
+
// Act
|
|
1303
|
+
const activations = layer.activate(input);
|
|
1304
|
+
const mean = activations.reduce((a, b) => a + b, 0) / size;
|
|
1305
|
+
// Assert
|
|
1306
|
+
expect(mean).toBeCloseTo(0, 5);
|
|
1307
|
+
});
|
|
1308
|
+
it('variance is ~1', () => {
|
|
1309
|
+
// Arrange
|
|
1310
|
+
const size = 10;
|
|
1311
|
+
const layer = Layer.layerNorm(size);
|
|
1312
|
+
const input = Array.from({ length: size }, (_, i) => i + 1);
|
|
1313
|
+
// Act
|
|
1314
|
+
const activations = layer.activate(input);
|
|
1315
|
+
const mean = activations.reduce((a, b) => a + b, 0) / size;
|
|
1316
|
+
const variance =
|
|
1317
|
+
activations.reduce((a, b) => a + (b - mean) ** 2, 0) / size;
|
|
1318
|
+
// Assert
|
|
1319
|
+
expect(variance).toBeCloseTo(1, 2);
|
|
1320
|
+
});
|
|
1321
|
+
});
|
|
1322
|
+
});
|
|
1323
|
+
|
|
1324
|
+
describe('Layer.conv1d() and Layer.attention()', () => {
|
|
1325
|
+
it('conv1d constructs a layer and slices input as stub', () => {
|
|
1326
|
+
const size = 4;
|
|
1327
|
+
const kernel = 3;
|
|
1328
|
+
const layer = Layer.conv1d(size, kernel);
|
|
1329
|
+
expect(layer.nodes).toHaveLength(size);
|
|
1330
|
+
expect(layer.output?.nodes).toHaveLength(size);
|
|
1331
|
+
// Should store conv params
|
|
1332
|
+
expect((layer as any).conv1d).toEqual({
|
|
1333
|
+
kernelSize: kernel,
|
|
1334
|
+
stride: 1,
|
|
1335
|
+
padding: 0,
|
|
1336
|
+
});
|
|
1337
|
+
// Activation slices input
|
|
1338
|
+
const input = [1, 2, 3, 4, 5, 6];
|
|
1339
|
+
expect(layer.activate(input)).toEqual([1, 2, 3, 4]);
|
|
1340
|
+
});
|
|
1341
|
+
it('attention constructs a layer and averages input as stub', () => {
|
|
1342
|
+
const size = 3;
|
|
1343
|
+
const heads = 2;
|
|
1344
|
+
const layer = Layer.attention(size, heads);
|
|
1345
|
+
expect(layer.nodes).toHaveLength(size);
|
|
1346
|
+
expect(layer.output?.nodes).toHaveLength(size);
|
|
1347
|
+
expect((layer as any).attention).toEqual({ heads });
|
|
1348
|
+
// Activation averages input
|
|
1349
|
+
const input = [2, 4, 6];
|
|
1350
|
+
expect(layer.activate(input)).toEqual([4, 4, 4]);
|
|
1351
|
+
});
|
|
1352
|
+
});
|
|
1353
|
+
});
|
|
1354
|
+
|
|
1355
|
+
describe('isGroup (private helper)', () => {
|
|
1356
|
+
let layer: Layer;
|
|
1357
|
+
|
|
1358
|
+
beforeEach(() => {
|
|
1359
|
+
layer = new Layer();
|
|
1360
|
+
});
|
|
1361
|
+
|
|
1362
|
+
it('should return true for a Group instance', () => {
|
|
1363
|
+
const group = new Group(1);
|
|
1364
|
+
expect((layer as any).isGroup(group)).toBe(true);
|
|
1365
|
+
});
|
|
1366
|
+
|
|
1367
|
+
it('should return false for a Node instance', () => {
|
|
1368
|
+
const node = new Node();
|
|
1369
|
+
expect((layer as any).isGroup(node)).toBe(false);
|
|
1370
|
+
});
|
|
1371
|
+
|
|
1372
|
+
it('should return false for a plain object', () => {
|
|
1373
|
+
const obj = { nodes: [], set: () => {} };
|
|
1374
|
+
expect((layer as any).isGroup(obj)).toBe(true);
|
|
1375
|
+
|
|
1376
|
+
const objMissingNodes = { set: () => {} };
|
|
1377
|
+
expect((layer as any).isGroup(objMissingNodes)).toBe(false);
|
|
1378
|
+
|
|
1379
|
+
const objMissingSet = { nodes: [] };
|
|
1380
|
+
expect((layer as any).isGroup(objMissingSet)).toBe(false);
|
|
1381
|
+
});
|
|
1382
|
+
|
|
1383
|
+
it('should return false for null', () => {
|
|
1384
|
+
expect((layer as any).isGroup(null)).toBe(false);
|
|
1385
|
+
});
|
|
1386
|
+
|
|
1387
|
+
it('should return false for undefined', () => {
|
|
1388
|
+
expect((layer as any).isGroup(undefined)).toBe(false);
|
|
1389
|
+
});
|
|
1390
|
+
|
|
1391
|
+
it('should return false for primitive types', () => {
|
|
1392
|
+
expect((layer as any).isGroup(123)).toBe(false);
|
|
1393
|
+
expect((layer as any).isGroup('string')).toBe(false);
|
|
1394
|
+
expect((layer as any).isGroup(true)).toBe(false);
|
|
1395
|
+
});
|
|
1396
|
+
});
|
|
1397
|
+
|
|
1398
|
+
describe('Layer-level Dropout', () => {
|
|
1399
|
+
describe('Scenario: all nodes in a layer are masked together during training', () => {
|
|
1400
|
+
it('all masks are the same (either all 0 or all 1)', () => {
|
|
1401
|
+
// Arrange
|
|
1402
|
+
const size = 8;
|
|
1403
|
+
const layer = new (require('../../src/architecture/layer').default)();
|
|
1404
|
+
for (let i = 0; i < size; i++) {
|
|
1405
|
+
layer.nodes.push(
|
|
1406
|
+
new (require('../../src/architecture/node').default)('hidden')
|
|
1407
|
+
);
|
|
1408
|
+
}
|
|
1409
|
+
layer.dropout = 0.7;
|
|
1410
|
+
// Act
|
|
1411
|
+
const masks = layer.nodes.map((n: any) => n.mask);
|
|
1412
|
+
// Assert
|
|
1413
|
+
expect(new Set(masks).size).toBe(1);
|
|
1414
|
+
});
|
|
1415
|
+
});
|
|
1416
|
+
|
|
1417
|
+
describe('Scenario: masks are reset to 1 after inference', () => {
|
|
1418
|
+
it('all masks are 1 after inference', () => {
|
|
1419
|
+
// Arrange
|
|
1420
|
+
const size = 8;
|
|
1421
|
+
const layer = new (require('../../src/architecture/layer').default)();
|
|
1422
|
+
for (let i = 0; i < size; i++) {
|
|
1423
|
+
layer.nodes.push(
|
|
1424
|
+
new (require('../../src/architecture/node').default)('hidden')
|
|
1425
|
+
);
|
|
1426
|
+
}
|
|
1427
|
+
layer.dropout = 0.7;
|
|
1428
|
+
// Simulate training
|
|
1429
|
+
layer.activate(undefined, true);
|
|
1430
|
+
// Act
|
|
1431
|
+
layer.activate(undefined, false);
|
|
1432
|
+
const masks = layer.nodes.map((n: any) => n.mask);
|
|
1433
|
+
// Assert
|
|
1434
|
+
masks.forEach((m: number) => expect(m).toBe(1));
|
|
1435
|
+
});
|
|
1436
|
+
});
|
|
1437
|
+
|
|
1438
|
+
describe('Scenario: inference is unaffected by previous dropout', () => {
|
|
1439
|
+
it('all activations are valid numbers after inference', () => {
|
|
1440
|
+
// Arrange
|
|
1441
|
+
const size = 8;
|
|
1442
|
+
const layer = new (require('../../src/architecture/layer').default)();
|
|
1443
|
+
for (let i = 0; i < size; i++) {
|
|
1444
|
+
layer.nodes.push(
|
|
1445
|
+
new (require('../../src/architecture/node').default)('hidden')
|
|
1446
|
+
);
|
|
1447
|
+
}
|
|
1448
|
+
layer.dropout = 0.7;
|
|
1449
|
+
// Simulate training
|
|
1450
|
+
layer.activate(undefined, true);
|
|
1451
|
+
// Act
|
|
1452
|
+
layer.activate(undefined, false);
|
|
1453
|
+
const activations = layer.nodes.map((n: any) => n.activation);
|
|
1454
|
+
// Assert
|
|
1455
|
+
activations.forEach((a: number) => expect(typeof a).toBe('number'));
|
|
1456
|
+
});
|
|
1457
|
+
});
|
|
1458
|
+
|
|
1459
|
+
describe('Scenario: masks all nodes together during training', () => {
|
|
1460
|
+
for (const dropout of [0.8, 0, 1]) {
|
|
1461
|
+
describe(`when dropout = ${dropout}`, () => {
|
|
1462
|
+
it('all masks are the same in this activation', () => {
|
|
1463
|
+
// Arrange
|
|
1464
|
+
const layer = Layer.dense(5);
|
|
1465
|
+
layer.dropout = dropout;
|
|
1466
|
+
// Act
|
|
1467
|
+
layer.activate(undefined, true); // training=true
|
|
1468
|
+
const mask = layer.nodes[0].mask;
|
|
1469
|
+
// Assert
|
|
1470
|
+
layer.nodes.forEach((node) => {
|
|
1471
|
+
expect(node.mask).toBe(mask);
|
|
1472
|
+
});
|
|
1473
|
+
});
|
|
1474
|
+
});
|
|
1475
|
+
}
|
|
1476
|
+
describe('when dropout is strictly between 0 and 1', () => {
|
|
1477
|
+
it('0 mask occurred over multiple activations', () => {
|
|
1478
|
+
// Arrange
|
|
1479
|
+
const layer = Layer.dense(5);
|
|
1480
|
+
layer.dropout = 0.8;
|
|
1481
|
+
// Act
|
|
1482
|
+
let zeroOccurred = false;
|
|
1483
|
+
for (let i = 0; i < 10; i++) {
|
|
1484
|
+
layer.activate(undefined, true);
|
|
1485
|
+
if (layer.nodes[0].mask === 0) zeroOccurred = true;
|
|
1486
|
+
}
|
|
1487
|
+
// Assert
|
|
1488
|
+
expect(zeroOccurred).toBe(true);
|
|
1489
|
+
});
|
|
1490
|
+
it('1 mask occurred over multiple activations', () => {
|
|
1491
|
+
// Arrange
|
|
1492
|
+
const layer = Layer.dense(5);
|
|
1493
|
+
layer.dropout = 0.8;
|
|
1494
|
+
// Act
|
|
1495
|
+
let oneOccurred = false;
|
|
1496
|
+
for (let i = 0; i < 10; i++) {
|
|
1497
|
+
layer.activate(undefined, true);
|
|
1498
|
+
if (layer.nodes[0].mask === 1) oneOccurred = true;
|
|
1499
|
+
}
|
|
1500
|
+
// Assert
|
|
1501
|
+
expect(oneOccurred).toBe(true);
|
|
1502
|
+
});
|
|
1503
|
+
});
|
|
1504
|
+
});
|
|
1505
|
+
|
|
1506
|
+
describe('Scenario: resets all masks to 1 after training (inference)', () => {
|
|
1507
|
+
it('all masks are 1 after inference', () => {
|
|
1508
|
+
// Arrange
|
|
1509
|
+
const layer = Layer.dense(4);
|
|
1510
|
+
layer.dropout = 0.9;
|
|
1511
|
+
layer.activate(undefined, true); // training
|
|
1512
|
+
// Act
|
|
1513
|
+
layer.activate(undefined, false); // inference
|
|
1514
|
+
// Assert
|
|
1515
|
+
layer.nodes.forEach((node) => {
|
|
1516
|
+
expect(node.mask).toBe(1);
|
|
1517
|
+
});
|
|
1518
|
+
});
|
|
1519
|
+
});
|
|
1520
|
+
|
|
1521
|
+
describe('Scenario: node-level dropout is not applied if layer-level dropout is set', () => {
|
|
1522
|
+
it('all node masks are 0 if layer.dropout = 1', () => {
|
|
1523
|
+
// Arrange
|
|
1524
|
+
const layer = Layer.dense(6);
|
|
1525
|
+
layer.dropout = 1; // always mask
|
|
1526
|
+
// Act
|
|
1527
|
+
layer.activate(undefined, true);
|
|
1528
|
+
// Assert
|
|
1529
|
+
layer.nodes.forEach((node) => {
|
|
1530
|
+
expect(node.mask).toBe(0);
|
|
1531
|
+
});
|
|
1532
|
+
});
|
|
1533
|
+
});
|
|
1534
|
+
});
|
|
1535
|
+
});
|