catniff 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,7 +45,10 @@ export declare enum OP {
45
45
  SIGMOID = 42,
46
46
  TANH = 43,
47
47
  T = 44,
48
- MM = 45
48
+ DOT = 45,
49
+ MM = 46,
50
+ MV = 47,
51
+ MATMUL = 48
49
52
  }
50
53
  export declare class Node {
51
54
  value: Tensor;
@@ -55,26 +58,26 @@ export declare class Node {
55
58
  op: OP;
56
59
  feedBackward: Function;
57
60
  constructor(value: Tensor, children?: Node[], op?: OP);
58
- add(other: Node | number): Node;
59
- sub(other: Node | number): Node;
60
- mul(other: Node | number): Node;
61
- pow(other: Node | number): Node;
62
- div(other: Node | number): Node;
63
- ge(other: Node | number): Node;
64
- le(other: Node | number): Node;
65
- gt(other: Node | number): Node;
66
- lt(other: Node | number): Node;
67
- eq(other: Node | number): Node;
68
- logicalAnd(other: Node | number): Node;
69
- logicalOr(other: Node | number): Node;
70
- logicalXor(other: Node | number): Node;
61
+ add(other: Node | Tensor): Node;
62
+ sub(other: Node | Tensor): Node;
63
+ mul(other: Node | Tensor): Node;
64
+ pow(other: Node | Tensor): Node;
65
+ div(other: Node | Tensor): Node;
66
+ ge(other: Node | Tensor): Node;
67
+ le(other: Node | Tensor): Node;
68
+ gt(other: Node | Tensor): Node;
69
+ lt(other: Node | Tensor): Node;
70
+ eq(other: Node | Tensor): Node;
71
+ logicalAnd(other: Node | Tensor): Node;
72
+ logicalOr(other: Node | Tensor): Node;
73
+ logicalXor(other: Node | Tensor): Node;
71
74
  logicalNot(): Node;
72
- bitwiseAnd(other: Node | number): Node;
73
- bitwiseOr(other: Node | number): Node;
74
- bitwiseXor(other: Node | number): Node;
75
+ bitwiseAnd(other: Node | Tensor): Node;
76
+ bitwiseOr(other: Node | Tensor): Node;
77
+ bitwiseXor(other: Node | Tensor): Node;
75
78
  bitwiseNot(): Node;
76
- bitwiseLeftShift(other: Node | number): Node;
77
- bitwiseRightShift(other: Node | number): Node;
79
+ bitwiseLeftShift(other: Node | Tensor): Node;
80
+ bitwiseRightShift(other: Node | Tensor): Node;
78
81
  neg(): Node;
79
82
  abs(): Node;
80
83
  sign(): Node;
@@ -99,8 +102,11 @@ export declare class Node {
99
102
  sigmoid(): Node;
100
103
  tanh(): Node;
101
104
  t(): Node;
102
- mm(other: Node | number): Node;
105
+ dot(other: Node | Tensor): Node;
106
+ mm(other: Node | Tensor): Node;
107
+ mv(other: Node | Tensor): Node;
108
+ matmul(other: Node | Tensor): Node;
103
109
  backward(): void;
104
- static forceNode(value: Node | number): Node;
110
+ static forceNode(value: Node | Tensor): Node;
105
111
  static addGrad(node: Node, accumGrad: Tensor): void;
106
112
  }
package/dist/autograd.js CHANGED
@@ -2,7 +2,7 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Node = exports.OP = void 0;
4
4
  const tensor_1 = require("./tensor");
5
- const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, mm } = tensor_1.TensorMath;
5
+ const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, dot, mm, mv, matmul } = tensor_1.TensorMath;
6
6
  var OP;
7
7
  (function (OP) {
8
8
  OP[OP["NONE"] = 0] = "NONE";
@@ -50,7 +50,10 @@ var OP;
50
50
  OP[OP["SIGMOID"] = 42] = "SIGMOID";
51
51
  OP[OP["TANH"] = 43] = "TANH";
52
52
  OP[OP["T"] = 44] = "T";
53
- OP[OP["MM"] = 45] = "MM";
53
+ OP[OP["DOT"] = 45] = "DOT";
54
+ OP[OP["MM"] = 46] = "MM";
55
+ OP[OP["MV"] = 47] = "MV";
56
+ OP[OP["MATMUL"] = 48] = "MATMUL";
54
57
  })(OP || (exports.OP = OP = {}));
55
58
  class Node {
56
59
  value;
@@ -438,6 +441,15 @@ class Node {
438
441
  };
439
442
  return out;
440
443
  }
444
+ dot(other) {
445
+ other = Node.forceNode(other);
446
+ const out = new Node(dot(this.value, other.value), [this, other], OP.DOT);
447
+ out.feedBackward = () => {
448
+ Node.addGrad(this, mul(out.grad, other.value));
449
+ Node.addGrad(other, mul(out.grad, this.value));
450
+ };
451
+ return out;
452
+ }
441
453
  mm(other) {
442
454
  other = Node.forceNode(other);
443
455
  const out = new Node(mm(this.value, other.value), [this, other], OP.MM);
@@ -447,6 +459,46 @@ class Node {
447
459
  };
448
460
  return out;
449
461
  }
462
+ mv(other) {
463
+ other = Node.forceNode(other);
464
+ const out = new Node(mv(this.value, other.value), [this, other], OP.MV);
465
+ out.feedBackward = () => {
466
+ const outGradMat = out.grad.map(el => [el]);
467
+ Node.addGrad(this, mm(outGradMat, [other.value]));
468
+ Node.addGrad(other, mv(t(this.value), out.grad));
469
+ };
470
+ return out;
471
+ }
472
+ matmul(other) {
473
+ other = Node.forceNode(other);
474
+ const out = new Node(matmul(this.value, other.value), [this, other], OP.MATMUL);
475
+ if (this.shape.length === 1 && other.shape.length === 1) {
476
+ out.feedBackward = () => {
477
+ Node.addGrad(this, mul(out.grad, other.value));
478
+ Node.addGrad(other, mul(out.grad, this.value));
479
+ };
480
+ }
481
+ else if (this.shape.length === 1 && other.shape.length === 2) {
482
+ out.feedBackward = () => {
483
+ Node.addGrad(this, matmul(out.grad, t(other.value)));
484
+ Node.addGrad(other, mm(t([this.value]), [out.grad]));
485
+ };
486
+ }
487
+ else if (this.shape.length === 2 && other.shape.length === 1) {
488
+ out.feedBackward = () => {
489
+ const outGradMat = out.grad.map(el => [el]);
490
+ Node.addGrad(this, mm(outGradMat, [other.value]));
491
+ Node.addGrad(other, mv(t(this.value), out.grad));
492
+ };
493
+ }
494
+ else if (this.shape.length === 2 && other.shape.length === 2) {
495
+ out.feedBackward = () => {
496
+ Node.addGrad(this, mm(out.grad, t(other.value)));
497
+ Node.addGrad(other, mm(t(this.value), out.grad));
498
+ };
499
+ }
500
+ return out;
501
+ }
450
502
  backward() {
451
503
  // Build topological order
452
504
  const topo = [];
package/dist/tensor.d.ts CHANGED
@@ -53,5 +53,8 @@ export declare class TensorMath {
53
53
  static sumAxis(tA: Tensor, axis: number): Tensor;
54
54
  static sum(tA: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
55
55
  static t(tA: Tensor): Tensor;
56
+ static dot(tA: Tensor, tB: Tensor): Tensor;
56
57
  static mm(tA: Tensor, tB: Tensor): Tensor;
58
+ static mv(tA: Tensor, tB: Tensor): Tensor;
59
+ static matmul(tA: Tensor, tB: Tensor): Tensor;
57
60
  }
package/dist/tensor.js CHANGED
@@ -266,6 +266,20 @@ class TensorMath {
266
266
  }
267
267
  return matATranspose;
268
268
  }
269
+ static dot(tA, tB) {
270
+ const shapeA = TensorMath.getShape(tA);
271
+ const shapeB = TensorMath.getShape(tB);
272
+ if (shapeA.length !== 1 || shapeB.length !== 1 || shapeA[0] !== shapeB[0])
273
+ throw new Error("Inputs are not 1D tensors");
274
+ const vectLen = shapeA[0];
275
+ const vectA = tA;
276
+ const vectB = tB;
277
+ let sum = 0;
278
+ for (let index = 0; index < vectLen; index++) {
279
+ sum += vectA[index] * vectB[index];
280
+ }
281
+ return sum;
282
+ }
269
283
  static mm(tA, tB) {
270
284
  const shapeA = TensorMath.getShape(tA);
271
285
  const shapeB = TensorMath.getShape(tB);
@@ -289,5 +303,32 @@ class TensorMath {
289
303
  }
290
304
  return matC;
291
305
  }
306
+ static mv(tA, tB) {
307
+ const shapeA = TensorMath.getShape(tA);
308
+ const shapeB = TensorMath.getShape(tB);
309
+ if (shapeA.length !== 2 || shapeB.length !== 1)
310
+ throw new Error("Input is not a 2D and 1D tensor pair");
311
+ const matA = tA;
312
+ const matB = tB.map(el => [el]); // Turn the 1D tensor into a nx1 matrix (vector)
313
+ return TensorMath.mm(matA, matB).map(el => el[0]);
314
+ }
315
+ static matmul(tA, tB) {
316
+ const shapeA = TensorMath.getShape(tA);
317
+ const shapeB = TensorMath.getShape(tB);
318
+ if (shapeA.length === 1 && shapeB.length === 1) {
319
+ return TensorMath.dot(tA, tB);
320
+ }
321
+ else if (shapeA.length === 1 && shapeB.length === 2) {
322
+ return TensorMath.mm([tA], tB)[0];
323
+ }
324
+ else if (shapeA.length === 2 && shapeB.length === 1) {
325
+ return TensorMath.mv(tA, tB);
326
+ }
327
+ else if (shapeA.length === 2 && shapeB.length === 2) {
328
+ return TensorMath.mm(tA, tB);
329
+ }
330
+ // Batched matmul will come when general nD transpose is done
331
+ throw new Error(`Shapes [] and [] are not supported`);
332
+ }
292
333
  }
293
334
  exports.TensorMath = TensorMath;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.7",
3
+ "version": "0.1.9",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {