catniff 0.1.9 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/index.js CHANGED
@@ -1,4 +1,3 @@
1
1
  module.exports = {
2
- ...require("./dist/autograd"),
3
- ...require("./dist/tensor")
2
+ ...require("./dist/core")
4
3
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.9",
3
+ "version": "0.2.0",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {
@@ -25,7 +25,8 @@
25
25
  "framework",
26
26
  "neural-network",
27
27
  "machine-learning",
28
- "deep-learning"
28
+ "deep-learning",
29
+ "micrograd"
29
30
  ],
30
31
  "author": "nguyenphuminh",
31
32
  "license": "GPL-3.0",
@@ -1,112 +0,0 @@
1
- import { Tensor } from "./tensor";
2
- export declare enum OP {
3
- NONE = 0,
4
- ADD = 1,
5
- SUB = 2,
6
- MUL = 3,
7
- POW = 4,
8
- DIV = 5,
9
- GE = 6,
10
- LE = 7,
11
- GT = 8,
12
- LT = 9,
13
- EQ = 10,
14
- LOGICALAND = 11,
15
- LOGICALOR = 12,
16
- LOGICALXOR = 13,
17
- LOGICALNOT = 14,
18
- BITWISEAND = 15,
19
- BITWISEOR = 16,
20
- BITWISEXOR = 17,
21
- BITWISENOT = 18,
22
- BITWISELEFTSHIFT = 19,
23
- BITWISERIGHTSHIFT = 20,
24
- NEG = 21,
25
- ABS = 22,
26
- SIGN = 23,
27
- SIN = 24,
28
- COS = 25,
29
- TAN = 26,
30
- ASIN = 27,
31
- ACOS = 28,
32
- ATAN = 29,
33
- SINH = 30,
34
- COSH = 31,
35
- ASINH = 32,
36
- ACOSH = 33,
37
- ATANH = 34,
38
- SQRT = 35,
39
- EXP = 36,
40
- LOG = 37,
41
- LOG2 = 38,
42
- LOG10 = 39,
43
- LOG1P = 40,
44
- RELU = 41,
45
- SIGMOID = 42,
46
- TANH = 43,
47
- T = 44,
48
- DOT = 45,
49
- MM = 46,
50
- MV = 47,
51
- MATMUL = 48
52
- }
53
- export declare class Node {
54
- value: Tensor;
55
- shape: number[];
56
- grad: Tensor;
57
- children: Node[];
58
- op: OP;
59
- feedBackward: Function;
60
- constructor(value: Tensor, children?: Node[], op?: OP);
61
- add(other: Node | Tensor): Node;
62
- sub(other: Node | Tensor): Node;
63
- mul(other: Node | Tensor): Node;
64
- pow(other: Node | Tensor): Node;
65
- div(other: Node | Tensor): Node;
66
- ge(other: Node | Tensor): Node;
67
- le(other: Node | Tensor): Node;
68
- gt(other: Node | Tensor): Node;
69
- lt(other: Node | Tensor): Node;
70
- eq(other: Node | Tensor): Node;
71
- logicalAnd(other: Node | Tensor): Node;
72
- logicalOr(other: Node | Tensor): Node;
73
- logicalXor(other: Node | Tensor): Node;
74
- logicalNot(): Node;
75
- bitwiseAnd(other: Node | Tensor): Node;
76
- bitwiseOr(other: Node | Tensor): Node;
77
- bitwiseXor(other: Node | Tensor): Node;
78
- bitwiseNot(): Node;
79
- bitwiseLeftShift(other: Node | Tensor): Node;
80
- bitwiseRightShift(other: Node | Tensor): Node;
81
- neg(): Node;
82
- abs(): Node;
83
- sign(): Node;
84
- sin(): Node;
85
- cos(): Node;
86
- tan(): Node;
87
- asin(): Node;
88
- acos(): Node;
89
- atan(): Node;
90
- sinh(): Node;
91
- cosh(): Node;
92
- asinh(): Node;
93
- acosh(): Node;
94
- atanh(): Node;
95
- sqrt(): Node;
96
- exp(): Node;
97
- log(): Node;
98
- log2(): Node;
99
- log10(): Node;
100
- log1p(): Node;
101
- relu(): Node;
102
- sigmoid(): Node;
103
- tanh(): Node;
104
- t(): Node;
105
- dot(other: Node | Tensor): Node;
106
- mm(other: Node | Tensor): Node;
107
- mv(other: Node | Tensor): Node;
108
- matmul(other: Node | Tensor): Node;
109
- backward(): void;
110
- static forceNode(value: Node | Tensor): Node;
111
- static addGrad(node: Node, accumGrad: Tensor): void;
112
- }
package/dist/autograd.js DELETED
@@ -1,547 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Node = exports.OP = void 0;
4
- const tensor_1 = require("./tensor");
5
- const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, dot, mm, mv, matmul } = tensor_1.TensorMath;
6
- var OP;
7
- (function (OP) {
8
- OP[OP["NONE"] = 0] = "NONE";
9
- OP[OP["ADD"] = 1] = "ADD";
10
- OP[OP["SUB"] = 2] = "SUB";
11
- OP[OP["MUL"] = 3] = "MUL";
12
- OP[OP["POW"] = 4] = "POW";
13
- OP[OP["DIV"] = 5] = "DIV";
14
- OP[OP["GE"] = 6] = "GE";
15
- OP[OP["LE"] = 7] = "LE";
16
- OP[OP["GT"] = 8] = "GT";
17
- OP[OP["LT"] = 9] = "LT";
18
- OP[OP["EQ"] = 10] = "EQ";
19
- OP[OP["LOGICALAND"] = 11] = "LOGICALAND";
20
- OP[OP["LOGICALOR"] = 12] = "LOGICALOR";
21
- OP[OP["LOGICALXOR"] = 13] = "LOGICALXOR";
22
- OP[OP["LOGICALNOT"] = 14] = "LOGICALNOT";
23
- OP[OP["BITWISEAND"] = 15] = "BITWISEAND";
24
- OP[OP["BITWISEOR"] = 16] = "BITWISEOR";
25
- OP[OP["BITWISEXOR"] = 17] = "BITWISEXOR";
26
- OP[OP["BITWISENOT"] = 18] = "BITWISENOT";
27
- OP[OP["BITWISELEFTSHIFT"] = 19] = "BITWISELEFTSHIFT";
28
- OP[OP["BITWISERIGHTSHIFT"] = 20] = "BITWISERIGHTSHIFT";
29
- OP[OP["NEG"] = 21] = "NEG";
30
- OP[OP["ABS"] = 22] = "ABS";
31
- OP[OP["SIGN"] = 23] = "SIGN";
32
- OP[OP["SIN"] = 24] = "SIN";
33
- OP[OP["COS"] = 25] = "COS";
34
- OP[OP["TAN"] = 26] = "TAN";
35
- OP[OP["ASIN"] = 27] = "ASIN";
36
- OP[OP["ACOS"] = 28] = "ACOS";
37
- OP[OP["ATAN"] = 29] = "ATAN";
38
- OP[OP["SINH"] = 30] = "SINH";
39
- OP[OP["COSH"] = 31] = "COSH";
40
- OP[OP["ASINH"] = 32] = "ASINH";
41
- OP[OP["ACOSH"] = 33] = "ACOSH";
42
- OP[OP["ATANH"] = 34] = "ATANH";
43
- OP[OP["SQRT"] = 35] = "SQRT";
44
- OP[OP["EXP"] = 36] = "EXP";
45
- OP[OP["LOG"] = 37] = "LOG";
46
- OP[OP["LOG2"] = 38] = "LOG2";
47
- OP[OP["LOG10"] = 39] = "LOG10";
48
- OP[OP["LOG1P"] = 40] = "LOG1P";
49
- OP[OP["RELU"] = 41] = "RELU";
50
- OP[OP["SIGMOID"] = 42] = "SIGMOID";
51
- OP[OP["TANH"] = 43] = "TANH";
52
- OP[OP["T"] = 44] = "T";
53
- OP[OP["DOT"] = 45] = "DOT";
54
- OP[OP["MM"] = 46] = "MM";
55
- OP[OP["MV"] = 47] = "MV";
56
- OP[OP["MATMUL"] = 48] = "MATMUL";
57
- })(OP || (exports.OP = OP = {}));
58
- class Node {
59
- value;
60
- shape;
61
- grad;
62
- children;
63
- op;
64
- feedBackward;
65
- constructor(value, children = [], op = OP.NONE) {
66
- this.value = value;
67
- this.shape = tensor_1.TensorMath.getShape(value);
68
- this.grad = tensor_1.TensorMath.create(0, this.shape);
69
- this.children = children;
70
- this.op = op;
71
- this.feedBackward = () => { };
72
- }
73
- add(other) {
74
- other = Node.forceNode(other);
75
- const out = new Node(add(this.value, other.value), [this, other], OP.ADD);
76
- out.feedBackward = () => {
77
- // x + y d/dx = 1, note that we apply the chain rule continuously so out.grad is multiplied into our derivative
78
- Node.addGrad(this, out.grad);
79
- // x + y d/dy = 1
80
- Node.addGrad(other, out.grad);
81
- };
82
- return out;
83
- }
84
- sub(other) {
85
- other = Node.forceNode(other);
86
- const out = new Node(sub(this.value, other.value), [this, other], OP.SUB);
87
- out.feedBackward = () => {
88
- // x - y d/dx = 1
89
- Node.addGrad(this, out.grad);
90
- // x - y d/dy = -1
91
- Node.addGrad(other, neg(out.grad));
92
- };
93
- return out;
94
- }
95
- mul(other) {
96
- other = Node.forceNode(other);
97
- const out = new Node(mul(this.value, other.value), [this, other], OP.MUL);
98
- out.feedBackward = () => {
99
- // x * y d/dx = y
100
- Node.addGrad(this, mul(out.grad, other.value));
101
- // x * y d/dy = x
102
- Node.addGrad(other, mul(out.grad, this.value));
103
- };
104
- return out;
105
- }
106
- pow(other) {
107
- if (other instanceof Node) {
108
- const out = new Node(pow(this.value, other.value), [this, other], OP.POW);
109
- out.feedBackward = () => {
110
- // x^a d/dx = a*x^(a-1)
111
- Node.addGrad(this, mul(out.grad, mul(other.value, pow(this.value, sub(other.value, 1)))));
112
- // x^a d/da = x^a*lnx
113
- Node.addGrad(other, mul(out.grad, mul(pow(this.value, other.value), log(this.value))));
114
- };
115
- return out;
116
- }
117
- const out = new Node(pow(this.value, other), [this], OP.POW);
118
- out.feedBackward = () => {
119
- Node.addGrad(this, mul(out.grad, mul(other, pow(this.value, sub(other, 1)))));
120
- };
121
- return out;
122
- }
123
- div(other) {
124
- other = Node.forceNode(other);
125
- const out = new Node(div(this.value, other.value), [this, other], OP.DIV);
126
- out.feedBackward = () => {
127
- // x/y d/dx = 1/y
128
- Node.addGrad(this, div(out.grad, other.value));
129
- // x/y d/dy = -x/y^2
130
- Node.addGrad(other, mul(out.grad, div(neg(this.value), pow(other.value, 2))));
131
- };
132
- return out;
133
- }
134
- ge(other) {
135
- other = Node.forceNode(other);
136
- const out = new Node(ge(this.value, other.value), [this, other], OP.GE);
137
- out.feedBackward = () => {
138
- // We consider the derivative of ge to be 0, which does not add to current grad, so this function is just empty
139
- };
140
- return out;
141
- }
142
- le(other) {
143
- other = Node.forceNode(other);
144
- const out = new Node(le(this.value, other.value), [this, other], OP.LE);
145
- out.feedBackward = () => {
146
- // We consider the derivative of le to be 0, which does not add to current grad, so this function is just empty
147
- };
148
- return out;
149
- }
150
- gt(other) {
151
- other = Node.forceNode(other);
152
- const out = new Node(gt(this.value, other.value), [this, other], OP.GT);
153
- out.feedBackward = () => {
154
- // We consider the derivative of gt to be 0, which does not add to current grad, so this function is just empty
155
- };
156
- return out;
157
- }
158
- lt(other) {
159
- other = Node.forceNode(other);
160
- const out = new Node(lt(this.value, other.value), [this, other], OP.LT);
161
- out.feedBackward = () => {
162
- // We consider the derivative of lt to be 0, which does not add to current grad, so this function is just empty
163
- };
164
- return out;
165
- }
166
- eq(other) {
167
- other = Node.forceNode(other);
168
- const out = new Node(eq(this.value, other.value), [this, other], OP.EQ);
169
- out.feedBackward = () => {
170
- // We consider the derivative of eq to be 0, which does not add to current grad, so this function is just empty
171
- };
172
- return out;
173
- }
174
- logicalAnd(other) {
175
- other = Node.forceNode(other);
176
- const out = new Node(logicalAnd(this.value, other.value), [this, other], OP.LOGICALAND);
177
- out.feedBackward = () => {
178
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
179
- };
180
- return out;
181
- }
182
- logicalOr(other) {
183
- other = Node.forceNode(other);
184
- const out = new Node(logicalOr(this.value, other.value), [this, other], OP.LOGICALOR);
185
- out.feedBackward = () => {
186
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
187
- };
188
- return out;
189
- }
190
- logicalXor(other) {
191
- other = Node.forceNode(other);
192
- const out = new Node(logicalXor(this.value, other.value), [this, other], OP.LOGICALXOR);
193
- out.feedBackward = () => {
194
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
195
- };
196
- return out;
197
- }
198
- logicalNot() {
199
- const out = new Node(logicalNot(this.value), [this], OP.LOGICALNOT);
200
- out.feedBackward = () => {
201
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
202
- };
203
- return out;
204
- }
205
- bitwiseAnd(other) {
206
- other = Node.forceNode(other);
207
- const out = new Node(bitwiseAnd(this.value, other.value), [this, other], OP.BITWISEAND);
208
- out.feedBackward = () => {
209
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
210
- };
211
- return out;
212
- }
213
- bitwiseOr(other) {
214
- other = Node.forceNode(other);
215
- const out = new Node(bitwiseOr(this.value, other.value), [this, other], OP.BITWISEOR);
216
- out.feedBackward = () => {
217
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
218
- };
219
- return out;
220
- }
221
- bitwiseXor(other) {
222
- other = Node.forceNode(other);
223
- const out = new Node(bitwiseXor(this.value, other.value), [this, other], OP.BITWISEXOR);
224
- out.feedBackward = () => {
225
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
226
- };
227
- return out;
228
- }
229
- bitwiseNot() {
230
- const out = new Node(bitwiseNot(this.value), [this], OP.BITWISENOT);
231
- out.feedBackward = () => {
232
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
233
- };
234
- return out;
235
- }
236
- bitwiseLeftShift(other) {
237
- other = Node.forceNode(other);
238
- const out = new Node(bitwiseLeftShift(this.value, other.value), [this, other], OP.BITWISELEFTSHIFT);
239
- out.feedBackward = () => {
240
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
241
- };
242
- return out;
243
- }
244
- bitwiseRightShift(other) {
245
- other = Node.forceNode(other);
246
- const out = new Node(bitwiseRightShift(this.value, other.value), [this, other], OP.BITWISERIGHTSHIFT);
247
- out.feedBackward = () => {
248
- // We consider the derivative of this to be 0, which does not add to current grad, so this function is just empty
249
- };
250
- return out;
251
- }
252
- neg() {
253
- const out = new Node(neg(this.value), [this], OP.NEG);
254
- out.feedBackward = () => {
255
- // -x d/dx = -1
256
- Node.addGrad(this, neg(out.grad));
257
- };
258
- return out;
259
- }
260
- abs() {
261
- const out = new Node(abs(this.value), [this], OP.ABS);
262
- out.feedBackward = () => {
263
- // |x| d/dx = sign(x)
264
- Node.addGrad(this, mul(out.grad, sign(this.value)));
265
- };
266
- return out;
267
- }
268
- sign() {
269
- const out = new Node(sign(this.value), [this], OP.SIGN);
270
- out.feedBackward = () => {
271
- // We consider the derivative of sign to be 0, which does not add to current grad, so this function is just empty
272
- };
273
- return out;
274
- }
275
- sin() {
276
- const out = new Node(sin(this.value), [this], OP.SIN);
277
- out.feedBackward = () => {
278
- // sinx d/dx = cosx
279
- Node.addGrad(this, mul(out.grad, cos(this.value)));
280
- };
281
- return out;
282
- }
283
- cos() {
284
- const out = new Node(cos(this.value), [this], OP.COS);
285
- out.feedBackward = () => {
286
- // cosx d/dx = -sinx
287
- Node.addGrad(this, mul(out.grad, neg(sin(this.value))));
288
- };
289
- return out;
290
- }
291
- tan() {
292
- const tanResult = tan(this.value);
293
- const out = new Node(tanResult, [this], OP.TAN);
294
- out.feedBackward = () => {
295
- // tanx d/dx = 1+(tanx)^2
296
- Node.addGrad(this, mul(out.grad, add(1, pow(tanResult, 2))));
297
- };
298
- return out;
299
- }
300
- asin() {
301
- const out = new Node(asin(this.value), [this], OP.ASIN);
302
- out.feedBackward = () => {
303
- // asinx d/dx = 1/sqrt(1-x^2)
304
- Node.addGrad(this, div(out.grad, sqrt(sub(1, pow(this.value, 2)))));
305
- };
306
- return out;
307
- }
308
- acos() {
309
- const out = new Node(acos(this.value), [this], OP.ACOS);
310
- out.feedBackward = () => {
311
- // acosx d/dx = -1/sqrt(1-x^2)
312
- Node.addGrad(this, neg(div(out.grad, sqrt(sub(1, pow(this.value, 2))))));
313
- };
314
- return out;
315
- }
316
- atan() {
317
- const out = new Node(atan(this.value), [this], OP.ATAN);
318
- out.feedBackward = () => {
319
- // atanx d/dx = 1/(1+x^2)
320
- Node.addGrad(this, div(out.grad, add(1, pow(this.value, 2))));
321
- };
322
- return out;
323
- }
324
- sinh() {
325
- const out = new Node(sinh(this.value), [this], OP.SINH);
326
- out.feedBackward = () => {
327
- // sinhx d/dx = coshx
328
- Node.addGrad(this, mul(out.grad, cosh(this.value)));
329
- };
330
- return out;
331
- }
332
- cosh() {
333
- const out = new Node(cosh(this.value), [this], OP.COSH);
334
- out.feedBackward = () => {
335
- // coshx d/dx = sinhx
336
- Node.addGrad(this, mul(out.grad, sinh(this.value)));
337
- };
338
- return out;
339
- }
340
- asinh() {
341
- const out = new Node(asinh(this.value), [this], OP.ASINH);
342
- out.feedBackward = () => {
343
- // asinhx d/dx = 1/sqrt(1+x^2)
344
- Node.addGrad(this, div(out.grad, sqrt(add(1, pow(this.value, 2)))));
345
- };
346
- return out;
347
- }
348
- acosh() {
349
- const out = new Node(acosh(this.value), [this], OP.ACOSH);
350
- out.feedBackward = () => {
351
- // acosx d/dx = 1/(sqrt(x-1)*sqrt(x+1))
352
- Node.addGrad(this, div(out.grad, mul(sqrt(sub(this.value, 1)), sqrt(add(this.value, 1)))));
353
- };
354
- return out;
355
- }
356
- atanh() {
357
- const out = new Node(atanh(this.value), [this], OP.ATANH);
358
- out.feedBackward = () => {
359
- // atanx d/dx = 1/(1-x^2)
360
- Node.addGrad(this, div(out.grad, sub(1, pow(this.value, 2))));
361
- };
362
- return out;
363
- }
364
- sqrt() {
365
- const sqrtResult = sqrt(this.value);
366
- const out = new Node(sqrtResult, [this], OP.SQRT);
367
- out.feedBackward = () => {
368
- // sqrt(x) d/dx = 1/(2*sqrt(x))
369
- Node.addGrad(this, div(out.grad, mul(2, sqrtResult)));
370
- };
371
- return out;
372
- }
373
- exp() {
374
- const expResult = exp(this.value);
375
- const out = new Node(expResult, [this], OP.EXP);
376
- out.feedBackward = () => {
377
- // e^x d/dx = e^x
378
- Node.addGrad(this, mul(out.grad, expResult));
379
- };
380
- return out;
381
- }
382
- log() {
383
- const out = new Node(log(this.value), [this], OP.LOG);
384
- out.feedBackward = () => {
385
- // lnx d/dx = 1/x
386
- Node.addGrad(this, div(out.grad, this.value));
387
- };
388
- return out;
389
- }
390
- log2() {
391
- const out = new Node(log2(this.value), [this], OP.LOG2);
392
- out.feedBackward = () => {
393
- // log2(x) d/dx = 1/(xln2)
394
- Node.addGrad(this, div(out.grad, mul(this.value, Math.log(2))));
395
- };
396
- return out;
397
- }
398
- log10() {
399
- const out = new Node(log10(this.value), [this], OP.LOG10);
400
- out.feedBackward = () => {
401
- // log2(x) d/dx = 1/(xln10)
402
- Node.addGrad(this, div(out.grad, mul(this.value, Math.log(10))));
403
- };
404
- return out;
405
- }
406
- log1p() {
407
- const out = new Node(log1p(this.value), [this], OP.LOG1P);
408
- out.feedBackward = () => {
409
- // ln(1+x) d/dx = 1/(1+x)
410
- Node.addGrad(this, div(out.grad, add(this.value, 1)));
411
- };
412
- return out;
413
- }
414
- relu() {
415
- const out = new Node(relu(this.value), [this], OP.RELU);
416
- out.feedBackward = () => {
417
- Node.addGrad(this, mul(out.grad, ge(this.value, 0)));
418
- };
419
- return out;
420
- }
421
- sigmoid() {
422
- const sigmoidResult = sigmoid(this.value);
423
- const out = new Node(sigmoidResult, [this], OP.SIGMOID);
424
- out.feedBackward = () => {
425
- Node.addGrad(this, mul(mul(out.grad, sigmoidResult), sub(1, sigmoidResult)));
426
- };
427
- return out;
428
- }
429
- tanh() {
430
- const tanhResult = tanh(this.value);
431
- const out = new Node(tanhResult, [this], OP.TANH);
432
- out.feedBackward = () => {
433
- Node.addGrad(this, mul(out.grad, sub(1, mul(tanhResult, tanhResult))));
434
- };
435
- return out;
436
- }
437
- t() {
438
- const out = new Node(t(this.value), [this], OP.T);
439
- out.feedBackward = () => {
440
- Node.addGrad(this, t(out.grad));
441
- };
442
- return out;
443
- }
444
- dot(other) {
445
- other = Node.forceNode(other);
446
- const out = new Node(dot(this.value, other.value), [this, other], OP.DOT);
447
- out.feedBackward = () => {
448
- Node.addGrad(this, mul(out.grad, other.value));
449
- Node.addGrad(other, mul(out.grad, this.value));
450
- };
451
- return out;
452
- }
453
- mm(other) {
454
- other = Node.forceNode(other);
455
- const out = new Node(mm(this.value, other.value), [this, other], OP.MM);
456
- out.feedBackward = () => {
457
- Node.addGrad(this, mm(out.grad, t(other.value)));
458
- Node.addGrad(other, mm(t(this.value), out.grad));
459
- };
460
- return out;
461
- }
462
- mv(other) {
463
- other = Node.forceNode(other);
464
- const out = new Node(mv(this.value, other.value), [this, other], OP.MV);
465
- out.feedBackward = () => {
466
- const outGradMat = out.grad.map(el => [el]);
467
- Node.addGrad(this, mm(outGradMat, [other.value]));
468
- Node.addGrad(other, mv(t(this.value), out.grad));
469
- };
470
- return out;
471
- }
472
- matmul(other) {
473
- other = Node.forceNode(other);
474
- const out = new Node(matmul(this.value, other.value), [this, other], OP.MATMUL);
475
- if (this.shape.length === 1 && other.shape.length === 1) {
476
- out.feedBackward = () => {
477
- Node.addGrad(this, mul(out.grad, other.value));
478
- Node.addGrad(other, mul(out.grad, this.value));
479
- };
480
- }
481
- else if (this.shape.length === 1 && other.shape.length === 2) {
482
- out.feedBackward = () => {
483
- Node.addGrad(this, matmul(out.grad, t(other.value)));
484
- Node.addGrad(other, mm(t([this.value]), [out.grad]));
485
- };
486
- }
487
- else if (this.shape.length === 2 && other.shape.length === 1) {
488
- out.feedBackward = () => {
489
- const outGradMat = out.grad.map(el => [el]);
490
- Node.addGrad(this, mm(outGradMat, [other.value]));
491
- Node.addGrad(other, mv(t(this.value), out.grad));
492
- };
493
- }
494
- else if (this.shape.length === 2 && other.shape.length === 2) {
495
- out.feedBackward = () => {
496
- Node.addGrad(this, mm(out.grad, t(other.value)));
497
- Node.addGrad(other, mm(t(this.value), out.grad));
498
- };
499
- }
500
- return out;
501
- }
502
- backward() {
503
- // Build topological order
504
- const topo = [];
505
- const visited = new Set();
506
- function build(node) {
507
- if (!visited.has(node)) {
508
- visited.add(node);
509
- node.grad = tensor_1.TensorMath.create(0, node.shape);
510
- for (let child of node.children)
511
- build(child);
512
- topo.push(node);
513
- }
514
- }
515
- build(this);
516
- // Feed backward to calculate gradient
517
- this.grad = tensor_1.TensorMath.create(1, this.shape); // Derivative of itself with respect to itself
518
- for (let index = topo.length - 1; index > -1; index--) {
519
- topo[index].feedBackward();
520
- }
521
- }
522
- static forceNode(value) {
523
- if (value instanceof Node)
524
- return value;
525
- return new Node(value);
526
- }
527
- static addGrad(node, accumGrad) {
528
- const axesToSqueeze = [];
529
- const axesToReduce = [];
530
- const shape = node.shape;
531
- const gradShape = tensor_1.TensorMath.getShape(accumGrad);
532
- const paddedDims = gradShape.length - shape.length;
533
- for (let i = 0; i < paddedDims; i++) {
534
- axesToReduce.push(i);
535
- axesToSqueeze.push(i);
536
- }
537
- for (let i = 0; i < shape.length; i++) {
538
- if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
539
- axesToReduce.push(i + paddedDims);
540
- }
541
- }
542
- const reducedGrad = tensor_1.TensorMath.sum(accumGrad, axesToReduce, true);
543
- const squeezedGrad = tensor_1.TensorMath.squeeze(reducedGrad, axesToSqueeze);
544
- node.grad = add(squeezedGrad, node.grad);
545
- }
546
- }
547
- exports.Node = Node;