@reicek/neataptic-ts 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. package/.github/ISSUE_TEMPLATE/bug_report.md +33 -0
  2. package/.github/ISSUE_TEMPLATE/feature_request.md +27 -0
  3. package/.github/PULL_REQUEST_TEMPLATE.md +28 -0
  4. package/.github/workflows/ci.yml +41 -0
  5. package/.github/workflows/deploy-pages.yml +29 -0
  6. package/.github/workflows/manual_release_pipeline.yml +62 -0
  7. package/.github/workflows/publish.yml +85 -0
  8. package/.github/workflows/release_dispatch.yml +38 -0
  9. package/.travis.yml +5 -0
  10. package/CONTRIBUTING.md +92 -0
  11. package/LICENSE +24 -0
  12. package/ONNX_EXPORT.md +87 -0
  13. package/README.md +1173 -0
  14. package/RELEASE.md +54 -0
  15. package/dist-docs/package.json +1 -0
  16. package/dist-docs/scripts/generate-docs.d.ts +2 -0
  17. package/dist-docs/scripts/generate-docs.d.ts.map +1 -0
  18. package/dist-docs/scripts/generate-docs.js +536 -0
  19. package/dist-docs/scripts/generate-docs.js.map +1 -0
  20. package/dist-docs/scripts/render-docs-html.d.ts +2 -0
  21. package/dist-docs/scripts/render-docs-html.d.ts.map +1 -0
  22. package/dist-docs/scripts/render-docs-html.js +148 -0
  23. package/dist-docs/scripts/render-docs-html.js.map +1 -0
  24. package/docs/FOLDERS.md +14 -0
  25. package/docs/README.md +1173 -0
  26. package/docs/architecture/README.md +1391 -0
  27. package/docs/architecture/index.html +938 -0
  28. package/docs/architecture/network/README.md +1210 -0
  29. package/docs/architecture/network/index.html +908 -0
  30. package/docs/assets/ascii-maze.bundle.js +16542 -0
  31. package/docs/assets/ascii-maze.bundle.js.map +7 -0
  32. package/docs/index.html +1419 -0
  33. package/docs/methods/README.md +670 -0
  34. package/docs/methods/index.html +477 -0
  35. package/docs/multithreading/README.md +274 -0
  36. package/docs/multithreading/index.html +215 -0
  37. package/docs/multithreading/workers/README.md +23 -0
  38. package/docs/multithreading/workers/browser/README.md +39 -0
  39. package/docs/multithreading/workers/browser/index.html +70 -0
  40. package/docs/multithreading/workers/index.html +57 -0
  41. package/docs/multithreading/workers/node/README.md +33 -0
  42. package/docs/multithreading/workers/node/index.html +66 -0
  43. package/docs/neat/README.md +1284 -0
  44. package/docs/neat/index.html +906 -0
  45. package/docs/src/README.md +2659 -0
  46. package/docs/src/index.html +1579 -0
  47. package/jest.config.ts +32 -0
  48. package/package.json +99 -0
  49. package/plans/HyperMorphoNEAT.md +293 -0
  50. package/plans/ONNX_EXPORT_PLAN.md +46 -0
  51. package/scripts/generate-docs.ts +486 -0
  52. package/scripts/render-docs-html.ts +138 -0
  53. package/scripts/types.d.ts +2 -0
  54. package/src/README.md +2659 -0
  55. package/src/architecture/README.md +1391 -0
  56. package/src/architecture/activationArrayPool.ts +135 -0
  57. package/src/architecture/architect.ts +635 -0
  58. package/src/architecture/connection.ts +148 -0
  59. package/src/architecture/group.ts +406 -0
  60. package/src/architecture/layer.ts +804 -0
  61. package/src/architecture/network/README.md +1210 -0
  62. package/src/architecture/network/network.activate.ts +223 -0
  63. package/src/architecture/network/network.connect.ts +157 -0
  64. package/src/architecture/network/network.deterministic.ts +167 -0
  65. package/src/architecture/network/network.evolve.ts +426 -0
  66. package/src/architecture/network/network.gating.ts +186 -0
  67. package/src/architecture/network/network.genetic.ts +247 -0
  68. package/src/architecture/network/network.mutate.ts +624 -0
  69. package/src/architecture/network/network.onnx.ts +463 -0
  70. package/src/architecture/network/network.prune.ts +216 -0
  71. package/src/architecture/network/network.remove.ts +96 -0
  72. package/src/architecture/network/network.serialize.ts +309 -0
  73. package/src/architecture/network/network.slab.ts +262 -0
  74. package/src/architecture/network/network.standalone.ts +246 -0
  75. package/src/architecture/network/network.stats.ts +59 -0
  76. package/src/architecture/network/network.topology.ts +86 -0
  77. package/src/architecture/network/network.training.ts +1278 -0
  78. package/src/architecture/network.ts +1302 -0
  79. package/src/architecture/node.ts +1288 -0
  80. package/src/architecture/onnx.ts +3 -0
  81. package/src/config.ts +83 -0
  82. package/src/methods/README.md +670 -0
  83. package/src/methods/activation.ts +372 -0
  84. package/src/methods/connection.ts +31 -0
  85. package/src/methods/cost.ts +347 -0
  86. package/src/methods/crossover.ts +63 -0
  87. package/src/methods/gating.ts +43 -0
  88. package/src/methods/methods.ts +8 -0
  89. package/src/methods/mutation.ts +300 -0
  90. package/src/methods/rate.ts +257 -0
  91. package/src/methods/selection.ts +65 -0
  92. package/src/multithreading/README.md +274 -0
  93. package/src/multithreading/multi.ts +339 -0
  94. package/src/multithreading/workers/README.md +23 -0
  95. package/src/multithreading/workers/browser/README.md +39 -0
  96. package/src/multithreading/workers/browser/testworker.ts +99 -0
  97. package/src/multithreading/workers/node/README.md +33 -0
  98. package/src/multithreading/workers/node/testworker.ts +72 -0
  99. package/src/multithreading/workers/node/worker.ts +70 -0
  100. package/src/multithreading/workers/workers.ts +22 -0
  101. package/src/neat/README.md +1284 -0
  102. package/src/neat/neat.adaptive.ts +544 -0
  103. package/src/neat/neat.compat.ts +164 -0
  104. package/src/neat/neat.constants.ts +20 -0
  105. package/src/neat/neat.diversity.ts +217 -0
  106. package/src/neat/neat.evaluate.ts +328 -0
  107. package/src/neat/neat.evolve.ts +1026 -0
  108. package/src/neat/neat.export.ts +249 -0
  109. package/src/neat/neat.helpers.ts +235 -0
  110. package/src/neat/neat.lineage.ts +220 -0
  111. package/src/neat/neat.multiobjective.ts +260 -0
  112. package/src/neat/neat.mutation.ts +718 -0
  113. package/src/neat/neat.objectives.ts +157 -0
  114. package/src/neat/neat.pruning.ts +190 -0
  115. package/src/neat/neat.selection.ts +269 -0
  116. package/src/neat/neat.speciation.ts +460 -0
  117. package/src/neat/neat.species.ts +151 -0
  118. package/src/neat/neat.telemetry.exports.ts +469 -0
  119. package/src/neat/neat.telemetry.ts +933 -0
  120. package/src/neat/neat.types.ts +275 -0
  121. package/src/neat.ts +1042 -0
  122. package/src/neataptic.ts +10 -0
  123. package/test/architecture/activationArrayPool.capacity.test.ts +19 -0
  124. package/test/architecture/activationArrayPool.test.ts +46 -0
  125. package/test/architecture/connection.test.ts +290 -0
  126. package/test/architecture/group.test.ts +950 -0
  127. package/test/architecture/layer.test.ts +1535 -0
  128. package/test/architecture/network.pruning.test.ts +65 -0
  129. package/test/architecture/node.test.ts +1602 -0
  130. package/test/examples/asciiMaze/asciiMaze.e2e.test.ts +499 -0
  131. package/test/examples/asciiMaze/asciiMaze.ts +41 -0
  132. package/test/examples/asciiMaze/browser-entry.ts +164 -0
  133. package/test/examples/asciiMaze/browserLogger.ts +221 -0
  134. package/test/examples/asciiMaze/browserTerminalUtility.ts +48 -0
  135. package/test/examples/asciiMaze/colors.ts +119 -0
  136. package/test/examples/asciiMaze/dashboardManager.ts +968 -0
  137. package/test/examples/asciiMaze/evolutionEngine.ts +1248 -0
  138. package/test/examples/asciiMaze/fitness.ts +136 -0
  139. package/test/examples/asciiMaze/index.html +128 -0
  140. package/test/examples/asciiMaze/index.ts +26 -0
  141. package/test/examples/asciiMaze/interfaces.ts +235 -0
  142. package/test/examples/asciiMaze/mazeMovement.ts +996 -0
  143. package/test/examples/asciiMaze/mazeUtils.ts +278 -0
  144. package/test/examples/asciiMaze/mazeVision.ts +402 -0
  145. package/test/examples/asciiMaze/mazeVisualization.ts +585 -0
  146. package/test/examples/asciiMaze/mazes.ts +245 -0
  147. package/test/examples/asciiMaze/networkRefinement.ts +76 -0
  148. package/test/examples/asciiMaze/networkVisualization.ts +901 -0
  149. package/test/examples/asciiMaze/terminalUtility.ts +73 -0
  150. package/test/methods/activation.test.ts +1142 -0
  151. package/test/methods/connection.test.ts +146 -0
  152. package/test/methods/cost.test.ts +1123 -0
  153. package/test/methods/crossover.test.ts +202 -0
  154. package/test/methods/gating.test.ts +144 -0
  155. package/test/methods/mutation.test.ts +451 -0
  156. package/test/methods/optimizers.advanced.test.ts +80 -0
  157. package/test/methods/optimizers.behavior.test.ts +105 -0
  158. package/test/methods/optimizers.formula.test.ts +89 -0
  159. package/test/methods/rate.cosineWarmRestarts.test.ts +44 -0
  160. package/test/methods/rate.linearWarmupDecay.test.ts +41 -0
  161. package/test/methods/rate.reduceOnPlateau.test.ts +45 -0
  162. package/test/methods/rate.test.ts +684 -0
  163. package/test/methods/selection.test.ts +245 -0
  164. package/test/multithreading/activations.functions.test.ts +54 -0
  165. package/test/multithreading/multi.test.ts +290 -0
  166. package/test/multithreading/worker.node.process.test.ts +39 -0
  167. package/test/multithreading/workers.coverage.test.ts +36 -0
  168. package/test/multithreading/workers.dynamic.import.test.ts +8 -0
  169. package/test/neat/neat.adaptive.complexityBudget.test.ts +34 -0
  170. package/test/neat/neat.adaptive.criterion.complexity.test.ts +50 -0
  171. package/test/neat/neat.adaptive.mutation.strategy.test.ts +37 -0
  172. package/test/neat/neat.adaptive.operator.decay.test.ts +31 -0
  173. package/test/neat/neat.adaptive.phasedComplexity.test.ts +25 -0
  174. package/test/neat/neat.adaptive.pruning.test.ts +25 -0
  175. package/test/neat/neat.adaptive.targetSpecies.test.ts +43 -0
  176. package/test/neat/neat.additional.coverage.test.ts +126 -0
  177. package/test/neat/neat.advanced.enhancements.test.ts +85 -0
  178. package/test/neat/neat.advanced.test.ts +589 -0
  179. package/test/neat/neat.diversity.autocompat.test.ts +47 -0
  180. package/test/neat/neat.diversity.metrics.test.ts +21 -0
  181. package/test/neat/neat.diversity.stats.test.ts +44 -0
  182. package/test/neat/neat.enhancements.test.ts +79 -0
  183. package/test/neat/neat.entropy.ancestorAdaptive.test.ts +133 -0
  184. package/test/neat/neat.entropy.compat.csv.test.ts +108 -0
  185. package/test/neat/neat.evolution.pruning.test.ts +39 -0
  186. package/test/neat/neat.fastmode.autotune.test.ts +42 -0
  187. package/test/neat/neat.innovation.test.ts +134 -0
  188. package/test/neat/neat.lineage.antibreeding.test.ts +35 -0
  189. package/test/neat/neat.lineage.entropy.test.ts +56 -0
  190. package/test/neat/neat.lineage.inbreeding.test.ts +49 -0
  191. package/test/neat/neat.lineage.pressure.test.ts +29 -0
  192. package/test/neat/neat.multiobjective.adaptive.test.ts +57 -0
  193. package/test/neat/neat.multiobjective.dynamic.schedule.test.ts +46 -0
  194. package/test/neat/neat.multiobjective.dynamic.test.ts +31 -0
  195. package/test/neat/neat.multiobjective.fastsort.delegation.test.ts +51 -0
  196. package/test/neat/neat.multiobjective.prune.test.ts +39 -0
  197. package/test/neat/neat.multiobjective.test.ts +21 -0
  198. package/test/neat/neat.mutation.undefined.pool.test.ts +24 -0
  199. package/test/neat/neat.objective.events.test.ts +26 -0
  200. package/test/neat/neat.objective.importance.test.ts +21 -0
  201. package/test/neat/neat.objective.lifetimes.test.ts +33 -0
  202. package/test/neat/neat.offspring.allocation.test.ts +22 -0
  203. package/test/neat/neat.operator.bandit.test.ts +17 -0
  204. package/test/neat/neat.operator.phases.test.ts +38 -0
  205. package/test/neat/neat.pruneInactive.behavior.test.ts +54 -0
  206. package/test/neat/neat.reenable.adaptation.test.ts +18 -0
  207. package/test/neat/neat.rng.state.test.ts +22 -0
  208. package/test/neat/neat.spawn.add.test.ts +123 -0
  209. package/test/neat/neat.speciation.test.ts +96 -0
  210. package/test/neat/neat.species.allocation.telemetry.test.ts +26 -0
  211. package/test/neat/neat.species.history.csv.test.ts +24 -0
  212. package/test/neat/neat.telemetry.advanced.test.ts +226 -0
  213. package/test/neat/neat.telemetry.csv.lineage.test.ts +19 -0
  214. package/test/neat/neat.telemetry.parity.test.ts +42 -0
  215. package/test/neat/neat.telemetry.stream.test.ts +19 -0
  216. package/test/neat/neat.telemetry.test.ts +16 -0
  217. package/test/neat/neat.test.ts +422 -0
  218. package/test/neat/neat.utilities.test.ts +44 -0
  219. package/test/network/__suppress_console.ts +9 -0
  220. package/test/network/acyclic.topoorder.test.ts +17 -0
  221. package/test/network/checkpoint.metricshook.test.ts +36 -0
  222. package/test/network/error.handling.test.ts +581 -0
  223. package/test/network/evolution.test.ts +285 -0
  224. package/test/network/genetic.test.ts +208 -0
  225. package/test/network/learning.capability.test.ts +244 -0
  226. package/test/network/mutation.effects.test.ts +492 -0
  227. package/test/network/network.activate.test.ts +115 -0
  228. package/test/network/network.activateBatch.test.ts +30 -0
  229. package/test/network/network.deterministic.test.ts +64 -0
  230. package/test/network/network.evolve.branches.test.ts +75 -0
  231. package/test/network/network.evolve.multithread.branches.test.ts +83 -0
  232. package/test/network/network.evolve.test.ts +100 -0
  233. package/test/network/network.gating.removal.test.ts +93 -0
  234. package/test/network/network.mutate.additional.test.ts +145 -0
  235. package/test/network/network.mutate.edgecases.test.ts +101 -0
  236. package/test/network/network.mutate.test.ts +101 -0
  237. package/test/network/network.prune.earlyexit.test.ts +38 -0
  238. package/test/network/network.remove.errors.test.ts +45 -0
  239. package/test/network/network.slab.fallbacks.test.ts +22 -0
  240. package/test/network/network.stats.test.ts +45 -0
  241. package/test/network/network.training.advanced.test.ts +149 -0
  242. package/test/network/network.training.basic.test.ts +228 -0
  243. package/test/network/network.training.helpers.test.ts +183 -0
  244. package/test/network/onnx.export.test.ts +310 -0
  245. package/test/network/onnx.import.test.ts +129 -0
  246. package/test/network/pruning.topology.test.ts +282 -0
  247. package/test/network/regularization.determinism.test.ts +83 -0
  248. package/test/network/regularization.dropconnect.test.ts +17 -0
  249. package/test/network/regularization.dropconnect.validation.test.ts +18 -0
  250. package/test/network/regularization.stochasticdepth.test.ts +27 -0
  251. package/test/network/regularization.test.ts +843 -0
  252. package/test/network/regularization.weightnoise.test.ts +30 -0
  253. package/test/network/setupTests.ts +2 -0
  254. package/test/network/standalone.test.ts +332 -0
  255. package/test/network/structure.serialization.test.ts +660 -0
  256. package/test/training/training.determinism.mixed-precision.test.ts +134 -0
  257. package/test/training/training.earlystopping.test.ts +91 -0
  258. package/test/training/training.edge-cases.test.ts +91 -0
  259. package/test/training/training.extensions.test.ts +47 -0
  260. package/test/training/training.gradient.features.test.ts +110 -0
  261. package/test/training/training.gradient.refinements.test.ts +170 -0
  262. package/test/training/training.gradient.separate-bias.test.ts +41 -0
  263. package/test/training/training.optimizer.test.ts +48 -0
  264. package/test/training/training.plateau.smoothing.test.ts +58 -0
  265. package/test/training/training.smoothing.types.test.ts +174 -0
  266. package/test/training/training.train.options.coverage.test.ts +52 -0
  267. package/test/utils/console-helper.ts +76 -0
  268. package/test/utils/jest-setup.ts +60 -0
  269. package/test/utils/test-helpers.ts +175 -0
  270. package/tsconfig.docs.json +12 -0
  271. package/tsconfig.json +21 -0
  272. package/webpack.config.js +49 -0
@@ -0,0 +1,1535 @@
1
+ import Layer from '../../src/architecture/layer';
2
+ import Node from '../../src/architecture/node';
3
+ import Group from '../../src/architecture/group';
4
+ import Connection from '../../src/architecture/connection';
5
+ import * as methods from '../../src/methods/methods';
6
+ import { config } from '../../src/config';
7
+
8
+ // Retry failed tests
9
+ jest.retryTimes(2, { logErrorsBeforeRetry: true });
10
+
11
+ // Helper function to check group connectivity
12
+ const isGroupConnectedTo = (
13
+ groupA: Group,
14
+ groupB: Group,
15
+ method?: any
16
+ ): boolean => {
17
+ if (!groupA || !groupB || !groupA.nodes || !groupB.nodes) return false; // Basic validation
18
+
19
+ if (method === methods.groupConnection.ONE_TO_ONE) {
20
+ if (groupA.nodes.length !== groupB.nodes.length) return false;
21
+ return groupA.nodes.every((nodeA, i) => {
22
+ const nodeB = groupB.nodes[i];
23
+ if (!nodeA || !nodeB) return false; // Node validation
24
+ if (nodeA.connections.out.some((conn) => conn.to === nodeB)) return true;
25
+ if (
26
+ nodeA === nodeB &&
27
+ nodeA.connections.self.some((conn: Connection) => conn.to === nodeB)
28
+ )
29
+ return true;
30
+ return false;
31
+ });
32
+ } else {
33
+ return groupA.nodes.some(
34
+ (nodeA) =>
35
+ nodeA &&
36
+ groupB.nodes.some((nodeB) => {
37
+ if (!nodeB) return false; // Node validation
38
+ if (nodeA.connections.out.some((conn) => conn.to === nodeB))
39
+ return true;
40
+ if (
41
+ nodeA === nodeB &&
42
+ nodeA.connections.self.some((conn: Connection) => conn.to === nodeB)
43
+ )
44
+ return true;
45
+ return false;
46
+ })
47
+ );
48
+ }
49
+ };
50
+
51
+ describe('Layer', () => {
52
+ const epsilon = 1e-9; // Tolerance for float comparisons
53
+
54
+ describe('Constructor', () => {
55
+ let layer: Layer;
56
+
57
+ beforeEach(() => {
58
+ layer = new Layer();
59
+ });
60
+
61
+ it('should initialize output to null', () => {
62
+ // Arrange, Act & Assert
63
+ expect(layer.output).toBeNull();
64
+ });
65
+
66
+ it('should initialize nodes as an empty array', () => {
67
+ // Arrange, Act & Assert
68
+ expect(layer.nodes).toEqual([]);
69
+ });
70
+
71
+ it('should initialize connections.in as an empty array', () => {
72
+ // Arrange, Act & Assert
73
+ expect(layer.connections.in).toEqual([]);
74
+ });
75
+
76
+ it('should initialize connections.out as an empty array', () => {
77
+ // Arrange, Act & Assert
78
+ expect(layer.connections.out).toEqual([]);
79
+ });
80
+
81
+ it('should initialize connections.self as an empty array', () => {
82
+ // Arrange, Act & Assert
83
+ expect(layer.connections.self).toEqual([]);
84
+ });
85
+ });
86
+
87
+ const createTestLayer = (size: number): Layer => {
88
+ const layer = new Layer();
89
+ for (let i = 0; i < size; i++) {
90
+ const node = new Node();
91
+ layer.nodes.push(node);
92
+ }
93
+ const group = new Group(size);
94
+ layer.output = group;
95
+ return layer;
96
+ };
97
+
98
+ describe('Instance Methods', () => {
99
+ describe('activate()', () => {
100
+ describe('Scenario: input layer', () => {
101
+ it('should return activation values from input nodes with input', () => {
102
+ // Arrange
103
+ const size = 3;
104
+ const inputValues = [0.5, -0.2, 0.9];
105
+ const layer = new Layer();
106
+ for (let i = 0; i < size; i++) {
107
+ const node = new Node('input');
108
+ layer.nodes.push(node);
109
+ }
110
+ // Act
111
+ const activations = layer.activate(inputValues);
112
+ // Assert
113
+ expect(activations).toHaveLength(size);
114
+ activations.forEach((act, i) => {
115
+ expect(act).toBe(inputValues[i]);
116
+ });
117
+ });
118
+ });
119
+ describe('Scenario: hidden layer', () => {
120
+ it('should return activation values after applying activation function', () => {
121
+ // Arrange
122
+ const size = 3;
123
+ const inputValues = [0.5, -0.2, 0.9];
124
+ const layer = new Layer();
125
+ for (let i = 0; i < size; i++) {
126
+ const node = new Node('hidden');
127
+ layer.nodes.push(node);
128
+ }
129
+ // Act
130
+ const activations = layer.activate(inputValues);
131
+ // Assert
132
+ expect(activations).toHaveLength(size);
133
+ activations.forEach((act, i) => {
134
+ // Default activation is sigmoid
135
+ expect(act).toBeCloseTo(1 / (1 + Math.exp(-inputValues[i])), 10);
136
+ });
137
+ });
138
+ });
139
+ });
140
+
141
+ describe('propagate()', () => {
142
+ const size = 3;
143
+ let layer: Layer;
144
+ let nodeSpies: jest.SpyInstance[];
145
+ const rate = 0.1;
146
+ const momentum = 0.9;
147
+
148
+ beforeEach(() => {
149
+ layer = createTestLayer(size);
150
+ nodeSpies = layer.nodes.map((node) => jest.spyOn(node, 'propagate'));
151
+ });
152
+
153
+ afterEach(() => {
154
+ nodeSpies.forEach((spy) => spy.mockRestore());
155
+ });
156
+
157
+ it('should call propagate on all nodes without target values', () => {
158
+ // Arrange
159
+ // Act
160
+ layer.propagate(rate, momentum);
161
+ // Assert
162
+ expect(nodeSpies).toHaveLength(size);
163
+ nodeSpies.forEach((spy) => {
164
+ expect(spy).toHaveBeenCalledTimes(1);
165
+ expect(spy).toHaveBeenCalledWith(rate, momentum, true, 0);
166
+ });
167
+ });
168
+
169
+ it('should call propagate on nodes in reverse order', () => {
170
+ // Arrange
171
+ const callOrder: number[] = [];
172
+ nodeSpies.forEach((spy, i) =>
173
+ spy.mockImplementation(() => callOrder.push(i))
174
+ );
175
+ // Act
176
+ layer.propagate(rate, momentum);
177
+ // Assert
178
+ expect(callOrder).toEqual([2, 1, 0]);
179
+ });
180
+
181
+ it('should call propagate on all nodes with target values', () => {
182
+ // Arrange
183
+ const targetValues = [0.8, 0.1, 0.5];
184
+ // Act
185
+ layer.propagate(rate, momentum, targetValues);
186
+ // Assert
187
+ expect(nodeSpies).toHaveLength(size);
188
+ nodeSpies.reverse().forEach((spy, i) => {
189
+ const originalIndex = size - 1 - i;
190
+ expect(spy).toHaveBeenCalledTimes(1);
191
+ expect(spy).toHaveBeenCalledWith(
192
+ rate,
193
+ momentum,
194
+ true,
195
+ 0,
196
+ targetValues[originalIndex]
197
+ );
198
+ });
199
+ });
200
+
201
+ it('should throw error if target value array length mismatches layer node count', () => {
202
+ // Arrange
203
+ const invalidTarget = [0.1, 0.2];
204
+ // Act & Assert
205
+ expect(() => layer.propagate(rate, momentum, invalidTarget)).toThrow(
206
+ 'Array with values should be same as the amount of nodes!'
207
+ );
208
+ });
209
+ });
210
+
211
+ describe('connect()', () => {
212
+ let sourceLayer: Layer;
213
+ let targetLayer: Layer;
214
+ let targetGroup: Group;
215
+ let targetNode: Node;
216
+ let sourceOutputSpy: jest.SpyInstance | undefined;
217
+ let targetInputSpy: jest.SpyInstance;
218
+
219
+ beforeEach(() => {
220
+ sourceLayer = createTestLayer(3);
221
+ targetLayer = createTestLayer(3);
222
+ targetGroup = new Group(2);
223
+ targetNode = new Node();
224
+
225
+ if (sourceLayer.output) {
226
+ sourceOutputSpy = jest.spyOn(sourceLayer.output, 'connect');
227
+ }
228
+ targetInputSpy = jest.spyOn(targetLayer, 'input');
229
+ });
230
+
231
+ afterEach(() => {
232
+ sourceOutputSpy?.mockRestore();
233
+ targetInputSpy?.mockRestore();
234
+ });
235
+
236
+ it('should throw error if source layer output is not defined', () => {
237
+ // Arrange
238
+ const layerWithoutOutput = new Layer();
239
+ // Act & Assert
240
+ expect(() => layerWithoutOutput.connect(targetGroup)).toThrow(
241
+ 'Layer output is not defined. Cannot connect from this layer.'
242
+ );
243
+ });
244
+
245
+ it('should call output.connect when connecting to a Group', () => {
246
+ // Arrange, Act
247
+ const method = methods.groupConnection.ALL_TO_ALL;
248
+ const weight = 0.5;
249
+ sourceLayer.connect(targetGroup, method, weight);
250
+ // Assert
251
+ expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
252
+ expect(sourceOutputSpy).toHaveBeenCalledWith(
253
+ targetGroup,
254
+ method,
255
+ weight
256
+ );
257
+ });
258
+
259
+ it('should call output.connect when connecting to a Node', () => {
260
+ // Arrange, Act
261
+ const weight = 0.6;
262
+ sourceLayer.connect(targetNode, undefined, weight);
263
+ // Assert
264
+ expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
265
+ expect(sourceOutputSpy).toHaveBeenCalledWith(
266
+ targetNode,
267
+ undefined,
268
+ weight
269
+ );
270
+ });
271
+
272
+ it('should call targetLayer.input when connecting to another Layer', () => {
273
+ // Arrange, Act
274
+ const method = methods.groupConnection.ONE_TO_ONE;
275
+ const weight = 0.7;
276
+ sourceLayer.connect(targetLayer, method, weight);
277
+ // Assert
278
+ expect(targetInputSpy).toHaveBeenCalledTimes(1);
279
+ expect(targetInputSpy).toHaveBeenCalledWith(
280
+ sourceLayer,
281
+ method,
282
+ weight
283
+ );
284
+ expect(sourceOutputSpy).toHaveBeenCalledTimes(1);
285
+ expect(sourceOutputSpy).toHaveBeenCalledWith(
286
+ targetLayer.output,
287
+ method,
288
+ weight
289
+ );
290
+ });
291
+
292
+ it('should return the connections created by output.connect (Group target)', () => {
293
+ // Arrange
294
+ const mockConnections = [new Connection(new Node(), new Node())];
295
+ sourceOutputSpy?.mockReturnValue(mockConnections);
296
+ // Act
297
+ const result = sourceLayer.connect(targetGroup);
298
+ // Assert
299
+ expect(result).toBe(mockConnections);
300
+ });
301
+
302
+ it('should return the connections created by targetLayer.input (Layer target)', () => {
303
+ // Arrange
304
+ const mockConnections = [new Connection(new Node(), new Node())];
305
+ targetInputSpy.mockReturnValue(mockConnections);
306
+ // Act
307
+ const result = sourceLayer.connect(targetLayer);
308
+ // Assert
309
+ expect(result).toBe(mockConnections);
310
+ });
311
+ });
312
+
313
+ describe('gate()', () => {
314
+ let layer: Layer;
315
+ let connectionsToGate: Connection[];
316
+ let outputGroupSpy: jest.SpyInstance;
317
+
318
+ beforeEach(() => {
319
+ layer = createTestLayer(2);
320
+ const node1 = new Node();
321
+ const node2 = new Node();
322
+ connectionsToGate = [new Connection(node1, node2)];
323
+
324
+ if (layer.output) {
325
+ outputGroupSpy = jest.spyOn(layer.output, 'gate');
326
+ }
327
+ });
328
+
329
+ afterEach(() => {
330
+ outputGroupSpy?.mockRestore();
331
+ });
332
+
333
+ it('should throw error if layer output is not defined', () => {
334
+ // Arrange
335
+ const layerWithoutOutput = new Layer();
336
+ // Act & Assert
337
+ expect(() =>
338
+ layerWithoutOutput.gate(connectionsToGate, methods.gating.INPUT)
339
+ ).toThrow('Layer output is not defined. Cannot gate from this layer.');
340
+ });
341
+
342
+ it('should call output.gate with the provided connections and method', () => {
343
+ // Arrange, Act
344
+ const method = methods.gating.OUTPUT;
345
+ layer.gate(connectionsToGate, method);
346
+ // Assert
347
+ expect(outputGroupSpy).toHaveBeenCalledTimes(1);
348
+ expect(outputGroupSpy).toHaveBeenCalledWith(connectionsToGate, method);
349
+ });
350
+ });
351
+
352
+ describe('set()', () => {
353
+ const size = 3;
354
+ let layer: Layer;
355
+
356
+ beforeEach(() => {
357
+ layer = createTestLayer(size);
358
+ layer.nodes.forEach((node, i) => {
359
+ node.bias = i * 0.1;
360
+ node.squash = methods.Activation.logistic;
361
+ node.type = 'hidden';
362
+ });
363
+ });
364
+
365
+ it('should set bias for all nodes', () => {
366
+ // Arrange, Act
367
+ const biasValue = 0.7;
368
+ layer.set({ bias: biasValue });
369
+ // Assert
370
+ layer.nodes.forEach((node) => {
371
+ expect(node.bias).toBe(biasValue);
372
+ });
373
+ });
374
+
375
+ it('should set squash function for all nodes', () => {
376
+ // Arrange, Act
377
+ const squashFn = methods.Activation.relu;
378
+ layer.set({ squash: squashFn });
379
+ // Assert
380
+ layer.nodes.forEach((node) => {
381
+ expect(node.squash).toBe(squashFn);
382
+ });
383
+ });
384
+
385
+ it('should set type for all nodes', () => {
386
+ // Arrange, Act
387
+ const typeValue = 'output';
388
+ layer.set({ type: typeValue });
389
+ // Assert
390
+ layer.nodes.forEach((node) => {
391
+ expect(node.type).toBe(typeValue);
392
+ });
393
+ });
394
+
395
+ it('should set multiple properties at once', () => {
396
+ // Arrange, Act
397
+ const biasValue = -0.2;
398
+ const squashFn = methods.Activation.tanh;
399
+ const typeValue = 'input';
400
+ layer.set({ bias: biasValue, squash: squashFn, type: typeValue });
401
+ // Assert
402
+ layer.nodes.forEach((node) => {
403
+ expect(node.bias).toBe(biasValue);
404
+ expect(node.squash).toBe(squashFn);
405
+ expect(node.type).toBe(typeValue);
406
+ });
407
+ });
408
+
409
+ it('should not change properties if not provided in values object', () => {
410
+ // Arrange
411
+ const initialBiases = layer.nodes.map((node) => node.bias);
412
+ const initialSquashes = layer.nodes.map((node) => node.squash);
413
+ const initialTypes = layer.nodes.map((node) => node.type);
414
+ // Act
415
+ layer.set({ bias: 0.99 });
416
+ // Assert
417
+ layer.nodes.forEach((node, i) => {
418
+ expect(node.bias).toBe(0.99);
419
+ expect(node.squash).toBe(initialSquashes[i]);
420
+ expect(node.type).toBe(initialTypes[i]);
421
+ });
422
+ });
423
+
424
+ it('should call set on Group instances within the layer nodes', () => {
425
+ const memoryLayer = Layer.memory(2, 2);
426
+ const groupSetSpies = memoryLayer.nodes.map((groupNode) =>
427
+ jest.spyOn((groupNode as unknown) as Group, 'set')
428
+ );
429
+
430
+ const settings = { bias: 0.1, squash: methods.Activation.relu };
431
+ memoryLayer.set(settings);
432
+
433
+ groupSetSpies.forEach((spy) => {
434
+ expect(spy).toHaveBeenCalledTimes(1);
435
+ expect(spy).toHaveBeenCalledWith(settings);
436
+ });
437
+
438
+ groupSetSpies.forEach((spy) => spy.mockRestore());
439
+ });
440
+ });
441
+
442
+ describe('disconnect()', () => {
443
+ let layer: Layer;
444
+ let targetGroup: Group;
445
+ let targetNode: Node;
446
+ let nodeDisconnectSpies: jest.SpyInstance[];
447
+
448
+ beforeEach(() => {
449
+ layer = createTestLayer(2);
450
+ targetGroup = new Group(2);
451
+ targetNode = new Node();
452
+
453
+ layer.nodes.forEach((node) => {
454
+ targetGroup.nodes.forEach((target) => node.connect(target));
455
+ node.connect(targetNode);
456
+ });
457
+ layer.nodes[0].connections.out.forEach((conn) =>
458
+ layer.connections.out.push(conn)
459
+ );
460
+ layer.nodes[1].connections.out.forEach((conn) =>
461
+ layer.connections.out.push(conn)
462
+ );
463
+
464
+ nodeDisconnectSpies = layer.nodes.map((node) =>
465
+ jest.spyOn(node, 'disconnect')
466
+ );
467
+ });
468
+
469
+ afterEach(() => {
470
+ nodeDisconnectSpies.forEach((spy) => spy.mockRestore());
471
+ });
472
+
473
+ describe('Disconnecting from Group', () => {
474
+ it('should call disconnect on each layer node for each target group node (one-sided)', () => {
475
+ layer.disconnect(targetGroup, false);
476
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(
477
+ targetGroup.nodes.length
478
+ );
479
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(
480
+ targetGroup.nodes.length
481
+ );
482
+ targetGroup.nodes.forEach((target) => {
483
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(target, false);
484
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(target, false);
485
+ });
486
+ });
487
+
488
+ it('should call disconnect on each layer node for each target group node (two-sided)', () => {
489
+ layer.disconnect(targetGroup, true);
490
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(
491
+ targetGroup.nodes.length
492
+ );
493
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(
494
+ targetGroup.nodes.length
495
+ );
496
+ targetGroup.nodes.forEach((target) => {
497
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(target, true);
498
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(target, true);
499
+ });
500
+ });
501
+
502
+ it('should remove outgoing connections from layer.connections.out (one-sided)', () => {
503
+ const initialOutCount = layer.connections.out.length;
504
+ expect(initialOutCount).toBe(6);
505
+ layer.disconnect(targetGroup, false);
506
+ expect(layer.connections.out).toHaveLength(initialOutCount - 4);
507
+ layer.connections.out.forEach((conn) => {
508
+ expect(conn.to).toBe(targetNode);
509
+ });
510
+ });
511
+
512
+ it('should remove incoming connections from layer.connections.in if two-sided', () => {
513
+ targetGroup.nodes[0].connect(layer.nodes[0]);
514
+ targetGroup.nodes[1].connect(layer.nodes[1]);
515
+ layer.connections.in.push(targetGroup.nodes[0].connections.out[0]);
516
+ layer.connections.in.push(targetGroup.nodes[1].connections.out[0]);
517
+ const initialInCount = layer.connections.in.length;
518
+ expect(initialInCount).toBe(2);
519
+
520
+ layer.disconnect(targetGroup, true);
521
+
522
+ expect(layer.connections.in).toHaveLength(0);
523
+ });
524
+ });
525
+
526
+ describe('Disconnecting from Node', () => {
527
+ it('should call disconnect on each layer node for the target node (one-sided)', () => {
528
+ layer.disconnect(targetNode, false);
529
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(1);
530
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(1);
531
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(
532
+ targetNode,
533
+ false
534
+ );
535
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(
536
+ targetNode,
537
+ false
538
+ );
539
+ });
540
+
541
+ it('should call disconnect on each layer node for the target node (two-sided)', () => {
542
+ layer.disconnect(targetNode, true);
543
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledTimes(1);
544
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledTimes(1);
545
+ expect(nodeDisconnectSpies[0]).toHaveBeenCalledWith(targetNode, true);
546
+ expect(nodeDisconnectSpies[1]).toHaveBeenCalledWith(targetNode, true);
547
+ });
548
+
549
+ it('should remove outgoing connections to the node from layer.connections.out (one-sided)', () => {
550
+ const initialOutCount = layer.connections.out.length;
551
+ layer.disconnect(targetNode, false);
552
+ expect(layer.connections.out).toHaveLength(initialOutCount - 2);
553
+ layer.connections.out.forEach((conn) => {
554
+ expect(targetGroup.nodes).toContain(conn.to);
555
+ });
556
+ });
557
+
558
+ it('should remove incoming connections from the node in layer.connections.in if two-sided', () => {
559
+ targetNode.connect(layer.nodes[0]);
560
+ layer.connections.in.push(targetNode.connections.out[0]);
561
+ const initialInCount = layer.connections.in.length;
562
+ expect(initialInCount).toBe(1);
563
+
564
+ layer.disconnect(targetNode, true);
565
+
566
+ expect(layer.connections.in).toHaveLength(0);
567
+ });
568
+ });
569
+ });
570
+
571
+ describe('clear()', () => {
572
+ const size = 3;
573
+ let layer: Layer;
574
+ let nodeSpies: jest.SpyInstance[];
575
+
576
+ beforeEach(() => {
577
+ layer = createTestLayer(size);
578
+ nodeSpies = layer.nodes.map((node) => jest.spyOn(node, 'clear'));
579
+ });
580
+
581
+ afterEach(() => {
582
+ nodeSpies.forEach((spy) => spy.mockRestore());
583
+ });
584
+
585
+ it('should call clear on all nodes in the layer', () => {
586
+ layer.clear();
587
+ expect(nodeSpies).toHaveLength(size);
588
+ nodeSpies.forEach((spy) => {
589
+ expect(spy).toHaveBeenCalledTimes(1);
590
+ });
591
+ });
592
+
593
+ it('should call clear on Group instances within the layer nodes (e.g., Memory layer)', () => {
594
+ const memoryLayer = Layer.memory(2, 2);
595
+ const groupClearSpies = memoryLayer.nodes.map((groupNode) =>
596
+ jest.spyOn((groupNode as unknown) as Group, 'clear')
597
+ );
598
+
599
+ memoryLayer.clear();
600
+
601
+ groupClearSpies.forEach((spy) => {
602
+ expect(spy).toHaveBeenCalledTimes(1);
603
+ });
604
+
605
+ groupClearSpies.forEach((spy) => spy.mockRestore());
606
+ });
607
+ });
608
+
609
+ describe('input()', () => {
610
+ let targetLayer: Layer;
611
+ let sourceLayer: Layer;
612
+ let sourceGroup: Group;
613
+ let sourceOutputConnectSpy: jest.SpyInstance | undefined;
614
+ let sourceGroupConnectSpy: jest.SpyInstance;
615
+ let targetOutputConnectSpy: jest.SpyInstance | undefined;
616
+
617
+ beforeEach(() => {
618
+ targetLayer = createTestLayer(2);
619
+ sourceLayer = createTestLayer(3);
620
+ sourceGroup = new Group(3);
621
+
622
+ if (sourceLayer.output) {
623
+ sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
624
+ }
625
+ sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
626
+
627
+ if (targetLayer.output) {
628
+ targetOutputConnectSpy = jest.spyOn(targetLayer.output, 'connect');
629
+ }
630
+ });
631
+
632
+ afterEach(() => {
633
+ sourceOutputConnectSpy?.mockRestore();
634
+ sourceGroupConnectSpy.mockRestore();
635
+ targetOutputConnectSpy?.mockRestore();
636
+ });
637
+
638
+ it('should throw error if target layer output (acting as input) is not defined', () => {
639
+ const layerWithoutOutput = new Layer();
640
+ expect(() => layerWithoutOutput.input(sourceGroup)).toThrow(
641
+ 'Layer output (acting as input target) is not defined.'
642
+ );
643
+ });
644
+
645
+ it('should use source Layer output group when connecting from a Layer', () => {
646
+ targetLayer.input(sourceLayer);
647
+ expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
648
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
649
+ targetLayer.output,
650
+ methods.groupConnection.ALL_TO_ALL,
651
+ undefined
652
+ );
653
+ expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
654
+ });
655
+
656
+ it('should use source Group directly when connecting from a Group', () => {
657
+ targetLayer.input(sourceGroup);
658
+ expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
659
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
660
+ targetLayer.output,
661
+ methods.groupConnection.ALL_TO_ALL,
662
+ undefined
663
+ );
664
+ expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
665
+ });
666
+
667
+ it('should use default connection method ALL_TO_ALL if none provided', () => {
668
+ targetLayer.input(sourceGroup);
669
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
670
+ expect.anything(),
671
+ methods.groupConnection.ALL_TO_ALL,
672
+ undefined
673
+ );
674
+ });
675
+
676
+ it('should use provided connection method', () => {
677
+ const method = methods.groupConnection.ONE_TO_ONE;
678
+ sourceGroup = new Group(2);
679
+ sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
680
+
681
+ targetLayer.input(sourceGroup, method);
682
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
683
+ targetLayer.output,
684
+ method,
685
+ undefined
686
+ );
687
+ });
688
+
689
+ it('should use provided weight', () => {
690
+ const weight = 0.88;
691
+ targetLayer.input(sourceGroup, undefined, weight);
692
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
693
+ targetLayer.output,
694
+ methods.groupConnection.ALL_TO_ALL,
695
+ weight
696
+ );
697
+ });
698
+
699
+ it('should return the connections created by the source connect call', () => {
700
+ const mockConnections = [new Connection(new Node(), new Node())];
701
+ sourceGroupConnectSpy.mockReturnValue(mockConnections);
702
+ const result = targetLayer.input(sourceGroup);
703
+ expect(result).toBe(mockConnections);
704
+ });
705
+ });
706
+ });
707
+
708
+ describe('Static Factory Methods', () => {
709
+ describe('Layer.dense()', () => {
710
+ const size = 5;
711
+ let layer: Layer;
712
+
713
+ beforeEach(() => {
714
+ layer = Layer.dense(size);
715
+ });
716
+
717
+ it('should create a layer with the specified number of nodes', () => {
718
+ expect(layer.nodes).toHaveLength(size);
719
+ layer.nodes.forEach((node) => expect(node).toBeInstanceOf(Node));
720
+ });
721
+
722
+ it('should set the output to a Group containing all nodes', () => {
723
+ expect(layer.output).toBeInstanceOf(Group);
724
+ expect(layer.output?.nodes).toHaveLength(size);
725
+ expect(layer.output?.nodes).toEqual(layer.nodes);
726
+ });
727
+
728
+ it('should have a custom input method', () => {
729
+ expect(typeof layer.input).toBe('function');
730
+ expect(layer.input).not.toBe(Layer.prototype.input);
731
+ });
732
+
733
+ describe('Dense Layer input() method', () => {
734
+ let sourceLayer: Layer;
735
+ let sourceGroup: Group;
736
+ let sourceOutputConnectSpy: jest.SpyInstance | undefined;
737
+ let sourceGroupConnectSpy: jest.SpyInstance;
738
+
739
+ beforeEach(() => {
740
+ sourceLayer = Layer.dense(3);
741
+ sourceGroup = new Group(3);
742
+
743
+ if (sourceLayer.output) {
744
+ sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
745
+ }
746
+ sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
747
+ });
748
+
749
+ afterEach(() => {
750
+ sourceOutputConnectSpy?.mockRestore();
751
+ sourceGroupConnectSpy.mockRestore();
752
+ });
753
+
754
+ it('should connect source Layer output to the dense layers internal block', () => {
755
+ const method = methods.groupConnection.ALL_TO_ALL;
756
+ const weight = 0.1;
757
+ layer.input(sourceLayer, method, weight);
758
+
759
+ expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
760
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
761
+ layer.output,
762
+ method,
763
+ weight
764
+ );
765
+ expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
766
+ });
767
+
768
+ it('should connect source Group to the dense layers internal block', () => {
769
+ const method = methods.groupConnection.ONE_TO_ONE;
770
+ sourceGroup = new Group(size);
771
+ sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
772
+ const weight = 0.2;
773
+
774
+ layer.input(sourceGroup, method, weight);
775
+
776
+ expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
777
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
778
+ layer.output,
779
+ method,
780
+ weight
781
+ );
782
+ expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
783
+ });
784
+
785
+ it('should use ALL_TO_ALL by default if method not provided', () => {
786
+ layer.input(sourceGroup);
787
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
788
+ expect.anything(),
789
+ methods.groupConnection.ALL_TO_ALL,
790
+ undefined
791
+ );
792
+ });
793
+ });
794
+ });
795
+
796
+ describe('Layer.lstm()', () => {
797
+ const size = 2;
798
+ let layer: Layer;
799
+ let inputGate: Group,
800
+ forgetGate: Group,
801
+ memoryCell: Group,
802
+ outputGate: Group,
803
+ outputBlock: Group;
804
+
805
+ beforeEach(() => {
806
+ layer = Layer.lstm(size);
807
+ inputGate = new Group(size);
808
+ inputGate.nodes = layer.nodes.slice(0, size);
809
+
810
+ forgetGate = new Group(size);
811
+ forgetGate.nodes = layer.nodes.slice(size, size * 2);
812
+
813
+ memoryCell = new Group(size);
814
+ memoryCell.nodes = layer.nodes.slice(size * 2, size * 3);
815
+
816
+ outputGate = new Group(size);
817
+ outputGate.nodes = layer.nodes.slice(size * 3, size * 4);
818
+
819
+ outputBlock = new Group(size);
820
+ outputBlock.nodes = layer.nodes.slice(size * 4, size * 5);
821
+
822
+ layer.output = outputBlock;
823
+ });
824
+
825
+ it('should create a layer with 5 * size nodes', () => {
826
+ expect(layer.nodes).toHaveLength(5 * size);
827
+ });
828
+
829
+ it('should set the output to the outputBlock group', () => {
830
+ expect(layer.output).toBe(outputBlock);
831
+ });
832
+
833
+ it('should set initial bias for gates', () => {
834
+ inputGate.nodes.forEach((node) => expect(node.bias).toBe(1));
835
+ forgetGate.nodes.forEach((node) => expect(node.bias).toBe(1));
836
+ outputGate.nodes.forEach((node) => expect(node.bias).toBe(1));
837
+ memoryCell.nodes.forEach((node) => expect(node.bias).toBe(0));
838
+ outputBlock.nodes.forEach((node) => expect(node.bias).toBe(0));
839
+ });
840
+
841
+ it('should establish internal connections (memoryCell to gates)', () => {
842
+ expect(isGroupConnectedTo(memoryCell, inputGate)).toBe(true);
843
+ expect(isGroupConnectedTo(memoryCell, forgetGate)).toBe(true);
844
+ expect(isGroupConnectedTo(memoryCell, outputGate)).toBe(true);
845
+ });
846
+
847
+ it('should establish internal connections (memoryCell self-connection)', () => {
848
+ expect(
849
+ isGroupConnectedTo(
850
+ memoryCell,
851
+ memoryCell,
852
+ methods.groupConnection.ONE_TO_ONE
853
+ )
854
+ ).toBe(true);
855
+ });
856
+
857
+ it('should establish internal connections (memoryCell to outputBlock)', () => {
858
+ expect(isGroupConnectedTo(memoryCell, outputBlock)).toBe(true);
859
+ });
860
+
861
+ it('should gate the memoryCell self-connection with the forgetGate', () => {
862
+ memoryCell.nodes.forEach((node, i) => {
863
+ const selfConnection = node.connections.self.find(
864
+ (conn: Connection) => conn.to === node
865
+ );
866
+ expect(selfConnection).toBeDefined();
867
+ expect(selfConnection?.gater).toBe(forgetGate.nodes[i]);
868
+ });
869
+ });
870
+
871
+ it('should gate the memoryCell to outputBlock connection with the outputGate', () => {
872
+ memoryCell.nodes.forEach((node, i) => {
873
+ const outputConnection = node.connections.out.find(
874
+ (conn) => conn.to === outputBlock.nodes[0]
875
+ );
876
+ expect(outputConnection).toBeDefined();
877
+ expect(outputConnection?.gater).toBe(outputGate.nodes[i]);
878
+
879
+ if (size > 1) {
880
+ const outputConnection1 = node.connections.out.find(
881
+ (conn) => conn.to === outputBlock.nodes[1]
882
+ );
883
+ expect(outputConnection1).toBeDefined();
884
+ expect(outputConnection1?.gater).toBe(outputGate.nodes[i]);
885
+ }
886
+ });
887
+ });
888
+
889
+ it('should have a custom input method', () => {
890
+ expect(typeof layer.input).toBe('function');
891
+ expect(layer.input).not.toBe(Layer.prototype.input);
892
+ });
893
+
894
+ describe('LSTM Layer input() method', () => {
895
+ let sourceLayer: Layer;
896
+ let sourceOutputConnectSpy: jest.SpyInstance | undefined;
897
+
898
+ beforeEach(() => {
899
+ sourceLayer = Layer.dense(3);
900
+ if (sourceLayer.output) {
901
+ sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
902
+ }
903
+ });
904
+
905
+ afterEach(() => {
906
+ sourceOutputConnectSpy?.mockRestore();
907
+ });
908
+
909
+ it('should connect source to inputGate, forgetGate, memoryCell, and outputGate', () => {
910
+ layer.input(sourceLayer);
911
+
912
+ expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(4);
913
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
914
+ expect.objectContaining({ nodes: inputGate.nodes }),
915
+ methods.groupConnection.ALL_TO_ALL,
916
+ undefined
917
+ );
918
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
919
+ expect.objectContaining({ nodes: forgetGate.nodes }),
920
+ methods.groupConnection.ALL_TO_ALL,
921
+ undefined
922
+ );
923
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
924
+ expect.objectContaining({ nodes: memoryCell.nodes }),
925
+ methods.groupConnection.ALL_TO_ALL,
926
+ undefined
927
+ );
928
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
929
+ expect.objectContaining({ nodes: outputGate.nodes }),
930
+ methods.groupConnection.ALL_TO_ALL,
931
+ undefined
932
+ );
933
+ });
934
+
935
+ it('should gate the source-to-memoryCell connection with the inputGate', () => {
936
+ const connections = layer.input(sourceLayer);
937
+ const inputToMemoryConnection = connections.find(
938
+ (conn) =>
939
+ sourceLayer.output!.nodes.includes(conn.from) &&
940
+ memoryCell.nodes.includes(conn.to)
941
+ );
942
+ expect(inputToMemoryConnection).toBeDefined();
943
+ const targetNodeIndex = memoryCell.nodes.indexOf(
944
+ inputToMemoryConnection.to
945
+ );
946
+ expect(inputToMemoryConnection.gater).toBe(
947
+ inputGate.nodes[targetNodeIndex]
948
+ );
949
+ });
950
+ });
951
+ });
952
+
953
+ describe('Layer.gru()', () => {
954
+ const size = 2;
955
+ let layer: Layer;
956
+ let updateGate: Group,
957
+ inverseUpdateGate: Group,
958
+ resetGate: Group,
959
+ memoryCell: Group,
960
+ output: Group,
961
+ previousOutput: Group;
962
+
963
+ beforeEach(() => {
964
+ layer = Layer.gru(size);
965
+ updateGate = new Group(size);
966
+ updateGate.nodes = layer.nodes.slice(0, size);
967
+
968
+ inverseUpdateGate = new Group(size);
969
+ inverseUpdateGate.nodes = layer.nodes.slice(size, size * 2);
970
+
971
+ resetGate = new Group(size);
972
+ resetGate.nodes = layer.nodes.slice(size * 2, size * 3);
973
+
974
+ memoryCell = new Group(size);
975
+ memoryCell.nodes = layer.nodes.slice(size * 3, size * 4);
976
+
977
+ output = new Group(size);
978
+ output.nodes = layer.nodes.slice(size * 4, size * 5);
979
+
980
+ previousOutput = new Group(size);
981
+ previousOutput.nodes = layer.nodes.slice(size * 5, size * 6);
982
+
983
+ layer.output = output;
984
+ });
985
+
986
+ it('should create a layer with 6 * size nodes', () => {
987
+ expect(layer.nodes).toHaveLength(6 * size);
988
+ });
989
+
990
+ it('should set the output to the output group', () => {
991
+ expect(layer.output).toBe(output);
992
+ });
993
+
994
+ it('should set specific node properties', () => {
995
+ previousOutput.nodes.forEach((node) => {
996
+ expect(node.bias).toBe(0);
997
+ expect(node.squash).toBe(methods.Activation.identity);
998
+ expect(node.type).toBe('variant');
999
+ });
1000
+ memoryCell.nodes.forEach((node) =>
1001
+ expect(node.squash).toBe(methods.Activation.tanh)
1002
+ );
1003
+ inverseUpdateGate.nodes.forEach((node) => {
1004
+ expect(node.bias).toBe(0);
1005
+ expect(node.squash).toBe(methods.Activation.inverse);
1006
+ expect(node.type).toBe('variant');
1007
+ });
1008
+ updateGate.nodes.forEach((node) => expect(node.bias).toBe(1));
1009
+ resetGate.nodes.forEach((node) => expect(node.bias).toBe(0));
1010
+ });
1011
+
1012
+ it('should establish internal connections (previousOutput to gates)', () => {
1013
+ expect(isGroupConnectedTo(previousOutput, updateGate)).toBe(true);
1014
+ expect(isGroupConnectedTo(previousOutput, resetGate)).toBe(true);
1015
+ });
1016
+
1017
+ it('should establish internal connections (updateGate to inverseUpdateGate)', () => {
1018
+ expect(
1019
+ isGroupConnectedTo(
1020
+ updateGate,
1021
+ inverseUpdateGate,
1022
+ methods.groupConnection.ONE_TO_ONE
1023
+ )
1024
+ ).toBe(true);
1025
+ updateGate.nodes.forEach((node, i) => {
1026
+ const conn = node.connections.out.find(
1027
+ (c) => c.to === inverseUpdateGate.nodes[i]
1028
+ );
1029
+ expect(conn).toBeDefined();
1030
+ expect(conn?.weight).toBe(1);
1031
+ });
1032
+ });
1033
+
1034
+ it('should establish internal connections (previousOutput to memoryCell)', () => {
1035
+ expect(isGroupConnectedTo(previousOutput, memoryCell)).toBe(true);
1036
+ });
1037
+
1038
+ it('should establish internal connections (previousOutput and memoryCell to output)', () => {
1039
+ expect(isGroupConnectedTo(previousOutput, output)).toBe(true);
1040
+ expect(isGroupConnectedTo(memoryCell, output)).toBe(true);
1041
+ });
1042
+
1043
+ it('should establish internal connections (output to previousOutput)', () => {
1044
+ expect(
1045
+ isGroupConnectedTo(
1046
+ output,
1047
+ previousOutput,
1048
+ methods.groupConnection.ONE_TO_ONE
1049
+ )
1050
+ ).toBe(true);
1051
+ output.nodes.forEach((node, i) => {
1052
+ const conn = node.connections.out.find(
1053
+ (c) => c.to === previousOutput.nodes[i]
1054
+ );
1055
+ expect(conn).toBeDefined();
1056
+ expect(conn?.weight).toBe(1);
1057
+ });
1058
+ });
1059
+
1060
+ it('should gate previousOutput->memoryCell connection with resetGate', () => {
1061
+ previousOutput.nodes.forEach((node, i) => {
1062
+ const conn = node.connections.out.find(
1063
+ (c) => c.to === memoryCell.nodes[0]
1064
+ );
1065
+ expect(conn).toBeDefined();
1066
+ expect(conn?.gater).toBe(resetGate.nodes[i]);
1067
+ });
1068
+ });
1069
+
1070
+ it('should gate previousOutput->output connection with updateGate', () => {
1071
+ previousOutput.nodes.forEach((node, i) => {
1072
+ const conn = node.connections.out.find(
1073
+ (c) => c.to === output.nodes[0]
1074
+ );
1075
+ expect(conn).toBeDefined();
1076
+ expect(conn?.gater).toBe(updateGate.nodes[i]);
1077
+ });
1078
+ });
1079
+
1080
+ it('should gate memoryCell->output connection with inverseUpdateGate', () => {
1081
+ memoryCell.nodes.forEach((node, i) => {
1082
+ const conn = node.connections.out.find(
1083
+ (c) => c.to === output.nodes[0]
1084
+ );
1085
+ expect(conn).toBeDefined();
1086
+ expect(conn?.gater).toBe(inverseUpdateGate.nodes[i]);
1087
+ });
1088
+ });
1089
+
1090
+ it('should have a custom input method', () => {
1091
+ expect(typeof layer.input).toBe('function');
1092
+ expect(layer.input).not.toBe(Layer.prototype.input);
1093
+ });
1094
+
1095
+ describe('GRU Layer input() method', () => {
1096
+ let sourceLayer: Layer;
1097
+ let sourceOutputConnectSpy: jest.SpyInstance | undefined;
1098
+
1099
+ beforeEach(() => {
1100
+ sourceLayer = Layer.dense(3);
1101
+ if (sourceLayer.output) {
1102
+ sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
1103
+ }
1104
+ });
1105
+
1106
+ afterEach(() => {
1107
+ sourceOutputConnectSpy?.mockRestore();
1108
+ });
1109
+
1110
+ it('should connect source to updateGate, resetGate, and memoryCell', () => {
1111
+ layer.input(sourceLayer);
1112
+
1113
+ expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(3);
1114
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
1115
+ expect.objectContaining({ nodes: updateGate.nodes }),
1116
+ methods.groupConnection.ALL_TO_ALL,
1117
+ undefined
1118
+ );
1119
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
1120
+ expect.objectContaining({ nodes: resetGate.nodes }),
1121
+ methods.groupConnection.ALL_TO_ALL,
1122
+ undefined
1123
+ );
1124
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
1125
+ expect.objectContaining({ nodes: memoryCell.nodes }),
1126
+ methods.groupConnection.ALL_TO_ALL,
1127
+ undefined
1128
+ );
1129
+ });
1130
+ });
1131
+ });
1132
+
1133
+ describe('Layer.memory()', () => {
1134
+ const size = 3;
1135
+ const memoryDepth = 2;
1136
+ let layer: Layer;
1137
+
1138
+ beforeEach(() => {
1139
+ layer = Layer.memory(size, memoryDepth);
1140
+ });
1141
+
1142
+ it('should create a layer with memoryDepth groups in nodes array', () => {
1143
+ expect(layer.nodes).toHaveLength(memoryDepth);
1144
+ layer.nodes.forEach((nodeOrGroup) => {
1145
+ expect((layer as any).isGroup(nodeOrGroup)).toBe(true);
1146
+ expect(((nodeOrGroup as unknown) as Group).nodes).toHaveLength(size);
1147
+ });
1148
+ });
1149
+
1150
+ it('should set specific properties for nodes within memory blocks', () => {
1151
+ layer.nodes.forEach((group) => {
1152
+ ((group as unknown) as Group).nodes.forEach((node) => {
1153
+ expect(node.squash).toBe(methods.Activation.identity);
1154
+ expect(node.bias).toBe(0);
1155
+ expect(node.type).toBe('variant');
1156
+ });
1157
+ });
1158
+ });
1159
+
1160
+ it('should connect previous memory block to current one (ONE_TO_ONE, weight 1)', () => {
1161
+ // After reversal in factory: nodes[0] is newest, nodes[1] is second newest, etc.
1162
+ // Connection is made from older (previous) to newer (block).
1163
+ // So, connection should exist from nodes[1] (older) to nodes[0] (newer).
1164
+ const block1 = (layer.nodes[0] as unknown) as Group; // Newest block
1165
+ const block2 = (layer.nodes[1] as unknown) as Group; // Second newest (older) block
1166
+
1167
+ // Check connection from the older block (block2) to the newer block (block1)
1168
+ expect(
1169
+ isGroupConnectedTo(block2, block1, methods.groupConnection.ONE_TO_ONE)
1170
+ ).toBe(true);
1171
+
1172
+ // Check weight on the node level (connection from block2 node to block1 node)
1173
+ block2.nodes.forEach((node, i) => {
1174
+ const conn = node.connections.out.find(
1175
+ (c) => c.to === block1.nodes[i]
1176
+ );
1177
+ expect(conn).toBeDefined();
1178
+ expect(conn?.weight).toBe(1);
1179
+ });
1180
+ });
1181
+
1182
+ it('should create a concatenated output group', () => {
1183
+ expect(layer.output).toBeInstanceOf(Group);
1184
+ expect(layer.output?.nodes).toHaveLength(size * memoryDepth);
1185
+ const block1Nodes = ((layer.nodes[0] as unknown) as Group).nodes;
1186
+ const block2Nodes = ((layer.nodes[1] as unknown) as Group).nodes;
1187
+ expect(layer.output?.nodes).toEqual([...block1Nodes, ...block2Nodes]);
1188
+ });
1189
+
1190
+ it('should have a custom input method', () => {
1191
+ expect(typeof layer.input).toBe('function');
1192
+ expect(layer.input).not.toBe(Layer.prototype.input);
1193
+ });
1194
+
1195
+ describe('Memory Layer input() method', () => {
1196
+ let sourceLayer: Layer;
1197
+ let sourceGroup: Group;
1198
+ let sourceOutputConnectSpy: jest.SpyInstance | undefined;
1199
+ let sourceGroupConnectSpy: jest.SpyInstance;
1200
+ let lastBlock: Group;
1201
+
1202
+ beforeEach(() => {
1203
+ sourceLayer = Layer.dense(size);
1204
+ sourceGroup = new Group(size);
1205
+ lastBlock = (layer.nodes[memoryDepth - 1] as unknown) as Group;
1206
+
1207
+ if (sourceLayer.output) {
1208
+ sourceOutputConnectSpy = jest.spyOn(sourceLayer.output, 'connect');
1209
+ }
1210
+ sourceGroupConnectSpy = jest.spyOn(sourceGroup, 'connect');
1211
+ });
1212
+
1213
+ afterEach(() => {
1214
+ sourceOutputConnectSpy?.mockRestore();
1215
+ sourceGroupConnectSpy.mockRestore();
1216
+ });
1217
+
1218
+ it('should connect source Layer output to the last memory block (ONE_TO_ONE, weight 1)', () => {
1219
+ layer.input(sourceLayer);
1220
+
1221
+ expect(sourceOutputConnectSpy).toHaveBeenCalledTimes(1);
1222
+ expect(sourceOutputConnectSpy).toHaveBeenCalledWith(
1223
+ expect.objectContaining({ nodes: lastBlock.nodes }),
1224
+ methods.groupConnection.ONE_TO_ONE,
1225
+ 1
1226
+ );
1227
+ expect(sourceGroupConnectSpy).not.toHaveBeenCalled();
1228
+ });
1229
+
1230
+ it('should connect source Group to the last memory block (ONE_TO_ONE, weight 1)', () => {
1231
+ layer.input(sourceGroup);
1232
+
1233
+ expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
1234
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
1235
+ expect.objectContaining({ nodes: lastBlock.nodes }),
1236
+ methods.groupConnection.ONE_TO_ONE,
1237
+ 1
1238
+ );
1239
+ expect(sourceOutputConnectSpy).not.toHaveBeenCalled();
1240
+ });
1241
+
1242
+ it('should ignore provided method and weight, forcing ONE_TO_ONE and weight 1', () => {
1243
+ layer.input(sourceGroup, methods.groupConnection.ALL_TO_ALL, 0.5);
1244
+
1245
+ expect(sourceGroupConnectSpy).toHaveBeenCalledTimes(1);
1246
+ expect(sourceGroupConnectSpy).toHaveBeenCalledWith(
1247
+ expect.objectContaining({ nodes: lastBlock.nodes }),
1248
+ methods.groupConnection.ONE_TO_ONE,
1249
+ 1
1250
+ );
1251
+ });
1252
+
1253
+ it('should throw error if source size does not match memory block size', () => {
1254
+ const wrongSizeSource = new Group(size + 1);
1255
+ expect(() => layer.input(wrongSizeSource)).toThrow(
1256
+ `Previous layer size (${wrongSizeSource.nodes.length}) must be same as memory size (${size})`
1257
+ );
1258
+ });
1259
+
1260
+ it('should throw error if the target input block is not a Group (edge case)', () => {
1261
+ layer.nodes[memoryDepth - 1] = new Node();
1262
+ expect(() => layer.input(sourceGroup)).toThrow(
1263
+ 'Memory layer input block is not a Group.'
1264
+ );
1265
+ });
1266
+ });
1267
+ });
1268
+
1269
+ describe('Layer.batchNorm() and Layer.layerNorm()', () => {
1270
+ describe('Scenario: batchNorm normalizes activations', () => {
1271
+ it('mean is ~0', () => {
1272
+ // Arrange
1273
+ const size = 10;
1274
+ const layer = Layer.batchNorm(size);
1275
+ const input = Array.from({ length: size }, (_, i) => i + 1);
1276
+ // Act
1277
+ const activations = layer.activate(input);
1278
+ const mean = activations.reduce((a, b) => a + b, 0) / size;
1279
+ // Assert
1280
+ expect(mean).toBeCloseTo(0, 5);
1281
+ });
1282
+ it('variance is ~1', () => {
1283
+ // Arrange
1284
+ const size = 10;
1285
+ const layer = Layer.batchNorm(size);
1286
+ const input = Array.from({ length: size }, (_, i) => i + 1);
1287
+ // Act
1288
+ const activations = layer.activate(input);
1289
+ const mean = activations.reduce((a, b) => a + b, 0) / size;
1290
+ const variance =
1291
+ activations.reduce((a, b) => a + (b - mean) ** 2, 0) / size;
1292
+ // Assert
1293
+ expect(variance).toBeCloseTo(1, 2);
1294
+ });
1295
+ });
1296
+ describe('Scenario: layerNorm normalizes activations', () => {
1297
+ it('mean is ~0', () => {
1298
+ // Arrange
1299
+ const size = 10;
1300
+ const layer = Layer.layerNorm(size);
1301
+ const input = Array.from({ length: size }, (_, i) => i + 1);
1302
+ // Act
1303
+ const activations = layer.activate(input);
1304
+ const mean = activations.reduce((a, b) => a + b, 0) / size;
1305
+ // Assert
1306
+ expect(mean).toBeCloseTo(0, 5);
1307
+ });
1308
+ it('variance is ~1', () => {
1309
+ // Arrange
1310
+ const size = 10;
1311
+ const layer = Layer.layerNorm(size);
1312
+ const input = Array.from({ length: size }, (_, i) => i + 1);
1313
+ // Act
1314
+ const activations = layer.activate(input);
1315
+ const mean = activations.reduce((a, b) => a + b, 0) / size;
1316
+ const variance =
1317
+ activations.reduce((a, b) => a + (b - mean) ** 2, 0) / size;
1318
+ // Assert
1319
+ expect(variance).toBeCloseTo(1, 2);
1320
+ });
1321
+ });
1322
+ });
1323
+
1324
+ describe('Layer.conv1d() and Layer.attention()', () => {
1325
+ it('conv1d constructs a layer and slices input as stub', () => {
1326
+ const size = 4;
1327
+ const kernel = 3;
1328
+ const layer = Layer.conv1d(size, kernel);
1329
+ expect(layer.nodes).toHaveLength(size);
1330
+ expect(layer.output?.nodes).toHaveLength(size);
1331
+ // Should store conv params
1332
+ expect((layer as any).conv1d).toEqual({
1333
+ kernelSize: kernel,
1334
+ stride: 1,
1335
+ padding: 0,
1336
+ });
1337
+ // Activation slices input
1338
+ const input = [1, 2, 3, 4, 5, 6];
1339
+ expect(layer.activate(input)).toEqual([1, 2, 3, 4]);
1340
+ });
1341
+ it('attention constructs a layer and averages input as stub', () => {
1342
+ const size = 3;
1343
+ const heads = 2;
1344
+ const layer = Layer.attention(size, heads);
1345
+ expect(layer.nodes).toHaveLength(size);
1346
+ expect(layer.output?.nodes).toHaveLength(size);
1347
+ expect((layer as any).attention).toEqual({ heads });
1348
+ // Activation averages input
1349
+ const input = [2, 4, 6];
1350
+ expect(layer.activate(input)).toEqual([4, 4, 4]);
1351
+ });
1352
+ });
1353
+ });
1354
+
1355
+ describe('isGroup (private helper)', () => {
1356
+ let layer: Layer;
1357
+
1358
+ beforeEach(() => {
1359
+ layer = new Layer();
1360
+ });
1361
+
1362
+ it('should return true for a Group instance', () => {
1363
+ const group = new Group(1);
1364
+ expect((layer as any).isGroup(group)).toBe(true);
1365
+ });
1366
+
1367
+ it('should return false for a Node instance', () => {
1368
+ const node = new Node();
1369
+ expect((layer as any).isGroup(node)).toBe(false);
1370
+ });
1371
+
1372
+ it('should return false for a plain object', () => {
1373
+ const obj = { nodes: [], set: () => {} };
1374
+ expect((layer as any).isGroup(obj)).toBe(true);
1375
+
1376
+ const objMissingNodes = { set: () => {} };
1377
+ expect((layer as any).isGroup(objMissingNodes)).toBe(false);
1378
+
1379
+ const objMissingSet = { nodes: [] };
1380
+ expect((layer as any).isGroup(objMissingSet)).toBe(false);
1381
+ });
1382
+
1383
+ it('should return false for null', () => {
1384
+ expect((layer as any).isGroup(null)).toBe(false);
1385
+ });
1386
+
1387
+ it('should return false for undefined', () => {
1388
+ expect((layer as any).isGroup(undefined)).toBe(false);
1389
+ });
1390
+
1391
+ it('should return false for primitive types', () => {
1392
+ expect((layer as any).isGroup(123)).toBe(false);
1393
+ expect((layer as any).isGroup('string')).toBe(false);
1394
+ expect((layer as any).isGroup(true)).toBe(false);
1395
+ });
1396
+ });
1397
+
1398
+ describe('Layer-level Dropout', () => {
1399
+ describe('Scenario: all nodes in a layer are masked together during training', () => {
1400
+ it('all masks are the same (either all 0 or all 1)', () => {
1401
+ // Arrange
1402
+ const size = 8;
1403
+ const layer = new (require('../../src/architecture/layer').default)();
1404
+ for (let i = 0; i < size; i++) {
1405
+ layer.nodes.push(
1406
+ new (require('../../src/architecture/node').default)('hidden')
1407
+ );
1408
+ }
1409
+ layer.dropout = 0.7;
1410
+ // Act
1411
+ const masks = layer.nodes.map((n: any) => n.mask);
1412
+ // Assert
1413
+ expect(new Set(masks).size).toBe(1);
1414
+ });
1415
+ });
1416
+
1417
+ describe('Scenario: masks are reset to 1 after inference', () => {
1418
+ it('all masks are 1 after inference', () => {
1419
+ // Arrange
1420
+ const size = 8;
1421
+ const layer = new (require('../../src/architecture/layer').default)();
1422
+ for (let i = 0; i < size; i++) {
1423
+ layer.nodes.push(
1424
+ new (require('../../src/architecture/node').default)('hidden')
1425
+ );
1426
+ }
1427
+ layer.dropout = 0.7;
1428
+ // Simulate training
1429
+ layer.activate(undefined, true);
1430
+ // Act
1431
+ layer.activate(undefined, false);
1432
+ const masks = layer.nodes.map((n: any) => n.mask);
1433
+ // Assert
1434
+ masks.forEach((m: number) => expect(m).toBe(1));
1435
+ });
1436
+ });
1437
+
1438
+ describe('Scenario: inference is unaffected by previous dropout', () => {
1439
+ it('all activations are valid numbers after inference', () => {
1440
+ // Arrange
1441
+ const size = 8;
1442
+ const layer = new (require('../../src/architecture/layer').default)();
1443
+ for (let i = 0; i < size; i++) {
1444
+ layer.nodes.push(
1445
+ new (require('../../src/architecture/node').default)('hidden')
1446
+ );
1447
+ }
1448
+ layer.dropout = 0.7;
1449
+ // Simulate training
1450
+ layer.activate(undefined, true);
1451
+ // Act
1452
+ layer.activate(undefined, false);
1453
+ const activations = layer.nodes.map((n: any) => n.activation);
1454
+ // Assert
1455
+ activations.forEach((a: number) => expect(typeof a).toBe('number'));
1456
+ });
1457
+ });
1458
+
1459
+ describe('Scenario: masks all nodes together during training', () => {
1460
+ for (const dropout of [0.8, 0, 1]) {
1461
+ describe(`when dropout = ${dropout}`, () => {
1462
+ it('all masks are the same in this activation', () => {
1463
+ // Arrange
1464
+ const layer = Layer.dense(5);
1465
+ layer.dropout = dropout;
1466
+ // Act
1467
+ layer.activate(undefined, true); // training=true
1468
+ const mask = layer.nodes[0].mask;
1469
+ // Assert
1470
+ layer.nodes.forEach((node) => {
1471
+ expect(node.mask).toBe(mask);
1472
+ });
1473
+ });
1474
+ });
1475
+ }
1476
+ describe('when dropout is strictly between 0 and 1', () => {
1477
+ it('0 mask occurred over multiple activations', () => {
1478
+ // Arrange
1479
+ const layer = Layer.dense(5);
1480
+ layer.dropout = 0.8;
1481
+ // Act
1482
+ let zeroOccurred = false;
1483
+ for (let i = 0; i < 10; i++) {
1484
+ layer.activate(undefined, true);
1485
+ if (layer.nodes[0].mask === 0) zeroOccurred = true;
1486
+ }
1487
+ // Assert
1488
+ expect(zeroOccurred).toBe(true);
1489
+ });
1490
+ it('1 mask occurred over multiple activations', () => {
1491
+ // Arrange
1492
+ const layer = Layer.dense(5);
1493
+ layer.dropout = 0.8;
1494
+ // Act
1495
+ let oneOccurred = false;
1496
+ for (let i = 0; i < 10; i++) {
1497
+ layer.activate(undefined, true);
1498
+ if (layer.nodes[0].mask === 1) oneOccurred = true;
1499
+ }
1500
+ // Assert
1501
+ expect(oneOccurred).toBe(true);
1502
+ });
1503
+ });
1504
+ });
1505
+
1506
+ describe('Scenario: resets all masks to 1 after training (inference)', () => {
1507
+ it('all masks are 1 after inference', () => {
1508
+ // Arrange
1509
+ const layer = Layer.dense(4);
1510
+ layer.dropout = 0.9;
1511
+ layer.activate(undefined, true); // training
1512
+ // Act
1513
+ layer.activate(undefined, false); // inference
1514
+ // Assert
1515
+ layer.nodes.forEach((node) => {
1516
+ expect(node.mask).toBe(1);
1517
+ });
1518
+ });
1519
+ });
1520
+
1521
+ describe('Scenario: node-level dropout is not applied if layer-level dropout is set', () => {
1522
+ it('all node masks are 0 if layer.dropout = 1', () => {
1523
+ // Arrange
1524
+ const layer = Layer.dense(6);
1525
+ layer.dropout = 1; // always mask
1526
+ // Act
1527
+ layer.activate(undefined, true);
1528
+ // Assert
1529
+ layer.nodes.forEach((node) => {
1530
+ expect(node.mask).toBe(0);
1531
+ });
1532
+ });
1533
+ });
1534
+ });
1535
+ });