scalar-autograd 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Value.ts DELETED
@@ -1,461 +0,0 @@
1
- export type BackwardFn = () => void;
2
- export { V } from './V';
3
- export { Optimizer, SGD, Adam, AdamW } from './Optimizers';
4
- export type { OptimizerOptions, AdamOptions } from './Optimizers';
5
- export { Losses } from './Losses';
6
-
7
- const EPS = 1e-12;
8
-
9
-
10
- import { ValueTrig } from './ValueTrig';
11
- import { ValueActivation } from './ValueActivation';
12
- import { ValueArithmetic } from './ValueArithmetic';
13
- import { ValueComparison } from './ValueComparison';
14
-
15
- export class Value {
16
- static no_grad_mode = false;
17
- data: number;
18
- grad: number = 0;
19
- requiresGrad: boolean;
20
- private backwardFn: BackwardFn = () => {};
21
- private prev: Value[] = [];
22
- public label: string;
23
-
24
- constructor(data: number, label = "", requiresGrad = false) {
25
- if (typeof data !== 'number' || Number.isNaN(data) || !Number.isFinite(data)) {
26
- throw new Error(`Invalid number passed to Value: ${data}`);
27
- }
28
- this.data = data;
29
- this.label = label;
30
- this.requiresGrad = requiresGrad;
31
- }
32
-
33
- private static ensureValue(x: Value | number): Value {
34
- return typeof x === 'number' ? new Value(x) : x;
35
- }
36
-
37
- /**
38
- * Returns sin(this).
39
- * @returns New Value with sin.
40
- */
41
- sin(): Value {
42
- return ValueTrig.sin(this);
43
- }
44
-
45
- /**
46
- * Returns cos(this).
47
- * @returns New Value with cos.
48
- */
49
- cos(): Value {
50
- return ValueTrig.cos(this);
51
- }
52
-
53
- /**
54
- * Returns tan(this).
55
- * @returns New Value with tan.
56
- */
57
- tan(): Value {
58
- return ValueTrig.tan(this);
59
- }
60
-
61
- /**
62
- * Returns asin(this).
63
- * @returns New Value with asin.
64
- */
65
- asin(): Value {
66
- return ValueTrig.asin(this);
67
- }
68
-
69
- /**
70
- * Returns acos(this).
71
- * @returns New Value with acos.
72
- */
73
- acos(): Value {
74
- return ValueTrig.acos(this);
75
- }
76
-
77
- /**
78
- * Returns atan(this).
79
- * @returns New Value with atan.
80
- */
81
- atan(): Value {
82
- return ValueTrig.atan(this);
83
- }
84
-
85
- /**
86
- * Returns relu(this).
87
- * @returns New Value with relu.
88
- */
89
- relu(): Value {
90
- return ValueActivation.relu(this);
91
- }
92
-
93
- /**
94
- * Returns abs(this).
95
- * @returns New Value with abs.
96
- */
97
- abs(): Value {
98
- return ValueArithmetic.abs(this);
99
- }
100
-
101
- /**
102
- * Returns exp(this).
103
- * @returns New Value with exp.
104
- */
105
- exp(): Value {
106
- return ValueArithmetic.exp(this);
107
- }
108
-
109
- /**
110
- * Returns log(this).
111
- * @returns New Value with log.
112
- */
113
- log(): Value {
114
- return ValueArithmetic.log(this, EPS);
115
- }
116
-
117
- /**
118
- * Returns min(this, other).
119
- * @param other Value to compare
120
- * @returns New Value with min.
121
- */
122
- min(other: Value): Value {
123
- return ValueArithmetic.min(this, other);
124
- }
125
-
126
- /**
127
- * Returns max(this, other).
128
- * @param other Value to compare
129
- * @returns New Value with max.
130
- */
131
- max(other: Value): Value {
132
- return ValueArithmetic.max(this, other);
133
- }
134
-
135
- /**
136
- * Adds this and other.
137
- * @param other Value or number to add
138
- * @returns New Value with sum.
139
- */
140
- add(other: Value | number): Value {
141
- return ValueArithmetic.add(this, Value.ensureValue(other));
142
- }
143
- /**
144
- * Multiplies this and other.
145
- * @param other Value or number to multiply
146
- * @returns New Value with product.
147
- */
148
- mul(other: Value | number): Value {
149
- return ValueArithmetic.mul(this, Value.ensureValue(other));
150
- }
151
-
152
- /**
153
- * Subtracts other from this.
154
- * @param other Value or number to subtract
155
- * @returns New Value with difference.
156
- */
157
- sub(other: Value | number): Value {
158
- return ValueArithmetic.sub(this, Value.ensureValue(other));
159
- }
160
-
161
- /**
162
- * Divides this by other.
163
- * @param other Value or number divisor
164
- * @returns New Value with quotient.
165
- */
166
- div(other: Value | number): Value {
167
- return ValueArithmetic.div(this, Value.ensureValue(other), EPS);
168
- }
169
-
170
- /**
171
- * Raises this to the power exp.
172
- * @param exp Exponent
173
- * @returns New Value with pow(this, exp)
174
- */
175
- pow(exp: number): Value {
176
- return ValueArithmetic.pow(this, exp);
177
- }
178
-
179
- /**
180
- * Raises this to a dynamic Value (other).
181
- * @param other Exponent Value or number
182
- * @returns New Value with pow(this, other)
183
- */
184
- powValue(other: Value | number): Value {
185
- return ValueArithmetic.powValue(this, Value.ensureValue(other), EPS);
186
- }
187
-
188
- /**
189
- * Returns this modulo other.
190
- * @param other Divisor Value
191
- * @returns New Value with modulo.
192
- */
193
- mod(other: Value): Value {
194
- return ValueArithmetic.mod(this, other);
195
- }
196
-
197
- /**
198
- * Returns Value indicating if this equals other.
199
- * @param other Value to compare
200
- * @returns New Value (1 if equal, else 0)
201
- */
202
- eq(other: Value): Value {
203
- return ValueComparison.eq(this, other);
204
- }
205
- /**
206
- * Returns Value indicating if this not equals other.
207
- * @param other Value to compare
208
- * @returns New Value (1 if not equal, else 0)
209
- */
210
- neq(other: Value): Value {
211
- return ValueComparison.neq(this, other);
212
- }
213
- /**
214
- * Returns Value indicating if this greater than other.
215
- * @param other Value to compare
216
- * @returns New Value (1 if true, else 0)
217
- */
218
- gt(other: Value): Value {
219
- return ValueComparison.gt(this, other);
220
- }
221
- /**
222
- * Returns Value indicating if this less than other.
223
- * @param other Value to compare
224
- * @returns New Value (1 if true, else 0)
225
- */
226
- lt(other: Value): Value {
227
- return ValueComparison.lt(this, other);
228
- }
229
- /**
230
- * Returns Value indicating if this greater than or equal to other.
231
- * @param other Value to compare
232
- * @returns New Value (1 if true, else 0)
233
- */
234
- gte(other: Value): Value {
235
- return ValueComparison.gte(this, other);
236
- }
237
- /**
238
- * Returns Value indicating if this less than or equal to other.
239
- * @param other Value to compare
240
- * @returns New Value (1 if true, else 0)
241
- */
242
- lte(other: Value): Value {
243
- return ValueComparison.lte(this, other);
244
- }
245
-
246
- /**
247
- * Returns softplus(this).
248
- * @returns New Value with softplus.
249
- */
250
- softplus(): Value {
251
- return ValueActivation.softplus(this);
252
- }
253
-
254
- /**
255
- * Returns the floor of this Value.
256
- * @returns New Value with floor(data).
257
- */
258
- floor(): Value {
259
- return ValueArithmetic.floor(this);
260
- }
261
- /**
262
- * Returns the ceiling of this Value.
263
- * @returns New Value with ceil(data).
264
- */
265
- ceil(): Value {
266
- return ValueArithmetic.ceil(this);
267
- }
268
- /**
269
- * Returns the rounded value of this Value.
270
- * @returns New Value with rounded data.
271
- */
272
- round(): Value {
273
- return ValueArithmetic.round(this);
274
- }
275
- /**
276
- * Returns the square of this Value.
277
- * @returns New Value with squared data.
278
- */
279
- square(): Value {
280
- return ValueArithmetic.square(this);
281
- }
282
- /**
283
- * Returns the cube of this Value.
284
- * @returns New Value with cubed data.
285
- */
286
- cube(): Value {
287
- return ValueArithmetic.cube(this);
288
- }
289
- /**
290
- * Returns the reciprocal (1/x) of this Value.
291
- * @returns New Value with reciprocal.
292
- */
293
- reciprocal(): Value {
294
- return ValueArithmetic.reciprocal(this, EPS);
295
- }
296
-
297
- /**
298
- * Clamps this between min and max.
299
- * @param min Minimum value
300
- * @param max Maximum value
301
- * @returns New clamped Value
302
- */
303
- clamp(min: number, max: number): Value {
304
- return ValueArithmetic.clamp(this, min, max);
305
- }
306
-
307
- /**
308
- * Returns the negation (-this) Value.
309
- * @returns New Value which is the negation.
310
- */
311
- neg(): Value {
312
- return ValueArithmetic.neg(this);
313
- }
314
-
315
- /**
316
- * Returns the sum of the given Values.
317
- * @param vals Array of Value objects
318
- * @returns New Value holding their sum.
319
- */
320
- static sum(vals: Value[]): Value {
321
- return ValueArithmetic.sum(vals);
322
- }
323
-
324
- /**
325
- * Returns the mean of the given Values.
326
- * @param vals Array of Value objects
327
- * @returns New Value holding their mean.
328
- */
329
- static mean(vals: Value[]): Value {
330
- return ValueArithmetic.mean(vals);
331
- }
332
-
333
- /**
334
- * Returns tanh(this).
335
- * @returns New Value with tanh.
336
- */
337
- tanh(): Value {
338
- return ValueActivation.tanh(this);
339
- }
340
-
341
- /**
342
- * Returns sigmoid(this).
343
- * @returns New Value with sigmoid.
344
- */
345
- sigmoid(): Value {
346
- return ValueActivation.sigmoid(this);
347
- }
348
-
349
- /**
350
- * Performs a reverse-mode autodiff backward pass from this Value.
351
- * @param zeroGrad If true, zeroes all grads in the graph before backward
352
- */
353
- backward(zeroGrad = false): void {
354
- // Only allow backward on scalars (not arrays), i.e. single value outputs
355
- // (output shape check is redundant for this codebase, but keep to scalar-by-convention)
356
- if (zeroGrad) Value.zeroGradTree(this);
357
-
358
- const topo: Value[] = [];
359
- const visited = new Set<Value>();
360
-
361
- const buildTopo = (v: Value) => {
362
- if (!visited.has(v)) {
363
- visited.add(v);
364
- for (const child of v.prev) {
365
- buildTopo(child);
366
- }
367
- topo.push(v);
368
- }
369
- };
370
-
371
- buildTopo(this);
372
- this.grad = 1;
373
-
374
- for (let i = topo.length - 1; i >= 0; i--) {
375
- if (topo[i].requiresGrad) {
376
- topo[i].backwardFn();
377
- }
378
- }
379
- }
380
-
381
- /**
382
- * Sets all grad fields in the computation tree (from root) to 0.
383
- * @param root Value to zero tree from
384
- */
385
- static zeroGradTree(root: Value): void {
386
- const visited = new Set<Value>();
387
- const visit = (v: Value) => {
388
- if (!visited.has(v)) {
389
- visited.add(v);
390
- v.grad = 0;
391
- for (const child of v.prev) visit(child);
392
- }
393
- };
394
- visit(root);
395
- }
396
-
397
- /**
398
- * Sets all grad fields in all supplied trees to 0.
399
- * @param vals Values whose trees to zero
400
- */
401
- static zeroGradAll(vals: Value[]): void {
402
- const visited = new Set<Value>();
403
- for (const v of vals) {
404
- const visit = (u: Value) => {
405
- if (!visited.has(u)) {
406
- visited.add(u);
407
- u.grad = 0;
408
- for (const child of u.prev) visit(child);
409
- }
410
- };
411
- visit(v);
412
- }
413
- }
414
-
415
- /**
416
- * Internal helper to construct a Value with correct backward fn and grads.
417
- * @param data Output value data
418
- * @param left Left operand Value
419
- * @param right Right operand Value or null
420
- * @param backwardFnBuilder Function to create backward closure
421
- * @param label Node label for debugging
422
- * @returns New Value node
423
- */
424
- static make(
425
- data: number,
426
- left: Value,
427
- right: Value | null,
428
- backwardFnBuilder: (out: Value) => BackwardFn,
429
- label: string
430
- ): Value {
431
- const requiresGrad = !Value.no_grad_mode && [left, right].filter(Boolean).some(v => v!.requiresGrad);
432
- const out = new Value(data, label, requiresGrad);
433
- out.prev = Value.no_grad_mode ? [] : ([left, right].filter(Boolean) as Value[]);
434
- if (requiresGrad) {
435
- out.backwardFn = backwardFnBuilder(out);
436
- }
437
- return out;
438
- }
439
-
440
- /**
441
- * Returns string representation for debugging.
442
- * @returns String summary of Value
443
- */
444
- toString(): string {
445
- return `Value(data=${this.data.toFixed(4)}, grad=${this.grad.toFixed(4)}, label=${this.label})`;
446
- }
447
-
448
- /**
449
- * Temporarily disables gradient tracking within the callback scope, like torch.no_grad().
450
- * Restores the previous state after running fn.
451
- */
452
- static withNoGrad<T>(fn: () => T): T {
453
- const prev = Value.no_grad_mode;
454
- Value.no_grad_mode = true;
455
- try {
456
- return fn();
457
- } finally {
458
- Value.no_grad_mode = prev;
459
- }
460
- }
461
- }
@@ -1,51 +0,0 @@
1
- import { Value } from './Value';
2
-
3
- export class ValueActivation {
4
- static relu(x: Value): Value {
5
- const r = Math.max(0, x.data);
6
- return Value.make(
7
- r,
8
- x, null,
9
- (out) => () => {
10
- if (x.requiresGrad) x.grad += (x.data > 0 ? 1 : 0) * out.grad;
11
- },
12
- `relu(${x.label})`
13
- );
14
- }
15
-
16
- static softplus(x: Value): Value {
17
- const s = Math.log(1 + Math.exp(x.data));
18
- return Value.make(
19
- s,
20
- x, null,
21
- (out) => () => {
22
- x.grad += 1 / (1 + Math.exp(-x.data)) * out.grad;
23
- },
24
- `softplus(${x.label})`
25
- );
26
- }
27
-
28
- static tanh(x: Value): Value {
29
- const t = Math.tanh(x.data);
30
- return Value.make(
31
- t,
32
- x, null,
33
- (out) => () => {
34
- if (x.requiresGrad) x.grad += (1 - t ** 2) * out.grad;
35
- },
36
- `tanh(${x.label})`
37
- );
38
- }
39
-
40
- static sigmoid(x: Value): Value {
41
- const s = 1 / (1 + Math.exp(-x.data));
42
- return Value.make(
43
- s,
44
- x, null,
45
- (out) => () => {
46
- if (x.requiresGrad) x.grad += s * (1 - s) * out.grad;
47
- },
48
- `sigmoid(${x.label})`
49
- );
50
- }
51
- }