catniff 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -26,7 +26,7 @@ console.log(x.grad); // 5
26
26
 
27
27
  Tensors in Catniff are either numbers (scalars/0-D tensors) or multidimensional number arrays (n-D tensors).
28
28
 
29
- There is a built-in `TensorMath` class to help with Tensor arithmetic, for example:
29
+ There is a built-in `TensorMath` class to help with tensor arithmetic, for example:
30
30
  ```js
31
31
  const { TensorMath } = require("catniff");
32
32
 
@@ -70,6 +70,8 @@ All available APIs are in `./src/autograd.ts`.
70
70
  I'm mostly just learning and playing with this currently, so there are no concrete plans yet, but here are what I currently have in mind:
71
71
 
72
72
  * Fix whatever is the problem right now (there are a lot of problems right now lol).
73
+ * Add more tensor ops.
74
+ * Proper documentation.
73
75
  * GPU acceleration.
74
76
  * Some general neural net APIs.
75
77
 
@@ -6,12 +6,36 @@ export declare enum OP {
6
6
  MUL = 3,
7
7
  POW = 4,
8
8
  DIV = 5,
9
- NEG = 6,
10
- EXP = 7,
11
- LOG = 8,
12
- RELU = 9,
13
- SIGMOID = 10,
14
- TANH = 11
9
+ GE = 6,
10
+ LE = 7,
11
+ GT = 8,
12
+ LT = 9,
13
+ EQ = 10,
14
+ NEG = 11,
15
+ ABS = 12,
16
+ SIGN = 13,
17
+ SIN = 14,
18
+ COS = 15,
19
+ TAN = 16,
20
+ ASIN = 17,
21
+ ACOS = 18,
22
+ ATAN = 19,
23
+ SINH = 20,
24
+ COSH = 21,
25
+ ASINH = 22,
26
+ ACOSH = 23,
27
+ ATANH = 24,
28
+ SQRT = 25,
29
+ EXP = 26,
30
+ LOG = 27,
31
+ LOG2 = 28,
32
+ LOG10 = 29,
33
+ LOG1P = 30,
34
+ RELU = 31,
35
+ SIGMOID = 32,
36
+ TANH = 33,
37
+ T = 34,
38
+ MM = 35
15
39
  }
16
40
  export declare class Node {
17
41
  value: Tensor;
@@ -26,12 +50,36 @@ export declare class Node {
26
50
  mul(other: Node | number): Node;
27
51
  pow(other: Node | number): Node;
28
52
  div(other: Node | number): Node;
53
+ ge(other: Node | number): Node;
54
+ le(other: Node | number): Node;
55
+ gt(other: Node | number): Node;
56
+ lt(other: Node | number): Node;
57
+ eq(other: Node | number): Node;
29
58
  neg(): Node;
59
+ abs(): Node;
60
+ sign(): Node;
61
+ sin(): Node;
62
+ cos(): Node;
63
+ tan(): Node;
64
+ asin(): Node;
65
+ acos(): Node;
66
+ atan(): Node;
67
+ sinh(): Node;
68
+ cosh(): Node;
69
+ asinh(): Node;
70
+ acosh(): Node;
71
+ atanh(): Node;
72
+ sqrt(): Node;
30
73
  exp(): Node;
31
74
  log(): Node;
75
+ log2(): Node;
76
+ log10(): Node;
77
+ log1p(): Node;
32
78
  relu(): Node;
33
79
  sigmoid(): Node;
34
80
  tanh(): Node;
81
+ t(): Node;
82
+ mm(other: Node | number): Node;
35
83
  backward(): void;
36
84
  static forceNode(value: Node | number): Node;
37
85
  static addGrad(node: Node, accumGrad: Tensor): void;
package/dist/autograd.js CHANGED
@@ -2,7 +2,7 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Node = exports.OP = void 0;
4
4
  const tensor_1 = require("./tensor");
5
- const { add, sub, mul, pow, div, neg, exp, log, relu, sigmoid, tanh, ge } = tensor_1.TensorMath;
5
+ const { add, sub, mul, pow, div, gt, lt, ge, le, eq, neg, abs, sign, sin, cos, tan, asin, acos, atan, sinh, cosh, asinh, acosh, atanh, sqrt, exp, log, log2, log10, log1p, relu, sigmoid, tanh, t, mm } = tensor_1.TensorMath;
6
6
  var OP;
7
7
  (function (OP) {
8
8
  OP[OP["NONE"] = 0] = "NONE";
@@ -11,12 +11,36 @@ var OP;
11
11
  OP[OP["MUL"] = 3] = "MUL";
12
12
  OP[OP["POW"] = 4] = "POW";
13
13
  OP[OP["DIV"] = 5] = "DIV";
14
- OP[OP["NEG"] = 6] = "NEG";
15
- OP[OP["EXP"] = 7] = "EXP";
16
- OP[OP["LOG"] = 8] = "LOG";
17
- OP[OP["RELU"] = 9] = "RELU";
18
- OP[OP["SIGMOID"] = 10] = "SIGMOID";
19
- OP[OP["TANH"] = 11] = "TANH";
14
+ OP[OP["GE"] = 6] = "GE";
15
+ OP[OP["LE"] = 7] = "LE";
16
+ OP[OP["GT"] = 8] = "GT";
17
+ OP[OP["LT"] = 9] = "LT";
18
+ OP[OP["EQ"] = 10] = "EQ";
19
+ OP[OP["NEG"] = 11] = "NEG";
20
+ OP[OP["ABS"] = 12] = "ABS";
21
+ OP[OP["SIGN"] = 13] = "SIGN";
22
+ OP[OP["SIN"] = 14] = "SIN";
23
+ OP[OP["COS"] = 15] = "COS";
24
+ OP[OP["TAN"] = 16] = "TAN";
25
+ OP[OP["ASIN"] = 17] = "ASIN";
26
+ OP[OP["ACOS"] = 18] = "ACOS";
27
+ OP[OP["ATAN"] = 19] = "ATAN";
28
+ OP[OP["SINH"] = 20] = "SINH";
29
+ OP[OP["COSH"] = 21] = "COSH";
30
+ OP[OP["ASINH"] = 22] = "ASINH";
31
+ OP[OP["ACOSH"] = 23] = "ACOSH";
32
+ OP[OP["ATANH"] = 24] = "ATANH";
33
+ OP[OP["SQRT"] = 25] = "SQRT";
34
+ OP[OP["EXP"] = 26] = "EXP";
35
+ OP[OP["LOG"] = 27] = "LOG";
36
+ OP[OP["LOG2"] = 28] = "LOG2";
37
+ OP[OP["LOG10"] = 29] = "LOG10";
38
+ OP[OP["LOG1P"] = 30] = "LOG1P";
39
+ OP[OP["RELU"] = 31] = "RELU";
40
+ OP[OP["SIGMOID"] = 32] = "SIGMOID";
41
+ OP[OP["TANH"] = 33] = "TANH";
42
+ OP[OP["T"] = 34] = "T";
43
+ OP[OP["MM"] = 35] = "MM";
20
44
  })(OP || (exports.OP = OP = {}));
21
45
  class Node {
22
46
  value;
@@ -61,7 +85,7 @@ class Node {
61
85
  out.feedBackward = () => {
62
86
  // x * y d/dx = y
63
87
  Node.addGrad(this, mul(out.grad, other.value));
64
- // x + y d/dy = x
88
+ // x * y d/dy = x
65
89
  Node.addGrad(other, mul(out.grad, this.value));
66
90
  };
67
91
  return out;
@@ -94,6 +118,46 @@ class Node {
94
118
  };
95
119
  return out;
96
120
  }
121
+ ge(other) {
122
+ other = Node.forceNode(other);
123
+ const out = new Node(ge(this.value, other.value), [this, other], OP.GE);
124
+ out.feedBackward = () => {
125
+ // We consider the derivative of ge to be 0, which does not add to current grad, so this function is just empty
126
+ };
127
+ return out;
128
+ }
129
+ le(other) {
130
+ other = Node.forceNode(other);
131
+ const out = new Node(le(this.value, other.value), [this, other], OP.LE);
132
+ out.feedBackward = () => {
133
+ // We consider the derivative of le to be 0, which does not add to current grad, so this function is just empty
134
+ };
135
+ return out;
136
+ }
137
+ gt(other) {
138
+ other = Node.forceNode(other);
139
+ const out = new Node(gt(this.value, other.value), [this, other], OP.GT);
140
+ out.feedBackward = () => {
141
+ // We consider the derivative of gt to be 0, which does not add to current grad, so this function is just empty
142
+ };
143
+ return out;
144
+ }
145
+ lt(other) {
146
+ other = Node.forceNode(other);
147
+ const out = new Node(lt(this.value, other.value), [this, other], OP.LT);
148
+ out.feedBackward = () => {
149
+ // We consider the derivative of lt to be 0, which does not add to current grad, so this function is just empty
150
+ };
151
+ return out;
152
+ }
153
+ eq(other) {
154
+ other = Node.forceNode(other);
155
+ const out = new Node(eq(this.value, other.value), [this, other], OP.EQ);
156
+ out.feedBackward = () => {
157
+ // We consider the derivative of eq to be 0, which does not add to current grad, so this function is just empty
158
+ };
159
+ return out;
160
+ }
97
161
  neg() {
98
162
  const out = new Node(neg(this.value), [this], OP.NEG);
99
163
  out.feedBackward = () => {
@@ -102,6 +166,119 @@ class Node {
102
166
  };
103
167
  return out;
104
168
  }
169
+ abs() {
170
+ const out = new Node(abs(this.value), [this], OP.ABS);
171
+ out.feedBackward = () => {
172
+ // |x| d/dx = sign(x)
173
+ Node.addGrad(this, mul(out.grad, sign(this.value)));
174
+ };
175
+ return out;
176
+ }
177
+ sign() {
178
+ const out = new Node(sign(this.value), [this], OP.SIGN);
179
+ out.feedBackward = () => {
180
+ // We consider the derivative of sign to be 0, which does not add to current grad, so this function is just empty
181
+ };
182
+ return out;
183
+ }
184
+ sin() {
185
+ const out = new Node(sin(this.value), [this], OP.SIN);
186
+ out.feedBackward = () => {
187
+ // sinx d/dx = cosx
188
+ Node.addGrad(this, mul(out.grad, cos(this.value)));
189
+ };
190
+ return out;
191
+ }
192
+ cos() {
193
+ const out = new Node(cos(this.value), [this], OP.COS);
194
+ out.feedBackward = () => {
195
+ // cosx d/dx = -sinx
196
+ Node.addGrad(this, mul(out.grad, neg(sin(this.value))));
197
+ };
198
+ return out;
199
+ }
200
+ tan() {
201
+ const tanResult = tan(this.value);
202
+ const out = new Node(tanResult, [this], OP.TAN);
203
+ out.feedBackward = () => {
204
+ // tanx d/dx = 1+(tanx)^2
205
+ Node.addGrad(this, mul(out.grad, add(1, pow(tanResult, 2))));
206
+ };
207
+ return out;
208
+ }
209
+ asin() {
210
+ const out = new Node(asin(this.value), [this], OP.ASIN);
211
+ out.feedBackward = () => {
212
+ // asinx d/dx = 1/sqrt(1-x^2)
213
+ Node.addGrad(this, div(out.grad, sqrt(sub(1, pow(this.value, 2)))));
214
+ };
215
+ return out;
216
+ }
217
+ acos() {
218
+ const out = new Node(acos(this.value), [this], OP.ACOS);
219
+ out.feedBackward = () => {
220
+ // acosx d/dx = -1/sqrt(1-x^2)
221
+ Node.addGrad(this, neg(div(out.grad, sqrt(sub(1, pow(this.value, 2))))));
222
+ };
223
+ return out;
224
+ }
225
+ atan() {
226
+ const out = new Node(atan(this.value), [this], OP.ATAN);
227
+ out.feedBackward = () => {
228
+ // atanx d/dx = 1/(1+x^2)
229
+ Node.addGrad(this, div(out.grad, add(1, pow(this.value, 2))));
230
+ };
231
+ return out;
232
+ }
233
+ sinh() {
234
+ const out = new Node(sinh(this.value), [this], OP.SINH);
235
+ out.feedBackward = () => {
236
+ // sinhx d/dx = coshx
237
+ Node.addGrad(this, mul(out.grad, cosh(this.value)));
238
+ };
239
+ return out;
240
+ }
241
+ cosh() {
242
+ const out = new Node(cosh(this.value), [this], OP.COSH);
243
+ out.feedBackward = () => {
244
+ // coshx d/dx = sinhx
245
+ Node.addGrad(this, mul(out.grad, sinh(this.value)));
246
+ };
247
+ return out;
248
+ }
249
+ asinh() {
250
+ const out = new Node(asinh(this.value), [this], OP.ASINH);
251
+ out.feedBackward = () => {
252
+ // asinhx d/dx = 1/sqrt(1+x^2)
253
+ Node.addGrad(this, div(out.grad, sqrt(add(1, pow(this.value, 2)))));
254
+ };
255
+ return out;
256
+ }
257
+ acosh() {
258
+ const out = new Node(acosh(this.value), [this], OP.ACOSH);
259
+ out.feedBackward = () => {
260
+ // acosx d/dx = 1/(sqrt(x-1)*sqrt(x+1))
261
+ Node.addGrad(this, div(out.grad, mul(sqrt(sub(this.value, 1)), sqrt(add(this.value, 1)))));
262
+ };
263
+ return out;
264
+ }
265
+ atanh() {
266
+ const out = new Node(atanh(this.value), [this], OP.ATANH);
267
+ out.feedBackward = () => {
268
+ // atanx d/dx = 1/(1-x^2)
269
+ Node.addGrad(this, div(out.grad, sub(1, pow(this.value, 2))));
270
+ };
271
+ return out;
272
+ }
273
+ sqrt() {
274
+ const sqrtResult = sqrt(this.value);
275
+ const out = new Node(sqrtResult, [this], OP.SQRT);
276
+ out.feedBackward = () => {
277
+ // sqrt(x) d/dx = 1/(2*sqrt(x))
278
+ Node.addGrad(this, div(out.grad, mul(2, sqrtResult)));
279
+ };
280
+ return out;
281
+ }
105
282
  exp() {
106
283
  const expResult = exp(this.value);
107
284
  const out = new Node(expResult, [this], OP.EXP);
@@ -119,6 +296,30 @@ class Node {
119
296
  };
120
297
  return out;
121
298
  }
299
+ log2() {
300
+ const out = new Node(log2(this.value), [this], OP.LOG2);
301
+ out.feedBackward = () => {
302
+ // log2(x) d/dx = 1/(xln2)
303
+ Node.addGrad(this, div(out.grad, mul(this.value, Math.log(2))));
304
+ };
305
+ return out;
306
+ }
307
+ log10() {
308
+ const out = new Node(log10(this.value), [this], OP.LOG10);
309
+ out.feedBackward = () => {
310
+ // log2(x) d/dx = 1/(xln10)
311
+ Node.addGrad(this, div(out.grad, mul(this.value, Math.log(10))));
312
+ };
313
+ return out;
314
+ }
315
+ log1p() {
316
+ const out = new Node(log1p(this.value), [this], OP.LOG1P);
317
+ out.feedBackward = () => {
318
+ // ln(1+x) d/dx = 1/(1+x)
319
+ Node.addGrad(this, div(out.grad, add(this.value, 1)));
320
+ };
321
+ return out;
322
+ }
122
323
  relu() {
123
324
  const out = new Node(relu(this.value), [this], OP.RELU);
124
325
  out.feedBackward = () => {
@@ -142,6 +343,22 @@ class Node {
142
343
  };
143
344
  return out;
144
345
  }
346
+ t() {
347
+ const out = new Node(t(this.value), [this], OP.T);
348
+ out.feedBackward = () => {
349
+ Node.addGrad(this, t(out.grad));
350
+ };
351
+ return out;
352
+ }
353
+ mm(other) {
354
+ other = Node.forceNode(other);
355
+ const out = new Node(mm(this.value, other.value), [this, other], OP.MM);
356
+ out.feedBackward = () => {
357
+ Node.addGrad(this, mm(out.grad, t(other.value)));
358
+ Node.addGrad(other, mm(t(this.value), out.grad));
359
+ };
360
+ return out;
361
+ }
145
362
  backward() {
146
363
  // Build topological order
147
364
  const topo = [];
package/dist/tensor.d.ts CHANGED
@@ -3,6 +3,8 @@ export declare class TensorMath {
3
3
  static create(num: number, shape: number[]): Tensor;
4
4
  static getShape(tA: Tensor): number[];
5
5
  static padShape(tA: Tensor, tB: Tensor): [Tensor[], Tensor[]];
6
+ static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
7
+ static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
6
8
  static add(tA: Tensor, tB: Tensor): Tensor;
7
9
  static sub(tA: Tensor, tB: Tensor): Tensor;
8
10
  static mul(tA: Tensor, tB: Tensor): Tensor;
@@ -14,8 +16,25 @@ export declare class TensorMath {
14
16
  static le(tA: Tensor, tB: Tensor): Tensor;
15
17
  static eq(tA: Tensor, tB: Tensor): Tensor;
16
18
  static neg(tA: Tensor): Tensor;
19
+ static abs(tA: Tensor): Tensor;
20
+ static sign(tA: Tensor): Tensor;
21
+ static sin(tA: Tensor): Tensor;
22
+ static cos(tA: Tensor): Tensor;
23
+ static tan(tA: Tensor): Tensor;
24
+ static asin(tA: Tensor): Tensor;
25
+ static acos(tA: Tensor): Tensor;
26
+ static atan(tA: Tensor): Tensor;
27
+ static sinh(tA: Tensor): Tensor;
28
+ static cosh(tA: Tensor): Tensor;
29
+ static asinh(tA: Tensor): Tensor;
30
+ static acosh(tA: Tensor): Tensor;
31
+ static atanh(tA: Tensor): Tensor;
32
+ static sqrt(tA: Tensor): Tensor;
17
33
  static exp(tA: Tensor): Tensor;
18
34
  static log(tA: Tensor): Tensor;
35
+ static log2(tA: Tensor): Tensor;
36
+ static log10(tA: Tensor): Tensor;
37
+ static log1p(tA: Tensor): Tensor;
19
38
  static relu(tA: Tensor): Tensor;
20
39
  static sigmoid(tA: Tensor): Tensor;
21
40
  static tanh(tA: Tensor): Tensor;
@@ -23,4 +42,6 @@ export declare class TensorMath {
23
42
  static squeeze(tA: Tensor, dims?: number[] | number): Tensor;
24
43
  static sumAxis(tA: Tensor, axis: number): Tensor;
25
44
  static sum(tA: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
45
+ static t(tA: Tensor): Tensor;
46
+ static mm(tA: Tensor, tB: Tensor): Tensor;
26
47
  }
package/dist/tensor.js CHANGED
@@ -35,9 +35,9 @@ class TensorMath {
35
35
  }
36
36
  return [tA, tB];
37
37
  }
38
- static add(tA, tB) {
38
+ static elementWiseAB(tA, tB, op) {
39
39
  if (typeof tA === "number" && typeof tB === "number") {
40
- return tA + tB;
40
+ return op(tA, tB);
41
41
  }
42
42
  [tA, tB] = TensorMath.padShape(tA, tB);
43
43
  const outLen = Math.max(tA.length, tB.length);
@@ -48,210 +48,116 @@ class TensorMath {
48
48
  for (let i = 0; i < outLen; i++) {
49
49
  const subA = tA[tA.length === 1 ? 0 : i];
50
50
  const subB = tB[tB.length === 1 ? 0 : i];
51
- result.push(TensorMath.add(subA, subB));
51
+ result.push(TensorMath.elementWiseAB(subA, subB, op));
52
52
  }
53
53
  return result;
54
54
  }
55
- static sub(tA, tB) {
56
- if (typeof tA === "number" && typeof tB === "number") {
57
- return tA - tB;
58
- }
59
- [tA, tB] = TensorMath.padShape(tA, tB);
60
- const outLen = Math.max(tA.length, tB.length);
61
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
62
- throw new Error("Inputs are incompatible tensors");
55
+ static elementWiseSelf(tA, op) {
56
+ if (typeof tA === "number") {
57
+ return op(tA);
63
58
  }
64
- const result = [];
65
- for (let i = 0; i < outLen; i++) {
66
- const subA = tA[tA.length === 1 ? 0 : i];
67
- const subB = tB[tB.length === 1 ? 0 : i];
68
- result.push(TensorMath.sub(subA, subB));
59
+ else {
60
+ return tA.map(subA => TensorMath.elementWiseSelf(subA, op));
69
61
  }
70
- return result;
62
+ }
63
+ static add(tA, tB) {
64
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA + tB);
65
+ }
66
+ static sub(tA, tB) {
67
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA - tB);
71
68
  }
72
69
  static mul(tA, tB) {
73
- if (typeof tA === "number" && typeof tB === "number") {
74
- return tA * tB;
75
- }
76
- [tA, tB] = TensorMath.padShape(tA, tB);
77
- const outLen = Math.max(tA.length, tB.length);
78
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
79
- throw new Error("Inputs are incompatible tensors");
80
- }
81
- const result = [];
82
- for (let i = 0; i < outLen; i++) {
83
- const subA = tA[tA.length === 1 ? 0 : i];
84
- const subB = tB[tB.length === 1 ? 0 : i];
85
- result.push(TensorMath.mul(subA, subB));
86
- }
87
- return result;
70
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA * tB);
88
71
  }
89
72
  static pow(tA, tB) {
90
- if (typeof tA === "number" && typeof tB === "number") {
91
- return tA ** tB;
92
- }
93
- [tA, tB] = TensorMath.padShape(tA, tB);
94
- const outLen = Math.max(tA.length, tB.length);
95
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
96
- throw new Error("Inputs are incompatible tensors");
97
- }
98
- const result = [];
99
- for (let i = 0; i < outLen; i++) {
100
- const subA = tA[tA.length === 1 ? 0 : i];
101
- const subB = tB[tB.length === 1 ? 0 : i];
102
- result.push(TensorMath.pow(subA, subB));
103
- }
104
- return result;
73
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA ** tB);
105
74
  }
106
75
  static div(tA, tB) {
107
- if (typeof tA === "number" && typeof tB === "number") {
108
- return tA / tB;
109
- }
110
- [tA, tB] = TensorMath.padShape(tA, tB);
111
- const outLen = Math.max(tA.length, tB.length);
112
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
113
- throw new Error("Inputs are incompatible tensors");
114
- }
115
- const result = [];
116
- for (let i = 0; i < outLen; i++) {
117
- const subA = tA[tA.length === 1 ? 0 : i];
118
- const subB = tB[tB.length === 1 ? 0 : i];
119
- result.push(TensorMath.div(subA, subB));
120
- }
121
- return result;
76
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA / tB);
122
77
  }
123
78
  static gt(tA, tB) {
124
- if (typeof tA === "number" && typeof tB === "number") {
125
- return tA > tB ? 1 : 0;
126
- }
127
- [tA, tB] = TensorMath.padShape(tA, tB);
128
- const outLen = Math.max(tA.length, tB.length);
129
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
130
- throw new Error("Inputs are incompatible tensors");
131
- }
132
- const result = [];
133
- for (let i = 0; i < outLen; i++) {
134
- const subA = tA[tA.length === 1 ? 0 : i];
135
- const subB = tB[tB.length === 1 ? 0 : i];
136
- result.push(TensorMath.gt(subA, subB));
137
- }
138
- return result;
79
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA > tB ? 1 : 0);
139
80
  }
140
81
  static lt(tA, tB) {
141
- if (typeof tA === "number" && typeof tB === "number") {
142
- return tA < tB ? 1 : 0;
143
- }
144
- [tA, tB] = TensorMath.padShape(tA, tB);
145
- const outLen = Math.max(tA.length, tB.length);
146
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
147
- throw new Error("Inputs are incompatible tensors");
148
- }
149
- const result = [];
150
- for (let i = 0; i < outLen; i++) {
151
- const subA = tA[tA.length === 1 ? 0 : i];
152
- const subB = tB[tB.length === 1 ? 0 : i];
153
- result.push(TensorMath.lt(subA, subB));
154
- }
155
- return result;
82
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA < tB ? 1 : 0);
156
83
  }
157
84
  static ge(tA, tB) {
158
- if (typeof tA === "number" && typeof tB === "number") {
159
- return tA >= tB ? 1 : 0;
160
- }
161
- [tA, tB] = TensorMath.padShape(tA, tB);
162
- const outLen = Math.max(tA.length, tB.length);
163
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
164
- throw new Error("Inputs are incompatible tensors");
165
- }
166
- const result = [];
167
- for (let i = 0; i < outLen; i++) {
168
- const subA = tA[tA.length === 1 ? 0 : i];
169
- const subB = tB[tB.length === 1 ? 0 : i];
170
- result.push(TensorMath.ge(subA, subB));
171
- }
172
- return result;
85
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA >= tB ? 1 : 0);
173
86
  }
174
87
  static le(tA, tB) {
175
- if (typeof tA === "number" && typeof tB === "number") {
176
- return tA <= tB ? 1 : 0;
177
- }
178
- [tA, tB] = TensorMath.padShape(tA, tB);
179
- const outLen = Math.max(tA.length, tB.length);
180
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
181
- throw new Error("Inputs are incompatible tensors");
182
- }
183
- const result = [];
184
- for (let i = 0; i < outLen; i++) {
185
- const subA = tA[tA.length === 1 ? 0 : i];
186
- const subB = tB[tB.length === 1 ? 0 : i];
187
- result.push(TensorMath.le(subA, subB));
188
- }
189
- return result;
88
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA <= tB ? 1 : 0);
190
89
  }
191
90
  static eq(tA, tB) {
192
- if (typeof tA === "number" && typeof tB === "number") {
193
- return tA === tB ? 1 : 0;
194
- }
195
- [tA, tB] = TensorMath.padShape(tA, tB);
196
- const outLen = Math.max(tA.length, tB.length);
197
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
198
- throw new Error("Inputs are incompatible tensors");
199
- }
200
- const result = [];
201
- for (let i = 0; i < outLen; i++) {
202
- const subA = tA[tA.length === 1 ? 0 : i];
203
- const subB = tB[tB.length === 1 ? 0 : i];
204
- result.push(TensorMath.eq(subA, subB));
205
- }
206
- return result;
91
+ return TensorMath.elementWiseAB(tA, tB, (tA, tB) => tA === tB ? 1 : 0);
207
92
  }
208
93
  static neg(tA) {
209
- if (typeof tA === "number") {
210
- return -tA;
211
- }
212
- else {
213
- return tA.map(subA => TensorMath.neg(subA));
214
- }
94
+ return TensorMath.elementWiseSelf(tA, (tA) => -tA);
95
+ }
96
+ static abs(tA) {
97
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.abs(tA));
98
+ }
99
+ static sign(tA) {
100
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.sign(tA));
101
+ }
102
+ static sin(tA) {
103
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.sin(tA));
104
+ }
105
+ static cos(tA) {
106
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.cos(tA));
107
+ }
108
+ static tan(tA) {
109
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.tan(tA));
110
+ }
111
+ static asin(tA) {
112
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.asin(tA));
113
+ }
114
+ static acos(tA) {
115
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.acos(tA));
116
+ }
117
+ static atan(tA) {
118
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.atan(tA));
119
+ }
120
+ static sinh(tA) {
121
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.sinh(tA));
122
+ }
123
+ static cosh(tA) {
124
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.cosh(tA));
125
+ }
126
+ static asinh(tA) {
127
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.asinh(tA));
128
+ }
129
+ static acosh(tA) {
130
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.acosh(tA));
131
+ }
132
+ static atanh(tA) {
133
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.atanh(tA));
134
+ }
135
+ static sqrt(tA) {
136
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.sqrt(tA));
215
137
  }
216
138
  static exp(tA) {
217
- if (typeof tA === "number") {
218
- return Math.exp(tA);
219
- }
220
- else {
221
- return tA.map(subA => TensorMath.exp(subA));
222
- }
139
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.exp(tA));
223
140
  }
224
141
  static log(tA) {
225
- if (typeof tA === "number") {
226
- return Math.log(tA);
227
- }
228
- else {
229
- return tA.map(subA => TensorMath.log(subA));
230
- }
142
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.log(tA));
143
+ }
144
+ static log2(tA) {
145
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.log2(tA));
146
+ }
147
+ static log10(tA) {
148
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.log10(tA));
149
+ }
150
+ static log1p(tA) {
151
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.log(tA));
231
152
  }
232
153
  static relu(tA) {
233
- if (typeof tA === "number") {
234
- return Math.max(tA, 0);
235
- }
236
- else {
237
- return tA.map(subA => TensorMath.relu(subA));
238
- }
154
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.max(tA, 0));
239
155
  }
240
156
  static sigmoid(tA) {
241
- if (typeof tA === "number") {
242
- return 1 / (1 + Math.exp(-tA));
243
- }
244
- else {
245
- return tA.map(subA => TensorMath.sigmoid(subA));
246
- }
157
+ return TensorMath.elementWiseSelf(tA, (tA) => 1 / (1 + Math.exp(-tA)));
247
158
  }
248
159
  static tanh(tA) {
249
- if (typeof tA === "number") {
250
- return Math.tanh(tA);
251
- }
252
- else {
253
- return tA.map(subA => TensorMath.tanh(subA));
254
- }
160
+ return TensorMath.elementWiseSelf(tA, (tA) => Math.tanh(tA));
255
161
  }
256
162
  static squeezeAxis(tA, axis) {
257
163
  if (typeof tA === "number")
@@ -315,5 +221,43 @@ class TensorMath {
315
221
  }
316
222
  return keepDims ? out : TensorMath.squeeze(out, dims);
317
223
  }
224
+ static t(tA) {
225
+ const shapeA = TensorMath.getShape(tA);
226
+ if (shapeA.length !== 2)
227
+ throw new Error("Input is not a matrix");
228
+ const matA = tA;
229
+ const matARows = matA.length;
230
+ const matACols = matA[0].length;
231
+ const matATranspose = Array.from({ length: matACols }, () => new Array(matARows).fill(0));
232
+ for (let i = 0; i < matARows; i++) {
233
+ for (let j = 0; j < matACols; j++) {
234
+ matATranspose[j][i] = matA[i][j];
235
+ }
236
+ }
237
+ return matATranspose;
238
+ }
239
+ static mm(tA, tB) {
240
+ const shapeA = TensorMath.getShape(tA);
241
+ const shapeB = TensorMath.getShape(tB);
242
+ if (shapeA.length !== 2 || shapeB.length !== 2)
243
+ throw new Error("Inputs are not matrices");
244
+ const matA = tA;
245
+ const matB = tB;
246
+ const matARows = matA.length;
247
+ const matACols = matA[0].length;
248
+ const matBRows = matB.length;
249
+ const matBCols = matB[0].length;
250
+ if (matACols !== matBRows)
251
+ throw new Error("Invalid matrices shape for multiplication");
252
+ const matC = Array.from({ length: matARows }, () => new Array(matBCols).fill(0));
253
+ for (let i = 0; i < matARows; i++) {
254
+ for (let j = 0; j < matBCols; j++) {
255
+ for (let k = 0; k < matACols; k++) {
256
+ matC[i][j] += matA[i][k] * matB[k][j];
257
+ }
258
+ }
259
+ }
260
+ return matC;
261
+ }
318
262
  }
319
263
  exports.TensorMath = TensorMath;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.4",
3
+ "version": "0.1.6",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {
@@ -14,12 +14,18 @@
14
14
  "cats",
15
15
  "catniff",
16
16
  "autograd",
17
+ "autodiff",
17
18
  "ml",
18
19
  "dl",
19
20
  "ai",
20
21
  "maths",
21
22
  "gradient",
22
- "tensors"
23
+ "tensors",
24
+ "library",
25
+ "framework",
26
+ "neural-network",
27
+ "machine-learning",
28
+ "deep-learning"
23
29
  ],
24
30
  "author": "nguyenphuminh",
25
31
  "license": "GPL-3.0",