scalar-autograd 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Losses.js +145 -0
- package/dist/Losses.spec.js +54 -0
- package/dist/Optimizers.edge-cases.spec.js +29 -0
- package/dist/Optimizers.js +177 -0
- package/dist/Optimizers.spec.js +56 -0
- package/dist/V.js +0 -0
- package/dist/Value.edge-cases.spec.js +54 -0
- package/dist/Value.grad-flow.spec.js +24 -0
- package/dist/Value.js +424 -0
- package/dist/Value.losses-edge-cases.spec.js +30 -0
- package/dist/Value.memory.spec.js +23 -0
- package/dist/Value.nn.spec.js +111 -0
- package/dist/Value.spec.js +245 -0
- package/dist/ValueActivation.js +34 -0
- package/dist/ValueArithmetic.js +180 -0
- package/dist/ValueComparison.js +47 -0
- package/dist/ValueTrig.js +49 -0
- package/package.json +4 -12
- package/Losses.ts +0 -145
- package/Optimizers.ts +0 -222
- package/V.ts +0 -0
- package/Value.edge-cases.spec.ts +0 -60
- package/Value.grad-flow.spec.ts +0 -24
- package/Value.losses-edge-cases.spec.ts +0 -32
- package/Value.memory.spec.ts +0 -25
- package/Value.nn.spec.ts +0 -109
- package/Value.spec.ts +0 -268
- package/Value.ts +0 -461
- package/ValueActivation.ts +0 -51
- package/ValueArithmetic.ts +0 -272
- package/ValueComparison.ts +0 -85
- package/ValueTrig.ts +0 -70
package/dist/Value.js
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Value = exports.Losses = exports.AdamW = exports.Adam = exports.SGD = exports.Optimizer = exports.V = void 0;
|
|
4
|
+
var V_1 = require("./V");
|
|
5
|
+
Object.defineProperty(exports, "V", { enumerable: true, get: function () { return V_1.V; } });
|
|
6
|
+
var Optimizers_1 = require("./Optimizers");
|
|
7
|
+
Object.defineProperty(exports, "Optimizer", { enumerable: true, get: function () { return Optimizers_1.Optimizer; } });
|
|
8
|
+
Object.defineProperty(exports, "SGD", { enumerable: true, get: function () { return Optimizers_1.SGD; } });
|
|
9
|
+
Object.defineProperty(exports, "Adam", { enumerable: true, get: function () { return Optimizers_1.Adam; } });
|
|
10
|
+
Object.defineProperty(exports, "AdamW", { enumerable: true, get: function () { return Optimizers_1.AdamW; } });
|
|
11
|
+
var Losses_1 = require("./Losses");
|
|
12
|
+
Object.defineProperty(exports, "Losses", { enumerable: true, get: function () { return Losses_1.Losses; } });
|
|
13
|
+
const EPS = 1e-12;
|
|
14
|
+
const ValueTrig_1 = require("./ValueTrig");
|
|
15
|
+
const ValueActivation_1 = require("./ValueActivation");
|
|
16
|
+
const ValueArithmetic_1 = require("./ValueArithmetic");
|
|
17
|
+
const ValueComparison_1 = require("./ValueComparison");
|
|
18
|
+
class Value {
|
|
19
|
+
static no_grad_mode = false;
|
|
20
|
+
data;
|
|
21
|
+
grad = 0;
|
|
22
|
+
requiresGrad;
|
|
23
|
+
backwardFn = () => { };
|
|
24
|
+
prev = [];
|
|
25
|
+
label;
|
|
26
|
+
constructor(data, label = "", requiresGrad = false) {
|
|
27
|
+
if (typeof data !== 'number' || Number.isNaN(data) || !Number.isFinite(data)) {
|
|
28
|
+
throw new Error(`Invalid number passed to Value: ${data}`);
|
|
29
|
+
}
|
|
30
|
+
this.data = data;
|
|
31
|
+
this.label = label;
|
|
32
|
+
this.requiresGrad = requiresGrad;
|
|
33
|
+
}
|
|
34
|
+
static ensureValue(x) {
|
|
35
|
+
return typeof x === 'number' ? new Value(x) : x;
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Returns sin(this).
|
|
39
|
+
* @returns New Value with sin.
|
|
40
|
+
*/
|
|
41
|
+
sin() {
|
|
42
|
+
return ValueTrig_1.ValueTrig.sin(this);
|
|
43
|
+
}
|
|
44
|
+
/**
|
|
45
|
+
* Returns cos(this).
|
|
46
|
+
* @returns New Value with cos.
|
|
47
|
+
*/
|
|
48
|
+
cos() {
|
|
49
|
+
return ValueTrig_1.ValueTrig.cos(this);
|
|
50
|
+
}
|
|
51
|
+
/**
|
|
52
|
+
* Returns tan(this).
|
|
53
|
+
* @returns New Value with tan.
|
|
54
|
+
*/
|
|
55
|
+
tan() {
|
|
56
|
+
return ValueTrig_1.ValueTrig.tan(this);
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Returns asin(this).
|
|
60
|
+
* @returns New Value with asin.
|
|
61
|
+
*/
|
|
62
|
+
asin() {
|
|
63
|
+
return ValueTrig_1.ValueTrig.asin(this);
|
|
64
|
+
}
|
|
65
|
+
/**
|
|
66
|
+
* Returns acos(this).
|
|
67
|
+
* @returns New Value with acos.
|
|
68
|
+
*/
|
|
69
|
+
acos() {
|
|
70
|
+
return ValueTrig_1.ValueTrig.acos(this);
|
|
71
|
+
}
|
|
72
|
+
/**
|
|
73
|
+
* Returns atan(this).
|
|
74
|
+
* @returns New Value with atan.
|
|
75
|
+
*/
|
|
76
|
+
atan() {
|
|
77
|
+
return ValueTrig_1.ValueTrig.atan(this);
|
|
78
|
+
}
|
|
79
|
+
/**
|
|
80
|
+
* Returns relu(this).
|
|
81
|
+
* @returns New Value with relu.
|
|
82
|
+
*/
|
|
83
|
+
relu() {
|
|
84
|
+
return ValueActivation_1.ValueActivation.relu(this);
|
|
85
|
+
}
|
|
86
|
+
/**
|
|
87
|
+
* Returns abs(this).
|
|
88
|
+
* @returns New Value with abs.
|
|
89
|
+
*/
|
|
90
|
+
abs() {
|
|
91
|
+
return ValueArithmetic_1.ValueArithmetic.abs(this);
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Returns exp(this).
|
|
95
|
+
* @returns New Value with exp.
|
|
96
|
+
*/
|
|
97
|
+
exp() {
|
|
98
|
+
return ValueArithmetic_1.ValueArithmetic.exp(this);
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* Returns log(this).
|
|
102
|
+
* @returns New Value with log.
|
|
103
|
+
*/
|
|
104
|
+
log() {
|
|
105
|
+
return ValueArithmetic_1.ValueArithmetic.log(this, EPS);
|
|
106
|
+
}
|
|
107
|
+
/**
|
|
108
|
+
* Returns min(this, other).
|
|
109
|
+
* @param other Value to compare
|
|
110
|
+
* @returns New Value with min.
|
|
111
|
+
*/
|
|
112
|
+
min(other) {
|
|
113
|
+
return ValueArithmetic_1.ValueArithmetic.min(this, other);
|
|
114
|
+
}
|
|
115
|
+
/**
|
|
116
|
+
* Returns max(this, other).
|
|
117
|
+
* @param other Value to compare
|
|
118
|
+
* @returns New Value with max.
|
|
119
|
+
*/
|
|
120
|
+
max(other) {
|
|
121
|
+
return ValueArithmetic_1.ValueArithmetic.max(this, other);
|
|
122
|
+
}
|
|
123
|
+
/**
|
|
124
|
+
* Adds this and other.
|
|
125
|
+
* @param other Value or number to add
|
|
126
|
+
* @returns New Value with sum.
|
|
127
|
+
*/
|
|
128
|
+
add(other) {
|
|
129
|
+
return ValueArithmetic_1.ValueArithmetic.add(this, Value.ensureValue(other));
|
|
130
|
+
}
|
|
131
|
+
/**
|
|
132
|
+
* Multiplies this and other.
|
|
133
|
+
* @param other Value or number to multiply
|
|
134
|
+
* @returns New Value with product.
|
|
135
|
+
*/
|
|
136
|
+
mul(other) {
|
|
137
|
+
return ValueArithmetic_1.ValueArithmetic.mul(this, Value.ensureValue(other));
|
|
138
|
+
}
|
|
139
|
+
/**
|
|
140
|
+
* Subtracts other from this.
|
|
141
|
+
* @param other Value or number to subtract
|
|
142
|
+
* @returns New Value with difference.
|
|
143
|
+
*/
|
|
144
|
+
sub(other) {
|
|
145
|
+
return ValueArithmetic_1.ValueArithmetic.sub(this, Value.ensureValue(other));
|
|
146
|
+
}
|
|
147
|
+
/**
|
|
148
|
+
* Divides this by other.
|
|
149
|
+
* @param other Value or number divisor
|
|
150
|
+
* @returns New Value with quotient.
|
|
151
|
+
*/
|
|
152
|
+
div(other) {
|
|
153
|
+
return ValueArithmetic_1.ValueArithmetic.div(this, Value.ensureValue(other), EPS);
|
|
154
|
+
}
|
|
155
|
+
/**
|
|
156
|
+
* Raises this to the power exp.
|
|
157
|
+
* @param exp Exponent
|
|
158
|
+
* @returns New Value with pow(this, exp)
|
|
159
|
+
*/
|
|
160
|
+
pow(exp) {
|
|
161
|
+
return ValueArithmetic_1.ValueArithmetic.pow(this, exp);
|
|
162
|
+
}
|
|
163
|
+
/**
|
|
164
|
+
* Raises this to a dynamic Value (other).
|
|
165
|
+
* @param other Exponent Value or number
|
|
166
|
+
* @returns New Value with pow(this, other)
|
|
167
|
+
*/
|
|
168
|
+
powValue(other) {
|
|
169
|
+
return ValueArithmetic_1.ValueArithmetic.powValue(this, Value.ensureValue(other), EPS);
|
|
170
|
+
}
|
|
171
|
+
/**
|
|
172
|
+
* Returns this modulo other.
|
|
173
|
+
* @param other Divisor Value
|
|
174
|
+
* @returns New Value with modulo.
|
|
175
|
+
*/
|
|
176
|
+
mod(other) {
|
|
177
|
+
return ValueArithmetic_1.ValueArithmetic.mod(this, other);
|
|
178
|
+
}
|
|
179
|
+
/**
|
|
180
|
+
* Returns Value indicating if this equals other.
|
|
181
|
+
* @param other Value to compare
|
|
182
|
+
* @returns New Value (1 if equal, else 0)
|
|
183
|
+
*/
|
|
184
|
+
eq(other) {
|
|
185
|
+
return ValueComparison_1.ValueComparison.eq(this, other);
|
|
186
|
+
}
|
|
187
|
+
/**
|
|
188
|
+
* Returns Value indicating if this not equals other.
|
|
189
|
+
* @param other Value to compare
|
|
190
|
+
* @returns New Value (1 if not equal, else 0)
|
|
191
|
+
*/
|
|
192
|
+
neq(other) {
|
|
193
|
+
return ValueComparison_1.ValueComparison.neq(this, other);
|
|
194
|
+
}
|
|
195
|
+
/**
|
|
196
|
+
* Returns Value indicating if this greater than other.
|
|
197
|
+
* @param other Value to compare
|
|
198
|
+
* @returns New Value (1 if true, else 0)
|
|
199
|
+
*/
|
|
200
|
+
gt(other) {
|
|
201
|
+
return ValueComparison_1.ValueComparison.gt(this, other);
|
|
202
|
+
}
|
|
203
|
+
/**
|
|
204
|
+
* Returns Value indicating if this less than other.
|
|
205
|
+
* @param other Value to compare
|
|
206
|
+
* @returns New Value (1 if true, else 0)
|
|
207
|
+
*/
|
|
208
|
+
lt(other) {
|
|
209
|
+
return ValueComparison_1.ValueComparison.lt(this, other);
|
|
210
|
+
}
|
|
211
|
+
/**
|
|
212
|
+
* Returns Value indicating if this greater than or equal to other.
|
|
213
|
+
* @param other Value to compare
|
|
214
|
+
* @returns New Value (1 if true, else 0)
|
|
215
|
+
*/
|
|
216
|
+
gte(other) {
|
|
217
|
+
return ValueComparison_1.ValueComparison.gte(this, other);
|
|
218
|
+
}
|
|
219
|
+
/**
|
|
220
|
+
* Returns Value indicating if this less than or equal to other.
|
|
221
|
+
* @param other Value to compare
|
|
222
|
+
* @returns New Value (1 if true, else 0)
|
|
223
|
+
*/
|
|
224
|
+
lte(other) {
|
|
225
|
+
return ValueComparison_1.ValueComparison.lte(this, other);
|
|
226
|
+
}
|
|
227
|
+
/**
|
|
228
|
+
* Returns softplus(this).
|
|
229
|
+
* @returns New Value with softplus.
|
|
230
|
+
*/
|
|
231
|
+
softplus() {
|
|
232
|
+
return ValueActivation_1.ValueActivation.softplus(this);
|
|
233
|
+
}
|
|
234
|
+
/**
|
|
235
|
+
* Returns the floor of this Value.
|
|
236
|
+
* @returns New Value with floor(data).
|
|
237
|
+
*/
|
|
238
|
+
floor() {
|
|
239
|
+
return ValueArithmetic_1.ValueArithmetic.floor(this);
|
|
240
|
+
}
|
|
241
|
+
/**
|
|
242
|
+
* Returns the ceiling of this Value.
|
|
243
|
+
* @returns New Value with ceil(data).
|
|
244
|
+
*/
|
|
245
|
+
ceil() {
|
|
246
|
+
return ValueArithmetic_1.ValueArithmetic.ceil(this);
|
|
247
|
+
}
|
|
248
|
+
/**
|
|
249
|
+
* Returns the rounded value of this Value.
|
|
250
|
+
* @returns New Value with rounded data.
|
|
251
|
+
*/
|
|
252
|
+
round() {
|
|
253
|
+
return ValueArithmetic_1.ValueArithmetic.round(this);
|
|
254
|
+
}
|
|
255
|
+
/**
|
|
256
|
+
* Returns the square of this Value.
|
|
257
|
+
* @returns New Value with squared data.
|
|
258
|
+
*/
|
|
259
|
+
square() {
|
|
260
|
+
return ValueArithmetic_1.ValueArithmetic.square(this);
|
|
261
|
+
}
|
|
262
|
+
/**
|
|
263
|
+
* Returns the cube of this Value.
|
|
264
|
+
* @returns New Value with cubed data.
|
|
265
|
+
*/
|
|
266
|
+
cube() {
|
|
267
|
+
return ValueArithmetic_1.ValueArithmetic.cube(this);
|
|
268
|
+
}
|
|
269
|
+
/**
|
|
270
|
+
* Returns the reciprocal (1/x) of this Value.
|
|
271
|
+
* @returns New Value with reciprocal.
|
|
272
|
+
*/
|
|
273
|
+
reciprocal() {
|
|
274
|
+
return ValueArithmetic_1.ValueArithmetic.reciprocal(this, EPS);
|
|
275
|
+
}
|
|
276
|
+
/**
|
|
277
|
+
* Clamps this between min and max.
|
|
278
|
+
* @param min Minimum value
|
|
279
|
+
* @param max Maximum value
|
|
280
|
+
* @returns New clamped Value
|
|
281
|
+
*/
|
|
282
|
+
clamp(min, max) {
|
|
283
|
+
return ValueArithmetic_1.ValueArithmetic.clamp(this, min, max);
|
|
284
|
+
}
|
|
285
|
+
/**
|
|
286
|
+
* Returns the negation (-this) Value.
|
|
287
|
+
* @returns New Value which is the negation.
|
|
288
|
+
*/
|
|
289
|
+
neg() {
|
|
290
|
+
return ValueArithmetic_1.ValueArithmetic.neg(this);
|
|
291
|
+
}
|
|
292
|
+
/**
|
|
293
|
+
* Returns the sum of the given Values.
|
|
294
|
+
* @param vals Array of Value objects
|
|
295
|
+
* @returns New Value holding their sum.
|
|
296
|
+
*/
|
|
297
|
+
static sum(vals) {
|
|
298
|
+
return ValueArithmetic_1.ValueArithmetic.sum(vals);
|
|
299
|
+
}
|
|
300
|
+
/**
|
|
301
|
+
* Returns the mean of the given Values.
|
|
302
|
+
* @param vals Array of Value objects
|
|
303
|
+
* @returns New Value holding their mean.
|
|
304
|
+
*/
|
|
305
|
+
static mean(vals) {
|
|
306
|
+
return ValueArithmetic_1.ValueArithmetic.mean(vals);
|
|
307
|
+
}
|
|
308
|
+
/**
|
|
309
|
+
* Returns tanh(this).
|
|
310
|
+
* @returns New Value with tanh.
|
|
311
|
+
*/
|
|
312
|
+
tanh() {
|
|
313
|
+
return ValueActivation_1.ValueActivation.tanh(this);
|
|
314
|
+
}
|
|
315
|
+
/**
|
|
316
|
+
* Returns sigmoid(this).
|
|
317
|
+
* @returns New Value with sigmoid.
|
|
318
|
+
*/
|
|
319
|
+
sigmoid() {
|
|
320
|
+
return ValueActivation_1.ValueActivation.sigmoid(this);
|
|
321
|
+
}
|
|
322
|
+
/**
|
|
323
|
+
* Performs a reverse-mode autodiff backward pass from this Value.
|
|
324
|
+
* @param zeroGrad If true, zeroes all grads in the graph before backward
|
|
325
|
+
*/
|
|
326
|
+
backward(zeroGrad = false) {
|
|
327
|
+
// Only allow backward on scalars (not arrays), i.e. single value outputs
|
|
328
|
+
// (output shape check is redundant for this codebase, but keep to scalar-by-convention)
|
|
329
|
+
if (zeroGrad)
|
|
330
|
+
Value.zeroGradTree(this);
|
|
331
|
+
const topo = [];
|
|
332
|
+
const visited = new Set();
|
|
333
|
+
const buildTopo = (v) => {
|
|
334
|
+
if (!visited.has(v)) {
|
|
335
|
+
visited.add(v);
|
|
336
|
+
for (const child of v.prev) {
|
|
337
|
+
buildTopo(child);
|
|
338
|
+
}
|
|
339
|
+
topo.push(v);
|
|
340
|
+
}
|
|
341
|
+
};
|
|
342
|
+
buildTopo(this);
|
|
343
|
+
this.grad = 1;
|
|
344
|
+
for (let i = topo.length - 1; i >= 0; i--) {
|
|
345
|
+
if (topo[i].requiresGrad) {
|
|
346
|
+
topo[i].backwardFn();
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
/**
|
|
351
|
+
* Sets all grad fields in the computation tree (from root) to 0.
|
|
352
|
+
* @param root Value to zero tree from
|
|
353
|
+
*/
|
|
354
|
+
static zeroGradTree(root) {
|
|
355
|
+
const visited = new Set();
|
|
356
|
+
const visit = (v) => {
|
|
357
|
+
if (!visited.has(v)) {
|
|
358
|
+
visited.add(v);
|
|
359
|
+
v.grad = 0;
|
|
360
|
+
for (const child of v.prev)
|
|
361
|
+
visit(child);
|
|
362
|
+
}
|
|
363
|
+
};
|
|
364
|
+
visit(root);
|
|
365
|
+
}
|
|
366
|
+
/**
|
|
367
|
+
* Sets all grad fields in all supplied trees to 0.
|
|
368
|
+
* @param vals Values whose trees to zero
|
|
369
|
+
*/
|
|
370
|
+
static zeroGradAll(vals) {
|
|
371
|
+
const visited = new Set();
|
|
372
|
+
for (const v of vals) {
|
|
373
|
+
const visit = (u) => {
|
|
374
|
+
if (!visited.has(u)) {
|
|
375
|
+
visited.add(u);
|
|
376
|
+
u.grad = 0;
|
|
377
|
+
for (const child of u.prev)
|
|
378
|
+
visit(child);
|
|
379
|
+
}
|
|
380
|
+
};
|
|
381
|
+
visit(v);
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
/**
|
|
385
|
+
* Internal helper to construct a Value with correct backward fn and grads.
|
|
386
|
+
* @param data Output value data
|
|
387
|
+
* @param left Left operand Value
|
|
388
|
+
* @param right Right operand Value or null
|
|
389
|
+
* @param backwardFnBuilder Function to create backward closure
|
|
390
|
+
* @param label Node label for debugging
|
|
391
|
+
* @returns New Value node
|
|
392
|
+
*/
|
|
393
|
+
static make(data, left, right, backwardFnBuilder, label) {
|
|
394
|
+
const requiresGrad = !Value.no_grad_mode && [left, right].filter(Boolean).some(v => v.requiresGrad);
|
|
395
|
+
const out = new Value(data, label, requiresGrad);
|
|
396
|
+
out.prev = Value.no_grad_mode ? [] : [left, right].filter(Boolean);
|
|
397
|
+
if (requiresGrad) {
|
|
398
|
+
out.backwardFn = backwardFnBuilder(out);
|
|
399
|
+
}
|
|
400
|
+
return out;
|
|
401
|
+
}
|
|
402
|
+
/**
|
|
403
|
+
* Returns string representation for debugging.
|
|
404
|
+
* @returns String summary of Value
|
|
405
|
+
*/
|
|
406
|
+
toString() {
|
|
407
|
+
return `Value(data=${this.data.toFixed(4)}, grad=${this.grad.toFixed(4)}, label=${this.label})`;
|
|
408
|
+
}
|
|
409
|
+
/**
|
|
410
|
+
* Temporarily disables gradient tracking within the callback scope, like torch.no_grad().
|
|
411
|
+
* Restores the previous state after running fn.
|
|
412
|
+
*/
|
|
413
|
+
static withNoGrad(fn) {
|
|
414
|
+
const prev = Value.no_grad_mode;
|
|
415
|
+
Value.no_grad_mode = true;
|
|
416
|
+
try {
|
|
417
|
+
return fn();
|
|
418
|
+
}
|
|
419
|
+
finally {
|
|
420
|
+
Value.no_grad_mode = prev;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
exports.Value = Value;
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const Value_1 = require("./Value");
|
|
4
|
+
const Losses_1 = require("./Losses");
|
|
5
|
+
describe('Loss function edge cases', () => {
|
|
6
|
+
it('handles empty arrays', () => {
|
|
7
|
+
expect(Losses_1.Losses.mse([], []).data).toBe(0);
|
|
8
|
+
expect(Losses_1.Losses.mae([], []).data).toBe(0);
|
|
9
|
+
expect(Losses_1.Losses.binaryCrossEntropy([], []).data).toBe(0);
|
|
10
|
+
expect(Losses_1.Losses.categoricalCrossEntropy([], []).data).toBe(0);
|
|
11
|
+
});
|
|
12
|
+
it('throws on mismatched lengths', () => {
|
|
13
|
+
const a = [new Value_1.Value(1)];
|
|
14
|
+
const b = [new Value_1.Value(1), new Value_1.Value(2)];
|
|
15
|
+
expect(() => Losses_1.Losses.mse(a, b)).toThrow();
|
|
16
|
+
});
|
|
17
|
+
it('handles extreme values in binary cross entropy', () => {
|
|
18
|
+
const out = new Value_1.Value(0.999999, 'out', true);
|
|
19
|
+
const target = new Value_1.Value(1);
|
|
20
|
+
const loss = Losses_1.Losses.binaryCrossEntropy([out], [target]);
|
|
21
|
+
expect(loss.data).toBeGreaterThan(0);
|
|
22
|
+
expect(loss.data).toBeLessThan(0.1);
|
|
23
|
+
});
|
|
24
|
+
it('throws on invalid class indices in categorical cross entropy', () => {
|
|
25
|
+
const outputs = [new Value_1.Value(1), new Value_1.Value(2)];
|
|
26
|
+
expect(() => Losses_1.Losses.categoricalCrossEntropy(outputs, [2])).toThrow();
|
|
27
|
+
expect(() => Losses_1.Losses.categoricalCrossEntropy(outputs, [-1])).toThrow();
|
|
28
|
+
expect(() => Losses_1.Losses.categoricalCrossEntropy(outputs, [1.5])).toThrow();
|
|
29
|
+
});
|
|
30
|
+
});
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const Value_1 = require("./Value");
|
|
4
|
+
describe('Memory management', () => {
|
|
5
|
+
it('handles large computation graphs', () => {
|
|
6
|
+
let x = new Value_1.Value(1, 'x', true);
|
|
7
|
+
for (let i = 0; i < 100; i++) {
|
|
8
|
+
x = x.add(1).mul(1.01);
|
|
9
|
+
}
|
|
10
|
+
expect(() => x.backward()).not.toThrow();
|
|
11
|
+
});
|
|
12
|
+
it('zeroGradAll handles multiple disconnected graphs', () => {
|
|
13
|
+
const x1 = new Value_1.Value(1, 'x1', true);
|
|
14
|
+
const y1 = x1.mul(2);
|
|
15
|
+
const x2 = new Value_1.Value(2, 'x2', true);
|
|
16
|
+
const y2 = x2.mul(3);
|
|
17
|
+
y1.backward();
|
|
18
|
+
y2.backward();
|
|
19
|
+
Value_1.Value.zeroGradAll([y1, y2]);
|
|
20
|
+
expect(x1.grad).toBe(0);
|
|
21
|
+
expect(x2.grad).toBe(0);
|
|
22
|
+
});
|
|
23
|
+
});
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const Value_1 = require("./Value");
|
|
4
|
+
const Optimizers_1 = require("./Optimizers");
|
|
5
|
+
const Losses_1 = require("./Losses");
|
|
6
|
+
describe("can train scalar neural networks on minimal problems", () => {
|
|
7
|
+
it("1. learns linear regression (y = 2x + 3) with SGD", () => {
|
|
8
|
+
let w = new Value_1.Value(Math.random(), "w", true);
|
|
9
|
+
let b = new Value_1.Value(Math.random(), "b", true);
|
|
10
|
+
const examples = [
|
|
11
|
+
{ x: 1, y: 5 },
|
|
12
|
+
{ x: 2, y: 7 },
|
|
13
|
+
{ x: 3, y: 9 },
|
|
14
|
+
];
|
|
15
|
+
const opt = new Optimizers_1.SGD([w, b], { learningRate: 0.1 });
|
|
16
|
+
for (let epoch = 0; epoch < 300; ++epoch) {
|
|
17
|
+
let preds = [];
|
|
18
|
+
let targets = [];
|
|
19
|
+
for (const ex of examples) {
|
|
20
|
+
const x = new Value_1.Value(ex.x, "x");
|
|
21
|
+
const pred = w.mul(x).add(b);
|
|
22
|
+
preds.push(pred);
|
|
23
|
+
targets.push(new Value_1.Value(ex.y));
|
|
24
|
+
}
|
|
25
|
+
let loss = Losses_1.Losses.mse(preds, targets);
|
|
26
|
+
if (loss.data < 1e-4)
|
|
27
|
+
break;
|
|
28
|
+
w.grad = 0;
|
|
29
|
+
b.grad = 0;
|
|
30
|
+
loss.backward();
|
|
31
|
+
opt.step();
|
|
32
|
+
}
|
|
33
|
+
expect(w.data).toBeCloseTo(2, 1);
|
|
34
|
+
expect(b.data).toBeCloseTo(3, 1);
|
|
35
|
+
});
|
|
36
|
+
it("2. learns quadratic fit (y = x^2) with SGD", () => {
|
|
37
|
+
let a = new Value_1.Value(Math.random(), "a", true);
|
|
38
|
+
let b = new Value_1.Value(Math.random(), "b", true);
|
|
39
|
+
let c = new Value_1.Value(Math.random(), "c", true);
|
|
40
|
+
const examples = [
|
|
41
|
+
{ x: -1, y: 1 },
|
|
42
|
+
{ x: 0, y: 0 },
|
|
43
|
+
{ x: 2, y: 4 },
|
|
44
|
+
{ x: 3, y: 9 },
|
|
45
|
+
];
|
|
46
|
+
const opt = new Optimizers_1.SGD([a, b, c], { learningRate: 0.01 });
|
|
47
|
+
for (let epoch = 0; epoch < 400; ++epoch) {
|
|
48
|
+
let preds = [];
|
|
49
|
+
let targets = [];
|
|
50
|
+
for (const ex of examples) {
|
|
51
|
+
const x = new Value_1.Value(ex.x);
|
|
52
|
+
const pred = a.mul(x.square()).add(b.mul(x)).add(c);
|
|
53
|
+
preds.push(pred);
|
|
54
|
+
targets.push(new Value_1.Value(ex.y));
|
|
55
|
+
}
|
|
56
|
+
let loss = Losses_1.Losses.mse(preds, targets);
|
|
57
|
+
if (loss.data < 1e-4)
|
|
58
|
+
break;
|
|
59
|
+
a.grad = 0;
|
|
60
|
+
b.grad = 0;
|
|
61
|
+
c.grad = 0;
|
|
62
|
+
loss.backward();
|
|
63
|
+
opt.step();
|
|
64
|
+
}
|
|
65
|
+
expect(a.data).toBeCloseTo(1, 1);
|
|
66
|
+
expect(Math.abs(b.data)).toBeLessThan(0.5);
|
|
67
|
+
expect(Math.abs(c.data)).toBeLessThan(0.5);
|
|
68
|
+
});
|
|
69
|
+
/*
|
|
70
|
+
// This is hard to get to work reliably, I believe it's a difficult problem to solve!?
|
|
71
|
+
it("3. learns XOR with tiny MLP (2-2-1) with SGD", () => {
|
|
72
|
+
function mlp(x1: Value, x2: Value, params: Value[]): Value {
|
|
73
|
+
const [w1, w2, w3, w4, b1, b2, v1, v2, c] = params;
|
|
74
|
+
const h1 = w1.mul(x1).add(w2.mul(x2)).add(b1).tanh();
|
|
75
|
+
const h2 = w3.mul(x1).add(w4.mul(x2)).add(b2).tanh();
|
|
76
|
+
return v1.mul(h1).add(v2.mul(h2)).add(c).sigmoid();
|
|
77
|
+
}
|
|
78
|
+
let params = Array.from({ length: 9 }, (_, i) => new Value(Math.random() - 0.5, "p" + i, true));
|
|
79
|
+
const data = [
|
|
80
|
+
{ x: [0, 0], y: 0 },
|
|
81
|
+
{ x: [0, 1], y: 1 },
|
|
82
|
+
{ x: [1, 0], y: 1 },
|
|
83
|
+
{ x: [1, 1], y: 0 },
|
|
84
|
+
];
|
|
85
|
+
const opt = new SGD(params, { learningRate: 0.01 });
|
|
86
|
+
for (let epoch = 0; epoch < 5000; ++epoch) {
|
|
87
|
+
let preds: Value[] = [];
|
|
88
|
+
let targets: Value[] = [];
|
|
89
|
+
for (const ex of data) {
|
|
90
|
+
const x1 = new Value(ex.x[0]);
|
|
91
|
+
const x2 = new Value(ex.x[1]);
|
|
92
|
+
const pred = mlp(x1, x2, params);
|
|
93
|
+
preds.push(pred);
|
|
94
|
+
targets.push(new Value(ex.y));
|
|
95
|
+
}
|
|
96
|
+
let loss = binaryCrossEntropy(preds, targets);
|
|
97
|
+
if (loss.data < 1e-3) break;
|
|
98
|
+
for (const p of params) p.grad = 0;
|
|
99
|
+
loss.backward();
|
|
100
|
+
opt.step();
|
|
101
|
+
}
|
|
102
|
+
const out00 = mlp(new Value(0), new Value(0), params).data;
|
|
103
|
+
const out01 = mlp(new Value(0), new Value(1), params).data;
|
|
104
|
+
const out10 = mlp(new Value(1), new Value(0), params).data;
|
|
105
|
+
const out11 = mlp(new Value(1), new Value(1), params).data;
|
|
106
|
+
expect((out00 < 0.4 || out00 > 0.6)).toBe(true);
|
|
107
|
+
expect(out01).toBeGreaterThan(0.6);
|
|
108
|
+
expect(out10).toBeGreaterThan(0.6);
|
|
109
|
+
expect(out11).toBeLessThan(0.4);
|
|
110
|
+
});*/
|
|
111
|
+
});
|