scalar-autograd 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/dist/Losses.d.ts +51 -0
  2. package/dist/Losses.js +145 -0
  3. package/dist/Losses.spec.d.ts +1 -0
  4. package/dist/Losses.spec.js +54 -0
  5. package/dist/Optimizers.d.ts +114 -0
  6. package/dist/Optimizers.edge-cases.spec.d.ts +1 -0
  7. package/dist/Optimizers.edge-cases.spec.js +29 -0
  8. package/dist/Optimizers.js +177 -0
  9. package/dist/Optimizers.spec.d.ts +1 -0
  10. package/dist/Optimizers.spec.js +56 -0
  11. package/dist/V.d.ts +0 -0
  12. package/dist/V.js +0 -0
  13. package/dist/Value.d.ts +260 -0
  14. package/dist/Value.edge-cases.spec.d.ts +1 -0
  15. package/dist/Value.edge-cases.spec.js +54 -0
  16. package/dist/Value.grad-flow.spec.d.ts +1 -0
  17. package/dist/Value.grad-flow.spec.js +24 -0
  18. package/dist/Value.js +424 -0
  19. package/dist/Value.losses-edge-cases.spec.d.ts +1 -0
  20. package/dist/Value.losses-edge-cases.spec.js +30 -0
  21. package/dist/Value.memory.spec.d.ts +1 -0
  22. package/dist/Value.memory.spec.js +23 -0
  23. package/dist/Value.nn.spec.d.ts +1 -0
  24. package/dist/Value.nn.spec.js +111 -0
  25. package/dist/Value.spec.d.ts +1 -0
  26. package/dist/Value.spec.js +245 -0
  27. package/dist/ValueActivation.d.ts +7 -0
  28. package/dist/ValueActivation.js +34 -0
  29. package/dist/ValueArithmetic.d.ts +26 -0
  30. package/dist/ValueArithmetic.js +180 -0
  31. package/dist/ValueComparison.d.ts +10 -0
  32. package/dist/ValueComparison.js +47 -0
  33. package/dist/ValueTrig.d.ts +9 -0
  34. package/dist/ValueTrig.js +49 -0
  35. package/package.json +4 -12
  36. package/Losses.ts +0 -145
  37. package/Optimizers.ts +0 -222
  38. package/V.ts +0 -0
  39. package/Value.edge-cases.spec.ts +0 -60
  40. package/Value.grad-flow.spec.ts +0 -24
  41. package/Value.losses-edge-cases.spec.ts +0 -32
  42. package/Value.memory.spec.ts +0 -25
  43. package/Value.nn.spec.ts +0 -109
  44. package/Value.spec.ts +0 -268
  45. package/Value.ts +0 -461
  46. package/ValueActivation.ts +0 -51
  47. package/ValueArithmetic.ts +0 -272
  48. package/ValueComparison.ts +0 -85
  49. package/ValueTrig.ts +0 -70
package/Losses.ts DELETED
@@ -1,145 +0,0 @@
1
- import { Value } from "./Value";
2
- import { V } from "./V";
3
-
4
- /**
5
- * Throws an error if outputs and targets length do not match.
6
- * @param outputs Array of output Values.
7
- * @param targets Array of target Values.
8
- */
9
- function checkLengthMatch(outputs: Value[], targets: Value[]): void {
10
- if (outputs.length !== targets.length) {
11
- throw new Error('Outputs and targets must have the same length');
12
- }
13
- }
14
-
15
- export class Losses {
16
- /**
17
- * Computes mean squared error (MSE) loss between outputs and targets.
18
- * @param outputs Array of Value predictions.
19
- * @param targets Array of Value targets.
20
- * @returns Mean squared error as a Value.
21
- */
22
- public static mse(outputs: Value[], targets: Value[]): Value {
23
- checkLengthMatch(outputs, targets);
24
- if (!Array.isArray(outputs) || !Array.isArray(targets)) throw new TypeError('mse expects Value[] for both arguments.');
25
- if (!outputs.length) return new Value(0);
26
- const diffs = outputs.map((out, i) => out.sub(targets[i]).square());
27
- return Value.mean(diffs);
28
- }
29
-
30
- /**
31
- * Computes mean absolute error (MAE) loss between outputs and targets.
32
- * @param outputs Array of Value predictions.
33
- * @param targets Array of Value targets.
34
- * @returns Mean absolute error as a Value.
35
- */
36
- public static mae(outputs: Value[], targets: Value[]): Value {
37
- checkLengthMatch(outputs, targets);
38
- if (!Array.isArray(outputs) || !Array.isArray(targets)) throw new TypeError('mae expects Value[] for both arguments.');
39
- if (!outputs.length) return new Value(0);
40
- const diffs = outputs.map((out, i) => out.sub(targets[i]).abs());
41
- return Value.mean(diffs);
42
- }
43
-
44
- static EPS = 1e-12;
45
-
46
- /**
47
- * Computes binary cross-entropy loss between predicted outputs and targets (after sigmoid).
48
- * @param outputs Array of Value predictions (expected in (0,1)).
49
- * @param targets Array of Value targets (typically 0 or 1).
50
- * @returns Binary cross-entropy loss as a Value.
51
- */
52
- public static binaryCrossEntropy(outputs: Value[], targets: Value[]): Value {
53
- checkLengthMatch(outputs, targets);
54
- if (!Array.isArray(outputs) || !Array.isArray(targets)) throw new TypeError('binaryCrossEntropy expects Value[] for both arguments.');
55
- if (!outputs.length) return new Value(0);
56
- const eps = Losses.EPS;
57
- const one = new Value(1);
58
- const losses = outputs.map((out, i) => {
59
- const t = targets[i];
60
- const outClamped = out.clamp(eps, 1 - eps); // sigmoid should output (0,1)
61
- return t.mul(outClamped.log()).add(one.sub(t).mul(one.sub(outClamped).log()));
62
- });
63
- return Value.mean(losses).mul(-1);
64
- }
65
-
66
- /**
67
- * Computes categorical cross-entropy loss between outputs (logits) and integer target classes.
68
- * @param outputs Array of Value logits for each class.
69
- * @param targets Array of integer class indices (0-based, one per sample).
70
- * @returns Categorical cross-entropy loss as a Value.
71
- */
72
- public static categoricalCrossEntropy(outputs: Value[], targets: number[]): Value {
73
- // targets: integer encoded class indices
74
- if (!Array.isArray(outputs) || !Array.isArray(targets)) throw new TypeError('categoricalCrossEntropy expects Value[] and number[].');
75
- if (!outputs.length || !targets.length) return new Value(0);
76
- if (targets.some(t => typeof t !== 'number' || !isFinite(t) || t < 0 || t >= outputs.length || Math.floor(t) !== t)) {
77
- throw new Error('Target indices must be valid integers in [0, outputs.length)');
78
- }
79
- const eps = Losses.EPS;
80
- const maxLogit = outputs.reduce((a, b) => a.data > b.data ? a : b);
81
- const exps = outputs.map(out => out.sub(maxLogit).exp());
82
- const sumExp = Value.sum(exps).add(eps);
83
- const softmax = exps.map(e => e.div(sumExp));
84
- const tIndices = targets.map((t, i) => softmax[t]);
85
- return Value.mean(tIndices.map(sm => sm.add(eps).log().mul(-1)));
86
- }
87
-
88
- /**
89
- * Computes Huber loss between outputs and targets.
90
- * Combines quadratic loss for small residuals and linear loss for large residuals.
91
- * @param outputs Array of Value predictions.
92
- * @param targets Array of Value targets.
93
- * @param delta Threshold at which to switch from quadratic to linear (default: 1.0).
94
- * @returns Huber loss as a Value.
95
- */
96
- public static huber(outputs: Value[], targets: Value[], delta = 1.0): Value {
97
- checkLengthMatch(outputs, targets);
98
- if (!Array.isArray(outputs) || !Array.isArray(targets)) throw new TypeError('huber expects Value[] for both arguments.');
99
- if (!outputs.length) return new Value(0);
100
-
101
- const deltaValue = new Value(delta);
102
- const half = new Value(0.5);
103
-
104
- const losses = outputs.map((out, i) => {
105
- const residual = V.abs(V.sub(out, targets[i]));
106
- const condition = V.lt(residual, deltaValue);
107
-
108
- const quadraticLoss = V.mul(half, V.square(residual));
109
- const linearLoss = V.mul(deltaValue, V.sub(residual, V.mul(half, deltaValue)));
110
-
111
- return V.ifThenElse(condition, quadraticLoss, linearLoss);
112
- });
113
-
114
- return V.mean(losses);
115
- }
116
-
117
- /**
118
- * Computes Tukey loss between outputs and targets.
119
- * This robust loss function saturates for large residuals.
120
- *
121
- * @param outputs Array of Value predictions.
122
- * @param targets Array of Value targets.
123
- * @param c Threshold constant (typically 4.685).
124
- * @returns Tukey loss as a Value.
125
- */
126
- public static tukey(outputs: Value[], targets: Value[], c: number = 4.685): Value {
127
- checkLengthMatch(outputs, targets);
128
- const c2_over_6 = (c * c) / 6;
129
- const cValue = V.C(c);
130
- const c2_over_6_Value = V.C(c2_over_6);
131
-
132
- const losses = outputs.map((out, i) => {
133
- const diff = V.abs(V.sub(out, targets[i]));
134
- const inlier = V.lte(diff, cValue);
135
- const rc = V.div(diff, cValue);
136
- const rc2 = V.square(rc);
137
- const oneMinusRC2 = V.sub(1, rc2);
138
- const inner = V.pow(oneMinusRC2, 3);
139
- const inlierLoss = V.mul(c2_over_6_Value, V.sub(1, inner));
140
- const loss = V.ifThenElse(inlier, inlierLoss, c2_over_6_Value);
141
- return loss;
142
- });
143
- return V.mean(losses);
144
- }
145
- }
package/Optimizers.ts DELETED
@@ -1,222 +0,0 @@
1
- // Optimizers.ts
2
-
3
- import { Value } from "./Value";
4
-
5
- /**
6
- * Abstract base class for all optimizers.
7
- * Ensures only requiresGrad parameters are optimized.
8
- */
9
- export abstract class Optimizer {
10
- protected trainables: Value[];
11
- public learningRate: number;
12
-
13
- /**
14
- * Constructs an Optimizer.
15
- * @param trainables Array of Value parameters to optimize.
16
- * @param learningRate Learning rate for updates.
17
- */
18
- constructor(trainables: Value[], learningRate: number) {
19
- this.trainables = trainables.filter(v => v.requiresGrad);
20
- this.learningRate = learningRate;
21
- }
22
-
23
- /**
24
- * Performs a parameter update step.
25
- */
26
- abstract step(): void;
27
-
28
- /**
29
- * Sets grads of all trainables to zero.
30
- */
31
- zeroGrad(): void {
32
- for (const v of this.trainables) v.grad = 0;
33
- }
34
-
35
- /**
36
- * Clips global norm of gradients as regularization.
37
- * @param maxNorm Maximum allowed norm for gradients.
38
- */
39
- clipGradients(maxNorm: number): void {
40
- const totalNorm = Math.sqrt(
41
- this.trainables.reduce((sum, v) => sum + v.grad * v.grad, 0)
42
- );
43
- if (totalNorm > maxNorm) {
44
- const scale = maxNorm / (totalNorm + 1e-6);
45
- for (const v of this.trainables) v.grad *= scale;
46
- }
47
- }
48
- }
49
-
50
- /**
51
- * Optional arguments for basic optimizers.
52
- * @property learningRate: Overrides the step size for parameter updates (default varies by optimizer).
53
- * @property weightDecay: L2 regularization multiplier (default 0). Ignored for plain SGD.
54
- * @property gradientClip: Maximum absolute value for gradient updates (default 0: no clipping).
55
- */
56
- export interface OptimizerOptions {
57
- learningRate?: number;
58
- weightDecay?: number;
59
- gradientClip?: number;
60
- }
61
-
62
- /**
63
- * Stochastic Gradient Descent (SGD) optimizer. Accepts weightDecay and gradientClip for API consistency (ignored).
64
- */
65
- export class SGD extends Optimizer {
66
- private weightDecay: number;
67
- private gradientClip: number;
68
- /**
69
- * Constructs an SGD optimizer.
70
- * @param trainables Array of Value parameters to optimize.
71
- * @param opts Optional parameters (learningRate, weightDecay, gradientClip).
72
- */
73
- constructor(trainables: Value[], opts: OptimizerOptions = {}) {
74
- super(
75
- trainables,
76
- opts.learningRate ?? 1e-2
77
- );
78
- this.weightDecay = opts.weightDecay ?? 0;
79
- this.gradientClip = opts.gradientClip ?? 0;
80
- }
81
- /**
82
- * Performs a parameter update using standard SGD.
83
- */
84
- step(): void {
85
- // Intentionally ignoring weightDecay/gradientClip for SGD
86
- for (const v of this.trainables) {
87
- v.data -= this.learningRate * v.grad;
88
- }
89
- }
90
- }
91
-
92
- /**
93
- * Adam and AdamW optimizer parameters.
94
- * Extends OptimizerOptions.
95
- * @property beta1: Exponential decay rate for 1st moment (default 0.9).
96
- * @property beta2: Exponential decay rate for 2nd moment (default 0.999).
97
- * @property epsilon: Numerical stability fudge factor (default 1e-8).
98
- */
99
- export interface AdamOptions extends OptimizerOptions {
100
- beta1?: number;
101
- beta2?: number;
102
- epsilon?: number;
103
- }
104
-
105
- /**
106
- * Adam optimizer, supports decoupled weight decay and gradient clipping.
107
- */
108
- export class Adam extends Optimizer {
109
- private beta1: number;
110
- private beta2: number;
111
- private epsilon: number;
112
- private weightDecay: number;
113
- private gradientClip: number;
114
- private m: Map<Value, number> = new Map();
115
- private v: Map<Value, number> = new Map();
116
- private stepCount: number = 0;
117
- /**
118
- * Constructs an Adam optimizer.
119
- * @param trainables Array of Value parameters to optimize.
120
- * @param opts Optional parameters (learningRate, weightDecay, gradientClip, beta1, beta2, epsilon).
121
- */
122
- constructor(
123
- trainables: Value[],
124
- opts: AdamOptions = {}
125
- ) {
126
- super(trainables, opts.learningRate ?? 0.001);
127
- this.beta1 = opts.beta1 ?? 0.9;
128
- this.beta2 = opts.beta2 ?? 0.999;
129
- this.epsilon = opts.epsilon ?? 1e-8;
130
- this.weightDecay = opts.weightDecay ?? 0;
131
- this.gradientClip = opts.gradientClip ?? 0;
132
- for (const v of this.trainables) {
133
- this.m.set(v, 0);
134
- this.v.set(v, 0);
135
- }
136
- }
137
- /**
138
- * Performs a parameter update using Adam optimization.
139
- */
140
- step(): void {
141
- this.stepCount++;
142
- for (const v of this.trainables) {
143
- let grad = v.grad;
144
- if (this.weightDecay > 0) grad += this.weightDecay * v.data;
145
-
146
- let m = this.m.get(v)!;
147
- let vVal = this.v.get(v)!;
148
- m = this.beta1 * m + (1 - this.beta1) * grad;
149
- vVal = this.beta2 * vVal + (1 - this.beta2) * grad * grad;
150
-
151
- const mHat = m / (1 - Math.pow(this.beta1, this.stepCount));
152
- const vHat = vVal / (1 - Math.pow(this.beta2, this.stepCount));
153
- let update = mHat / (Math.sqrt(vHat) + this.epsilon);
154
-
155
- if (this.gradientClip > 0) {
156
- update = Math.max(-this.gradientClip, Math.min(update, this.gradientClip));
157
- }
158
- v.data -= this.learningRate * update;
159
-
160
- this.m.set(v, m);
161
- this.v.set(v, vVal);
162
- }
163
- }
164
- }
165
-
166
- /**
167
- * AdamW optimizer, supports decoupled weight decay and gradient clipping (same options as Adam).
168
- */
169
- export class AdamW extends Optimizer {
170
- private beta1: number;
171
- private beta2: number;
172
- private epsilon: number;
173
- private weightDecay: number;
174
- private gradientClip: number;
175
- private m: Map<Value, number> = new Map();
176
- private v: Map<Value, number> = new Map();
177
- private stepCount: number = 0;
178
- /**
179
- * Constructs an AdamW optimizer.
180
- * @param trainables Array of Value parameters to optimize.
181
- * @param opts Optional parameters (learningRate, weightDecay, gradientClip, beta1, beta2, epsilon).
182
- */
183
- constructor(
184
- trainables: Value[],
185
- opts: AdamOptions = {}
186
- ) {
187
- super(trainables, opts.learningRate ?? 0.001);
188
- this.beta1 = opts.beta1 ?? 0.9;
189
- this.beta2 = opts.beta2 ?? 0.999;
190
- this.epsilon = opts.epsilon ?? 1e-8;
191
- this.weightDecay = opts.weightDecay ?? 0.01;
192
- this.gradientClip = opts.gradientClip ?? 0;
193
- for (const v of this.trainables) {
194
- this.m.set(v, 0);
195
- this.v.set(v, 0);
196
- }
197
- }
198
- /**
199
- * Performs a parameter update using AdamW optimization (decoupled weight decay).
200
- */
201
- step(): void {
202
- this.stepCount++;
203
- for (const v of this.trainables) {
204
- let grad = v.grad;
205
- let m = this.m.get(v)!;
206
- let vVal = this.v.get(v)!;
207
- m = this.beta1 * m + (1 - this.beta1) * grad;
208
- vVal = this.beta2 * vVal + (1 - this.beta2) * grad * grad;
209
-
210
- const mHat = m / (1 - Math.pow(this.beta1, this.stepCount));
211
- const vHat = vVal / (1 - Math.pow(this.beta2, this.stepCount));
212
- let update = mHat / (Math.sqrt(vHat) + this.epsilon);
213
- if (this.gradientClip > 0) {
214
- update = Math.max(-this.gradientClip, Math.min(update, this.gradientClip));
215
- }
216
- // Weight decay is decoupled as in AdamW paper:
217
- v.data -= this.learningRate * update + this.learningRate * this.weightDecay * v.data;
218
- this.m.set(v, m);
219
- this.v.set(v, vVal);
220
- }
221
- }
222
- }
package/V.ts DELETED
Binary file
@@ -1,60 +0,0 @@
1
- import { Value } from "./Value";
2
-
3
- // Edge cases and error handling
4
- describe('Value edge cases and error handling', () => {
5
- it('throws on invalid numeric inputs', () => {
6
- expect(() => new Value(NaN)).toThrow();
7
- expect(() => new Value(Infinity)).toThrow();
8
- expect(() => new Value(-Infinity)).toThrow();
9
- });
10
-
11
- it('handles gradient accumulation correctly', () => {
12
- const x = new Value(2, 'x', true);
13
- const y = x.mul(3);
14
- const z = x.mul(4);
15
- const out = y.add(z);
16
- out.backward();
17
- expect(x.grad).toBe(7); // 3 + 4
18
- });
19
-
20
- it('handles repeated use of same value in expression', () => {
21
- const x = new Value(3, 'x', true);
22
- const y = x.mul(x).mul(x); // x^3
23
- y.backward();
24
- expect(x.grad).toBeCloseTo(27); // 3*x^2 = 27
25
- });
26
-
27
- it('throws on division by zero', () => {
28
- const a = new Value(1);
29
- const b = new Value(0);
30
- expect(() => a.div(b)).toThrow();
31
- });
32
-
33
- it('throws on log of negative number', () => {
34
- const x = new Value(-1);
35
- expect(() => x.log()).toThrow();
36
- });
37
-
38
- it('throws on negative base with fractional exponent', () => {
39
- const x = new Value(-2);
40
- expect(() => x.pow(0.5)).toThrow();
41
- });
42
- });
43
-
44
- // Complex expressions
45
- describe('Complex mathematical expressions', () => {
46
- it('computes gradient of complex expression', () => {
47
- const x = new Value(0.5, 'x', true);
48
- const y = x.sin().mul(x.cos()).add(x.exp());
49
- y.backward();
50
- const expected = Math.cos(0.5)**2 - Math.sin(0.5)**2 + Math.exp(0.5);
51
- expect(x.grad).toBeCloseTo(expected, 4);
52
- });
53
-
54
- it('handles nested activation functions', () => {
55
- const x = new Value(0.5, 'x', true);
56
- const y = x.tanh().sigmoid().relu();
57
- y.backward();
58
- expect(x.grad).toBeGreaterThan(0);
59
- });
60
- });
@@ -1,24 +0,0 @@
1
- import { Value } from "./Value";
2
-
3
- describe('Gradient flow control', () => {
4
- it('stops gradient at non-requiresGrad nodes', () => {
5
- const x = new Value(2, 'x', true);
6
- const y = new Value(3, 'y', false);
7
- const z = new Value(4, 'z', true);
8
- const out = x.mul(y).add(z);
9
- out.backward();
10
- expect(x.grad).toBe(3);
11
- expect(y.grad).toBe(0);
12
- expect(z.grad).toBe(1);
13
- });
14
-
15
- it('handles detached computation graphs', () => {
16
- const x = new Value(2, 'x', true);
17
- const y = x.mul(3);
18
- const z = new Value(y.data, 'z', true); // detached
19
- const out = z.mul(4);
20
- out.backward();
21
- expect(z.grad).toBe(4);
22
- expect(x.grad).toBe(0); // no gradient flows to x
23
- });
24
- });
@@ -1,32 +0,0 @@
1
- import { Value } from "./Value";
2
- import { Losses } from "./Losses";
3
-
4
- describe('Loss function edge cases', () => {
5
- it('handles empty arrays', () => {
6
- expect(Losses.mse([], []).data).toBe(0);
7
- expect(Losses.mae([], []).data).toBe(0);
8
- expect(Losses.binaryCrossEntropy([], []).data).toBe(0);
9
- expect(Losses.categoricalCrossEntropy([], []).data).toBe(0);
10
- });
11
-
12
- it('throws on mismatched lengths', () => {
13
- const a = [new Value(1)];
14
- const b = [new Value(1), new Value(2)];
15
- expect(() => Losses.mse(a, b)).toThrow();
16
- });
17
-
18
- it('handles extreme values in binary cross entropy', () => {
19
- const out = new Value(0.999999, 'out', true);
20
- const target = new Value(1);
21
- const loss = Losses.binaryCrossEntropy([out], [target]);
22
- expect(loss.data).toBeGreaterThan(0);
23
- expect(loss.data).toBeLessThan(0.1);
24
- });
25
-
26
- it('throws on invalid class indices in categorical cross entropy', () => {
27
- const outputs = [new Value(1), new Value(2)];
28
- expect(() => Losses.categoricalCrossEntropy(outputs, [2])).toThrow();
29
- expect(() => Losses.categoricalCrossEntropy(outputs, [-1])).toThrow();
30
- expect(() => Losses.categoricalCrossEntropy(outputs, [1.5])).toThrow();
31
- });
32
- });
@@ -1,25 +0,0 @@
1
- import { Value } from "./Value";
2
-
3
- describe('Memory management', () => {
4
- it('handles large computation graphs', () => {
5
- let x = new Value(1, 'x', true);
6
- for (let i = 0; i < 100; i++) {
7
- x = x.add(1).mul(1.01);
8
- }
9
- expect(() => x.backward()).not.toThrow();
10
- });
11
-
12
- it('zeroGradAll handles multiple disconnected graphs', () => {
13
- const x1 = new Value(1, 'x1', true);
14
- const y1 = x1.mul(2);
15
- const x2 = new Value(2, 'x2', true);
16
- const y2 = x2.mul(3);
17
-
18
- y1.backward();
19
- y2.backward();
20
-
21
- Value.zeroGradAll([y1, y2]);
22
- expect(x1.grad).toBe(0);
23
- expect(x2.grad).toBe(0);
24
- });
25
- });
package/Value.nn.spec.ts DELETED
@@ -1,109 +0,0 @@
1
- import { Value } from "./Value";
2
- import { SGD, Adam } from "./Optimizers";
3
- import { Losses } from "./Losses";
4
-
5
- describe("can train scalar neural networks on minimal problems", () => {
6
-
7
- it("1. learns linear regression (y = 2x + 3) with SGD", () => {
8
- let w = new Value(Math.random(), "w", true);
9
- let b = new Value(Math.random(), "b", true);
10
- const examples = [
11
- { x: 1, y: 5 },
12
- { x: 2, y: 7 },
13
- { x: 3, y: 9 },
14
- ];
15
- const opt = new SGD([w, b], { learningRate: 0.1 });
16
- for (let epoch = 0; epoch < 300; ++epoch) {
17
- let preds: Value[] = [];
18
- let targets: Value[] = [];
19
- for (const ex of examples) {
20
- const x = new Value(ex.x, "x");
21
- const pred = w.mul(x).add(b);
22
- preds.push(pred);
23
- targets.push(new Value(ex.y));
24
- }
25
- let loss = Losses.mse(preds, targets);
26
- if (loss.data < 1e-4) break;
27
- w.grad = 0; b.grad = 0;
28
- loss.backward();
29
- opt.step();
30
- }
31
- expect(w.data).toBeCloseTo(2, 1);
32
- expect(b.data).toBeCloseTo(3, 1);
33
- });
34
-
35
- it("2. learns quadratic fit (y = x^2) with SGD", () => {
36
- let a = new Value(Math.random(), "a", true);
37
- let b = new Value(Math.random(), "b", true);
38
- let c = new Value(Math.random(), "c", true);
39
- const examples = [
40
- { x: -1, y: 1 },
41
- { x: 0, y: 0 },
42
- { x: 2, y: 4 },
43
- { x: 3, y: 9 },
44
- ];
45
- const opt = new SGD([a, b, c], { learningRate: 0.01 });
46
-
47
- for (let epoch = 0; epoch < 400; ++epoch) {
48
- let preds: Value[] = [];
49
- let targets: Value[] = [];
50
- for (const ex of examples) {
51
- const x = new Value(ex.x);
52
- const pred = a.mul(x.square()).add(b.mul(x)).add(c);
53
- preds.push(pred);
54
- targets.push(new Value(ex.y));
55
- }
56
- let loss = Losses.mse(preds, targets);
57
- if (loss.data < 1e-4) break;
58
- a.grad = 0; b.grad = 0; c.grad = 0;
59
- loss.backward();
60
- opt.step();
61
- }
62
- expect(a.data).toBeCloseTo(1, 1);
63
- expect(Math.abs(b.data)).toBeLessThan(0.5);
64
- expect(Math.abs(c.data)).toBeLessThan(0.5);
65
- });
66
-
67
- /*
68
- // This is hard to get to work reliably, I believe it's a difficult problem to solve!?
69
- it("3. learns XOR with tiny MLP (2-2-1) with SGD", () => {
70
- function mlp(x1: Value, x2: Value, params: Value[]): Value {
71
- const [w1, w2, w3, w4, b1, b2, v1, v2, c] = params;
72
- const h1 = w1.mul(x1).add(w2.mul(x2)).add(b1).tanh();
73
- const h2 = w3.mul(x1).add(w4.mul(x2)).add(b2).tanh();
74
- return v1.mul(h1).add(v2.mul(h2)).add(c).sigmoid();
75
- }
76
- let params = Array.from({ length: 9 }, (_, i) => new Value(Math.random() - 0.5, "p" + i, true));
77
- const data = [
78
- { x: [0, 0], y: 0 },
79
- { x: [0, 1], y: 1 },
80
- { x: [1, 0], y: 1 },
81
- { x: [1, 1], y: 0 },
82
- ];
83
- const opt = new SGD(params, { learningRate: 0.01 });
84
- for (let epoch = 0; epoch < 5000; ++epoch) {
85
- let preds: Value[] = [];
86
- let targets: Value[] = [];
87
- for (const ex of data) {
88
- const x1 = new Value(ex.x[0]);
89
- const x2 = new Value(ex.x[1]);
90
- const pred = mlp(x1, x2, params);
91
- preds.push(pred);
92
- targets.push(new Value(ex.y));
93
- }
94
- let loss = binaryCrossEntropy(preds, targets);
95
- if (loss.data < 1e-3) break;
96
- for (const p of params) p.grad = 0;
97
- loss.backward();
98
- opt.step();
99
- }
100
- const out00 = mlp(new Value(0), new Value(0), params).data;
101
- const out01 = mlp(new Value(0), new Value(1), params).data;
102
- const out10 = mlp(new Value(1), new Value(0), params).data;
103
- const out11 = mlp(new Value(1), new Value(1), params).data;
104
- expect((out00 < 0.4 || out00 > 0.6)).toBe(true);
105
- expect(out01).toBeGreaterThan(0.6);
106
- expect(out10).toBeGreaterThan(0.6);
107
- expect(out11).toBeLessThan(0.4);
108
- });*/
109
- });