catniff 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/dist/autograd.d.ts +2 -1
- package/dist/autograd.js +41 -22
- package/dist/tensor.d.ts +5 -0
- package/dist/tensor.js +175 -180
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -31,7 +31,7 @@ There is a built-in `TensorMath` class to help with Tensor arithmetic, for examp
|
|
|
31
31
|
const { TensorMath } = require("catniff");
|
|
32
32
|
|
|
33
33
|
const A = [ 1, 2, 3 ];
|
|
34
|
-
const B = 3
|
|
34
|
+
const B = 3;
|
|
35
35
|
console.log(TensorMath.add(A, B));
|
|
36
36
|
```
|
|
37
37
|
|
|
@@ -41,7 +41,7 @@ All available APIs are in `./src/tensor.ts`.
|
|
|
41
41
|
|
|
42
42
|
To compute the gradient of our mathematical expression, we use the `Node` class to dynamically build our DAG:
|
|
43
43
|
```js
|
|
44
|
-
const { Node } = require("
|
|
44
|
+
const { Node } = require("catniff");
|
|
45
45
|
|
|
46
46
|
const X = new Node([
|
|
47
47
|
[ 0.5, -1.0 ],
|
|
@@ -69,6 +69,7 @@ All available APIs are in `./src/autograd.ts`.
|
|
|
69
69
|
|
|
70
70
|
I'm mostly just learning and playing with this currently, so there are no concrete plans yet, but here are what I currently have in mind:
|
|
71
71
|
|
|
72
|
+
* Fix whatever is the problem right now (there are a lot of problems right now lol).
|
|
72
73
|
* GPU acceleration.
|
|
73
74
|
* Some general neural net APIs.
|
|
74
75
|
|
package/dist/autograd.d.ts
CHANGED
package/dist/autograd.js
CHANGED
|
@@ -34,35 +34,35 @@ class Node {
|
|
|
34
34
|
this.feedBackward = () => { };
|
|
35
35
|
}
|
|
36
36
|
add(other) {
|
|
37
|
-
other =
|
|
37
|
+
other = Node.forceNode(other);
|
|
38
38
|
const out = new Node(add(this.value, other.value), [this, other], OP.ADD);
|
|
39
39
|
out.feedBackward = () => {
|
|
40
40
|
// x + y d/dx = 1, note that we apply the chain rule continuously so out.grad is multiplied into our derivative
|
|
41
|
-
|
|
41
|
+
Node.addGrad(this, out.grad);
|
|
42
42
|
// x + y d/dy = 1
|
|
43
|
-
|
|
43
|
+
Node.addGrad(other, out.grad);
|
|
44
44
|
};
|
|
45
45
|
return out;
|
|
46
46
|
}
|
|
47
47
|
sub(other) {
|
|
48
|
-
other =
|
|
48
|
+
other = Node.forceNode(other);
|
|
49
49
|
const out = new Node(sub(this.value, other.value), [this, other], OP.SUB);
|
|
50
50
|
out.feedBackward = () => {
|
|
51
51
|
// x - y d/dx = 1
|
|
52
|
-
|
|
52
|
+
Node.addGrad(this, out.grad);
|
|
53
53
|
// x - y d/dy = -1
|
|
54
|
-
|
|
54
|
+
Node.addGrad(other, neg(out.grad));
|
|
55
55
|
};
|
|
56
56
|
return out;
|
|
57
57
|
}
|
|
58
58
|
mul(other) {
|
|
59
|
-
other =
|
|
59
|
+
other = Node.forceNode(other);
|
|
60
60
|
const out = new Node(mul(this.value, other.value), [this, other], OP.MUL);
|
|
61
61
|
out.feedBackward = () => {
|
|
62
62
|
// x * y d/dx = y
|
|
63
|
-
|
|
63
|
+
Node.addGrad(this, mul(out.grad, other.value));
|
|
64
64
|
// x + y d/dy = x
|
|
65
|
-
|
|
65
|
+
Node.addGrad(other, mul(out.grad, this.value));
|
|
66
66
|
};
|
|
67
67
|
return out;
|
|
68
68
|
}
|
|
@@ -71,26 +71,26 @@ class Node {
|
|
|
71
71
|
const out = new Node(pow(this.value, other.value), [this, other], OP.POW);
|
|
72
72
|
out.feedBackward = () => {
|
|
73
73
|
// x^a d/dx = a*x^(a-1)
|
|
74
|
-
|
|
74
|
+
Node.addGrad(this, mul(out.grad, mul(other.value, pow(this.value, sub(other.value, 1)))));
|
|
75
75
|
// x^a d/da = x^a*lnx
|
|
76
|
-
|
|
76
|
+
Node.addGrad(other, mul(out.grad, mul(pow(this.value, other.value), log(this.value))));
|
|
77
77
|
};
|
|
78
78
|
return out;
|
|
79
79
|
}
|
|
80
80
|
const out = new Node(pow(this.value, other), [this], OP.POW);
|
|
81
81
|
out.feedBackward = () => {
|
|
82
|
-
|
|
82
|
+
Node.addGrad(this, mul(out.grad, mul(other, pow(this.value, sub(other, 1)))));
|
|
83
83
|
};
|
|
84
84
|
return out;
|
|
85
85
|
}
|
|
86
86
|
div(other) {
|
|
87
|
-
other =
|
|
87
|
+
other = Node.forceNode(other);
|
|
88
88
|
const out = new Node(div(this.value, other.value), [this, other], OP.DIV);
|
|
89
89
|
out.feedBackward = () => {
|
|
90
90
|
// x/y d/dx = 1/y
|
|
91
|
-
|
|
91
|
+
Node.addGrad(this, div(out.grad, other.value));
|
|
92
92
|
// x/y d/dy = -x/y^2
|
|
93
|
-
|
|
93
|
+
Node.addGrad(other, mul(out.grad, div(neg(this.value), pow(other.value, 2))));
|
|
94
94
|
};
|
|
95
95
|
return out;
|
|
96
96
|
}
|
|
@@ -98,7 +98,7 @@ class Node {
|
|
|
98
98
|
const out = new Node(neg(this.value), [this], OP.NEG);
|
|
99
99
|
out.feedBackward = () => {
|
|
100
100
|
// -x d/dx = -1
|
|
101
|
-
|
|
101
|
+
Node.addGrad(this, neg(out.grad));
|
|
102
102
|
};
|
|
103
103
|
return out;
|
|
104
104
|
}
|
|
@@ -107,7 +107,7 @@ class Node {
|
|
|
107
107
|
const out = new Node(expResult, [this], OP.EXP);
|
|
108
108
|
out.feedBackward = () => {
|
|
109
109
|
// e^x d/dx = e^x
|
|
110
|
-
|
|
110
|
+
Node.addGrad(this, mul(out.grad, expResult));
|
|
111
111
|
};
|
|
112
112
|
return out;
|
|
113
113
|
}
|
|
@@ -115,14 +115,14 @@ class Node {
|
|
|
115
115
|
const out = new Node(log(this.value), [this], OP.LOG);
|
|
116
116
|
out.feedBackward = () => {
|
|
117
117
|
// lnx d/dx = 1/x
|
|
118
|
-
|
|
118
|
+
Node.addGrad(this, div(out.grad, this.value));
|
|
119
119
|
};
|
|
120
120
|
return out;
|
|
121
121
|
}
|
|
122
122
|
relu() {
|
|
123
123
|
const out = new Node(relu(this.value), [this], OP.RELU);
|
|
124
124
|
out.feedBackward = () => {
|
|
125
|
-
|
|
125
|
+
Node.addGrad(this, mul(out.grad, ge(this.value, 0)));
|
|
126
126
|
};
|
|
127
127
|
return out;
|
|
128
128
|
}
|
|
@@ -130,7 +130,7 @@ class Node {
|
|
|
130
130
|
const sigmoidResult = sigmoid(this.value);
|
|
131
131
|
const out = new Node(sigmoidResult, [this], OP.SIGMOID);
|
|
132
132
|
out.feedBackward = () => {
|
|
133
|
-
|
|
133
|
+
Node.addGrad(this, mul(mul(out.grad, sigmoidResult), sub(1, sigmoidResult)));
|
|
134
134
|
};
|
|
135
135
|
return out;
|
|
136
136
|
}
|
|
@@ -138,7 +138,7 @@ class Node {
|
|
|
138
138
|
const tanhResult = tanh(this.value);
|
|
139
139
|
const out = new Node(tanhResult, [this], OP.TANH);
|
|
140
140
|
out.feedBackward = () => {
|
|
141
|
-
|
|
141
|
+
Node.addGrad(this, mul(out.grad, sub(1, mul(tanhResult, tanhResult))));
|
|
142
142
|
};
|
|
143
143
|
return out;
|
|
144
144
|
}
|
|
@@ -162,10 +162,29 @@ class Node {
|
|
|
162
162
|
topo[index].feedBackward();
|
|
163
163
|
}
|
|
164
164
|
}
|
|
165
|
-
forceNode(value) {
|
|
165
|
+
static forceNode(value) {
|
|
166
166
|
if (value instanceof Node)
|
|
167
167
|
return value;
|
|
168
168
|
return new Node(value);
|
|
169
169
|
}
|
|
170
|
+
static addGrad(node, accumGrad) {
|
|
171
|
+
const axesToSqueeze = [];
|
|
172
|
+
const axesToReduce = [];
|
|
173
|
+
const shape = node.shape;
|
|
174
|
+
const gradShape = tensor_1.TensorMath.getShape(accumGrad);
|
|
175
|
+
const paddedDims = gradShape.length - shape.length;
|
|
176
|
+
for (let i = 0; i < paddedDims; i++) {
|
|
177
|
+
axesToReduce.push(i);
|
|
178
|
+
axesToSqueeze.push(i);
|
|
179
|
+
}
|
|
180
|
+
for (let i = 0; i < shape.length; i++) {
|
|
181
|
+
if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
|
|
182
|
+
axesToReduce.push(i + paddedDims);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
const reducedGrad = tensor_1.TensorMath.sum(accumGrad, axesToReduce, true);
|
|
186
|
+
const squeezedGrad = tensor_1.TensorMath.squeeze(reducedGrad, axesToSqueeze);
|
|
187
|
+
node.grad = add(squeezedGrad, node.grad);
|
|
188
|
+
}
|
|
170
189
|
}
|
|
171
190
|
exports.Node = Node;
|
package/dist/tensor.d.ts
CHANGED
|
@@ -2,6 +2,7 @@ export type Tensor = number | Tensor[];
|
|
|
2
2
|
export declare class TensorMath {
|
|
3
3
|
static create(num: number, shape: number[]): Tensor;
|
|
4
4
|
static getShape(tA: Tensor): number[];
|
|
5
|
+
static padShape(tA: Tensor, tB: Tensor): [Tensor[], Tensor[]];
|
|
5
6
|
static add(tA: Tensor, tB: Tensor): Tensor;
|
|
6
7
|
static sub(tA: Tensor, tB: Tensor): Tensor;
|
|
7
8
|
static mul(tA: Tensor, tB: Tensor): Tensor;
|
|
@@ -18,4 +19,8 @@ export declare class TensorMath {
|
|
|
18
19
|
static relu(tA: Tensor): Tensor;
|
|
19
20
|
static sigmoid(tA: Tensor): Tensor;
|
|
20
21
|
static tanh(tA: Tensor): Tensor;
|
|
22
|
+
static squeezeAxis(tA: Tensor, axis: number): Tensor;
|
|
23
|
+
static squeeze(tA: Tensor, dims?: number[] | number): Tensor;
|
|
24
|
+
static sumAxis(tA: Tensor, axis: number): Tensor;
|
|
25
|
+
static sum(tA: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
21
26
|
}
|
package/dist/tensor.js
CHANGED
|
@@ -22,255 +22,188 @@ class TensorMath {
|
|
|
22
22
|
}
|
|
23
23
|
return shape;
|
|
24
24
|
}
|
|
25
|
+
static padShape(tA, tB) {
|
|
26
|
+
let dimA = TensorMath.getShape(tA).length;
|
|
27
|
+
let dimB = TensorMath.getShape(tB).length;
|
|
28
|
+
while (dimA < dimB) {
|
|
29
|
+
dimA++;
|
|
30
|
+
tA = [tA];
|
|
31
|
+
}
|
|
32
|
+
while (dimA > dimB) {
|
|
33
|
+
dimB++;
|
|
34
|
+
tB = [tB];
|
|
35
|
+
}
|
|
36
|
+
return [tA, tB];
|
|
37
|
+
}
|
|
25
38
|
static add(tA, tB) {
|
|
26
39
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
27
40
|
return tA + tB;
|
|
28
41
|
}
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
}
|
|
34
|
-
const result = [];
|
|
35
|
-
for (let i = 0; i < outLen; i++) {
|
|
36
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
37
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
38
|
-
result.push(TensorMath.add(subA, subB));
|
|
39
|
-
}
|
|
40
|
-
return result;
|
|
41
|
-
}
|
|
42
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
43
|
-
return tA.map(subA => TensorMath.add(subA, tB));
|
|
42
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
43
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
44
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
45
|
+
throw new Error("Inputs are incompatible tensors");
|
|
44
46
|
}
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
const result = [];
|
|
48
|
+
for (let i = 0; i < outLen; i++) {
|
|
49
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
50
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
51
|
+
result.push(TensorMath.add(subA, subB));
|
|
47
52
|
}
|
|
48
|
-
|
|
53
|
+
return result;
|
|
49
54
|
}
|
|
50
55
|
static sub(tA, tB) {
|
|
51
56
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
52
57
|
return tA - tB;
|
|
53
58
|
}
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
}
|
|
59
|
-
const result = [];
|
|
60
|
-
for (let i = 0; i < outLen; i++) {
|
|
61
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
62
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
63
|
-
result.push(TensorMath.sub(subA, subB));
|
|
64
|
-
}
|
|
65
|
-
return result;
|
|
66
|
-
}
|
|
67
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
68
|
-
return tA.map(subA => TensorMath.sub(subA, tB));
|
|
59
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
60
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
61
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
62
|
+
throw new Error("Inputs are incompatible tensors");
|
|
69
63
|
}
|
|
70
|
-
|
|
71
|
-
|
|
64
|
+
const result = [];
|
|
65
|
+
for (let i = 0; i < outLen; i++) {
|
|
66
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
67
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
68
|
+
result.push(TensorMath.sub(subA, subB));
|
|
72
69
|
}
|
|
73
|
-
|
|
70
|
+
return result;
|
|
74
71
|
}
|
|
75
72
|
static mul(tA, tB) {
|
|
76
73
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
77
74
|
return tA * tB;
|
|
78
75
|
}
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
}
|
|
84
|
-
const result = [];
|
|
85
|
-
for (let i = 0; i < outLen; i++) {
|
|
86
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
87
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
88
|
-
result.push(TensorMath.mul(subA, subB));
|
|
89
|
-
}
|
|
90
|
-
return result;
|
|
91
|
-
}
|
|
92
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
93
|
-
return tA.map(subA => TensorMath.mul(subA, tB));
|
|
76
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
77
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
78
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
79
|
+
throw new Error("Inputs are incompatible tensors");
|
|
94
80
|
}
|
|
95
|
-
|
|
96
|
-
|
|
81
|
+
const result = [];
|
|
82
|
+
for (let i = 0; i < outLen; i++) {
|
|
83
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
84
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
85
|
+
result.push(TensorMath.mul(subA, subB));
|
|
97
86
|
}
|
|
98
|
-
|
|
87
|
+
return result;
|
|
99
88
|
}
|
|
100
89
|
static pow(tA, tB) {
|
|
101
90
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
102
91
|
return tA ** tB;
|
|
103
92
|
}
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
}
|
|
109
|
-
const result = [];
|
|
110
|
-
for (let i = 0; i < outLen; i++) {
|
|
111
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
112
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
113
|
-
result.push(TensorMath.pow(subA, subB));
|
|
114
|
-
}
|
|
115
|
-
return result;
|
|
116
|
-
}
|
|
117
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
118
|
-
return tA.map(subA => TensorMath.pow(subA, tB));
|
|
93
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
94
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
95
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
96
|
+
throw new Error("Inputs are incompatible tensors");
|
|
119
97
|
}
|
|
120
|
-
|
|
121
|
-
|
|
98
|
+
const result = [];
|
|
99
|
+
for (let i = 0; i < outLen; i++) {
|
|
100
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
101
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
102
|
+
result.push(TensorMath.pow(subA, subB));
|
|
122
103
|
}
|
|
123
|
-
|
|
104
|
+
return result;
|
|
124
105
|
}
|
|
125
106
|
static div(tA, tB) {
|
|
126
107
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
127
108
|
return tA / tB;
|
|
128
109
|
}
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
}
|
|
134
|
-
const result = [];
|
|
135
|
-
for (let i = 0; i < outLen; i++) {
|
|
136
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
137
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
138
|
-
result.push(TensorMath.div(subA, subB));
|
|
139
|
-
}
|
|
140
|
-
return result;
|
|
141
|
-
}
|
|
142
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
143
|
-
return tA.map(subA => TensorMath.div(subA, tB));
|
|
110
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
111
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
112
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
113
|
+
throw new Error("Inputs are incompatible tensors");
|
|
144
114
|
}
|
|
145
|
-
|
|
146
|
-
|
|
115
|
+
const result = [];
|
|
116
|
+
for (let i = 0; i < outLen; i++) {
|
|
117
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
118
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
119
|
+
result.push(TensorMath.div(subA, subB));
|
|
147
120
|
}
|
|
148
|
-
|
|
121
|
+
return result;
|
|
149
122
|
}
|
|
150
123
|
static gt(tA, tB) {
|
|
151
124
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
152
125
|
return tA > tB ? 1 : 0;
|
|
153
126
|
}
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
}
|
|
159
|
-
const result = [];
|
|
160
|
-
for (let i = 0; i < outLen; i++) {
|
|
161
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
162
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
163
|
-
result.push(TensorMath.gt(subA, subB));
|
|
164
|
-
}
|
|
165
|
-
return result;
|
|
166
|
-
}
|
|
167
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
168
|
-
return tA.map(subA => TensorMath.gt(subA, tB));
|
|
127
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
128
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
129
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
130
|
+
throw new Error("Inputs are incompatible tensors");
|
|
169
131
|
}
|
|
170
|
-
|
|
171
|
-
|
|
132
|
+
const result = [];
|
|
133
|
+
for (let i = 0; i < outLen; i++) {
|
|
134
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
135
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
136
|
+
result.push(TensorMath.gt(subA, subB));
|
|
172
137
|
}
|
|
173
|
-
|
|
138
|
+
return result;
|
|
174
139
|
}
|
|
175
140
|
static lt(tA, tB) {
|
|
176
141
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
177
142
|
return tA < tB ? 1 : 0;
|
|
178
143
|
}
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
}
|
|
184
|
-
const result = [];
|
|
185
|
-
for (let i = 0; i < outLen; i++) {
|
|
186
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
187
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
188
|
-
result.push(TensorMath.lt(subA, subB));
|
|
189
|
-
}
|
|
190
|
-
return result;
|
|
191
|
-
}
|
|
192
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
193
|
-
return tA.map(subA => TensorMath.lt(subA, tB));
|
|
144
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
145
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
146
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
147
|
+
throw new Error("Inputs are incompatible tensors");
|
|
194
148
|
}
|
|
195
|
-
|
|
196
|
-
|
|
149
|
+
const result = [];
|
|
150
|
+
for (let i = 0; i < outLen; i++) {
|
|
151
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
152
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
153
|
+
result.push(TensorMath.lt(subA, subB));
|
|
197
154
|
}
|
|
198
|
-
|
|
155
|
+
return result;
|
|
199
156
|
}
|
|
200
157
|
static ge(tA, tB) {
|
|
201
158
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
202
159
|
return tA >= tB ? 1 : 0;
|
|
203
160
|
}
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
}
|
|
209
|
-
const result = [];
|
|
210
|
-
for (let i = 0; i < outLen; i++) {
|
|
211
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
212
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
213
|
-
result.push(TensorMath.ge(subA, subB));
|
|
214
|
-
}
|
|
215
|
-
return result;
|
|
216
|
-
}
|
|
217
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
218
|
-
return tA.map(subA => TensorMath.ge(subA, tB));
|
|
161
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
162
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
163
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
164
|
+
throw new Error("Inputs are incompatible tensors");
|
|
219
165
|
}
|
|
220
|
-
|
|
221
|
-
|
|
166
|
+
const result = [];
|
|
167
|
+
for (let i = 0; i < outLen; i++) {
|
|
168
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
169
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
170
|
+
result.push(TensorMath.ge(subA, subB));
|
|
222
171
|
}
|
|
223
|
-
|
|
172
|
+
return result;
|
|
224
173
|
}
|
|
225
174
|
static le(tA, tB) {
|
|
226
175
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
227
176
|
return tA <= tB ? 1 : 0;
|
|
228
177
|
}
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
}
|
|
234
|
-
const result = [];
|
|
235
|
-
for (let i = 0; i < outLen; i++) {
|
|
236
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
237
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
238
|
-
result.push(TensorMath.le(subA, subB));
|
|
239
|
-
}
|
|
240
|
-
return result;
|
|
241
|
-
}
|
|
242
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
243
|
-
return tA.map(subA => TensorMath.le(subA, tB));
|
|
178
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
179
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
180
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
181
|
+
throw new Error("Inputs are incompatible tensors");
|
|
244
182
|
}
|
|
245
|
-
|
|
246
|
-
|
|
183
|
+
const result = [];
|
|
184
|
+
for (let i = 0; i < outLen; i++) {
|
|
185
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
186
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
187
|
+
result.push(TensorMath.le(subA, subB));
|
|
247
188
|
}
|
|
248
|
-
|
|
189
|
+
return result;
|
|
249
190
|
}
|
|
250
191
|
static eq(tA, tB) {
|
|
251
192
|
if (typeof tA === "number" && typeof tB === "number") {
|
|
252
193
|
return tA === tB ? 1 : 0;
|
|
253
194
|
}
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
}
|
|
259
|
-
const result = [];
|
|
260
|
-
for (let i = 0; i < outLen; i++) {
|
|
261
|
-
const subA = tA[tA.length === 1 ? 0 : i];
|
|
262
|
-
const subB = tB[tB.length === 1 ? 0 : i];
|
|
263
|
-
result.push(TensorMath.eq(subA, subB));
|
|
264
|
-
}
|
|
265
|
-
return result;
|
|
266
|
-
}
|
|
267
|
-
else if (Array.isArray(tA) && typeof tB === "number") {
|
|
268
|
-
return tA.map(subA => TensorMath.eq(subA, tB));
|
|
195
|
+
[tA, tB] = TensorMath.padShape(tA, tB);
|
|
196
|
+
const outLen = Math.max(tA.length, tB.length);
|
|
197
|
+
if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
|
|
198
|
+
throw new Error("Inputs are incompatible tensors");
|
|
269
199
|
}
|
|
270
|
-
|
|
271
|
-
|
|
200
|
+
const result = [];
|
|
201
|
+
for (let i = 0; i < outLen; i++) {
|
|
202
|
+
const subA = tA[tA.length === 1 ? 0 : i];
|
|
203
|
+
const subB = tB[tB.length === 1 ? 0 : i];
|
|
204
|
+
result.push(TensorMath.eq(subA, subB));
|
|
272
205
|
}
|
|
273
|
-
|
|
206
|
+
return result;
|
|
274
207
|
}
|
|
275
208
|
static neg(tA) {
|
|
276
209
|
if (typeof tA === "number") {
|
|
@@ -320,5 +253,67 @@ class TensorMath {
|
|
|
320
253
|
return tA.map(subA => TensorMath.tanh(subA));
|
|
321
254
|
}
|
|
322
255
|
}
|
|
256
|
+
static squeezeAxis(tA, axis) {
|
|
257
|
+
if (typeof tA === "number")
|
|
258
|
+
return tA;
|
|
259
|
+
if (axis === 0) {
|
|
260
|
+
return tA[0];
|
|
261
|
+
}
|
|
262
|
+
else {
|
|
263
|
+
return tA.map(slice => TensorMath.squeezeAxis(slice, axis - 1));
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
static squeeze(tA, dims) {
|
|
267
|
+
if (typeof tA === "number")
|
|
268
|
+
return tA;
|
|
269
|
+
if (typeof dims === "number") {
|
|
270
|
+
dims = [dims];
|
|
271
|
+
}
|
|
272
|
+
if (typeof dims === "undefined") {
|
|
273
|
+
const shape = TensorMath.getShape(tA);
|
|
274
|
+
dims = [];
|
|
275
|
+
for (let index = 0; index < shape.length; index++) {
|
|
276
|
+
if (shape[index] === 1) {
|
|
277
|
+
dims.push(index);
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
dims = [...dims].sort((a, b) => b - a);
|
|
282
|
+
let out = tA;
|
|
283
|
+
for (const axis of dims) {
|
|
284
|
+
out = TensorMath.squeezeAxis(out, axis);
|
|
285
|
+
}
|
|
286
|
+
return out;
|
|
287
|
+
}
|
|
288
|
+
static sumAxis(tA, axis) {
|
|
289
|
+
if (typeof tA === "number")
|
|
290
|
+
return tA;
|
|
291
|
+
if (axis === 0) {
|
|
292
|
+
let result = tA[0];
|
|
293
|
+
for (let i = 1; i < tA.length; i++) {
|
|
294
|
+
result = TensorMath.add(result, tA[i]);
|
|
295
|
+
}
|
|
296
|
+
return [result];
|
|
297
|
+
}
|
|
298
|
+
else {
|
|
299
|
+
return tA.map(slice => TensorMath.sumAxis(slice, axis - 1));
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
static sum(tA, dims, keepDims = false) {
|
|
303
|
+
if (typeof tA === "number")
|
|
304
|
+
return tA;
|
|
305
|
+
if (typeof dims === "number") {
|
|
306
|
+
dims = [dims];
|
|
307
|
+
}
|
|
308
|
+
if (typeof dims === "undefined") {
|
|
309
|
+
dims = Array.from({ length: TensorMath.getShape(tA).length }, (_, index) => index);
|
|
310
|
+
}
|
|
311
|
+
dims = [...dims].sort((a, b) => b - a);
|
|
312
|
+
let out = tA;
|
|
313
|
+
for (const axis of dims) {
|
|
314
|
+
out = TensorMath.sumAxis(out, axis);
|
|
315
|
+
}
|
|
316
|
+
return keepDims ? out : TensorMath.squeeze(out, dims);
|
|
317
|
+
}
|
|
323
318
|
}
|
|
324
319
|
exports.TensorMath = TensorMath;
|