scalar-autograd 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +127 -2
- package/dist/CompiledFunctions.d.ts +111 -0
- package/dist/CompiledFunctions.js +268 -0
- package/dist/CompiledResiduals.d.ts +74 -0
- package/dist/CompiledResiduals.js +94 -0
- package/dist/EigenvalueHelpers.d.ts +14 -0
- package/dist/EigenvalueHelpers.js +93 -0
- package/dist/Geometry.d.ts +131 -0
- package/dist/Geometry.js +213 -0
- package/dist/GraphBuilder.d.ts +64 -0
- package/dist/GraphBuilder.js +237 -0
- package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
- package/dist/GraphCanonicalizerNoSort.js +190 -0
- package/dist/GraphHashCanonicalizer.d.ts +46 -0
- package/dist/GraphHashCanonicalizer.js +220 -0
- package/dist/GraphSignature.d.ts +7 -0
- package/dist/GraphSignature.js +7 -0
- package/dist/KernelPool.d.ts +55 -0
- package/dist/KernelPool.js +124 -0
- package/dist/LBFGS.d.ts +84 -0
- package/dist/LBFGS.js +313 -0
- package/dist/LinearSolver.d.ts +69 -0
- package/dist/LinearSolver.js +213 -0
- package/dist/Losses.d.ts +9 -0
- package/dist/Losses.js +42 -37
- package/dist/Matrix3x3.d.ts +50 -0
- package/dist/Matrix3x3.js +146 -0
- package/dist/NonlinearLeastSquares.d.ts +33 -0
- package/dist/NonlinearLeastSquares.js +252 -0
- package/dist/Optimizers.d.ts +70 -14
- package/dist/Optimizers.js +42 -19
- package/dist/V.d.ts +0 -0
- package/dist/V.js +0 -0
- package/dist/Value.d.ts +84 -2
- package/dist/Value.js +296 -58
- package/dist/ValueActivation.js +10 -14
- package/dist/ValueArithmetic.d.ts +1 -0
- package/dist/ValueArithmetic.js +58 -50
- package/dist/ValueComparison.js +9 -13
- package/dist/ValueRegistry.d.ts +38 -0
- package/dist/ValueRegistry.js +88 -0
- package/dist/ValueTrig.js +14 -18
- package/dist/Vec2.d.ts +45 -0
- package/dist/Vec2.js +93 -0
- package/dist/Vec3.d.ts +78 -0
- package/dist/Vec3.js +169 -0
- package/dist/Vec4.d.ts +45 -0
- package/dist/Vec4.js +126 -0
- package/dist/__tests__/duplicate-inputs.test.js +33 -0
- package/dist/cli/gradient-gen.d.ts +19 -0
- package/dist/cli/gradient-gen.js +264 -0
- package/dist/compileIndirectKernel.d.ts +24 -0
- package/dist/compileIndirectKernel.js +148 -0
- package/dist/index.d.ts +20 -0
- package/dist/index.js +20 -0
- package/dist/scalar-autograd.d.ts +1157 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/dist/tsdoc-metadata.json +11 -0
- package/package.json +29 -5
- package/dist/Losses.spec.js +0 -54
- package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
- package/dist/Optimizers.edge-cases.spec.js +0 -29
- package/dist/Optimizers.spec.d.ts +0 -1
- package/dist/Optimizers.spec.js +0 -56
- package/dist/Value.edge-cases.spec.d.ts +0 -1
- package/dist/Value.edge-cases.spec.js +0 -54
- package/dist/Value.grad-flow.spec.d.ts +0 -1
- package/dist/Value.grad-flow.spec.js +0 -24
- package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
- package/dist/Value.losses-edge-cases.spec.js +0 -30
- package/dist/Value.memory.spec.d.ts +0 -1
- package/dist/Value.memory.spec.js +0 -23
- package/dist/Value.nn.spec.d.ts +0 -1
- package/dist/Value.nn.spec.js +0 -111
- package/dist/Value.spec.d.ts +0 -1
- package/dist/Value.spec.js +0 -245
- /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
package/dist/ValueArithmetic.js
CHANGED
|
@@ -1,66 +1,57 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
exports.ValueArithmetic = void 0;
|
|
4
|
-
const Value_1 = require("./Value");
|
|
5
|
-
class ValueArithmetic {
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
export class ValueArithmetic {
|
|
6
3
|
static add(a, b) {
|
|
7
|
-
return
|
|
4
|
+
return Value.make(a.data + b.data, a, b, (out) => () => {
|
|
8
5
|
if (a.requiresGrad)
|
|
9
6
|
a.grad += 1 * out.grad;
|
|
10
7
|
if (b.requiresGrad)
|
|
11
8
|
b.grad += 1 * out.grad;
|
|
12
|
-
}, `(${a.label}+${b.label})
|
|
9
|
+
}, `(${a.label}+${b.label})`, '+');
|
|
13
10
|
}
|
|
14
11
|
static sqrt(a) {
|
|
15
12
|
if (a.data < 0) {
|
|
16
13
|
throw new Error(`Cannot take sqrt of negative number: ${a.data}`);
|
|
17
14
|
}
|
|
18
15
|
const root = Math.sqrt(a.data);
|
|
19
|
-
return
|
|
16
|
+
return Value.make(root, a, null, (out) => () => {
|
|
20
17
|
if (a.requiresGrad)
|
|
21
18
|
a.grad += 0.5 / root * out.grad;
|
|
22
|
-
}, `sqrt(${a.label})
|
|
19
|
+
}, `sqrt(${a.label})`, 'sqrt');
|
|
23
20
|
}
|
|
24
21
|
static mul(a, b) {
|
|
25
|
-
return
|
|
22
|
+
return Value.make(a.data * b.data, a, b, (out) => () => {
|
|
26
23
|
if (a.requiresGrad)
|
|
27
24
|
a.grad += b.data * out.grad;
|
|
28
25
|
if (b.requiresGrad)
|
|
29
26
|
b.grad += a.data * out.grad;
|
|
30
|
-
}, `(${a.label}*${b.label})
|
|
27
|
+
}, `(${a.label}*${b.label})`, '*');
|
|
31
28
|
}
|
|
32
29
|
static sub(a, b) {
|
|
33
|
-
return
|
|
30
|
+
return Value.make(a.data - b.data, a, b, (out) => () => {
|
|
34
31
|
if (a.requiresGrad)
|
|
35
32
|
a.grad += 1 * out.grad;
|
|
36
33
|
if (b.requiresGrad)
|
|
37
34
|
b.grad -= 1 * out.grad;
|
|
38
|
-
}, `(${a.label}-${b.label})
|
|
35
|
+
}, `(${a.label}-${b.label})`, '-');
|
|
39
36
|
}
|
|
40
37
|
static div(a, b, eps = 1e-12) {
|
|
41
38
|
if (Math.abs(b.data) < eps) {
|
|
42
39
|
throw new Error(`Division by zero or near-zero encountered in div: denominator=${b.data}`);
|
|
43
40
|
}
|
|
44
41
|
const safe = b.data;
|
|
45
|
-
return
|
|
42
|
+
return Value.make(a.data / safe, a, b, (out) => () => {
|
|
46
43
|
if (a.requiresGrad)
|
|
47
44
|
a.grad += (1 / safe) * out.grad;
|
|
48
45
|
if (b.requiresGrad)
|
|
49
46
|
b.grad -= (a.data / (safe ** 2)) * out.grad;
|
|
50
|
-
}, `(${a.label}/${b.label})
|
|
47
|
+
}, `(${a.label}/${b.label})`, '/');
|
|
51
48
|
}
|
|
52
49
|
static pow(a, exp) {
|
|
53
50
|
if (typeof exp !== "number" || Number.isNaN(exp) || !Number.isFinite(exp)) {
|
|
54
51
|
throw new Error(`Exponent must be a finite number, got ${exp}`);
|
|
55
52
|
}
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
}
|
|
59
|
-
const safeBase = a.data;
|
|
60
|
-
return Value_1.Value.make(Math.pow(safeBase, exp), a, null, (out) => () => {
|
|
61
|
-
if (a.requiresGrad)
|
|
62
|
-
a.grad += exp * Math.pow(safeBase, exp - 1) * out.grad;
|
|
63
|
-
}, `(${a.label}^${exp})`);
|
|
53
|
+
const expValue = new Value(exp, String(exp), false);
|
|
54
|
+
return ValueArithmetic.powValue(a, expValue);
|
|
64
55
|
}
|
|
65
56
|
static powValue(a, b, eps = 1e-12) {
|
|
66
57
|
if (a.data < 0 && Math.abs(b.data % 1) > eps) {
|
|
@@ -70,33 +61,33 @@ class ValueArithmetic {
|
|
|
70
61
|
throw new Error(`0 cannot be raised to zero or negative power: ${b.data}`);
|
|
71
62
|
}
|
|
72
63
|
const safeBase = a.data;
|
|
73
|
-
return
|
|
64
|
+
return Value.make(Math.pow(safeBase, b.data), a, b, (out) => () => {
|
|
74
65
|
a.grad += b.data * Math.pow(safeBase, b.data - 1) * out.grad;
|
|
75
66
|
b.grad += Math.log(Math.max(eps, safeBase)) * Math.pow(safeBase, b.data) * out.grad;
|
|
76
|
-
}, `(${a.label}^${b.label})
|
|
67
|
+
}, `(${a.label}^${b.label})`, 'powValue');
|
|
77
68
|
}
|
|
78
69
|
static mod(a, b) {
|
|
79
70
|
if (typeof b.data !== 'number' || b.data === 0) {
|
|
80
71
|
throw new Error(`Modulo by zero encountered`);
|
|
81
72
|
}
|
|
82
|
-
return
|
|
73
|
+
return Value.make(a.data % b.data, a, b, (out) => () => {
|
|
83
74
|
a.grad += 1 * out.grad;
|
|
84
75
|
// No grad to b (modulus not used in most diff cases)
|
|
85
|
-
}, `(${a.label}%${b.label})
|
|
76
|
+
}, `(${a.label}%${b.label})`, 'mod');
|
|
86
77
|
}
|
|
87
78
|
static abs(a) {
|
|
88
79
|
const d = Math.abs(a.data);
|
|
89
|
-
return
|
|
80
|
+
return Value.make(d, a, null, (out) => () => {
|
|
90
81
|
if (a.requiresGrad)
|
|
91
82
|
a.grad += (a.data >= 0 ? 1 : -1) * out.grad;
|
|
92
|
-
}, `abs(${a.label})
|
|
83
|
+
}, `abs(${a.label})`, 'abs');
|
|
93
84
|
}
|
|
94
85
|
static exp(a) {
|
|
95
86
|
const e = Math.exp(a.data);
|
|
96
|
-
return
|
|
87
|
+
return Value.make(e, a, null, (out) => () => {
|
|
97
88
|
if (a.requiresGrad)
|
|
98
89
|
a.grad += e * out.grad;
|
|
99
|
-
}, `exp(${a.label})
|
|
90
|
+
}, `exp(${a.label})`, 'exp');
|
|
100
91
|
}
|
|
101
92
|
static log(a, eps = 1e-12) {
|
|
102
93
|
if (a.data <= 0) {
|
|
@@ -104,40 +95,40 @@ class ValueArithmetic {
|
|
|
104
95
|
}
|
|
105
96
|
const safe = Math.max(a.data, eps);
|
|
106
97
|
const l = Math.log(safe);
|
|
107
|
-
return
|
|
98
|
+
return Value.make(l, a, null, (out) => () => {
|
|
108
99
|
if (a.requiresGrad)
|
|
109
100
|
a.grad += (1 / safe) * out.grad;
|
|
110
|
-
}, `log(${a.label})
|
|
101
|
+
}, `log(${a.label})`, 'log');
|
|
111
102
|
}
|
|
112
103
|
static min(a, b) {
|
|
113
104
|
const d = Math.min(a.data, b.data);
|
|
114
|
-
return
|
|
105
|
+
return Value.make(d, a, b, (out) => () => {
|
|
115
106
|
if (a.requiresGrad)
|
|
116
107
|
a.grad += (a.data < b.data ? 1 : 0) * out.grad;
|
|
117
108
|
if (b.requiresGrad)
|
|
118
109
|
b.grad += (b.data < a.data ? 1 : 0) * out.grad;
|
|
119
|
-
}, `min(${a.label},${b.label})
|
|
110
|
+
}, `min(${a.label},${b.label})`, 'min');
|
|
120
111
|
}
|
|
121
112
|
static max(a, b) {
|
|
122
113
|
const d = Math.max(a.data, b.data);
|
|
123
|
-
return
|
|
114
|
+
return Value.make(d, a, b, (out) => () => {
|
|
124
115
|
if (a.requiresGrad)
|
|
125
116
|
a.grad += (a.data > b.data ? 1 : 0) * out.grad;
|
|
126
117
|
if (b.requiresGrad)
|
|
127
118
|
b.grad += (b.data > a.data ? 1 : 0) * out.grad;
|
|
128
|
-
}, `max(${a.label},${b.label})
|
|
119
|
+
}, `max(${a.label},${b.label})`, 'max');
|
|
129
120
|
}
|
|
130
121
|
static floor(a) {
|
|
131
122
|
const fl = Math.floor(a.data);
|
|
132
|
-
return
|
|
123
|
+
return Value.make(fl, a, null, () => () => { }, `floor(${a.label})`, 'floor');
|
|
133
124
|
}
|
|
134
125
|
static ceil(a) {
|
|
135
126
|
const cl = Math.ceil(a.data);
|
|
136
|
-
return
|
|
127
|
+
return Value.make(cl, a, null, () => () => { }, `ceil(${a.label})`, 'ceil');
|
|
137
128
|
}
|
|
138
129
|
static round(a) {
|
|
139
130
|
const rd = Math.round(a.data);
|
|
140
|
-
return
|
|
131
|
+
return Value.make(rd, a, null, () => () => { }, `round(${a.label})`, 'round');
|
|
141
132
|
}
|
|
142
133
|
static square(a) {
|
|
143
134
|
return ValueArithmetic.pow(a, 2);
|
|
@@ -149,32 +140,49 @@ class ValueArithmetic {
|
|
|
149
140
|
if (Math.abs(a.data) < eps) {
|
|
150
141
|
throw new Error(`Reciprocal of zero or near-zero detected`);
|
|
151
142
|
}
|
|
152
|
-
return
|
|
143
|
+
return Value.make(1 / a.data, a, null, (out) => () => {
|
|
153
144
|
if (a.requiresGrad)
|
|
154
145
|
a.grad += -1 / (a.data * a.data) * out.grad;
|
|
155
|
-
}, `reciprocal(${a.label})
|
|
146
|
+
}, `reciprocal(${a.label})`, 'reciprocal');
|
|
156
147
|
}
|
|
157
148
|
static clamp(a, min, max) {
|
|
158
149
|
let val = Math.max(min, Math.min(a.data, max));
|
|
159
|
-
|
|
150
|
+
const out = Value.make(val, a, null, (out) => () => {
|
|
160
151
|
a.grad += (a.data > min && a.data < max ? 1 : 0) * out.grad;
|
|
161
|
-
}, `clamp(${a.label},${min},${max})
|
|
152
|
+
}, `clamp(${a.label},${min},${max})`, 'clamp');
|
|
153
|
+
out._opConstants = [min, max];
|
|
154
|
+
return out;
|
|
162
155
|
}
|
|
163
156
|
static sum(vals) {
|
|
164
157
|
if (!vals.length)
|
|
165
|
-
return new
|
|
166
|
-
|
|
158
|
+
return new Value(0);
|
|
159
|
+
if (vals.length === 1)
|
|
160
|
+
return vals[0];
|
|
161
|
+
// N-ary sum to avoid deep chains
|
|
162
|
+
const sum = vals.reduce((acc, v) => acc + v.data, 0);
|
|
163
|
+
return Value.makeNary(sum, vals, (out) => () => {
|
|
164
|
+
for (const v of vals) {
|
|
165
|
+
if (v.requiresGrad)
|
|
166
|
+
v.grad += out.grad;
|
|
167
|
+
}
|
|
168
|
+
}, `sum(${vals.length})`, 'sum');
|
|
167
169
|
}
|
|
168
170
|
static mean(vals) {
|
|
169
171
|
if (!vals.length)
|
|
170
|
-
return new
|
|
172
|
+
return new Value(0);
|
|
171
173
|
return ValueArithmetic.sum(vals).div(vals.length);
|
|
172
174
|
}
|
|
173
175
|
static neg(a) {
|
|
174
|
-
return
|
|
176
|
+
return Value.make(-a.data, a, null, (out) => () => {
|
|
175
177
|
if (a.requiresGrad)
|
|
176
178
|
a.grad -= out.grad;
|
|
177
|
-
}, `(-${a.label})
|
|
179
|
+
}, `(-${a.label})`, 'neg');
|
|
180
|
+
}
|
|
181
|
+
static sign(a) {
|
|
182
|
+
const s = Math.sign(a.data);
|
|
183
|
+
return Value.make(s, a, null, (out) => () => {
|
|
184
|
+
if (a.requiresGrad)
|
|
185
|
+
a.grad += 0 * out.grad;
|
|
186
|
+
}, `sign(${a.label})`, 'sign');
|
|
178
187
|
}
|
|
179
188
|
}
|
|
180
|
-
exports.ValueArithmetic = ValueArithmetic;
|
package/dist/ValueComparison.js
CHANGED
|
@@ -1,15 +1,12 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
exports.ValueComparison = void 0;
|
|
4
|
-
const Value_1 = require("./Value");
|
|
5
|
-
class ValueComparison {
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
export class ValueComparison {
|
|
6
3
|
static eq(a, b) {
|
|
7
|
-
return
|
|
4
|
+
return Value.make(a.data === b.data ? 1 : 0, a, b, (out) => () => {
|
|
8
5
|
// No gradient - discrete operation
|
|
9
6
|
}, `(${a.label}==${b.label})`);
|
|
10
7
|
}
|
|
11
8
|
static ifThenElse(cond, thenVal, elseVal) {
|
|
12
|
-
return
|
|
9
|
+
return Value.make(cond.data ? thenVal.data : elseVal.data, cond, cond.data ? thenVal : elseVal, (out) => () => {
|
|
13
10
|
if (cond.data) {
|
|
14
11
|
thenVal.grad += out.grad;
|
|
15
12
|
}
|
|
@@ -19,29 +16,28 @@ class ValueComparison {
|
|
|
19
16
|
}, `if(${cond.label}){${thenVal.label}}else{${elseVal.label}}`);
|
|
20
17
|
}
|
|
21
18
|
static neq(a, b) {
|
|
22
|
-
return
|
|
19
|
+
return Value.make(a.data !== b.data ? 1 : 0, a, b, (out) => () => {
|
|
23
20
|
// No gradient - discrete operation
|
|
24
21
|
}, `(${a.label}!=${b.label})`);
|
|
25
22
|
}
|
|
26
23
|
static gt(a, b) {
|
|
27
|
-
return
|
|
24
|
+
return Value.make(a.data > b.data ? 1 : 0, a, b, (out) => () => {
|
|
28
25
|
// No gradient - discrete operation
|
|
29
26
|
}, `(${a.label}>${b.label})`);
|
|
30
27
|
}
|
|
31
28
|
static lt(a, b) {
|
|
32
|
-
return
|
|
29
|
+
return Value.make(a.data < b.data ? 1 : 0, a, b, (out) => () => {
|
|
33
30
|
// No gradient - discrete operation
|
|
34
31
|
}, `(${a.label}<${b.label})`);
|
|
35
32
|
}
|
|
36
33
|
static gte(a, b) {
|
|
37
|
-
return
|
|
34
|
+
return Value.make(a.data >= b.data ? 1 : 0, a, b, (out) => () => {
|
|
38
35
|
// No gradient - discrete operation
|
|
39
36
|
}, `(${a.label}>=${b.label})`);
|
|
40
37
|
}
|
|
41
38
|
static lte(a, b) {
|
|
42
|
-
return
|
|
39
|
+
return Value.make(a.data <= b.data ? 1 : 0, a, b, (out) => () => {
|
|
43
40
|
// No gradient - discrete operation
|
|
44
41
|
}, `(${a.label}<=${b.label})`);
|
|
45
42
|
}
|
|
46
43
|
}
|
|
47
|
-
exports.ValueComparison = ValueComparison;
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
/**
|
|
3
|
+
* Registry for tracking unique Value objects across residual graphs.
|
|
4
|
+
* Handles deduplication of constants and optionally variables.
|
|
5
|
+
* @internal
|
|
6
|
+
*/
|
|
7
|
+
export declare class ValueRegistry {
|
|
8
|
+
private values;
|
|
9
|
+
private valueToId;
|
|
10
|
+
/**
|
|
11
|
+
* Register a Value and return its unique ID.
|
|
12
|
+
* Deduplication rules:
|
|
13
|
+
* - Constants (requiresGrad=false, no prev): dedupe by data value only
|
|
14
|
+
* - Variables (requiresGrad=true, has paramName): dedupe by paramName
|
|
15
|
+
* - Weights & computed values: always unique
|
|
16
|
+
*/
|
|
17
|
+
register(value: Value): number;
|
|
18
|
+
/**
|
|
19
|
+
* Get Value by ID
|
|
20
|
+
*/
|
|
21
|
+
getValue(id: number): Value;
|
|
22
|
+
/**
|
|
23
|
+
* Get all registered values as data array
|
|
24
|
+
*/
|
|
25
|
+
getDataArray(): number[];
|
|
26
|
+
/**
|
|
27
|
+
* Update value array from current Value.data
|
|
28
|
+
*/
|
|
29
|
+
updateDataArray(dataArray: number[]): void;
|
|
30
|
+
/**
|
|
31
|
+
* Get ID for a Value (must be registered)
|
|
32
|
+
*/
|
|
33
|
+
getId(value: Value): number;
|
|
34
|
+
/**
|
|
35
|
+
* Total number of unique values
|
|
36
|
+
*/
|
|
37
|
+
get size(): number;
|
|
38
|
+
}
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Registry for tracking unique Value objects across residual graphs.
|
|
3
|
+
* Handles deduplication of constants and optionally variables.
|
|
4
|
+
* @internal
|
|
5
|
+
*/
|
|
6
|
+
export class ValueRegistry {
|
|
7
|
+
values = [];
|
|
8
|
+
valueToId = new Map();
|
|
9
|
+
/**
|
|
10
|
+
* Register a Value and return its unique ID.
|
|
11
|
+
* Deduplication rules:
|
|
12
|
+
* - Constants (requiresGrad=false, no prev): dedupe by data value only
|
|
13
|
+
* - Variables (requiresGrad=true, has paramName): dedupe by paramName
|
|
14
|
+
* - Weights & computed values: always unique
|
|
15
|
+
*/
|
|
16
|
+
register(value) {
|
|
17
|
+
// Check if already registered
|
|
18
|
+
if (this.valueToId.has(value)) {
|
|
19
|
+
return this.valueToId.get(value);
|
|
20
|
+
}
|
|
21
|
+
// Constants: dedupe by value only (ignore labels/paramNames)
|
|
22
|
+
if (!value.requiresGrad && value.prev.length === 0) {
|
|
23
|
+
const existing = this.values.find(v => !v.requiresGrad &&
|
|
24
|
+
v.prev.length === 0 &&
|
|
25
|
+
v.data === value.data);
|
|
26
|
+
if (existing) {
|
|
27
|
+
const existingId = this.valueToId.get(existing);
|
|
28
|
+
this.valueToId.set(value, existingId);
|
|
29
|
+
value._registryId = existingId;
|
|
30
|
+
return existingId;
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
// Variables: dedupe by paramName if present
|
|
34
|
+
if (value.requiresGrad && value.paramName && value.prev.length === 0) {
|
|
35
|
+
const existing = this.values.find(v => v.requiresGrad &&
|
|
36
|
+
v.paramName === value.paramName &&
|
|
37
|
+
v.prev.length === 0);
|
|
38
|
+
if (existing) {
|
|
39
|
+
const existingId = this.valueToId.get(existing);
|
|
40
|
+
this.valueToId.set(value, existingId);
|
|
41
|
+
value._registryId = existingId;
|
|
42
|
+
return existingId;
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
// Weights & computed values: always unique
|
|
46
|
+
const id = this.values.length;
|
|
47
|
+
this.values.push(value);
|
|
48
|
+
this.valueToId.set(value, id);
|
|
49
|
+
value._registryId = id;
|
|
50
|
+
return id;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* Get Value by ID
|
|
54
|
+
*/
|
|
55
|
+
getValue(id) {
|
|
56
|
+
return this.values[id];
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Get all registered values as data array
|
|
60
|
+
*/
|
|
61
|
+
getDataArray() {
|
|
62
|
+
return this.values.map(v => v.data);
|
|
63
|
+
}
|
|
64
|
+
/**
|
|
65
|
+
* Update value array from current Value.data
|
|
66
|
+
*/
|
|
67
|
+
updateDataArray(dataArray) {
|
|
68
|
+
for (let i = 0; i < this.values.length; i++) {
|
|
69
|
+
dataArray[i] = this.values[i].data;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
/**
|
|
73
|
+
* Get ID for a Value (must be registered)
|
|
74
|
+
*/
|
|
75
|
+
getId(value) {
|
|
76
|
+
const id = this.valueToId.get(value);
|
|
77
|
+
if (id === undefined) {
|
|
78
|
+
throw new Error('Value not registered');
|
|
79
|
+
}
|
|
80
|
+
return id;
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Total number of unique values
|
|
84
|
+
*/
|
|
85
|
+
get size() {
|
|
86
|
+
return this.values.length;
|
|
87
|
+
}
|
|
88
|
+
}
|
package/dist/ValueTrig.js
CHANGED
|
@@ -1,49 +1,45 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
exports.ValueTrig = void 0;
|
|
4
|
-
const Value_1 = require("./Value");
|
|
5
|
-
class ValueTrig {
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
export class ValueTrig {
|
|
6
3
|
static sin(x) {
|
|
7
4
|
const s = Math.sin(x.data);
|
|
8
|
-
return
|
|
5
|
+
return Value.make(s, x, null, (out) => () => {
|
|
9
6
|
if (x.requiresGrad)
|
|
10
7
|
x.grad += Math.cos(x.data) * out.grad;
|
|
11
|
-
}, `sin(${x.label})
|
|
8
|
+
}, `sin(${x.label})`, 'sin');
|
|
12
9
|
}
|
|
13
10
|
static cos(x) {
|
|
14
11
|
const c = Math.cos(x.data);
|
|
15
|
-
return
|
|
12
|
+
return Value.make(c, x, null, (out) => () => {
|
|
16
13
|
if (x.requiresGrad)
|
|
17
14
|
x.grad += -Math.sin(x.data) * out.grad;
|
|
18
|
-
}, `cos(${x.label})
|
|
15
|
+
}, `cos(${x.label})`, 'cos');
|
|
19
16
|
}
|
|
20
17
|
static tan(x) {
|
|
21
18
|
const t = Math.tan(x.data);
|
|
22
|
-
return
|
|
19
|
+
return Value.make(t, x, null, (out) => () => {
|
|
23
20
|
if (x.requiresGrad)
|
|
24
21
|
x.grad += (1 / (Math.cos(x.data) ** 2)) * out.grad;
|
|
25
|
-
}, `tan(${x.label})
|
|
22
|
+
}, `tan(${x.label})`, 'tan');
|
|
26
23
|
}
|
|
27
24
|
static asin(x) {
|
|
28
25
|
const s = Math.asin(x.data);
|
|
29
|
-
return
|
|
26
|
+
return Value.make(s, x, null, (out) => () => {
|
|
30
27
|
if (x.requiresGrad)
|
|
31
28
|
x.grad += (1 / Math.sqrt(1 - x.data * x.data)) * out.grad;
|
|
32
|
-
}, `asin(${x.label})
|
|
29
|
+
}, `asin(${x.label})`, 'asin');
|
|
33
30
|
}
|
|
34
31
|
static acos(x) {
|
|
35
32
|
const c = Math.acos(x.data);
|
|
36
|
-
return
|
|
33
|
+
return Value.make(c, x, null, (out) => () => {
|
|
37
34
|
if (x.requiresGrad)
|
|
38
35
|
x.grad += (-1 / Math.sqrt(1 - x.data * x.data)) * out.grad;
|
|
39
|
-
}, `acos(${x.label})
|
|
36
|
+
}, `acos(${x.label})`, 'acos');
|
|
40
37
|
}
|
|
41
38
|
static atan(x) {
|
|
42
39
|
const a = Math.atan(x.data);
|
|
43
|
-
return
|
|
40
|
+
return Value.make(a, x, null, (out) => () => {
|
|
44
41
|
if (x.requiresGrad)
|
|
45
42
|
x.grad += (1 / (1 + x.data * x.data)) * out.grad;
|
|
46
|
-
}, `atan(${x.label})
|
|
43
|
+
}, `atan(${x.label})`, 'atan');
|
|
47
44
|
}
|
|
48
45
|
}
|
|
49
|
-
exports.ValueTrig = ValueTrig;
|
package/dist/Vec2.d.ts
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
/**
|
|
3
|
+
* 2D vector class with differentiable operations.
|
|
4
|
+
* @public
|
|
5
|
+
*/
|
|
6
|
+
export declare class Vec2 {
|
|
7
|
+
x: Value;
|
|
8
|
+
y: Value;
|
|
9
|
+
constructor(x: Value | number, y: Value | number);
|
|
10
|
+
static W(x: number, y: number): Vec2;
|
|
11
|
+
static C(x: number, y: number): Vec2;
|
|
12
|
+
static zero(): Vec2;
|
|
13
|
+
static one(): Vec2;
|
|
14
|
+
get magnitude(): Value;
|
|
15
|
+
get sqrMagnitude(): Value;
|
|
16
|
+
get normalized(): Vec2;
|
|
17
|
+
static dot(a: Vec2, b: Vec2): Value;
|
|
18
|
+
add(other: Vec2): Vec2;
|
|
19
|
+
sub(other: Vec2): Vec2;
|
|
20
|
+
mul(scalar: Value | number): Vec2;
|
|
21
|
+
div(scalar: Value | number): Vec2;
|
|
22
|
+
/**
|
|
23
|
+
* 2D cross product (returns scalar z-component of 3D cross product).
|
|
24
|
+
* Also known as the "perpendicular dot product" or "wedge product".
|
|
25
|
+
* Returns positive if b is counter-clockwise from a, negative if clockwise.
|
|
26
|
+
*/
|
|
27
|
+
static cross(a: Vec2, b: Vec2): Value;
|
|
28
|
+
/**
|
|
29
|
+
* Get perpendicular vector (rotated 90° counter-clockwise).
|
|
30
|
+
* Useful for computing distances to lines.
|
|
31
|
+
*/
|
|
32
|
+
get perpendicular(): Vec2;
|
|
33
|
+
/**
|
|
34
|
+
* Distance from this point to a line defined by two points.
|
|
35
|
+
* Uses the perpendicular distance formula.
|
|
36
|
+
*/
|
|
37
|
+
static distanceToLine(point: Vec2, lineStart: Vec2, lineEnd: Vec2): Value;
|
|
38
|
+
/**
|
|
39
|
+
* Angle between two vectors in radians.
|
|
40
|
+
* Returns value in range [0, π].
|
|
41
|
+
*/
|
|
42
|
+
static angleBetween(a: Vec2, b: Vec2): Value;
|
|
43
|
+
get trainables(): Value[];
|
|
44
|
+
toString(): string;
|
|
45
|
+
}
|
package/dist/Vec2.js
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import { V } from './V';
|
|
2
|
+
import { Value } from './Value';
|
|
3
|
+
/**
|
|
4
|
+
* 2D vector class with differentiable operations.
|
|
5
|
+
* @public
|
|
6
|
+
*/
|
|
7
|
+
export class Vec2 {
|
|
8
|
+
x;
|
|
9
|
+
y;
|
|
10
|
+
constructor(x, y) {
|
|
11
|
+
this.x = typeof x === 'number' ? new Value(x) : x;
|
|
12
|
+
this.y = typeof y === 'number' ? new Value(y) : y;
|
|
13
|
+
}
|
|
14
|
+
static W(x, y) {
|
|
15
|
+
return new Vec2(V.W(x), V.W(y));
|
|
16
|
+
}
|
|
17
|
+
static C(x, y) {
|
|
18
|
+
return new Vec2(V.C(x), V.C(y));
|
|
19
|
+
}
|
|
20
|
+
static zero() {
|
|
21
|
+
return new Vec2(V.C(0), V.C(0));
|
|
22
|
+
}
|
|
23
|
+
static one() {
|
|
24
|
+
return new Vec2(V.C(1), V.C(1));
|
|
25
|
+
}
|
|
26
|
+
get magnitude() {
|
|
27
|
+
return V.sqrt(V.add(V.square(this.x), V.square(this.y)));
|
|
28
|
+
}
|
|
29
|
+
get sqrMagnitude() {
|
|
30
|
+
return V.add(V.square(this.x), V.square(this.y));
|
|
31
|
+
}
|
|
32
|
+
get normalized() {
|
|
33
|
+
const mag = this.magnitude;
|
|
34
|
+
return new Vec2(V.div(this.x, mag), V.div(this.y, mag));
|
|
35
|
+
}
|
|
36
|
+
static dot(a, b) {
|
|
37
|
+
return V.add(V.mul(a.x, b.x), V.mul(a.y, b.y));
|
|
38
|
+
}
|
|
39
|
+
add(other) {
|
|
40
|
+
return new Vec2(V.add(this.x, other.x), V.add(this.y, other.y));
|
|
41
|
+
}
|
|
42
|
+
sub(other) {
|
|
43
|
+
return new Vec2(V.sub(this.x, other.x), V.sub(this.y, other.y));
|
|
44
|
+
}
|
|
45
|
+
mul(scalar) {
|
|
46
|
+
return new Vec2(V.mul(this.x, scalar), V.mul(this.y, scalar));
|
|
47
|
+
}
|
|
48
|
+
div(scalar) {
|
|
49
|
+
return new Vec2(V.div(this.x, scalar), V.div(this.y, scalar));
|
|
50
|
+
}
|
|
51
|
+
/**
|
|
52
|
+
* 2D cross product (returns scalar z-component of 3D cross product).
|
|
53
|
+
* Also known as the "perpendicular dot product" or "wedge product".
|
|
54
|
+
* Returns positive if b is counter-clockwise from a, negative if clockwise.
|
|
55
|
+
*/
|
|
56
|
+
static cross(a, b) {
|
|
57
|
+
return V.sub(V.mul(a.x, b.y), V.mul(a.y, b.x));
|
|
58
|
+
}
|
|
59
|
+
/**
|
|
60
|
+
* Get perpendicular vector (rotated 90° counter-clockwise).
|
|
61
|
+
* Useful for computing distances to lines.
|
|
62
|
+
*/
|
|
63
|
+
get perpendicular() {
|
|
64
|
+
return new Vec2(V.neg(this.y), this.x);
|
|
65
|
+
}
|
|
66
|
+
/**
|
|
67
|
+
* Distance from this point to a line defined by two points.
|
|
68
|
+
* Uses the perpendicular distance formula.
|
|
69
|
+
*/
|
|
70
|
+
static distanceToLine(point, lineStart, lineEnd) {
|
|
71
|
+
const lineDir = lineEnd.sub(lineStart);
|
|
72
|
+
const pointDir = point.sub(lineStart);
|
|
73
|
+
const lineLength = lineDir.magnitude;
|
|
74
|
+
const cross = Vec2.cross(lineDir, pointDir);
|
|
75
|
+
return V.div(V.abs(cross), lineLength);
|
|
76
|
+
}
|
|
77
|
+
/**
|
|
78
|
+
* Angle between two vectors in radians.
|
|
79
|
+
* Returns value in range [0, π].
|
|
80
|
+
*/
|
|
81
|
+
static angleBetween(a, b) {
|
|
82
|
+
const dotProd = Vec2.dot(a, b);
|
|
83
|
+
const magProduct = V.mul(a.magnitude, b.magnitude);
|
|
84
|
+
const cosAngle = V.div(dotProd, magProduct);
|
|
85
|
+
return V.acos(cosAngle);
|
|
86
|
+
}
|
|
87
|
+
get trainables() {
|
|
88
|
+
return [this.x, this.y];
|
|
89
|
+
}
|
|
90
|
+
toString() {
|
|
91
|
+
return `Vec2(${this.x.data}, ${this.y.data})`;
|
|
92
|
+
}
|
|
93
|
+
}
|