scalar-autograd 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/dist/Losses.d.ts +51 -0
  2. package/dist/Losses.js +145 -0
  3. package/dist/Losses.spec.d.ts +1 -0
  4. package/dist/Losses.spec.js +54 -0
  5. package/dist/Optimizers.d.ts +114 -0
  6. package/dist/Optimizers.edge-cases.spec.d.ts +1 -0
  7. package/dist/Optimizers.edge-cases.spec.js +29 -0
  8. package/dist/Optimizers.js +177 -0
  9. package/dist/Optimizers.spec.d.ts +1 -0
  10. package/dist/Optimizers.spec.js +56 -0
  11. package/dist/V.d.ts +0 -0
  12. package/dist/V.js +0 -0
  13. package/dist/Value.d.ts +260 -0
  14. package/dist/Value.edge-cases.spec.d.ts +1 -0
  15. package/dist/Value.edge-cases.spec.js +54 -0
  16. package/dist/Value.grad-flow.spec.d.ts +1 -0
  17. package/dist/Value.grad-flow.spec.js +24 -0
  18. package/dist/Value.js +424 -0
  19. package/dist/Value.losses-edge-cases.spec.d.ts +1 -0
  20. package/dist/Value.losses-edge-cases.spec.js +30 -0
  21. package/dist/Value.memory.spec.d.ts +1 -0
  22. package/dist/Value.memory.spec.js +23 -0
  23. package/dist/Value.nn.spec.d.ts +1 -0
  24. package/dist/Value.nn.spec.js +111 -0
  25. package/dist/Value.spec.d.ts +1 -0
  26. package/dist/Value.spec.js +245 -0
  27. package/dist/ValueActivation.d.ts +7 -0
  28. package/dist/ValueActivation.js +34 -0
  29. package/dist/ValueArithmetic.d.ts +26 -0
  30. package/dist/ValueArithmetic.js +180 -0
  31. package/dist/ValueComparison.d.ts +10 -0
  32. package/dist/ValueComparison.js +47 -0
  33. package/dist/ValueTrig.d.ts +9 -0
  34. package/dist/ValueTrig.js +49 -0
  35. package/package.json +4 -12
  36. package/Losses.ts +0 -145
  37. package/Optimizers.ts +0 -222
  38. package/V.ts +0 -0
  39. package/Value.edge-cases.spec.ts +0 -60
  40. package/Value.grad-flow.spec.ts +0 -24
  41. package/Value.losses-edge-cases.spec.ts +0 -32
  42. package/Value.memory.spec.ts +0 -25
  43. package/Value.nn.spec.ts +0 -109
  44. package/Value.spec.ts +0 -268
  45. package/Value.ts +0 -461
  46. package/ValueActivation.ts +0 -51
  47. package/ValueArithmetic.ts +0 -272
  48. package/ValueComparison.ts +0 -85
  49. package/ValueTrig.ts +0 -70
@@ -0,0 +1,260 @@
1
+ export type BackwardFn = () => void;
2
+ export { V } from './V';
3
+ export { Optimizer, SGD, Adam, AdamW } from './Optimizers';
4
+ export type { OptimizerOptions, AdamOptions } from './Optimizers';
5
+ export { Losses } from './Losses';
6
+ export declare class Value {
7
+ static no_grad_mode: boolean;
8
+ data: number;
9
+ grad: number;
10
+ requiresGrad: boolean;
11
+ private backwardFn;
12
+ private prev;
13
+ label: string;
14
+ constructor(data: number, label?: string, requiresGrad?: boolean);
15
+ private static ensureValue;
16
+ /**
17
+ * Returns sin(this).
18
+ * @returns New Value with sin.
19
+ */
20
+ sin(): Value;
21
+ /**
22
+ * Returns cos(this).
23
+ * @returns New Value with cos.
24
+ */
25
+ cos(): Value;
26
+ /**
27
+ * Returns tan(this).
28
+ * @returns New Value with tan.
29
+ */
30
+ tan(): Value;
31
+ /**
32
+ * Returns asin(this).
33
+ * @returns New Value with asin.
34
+ */
35
+ asin(): Value;
36
+ /**
37
+ * Returns acos(this).
38
+ * @returns New Value with acos.
39
+ */
40
+ acos(): Value;
41
+ /**
42
+ * Returns atan(this).
43
+ * @returns New Value with atan.
44
+ */
45
+ atan(): Value;
46
+ /**
47
+ * Returns relu(this).
48
+ * @returns New Value with relu.
49
+ */
50
+ relu(): Value;
51
+ /**
52
+ * Returns abs(this).
53
+ * @returns New Value with abs.
54
+ */
55
+ abs(): Value;
56
+ /**
57
+ * Returns exp(this).
58
+ * @returns New Value with exp.
59
+ */
60
+ exp(): Value;
61
+ /**
62
+ * Returns log(this).
63
+ * @returns New Value with log.
64
+ */
65
+ log(): Value;
66
+ /**
67
+ * Returns min(this, other).
68
+ * @param other Value to compare
69
+ * @returns New Value with min.
70
+ */
71
+ min(other: Value): Value;
72
+ /**
73
+ * Returns max(this, other).
74
+ * @param other Value to compare
75
+ * @returns New Value with max.
76
+ */
77
+ max(other: Value): Value;
78
+ /**
79
+ * Adds this and other.
80
+ * @param other Value or number to add
81
+ * @returns New Value with sum.
82
+ */
83
+ add(other: Value | number): Value;
84
+ /**
85
+ * Multiplies this and other.
86
+ * @param other Value or number to multiply
87
+ * @returns New Value with product.
88
+ */
89
+ mul(other: Value | number): Value;
90
+ /**
91
+ * Subtracts other from this.
92
+ * @param other Value or number to subtract
93
+ * @returns New Value with difference.
94
+ */
95
+ sub(other: Value | number): Value;
96
+ /**
97
+ * Divides this by other.
98
+ * @param other Value or number divisor
99
+ * @returns New Value with quotient.
100
+ */
101
+ div(other: Value | number): Value;
102
+ /**
103
+ * Raises this to the power exp.
104
+ * @param exp Exponent
105
+ * @returns New Value with pow(this, exp)
106
+ */
107
+ pow(exp: number): Value;
108
+ /**
109
+ * Raises this to a dynamic Value (other).
110
+ * @param other Exponent Value or number
111
+ * @returns New Value with pow(this, other)
112
+ */
113
+ powValue(other: Value | number): Value;
114
+ /**
115
+ * Returns this modulo other.
116
+ * @param other Divisor Value
117
+ * @returns New Value with modulo.
118
+ */
119
+ mod(other: Value): Value;
120
+ /**
121
+ * Returns Value indicating if this equals other.
122
+ * @param other Value to compare
123
+ * @returns New Value (1 if equal, else 0)
124
+ */
125
+ eq(other: Value): Value;
126
+ /**
127
+ * Returns Value indicating if this not equals other.
128
+ * @param other Value to compare
129
+ * @returns New Value (1 if not equal, else 0)
130
+ */
131
+ neq(other: Value): Value;
132
+ /**
133
+ * Returns Value indicating if this greater than other.
134
+ * @param other Value to compare
135
+ * @returns New Value (1 if true, else 0)
136
+ */
137
+ gt(other: Value): Value;
138
+ /**
139
+ * Returns Value indicating if this less than other.
140
+ * @param other Value to compare
141
+ * @returns New Value (1 if true, else 0)
142
+ */
143
+ lt(other: Value): Value;
144
+ /**
145
+ * Returns Value indicating if this greater than or equal to other.
146
+ * @param other Value to compare
147
+ * @returns New Value (1 if true, else 0)
148
+ */
149
+ gte(other: Value): Value;
150
+ /**
151
+ * Returns Value indicating if this less than or equal to other.
152
+ * @param other Value to compare
153
+ * @returns New Value (1 if true, else 0)
154
+ */
155
+ lte(other: Value): Value;
156
+ /**
157
+ * Returns softplus(this).
158
+ * @returns New Value with softplus.
159
+ */
160
+ softplus(): Value;
161
+ /**
162
+ * Returns the floor of this Value.
163
+ * @returns New Value with floor(data).
164
+ */
165
+ floor(): Value;
166
+ /**
167
+ * Returns the ceiling of this Value.
168
+ * @returns New Value with ceil(data).
169
+ */
170
+ ceil(): Value;
171
+ /**
172
+ * Returns the rounded value of this Value.
173
+ * @returns New Value with rounded data.
174
+ */
175
+ round(): Value;
176
+ /**
177
+ * Returns the square of this Value.
178
+ * @returns New Value with squared data.
179
+ */
180
+ square(): Value;
181
+ /**
182
+ * Returns the cube of this Value.
183
+ * @returns New Value with cubed data.
184
+ */
185
+ cube(): Value;
186
+ /**
187
+ * Returns the reciprocal (1/x) of this Value.
188
+ * @returns New Value with reciprocal.
189
+ */
190
+ reciprocal(): Value;
191
+ /**
192
+ * Clamps this between min and max.
193
+ * @param min Minimum value
194
+ * @param max Maximum value
195
+ * @returns New clamped Value
196
+ */
197
+ clamp(min: number, max: number): Value;
198
+ /**
199
+ * Returns the negation (-this) Value.
200
+ * @returns New Value which is the negation.
201
+ */
202
+ neg(): Value;
203
+ /**
204
+ * Returns the sum of the given Values.
205
+ * @param vals Array of Value objects
206
+ * @returns New Value holding their sum.
207
+ */
208
+ static sum(vals: Value[]): Value;
209
+ /**
210
+ * Returns the mean of the given Values.
211
+ * @param vals Array of Value objects
212
+ * @returns New Value holding their mean.
213
+ */
214
+ static mean(vals: Value[]): Value;
215
+ /**
216
+ * Returns tanh(this).
217
+ * @returns New Value with tanh.
218
+ */
219
+ tanh(): Value;
220
+ /**
221
+ * Returns sigmoid(this).
222
+ * @returns New Value with sigmoid.
223
+ */
224
+ sigmoid(): Value;
225
+ /**
226
+ * Performs a reverse-mode autodiff backward pass from this Value.
227
+ * @param zeroGrad If true, zeroes all grads in the graph before backward
228
+ */
229
+ backward(zeroGrad?: boolean): void;
230
+ /**
231
+ * Sets all grad fields in the computation tree (from root) to 0.
232
+ * @param root Value to zero tree from
233
+ */
234
+ static zeroGradTree(root: Value): void;
235
+ /**
236
+ * Sets all grad fields in all supplied trees to 0.
237
+ * @param vals Values whose trees to zero
238
+ */
239
+ static zeroGradAll(vals: Value[]): void;
240
+ /**
241
+ * Internal helper to construct a Value with correct backward fn and grads.
242
+ * @param data Output value data
243
+ * @param left Left operand Value
244
+ * @param right Right operand Value or null
245
+ * @param backwardFnBuilder Function to create backward closure
246
+ * @param label Node label for debugging
247
+ * @returns New Value node
248
+ */
249
+ static make(data: number, left: Value, right: Value | null, backwardFnBuilder: (out: Value) => BackwardFn, label: string): Value;
250
+ /**
251
+ * Returns string representation for debugging.
252
+ * @returns String summary of Value
253
+ */
254
+ toString(): string;
255
+ /**
256
+ * Temporarily disables gradient tracking within the callback scope, like torch.no_grad().
257
+ * Restores the previous state after running fn.
258
+ */
259
+ static withNoGrad<T>(fn: () => T): T;
260
+ }
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,54 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ const Value_1 = require("./Value");
4
+ // Edge cases and error handling
5
+ describe('Value edge cases and error handling', () => {
6
+ it('throws on invalid numeric inputs', () => {
7
+ expect(() => new Value_1.Value(NaN)).toThrow();
8
+ expect(() => new Value_1.Value(Infinity)).toThrow();
9
+ expect(() => new Value_1.Value(-Infinity)).toThrow();
10
+ });
11
+ it('handles gradient accumulation correctly', () => {
12
+ const x = new Value_1.Value(2, 'x', true);
13
+ const y = x.mul(3);
14
+ const z = x.mul(4);
15
+ const out = y.add(z);
16
+ out.backward();
17
+ expect(x.grad).toBe(7); // 3 + 4
18
+ });
19
+ it('handles repeated use of same value in expression', () => {
20
+ const x = new Value_1.Value(3, 'x', true);
21
+ const y = x.mul(x).mul(x); // x^3
22
+ y.backward();
23
+ expect(x.grad).toBeCloseTo(27); // 3*x^2 = 27
24
+ });
25
+ it('throws on division by zero', () => {
26
+ const a = new Value_1.Value(1);
27
+ const b = new Value_1.Value(0);
28
+ expect(() => a.div(b)).toThrow();
29
+ });
30
+ it('throws on log of negative number', () => {
31
+ const x = new Value_1.Value(-1);
32
+ expect(() => x.log()).toThrow();
33
+ });
34
+ it('throws on negative base with fractional exponent', () => {
35
+ const x = new Value_1.Value(-2);
36
+ expect(() => x.pow(0.5)).toThrow();
37
+ });
38
+ });
39
+ // Complex expressions
40
+ describe('Complex mathematical expressions', () => {
41
+ it('computes gradient of complex expression', () => {
42
+ const x = new Value_1.Value(0.5, 'x', true);
43
+ const y = x.sin().mul(x.cos()).add(x.exp());
44
+ y.backward();
45
+ const expected = Math.cos(0.5) ** 2 - Math.sin(0.5) ** 2 + Math.exp(0.5);
46
+ expect(x.grad).toBeCloseTo(expected, 4);
47
+ });
48
+ it('handles nested activation functions', () => {
49
+ const x = new Value_1.Value(0.5, 'x', true);
50
+ const y = x.tanh().sigmoid().relu();
51
+ y.backward();
52
+ expect(x.grad).toBeGreaterThan(0);
53
+ });
54
+ });
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,24 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ const Value_1 = require("./Value");
4
+ describe('Gradient flow control', () => {
5
+ it('stops gradient at non-requiresGrad nodes', () => {
6
+ const x = new Value_1.Value(2, 'x', true);
7
+ const y = new Value_1.Value(3, 'y', false);
8
+ const z = new Value_1.Value(4, 'z', true);
9
+ const out = x.mul(y).add(z);
10
+ out.backward();
11
+ expect(x.grad).toBe(3);
12
+ expect(y.grad).toBe(0);
13
+ expect(z.grad).toBe(1);
14
+ });
15
+ it('handles detached computation graphs', () => {
16
+ const x = new Value_1.Value(2, 'x', true);
17
+ const y = x.mul(3);
18
+ const z = new Value_1.Value(y.data, 'z', true); // detached
19
+ const out = z.mul(4);
20
+ out.backward();
21
+ expect(z.grad).toBe(4);
22
+ expect(x.grad).toBe(0); // no gradient flows to x
23
+ });
24
+ });