@reicek/neataptic-ts 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/ISSUE_TEMPLATE/bug_report.md +33 -0
- package/.github/ISSUE_TEMPLATE/feature_request.md +27 -0
- package/.github/PULL_REQUEST_TEMPLATE.md +28 -0
- package/.github/workflows/ci.yml +41 -0
- package/.github/workflows/deploy-pages.yml +29 -0
- package/.github/workflows/manual_release_pipeline.yml +62 -0
- package/.github/workflows/publish.yml +85 -0
- package/.github/workflows/release_dispatch.yml +38 -0
- package/.travis.yml +5 -0
- package/CONTRIBUTING.md +92 -0
- package/LICENSE +24 -0
- package/ONNX_EXPORT.md +87 -0
- package/README.md +1173 -0
- package/RELEASE.md +54 -0
- package/dist-docs/package.json +1 -0
- package/dist-docs/scripts/generate-docs.d.ts +2 -0
- package/dist-docs/scripts/generate-docs.d.ts.map +1 -0
- package/dist-docs/scripts/generate-docs.js +536 -0
- package/dist-docs/scripts/generate-docs.js.map +1 -0
- package/dist-docs/scripts/render-docs-html.d.ts +2 -0
- package/dist-docs/scripts/render-docs-html.d.ts.map +1 -0
- package/dist-docs/scripts/render-docs-html.js +148 -0
- package/dist-docs/scripts/render-docs-html.js.map +1 -0
- package/docs/FOLDERS.md +14 -0
- package/docs/README.md +1173 -0
- package/docs/architecture/README.md +1391 -0
- package/docs/architecture/index.html +938 -0
- package/docs/architecture/network/README.md +1210 -0
- package/docs/architecture/network/index.html +908 -0
- package/docs/assets/ascii-maze.bundle.js +16542 -0
- package/docs/assets/ascii-maze.bundle.js.map +7 -0
- package/docs/index.html +1419 -0
- package/docs/methods/README.md +670 -0
- package/docs/methods/index.html +477 -0
- package/docs/multithreading/README.md +274 -0
- package/docs/multithreading/index.html +215 -0
- package/docs/multithreading/workers/README.md +23 -0
- package/docs/multithreading/workers/browser/README.md +39 -0
- package/docs/multithreading/workers/browser/index.html +70 -0
- package/docs/multithreading/workers/index.html +57 -0
- package/docs/multithreading/workers/node/README.md +33 -0
- package/docs/multithreading/workers/node/index.html +66 -0
- package/docs/neat/README.md +1284 -0
- package/docs/neat/index.html +906 -0
- package/docs/src/README.md +2659 -0
- package/docs/src/index.html +1579 -0
- package/jest.config.ts +32 -0
- package/package.json +99 -0
- package/plans/HyperMorphoNEAT.md +293 -0
- package/plans/ONNX_EXPORT_PLAN.md +46 -0
- package/scripts/generate-docs.ts +486 -0
- package/scripts/render-docs-html.ts +138 -0
- package/scripts/types.d.ts +2 -0
- package/src/README.md +2659 -0
- package/src/architecture/README.md +1391 -0
- package/src/architecture/activationArrayPool.ts +135 -0
- package/src/architecture/architect.ts +635 -0
- package/src/architecture/connection.ts +148 -0
- package/src/architecture/group.ts +406 -0
- package/src/architecture/layer.ts +804 -0
- package/src/architecture/network/README.md +1210 -0
- package/src/architecture/network/network.activate.ts +223 -0
- package/src/architecture/network/network.connect.ts +157 -0
- package/src/architecture/network/network.deterministic.ts +167 -0
- package/src/architecture/network/network.evolve.ts +426 -0
- package/src/architecture/network/network.gating.ts +186 -0
- package/src/architecture/network/network.genetic.ts +247 -0
- package/src/architecture/network/network.mutate.ts +624 -0
- package/src/architecture/network/network.onnx.ts +463 -0
- package/src/architecture/network/network.prune.ts +216 -0
- package/src/architecture/network/network.remove.ts +96 -0
- package/src/architecture/network/network.serialize.ts +309 -0
- package/src/architecture/network/network.slab.ts +262 -0
- package/src/architecture/network/network.standalone.ts +246 -0
- package/src/architecture/network/network.stats.ts +59 -0
- package/src/architecture/network/network.topology.ts +86 -0
- package/src/architecture/network/network.training.ts +1278 -0
- package/src/architecture/network.ts +1302 -0
- package/src/architecture/node.ts +1288 -0
- package/src/architecture/onnx.ts +3 -0
- package/src/config.ts +83 -0
- package/src/methods/README.md +670 -0
- package/src/methods/activation.ts +372 -0
- package/src/methods/connection.ts +31 -0
- package/src/methods/cost.ts +347 -0
- package/src/methods/crossover.ts +63 -0
- package/src/methods/gating.ts +43 -0
- package/src/methods/methods.ts +8 -0
- package/src/methods/mutation.ts +300 -0
- package/src/methods/rate.ts +257 -0
- package/src/methods/selection.ts +65 -0
- package/src/multithreading/README.md +274 -0
- package/src/multithreading/multi.ts +339 -0
- package/src/multithreading/workers/README.md +23 -0
- package/src/multithreading/workers/browser/README.md +39 -0
- package/src/multithreading/workers/browser/testworker.ts +99 -0
- package/src/multithreading/workers/node/README.md +33 -0
- package/src/multithreading/workers/node/testworker.ts +72 -0
- package/src/multithreading/workers/node/worker.ts +70 -0
- package/src/multithreading/workers/workers.ts +22 -0
- package/src/neat/README.md +1284 -0
- package/src/neat/neat.adaptive.ts +544 -0
- package/src/neat/neat.compat.ts +164 -0
- package/src/neat/neat.constants.ts +20 -0
- package/src/neat/neat.diversity.ts +217 -0
- package/src/neat/neat.evaluate.ts +328 -0
- package/src/neat/neat.evolve.ts +1026 -0
- package/src/neat/neat.export.ts +249 -0
- package/src/neat/neat.helpers.ts +235 -0
- package/src/neat/neat.lineage.ts +220 -0
- package/src/neat/neat.multiobjective.ts +260 -0
- package/src/neat/neat.mutation.ts +718 -0
- package/src/neat/neat.objectives.ts +157 -0
- package/src/neat/neat.pruning.ts +190 -0
- package/src/neat/neat.selection.ts +269 -0
- package/src/neat/neat.speciation.ts +460 -0
- package/src/neat/neat.species.ts +151 -0
- package/src/neat/neat.telemetry.exports.ts +469 -0
- package/src/neat/neat.telemetry.ts +933 -0
- package/src/neat/neat.types.ts +275 -0
- package/src/neat.ts +1042 -0
- package/src/neataptic.ts +10 -0
- package/test/architecture/activationArrayPool.capacity.test.ts +19 -0
- package/test/architecture/activationArrayPool.test.ts +46 -0
- package/test/architecture/connection.test.ts +290 -0
- package/test/architecture/group.test.ts +950 -0
- package/test/architecture/layer.test.ts +1535 -0
- package/test/architecture/network.pruning.test.ts +65 -0
- package/test/architecture/node.test.ts +1602 -0
- package/test/examples/asciiMaze/asciiMaze.e2e.test.ts +499 -0
- package/test/examples/asciiMaze/asciiMaze.ts +41 -0
- package/test/examples/asciiMaze/browser-entry.ts +164 -0
- package/test/examples/asciiMaze/browserLogger.ts +221 -0
- package/test/examples/asciiMaze/browserTerminalUtility.ts +48 -0
- package/test/examples/asciiMaze/colors.ts +119 -0
- package/test/examples/asciiMaze/dashboardManager.ts +968 -0
- package/test/examples/asciiMaze/evolutionEngine.ts +1248 -0
- package/test/examples/asciiMaze/fitness.ts +136 -0
- package/test/examples/asciiMaze/index.html +128 -0
- package/test/examples/asciiMaze/index.ts +26 -0
- package/test/examples/asciiMaze/interfaces.ts +235 -0
- package/test/examples/asciiMaze/mazeMovement.ts +996 -0
- package/test/examples/asciiMaze/mazeUtils.ts +278 -0
- package/test/examples/asciiMaze/mazeVision.ts +402 -0
- package/test/examples/asciiMaze/mazeVisualization.ts +585 -0
- package/test/examples/asciiMaze/mazes.ts +245 -0
- package/test/examples/asciiMaze/networkRefinement.ts +76 -0
- package/test/examples/asciiMaze/networkVisualization.ts +901 -0
- package/test/examples/asciiMaze/terminalUtility.ts +73 -0
- package/test/methods/activation.test.ts +1142 -0
- package/test/methods/connection.test.ts +146 -0
- package/test/methods/cost.test.ts +1123 -0
- package/test/methods/crossover.test.ts +202 -0
- package/test/methods/gating.test.ts +144 -0
- package/test/methods/mutation.test.ts +451 -0
- package/test/methods/optimizers.advanced.test.ts +80 -0
- package/test/methods/optimizers.behavior.test.ts +105 -0
- package/test/methods/optimizers.formula.test.ts +89 -0
- package/test/methods/rate.cosineWarmRestarts.test.ts +44 -0
- package/test/methods/rate.linearWarmupDecay.test.ts +41 -0
- package/test/methods/rate.reduceOnPlateau.test.ts +45 -0
- package/test/methods/rate.test.ts +684 -0
- package/test/methods/selection.test.ts +245 -0
- package/test/multithreading/activations.functions.test.ts +54 -0
- package/test/multithreading/multi.test.ts +290 -0
- package/test/multithreading/worker.node.process.test.ts +39 -0
- package/test/multithreading/workers.coverage.test.ts +36 -0
- package/test/multithreading/workers.dynamic.import.test.ts +8 -0
- package/test/neat/neat.adaptive.complexityBudget.test.ts +34 -0
- package/test/neat/neat.adaptive.criterion.complexity.test.ts +50 -0
- package/test/neat/neat.adaptive.mutation.strategy.test.ts +37 -0
- package/test/neat/neat.adaptive.operator.decay.test.ts +31 -0
- package/test/neat/neat.adaptive.phasedComplexity.test.ts +25 -0
- package/test/neat/neat.adaptive.pruning.test.ts +25 -0
- package/test/neat/neat.adaptive.targetSpecies.test.ts +43 -0
- package/test/neat/neat.additional.coverage.test.ts +126 -0
- package/test/neat/neat.advanced.enhancements.test.ts +85 -0
- package/test/neat/neat.advanced.test.ts +589 -0
- package/test/neat/neat.diversity.autocompat.test.ts +47 -0
- package/test/neat/neat.diversity.metrics.test.ts +21 -0
- package/test/neat/neat.diversity.stats.test.ts +44 -0
- package/test/neat/neat.enhancements.test.ts +79 -0
- package/test/neat/neat.entropy.ancestorAdaptive.test.ts +133 -0
- package/test/neat/neat.entropy.compat.csv.test.ts +108 -0
- package/test/neat/neat.evolution.pruning.test.ts +39 -0
- package/test/neat/neat.fastmode.autotune.test.ts +42 -0
- package/test/neat/neat.innovation.test.ts +134 -0
- package/test/neat/neat.lineage.antibreeding.test.ts +35 -0
- package/test/neat/neat.lineage.entropy.test.ts +56 -0
- package/test/neat/neat.lineage.inbreeding.test.ts +49 -0
- package/test/neat/neat.lineage.pressure.test.ts +29 -0
- package/test/neat/neat.multiobjective.adaptive.test.ts +57 -0
- package/test/neat/neat.multiobjective.dynamic.schedule.test.ts +46 -0
- package/test/neat/neat.multiobjective.dynamic.test.ts +31 -0
- package/test/neat/neat.multiobjective.fastsort.delegation.test.ts +51 -0
- package/test/neat/neat.multiobjective.prune.test.ts +39 -0
- package/test/neat/neat.multiobjective.test.ts +21 -0
- package/test/neat/neat.mutation.undefined.pool.test.ts +24 -0
- package/test/neat/neat.objective.events.test.ts +26 -0
- package/test/neat/neat.objective.importance.test.ts +21 -0
- package/test/neat/neat.objective.lifetimes.test.ts +33 -0
- package/test/neat/neat.offspring.allocation.test.ts +22 -0
- package/test/neat/neat.operator.bandit.test.ts +17 -0
- package/test/neat/neat.operator.phases.test.ts +38 -0
- package/test/neat/neat.pruneInactive.behavior.test.ts +54 -0
- package/test/neat/neat.reenable.adaptation.test.ts +18 -0
- package/test/neat/neat.rng.state.test.ts +22 -0
- package/test/neat/neat.spawn.add.test.ts +123 -0
- package/test/neat/neat.speciation.test.ts +96 -0
- package/test/neat/neat.species.allocation.telemetry.test.ts +26 -0
- package/test/neat/neat.species.history.csv.test.ts +24 -0
- package/test/neat/neat.telemetry.advanced.test.ts +226 -0
- package/test/neat/neat.telemetry.csv.lineage.test.ts +19 -0
- package/test/neat/neat.telemetry.parity.test.ts +42 -0
- package/test/neat/neat.telemetry.stream.test.ts +19 -0
- package/test/neat/neat.telemetry.test.ts +16 -0
- package/test/neat/neat.test.ts +422 -0
- package/test/neat/neat.utilities.test.ts +44 -0
- package/test/network/__suppress_console.ts +9 -0
- package/test/network/acyclic.topoorder.test.ts +17 -0
- package/test/network/checkpoint.metricshook.test.ts +36 -0
- package/test/network/error.handling.test.ts +581 -0
- package/test/network/evolution.test.ts +285 -0
- package/test/network/genetic.test.ts +208 -0
- package/test/network/learning.capability.test.ts +244 -0
- package/test/network/mutation.effects.test.ts +492 -0
- package/test/network/network.activate.test.ts +115 -0
- package/test/network/network.activateBatch.test.ts +30 -0
- package/test/network/network.deterministic.test.ts +64 -0
- package/test/network/network.evolve.branches.test.ts +75 -0
- package/test/network/network.evolve.multithread.branches.test.ts +83 -0
- package/test/network/network.evolve.test.ts +100 -0
- package/test/network/network.gating.removal.test.ts +93 -0
- package/test/network/network.mutate.additional.test.ts +145 -0
- package/test/network/network.mutate.edgecases.test.ts +101 -0
- package/test/network/network.mutate.test.ts +101 -0
- package/test/network/network.prune.earlyexit.test.ts +38 -0
- package/test/network/network.remove.errors.test.ts +45 -0
- package/test/network/network.slab.fallbacks.test.ts +22 -0
- package/test/network/network.stats.test.ts +45 -0
- package/test/network/network.training.advanced.test.ts +149 -0
- package/test/network/network.training.basic.test.ts +228 -0
- package/test/network/network.training.helpers.test.ts +183 -0
- package/test/network/onnx.export.test.ts +310 -0
- package/test/network/onnx.import.test.ts +129 -0
- package/test/network/pruning.topology.test.ts +282 -0
- package/test/network/regularization.determinism.test.ts +83 -0
- package/test/network/regularization.dropconnect.test.ts +17 -0
- package/test/network/regularization.dropconnect.validation.test.ts +18 -0
- package/test/network/regularization.stochasticdepth.test.ts +27 -0
- package/test/network/regularization.test.ts +843 -0
- package/test/network/regularization.weightnoise.test.ts +30 -0
- package/test/network/setupTests.ts +2 -0
- package/test/network/standalone.test.ts +332 -0
- package/test/network/structure.serialization.test.ts +660 -0
- package/test/training/training.determinism.mixed-precision.test.ts +134 -0
- package/test/training/training.earlystopping.test.ts +91 -0
- package/test/training/training.edge-cases.test.ts +91 -0
- package/test/training/training.extensions.test.ts +47 -0
- package/test/training/training.gradient.features.test.ts +110 -0
- package/test/training/training.gradient.refinements.test.ts +170 -0
- package/test/training/training.gradient.separate-bias.test.ts +41 -0
- package/test/training/training.optimizer.test.ts +48 -0
- package/test/training/training.plateau.smoothing.test.ts +58 -0
- package/test/training/training.smoothing.types.test.ts +174 -0
- package/test/training/training.train.options.coverage.test.ts +52 -0
- package/test/utils/console-helper.ts +76 -0
- package/test/utils/jest-setup.ts +60 -0
- package/test/utils/test-helpers.ts +175 -0
- package/tsconfig.docs.json +12 -0
- package/tsconfig.json +21 -0
- package/webpack.config.js +49 -0
|
@@ -0,0 +1,950 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @jest-environment node
|
|
3
|
+
*/
|
|
4
|
+
import Group from '../../src/architecture/group';
|
|
5
|
+
import Node from '../../src/architecture/node';
|
|
6
|
+
import Layer from '../../src/architecture/layer';
|
|
7
|
+
import Connection from '../../src/architecture/connection';
|
|
8
|
+
import * as methods from '../../src/methods/methods';
|
|
9
|
+
import { config } from '../../src/config';
|
|
10
|
+
|
|
11
|
+
// Retry failed tests
|
|
12
|
+
jest.retryTimes(2, { logErrorsBeforeRetry: true });
|
|
13
|
+
|
|
14
|
+
beforeEach(() => {
|
|
15
|
+
const origLog = console.log;
|
|
16
|
+
const origDir = console.dir;
|
|
17
|
+
console.log = function (...args) {
|
|
18
|
+
origLog.apply(
|
|
19
|
+
console,
|
|
20
|
+
args.map((arg) =>
|
|
21
|
+
arg && typeof arg.toJSON === 'function' ? arg.toJSON() : arg
|
|
22
|
+
)
|
|
23
|
+
);
|
|
24
|
+
};
|
|
25
|
+
console.dir = function (obj, options) {
|
|
26
|
+
if (obj && typeof obj.toJSON === 'function') {
|
|
27
|
+
obj = obj.toJSON();
|
|
28
|
+
}
|
|
29
|
+
origDir.call(console, obj, options);
|
|
30
|
+
};
|
|
31
|
+
});
|
|
32
|
+
|
|
33
|
+
describe('Group', () => {
|
|
34
|
+
const epsilon = 1e-9; // Tolerance for float comparisons
|
|
35
|
+
|
|
36
|
+
describe('Constructor', () => {
|
|
37
|
+
const size = 5;
|
|
38
|
+
let group: Group;
|
|
39
|
+
|
|
40
|
+
beforeEach(() => {
|
|
41
|
+
// Arrange
|
|
42
|
+
group = new Group(size);
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('should create a group with the specified number of nodes', () => {
|
|
46
|
+
// Arrange done in beforeEach
|
|
47
|
+
// Act
|
|
48
|
+
const nodeCount = group.nodes.length;
|
|
49
|
+
// Assert
|
|
50
|
+
expect(nodeCount).toBe(size);
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
it('should initialize all nodes as instances of Node', () => {
|
|
54
|
+
// Arrange done in beforeEach
|
|
55
|
+
// Act
|
|
56
|
+
const allNodesAreNode = group.nodes.every((node) => node instanceof Node);
|
|
57
|
+
// Assert
|
|
58
|
+
expect(allNodesAreNode).toBe(true);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
it('should initialize connection properties as empty arrays (in)', () => {
|
|
62
|
+
// Arrange done in beforeEach
|
|
63
|
+
// Act
|
|
64
|
+
const inConnections = group.connections.in;
|
|
65
|
+
// Assert
|
|
66
|
+
expect(inConnections).toEqual([]);
|
|
67
|
+
});
|
|
68
|
+
|
|
69
|
+
it('should initialize connection properties as empty arrays (out)', () => {
|
|
70
|
+
// Arrange done in beforeEach
|
|
71
|
+
// Act
|
|
72
|
+
const outConnections = group.connections.out;
|
|
73
|
+
// Assert
|
|
74
|
+
expect(outConnections).toEqual([]);
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
it('should initialize connection properties as empty arrays (self)', () => {
|
|
78
|
+
// Arrange done in beforeEach
|
|
79
|
+
// Act
|
|
80
|
+
const selfConnections = group.connections.self;
|
|
81
|
+
// Assert
|
|
82
|
+
expect(selfConnections).toEqual([]);
|
|
83
|
+
});
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
describe('activate()', () => {
|
|
87
|
+
describe('Scenario: input group', () => {
|
|
88
|
+
const size = 3;
|
|
89
|
+
const inputValues = [0.5, -0.2, 0.9];
|
|
90
|
+
let group: Group;
|
|
91
|
+
beforeEach(() => {
|
|
92
|
+
// Arrange
|
|
93
|
+
group = new Group(size);
|
|
94
|
+
group.nodes.forEach((node) => (node.type = 'input'));
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
describe('when activating all input nodes with input values', () => {
|
|
98
|
+
let activations: number[];
|
|
99
|
+
beforeEach(() => {
|
|
100
|
+
// Act
|
|
101
|
+
activations = group.activate(inputValues);
|
|
102
|
+
});
|
|
103
|
+
it('returns an array of length equal to group size', () => {
|
|
104
|
+
// Assert
|
|
105
|
+
expect(activations).toHaveLength(size);
|
|
106
|
+
});
|
|
107
|
+
it('assigns correct activation value to node 0', () => {
|
|
108
|
+
// Assert
|
|
109
|
+
expect(activations[0]).toBe(inputValues[0]);
|
|
110
|
+
});
|
|
111
|
+
it('assigns correct activation value to node 1', () => {
|
|
112
|
+
// Assert
|
|
113
|
+
expect(activations[1]).toBe(inputValues[1]);
|
|
114
|
+
});
|
|
115
|
+
it('assigns correct activation value to node 2', () => {
|
|
116
|
+
// Assert
|
|
117
|
+
expect(activations[2]).toBe(inputValues[2]);
|
|
118
|
+
});
|
|
119
|
+
});
|
|
120
|
+
|
|
121
|
+
describe('when input array length does not match group size', () => {
|
|
122
|
+
it('throws an error', () => {
|
|
123
|
+
// Act & Assert
|
|
124
|
+
expect(() => group.activate([1, 2])).toThrow(
|
|
125
|
+
'Array with values should be same as the amount of nodes!'
|
|
126
|
+
);
|
|
127
|
+
});
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
describe('Scenario: hidden group', () => {
|
|
132
|
+
const size = 3;
|
|
133
|
+
const inputValues = [0.5, -0.2, 0.9];
|
|
134
|
+
let group: Group;
|
|
135
|
+
beforeEach(() => {
|
|
136
|
+
// Arrange
|
|
137
|
+
group = new Group(size);
|
|
138
|
+
group.nodes.forEach((node) => (node.type = 'hidden'));
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
describe('when activating all hidden nodes with input values', () => {
|
|
142
|
+
let activations: number[];
|
|
143
|
+
beforeEach(() => {
|
|
144
|
+
// Act
|
|
145
|
+
activations = group.activate(inputValues);
|
|
146
|
+
});
|
|
147
|
+
it('returns an array of length equal to group size', () => {
|
|
148
|
+
// Assert
|
|
149
|
+
expect(activations).toHaveLength(size);
|
|
150
|
+
});
|
|
151
|
+
it('assigns correct activation value to node 0', () => {
|
|
152
|
+
// Assert
|
|
153
|
+
expect(activations[0]).toBeCloseTo(
|
|
154
|
+
1 / (1 + Math.exp(-inputValues[0])),
|
|
155
|
+
10
|
|
156
|
+
);
|
|
157
|
+
});
|
|
158
|
+
it('assigns correct activation value to node 1', () => {
|
|
159
|
+
// Assert
|
|
160
|
+
expect(activations[1]).toBeCloseTo(
|
|
161
|
+
1 / (1 + Math.exp(-inputValues[1])),
|
|
162
|
+
10
|
|
163
|
+
);
|
|
164
|
+
});
|
|
165
|
+
it('assigns correct activation value to node 2', () => {
|
|
166
|
+
// Assert
|
|
167
|
+
expect(activations[2]).toBeCloseTo(
|
|
168
|
+
1 / (1 + Math.exp(-inputValues[2])),
|
|
169
|
+
10
|
|
170
|
+
);
|
|
171
|
+
});
|
|
172
|
+
});
|
|
173
|
+
|
|
174
|
+
describe('when input array length does not match group size', () => {
|
|
175
|
+
it('throws an error', () => {
|
|
176
|
+
// Act & Assert
|
|
177
|
+
expect(() => group.activate([1, 2])).toThrow(
|
|
178
|
+
'Array with values should be same as the amount of nodes!'
|
|
179
|
+
);
|
|
180
|
+
});
|
|
181
|
+
});
|
|
182
|
+
});
|
|
183
|
+
});
|
|
184
|
+
|
|
185
|
+
describe('propagate()', () => {
|
|
186
|
+
const size = 3;
|
|
187
|
+
let group: Group;
|
|
188
|
+
const rate = 0.1;
|
|
189
|
+
const momentum = 0.9;
|
|
190
|
+
|
|
191
|
+
beforeEach(() => {
|
|
192
|
+
// Arrange
|
|
193
|
+
group = new Group(size);
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
describe('when propagating without target values', () => {
|
|
197
|
+
it('does not throw', () => {
|
|
198
|
+
// Act & Assert
|
|
199
|
+
expect(() => group.propagate(rate, momentum)).not.toThrow();
|
|
200
|
+
});
|
|
201
|
+
});
|
|
202
|
+
|
|
203
|
+
describe('when propagating with target values', () => {
|
|
204
|
+
it('does not throw if target length matches group size', () => {
|
|
205
|
+
// Arrange
|
|
206
|
+
const targets = [0.1, 0.2, 0.3];
|
|
207
|
+
// Act & Assert
|
|
208
|
+
expect(() => group.propagate(rate, momentum, targets)).not.toThrow();
|
|
209
|
+
});
|
|
210
|
+
it('throws if target length does not match group size', () => {
|
|
211
|
+
// Arrange
|
|
212
|
+
const targets = [0.1, 0.2];
|
|
213
|
+
// Act & Assert
|
|
214
|
+
expect(() => group.propagate(rate, momentum, targets)).toThrow(
|
|
215
|
+
'Array with values should be same as the amount of nodes!'
|
|
216
|
+
);
|
|
217
|
+
});
|
|
218
|
+
});
|
|
219
|
+
});
|
|
220
|
+
|
|
221
|
+
describe('gate()', () => {
|
|
222
|
+
let gatingGroup: Group;
|
|
223
|
+
let sourceNode1: Node, sourceNode2: Node;
|
|
224
|
+
let targetNode1: Node, targetNode2: Node;
|
|
225
|
+
let conn1: Connection, conn2: Connection, selfConn: Connection;
|
|
226
|
+
let connections: Connection[];
|
|
227
|
+
|
|
228
|
+
beforeEach(() => {
|
|
229
|
+
// Arrange
|
|
230
|
+
gatingGroup = new Group(2);
|
|
231
|
+
sourceNode1 = new Node();
|
|
232
|
+
sourceNode2 = new Node();
|
|
233
|
+
targetNode1 = new Node();
|
|
234
|
+
targetNode2 = new Node();
|
|
235
|
+
conn1 = sourceNode1.connect(targetNode1)[0];
|
|
236
|
+
conn2 = sourceNode2.connect(targetNode2)[0];
|
|
237
|
+
selfConn = sourceNode1.connect(sourceNode1)[0];
|
|
238
|
+
connections = [conn1, conn2];
|
|
239
|
+
});
|
|
240
|
+
|
|
241
|
+
describe('when no gating method is specified', () => {
|
|
242
|
+
it('throws an error', () => {
|
|
243
|
+
expect(() => gatingGroup.gate(connections, undefined)).toThrow(
|
|
244
|
+
'Please specify a gating method: Gating.INPUT, Gating.OUTPUT, or Gating.SELF'
|
|
245
|
+
);
|
|
246
|
+
});
|
|
247
|
+
});
|
|
248
|
+
|
|
249
|
+
describe('when gating a single connection (INPUT)', () => {
|
|
250
|
+
beforeEach(() => {
|
|
251
|
+
// Act
|
|
252
|
+
gatingGroup.gate(conn1, methods.gating.INPUT);
|
|
253
|
+
});
|
|
254
|
+
it('assigns the first node as gater', () => {
|
|
255
|
+
expect(conn1.gater).toBe(gatingGroup.nodes[0]);
|
|
256
|
+
});
|
|
257
|
+
});
|
|
258
|
+
|
|
259
|
+
describe('when gating multiple connections (INPUT)', () => {
|
|
260
|
+
beforeEach(() => {
|
|
261
|
+
// Act
|
|
262
|
+
gatingGroup.gate(connections, methods.gating.INPUT);
|
|
263
|
+
});
|
|
264
|
+
it('assigns the first node as gater for conn1', () => {
|
|
265
|
+
expect(conn1.gater).toBe(gatingGroup.nodes[0]);
|
|
266
|
+
});
|
|
267
|
+
it('assigns the second node as gater for conn2', () => {
|
|
268
|
+
expect(conn2.gater).toBe(gatingGroup.nodes[1]);
|
|
269
|
+
});
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
describe('when gating multiple connections (OUTPUT)', () => {
|
|
273
|
+
beforeEach(() => {
|
|
274
|
+
// Act
|
|
275
|
+
gatingGroup.gate(connections, methods.gating.OUTPUT);
|
|
276
|
+
});
|
|
277
|
+
it('assigns the first node as gater for conn1', () => {
|
|
278
|
+
expect(conn1.gater).toBe(gatingGroup.nodes[0]);
|
|
279
|
+
});
|
|
280
|
+
it('assigns the second node as gater for conn2', () => {
|
|
281
|
+
expect(conn2.gater).toBe(gatingGroup.nodes[1]);
|
|
282
|
+
});
|
|
283
|
+
});
|
|
284
|
+
|
|
285
|
+
describe('when gating a self connection (SELF)', () => {
|
|
286
|
+
beforeEach(() => {
|
|
287
|
+
// Act
|
|
288
|
+
gatingGroup.gate(selfConn, methods.gating.SELF);
|
|
289
|
+
});
|
|
290
|
+
it('assigns the first node as gater for selfConn', () => {
|
|
291
|
+
expect(selfConn.gater).toBe(gatingGroup.nodes[0]);
|
|
292
|
+
});
|
|
293
|
+
});
|
|
294
|
+
|
|
295
|
+
describe('when more connections than gaters (cycle)', () => {
|
|
296
|
+
let conn3: Connection;
|
|
297
|
+
beforeEach(() => {
|
|
298
|
+
conn3 = sourceNode1.connect(targetNode2)[0];
|
|
299
|
+
const threeConnections = [conn1, conn2, conn3];
|
|
300
|
+
gatingGroup.gate(threeConnections, methods.gating.INPUT);
|
|
301
|
+
});
|
|
302
|
+
it('cycles gater assignment for conn1', () => {
|
|
303
|
+
expect(conn1.gater).toBe(gatingGroup.nodes[0]);
|
|
304
|
+
});
|
|
305
|
+
it('cycles gater assignment for conn2', () => {
|
|
306
|
+
expect(conn2.gater).toBe(gatingGroup.nodes[1]);
|
|
307
|
+
});
|
|
308
|
+
it('cycles gater assignment for conn3', () => {
|
|
309
|
+
expect(conn3.gater).toBe(gatingGroup.nodes[0]);
|
|
310
|
+
});
|
|
311
|
+
});
|
|
312
|
+
});
|
|
313
|
+
|
|
314
|
+
describe('set()', () => {
|
|
315
|
+
const size = 4;
|
|
316
|
+
let group: Group;
|
|
317
|
+
|
|
318
|
+
beforeEach(() => {
|
|
319
|
+
// Arrange
|
|
320
|
+
group = new Group(size);
|
|
321
|
+
});
|
|
322
|
+
|
|
323
|
+
describe('when setting bias for all nodes', () => {
|
|
324
|
+
const biasValue = 0.5;
|
|
325
|
+
beforeEach(() => {
|
|
326
|
+
// Act
|
|
327
|
+
group.set({ bias: biasValue });
|
|
328
|
+
});
|
|
329
|
+
it('sets bias for node 0', () => {
|
|
330
|
+
expect(group.nodes[0].bias).toBe(biasValue);
|
|
331
|
+
});
|
|
332
|
+
it('sets bias for node 1', () => {
|
|
333
|
+
expect(group.nodes[1].bias).toBe(biasValue);
|
|
334
|
+
});
|
|
335
|
+
it('sets bias for node 2', () => {
|
|
336
|
+
expect(group.nodes[2].bias).toBe(biasValue);
|
|
337
|
+
});
|
|
338
|
+
it('sets bias for node 3', () => {
|
|
339
|
+
expect(group.nodes[3].bias).toBe(biasValue);
|
|
340
|
+
});
|
|
341
|
+
});
|
|
342
|
+
|
|
343
|
+
describe('when setting squash function for all nodes', () => {
|
|
344
|
+
const squashFn = methods.Activation.relu;
|
|
345
|
+
beforeEach(() => {
|
|
346
|
+
// Act
|
|
347
|
+
group.set({ squash: squashFn });
|
|
348
|
+
});
|
|
349
|
+
it('sets squash for node 0', () => {
|
|
350
|
+
expect(group.nodes[0].squash).toBe(squashFn);
|
|
351
|
+
});
|
|
352
|
+
it('sets squash for node 1', () => {
|
|
353
|
+
expect(group.nodes[1].squash).toBe(squashFn);
|
|
354
|
+
});
|
|
355
|
+
it('sets squash for node 2', () => {
|
|
356
|
+
expect(group.nodes[2].squash).toBe(squashFn);
|
|
357
|
+
});
|
|
358
|
+
it('sets squash for node 3', () => {
|
|
359
|
+
expect(group.nodes[3].squash).toBe(squashFn);
|
|
360
|
+
});
|
|
361
|
+
});
|
|
362
|
+
|
|
363
|
+
describe('when setting type for all nodes', () => {
|
|
364
|
+
const typeValue = 'memory';
|
|
365
|
+
beforeEach(() => {
|
|
366
|
+
// Act
|
|
367
|
+
group.set({ type: typeValue });
|
|
368
|
+
});
|
|
369
|
+
it('sets type for node 0', () => {
|
|
370
|
+
expect(group.nodes[0].type).toBe(typeValue);
|
|
371
|
+
});
|
|
372
|
+
it('sets type for node 1', () => {
|
|
373
|
+
expect(group.nodes[1].type).toBe(typeValue);
|
|
374
|
+
});
|
|
375
|
+
it('sets type for node 2', () => {
|
|
376
|
+
expect(group.nodes[2].type).toBe(typeValue);
|
|
377
|
+
});
|
|
378
|
+
it('sets type for node 3', () => {
|
|
379
|
+
expect(group.nodes[3].type).toBe(typeValue);
|
|
380
|
+
});
|
|
381
|
+
});
|
|
382
|
+
|
|
383
|
+
describe('when setting multiple properties at once', () => {
|
|
384
|
+
const biasValue = -0.1;
|
|
385
|
+
const squashFn = methods.Activation.tanh;
|
|
386
|
+
const typeValue = 'output';
|
|
387
|
+
beforeEach(() => {
|
|
388
|
+
// Act
|
|
389
|
+
group.set({ bias: biasValue, squash: squashFn, type: typeValue });
|
|
390
|
+
});
|
|
391
|
+
it('sets bias for all nodes', () => {
|
|
392
|
+
group.nodes.forEach((node) => {
|
|
393
|
+
expect(node.bias).toBe(biasValue);
|
|
394
|
+
});
|
|
395
|
+
});
|
|
396
|
+
it('sets squash for all nodes', () => {
|
|
397
|
+
group.nodes.forEach((node) => {
|
|
398
|
+
expect(node.squash).toBe(squashFn);
|
|
399
|
+
});
|
|
400
|
+
});
|
|
401
|
+
it('sets type for all nodes', () => {
|
|
402
|
+
group.nodes.forEach((node) => {
|
|
403
|
+
expect(node.type).toBe(typeValue);
|
|
404
|
+
});
|
|
405
|
+
});
|
|
406
|
+
});
|
|
407
|
+
|
|
408
|
+
describe('when not changing properties if not provided', () => {
|
|
409
|
+
let initialBiases: number[];
|
|
410
|
+
let initialSquashes: any[];
|
|
411
|
+
let initialTypes: string[];
|
|
412
|
+
beforeEach(() => {
|
|
413
|
+
// Arrange
|
|
414
|
+
initialBiases = group.nodes.map((node) => node.bias);
|
|
415
|
+
initialSquashes = group.nodes.map((node) => node.squash);
|
|
416
|
+
initialTypes = group.nodes.map((node) => node.type);
|
|
417
|
+
// Act
|
|
418
|
+
group.set({});
|
|
419
|
+
});
|
|
420
|
+
it('does not change bias', () => {
|
|
421
|
+
group.nodes.forEach((node, i) => {
|
|
422
|
+
expect(node.bias).toBe(initialBiases[i]);
|
|
423
|
+
});
|
|
424
|
+
});
|
|
425
|
+
it('does not change squash', () => {
|
|
426
|
+
group.nodes.forEach((node, i) => {
|
|
427
|
+
expect(node.squash).toBe(initialSquashes[i]);
|
|
428
|
+
});
|
|
429
|
+
});
|
|
430
|
+
it('does not change type', () => {
|
|
431
|
+
group.nodes.forEach((node, i) => {
|
|
432
|
+
expect(node.type).toBe(initialTypes[i]);
|
|
433
|
+
});
|
|
434
|
+
});
|
|
435
|
+
});
|
|
436
|
+
|
|
437
|
+
describe('when setting only bias', () => {
|
|
438
|
+
let initialSquashes: any[];
|
|
439
|
+
let initialTypes: string[];
|
|
440
|
+
beforeEach(() => {
|
|
441
|
+
// Arrange
|
|
442
|
+
initialSquashes = group.nodes.map((node) => node.squash);
|
|
443
|
+
initialTypes = group.nodes.map((node) => node.type);
|
|
444
|
+
// Act
|
|
445
|
+
group.set({ bias: 0.9 });
|
|
446
|
+
});
|
|
447
|
+
it('sets bias for all nodes', () => {
|
|
448
|
+
group.nodes.forEach((node) => {
|
|
449
|
+
expect(node.bias).toBe(0.9);
|
|
450
|
+
});
|
|
451
|
+
});
|
|
452
|
+
it('does not change squash', () => {
|
|
453
|
+
group.nodes.forEach((node, i) => {
|
|
454
|
+
expect(node.squash).toBe(initialSquashes[i]);
|
|
455
|
+
});
|
|
456
|
+
});
|
|
457
|
+
it('does not change type', () => {
|
|
458
|
+
group.nodes.forEach((node, i) => {
|
|
459
|
+
expect(node.type).toBe(initialTypes[i]);
|
|
460
|
+
});
|
|
461
|
+
});
|
|
462
|
+
});
|
|
463
|
+
});
|
|
464
|
+
|
|
465
|
+
describe('disconnect()', () => {
|
|
466
|
+
let group1: Group;
|
|
467
|
+
let group2: Group;
|
|
468
|
+
let node: Node;
|
|
469
|
+
const size1 = 2;
|
|
470
|
+
const size2 = 2;
|
|
471
|
+
|
|
472
|
+
beforeEach(() => {
|
|
473
|
+
// Arrange
|
|
474
|
+
group1 = new Group(size1);
|
|
475
|
+
group2 = new Group(size2);
|
|
476
|
+
node = new Node();
|
|
477
|
+
group1.connect(group2, methods.groupConnection.ALL_TO_ALL);
|
|
478
|
+
group1.connect(node);
|
|
479
|
+
group2.connect(group1, methods.groupConnection.ALL_TO_ALL);
|
|
480
|
+
});
|
|
481
|
+
|
|
482
|
+
describe('Scenario: From Group', () => {
|
|
483
|
+
describe('when disconnecting one-sided (default)', () => {
|
|
484
|
+
beforeEach(() => {
|
|
485
|
+
// Act
|
|
486
|
+
group1.disconnect(group2);
|
|
487
|
+
});
|
|
488
|
+
it('removes out connections from group1 to group2', () => {
|
|
489
|
+
expect(group1.connections.out).toHaveLength(size1);
|
|
490
|
+
});
|
|
491
|
+
it('does not change group2 in connections', () => {
|
|
492
|
+
expect(group2.connections.in).toHaveLength(size1 * size2);
|
|
493
|
+
});
|
|
494
|
+
it('does not change group2 out connections', () => {
|
|
495
|
+
expect(group2.connections.out).toHaveLength(size1 * size2);
|
|
496
|
+
});
|
|
497
|
+
it('does not change group1 in connections', () => {
|
|
498
|
+
expect(group1.connections.in).toHaveLength(size1 * size2);
|
|
499
|
+
});
|
|
500
|
+
});
|
|
501
|
+
|
|
502
|
+
describe('when disconnecting two-sided', () => {
|
|
503
|
+
beforeEach(() => {
|
|
504
|
+
// Act
|
|
505
|
+
group1.disconnect(group2, true);
|
|
506
|
+
});
|
|
507
|
+
it('removes out connections from group1 to group2', () => {
|
|
508
|
+
expect(group1.connections.out).toHaveLength(size1);
|
|
509
|
+
});
|
|
510
|
+
it('removes in connections from group2', () => {
|
|
511
|
+
expect(group2.connections.in).toHaveLength(0);
|
|
512
|
+
});
|
|
513
|
+
it('removes out connections from group2', () => {
|
|
514
|
+
expect(group2.connections.out).toHaveLength(0);
|
|
515
|
+
});
|
|
516
|
+
it('removes in connections from group1', () => {
|
|
517
|
+
expect(group1.connections.in).toHaveLength(0);
|
|
518
|
+
});
|
|
519
|
+
});
|
|
520
|
+
});
|
|
521
|
+
|
|
522
|
+
describe('Scenario: From Node', () => {
|
|
523
|
+
describe('when disconnecting one-sided (default)', () => {
|
|
524
|
+
beforeEach(() => {
|
|
525
|
+
// Act
|
|
526
|
+
group1.disconnect(node);
|
|
527
|
+
});
|
|
528
|
+
it('does not change group1 out connections to group2', () => {
|
|
529
|
+
expect(group1.connections.out).toHaveLength(size1 * size2);
|
|
530
|
+
});
|
|
531
|
+
it('removes node in connections', () => {
|
|
532
|
+
expect(node.connections.in).toHaveLength(0);
|
|
533
|
+
});
|
|
534
|
+
});
|
|
535
|
+
|
|
536
|
+
describe('when disconnecting two-sided', () => {
|
|
537
|
+
beforeEach(() => {
|
|
538
|
+
// Arrange
|
|
539
|
+
node.connect(group1.nodes[0]);
|
|
540
|
+
group1.connections.in.push(node.connections.out[0]);
|
|
541
|
+
// Act
|
|
542
|
+
group1.disconnect(node, true);
|
|
543
|
+
});
|
|
544
|
+
it('does not change group1 out connections to group2', () => {
|
|
545
|
+
expect(group1.connections.out).toHaveLength(size1 * size2);
|
|
546
|
+
});
|
|
547
|
+
it('removes group1 in connections from node', () => {
|
|
548
|
+
expect(group1.connections.in).toHaveLength(size1 * size2);
|
|
549
|
+
});
|
|
550
|
+
it('removes node in connections', () => {
|
|
551
|
+
expect(node.connections.in).toHaveLength(0);
|
|
552
|
+
});
|
|
553
|
+
it('removes node out connections', () => {
|
|
554
|
+
expect(node.connections.out).toHaveLength(0);
|
|
555
|
+
});
|
|
556
|
+
});
|
|
557
|
+
});
|
|
558
|
+
});
|
|
559
|
+
|
|
560
|
+
describe('clear()', () => {
|
|
561
|
+
const size = 3;
|
|
562
|
+
let group: Group;
|
|
563
|
+
|
|
564
|
+
beforeEach(() => {
|
|
565
|
+
// Arrange
|
|
566
|
+
group = new Group(size);
|
|
567
|
+
// Set non-default values
|
|
568
|
+
group.nodes.forEach((node) => {
|
|
569
|
+
node.state = 1;
|
|
570
|
+
node.old = 2;
|
|
571
|
+
node.activation = 3;
|
|
572
|
+
node.derivative = 4;
|
|
573
|
+
});
|
|
574
|
+
});
|
|
575
|
+
|
|
576
|
+
beforeEach(() => {
|
|
577
|
+
// Act
|
|
578
|
+
group.clear();
|
|
579
|
+
});
|
|
580
|
+
|
|
581
|
+
it('resets state for all nodes', () => {
|
|
582
|
+
group.nodes.forEach((node) => {
|
|
583
|
+
expect(node.state).toBe(0);
|
|
584
|
+
});
|
|
585
|
+
});
|
|
586
|
+
it('resets old for all nodes', () => {
|
|
587
|
+
group.nodes.forEach((node) => {
|
|
588
|
+
expect(node.old).toBe(0);
|
|
589
|
+
});
|
|
590
|
+
});
|
|
591
|
+
it('resets activation for all nodes', () => {
|
|
592
|
+
group.nodes.forEach((node) => {
|
|
593
|
+
expect(node.activation).toBe(0);
|
|
594
|
+
});
|
|
595
|
+
});
|
|
596
|
+
it('resets derivative for all nodes', () => {
|
|
597
|
+
group.nodes.forEach((node) => {
|
|
598
|
+
expect(node.derivative).toBe(4);
|
|
599
|
+
});
|
|
600
|
+
});
|
|
601
|
+
});
|
|
602
|
+
|
|
603
|
+
describe('toJSON()', () => {
|
|
604
|
+
describe('when serializing an empty group', () => {
|
|
605
|
+
let group: Group;
|
|
606
|
+
let json: any;
|
|
607
|
+
beforeEach(() => {
|
|
608
|
+
// Arrange
|
|
609
|
+
group = new Group(2);
|
|
610
|
+
group.nodes[0].index = 10;
|
|
611
|
+
group.nodes[1].index = 11;
|
|
612
|
+
// Act
|
|
613
|
+
json = group.toJSON();
|
|
614
|
+
});
|
|
615
|
+
it('serializes size', () => {
|
|
616
|
+
expect(json.size).toBe(2);
|
|
617
|
+
});
|
|
618
|
+
it('serializes nodeIndices', () => {
|
|
619
|
+
expect(json.nodeIndices).toEqual([10, 11]);
|
|
620
|
+
});
|
|
621
|
+
it('serializes connections.in', () => {
|
|
622
|
+
expect(json.connections.in).toBe(0);
|
|
623
|
+
});
|
|
624
|
+
it('serializes connections.out', () => {
|
|
625
|
+
expect(json.connections.out).toBe(0);
|
|
626
|
+
});
|
|
627
|
+
it('serializes connections.self', () => {
|
|
628
|
+
expect(json.connections.self).toBe(0);
|
|
629
|
+
});
|
|
630
|
+
});
|
|
631
|
+
|
|
632
|
+
describe('when serializing group after connections', () => {
|
|
633
|
+
let group1: Group;
|
|
634
|
+
let group2: Group;
|
|
635
|
+
let json1: any;
|
|
636
|
+
let json2: any;
|
|
637
|
+
beforeEach(() => {
|
|
638
|
+
// Arrange
|
|
639
|
+
group1 = new Group(2);
|
|
640
|
+
group2 = new Group(2);
|
|
641
|
+
group1.nodes.forEach((n, i) => (n.index = i));
|
|
642
|
+
group2.nodes.forEach((n, i) => (n.index = i + 2));
|
|
643
|
+
group1.connect(group2, methods.groupConnection.ALL_TO_ALL);
|
|
644
|
+
// Act
|
|
645
|
+
json1 = group1.toJSON();
|
|
646
|
+
json2 = group2.toJSON();
|
|
647
|
+
});
|
|
648
|
+
it('serializes size for group1', () => {
|
|
649
|
+
expect(json1.size).toBe(2);
|
|
650
|
+
});
|
|
651
|
+
it('serializes out connections for group1', () => {
|
|
652
|
+
expect(json1.connections.out).toBe(4);
|
|
653
|
+
});
|
|
654
|
+
it('serializes in connections for group2', () => {
|
|
655
|
+
expect(json2.connections.in).toBe(4);
|
|
656
|
+
});
|
|
657
|
+
it('serializes nodeIndices for group1', () => {
|
|
658
|
+
expect(json1.nodeIndices).toEqual([0, 1]);
|
|
659
|
+
});
|
|
660
|
+
it('serializes nodeIndices for group2', () => {
|
|
661
|
+
expect(json2.nodeIndices).toEqual([2, 3]);
|
|
662
|
+
});
|
|
663
|
+
});
|
|
664
|
+
|
|
665
|
+
describe('when serializing group after gating', () => {
|
|
666
|
+
let group: Group;
|
|
667
|
+
let node1: Node;
|
|
668
|
+
let node2: Node;
|
|
669
|
+
let conn1: Connection;
|
|
670
|
+
let json: any;
|
|
671
|
+
beforeEach(() => {
|
|
672
|
+
// Arrange
|
|
673
|
+
group = new Group(2);
|
|
674
|
+
node1 = new Node();
|
|
675
|
+
node2 = new Node();
|
|
676
|
+
node1.index = 10;
|
|
677
|
+
node2.index = 11;
|
|
678
|
+
conn1 = node1.connect(node2)[0];
|
|
679
|
+
group.nodes[0].index = 20;
|
|
680
|
+
group.nodes[1].index = 21;
|
|
681
|
+
group.gate([conn1], methods.gating.INPUT);
|
|
682
|
+
// Act
|
|
683
|
+
json = group.toJSON();
|
|
684
|
+
});
|
|
685
|
+
it('serializes size', () => {
|
|
686
|
+
expect(json.size).toBe(2);
|
|
687
|
+
});
|
|
688
|
+
it('serializes nodeIndices', () => {
|
|
689
|
+
expect(json.nodeIndices).toEqual([20, 21]);
|
|
690
|
+
});
|
|
691
|
+
it('serializes connections.in', () => {
|
|
692
|
+
expect(json.connections.in).toBe(0);
|
|
693
|
+
});
|
|
694
|
+
it('serializes connections.out', () => {
|
|
695
|
+
expect(json.connections.out).toBe(0);
|
|
696
|
+
});
|
|
697
|
+
it('serializes connections.self', () => {
|
|
698
|
+
expect(json.connections.self).toBe(0);
|
|
699
|
+
});
|
|
700
|
+
});
|
|
701
|
+
});
|
|
702
|
+
|
|
703
|
+
describe('connect()', () => {
|
|
704
|
+
let group1: Group;
|
|
705
|
+
let group2: Group;
|
|
706
|
+
let node: Node;
|
|
707
|
+
let layer: Layer;
|
|
708
|
+
const size1 = 3;
|
|
709
|
+
const size2 = 2;
|
|
710
|
+
let originalWarnings: boolean;
|
|
711
|
+
|
|
712
|
+
beforeEach(() => {
|
|
713
|
+
// Arrange
|
|
714
|
+
group1 = new Group(size1);
|
|
715
|
+
group2 = new Group(size2);
|
|
716
|
+
node = new Node();
|
|
717
|
+
layer = new Layer();
|
|
718
|
+
originalWarnings = config.warnings;
|
|
719
|
+
config.warnings = true;
|
|
720
|
+
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
|
721
|
+
});
|
|
722
|
+
|
|
723
|
+
afterEach(() => {
|
|
724
|
+
(console.warn as jest.Mock).mockRestore();
|
|
725
|
+
config.warnings = originalWarnings;
|
|
726
|
+
jest.restoreAllMocks();
|
|
727
|
+
});
|
|
728
|
+
|
|
729
|
+
describe('Scenario: To Group', () => {
|
|
730
|
+
describe('when connecting ALL_TO_ALL by default to a different group', () => {
|
|
731
|
+
let connections: any[];
|
|
732
|
+
beforeEach(() => {
|
|
733
|
+
// Act
|
|
734
|
+
connections = group1.connect(group2);
|
|
735
|
+
});
|
|
736
|
+
it('creates the correct number of connections', () => {
|
|
737
|
+
// Assert
|
|
738
|
+
expect(connections).toHaveLength(size1 * size2);
|
|
739
|
+
});
|
|
740
|
+
it('updates group1.connections.out', () => {
|
|
741
|
+
expect(group1.connections.out).toHaveLength(size1 * size2);
|
|
742
|
+
});
|
|
743
|
+
it('warns about default ALL_TO_ALL', () => {
|
|
744
|
+
expect(console.warn).toHaveBeenCalledWith(
|
|
745
|
+
'No group connection specified, using ALL_TO_ALL by default.'
|
|
746
|
+
);
|
|
747
|
+
});
|
|
748
|
+
it('forms connections between all node pairs', () => {
|
|
749
|
+
let connCount = 0;
|
|
750
|
+
group1.nodes.forEach((fromNode: Node) => {
|
|
751
|
+
group2.nodes.forEach((toNode: Node) => {
|
|
752
|
+
if (fromNode.isConnectedTo(toNode)) {
|
|
753
|
+
connCount++;
|
|
754
|
+
}
|
|
755
|
+
});
|
|
756
|
+
});
|
|
757
|
+
expect(connCount).toBe(size1 * size2);
|
|
758
|
+
});
|
|
759
|
+
});
|
|
760
|
+
|
|
761
|
+
describe('when connecting ONE_TO_ONE by default to the same group', () => {
|
|
762
|
+
let sameSizeGroup: Group;
|
|
763
|
+
let connections: any[];
|
|
764
|
+
beforeEach(() => {
|
|
765
|
+
// Act
|
|
766
|
+
sameSizeGroup = new Group(size1);
|
|
767
|
+
connections = sameSizeGroup.connect(sameSizeGroup);
|
|
768
|
+
});
|
|
769
|
+
it('creates the correct number of connections', () => {
|
|
770
|
+
expect(connections).toHaveLength(size1);
|
|
771
|
+
});
|
|
772
|
+
it('updates self connections', () => {
|
|
773
|
+
expect(sameSizeGroup.connections.self).toHaveLength(size1);
|
|
774
|
+
});
|
|
775
|
+
it('warns about default ONE_TO_ONE', () => {
|
|
776
|
+
expect(console.warn).toHaveBeenCalledWith(
|
|
777
|
+
'Connecting group to itself, using ONE_TO_ONE by default.'
|
|
778
|
+
);
|
|
779
|
+
});
|
|
780
|
+
it('stores self-connection in group', () => {
|
|
781
|
+
sameSizeGroup.nodes.forEach((node: Node, i: number) => {
|
|
782
|
+
const selfConn = sameSizeGroup.connections.self[i];
|
|
783
|
+
expect(selfConn).toBeInstanceOf(Connection);
|
|
784
|
+
expect(selfConn.from).toBe(node);
|
|
785
|
+
expect(selfConn.to).toBe(node);
|
|
786
|
+
expect(node.connections.self[0]).toBe(selfConn);
|
|
787
|
+
});
|
|
788
|
+
});
|
|
789
|
+
});
|
|
790
|
+
|
|
791
|
+
describe('when connecting using specified ALL_TO_ALL method', () => {
|
|
792
|
+
let connections: any[];
|
|
793
|
+
beforeEach(() => {
|
|
794
|
+
// Act
|
|
795
|
+
connections = group1.connect(
|
|
796
|
+
group2,
|
|
797
|
+
methods.groupConnection.ALL_TO_ALL
|
|
798
|
+
);
|
|
799
|
+
});
|
|
800
|
+
it('creates the correct number of connections', () => {
|
|
801
|
+
expect(connections).toHaveLength(size1 * size2);
|
|
802
|
+
});
|
|
803
|
+
it('updates group1.connections.out', () => {
|
|
804
|
+
expect(group1.connections.out).toHaveLength(size1 * size2);
|
|
805
|
+
});
|
|
806
|
+
it('updates group2.connections.in', () => {
|
|
807
|
+
expect(group2.connections.in).toHaveLength(size1 * size2);
|
|
808
|
+
});
|
|
809
|
+
});
|
|
810
|
+
|
|
811
|
+
describe('when connecting using specified ALL_TO_ELSE method', () => {
|
|
812
|
+
let sameSizeGroup: Group;
|
|
813
|
+
let connections: any[];
|
|
814
|
+
const expectedConns = size1 * size1 - size1;
|
|
815
|
+
beforeEach(() => {
|
|
816
|
+
// Act
|
|
817
|
+
sameSizeGroup = new Group(size1);
|
|
818
|
+
connections = sameSizeGroup.connect(
|
|
819
|
+
sameSizeGroup,
|
|
820
|
+
methods.groupConnection.ALL_TO_ELSE
|
|
821
|
+
);
|
|
822
|
+
});
|
|
823
|
+
it('creates the correct number of connections', () => {
|
|
824
|
+
expect(connections).toHaveLength(expectedConns);
|
|
825
|
+
});
|
|
826
|
+
it('updates out and in connections', () => {
|
|
827
|
+
expect(sameSizeGroup.connections.out).toHaveLength(expectedConns);
|
|
828
|
+
expect(sameSizeGroup.connections.in).toHaveLength(expectedConns);
|
|
829
|
+
});
|
|
830
|
+
it('does not create self-connections', () => {
|
|
831
|
+
sameSizeGroup.nodes.forEach((node: Node, i: number) => {
|
|
832
|
+
expect(node.isConnectedTo(sameSizeGroup.nodes[i])).toBe(false);
|
|
833
|
+
});
|
|
834
|
+
});
|
|
835
|
+
});
|
|
836
|
+
|
|
837
|
+
describe('when connecting using specified ONE_TO_ONE method', () => {
|
|
838
|
+
let sameSizeGroup: Group;
|
|
839
|
+
let connections: any[];
|
|
840
|
+
beforeEach(() => {
|
|
841
|
+
// Act
|
|
842
|
+
sameSizeGroup = new Group(size1);
|
|
843
|
+
connections = group1.connect(
|
|
844
|
+
sameSizeGroup,
|
|
845
|
+
methods.groupConnection.ONE_TO_ONE
|
|
846
|
+
);
|
|
847
|
+
});
|
|
848
|
+
it('creates the correct number of connections', () => {
|
|
849
|
+
expect(connections).toHaveLength(size1);
|
|
850
|
+
});
|
|
851
|
+
it('updates out and in connections', () => {
|
|
852
|
+
expect(group1.connections.out).toHaveLength(size1);
|
|
853
|
+
expect(sameSizeGroup.connections.in).toHaveLength(size1);
|
|
854
|
+
});
|
|
855
|
+
it('does not update self connections for different groups', () => {
|
|
856
|
+
expect(group1.connections.self).toHaveLength(0);
|
|
857
|
+
expect(sameSizeGroup.connections.self).toHaveLength(0);
|
|
858
|
+
});
|
|
859
|
+
it('connects corresponding nodes', () => {
|
|
860
|
+
group1.nodes.forEach((node: Node, i: number) => {
|
|
861
|
+
expect(node.isConnectedTo(sameSizeGroup.nodes[i])).toBe(true);
|
|
862
|
+
});
|
|
863
|
+
});
|
|
864
|
+
});
|
|
865
|
+
|
|
866
|
+
describe('when connecting ONE_TO_ONE with different group sizes', () => {
|
|
867
|
+
it('throws an error', () => {
|
|
868
|
+
expect(() =>
|
|
869
|
+
group1.connect(group2, methods.groupConnection.ONE_TO_ONE)
|
|
870
|
+
).toThrow(
|
|
871
|
+
'Cannot create ONE_TO_ONE connection: source and target groups must have the same size.'
|
|
872
|
+
);
|
|
873
|
+
});
|
|
874
|
+
});
|
|
875
|
+
|
|
876
|
+
describe('when connecting with specified weight', () => {
|
|
877
|
+
const weight = 0.75;
|
|
878
|
+
let connections: any[];
|
|
879
|
+
beforeEach(() => {
|
|
880
|
+
// Act
|
|
881
|
+
connections = group1.connect(
|
|
882
|
+
group2,
|
|
883
|
+
methods.groupConnection.ALL_TO_ALL,
|
|
884
|
+
weight
|
|
885
|
+
);
|
|
886
|
+
});
|
|
887
|
+
it('sets the weight for all connections', () => {
|
|
888
|
+
connections.forEach((conn: Connection) => {
|
|
889
|
+
expect(conn.weight).toBe(weight);
|
|
890
|
+
});
|
|
891
|
+
});
|
|
892
|
+
});
|
|
893
|
+
});
|
|
894
|
+
|
|
895
|
+
describe('Scenario: To Layer', () => {
|
|
896
|
+
it('delegates connection to Layer.input()', () => {
|
|
897
|
+
const layer = new Layer();
|
|
898
|
+
const layerInputSpy = jest
|
|
899
|
+
.spyOn(layer, 'input')
|
|
900
|
+
.mockImplementation(() => []);
|
|
901
|
+
const method = methods.groupConnection.ALL_TO_ALL;
|
|
902
|
+
const weight = 0.5;
|
|
903
|
+
const group1 = new Group(3);
|
|
904
|
+
|
|
905
|
+
group1.connect(layer, method, weight);
|
|
906
|
+
|
|
907
|
+
expect(layerInputSpy).toHaveBeenCalledTimes(1);
|
|
908
|
+
expect(layerInputSpy).toHaveBeenCalledWith(group1, method, weight);
|
|
909
|
+
|
|
910
|
+
layerInputSpy.mockRestore();
|
|
911
|
+
});
|
|
912
|
+
});
|
|
913
|
+
|
|
914
|
+
describe('Scenario: To Node', () => {
|
|
915
|
+
let connections: any[];
|
|
916
|
+
beforeEach(() => {
|
|
917
|
+
// Act
|
|
918
|
+
connections = group1.connect(node);
|
|
919
|
+
});
|
|
920
|
+
it('creates the correct number of connections', () => {
|
|
921
|
+
expect(connections).toHaveLength(size1);
|
|
922
|
+
});
|
|
923
|
+
it('updates group1.connections.out', () => {
|
|
924
|
+
expect(group1.connections.out).toHaveLength(size1);
|
|
925
|
+
});
|
|
926
|
+
it('connects all nodes in group to the target node', () => {
|
|
927
|
+
group1.nodes.forEach((fromNode: Node) => {
|
|
928
|
+
expect(fromNode.isConnectedTo(node)).toBe(true);
|
|
929
|
+
});
|
|
930
|
+
});
|
|
931
|
+
it('updates node.connections.in', () => {
|
|
932
|
+
expect(node.connections.in).toHaveLength(size1);
|
|
933
|
+
});
|
|
934
|
+
});
|
|
935
|
+
|
|
936
|
+
describe('when connecting to node with specified weight', () => {
|
|
937
|
+
const weight = -0.3;
|
|
938
|
+
let connections: any[];
|
|
939
|
+
beforeEach(() => {
|
|
940
|
+
// Act
|
|
941
|
+
connections = group1.connect(node, undefined, weight);
|
|
942
|
+
});
|
|
943
|
+
it('sets the weight for all connections', () => {
|
|
944
|
+
connections.forEach((conn: Connection) => {
|
|
945
|
+
expect(conn.weight).toBe(weight);
|
|
946
|
+
});
|
|
947
|
+
});
|
|
948
|
+
});
|
|
949
|
+
});
|
|
950
|
+
});
|