scalar-autograd 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +127 -2
  2. package/dist/CompiledFunctions.d.ts +111 -0
  3. package/dist/CompiledFunctions.js +268 -0
  4. package/dist/CompiledResiduals.d.ts +74 -0
  5. package/dist/CompiledResiduals.js +94 -0
  6. package/dist/EigenvalueHelpers.d.ts +14 -0
  7. package/dist/EigenvalueHelpers.js +93 -0
  8. package/dist/Geometry.d.ts +131 -0
  9. package/dist/Geometry.js +213 -0
  10. package/dist/GraphBuilder.d.ts +64 -0
  11. package/dist/GraphBuilder.js +237 -0
  12. package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
  13. package/dist/GraphCanonicalizerNoSort.js +190 -0
  14. package/dist/GraphHashCanonicalizer.d.ts +46 -0
  15. package/dist/GraphHashCanonicalizer.js +220 -0
  16. package/dist/GraphSignature.d.ts +7 -0
  17. package/dist/GraphSignature.js +7 -0
  18. package/dist/KernelPool.d.ts +55 -0
  19. package/dist/KernelPool.js +124 -0
  20. package/dist/LBFGS.d.ts +84 -0
  21. package/dist/LBFGS.js +313 -0
  22. package/dist/LinearSolver.d.ts +69 -0
  23. package/dist/LinearSolver.js +213 -0
  24. package/dist/Losses.d.ts +9 -0
  25. package/dist/Losses.js +42 -37
  26. package/dist/Matrix3x3.d.ts +50 -0
  27. package/dist/Matrix3x3.js +146 -0
  28. package/dist/NonlinearLeastSquares.d.ts +33 -0
  29. package/dist/NonlinearLeastSquares.js +252 -0
  30. package/dist/Optimizers.d.ts +70 -14
  31. package/dist/Optimizers.js +42 -19
  32. package/dist/V.d.ts +0 -0
  33. package/dist/V.js +0 -0
  34. package/dist/Value.d.ts +84 -2
  35. package/dist/Value.js +296 -58
  36. package/dist/ValueActivation.js +10 -14
  37. package/dist/ValueArithmetic.d.ts +1 -0
  38. package/dist/ValueArithmetic.js +58 -50
  39. package/dist/ValueComparison.js +9 -13
  40. package/dist/ValueRegistry.d.ts +38 -0
  41. package/dist/ValueRegistry.js +88 -0
  42. package/dist/ValueTrig.js +14 -18
  43. package/dist/Vec2.d.ts +45 -0
  44. package/dist/Vec2.js +93 -0
  45. package/dist/Vec3.d.ts +78 -0
  46. package/dist/Vec3.js +169 -0
  47. package/dist/Vec4.d.ts +45 -0
  48. package/dist/Vec4.js +126 -0
  49. package/dist/__tests__/duplicate-inputs.test.js +33 -0
  50. package/dist/cli/gradient-gen.d.ts +19 -0
  51. package/dist/cli/gradient-gen.js +264 -0
  52. package/dist/compileIndirectKernel.d.ts +24 -0
  53. package/dist/compileIndirectKernel.js +148 -0
  54. package/dist/index.d.ts +20 -0
  55. package/dist/index.js +20 -0
  56. package/dist/scalar-autograd.d.ts +1157 -0
  57. package/dist/symbolic/AST.d.ts +113 -0
  58. package/dist/symbolic/AST.js +128 -0
  59. package/dist/symbolic/CodeGen.d.ts +35 -0
  60. package/dist/symbolic/CodeGen.js +280 -0
  61. package/dist/symbolic/Parser.d.ts +64 -0
  62. package/dist/symbolic/Parser.js +329 -0
  63. package/dist/symbolic/Simplify.d.ts +10 -0
  64. package/dist/symbolic/Simplify.js +244 -0
  65. package/dist/symbolic/SymbolicDiff.d.ts +35 -0
  66. package/dist/symbolic/SymbolicDiff.js +339 -0
  67. package/dist/tsdoc-metadata.json +11 -0
  68. package/package.json +29 -5
  69. package/dist/Losses.spec.js +0 -54
  70. package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
  71. package/dist/Optimizers.edge-cases.spec.js +0 -29
  72. package/dist/Optimizers.spec.d.ts +0 -1
  73. package/dist/Optimizers.spec.js +0 -56
  74. package/dist/Value.edge-cases.spec.d.ts +0 -1
  75. package/dist/Value.edge-cases.spec.js +0 -54
  76. package/dist/Value.grad-flow.spec.d.ts +0 -1
  77. package/dist/Value.grad-flow.spec.js +0 -24
  78. package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
  79. package/dist/Value.losses-edge-cases.spec.js +0 -30
  80. package/dist/Value.memory.spec.d.ts +0 -1
  81. package/dist/Value.memory.spec.js +0 -23
  82. package/dist/Value.nn.spec.d.ts +0 -1
  83. package/dist/Value.nn.spec.js +0 -111
  84. package/dist/Value.spec.d.ts +0 -1
  85. package/dist/Value.spec.js +0 -245
  86. /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
package/dist/Value.js CHANGED
@@ -1,32 +1,86 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Value = exports.Losses = exports.AdamW = exports.Adam = exports.SGD = exports.Optimizer = exports.V = void 0;
4
- var V_1 = require("./V");
5
- Object.defineProperty(exports, "V", { enumerable: true, get: function () { return V_1.V; } });
6
- var Optimizers_1 = require("./Optimizers");
7
- Object.defineProperty(exports, "Optimizer", { enumerable: true, get: function () { return Optimizers_1.Optimizer; } });
8
- Object.defineProperty(exports, "SGD", { enumerable: true, get: function () { return Optimizers_1.SGD; } });
9
- Object.defineProperty(exports, "Adam", { enumerable: true, get: function () { return Optimizers_1.Adam; } });
10
- Object.defineProperty(exports, "AdamW", { enumerable: true, get: function () { return Optimizers_1.AdamW; } });
11
- var Losses_1 = require("./Losses");
12
- Object.defineProperty(exports, "Losses", { enumerable: true, get: function () { return Losses_1.Losses; } });
1
+ export { V } from './V';
2
+ export { Optimizer, SGD, Adam, AdamW } from './Optimizers';
3
+ export { Losses } from './Losses';
4
+ export { Vec2 } from './Vec2';
5
+ export { Vec3 } from './Vec3';
13
6
  const EPS = 1e-12;
14
- const ValueTrig_1 = require("./ValueTrig");
15
- const ValueActivation_1 = require("./ValueActivation");
16
- const ValueArithmetic_1 = require("./ValueArithmetic");
17
- const ValueComparison_1 = require("./ValueComparison");
18
- class Value {
7
+ import { ValueActivation } from './ValueActivation';
8
+ import { ValueArithmetic } from './ValueArithmetic';
9
+ import { ValueComparison } from './ValueComparison';
10
+ import { ValueTrig } from './ValueTrig';
11
+ /**
12
+ * Represents a scalar value in the computational graph for automatic differentiation.
13
+ * Supports forward computation and reverse-mode autodiff (backpropagation).
14
+ * @public
15
+ */
16
+ export class Value {
17
+ /**
18
+ * Global flag to disable gradient tracking. Use Value.withNoGrad() instead of setting directly.
19
+ * @public
20
+ */
19
21
  static no_grad_mode = false;
22
+ /**
23
+ * Current graph builder for incremental hash tracking (null when not building).
24
+ * @internal
25
+ */
26
+ static currentBuilder = null;
27
+ /**
28
+ * Global counter for unique Value IDs.
29
+ * @internal
30
+ */
31
+ static nextId = 0;
32
+ /**
33
+ * Unique ID for this Value instance (for hashing intermediate nodes).
34
+ * @internal
35
+ */
36
+ _id;
37
+ /**
38
+ * The numeric value stored in this node.
39
+ * @public
40
+ */
20
41
  data;
42
+ /**
43
+ * The gradient of the output with respect to this value.
44
+ * @public
45
+ */
21
46
  grad = 0;
47
+ /**
48
+ * Whether this value participates in gradient computation.
49
+ * @public
50
+ */
22
51
  requiresGrad;
23
52
  backwardFn = () => { };
24
- prev = [];
53
+ /** @internal */ prev = [];
54
+ /**
55
+ * Optional label for debugging and visualization.
56
+ * @public
57
+ */
25
58
  label;
59
+ /**
60
+ * Operation type for JIT compilation (e.g., '+', 'exp', 'sin').
61
+ * @internal
62
+ */
63
+ _op;
64
+ /**
65
+ * Parameter name for JIT compilation inputs.
66
+ * @internal
67
+ */
68
+ paramName;
69
+ /**
70
+ * Registry ID for kernel reuse system.
71
+ * @internal
72
+ */
73
+ _registryId;
74
+ /**
75
+ * Operation constants (e.g., min/max for clamp, exponent for pow).
76
+ * @internal
77
+ */
78
+ _opConstants;
26
79
  constructor(data, label = "", requiresGrad = false) {
27
80
  if (typeof data !== 'number' || Number.isNaN(data) || !Number.isFinite(data)) {
28
81
  throw new Error(`Invalid number passed to Value: ${data}`);
29
82
  }
83
+ this._id = Value.nextId++;
30
84
  this.data = data;
31
85
  this.label = label;
32
86
  this.requiresGrad = requiresGrad;
@@ -39,70 +93,70 @@ class Value {
39
93
  * @returns New Value with sin.
40
94
  */
41
95
  sin() {
42
- return ValueTrig_1.ValueTrig.sin(this);
96
+ return ValueTrig.sin(this);
43
97
  }
44
98
  /**
45
99
  * Returns cos(this).
46
100
  * @returns New Value with cos.
47
101
  */
48
102
  cos() {
49
- return ValueTrig_1.ValueTrig.cos(this);
103
+ return ValueTrig.cos(this);
50
104
  }
51
105
  /**
52
106
  * Returns tan(this).
53
107
  * @returns New Value with tan.
54
108
  */
55
109
  tan() {
56
- return ValueTrig_1.ValueTrig.tan(this);
110
+ return ValueTrig.tan(this);
57
111
  }
58
112
  /**
59
113
  * Returns asin(this).
60
114
  * @returns New Value with asin.
61
115
  */
62
116
  asin() {
63
- return ValueTrig_1.ValueTrig.asin(this);
117
+ return ValueTrig.asin(this);
64
118
  }
65
119
  /**
66
120
  * Returns acos(this).
67
121
  * @returns New Value with acos.
68
122
  */
69
123
  acos() {
70
- return ValueTrig_1.ValueTrig.acos(this);
124
+ return ValueTrig.acos(this);
71
125
  }
72
126
  /**
73
127
  * Returns atan(this).
74
128
  * @returns New Value with atan.
75
129
  */
76
130
  atan() {
77
- return ValueTrig_1.ValueTrig.atan(this);
131
+ return ValueTrig.atan(this);
78
132
  }
79
133
  /**
80
134
  * Returns relu(this).
81
135
  * @returns New Value with relu.
82
136
  */
83
137
  relu() {
84
- return ValueActivation_1.ValueActivation.relu(this);
138
+ return ValueActivation.relu(this);
85
139
  }
86
140
  /**
87
141
  * Returns abs(this).
88
142
  * @returns New Value with abs.
89
143
  */
90
144
  abs() {
91
- return ValueArithmetic_1.ValueArithmetic.abs(this);
145
+ return ValueArithmetic.abs(this);
92
146
  }
93
147
  /**
94
148
  * Returns exp(this).
95
149
  * @returns New Value with exp.
96
150
  */
97
151
  exp() {
98
- return ValueArithmetic_1.ValueArithmetic.exp(this);
152
+ return ValueArithmetic.exp(this);
99
153
  }
100
154
  /**
101
155
  * Returns log(this).
102
156
  * @returns New Value with log.
103
157
  */
104
158
  log() {
105
- return ValueArithmetic_1.ValueArithmetic.log(this, EPS);
159
+ return ValueArithmetic.log(this, EPS);
106
160
  }
107
161
  /**
108
162
  * Returns min(this, other).
@@ -110,7 +164,7 @@ class Value {
110
164
  * @returns New Value with min.
111
165
  */
112
166
  min(other) {
113
- return ValueArithmetic_1.ValueArithmetic.min(this, other);
167
+ return ValueArithmetic.min(this, other);
114
168
  }
115
169
  /**
116
170
  * Returns max(this, other).
@@ -118,7 +172,7 @@ class Value {
118
172
  * @returns New Value with max.
119
173
  */
120
174
  max(other) {
121
- return ValueArithmetic_1.ValueArithmetic.max(this, other);
175
+ return ValueArithmetic.max(this, other);
122
176
  }
123
177
  /**
124
178
  * Adds this and other.
@@ -126,7 +180,7 @@ class Value {
126
180
  * @returns New Value with sum.
127
181
  */
128
182
  add(other) {
129
- return ValueArithmetic_1.ValueArithmetic.add(this, Value.ensureValue(other));
183
+ return ValueArithmetic.add(this, Value.ensureValue(other));
130
184
  }
131
185
  /**
132
186
  * Multiplies this and other.
@@ -134,7 +188,7 @@ class Value {
134
188
  * @returns New Value with product.
135
189
  */
136
190
  mul(other) {
137
- return ValueArithmetic_1.ValueArithmetic.mul(this, Value.ensureValue(other));
191
+ return ValueArithmetic.mul(this, Value.ensureValue(other));
138
192
  }
139
193
  /**
140
194
  * Subtracts other from this.
@@ -142,7 +196,7 @@ class Value {
142
196
  * @returns New Value with difference.
143
197
  */
144
198
  sub(other) {
145
- return ValueArithmetic_1.ValueArithmetic.sub(this, Value.ensureValue(other));
199
+ return ValueArithmetic.sub(this, Value.ensureValue(other));
146
200
  }
147
201
  /**
148
202
  * Divides this by other.
@@ -150,7 +204,7 @@ class Value {
150
204
  * @returns New Value with quotient.
151
205
  */
152
206
  div(other) {
153
- return ValueArithmetic_1.ValueArithmetic.div(this, Value.ensureValue(other), EPS);
207
+ return ValueArithmetic.div(this, Value.ensureValue(other), EPS);
154
208
  }
155
209
  /**
156
210
  * Raises this to the power exp.
@@ -158,7 +212,7 @@ class Value {
158
212
  * @returns New Value with pow(this, exp)
159
213
  */
160
214
  pow(exp) {
161
- return ValueArithmetic_1.ValueArithmetic.pow(this, exp);
215
+ return ValueArithmetic.pow(this, exp);
162
216
  }
163
217
  /**
164
218
  * Raises this to a dynamic Value (other).
@@ -166,7 +220,7 @@ class Value {
166
220
  * @returns New Value with pow(this, other)
167
221
  */
168
222
  powValue(other) {
169
- return ValueArithmetic_1.ValueArithmetic.powValue(this, Value.ensureValue(other), EPS);
223
+ return ValueArithmetic.powValue(this, Value.ensureValue(other), EPS);
170
224
  }
171
225
  /**
172
226
  * Returns this modulo other.
@@ -174,7 +228,7 @@ class Value {
174
228
  * @returns New Value with modulo.
175
229
  */
176
230
  mod(other) {
177
- return ValueArithmetic_1.ValueArithmetic.mod(this, other);
231
+ return ValueArithmetic.mod(this, other);
178
232
  }
179
233
  /**
180
234
  * Returns Value indicating if this equals other.
@@ -182,7 +236,7 @@ class Value {
182
236
  * @returns New Value (1 if equal, else 0)
183
237
  */
184
238
  eq(other) {
185
- return ValueComparison_1.ValueComparison.eq(this, other);
239
+ return ValueComparison.eq(this, other);
186
240
  }
187
241
  /**
188
242
  * Returns Value indicating if this not equals other.
@@ -190,7 +244,7 @@ class Value {
190
244
  * @returns New Value (1 if not equal, else 0)
191
245
  */
192
246
  neq(other) {
193
- return ValueComparison_1.ValueComparison.neq(this, other);
247
+ return ValueComparison.neq(this, other);
194
248
  }
195
249
  /**
196
250
  * Returns Value indicating if this greater than other.
@@ -198,7 +252,7 @@ class Value {
198
252
  * @returns New Value (1 if true, else 0)
199
253
  */
200
254
  gt(other) {
201
- return ValueComparison_1.ValueComparison.gt(this, other);
255
+ return ValueComparison.gt(this, other);
202
256
  }
203
257
  /**
204
258
  * Returns Value indicating if this less than other.
@@ -206,7 +260,7 @@ class Value {
206
260
  * @returns New Value (1 if true, else 0)
207
261
  */
208
262
  lt(other) {
209
- return ValueComparison_1.ValueComparison.lt(this, other);
263
+ return ValueComparison.lt(this, other);
210
264
  }
211
265
  /**
212
266
  * Returns Value indicating if this greater than or equal to other.
@@ -214,7 +268,7 @@ class Value {
214
268
  * @returns New Value (1 if true, else 0)
215
269
  */
216
270
  gte(other) {
217
- return ValueComparison_1.ValueComparison.gte(this, other);
271
+ return ValueComparison.gte(this, other);
218
272
  }
219
273
  /**
220
274
  * Returns Value indicating if this less than or equal to other.
@@ -222,56 +276,56 @@ class Value {
222
276
  * @returns New Value (1 if true, else 0)
223
277
  */
224
278
  lte(other) {
225
- return ValueComparison_1.ValueComparison.lte(this, other);
279
+ return ValueComparison.lte(this, other);
226
280
  }
227
281
  /**
228
282
  * Returns softplus(this).
229
283
  * @returns New Value with softplus.
230
284
  */
231
285
  softplus() {
232
- return ValueActivation_1.ValueActivation.softplus(this);
286
+ return ValueActivation.softplus(this);
233
287
  }
234
288
  /**
235
289
  * Returns the floor of this Value.
236
290
  * @returns New Value with floor(data).
237
291
  */
238
292
  floor() {
239
- return ValueArithmetic_1.ValueArithmetic.floor(this);
293
+ return ValueArithmetic.floor(this);
240
294
  }
241
295
  /**
242
296
  * Returns the ceiling of this Value.
243
297
  * @returns New Value with ceil(data).
244
298
  */
245
299
  ceil() {
246
- return ValueArithmetic_1.ValueArithmetic.ceil(this);
300
+ return ValueArithmetic.ceil(this);
247
301
  }
248
302
  /**
249
303
  * Returns the rounded value of this Value.
250
304
  * @returns New Value with rounded data.
251
305
  */
252
306
  round() {
253
- return ValueArithmetic_1.ValueArithmetic.round(this);
307
+ return ValueArithmetic.round(this);
254
308
  }
255
309
  /**
256
310
  * Returns the square of this Value.
257
311
  * @returns New Value with squared data.
258
312
  */
259
313
  square() {
260
- return ValueArithmetic_1.ValueArithmetic.square(this);
314
+ return ValueArithmetic.square(this);
261
315
  }
262
316
  /**
263
317
  * Returns the cube of this Value.
264
318
  * @returns New Value with cubed data.
265
319
  */
266
320
  cube() {
267
- return ValueArithmetic_1.ValueArithmetic.cube(this);
321
+ return ValueArithmetic.cube(this);
268
322
  }
269
323
  /**
270
324
  * Returns the reciprocal (1/x) of this Value.
271
325
  * @returns New Value with reciprocal.
272
326
  */
273
327
  reciprocal() {
274
- return ValueArithmetic_1.ValueArithmetic.reciprocal(this, EPS);
328
+ return ValueArithmetic.reciprocal(this, EPS);
275
329
  }
276
330
  /**
277
331
  * Clamps this between min and max.
@@ -280,14 +334,21 @@ class Value {
280
334
  * @returns New clamped Value
281
335
  */
282
336
  clamp(min, max) {
283
- return ValueArithmetic_1.ValueArithmetic.clamp(this, min, max);
337
+ return ValueArithmetic.clamp(this, min, max);
284
338
  }
285
339
  /**
286
340
  * Returns the negation (-this) Value.
287
341
  * @returns New Value which is the negation.
288
342
  */
289
343
  neg() {
290
- return ValueArithmetic_1.ValueArithmetic.neg(this);
344
+ return ValueArithmetic.neg(this);
345
+ }
346
+ /**
347
+ * Returns sign(this).
348
+ * @returns New Value with sign.
349
+ */
350
+ sign() {
351
+ return ValueArithmetic.sign(this);
291
352
  }
292
353
  /**
293
354
  * Returns the sum of the given Values.
@@ -295,7 +356,7 @@ class Value {
295
356
  * @returns New Value holding their sum.
296
357
  */
297
358
  static sum(vals) {
298
- return ValueArithmetic_1.ValueArithmetic.sum(vals);
359
+ return ValueArithmetic.sum(vals);
299
360
  }
300
361
  /**
301
362
  * Returns the mean of the given Values.
@@ -303,21 +364,21 @@ class Value {
303
364
  * @returns New Value holding their mean.
304
365
  */
305
366
  static mean(vals) {
306
- return ValueArithmetic_1.ValueArithmetic.mean(vals);
367
+ return ValueArithmetic.mean(vals);
307
368
  }
308
369
  /**
309
370
  * Returns tanh(this).
310
371
  * @returns New Value with tanh.
311
372
  */
312
373
  tanh() {
313
- return ValueActivation_1.ValueActivation.tanh(this);
374
+ return ValueActivation.tanh(this);
314
375
  }
315
376
  /**
316
377
  * Returns sigmoid(this).
317
378
  * @returns New Value with sigmoid.
318
379
  */
319
380
  sigmoid() {
320
- return ValueActivation_1.ValueActivation.sigmoid(this);
381
+ return ValueActivation.sigmoid(this);
321
382
  }
322
383
  /**
323
384
  * Performs a reverse-mode autodiff backward pass from this Value.
@@ -388,15 +449,39 @@ class Value {
388
449
  * @param right Right operand Value or null
389
450
  * @param backwardFnBuilder Function to create backward closure
390
451
  * @param label Node label for debugging
452
+ * @param op Operation name for JIT compilation
391
453
  * @returns New Value node
392
454
  */
393
- static make(data, left, right, backwardFnBuilder, label) {
455
+ static make(data, left, right, backwardFnBuilder, label, op) {
394
456
  const requiresGrad = !Value.no_grad_mode && [left, right].filter(Boolean).some(v => v.requiresGrad);
395
457
  const out = new Value(data, label, requiresGrad);
396
458
  out.prev = Value.no_grad_mode ? [] : [left, right].filter(Boolean);
459
+ out._op = op;
397
460
  if (requiresGrad) {
398
461
  out.backwardFn = backwardFnBuilder(out);
399
462
  }
463
+ if (Value.currentBuilder) {
464
+ Value.currentBuilder.recordOp(out);
465
+ }
466
+ return out;
467
+ }
468
+ /**
469
+ * N-ary operation helper for operations with multiple inputs
470
+ *
471
+ * TODO: Move code generation logic into makeNary instead of centralized switches.
472
+ * This would co-locate runtime and codegen logic at operation definition sites.
473
+ */
474
+ static makeNary(data, inputs, backwardFnBuilder, label, op) {
475
+ const requiresGrad = !Value.no_grad_mode && inputs.some(v => v.requiresGrad);
476
+ const out = new Value(data, label, requiresGrad);
477
+ out.prev = Value.no_grad_mode ? [] : inputs;
478
+ out._op = op;
479
+ if (requiresGrad) {
480
+ out.backwardFn = backwardFnBuilder(out);
481
+ }
482
+ if (Value.currentBuilder) {
483
+ Value.currentBuilder.recordOp(out);
484
+ }
400
485
  return out;
401
486
  }
402
487
  /**
@@ -420,5 +505,158 @@ class Value {
420
505
  Value.no_grad_mode = prev;
421
506
  }
422
507
  }
508
+ // TODO: Move code generation into make/makeNary to co-locate with runtime logic
509
+ getForwardCode(childCodes) {
510
+ if (this.paramName)
511
+ return this.paramName;
512
+ if (this.prev.length === 1) {
513
+ const [child] = childCodes;
514
+ switch (this._op) {
515
+ case 'exp': return `Math.exp(${child})`;
516
+ case 'log': return `Math.log(${child})`;
517
+ case 'sqrt': return `Math.sqrt(${child})`;
518
+ case 'tanh': return `Math.tanh(${child})`;
519
+ case 'sigmoid': return `(1 / (1 + Math.exp(-${child})))`;
520
+ case 'relu': return `Math.max(0, ${child})`;
521
+ case 'sin': return `Math.sin(${child})`;
522
+ case 'cos': return `Math.cos(${child})`;
523
+ case 'tan': return `Math.tan(${child})`;
524
+ case 'asin': return `Math.asin(${child})`;
525
+ case 'acos': return `Math.acos(${child})`;
526
+ case 'atan': return `Math.atan(${child})`;
527
+ case 'neg': return `(-${child})`;
528
+ case 'abs': return `Math.abs(${child})`;
529
+ case 'square': return `(${child} * ${child})`;
530
+ case 'cube': return `(${child} * ${child} * ${child})`;
531
+ case 'reciprocal': return `(1 / ${child})`;
532
+ case 'sign': return `Math.sign(${child})`;
533
+ case 'softplus': return `Math.log(1 + Math.exp(${child}))`;
534
+ case 'floor': return `Math.floor(${child})`;
535
+ case 'ceil': return `Math.ceil(${child})`;
536
+ case 'round': return `Math.round(${child})`;
537
+ case 'clamp': {
538
+ const [min, max] = this._opConstants || [0, 1];
539
+ return `Math.max(${min}, Math.min(${child}, ${max}))`;
540
+ }
541
+ default: return String(this.data);
542
+ }
543
+ }
544
+ // N-ary operations (6 inputs)
545
+ if (this.prev.length === 6 && this._op === 'eigenvalue_custom') {
546
+ const [c00, c01, c02, c11, c12, c22] = childCodes;
547
+ return `window.__eigenHelpers.computeSmallestEigenvalue(${c00}, ${c01}, ${c02}, ${c11}, ${c12}, ${c22})`;
548
+ }
549
+ const [left, right] = childCodes;
550
+ switch (this._op) {
551
+ case '+': return `(${left} + ${right})`;
552
+ case '-': return `(${left} - ${right})`;
553
+ case '*': return `(${left} * ${right})`;
554
+ case '/': return `(${left} / ${right})`;
555
+ case 'powValue': return `Math.pow(${left}, ${right})`;
556
+ case 'mod': return `(${left} % ${right})`;
557
+ case 'min': return `Math.min(${left}, ${right})`;
558
+ case 'max': return `Math.max(${left}, ${right})`;
559
+ default: return String(this.data);
560
+ }
561
+ }
562
+ // TODO: Move code generation into make/makeNary to co-locate with runtime logic
563
+ getBackwardCode(gradVar, childGrads, childVars) {
564
+ if (this.prev.length === 1) {
565
+ const [childGrad] = childGrads;
566
+ const [child] = childVars;
567
+ switch (this._op) {
568
+ case 'exp':
569
+ return `${childGrad} += ${gradVar} * Math.exp(${child});`;
570
+ case 'log':
571
+ return `${childGrad} += ${gradVar} / ${child};`;
572
+ case 'sqrt':
573
+ return `${childGrad} += ${gradVar} * 0.5 / Math.sqrt(${child});`;
574
+ case 'tanh': {
575
+ const tanhChild = `Math.tanh(${child})`;
576
+ return `${childGrad} += ${gradVar} * (1 - ${tanhChild} * ${tanhChild});`;
577
+ }
578
+ case 'sigmoid': {
579
+ const sigChild = `(1 / (1 + Math.exp(-${child})))`;
580
+ return `${childGrad} += ${gradVar} * ${sigChild} * (1 - ${sigChild});`;
581
+ }
582
+ case 'relu':
583
+ return `${childGrad} += ${gradVar} * (${child} > 0 ? 1 : 0);`;
584
+ case 'sin':
585
+ return `${childGrad} += ${gradVar} * Math.cos(${child});`;
586
+ case 'cos':
587
+ return `${childGrad} += ${gradVar} * (-Math.sin(${child}));`;
588
+ case 'tan': {
589
+ const cosChild = `Math.cos(${child})`;
590
+ return `${childGrad} += ${gradVar} / (${cosChild} * ${cosChild});`;
591
+ }
592
+ case 'asin':
593
+ return `${childGrad} += ${gradVar} / Math.sqrt(1 - ${child} * ${child});`;
594
+ case 'acos':
595
+ return `${childGrad} += ${gradVar} / (-Math.sqrt(1 - ${child} * ${child}));`;
596
+ case 'atan':
597
+ return `${childGrad} += ${gradVar} / (1 + ${child} * ${child});`;
598
+ case 'neg':
599
+ return `${childGrad} -= ${gradVar};`;
600
+ case 'abs':
601
+ return `${childGrad} += ${gradVar} * (${child} >= 0 ? 1 : -1);`;
602
+ case 'square':
603
+ return `${childGrad} += ${gradVar} * 2 * ${child};`;
604
+ case 'cube':
605
+ return `${childGrad} += ${gradVar} * 3 * ${child} * ${child};`;
606
+ case 'reciprocal':
607
+ return `${childGrad} -= ${gradVar} / (${child} * ${child});`;
608
+ case 'sign':
609
+ return `${childGrad} += 0;`;
610
+ case 'softplus': {
611
+ const expChild = `Math.exp(${child})`;
612
+ return `${childGrad} += ${gradVar} * ${expChild} / (1 + ${expChild});`;
613
+ }
614
+ case 'floor':
615
+ case 'ceil':
616
+ case 'round':
617
+ return `${childGrad} += 0;`;
618
+ case 'clamp': {
619
+ const [min, max] = this._opConstants || [0, 1];
620
+ return `${childGrad} += ${gradVar} * (${child} > ${min} && ${child} < ${max} ? 1 : 0);`;
621
+ }
622
+ default:
623
+ return '';
624
+ }
625
+ }
626
+ // N-ary operations (6 inputs)
627
+ if (this.prev.length === 6 && this._op === 'eigenvalue_custom') {
628
+ const [g00, g01, g02, g11, g12, g22] = childGrads;
629
+ const [v00, v01, v02, v11, v12, v22] = childVars;
630
+ return `{
631
+ const grads = window.__eigenHelpers.applyEigenvalueGradients(
632
+ ${gradVar}, ${v00}, ${v01}, ${v02}, ${v11}, ${v12}, ${v22},
633
+ ${g00}, ${g01}, ${g02}, ${g11}, ${g12}, ${g22}
634
+ );
635
+ ${g00} = grads[0]; ${g01} = grads[1]; ${g02} = grads[2];
636
+ ${g11} = grads[3]; ${g12} = grads[4]; ${g22} = grads[5];
637
+ }`;
638
+ }
639
+ const [leftGrad, rightGrad] = childGrads;
640
+ const [left, right] = childVars;
641
+ switch (this._op) {
642
+ case '+':
643
+ return `${leftGrad} += ${gradVar}; ${rightGrad} += ${gradVar};`;
644
+ case '-':
645
+ return `${leftGrad} += ${gradVar}; ${rightGrad} -= ${gradVar};`;
646
+ case '*':
647
+ return `${leftGrad} += ${gradVar} * ${right}; ${rightGrad} += ${gradVar} * ${left};`;
648
+ case '/':
649
+ return `${leftGrad} += ${gradVar} / ${right}; ${rightGrad} -= ${gradVar} * ${left} / (${right} * ${right});`;
650
+ case 'powValue':
651
+ return `${leftGrad} += ${gradVar} * ${right} * Math.pow(${left}, ${right} - 1); ${rightGrad} += ${gradVar} * Math.pow(${left}, ${right}) * Math.log(${left});`;
652
+ case 'mod':
653
+ return `${leftGrad} += ${gradVar}; ${rightGrad} += 0;`;
654
+ case 'min':
655
+ return `${leftGrad} += ${gradVar} * (${left} < ${right} ? 1 : 0); ${rightGrad} += ${gradVar} * (${right} < ${left} ? 1 : 0);`;
656
+ case 'max':
657
+ return `${leftGrad} += ${gradVar} * (${left} > ${right} ? 1 : 0); ${rightGrad} += ${gradVar} * (${right} > ${left} ? 1 : 0);`;
658
+ default:
659
+ return '';
660
+ }
661
+ }
423
662
  }
424
- exports.Value = Value;
@@ -1,34 +1,30 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ValueActivation = void 0;
4
- const Value_1 = require("./Value");
5
- class ValueActivation {
1
+ import { Value } from './Value';
2
+ export class ValueActivation {
6
3
  static relu(x) {
7
4
  const r = Math.max(0, x.data);
8
- return Value_1.Value.make(r, x, null, (out) => () => {
5
+ return Value.make(r, x, null, (out) => () => {
9
6
  if (x.requiresGrad)
10
7
  x.grad += (x.data > 0 ? 1 : 0) * out.grad;
11
- }, `relu(${x.label})`);
8
+ }, `relu(${x.label})`, 'relu');
12
9
  }
13
10
  static softplus(x) {
14
11
  const s = Math.log(1 + Math.exp(x.data));
15
- return Value_1.Value.make(s, x, null, (out) => () => {
12
+ return Value.make(s, x, null, (out) => () => {
16
13
  x.grad += 1 / (1 + Math.exp(-x.data)) * out.grad;
17
- }, `softplus(${x.label})`);
14
+ }, `softplus(${x.label})`, 'softplus');
18
15
  }
19
16
  static tanh(x) {
20
17
  const t = Math.tanh(x.data);
21
- return Value_1.Value.make(t, x, null, (out) => () => {
18
+ return Value.make(t, x, null, (out) => () => {
22
19
  if (x.requiresGrad)
23
20
  x.grad += (1 - t ** 2) * out.grad;
24
- }, `tanh(${x.label})`);
21
+ }, `tanh(${x.label})`, 'tanh');
25
22
  }
26
23
  static sigmoid(x) {
27
24
  const s = 1 / (1 + Math.exp(-x.data));
28
- return Value_1.Value.make(s, x, null, (out) => () => {
25
+ return Value.make(s, x, null, (out) => () => {
29
26
  if (x.requiresGrad)
30
27
  x.grad += s * (1 - s) * out.grad;
31
- }, `sigmoid(${x.label})`);
28
+ }, `sigmoid(${x.label})`, 'sigmoid');
32
29
  }
33
30
  }
34
- exports.ValueActivation = ValueActivation;
@@ -23,4 +23,5 @@ export declare class ValueArithmetic {
23
23
  static sum(vals: Value[]): Value;
24
24
  static mean(vals: Value[]): Value;
25
25
  static neg(a: Value): Value;
26
+ static sign(a: Value): Value;
26
27
  }