catniff 0.1.8 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/autograd.d.ts +27 -23
- package/dist/autograd.js +49 -7
- package/dist/tensor.d.ts +3 -1
- package/dist/tensor.js +37 -10
- package/package.json +1 -1
package/dist/autograd.d.ts
CHANGED
|
@@ -45,8 +45,10 @@ export declare enum OP {
|
|
|
45
45
|
SIGMOID = 42,
|
|
46
46
|
TANH = 43,
|
|
47
47
|
T = 44,
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
DOT = 45,
|
|
49
|
+
MM = 46,
|
|
50
|
+
MV = 47,
|
|
51
|
+
MATMUL = 48
|
|
50
52
|
}
|
|
51
53
|
export declare class Node {
|
|
52
54
|
value: Tensor;
|
|
@@ -56,26 +58,26 @@ export declare class Node {
|
|
|
56
58
|
op: OP;
|
|
57
59
|
feedBackward: Function;
|
|
58
60
|
constructor(value: Tensor, children?: Node[], op?: OP);
|
|
59
|
-
add(other: Node |
|
|
60
|
-
sub(other: Node |
|
|
61
|
-
mul(other: Node |
|
|
62
|
-
pow(other: Node |
|
|
63
|
-
div(other: Node |
|
|
64
|
-
ge(other: Node |
|
|
65
|
-
le(other: Node |
|
|
66
|
-
gt(other: Node |
|
|
67
|
-
lt(other: Node |
|
|
68
|
-
eq(other: Node |
|
|
69
|
-
logicalAnd(other: Node |
|
|
70
|
-
logicalOr(other: Node |
|
|
71
|
-
logicalXor(other: Node |
|
|
61
|
+
add(other: Node | Tensor): Node;
|
|
62
|
+
sub(other: Node | Tensor): Node;
|
|
63
|
+
mul(other: Node | Tensor): Node;
|
|
64
|
+
pow(other: Node | Tensor): Node;
|
|
65
|
+
div(other: Node | Tensor): Node;
|
|
66
|
+
ge(other: Node | Tensor): Node;
|
|
67
|
+
le(other: Node | Tensor): Node;
|
|
68
|
+
gt(other: Node | Tensor): Node;
|
|
69
|
+
lt(other: Node | Tensor): Node;
|
|
70
|
+
eq(other: Node | Tensor): Node;
|
|
71
|
+
logicalAnd(other: Node | Tensor): Node;
|
|
72
|
+
logicalOr(other: Node | Tensor): Node;
|
|
73
|
+
logicalXor(other: Node | Tensor): Node;
|
|
72
74
|
logicalNot(): Node;
|
|
73
|
-
bitwiseAnd(other: Node |
|
|
74
|
-
bitwiseOr(other: Node |
|
|
75
|
-
bitwiseXor(other: Node |
|
|
75
|
+
bitwiseAnd(other: Node | Tensor): Node;
|
|
76
|
+
bitwiseOr(other: Node | Tensor): Node;
|
|
77
|
+
bitwiseXor(other: Node | Tensor): Node;
|
|
76
78
|
bitwiseNot(): Node;
|
|
77
|
-
bitwiseLeftShift(other: Node |
|
|
78
|
-
bitwiseRightShift(other: Node |
|
|
79
|
+
bitwiseLeftShift(other: Node | Tensor): Node;
|
|
80
|
+
bitwiseRightShift(other: Node | Tensor): Node;
|
|
79
81
|
neg(): Node;
|
|
80
82
|
abs(): Node;
|
|
81
83
|
sign(): Node;
|
|
@@ -100,9 +102,11 @@ export declare class Node {
|
|
|
100
102
|
sigmoid(): Node;
|
|
101
103
|
tanh(): Node;
|
|
102
104
|
t(): Node;
|
|
103
|
-
|
|
104
|
-
|
|
105
|
+
dot(other: Node | Tensor): Node;
|
|
106
|
+
mm(other: Node | Tensor): Node;
|
|
107
|
+
mv(other: Node | Tensor): Node;
|
|
108
|
+
matmul(other: Node | Tensor): Node;
|
|
105
109
|
backward(): void;
|
|
106
|
-
static forceNode(value: Node |
|
|
110
|
+
static forceNode(value: Node | Tensor): Node;
|
|
107
111
|
static addGrad(node: Node, accumGrad: Tensor): void;
|
|
108
112
|
}
|
package/dist/autograd.js
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.Node = exports.OP = void 0;
|
|
4
4
|
const tensor_1 = require("./tensor");
|
|
5
|
-
const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, mm,
|
|
5
|
+
const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, dot, mm, mv, matmul } = tensor_1.TensorMath;
|
|
6
6
|
var OP;
|
|
7
7
|
(function (OP) {
|
|
8
8
|
OP[OP["NONE"] = 0] = "NONE";
|
|
@@ -50,8 +50,10 @@ var OP;
|
|
|
50
50
|
OP[OP["SIGMOID"] = 42] = "SIGMOID";
|
|
51
51
|
OP[OP["TANH"] = 43] = "TANH";
|
|
52
52
|
OP[OP["T"] = 44] = "T";
|
|
53
|
-
OP[OP["
|
|
54
|
-
OP[OP["
|
|
53
|
+
OP[OP["DOT"] = 45] = "DOT";
|
|
54
|
+
OP[OP["MM"] = 46] = "MM";
|
|
55
|
+
OP[OP["MV"] = 47] = "MV";
|
|
56
|
+
OP[OP["MATMUL"] = 48] = "MATMUL";
|
|
55
57
|
})(OP || (exports.OP = OP = {}));
|
|
56
58
|
class Node {
|
|
57
59
|
value;
|
|
@@ -439,6 +441,15 @@ class Node {
|
|
|
439
441
|
};
|
|
440
442
|
return out;
|
|
441
443
|
}
|
|
444
|
+
dot(other) {
|
|
445
|
+
other = Node.forceNode(other);
|
|
446
|
+
const out = new Node(dot(this.value, other.value), [this, other], OP.DOT);
|
|
447
|
+
out.feedBackward = () => {
|
|
448
|
+
Node.addGrad(this, mul(out.grad, other.value));
|
|
449
|
+
Node.addGrad(other, mul(out.grad, this.value));
|
|
450
|
+
};
|
|
451
|
+
return out;
|
|
452
|
+
}
|
|
442
453
|
mm(other) {
|
|
443
454
|
other = Node.forceNode(other);
|
|
444
455
|
const out = new Node(mm(this.value, other.value), [this, other], OP.MM);
|
|
@@ -448,15 +459,46 @@ class Node {
|
|
|
448
459
|
};
|
|
449
460
|
return out;
|
|
450
461
|
}
|
|
451
|
-
|
|
462
|
+
mv(other) {
|
|
452
463
|
other = Node.forceNode(other);
|
|
453
|
-
const out = new Node(
|
|
464
|
+
const out = new Node(mv(this.value, other.value), [this, other], OP.MV);
|
|
454
465
|
out.feedBackward = () => {
|
|
455
|
-
|
|
456
|
-
Node.addGrad(
|
|
466
|
+
const outGradMat = out.grad.map(el => [el]);
|
|
467
|
+
Node.addGrad(this, mm(outGradMat, [other.value]));
|
|
468
|
+
Node.addGrad(other, mv(t(this.value), out.grad));
|
|
457
469
|
};
|
|
458
470
|
return out;
|
|
459
471
|
}
|
|
472
|
+
matmul(other) {
|
|
473
|
+
other = Node.forceNode(other);
|
|
474
|
+
const out = new Node(matmul(this.value, other.value), [this, other], OP.MATMUL);
|
|
475
|
+
if (this.shape.length === 1 && other.shape.length === 1) {
|
|
476
|
+
out.feedBackward = () => {
|
|
477
|
+
Node.addGrad(this, mul(out.grad, other.value));
|
|
478
|
+
Node.addGrad(other, mul(out.grad, this.value));
|
|
479
|
+
};
|
|
480
|
+
}
|
|
481
|
+
else if (this.shape.length === 1 && other.shape.length === 2) {
|
|
482
|
+
out.feedBackward = () => {
|
|
483
|
+
Node.addGrad(this, matmul(out.grad, t(other.value)));
|
|
484
|
+
Node.addGrad(other, mm(t([this.value]), [out.grad]));
|
|
485
|
+
};
|
|
486
|
+
}
|
|
487
|
+
else if (this.shape.length === 2 && other.shape.length === 1) {
|
|
488
|
+
out.feedBackward = () => {
|
|
489
|
+
const outGradMat = out.grad.map(el => [el]);
|
|
490
|
+
Node.addGrad(this, mm(outGradMat, [other.value]));
|
|
491
|
+
Node.addGrad(other, mv(t(this.value), out.grad));
|
|
492
|
+
};
|
|
493
|
+
}
|
|
494
|
+
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
495
|
+
out.feedBackward = () => {
|
|
496
|
+
Node.addGrad(this, mm(out.grad, t(other.value)));
|
|
497
|
+
Node.addGrad(other, mm(t(this.value), out.grad));
|
|
498
|
+
};
|
|
499
|
+
}
|
|
500
|
+
return out;
|
|
501
|
+
}
|
|
460
502
|
backward() {
|
|
461
503
|
// Build topological order
|
|
462
504
|
const topo = [];
|
package/dist/tensor.d.ts
CHANGED
|
@@ -53,6 +53,8 @@ export declare class TensorMath {
|
|
|
53
53
|
static sumAxis(tA: Tensor, axis: number): Tensor;
|
|
54
54
|
static sum(tA: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
55
55
|
static t(tA: Tensor): Tensor;
|
|
56
|
-
static mm(tA: Tensor, tB: Tensor): Tensor;
|
|
57
56
|
static dot(tA: Tensor, tB: Tensor): Tensor;
|
|
57
|
+
static mm(tA: Tensor, tB: Tensor): Tensor;
|
|
58
|
+
static mv(tA: Tensor, tB: Tensor): Tensor;
|
|
59
|
+
static matmul(tA: Tensor, tB: Tensor): Tensor;
|
|
58
60
|
}
|
package/dist/tensor.js
CHANGED
|
@@ -266,6 +266,20 @@ class TensorMath {
|
|
|
266
266
|
}
|
|
267
267
|
return matATranspose;
|
|
268
268
|
}
|
|
269
|
+
static dot(tA, tB) {
|
|
270
|
+
const shapeA = TensorMath.getShape(tA);
|
|
271
|
+
const shapeB = TensorMath.getShape(tB);
|
|
272
|
+
if (shapeA.length !== 1 || shapeB.length !== 1 || shapeA[0] !== shapeB[0])
|
|
273
|
+
throw new Error("Inputs are not 1D tensors");
|
|
274
|
+
const vectLen = shapeA[0];
|
|
275
|
+
const vectA = tA;
|
|
276
|
+
const vectB = tB;
|
|
277
|
+
let sum = 0;
|
|
278
|
+
for (let index = 0; index < vectLen; index++) {
|
|
279
|
+
sum += vectA[index] * vectB[index];
|
|
280
|
+
}
|
|
281
|
+
return sum;
|
|
282
|
+
}
|
|
269
283
|
static mm(tA, tB) {
|
|
270
284
|
const shapeA = TensorMath.getShape(tA);
|
|
271
285
|
const shapeB = TensorMath.getShape(tB);
|
|
@@ -289,19 +303,32 @@ class TensorMath {
|
|
|
289
303
|
}
|
|
290
304
|
return matC;
|
|
291
305
|
}
|
|
292
|
-
static
|
|
306
|
+
static mv(tA, tB) {
|
|
293
307
|
const shapeA = TensorMath.getShape(tA);
|
|
294
308
|
const shapeB = TensorMath.getShape(tB);
|
|
295
|
-
if (shapeA.length !==
|
|
296
|
-
throw new Error("
|
|
297
|
-
const
|
|
298
|
-
const
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
309
|
+
if (shapeA.length !== 2 || shapeB.length !== 1)
|
|
310
|
+
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
311
|
+
const matA = tA;
|
|
312
|
+
const matB = tB.map(el => [el]); // Turn the 1D tensor into a nx1 matrix (vector)
|
|
313
|
+
return TensorMath.mm(matA, matB).map(el => el[0]);
|
|
314
|
+
}
|
|
315
|
+
static matmul(tA, tB) {
|
|
316
|
+
const shapeA = TensorMath.getShape(tA);
|
|
317
|
+
const shapeB = TensorMath.getShape(tB);
|
|
318
|
+
if (shapeA.length === 1 && shapeB.length === 1) {
|
|
319
|
+
return TensorMath.dot(tA, tB);
|
|
303
320
|
}
|
|
304
|
-
|
|
321
|
+
else if (shapeA.length === 1 && shapeB.length === 2) {
|
|
322
|
+
return TensorMath.mm([tA], tB)[0];
|
|
323
|
+
}
|
|
324
|
+
else if (shapeA.length === 2 && shapeB.length === 1) {
|
|
325
|
+
return TensorMath.mv(tA, tB);
|
|
326
|
+
}
|
|
327
|
+
else if (shapeA.length === 2 && shapeB.length === 2) {
|
|
328
|
+
return TensorMath.mm(tA, tB);
|
|
329
|
+
}
|
|
330
|
+
// Batched matmul will come when general nD transpose is done
|
|
331
|
+
throw new Error(`Shapes [] and [] are not supported`);
|
|
305
332
|
}
|
|
306
333
|
}
|
|
307
334
|
exports.TensorMath = TensorMath;
|