scalar-autograd 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +127 -2
- package/dist/CompiledFunctions.d.ts +111 -0
- package/dist/CompiledFunctions.js +268 -0
- package/dist/CompiledResiduals.d.ts +74 -0
- package/dist/CompiledResiduals.js +94 -0
- package/dist/EigenvalueHelpers.d.ts +14 -0
- package/dist/EigenvalueHelpers.js +93 -0
- package/dist/Geometry.d.ts +131 -0
- package/dist/Geometry.js +213 -0
- package/dist/GraphBuilder.d.ts +64 -0
- package/dist/GraphBuilder.js +237 -0
- package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
- package/dist/GraphCanonicalizerNoSort.js +190 -0
- package/dist/GraphHashCanonicalizer.d.ts +46 -0
- package/dist/GraphHashCanonicalizer.js +220 -0
- package/dist/GraphSignature.d.ts +7 -0
- package/dist/GraphSignature.js +7 -0
- package/dist/KernelPool.d.ts +55 -0
- package/dist/KernelPool.js +124 -0
- package/dist/LBFGS.d.ts +84 -0
- package/dist/LBFGS.js +313 -0
- package/dist/LinearSolver.d.ts +69 -0
- package/dist/LinearSolver.js +213 -0
- package/dist/Losses.d.ts +9 -0
- package/dist/Losses.js +42 -37
- package/dist/Matrix3x3.d.ts +50 -0
- package/dist/Matrix3x3.js +146 -0
- package/dist/NonlinearLeastSquares.d.ts +33 -0
- package/dist/NonlinearLeastSquares.js +252 -0
- package/dist/Optimizers.d.ts +70 -14
- package/dist/Optimizers.js +42 -19
- package/dist/V.d.ts +0 -0
- package/dist/V.js +0 -0
- package/dist/Value.d.ts +84 -2
- package/dist/Value.js +296 -58
- package/dist/ValueActivation.js +10 -14
- package/dist/ValueArithmetic.d.ts +1 -0
- package/dist/ValueArithmetic.js +58 -50
- package/dist/ValueComparison.js +9 -13
- package/dist/ValueRegistry.d.ts +38 -0
- package/dist/ValueRegistry.js +88 -0
- package/dist/ValueTrig.js +14 -18
- package/dist/Vec2.d.ts +45 -0
- package/dist/Vec2.js +93 -0
- package/dist/Vec3.d.ts +78 -0
- package/dist/Vec3.js +169 -0
- package/dist/Vec4.d.ts +45 -0
- package/dist/Vec4.js +126 -0
- package/dist/__tests__/duplicate-inputs.test.js +33 -0
- package/dist/cli/gradient-gen.d.ts +19 -0
- package/dist/cli/gradient-gen.js +264 -0
- package/dist/compileIndirectKernel.d.ts +24 -0
- package/dist/compileIndirectKernel.js +148 -0
- package/dist/index.d.ts +20 -0
- package/dist/index.js +20 -0
- package/dist/scalar-autograd.d.ts +1157 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/dist/tsdoc-metadata.json +11 -0
- package/package.json +29 -5
- package/dist/Losses.spec.js +0 -54
- package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
- package/dist/Optimizers.edge-cases.spec.js +0 -29
- package/dist/Optimizers.spec.d.ts +0 -1
- package/dist/Optimizers.spec.js +0 -56
- package/dist/Value.edge-cases.spec.d.ts +0 -1
- package/dist/Value.edge-cases.spec.js +0 -54
- package/dist/Value.grad-flow.spec.d.ts +0 -1
- package/dist/Value.grad-flow.spec.js +0 -24
- package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
- package/dist/Value.losses-edge-cases.spec.js +0 -30
- package/dist/Value.memory.spec.d.ts +0 -1
- package/dist/Value.memory.spec.js +0 -23
- package/dist/Value.nn.spec.d.ts +0 -1
- package/dist/Value.nn.spec.js +0 -111
- package/dist/Value.spec.d.ts +0 -1
- package/dist/Value.spec.js +0 -245
- /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
package/dist/Value.js
CHANGED
|
@@ -1,32 +1,86 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
var Optimizers_1 = require("./Optimizers");
|
|
7
|
-
Object.defineProperty(exports, "Optimizer", { enumerable: true, get: function () { return Optimizers_1.Optimizer; } });
|
|
8
|
-
Object.defineProperty(exports, "SGD", { enumerable: true, get: function () { return Optimizers_1.SGD; } });
|
|
9
|
-
Object.defineProperty(exports, "Adam", { enumerable: true, get: function () { return Optimizers_1.Adam; } });
|
|
10
|
-
Object.defineProperty(exports, "AdamW", { enumerable: true, get: function () { return Optimizers_1.AdamW; } });
|
|
11
|
-
var Losses_1 = require("./Losses");
|
|
12
|
-
Object.defineProperty(exports, "Losses", { enumerable: true, get: function () { return Losses_1.Losses; } });
|
|
1
|
+
export { V } from './V';
|
|
2
|
+
export { Optimizer, SGD, Adam, AdamW } from './Optimizers';
|
|
3
|
+
export { Losses } from './Losses';
|
|
4
|
+
export { Vec2 } from './Vec2';
|
|
5
|
+
export { Vec3 } from './Vec3';
|
|
13
6
|
const EPS = 1e-12;
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
7
|
+
import { ValueActivation } from './ValueActivation';
|
|
8
|
+
import { ValueArithmetic } from './ValueArithmetic';
|
|
9
|
+
import { ValueComparison } from './ValueComparison';
|
|
10
|
+
import { ValueTrig } from './ValueTrig';
|
|
11
|
+
/**
|
|
12
|
+
* Represents a scalar value in the computational graph for automatic differentiation.
|
|
13
|
+
* Supports forward computation and reverse-mode autodiff (backpropagation).
|
|
14
|
+
* @public
|
|
15
|
+
*/
|
|
16
|
+
export class Value {
|
|
17
|
+
/**
|
|
18
|
+
* Global flag to disable gradient tracking. Use Value.withNoGrad() instead of setting directly.
|
|
19
|
+
* @public
|
|
20
|
+
*/
|
|
19
21
|
static no_grad_mode = false;
|
|
22
|
+
/**
|
|
23
|
+
* Current graph builder for incremental hash tracking (null when not building).
|
|
24
|
+
* @internal
|
|
25
|
+
*/
|
|
26
|
+
static currentBuilder = null;
|
|
27
|
+
/**
|
|
28
|
+
* Global counter for unique Value IDs.
|
|
29
|
+
* @internal
|
|
30
|
+
*/
|
|
31
|
+
static nextId = 0;
|
|
32
|
+
/**
|
|
33
|
+
* Unique ID for this Value instance (for hashing intermediate nodes).
|
|
34
|
+
* @internal
|
|
35
|
+
*/
|
|
36
|
+
_id;
|
|
37
|
+
/**
|
|
38
|
+
* The numeric value stored in this node.
|
|
39
|
+
* @public
|
|
40
|
+
*/
|
|
20
41
|
data;
|
|
42
|
+
/**
|
|
43
|
+
* The gradient of the output with respect to this value.
|
|
44
|
+
* @public
|
|
45
|
+
*/
|
|
21
46
|
grad = 0;
|
|
47
|
+
/**
|
|
48
|
+
* Whether this value participates in gradient computation.
|
|
49
|
+
* @public
|
|
50
|
+
*/
|
|
22
51
|
requiresGrad;
|
|
23
52
|
backwardFn = () => { };
|
|
24
|
-
prev = [];
|
|
53
|
+
/** @internal */ prev = [];
|
|
54
|
+
/**
|
|
55
|
+
* Optional label for debugging and visualization.
|
|
56
|
+
* @public
|
|
57
|
+
*/
|
|
25
58
|
label;
|
|
59
|
+
/**
|
|
60
|
+
* Operation type for JIT compilation (e.g., '+', 'exp', 'sin').
|
|
61
|
+
* @internal
|
|
62
|
+
*/
|
|
63
|
+
_op;
|
|
64
|
+
/**
|
|
65
|
+
* Parameter name for JIT compilation inputs.
|
|
66
|
+
* @internal
|
|
67
|
+
*/
|
|
68
|
+
paramName;
|
|
69
|
+
/**
|
|
70
|
+
* Registry ID for kernel reuse system.
|
|
71
|
+
* @internal
|
|
72
|
+
*/
|
|
73
|
+
_registryId;
|
|
74
|
+
/**
|
|
75
|
+
* Operation constants (e.g., min/max for clamp, exponent for pow).
|
|
76
|
+
* @internal
|
|
77
|
+
*/
|
|
78
|
+
_opConstants;
|
|
26
79
|
constructor(data, label = "", requiresGrad = false) {
|
|
27
80
|
if (typeof data !== 'number' || Number.isNaN(data) || !Number.isFinite(data)) {
|
|
28
81
|
throw new Error(`Invalid number passed to Value: ${data}`);
|
|
29
82
|
}
|
|
83
|
+
this._id = Value.nextId++;
|
|
30
84
|
this.data = data;
|
|
31
85
|
this.label = label;
|
|
32
86
|
this.requiresGrad = requiresGrad;
|
|
@@ -39,70 +93,70 @@ class Value {
|
|
|
39
93
|
* @returns New Value with sin.
|
|
40
94
|
*/
|
|
41
95
|
sin() {
|
|
42
|
-
return
|
|
96
|
+
return ValueTrig.sin(this);
|
|
43
97
|
}
|
|
44
98
|
/**
|
|
45
99
|
* Returns cos(this).
|
|
46
100
|
* @returns New Value with cos.
|
|
47
101
|
*/
|
|
48
102
|
cos() {
|
|
49
|
-
return
|
|
103
|
+
return ValueTrig.cos(this);
|
|
50
104
|
}
|
|
51
105
|
/**
|
|
52
106
|
* Returns tan(this).
|
|
53
107
|
* @returns New Value with tan.
|
|
54
108
|
*/
|
|
55
109
|
tan() {
|
|
56
|
-
return
|
|
110
|
+
return ValueTrig.tan(this);
|
|
57
111
|
}
|
|
58
112
|
/**
|
|
59
113
|
* Returns asin(this).
|
|
60
114
|
* @returns New Value with asin.
|
|
61
115
|
*/
|
|
62
116
|
asin() {
|
|
63
|
-
return
|
|
117
|
+
return ValueTrig.asin(this);
|
|
64
118
|
}
|
|
65
119
|
/**
|
|
66
120
|
* Returns acos(this).
|
|
67
121
|
* @returns New Value with acos.
|
|
68
122
|
*/
|
|
69
123
|
acos() {
|
|
70
|
-
return
|
|
124
|
+
return ValueTrig.acos(this);
|
|
71
125
|
}
|
|
72
126
|
/**
|
|
73
127
|
* Returns atan(this).
|
|
74
128
|
* @returns New Value with atan.
|
|
75
129
|
*/
|
|
76
130
|
atan() {
|
|
77
|
-
return
|
|
131
|
+
return ValueTrig.atan(this);
|
|
78
132
|
}
|
|
79
133
|
/**
|
|
80
134
|
* Returns relu(this).
|
|
81
135
|
* @returns New Value with relu.
|
|
82
136
|
*/
|
|
83
137
|
relu() {
|
|
84
|
-
return
|
|
138
|
+
return ValueActivation.relu(this);
|
|
85
139
|
}
|
|
86
140
|
/**
|
|
87
141
|
* Returns abs(this).
|
|
88
142
|
* @returns New Value with abs.
|
|
89
143
|
*/
|
|
90
144
|
abs() {
|
|
91
|
-
return
|
|
145
|
+
return ValueArithmetic.abs(this);
|
|
92
146
|
}
|
|
93
147
|
/**
|
|
94
148
|
* Returns exp(this).
|
|
95
149
|
* @returns New Value with exp.
|
|
96
150
|
*/
|
|
97
151
|
exp() {
|
|
98
|
-
return
|
|
152
|
+
return ValueArithmetic.exp(this);
|
|
99
153
|
}
|
|
100
154
|
/**
|
|
101
155
|
* Returns log(this).
|
|
102
156
|
* @returns New Value with log.
|
|
103
157
|
*/
|
|
104
158
|
log() {
|
|
105
|
-
return
|
|
159
|
+
return ValueArithmetic.log(this, EPS);
|
|
106
160
|
}
|
|
107
161
|
/**
|
|
108
162
|
* Returns min(this, other).
|
|
@@ -110,7 +164,7 @@ class Value {
|
|
|
110
164
|
* @returns New Value with min.
|
|
111
165
|
*/
|
|
112
166
|
min(other) {
|
|
113
|
-
return
|
|
167
|
+
return ValueArithmetic.min(this, other);
|
|
114
168
|
}
|
|
115
169
|
/**
|
|
116
170
|
* Returns max(this, other).
|
|
@@ -118,7 +172,7 @@ class Value {
|
|
|
118
172
|
* @returns New Value with max.
|
|
119
173
|
*/
|
|
120
174
|
max(other) {
|
|
121
|
-
return
|
|
175
|
+
return ValueArithmetic.max(this, other);
|
|
122
176
|
}
|
|
123
177
|
/**
|
|
124
178
|
* Adds this and other.
|
|
@@ -126,7 +180,7 @@ class Value {
|
|
|
126
180
|
* @returns New Value with sum.
|
|
127
181
|
*/
|
|
128
182
|
add(other) {
|
|
129
|
-
return
|
|
183
|
+
return ValueArithmetic.add(this, Value.ensureValue(other));
|
|
130
184
|
}
|
|
131
185
|
/**
|
|
132
186
|
* Multiplies this and other.
|
|
@@ -134,7 +188,7 @@ class Value {
|
|
|
134
188
|
* @returns New Value with product.
|
|
135
189
|
*/
|
|
136
190
|
mul(other) {
|
|
137
|
-
return
|
|
191
|
+
return ValueArithmetic.mul(this, Value.ensureValue(other));
|
|
138
192
|
}
|
|
139
193
|
/**
|
|
140
194
|
* Subtracts other from this.
|
|
@@ -142,7 +196,7 @@ class Value {
|
|
|
142
196
|
* @returns New Value with difference.
|
|
143
197
|
*/
|
|
144
198
|
sub(other) {
|
|
145
|
-
return
|
|
199
|
+
return ValueArithmetic.sub(this, Value.ensureValue(other));
|
|
146
200
|
}
|
|
147
201
|
/**
|
|
148
202
|
* Divides this by other.
|
|
@@ -150,7 +204,7 @@ class Value {
|
|
|
150
204
|
* @returns New Value with quotient.
|
|
151
205
|
*/
|
|
152
206
|
div(other) {
|
|
153
|
-
return
|
|
207
|
+
return ValueArithmetic.div(this, Value.ensureValue(other), EPS);
|
|
154
208
|
}
|
|
155
209
|
/**
|
|
156
210
|
* Raises this to the power exp.
|
|
@@ -158,7 +212,7 @@ class Value {
|
|
|
158
212
|
* @returns New Value with pow(this, exp)
|
|
159
213
|
*/
|
|
160
214
|
pow(exp) {
|
|
161
|
-
return
|
|
215
|
+
return ValueArithmetic.pow(this, exp);
|
|
162
216
|
}
|
|
163
217
|
/**
|
|
164
218
|
* Raises this to a dynamic Value (other).
|
|
@@ -166,7 +220,7 @@ class Value {
|
|
|
166
220
|
* @returns New Value with pow(this, other)
|
|
167
221
|
*/
|
|
168
222
|
powValue(other) {
|
|
169
|
-
return
|
|
223
|
+
return ValueArithmetic.powValue(this, Value.ensureValue(other), EPS);
|
|
170
224
|
}
|
|
171
225
|
/**
|
|
172
226
|
* Returns this modulo other.
|
|
@@ -174,7 +228,7 @@ class Value {
|
|
|
174
228
|
* @returns New Value with modulo.
|
|
175
229
|
*/
|
|
176
230
|
mod(other) {
|
|
177
|
-
return
|
|
231
|
+
return ValueArithmetic.mod(this, other);
|
|
178
232
|
}
|
|
179
233
|
/**
|
|
180
234
|
* Returns Value indicating if this equals other.
|
|
@@ -182,7 +236,7 @@ class Value {
|
|
|
182
236
|
* @returns New Value (1 if equal, else 0)
|
|
183
237
|
*/
|
|
184
238
|
eq(other) {
|
|
185
|
-
return
|
|
239
|
+
return ValueComparison.eq(this, other);
|
|
186
240
|
}
|
|
187
241
|
/**
|
|
188
242
|
* Returns Value indicating if this not equals other.
|
|
@@ -190,7 +244,7 @@ class Value {
|
|
|
190
244
|
* @returns New Value (1 if not equal, else 0)
|
|
191
245
|
*/
|
|
192
246
|
neq(other) {
|
|
193
|
-
return
|
|
247
|
+
return ValueComparison.neq(this, other);
|
|
194
248
|
}
|
|
195
249
|
/**
|
|
196
250
|
* Returns Value indicating if this greater than other.
|
|
@@ -198,7 +252,7 @@ class Value {
|
|
|
198
252
|
* @returns New Value (1 if true, else 0)
|
|
199
253
|
*/
|
|
200
254
|
gt(other) {
|
|
201
|
-
return
|
|
255
|
+
return ValueComparison.gt(this, other);
|
|
202
256
|
}
|
|
203
257
|
/**
|
|
204
258
|
* Returns Value indicating if this less than other.
|
|
@@ -206,7 +260,7 @@ class Value {
|
|
|
206
260
|
* @returns New Value (1 if true, else 0)
|
|
207
261
|
*/
|
|
208
262
|
lt(other) {
|
|
209
|
-
return
|
|
263
|
+
return ValueComparison.lt(this, other);
|
|
210
264
|
}
|
|
211
265
|
/**
|
|
212
266
|
* Returns Value indicating if this greater than or equal to other.
|
|
@@ -214,7 +268,7 @@ class Value {
|
|
|
214
268
|
* @returns New Value (1 if true, else 0)
|
|
215
269
|
*/
|
|
216
270
|
gte(other) {
|
|
217
|
-
return
|
|
271
|
+
return ValueComparison.gte(this, other);
|
|
218
272
|
}
|
|
219
273
|
/**
|
|
220
274
|
* Returns Value indicating if this less than or equal to other.
|
|
@@ -222,56 +276,56 @@ class Value {
|
|
|
222
276
|
* @returns New Value (1 if true, else 0)
|
|
223
277
|
*/
|
|
224
278
|
lte(other) {
|
|
225
|
-
return
|
|
279
|
+
return ValueComparison.lte(this, other);
|
|
226
280
|
}
|
|
227
281
|
/**
|
|
228
282
|
* Returns softplus(this).
|
|
229
283
|
* @returns New Value with softplus.
|
|
230
284
|
*/
|
|
231
285
|
softplus() {
|
|
232
|
-
return
|
|
286
|
+
return ValueActivation.softplus(this);
|
|
233
287
|
}
|
|
234
288
|
/**
|
|
235
289
|
* Returns the floor of this Value.
|
|
236
290
|
* @returns New Value with floor(data).
|
|
237
291
|
*/
|
|
238
292
|
floor() {
|
|
239
|
-
return
|
|
293
|
+
return ValueArithmetic.floor(this);
|
|
240
294
|
}
|
|
241
295
|
/**
|
|
242
296
|
* Returns the ceiling of this Value.
|
|
243
297
|
* @returns New Value with ceil(data).
|
|
244
298
|
*/
|
|
245
299
|
ceil() {
|
|
246
|
-
return
|
|
300
|
+
return ValueArithmetic.ceil(this);
|
|
247
301
|
}
|
|
248
302
|
/**
|
|
249
303
|
* Returns the rounded value of this Value.
|
|
250
304
|
* @returns New Value with rounded data.
|
|
251
305
|
*/
|
|
252
306
|
round() {
|
|
253
|
-
return
|
|
307
|
+
return ValueArithmetic.round(this);
|
|
254
308
|
}
|
|
255
309
|
/**
|
|
256
310
|
* Returns the square of this Value.
|
|
257
311
|
* @returns New Value with squared data.
|
|
258
312
|
*/
|
|
259
313
|
square() {
|
|
260
|
-
return
|
|
314
|
+
return ValueArithmetic.square(this);
|
|
261
315
|
}
|
|
262
316
|
/**
|
|
263
317
|
* Returns the cube of this Value.
|
|
264
318
|
* @returns New Value with cubed data.
|
|
265
319
|
*/
|
|
266
320
|
cube() {
|
|
267
|
-
return
|
|
321
|
+
return ValueArithmetic.cube(this);
|
|
268
322
|
}
|
|
269
323
|
/**
|
|
270
324
|
* Returns the reciprocal (1/x) of this Value.
|
|
271
325
|
* @returns New Value with reciprocal.
|
|
272
326
|
*/
|
|
273
327
|
reciprocal() {
|
|
274
|
-
return
|
|
328
|
+
return ValueArithmetic.reciprocal(this, EPS);
|
|
275
329
|
}
|
|
276
330
|
/**
|
|
277
331
|
* Clamps this between min and max.
|
|
@@ -280,14 +334,21 @@ class Value {
|
|
|
280
334
|
* @returns New clamped Value
|
|
281
335
|
*/
|
|
282
336
|
clamp(min, max) {
|
|
283
|
-
return
|
|
337
|
+
return ValueArithmetic.clamp(this, min, max);
|
|
284
338
|
}
|
|
285
339
|
/**
|
|
286
340
|
* Returns the negation (-this) Value.
|
|
287
341
|
* @returns New Value which is the negation.
|
|
288
342
|
*/
|
|
289
343
|
neg() {
|
|
290
|
-
return
|
|
344
|
+
return ValueArithmetic.neg(this);
|
|
345
|
+
}
|
|
346
|
+
/**
|
|
347
|
+
* Returns sign(this).
|
|
348
|
+
* @returns New Value with sign.
|
|
349
|
+
*/
|
|
350
|
+
sign() {
|
|
351
|
+
return ValueArithmetic.sign(this);
|
|
291
352
|
}
|
|
292
353
|
/**
|
|
293
354
|
* Returns the sum of the given Values.
|
|
@@ -295,7 +356,7 @@ class Value {
|
|
|
295
356
|
* @returns New Value holding their sum.
|
|
296
357
|
*/
|
|
297
358
|
static sum(vals) {
|
|
298
|
-
return
|
|
359
|
+
return ValueArithmetic.sum(vals);
|
|
299
360
|
}
|
|
300
361
|
/**
|
|
301
362
|
* Returns the mean of the given Values.
|
|
@@ -303,21 +364,21 @@ class Value {
|
|
|
303
364
|
* @returns New Value holding their mean.
|
|
304
365
|
*/
|
|
305
366
|
static mean(vals) {
|
|
306
|
-
return
|
|
367
|
+
return ValueArithmetic.mean(vals);
|
|
307
368
|
}
|
|
308
369
|
/**
|
|
309
370
|
* Returns tanh(this).
|
|
310
371
|
* @returns New Value with tanh.
|
|
311
372
|
*/
|
|
312
373
|
tanh() {
|
|
313
|
-
return
|
|
374
|
+
return ValueActivation.tanh(this);
|
|
314
375
|
}
|
|
315
376
|
/**
|
|
316
377
|
* Returns sigmoid(this).
|
|
317
378
|
* @returns New Value with sigmoid.
|
|
318
379
|
*/
|
|
319
380
|
sigmoid() {
|
|
320
|
-
return
|
|
381
|
+
return ValueActivation.sigmoid(this);
|
|
321
382
|
}
|
|
322
383
|
/**
|
|
323
384
|
* Performs a reverse-mode autodiff backward pass from this Value.
|
|
@@ -388,15 +449,39 @@ class Value {
|
|
|
388
449
|
* @param right Right operand Value or null
|
|
389
450
|
* @param backwardFnBuilder Function to create backward closure
|
|
390
451
|
* @param label Node label for debugging
|
|
452
|
+
* @param op Operation name for JIT compilation
|
|
391
453
|
* @returns New Value node
|
|
392
454
|
*/
|
|
393
|
-
static make(data, left, right, backwardFnBuilder, label) {
|
|
455
|
+
static make(data, left, right, backwardFnBuilder, label, op) {
|
|
394
456
|
const requiresGrad = !Value.no_grad_mode && [left, right].filter(Boolean).some(v => v.requiresGrad);
|
|
395
457
|
const out = new Value(data, label, requiresGrad);
|
|
396
458
|
out.prev = Value.no_grad_mode ? [] : [left, right].filter(Boolean);
|
|
459
|
+
out._op = op;
|
|
397
460
|
if (requiresGrad) {
|
|
398
461
|
out.backwardFn = backwardFnBuilder(out);
|
|
399
462
|
}
|
|
463
|
+
if (Value.currentBuilder) {
|
|
464
|
+
Value.currentBuilder.recordOp(out);
|
|
465
|
+
}
|
|
466
|
+
return out;
|
|
467
|
+
}
|
|
468
|
+
/**
|
|
469
|
+
* N-ary operation helper for operations with multiple inputs
|
|
470
|
+
*
|
|
471
|
+
* TODO: Move code generation logic into makeNary instead of centralized switches.
|
|
472
|
+
* This would co-locate runtime and codegen logic at operation definition sites.
|
|
473
|
+
*/
|
|
474
|
+
static makeNary(data, inputs, backwardFnBuilder, label, op) {
|
|
475
|
+
const requiresGrad = !Value.no_grad_mode && inputs.some(v => v.requiresGrad);
|
|
476
|
+
const out = new Value(data, label, requiresGrad);
|
|
477
|
+
out.prev = Value.no_grad_mode ? [] : inputs;
|
|
478
|
+
out._op = op;
|
|
479
|
+
if (requiresGrad) {
|
|
480
|
+
out.backwardFn = backwardFnBuilder(out);
|
|
481
|
+
}
|
|
482
|
+
if (Value.currentBuilder) {
|
|
483
|
+
Value.currentBuilder.recordOp(out);
|
|
484
|
+
}
|
|
400
485
|
return out;
|
|
401
486
|
}
|
|
402
487
|
/**
|
|
@@ -420,5 +505,158 @@ class Value {
|
|
|
420
505
|
Value.no_grad_mode = prev;
|
|
421
506
|
}
|
|
422
507
|
}
|
|
508
|
+
// TODO: Move code generation into make/makeNary to co-locate with runtime logic
|
|
509
|
+
getForwardCode(childCodes) {
|
|
510
|
+
if (this.paramName)
|
|
511
|
+
return this.paramName;
|
|
512
|
+
if (this.prev.length === 1) {
|
|
513
|
+
const [child] = childCodes;
|
|
514
|
+
switch (this._op) {
|
|
515
|
+
case 'exp': return `Math.exp(${child})`;
|
|
516
|
+
case 'log': return `Math.log(${child})`;
|
|
517
|
+
case 'sqrt': return `Math.sqrt(${child})`;
|
|
518
|
+
case 'tanh': return `Math.tanh(${child})`;
|
|
519
|
+
case 'sigmoid': return `(1 / (1 + Math.exp(-${child})))`;
|
|
520
|
+
case 'relu': return `Math.max(0, ${child})`;
|
|
521
|
+
case 'sin': return `Math.sin(${child})`;
|
|
522
|
+
case 'cos': return `Math.cos(${child})`;
|
|
523
|
+
case 'tan': return `Math.tan(${child})`;
|
|
524
|
+
case 'asin': return `Math.asin(${child})`;
|
|
525
|
+
case 'acos': return `Math.acos(${child})`;
|
|
526
|
+
case 'atan': return `Math.atan(${child})`;
|
|
527
|
+
case 'neg': return `(-${child})`;
|
|
528
|
+
case 'abs': return `Math.abs(${child})`;
|
|
529
|
+
case 'square': return `(${child} * ${child})`;
|
|
530
|
+
case 'cube': return `(${child} * ${child} * ${child})`;
|
|
531
|
+
case 'reciprocal': return `(1 / ${child})`;
|
|
532
|
+
case 'sign': return `Math.sign(${child})`;
|
|
533
|
+
case 'softplus': return `Math.log(1 + Math.exp(${child}))`;
|
|
534
|
+
case 'floor': return `Math.floor(${child})`;
|
|
535
|
+
case 'ceil': return `Math.ceil(${child})`;
|
|
536
|
+
case 'round': return `Math.round(${child})`;
|
|
537
|
+
case 'clamp': {
|
|
538
|
+
const [min, max] = this._opConstants || [0, 1];
|
|
539
|
+
return `Math.max(${min}, Math.min(${child}, ${max}))`;
|
|
540
|
+
}
|
|
541
|
+
default: return String(this.data);
|
|
542
|
+
}
|
|
543
|
+
}
|
|
544
|
+
// N-ary operations (6 inputs)
|
|
545
|
+
if (this.prev.length === 6 && this._op === 'eigenvalue_custom') {
|
|
546
|
+
const [c00, c01, c02, c11, c12, c22] = childCodes;
|
|
547
|
+
return `window.__eigenHelpers.computeSmallestEigenvalue(${c00}, ${c01}, ${c02}, ${c11}, ${c12}, ${c22})`;
|
|
548
|
+
}
|
|
549
|
+
const [left, right] = childCodes;
|
|
550
|
+
switch (this._op) {
|
|
551
|
+
case '+': return `(${left} + ${right})`;
|
|
552
|
+
case '-': return `(${left} - ${right})`;
|
|
553
|
+
case '*': return `(${left} * ${right})`;
|
|
554
|
+
case '/': return `(${left} / ${right})`;
|
|
555
|
+
case 'powValue': return `Math.pow(${left}, ${right})`;
|
|
556
|
+
case 'mod': return `(${left} % ${right})`;
|
|
557
|
+
case 'min': return `Math.min(${left}, ${right})`;
|
|
558
|
+
case 'max': return `Math.max(${left}, ${right})`;
|
|
559
|
+
default: return String(this.data);
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
// TODO: Move code generation into make/makeNary to co-locate with runtime logic
|
|
563
|
+
getBackwardCode(gradVar, childGrads, childVars) {
|
|
564
|
+
if (this.prev.length === 1) {
|
|
565
|
+
const [childGrad] = childGrads;
|
|
566
|
+
const [child] = childVars;
|
|
567
|
+
switch (this._op) {
|
|
568
|
+
case 'exp':
|
|
569
|
+
return `${childGrad} += ${gradVar} * Math.exp(${child});`;
|
|
570
|
+
case 'log':
|
|
571
|
+
return `${childGrad} += ${gradVar} / ${child};`;
|
|
572
|
+
case 'sqrt':
|
|
573
|
+
return `${childGrad} += ${gradVar} * 0.5 / Math.sqrt(${child});`;
|
|
574
|
+
case 'tanh': {
|
|
575
|
+
const tanhChild = `Math.tanh(${child})`;
|
|
576
|
+
return `${childGrad} += ${gradVar} * (1 - ${tanhChild} * ${tanhChild});`;
|
|
577
|
+
}
|
|
578
|
+
case 'sigmoid': {
|
|
579
|
+
const sigChild = `(1 / (1 + Math.exp(-${child})))`;
|
|
580
|
+
return `${childGrad} += ${gradVar} * ${sigChild} * (1 - ${sigChild});`;
|
|
581
|
+
}
|
|
582
|
+
case 'relu':
|
|
583
|
+
return `${childGrad} += ${gradVar} * (${child} > 0 ? 1 : 0);`;
|
|
584
|
+
case 'sin':
|
|
585
|
+
return `${childGrad} += ${gradVar} * Math.cos(${child});`;
|
|
586
|
+
case 'cos':
|
|
587
|
+
return `${childGrad} += ${gradVar} * (-Math.sin(${child}));`;
|
|
588
|
+
case 'tan': {
|
|
589
|
+
const cosChild = `Math.cos(${child})`;
|
|
590
|
+
return `${childGrad} += ${gradVar} / (${cosChild} * ${cosChild});`;
|
|
591
|
+
}
|
|
592
|
+
case 'asin':
|
|
593
|
+
return `${childGrad} += ${gradVar} / Math.sqrt(1 - ${child} * ${child});`;
|
|
594
|
+
case 'acos':
|
|
595
|
+
return `${childGrad} += ${gradVar} / (-Math.sqrt(1 - ${child} * ${child}));`;
|
|
596
|
+
case 'atan':
|
|
597
|
+
return `${childGrad} += ${gradVar} / (1 + ${child} * ${child});`;
|
|
598
|
+
case 'neg':
|
|
599
|
+
return `${childGrad} -= ${gradVar};`;
|
|
600
|
+
case 'abs':
|
|
601
|
+
return `${childGrad} += ${gradVar} * (${child} >= 0 ? 1 : -1);`;
|
|
602
|
+
case 'square':
|
|
603
|
+
return `${childGrad} += ${gradVar} * 2 * ${child};`;
|
|
604
|
+
case 'cube':
|
|
605
|
+
return `${childGrad} += ${gradVar} * 3 * ${child} * ${child};`;
|
|
606
|
+
case 'reciprocal':
|
|
607
|
+
return `${childGrad} -= ${gradVar} / (${child} * ${child});`;
|
|
608
|
+
case 'sign':
|
|
609
|
+
return `${childGrad} += 0;`;
|
|
610
|
+
case 'softplus': {
|
|
611
|
+
const expChild = `Math.exp(${child})`;
|
|
612
|
+
return `${childGrad} += ${gradVar} * ${expChild} / (1 + ${expChild});`;
|
|
613
|
+
}
|
|
614
|
+
case 'floor':
|
|
615
|
+
case 'ceil':
|
|
616
|
+
case 'round':
|
|
617
|
+
return `${childGrad} += 0;`;
|
|
618
|
+
case 'clamp': {
|
|
619
|
+
const [min, max] = this._opConstants || [0, 1];
|
|
620
|
+
return `${childGrad} += ${gradVar} * (${child} > ${min} && ${child} < ${max} ? 1 : 0);`;
|
|
621
|
+
}
|
|
622
|
+
default:
|
|
623
|
+
return '';
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
// N-ary operations (6 inputs)
|
|
627
|
+
if (this.prev.length === 6 && this._op === 'eigenvalue_custom') {
|
|
628
|
+
const [g00, g01, g02, g11, g12, g22] = childGrads;
|
|
629
|
+
const [v00, v01, v02, v11, v12, v22] = childVars;
|
|
630
|
+
return `{
|
|
631
|
+
const grads = window.__eigenHelpers.applyEigenvalueGradients(
|
|
632
|
+
${gradVar}, ${v00}, ${v01}, ${v02}, ${v11}, ${v12}, ${v22},
|
|
633
|
+
${g00}, ${g01}, ${g02}, ${g11}, ${g12}, ${g22}
|
|
634
|
+
);
|
|
635
|
+
${g00} = grads[0]; ${g01} = grads[1]; ${g02} = grads[2];
|
|
636
|
+
${g11} = grads[3]; ${g12} = grads[4]; ${g22} = grads[5];
|
|
637
|
+
}`;
|
|
638
|
+
}
|
|
639
|
+
const [leftGrad, rightGrad] = childGrads;
|
|
640
|
+
const [left, right] = childVars;
|
|
641
|
+
switch (this._op) {
|
|
642
|
+
case '+':
|
|
643
|
+
return `${leftGrad} += ${gradVar}; ${rightGrad} += ${gradVar};`;
|
|
644
|
+
case '-':
|
|
645
|
+
return `${leftGrad} += ${gradVar}; ${rightGrad} -= ${gradVar};`;
|
|
646
|
+
case '*':
|
|
647
|
+
return `${leftGrad} += ${gradVar} * ${right}; ${rightGrad} += ${gradVar} * ${left};`;
|
|
648
|
+
case '/':
|
|
649
|
+
return `${leftGrad} += ${gradVar} / ${right}; ${rightGrad} -= ${gradVar} * ${left} / (${right} * ${right});`;
|
|
650
|
+
case 'powValue':
|
|
651
|
+
return `${leftGrad} += ${gradVar} * ${right} * Math.pow(${left}, ${right} - 1); ${rightGrad} += ${gradVar} * Math.pow(${left}, ${right}) * Math.log(${left});`;
|
|
652
|
+
case 'mod':
|
|
653
|
+
return `${leftGrad} += ${gradVar}; ${rightGrad} += 0;`;
|
|
654
|
+
case 'min':
|
|
655
|
+
return `${leftGrad} += ${gradVar} * (${left} < ${right} ? 1 : 0); ${rightGrad} += ${gradVar} * (${right} < ${left} ? 1 : 0);`;
|
|
656
|
+
case 'max':
|
|
657
|
+
return `${leftGrad} += ${gradVar} * (${left} > ${right} ? 1 : 0); ${rightGrad} += ${gradVar} * (${right} > ${left} ? 1 : 0);`;
|
|
658
|
+
default:
|
|
659
|
+
return '';
|
|
660
|
+
}
|
|
661
|
+
}
|
|
423
662
|
}
|
|
424
|
-
exports.Value = Value;
|
package/dist/ValueActivation.js
CHANGED
|
@@ -1,34 +1,30 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
exports.ValueActivation = void 0;
|
|
4
|
-
const Value_1 = require("./Value");
|
|
5
|
-
class ValueActivation {
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
export class ValueActivation {
|
|
6
3
|
static relu(x) {
|
|
7
4
|
const r = Math.max(0, x.data);
|
|
8
|
-
return
|
|
5
|
+
return Value.make(r, x, null, (out) => () => {
|
|
9
6
|
if (x.requiresGrad)
|
|
10
7
|
x.grad += (x.data > 0 ? 1 : 0) * out.grad;
|
|
11
|
-
}, `relu(${x.label})
|
|
8
|
+
}, `relu(${x.label})`, 'relu');
|
|
12
9
|
}
|
|
13
10
|
static softplus(x) {
|
|
14
11
|
const s = Math.log(1 + Math.exp(x.data));
|
|
15
|
-
return
|
|
12
|
+
return Value.make(s, x, null, (out) => () => {
|
|
16
13
|
x.grad += 1 / (1 + Math.exp(-x.data)) * out.grad;
|
|
17
|
-
}, `softplus(${x.label})
|
|
14
|
+
}, `softplus(${x.label})`, 'softplus');
|
|
18
15
|
}
|
|
19
16
|
static tanh(x) {
|
|
20
17
|
const t = Math.tanh(x.data);
|
|
21
|
-
return
|
|
18
|
+
return Value.make(t, x, null, (out) => () => {
|
|
22
19
|
if (x.requiresGrad)
|
|
23
20
|
x.grad += (1 - t ** 2) * out.grad;
|
|
24
|
-
}, `tanh(${x.label})
|
|
21
|
+
}, `tanh(${x.label})`, 'tanh');
|
|
25
22
|
}
|
|
26
23
|
static sigmoid(x) {
|
|
27
24
|
const s = 1 / (1 + Math.exp(-x.data));
|
|
28
|
-
return
|
|
25
|
+
return Value.make(s, x, null, (out) => () => {
|
|
29
26
|
if (x.requiresGrad)
|
|
30
27
|
x.grad += s * (1 - s) * out.grad;
|
|
31
|
-
}, `sigmoid(${x.label})
|
|
28
|
+
}, `sigmoid(${x.label})`, 'sigmoid');
|
|
32
29
|
}
|
|
33
30
|
}
|
|
34
|
-
exports.ValueActivation = ValueActivation;
|