catniff 0.1.8 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,8 +45,10 @@ export declare enum OP {
45
45
  SIGMOID = 42,
46
46
  TANH = 43,
47
47
  T = 44,
48
- MM = 45,
49
- DOT = 46
48
+ DOT = 45,
49
+ MM = 46,
50
+ MV = 47,
51
+ MATMUL = 48
50
52
  }
51
53
  export declare class Node {
52
54
  value: Tensor;
@@ -56,26 +58,26 @@ export declare class Node {
56
58
  op: OP;
57
59
  feedBackward: Function;
58
60
  constructor(value: Tensor, children?: Node[], op?: OP);
59
- add(other: Node | number): Node;
60
- sub(other: Node | number): Node;
61
- mul(other: Node | number): Node;
62
- pow(other: Node | number): Node;
63
- div(other: Node | number): Node;
64
- ge(other: Node | number): Node;
65
- le(other: Node | number): Node;
66
- gt(other: Node | number): Node;
67
- lt(other: Node | number): Node;
68
- eq(other: Node | number): Node;
69
- logicalAnd(other: Node | number): Node;
70
- logicalOr(other: Node | number): Node;
71
- logicalXor(other: Node | number): Node;
61
+ add(other: Node | Tensor): Node;
62
+ sub(other: Node | Tensor): Node;
63
+ mul(other: Node | Tensor): Node;
64
+ pow(other: Node | Tensor): Node;
65
+ div(other: Node | Tensor): Node;
66
+ ge(other: Node | Tensor): Node;
67
+ le(other: Node | Tensor): Node;
68
+ gt(other: Node | Tensor): Node;
69
+ lt(other: Node | Tensor): Node;
70
+ eq(other: Node | Tensor): Node;
71
+ logicalAnd(other: Node | Tensor): Node;
72
+ logicalOr(other: Node | Tensor): Node;
73
+ logicalXor(other: Node | Tensor): Node;
72
74
  logicalNot(): Node;
73
- bitwiseAnd(other: Node | number): Node;
74
- bitwiseOr(other: Node | number): Node;
75
- bitwiseXor(other: Node | number): Node;
75
+ bitwiseAnd(other: Node | Tensor): Node;
76
+ bitwiseOr(other: Node | Tensor): Node;
77
+ bitwiseXor(other: Node | Tensor): Node;
76
78
  bitwiseNot(): Node;
77
- bitwiseLeftShift(other: Node | number): Node;
78
- bitwiseRightShift(other: Node | number): Node;
79
+ bitwiseLeftShift(other: Node | Tensor): Node;
80
+ bitwiseRightShift(other: Node | Tensor): Node;
79
81
  neg(): Node;
80
82
  abs(): Node;
81
83
  sign(): Node;
@@ -100,9 +102,11 @@ export declare class Node {
100
102
  sigmoid(): Node;
101
103
  tanh(): Node;
102
104
  t(): Node;
103
- mm(other: Node | number): Node;
104
- dot(other: Node | number): Node;
105
+ dot(other: Node | Tensor): Node;
106
+ mm(other: Node | Tensor): Node;
107
+ mv(other: Node | Tensor): Node;
108
+ matmul(other: Node | Tensor): Node;
105
109
  backward(): void;
106
- static forceNode(value: Node | number): Node;
110
+ static forceNode(value: Node | Tensor): Node;
107
111
  static addGrad(node: Node, accumGrad: Tensor): void;
108
112
  }
package/dist/autograd.js CHANGED
@@ -2,7 +2,7 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Node = exports.OP = void 0;
4
4
  const tensor_1 = require("./tensor");
5
- const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, mm, dot } = tensor_1.TensorMath;
5
+ const { add, sub, mul, pow, div, gt, lt, ge, le, eq, logicalAnd, logicalOr, logicalXor, logicalNot, bitwiseAnd, bitwiseOr, bitwiseXor, bitwiseNot, bitwiseLeftShift, bitwiseRightShift, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, dot, mm, mv, matmul } = tensor_1.TensorMath;
6
6
  var OP;
7
7
  (function (OP) {
8
8
  OP[OP["NONE"] = 0] = "NONE";
@@ -50,8 +50,10 @@ var OP;
50
50
  OP[OP["SIGMOID"] = 42] = "SIGMOID";
51
51
  OP[OP["TANH"] = 43] = "TANH";
52
52
  OP[OP["T"] = 44] = "T";
53
- OP[OP["MM"] = 45] = "MM";
54
- OP[OP["DOT"] = 46] = "DOT";
53
+ OP[OP["DOT"] = 45] = "DOT";
54
+ OP[OP["MM"] = 46] = "MM";
55
+ OP[OP["MV"] = 47] = "MV";
56
+ OP[OP["MATMUL"] = 48] = "MATMUL";
55
57
  })(OP || (exports.OP = OP = {}));
56
58
  class Node {
57
59
  value;
@@ -439,6 +441,15 @@ class Node {
439
441
  };
440
442
  return out;
441
443
  }
444
+ dot(other) {
445
+ other = Node.forceNode(other);
446
+ const out = new Node(dot(this.value, other.value), [this, other], OP.DOT);
447
+ out.feedBackward = () => {
448
+ Node.addGrad(this, mul(out.grad, other.value));
449
+ Node.addGrad(other, mul(out.grad, this.value));
450
+ };
451
+ return out;
452
+ }
442
453
  mm(other) {
443
454
  other = Node.forceNode(other);
444
455
  const out = new Node(mm(this.value, other.value), [this, other], OP.MM);
@@ -448,15 +459,46 @@ class Node {
448
459
  };
449
460
  return out;
450
461
  }
451
- dot(other) {
462
+ mv(other) {
452
463
  other = Node.forceNode(other);
453
- const out = new Node(dot(this.value, other.value), [this, other], OP.DOT);
464
+ const out = new Node(mv(this.value, other.value), [this, other], OP.MV);
454
465
  out.feedBackward = () => {
455
- Node.addGrad(this, mul(out.grad, other.value));
456
- Node.addGrad(other, mul(out.grad, this.value));
466
+ const outGradMat = out.grad.map(el => [el]);
467
+ Node.addGrad(this, mm(outGradMat, [other.value]));
468
+ Node.addGrad(other, mv(t(this.value), out.grad));
457
469
  };
458
470
  return out;
459
471
  }
472
+ matmul(other) {
473
+ other = Node.forceNode(other);
474
+ const out = new Node(matmul(this.value, other.value), [this, other], OP.MATMUL);
475
+ if (this.shape.length === 1 && other.shape.length === 1) {
476
+ out.feedBackward = () => {
477
+ Node.addGrad(this, mul(out.grad, other.value));
478
+ Node.addGrad(other, mul(out.grad, this.value));
479
+ };
480
+ }
481
+ else if (this.shape.length === 1 && other.shape.length === 2) {
482
+ out.feedBackward = () => {
483
+ Node.addGrad(this, matmul(out.grad, t(other.value)));
484
+ Node.addGrad(other, mm(t([this.value]), [out.grad]));
485
+ };
486
+ }
487
+ else if (this.shape.length === 2 && other.shape.length === 1) {
488
+ out.feedBackward = () => {
489
+ const outGradMat = out.grad.map(el => [el]);
490
+ Node.addGrad(this, mm(outGradMat, [other.value]));
491
+ Node.addGrad(other, mv(t(this.value), out.grad));
492
+ };
493
+ }
494
+ else if (this.shape.length === 2 && other.shape.length === 2) {
495
+ out.feedBackward = () => {
496
+ Node.addGrad(this, mm(out.grad, t(other.value)));
497
+ Node.addGrad(other, mm(t(this.value), out.grad));
498
+ };
499
+ }
500
+ return out;
501
+ }
460
502
  backward() {
461
503
  // Build topological order
462
504
  const topo = [];
package/dist/tensor.d.ts CHANGED
@@ -53,6 +53,8 @@ export declare class TensorMath {
53
53
  static sumAxis(tA: Tensor, axis: number): Tensor;
54
54
  static sum(tA: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
55
55
  static t(tA: Tensor): Tensor;
56
- static mm(tA: Tensor, tB: Tensor): Tensor;
57
56
  static dot(tA: Tensor, tB: Tensor): Tensor;
57
+ static mm(tA: Tensor, tB: Tensor): Tensor;
58
+ static mv(tA: Tensor, tB: Tensor): Tensor;
59
+ static matmul(tA: Tensor, tB: Tensor): Tensor;
58
60
  }
package/dist/tensor.js CHANGED
@@ -266,6 +266,20 @@ class TensorMath {
266
266
  }
267
267
  return matATranspose;
268
268
  }
269
+ static dot(tA, tB) {
270
+ const shapeA = TensorMath.getShape(tA);
271
+ const shapeB = TensorMath.getShape(tB);
272
+ if (shapeA.length !== 1 || shapeB.length !== 1 || shapeA[0] !== shapeB[0])
273
+ throw new Error("Inputs are not 1D tensors");
274
+ const vectLen = shapeA[0];
275
+ const vectA = tA;
276
+ const vectB = tB;
277
+ let sum = 0;
278
+ for (let index = 0; index < vectLen; index++) {
279
+ sum += vectA[index] * vectB[index];
280
+ }
281
+ return sum;
282
+ }
269
283
  static mm(tA, tB) {
270
284
  const shapeA = TensorMath.getShape(tA);
271
285
  const shapeB = TensorMath.getShape(tB);
@@ -289,19 +303,32 @@ class TensorMath {
289
303
  }
290
304
  return matC;
291
305
  }
292
- static dot(tA, tB) {
306
+ static mv(tA, tB) {
293
307
  const shapeA = TensorMath.getShape(tA);
294
308
  const shapeB = TensorMath.getShape(tB);
295
- if (shapeA.length !== 1 || shapeB.length !== 1 || shapeA[0] !== shapeB[0])
296
- throw new Error("Inputs are not 1D tensors");
297
- const vectLen = shapeA[0];
298
- const vectA = tA;
299
- const vectB = tB;
300
- let sum = 0;
301
- for (let index = 0; index < vectLen; index++) {
302
- sum += vectA[index] * vectB[index];
309
+ if (shapeA.length !== 2 || shapeB.length !== 1)
310
+ throw new Error("Input is not a 2D and 1D tensor pair");
311
+ const matA = tA;
312
+ const matB = tB.map(el => [el]); // Turn the 1D tensor into a nx1 matrix (vector)
313
+ return TensorMath.mm(matA, matB).map(el => el[0]);
314
+ }
315
+ static matmul(tA, tB) {
316
+ const shapeA = TensorMath.getShape(tA);
317
+ const shapeB = TensorMath.getShape(tB);
318
+ if (shapeA.length === 1 && shapeB.length === 1) {
319
+ return TensorMath.dot(tA, tB);
303
320
  }
304
- return sum;
321
+ else if (shapeA.length === 1 && shapeB.length === 2) {
322
+ return TensorMath.mm([tA], tB)[0];
323
+ }
324
+ else if (shapeA.length === 2 && shapeB.length === 1) {
325
+ return TensorMath.mv(tA, tB);
326
+ }
327
+ else if (shapeA.length === 2 && shapeB.length === 2) {
328
+ return TensorMath.mm(tA, tB);
329
+ }
330
+ // Batched matmul will come when general nD transpose is done
331
+ throw new Error(`Shapes [] and [] are not supported`);
305
332
  }
306
333
  }
307
334
  exports.TensorMath = TensorMath;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.8",
3
+ "version": "0.1.9",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {