catniff 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,4 +1,4 @@
1
- ## Catniff
1
+ # Catniff
2
2
 
3
3
  Catniff is a small and experimental tensor library and autograd engine inspired by [micrograd](https://github.com/karpathy/micrograd). The name is a play on "catnip" and "differentiation".
4
4
 
@@ -31,7 +31,7 @@ There is a built-in `TensorMath` class to help with Tensor arithmetic, for examp
31
31
  const { TensorMath } = require("catniff");
32
32
 
33
33
  const A = [ 1, 2, 3 ];
34
- const B = 3
34
+ const B = 3;
35
35
  console.log(TensorMath.add(A, B));
36
36
  ```
37
37
 
@@ -43,11 +43,24 @@ To compute the gradient of our mathematical expression, we use the `Node` class
43
43
  ```js
44
44
  const { Node } = require("catniff");
45
45
 
46
- const X = new Node([ 1, 2, 3 ]);
47
- const L = x.pow(2).add(x); // X^2 + X
46
+ const X = new Node([
47
+ [ 0.5, -1.0 ],
48
+ [ 2.0, 0.0 ]
49
+ ]);
48
50
 
49
- L.backward();
50
- console.log(x.grad);
51
+ const Y = new Node([
52
+ [ 1.0, -2.0 ],
53
+ [ 0.5, 1.5 ]
54
+ ]);
55
+
56
+ const D = X.sub(Y);
57
+ const E = D.exp();
58
+ const F = E.add(1);
59
+ const G = F.log();
60
+
61
+ G.backward();
62
+
63
+ console.log(X.grad, Y.grad);
51
64
  ```
52
65
 
53
66
  All available APIs are in `./src/autograd.ts`.
package/dist/tensor.d.ts CHANGED
@@ -2,6 +2,7 @@ export type Tensor = number | Tensor[];
2
2
  export declare class TensorMath {
3
3
  static create(num: number, shape: number[]): Tensor;
4
4
  static getShape(tA: Tensor): number[];
5
+ static padShape(tA: Tensor, tB: Tensor): [Tensor[], Tensor[]];
5
6
  static add(tA: Tensor, tB: Tensor): Tensor;
6
7
  static sub(tA: Tensor, tB: Tensor): Tensor;
7
8
  static mul(tA: Tensor, tB: Tensor): Tensor;
@@ -11,6 +12,7 @@ export declare class TensorMath {
11
12
  static lt(tA: Tensor, tB: Tensor): Tensor;
12
13
  static ge(tA: Tensor, tB: Tensor): Tensor;
13
14
  static le(tA: Tensor, tB: Tensor): Tensor;
15
+ static eq(tA: Tensor, tB: Tensor): Tensor;
14
16
  static neg(tA: Tensor): Tensor;
15
17
  static exp(tA: Tensor): Tensor;
16
18
  static log(tA: Tensor): Tensor;
package/dist/tensor.js CHANGED
@@ -22,230 +22,188 @@ class TensorMath {
22
22
  }
23
23
  return shape;
24
24
  }
25
+ static padShape(tA, tB) {
26
+ let dimA = TensorMath.getShape(tA).length;
27
+ let dimB = TensorMath.getShape(tB).length;
28
+ while (dimA < dimB) {
29
+ dimA++;
30
+ tA = [tA];
31
+ }
32
+ while (dimA > dimB) {
33
+ dimB++;
34
+ tB = [tB];
35
+ }
36
+ return [tA, tB];
37
+ }
25
38
  static add(tA, tB) {
26
39
  if (typeof tA === "number" && typeof tB === "number") {
27
40
  return tA + tB;
28
41
  }
29
- else if (Array.isArray(tA) && Array.isArray(tB)) {
30
- const outLen = Math.max(tA.length, tB.length);
31
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
32
- throw new Error("Inputs are incompatible tensors");
33
- }
34
- const result = [];
35
- for (let i = 0; i < outLen; i++) {
36
- const subA = tA[tA.length === 1 ? 0 : i];
37
- const subB = tB[tB.length === 1 ? 0 : i];
38
- result.push(TensorMath.add(subA, subB));
39
- }
40
- return result;
41
- }
42
- else if (Array.isArray(tA) && typeof tB === "number") {
43
- return tA.map(subA => TensorMath.add(subA, tB));
44
- }
45
- else if (typeof tA === "number" && Array.isArray(tB)) {
46
- return tB.map(subB => TensorMath.add(tA, subB));
47
- }
48
- throw new Error("Inputs are not tensors");
42
+ [tA, tB] = TensorMath.padShape(tA, tB);
43
+ const outLen = Math.max(tA.length, tB.length);
44
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
45
+ throw new Error("Inputs are incompatible tensors");
46
+ }
47
+ const result = [];
48
+ for (let i = 0; i < outLen; i++) {
49
+ const subA = tA[tA.length === 1 ? 0 : i];
50
+ const subB = tB[tB.length === 1 ? 0 : i];
51
+ result.push(TensorMath.add(subA, subB));
52
+ }
53
+ return result;
49
54
  }
50
55
  static sub(tA, tB) {
51
56
  if (typeof tA === "number" && typeof tB === "number") {
52
57
  return tA - tB;
53
58
  }
54
- else if (Array.isArray(tA) && Array.isArray(tB)) {
55
- const outLen = Math.max(tA.length, tB.length);
56
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
57
- throw new Error("Inputs are incompatible tensors");
58
- }
59
- const result = [];
60
- for (let i = 0; i < outLen; i++) {
61
- const subA = tA[tA.length === 1 ? 0 : i];
62
- const subB = tB[tB.length === 1 ? 0 : i];
63
- result.push(TensorMath.sub(subA, subB));
64
- }
65
- return result;
66
- }
67
- else if (Array.isArray(tA) && typeof tB === "number") {
68
- return tA.map(subA => TensorMath.sub(subA, tB));
69
- }
70
- else if (typeof tA === "number" && Array.isArray(tB)) {
71
- return tB.map(subB => TensorMath.sub(tA, subB));
72
- }
73
- throw new Error("Inputs are not tensors");
59
+ [tA, tB] = TensorMath.padShape(tA, tB);
60
+ const outLen = Math.max(tA.length, tB.length);
61
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
62
+ throw new Error("Inputs are incompatible tensors");
63
+ }
64
+ const result = [];
65
+ for (let i = 0; i < outLen; i++) {
66
+ const subA = tA[tA.length === 1 ? 0 : i];
67
+ const subB = tB[tB.length === 1 ? 0 : i];
68
+ result.push(TensorMath.sub(subA, subB));
69
+ }
70
+ return result;
74
71
  }
75
72
  static mul(tA, tB) {
76
73
  if (typeof tA === "number" && typeof tB === "number") {
77
74
  return tA * tB;
78
75
  }
79
- else if (Array.isArray(tA) && Array.isArray(tB)) {
80
- const outLen = Math.max(tA.length, tB.length);
81
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
82
- throw new Error("Inputs are incompatible tensors");
83
- }
84
- const result = [];
85
- for (let i = 0; i < outLen; i++) {
86
- const subA = tA[tA.length === 1 ? 0 : i];
87
- const subB = tB[tB.length === 1 ? 0 : i];
88
- result.push(TensorMath.mul(subA, subB));
89
- }
90
- return result;
91
- }
92
- else if (Array.isArray(tA) && typeof tB === "number") {
93
- return tA.map(subA => TensorMath.mul(subA, tB));
94
- }
95
- else if (typeof tA === "number" && Array.isArray(tB)) {
96
- return tB.map(subB => TensorMath.mul(tA, subB));
97
- }
98
- throw new Error("Inputs are not tensors");
76
+ [tA, tB] = TensorMath.padShape(tA, tB);
77
+ const outLen = Math.max(tA.length, tB.length);
78
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
79
+ throw new Error("Inputs are incompatible tensors");
80
+ }
81
+ const result = [];
82
+ for (let i = 0; i < outLen; i++) {
83
+ const subA = tA[tA.length === 1 ? 0 : i];
84
+ const subB = tB[tB.length === 1 ? 0 : i];
85
+ result.push(TensorMath.mul(subA, subB));
86
+ }
87
+ return result;
99
88
  }
100
89
  static pow(tA, tB) {
101
90
  if (typeof tA === "number" && typeof tB === "number") {
102
91
  return tA ** tB;
103
92
  }
104
- else if (Array.isArray(tA) && Array.isArray(tB)) {
105
- const outLen = Math.max(tA.length, tB.length);
106
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
107
- throw new Error("Inputs are incompatible tensors");
108
- }
109
- const result = [];
110
- for (let i = 0; i < outLen; i++) {
111
- const subA = tA[tA.length === 1 ? 0 : i];
112
- const subB = tB[tB.length === 1 ? 0 : i];
113
- result.push(TensorMath.pow(subA, subB));
114
- }
115
- return result;
116
- }
117
- else if (Array.isArray(tA) && typeof tB === "number") {
118
- return tA.map(subA => TensorMath.pow(subA, tB));
119
- }
120
- else if (typeof tA === "number" && Array.isArray(tB)) {
121
- return tB.map(subB => TensorMath.pow(tA, subB));
122
- }
123
- throw new Error("Inputs are not tensors");
93
+ [tA, tB] = TensorMath.padShape(tA, tB);
94
+ const outLen = Math.max(tA.length, tB.length);
95
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
96
+ throw new Error("Inputs are incompatible tensors");
97
+ }
98
+ const result = [];
99
+ for (let i = 0; i < outLen; i++) {
100
+ const subA = tA[tA.length === 1 ? 0 : i];
101
+ const subB = tB[tB.length === 1 ? 0 : i];
102
+ result.push(TensorMath.pow(subA, subB));
103
+ }
104
+ return result;
124
105
  }
125
106
  static div(tA, tB) {
126
107
  if (typeof tA === "number" && typeof tB === "number") {
127
108
  return tA / tB;
128
109
  }
129
- else if (Array.isArray(tA) && Array.isArray(tB)) {
130
- const outLen = Math.max(tA.length, tB.length);
131
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
132
- throw new Error("Inputs are incompatible tensors");
133
- }
134
- const result = [];
135
- for (let i = 0; i < outLen; i++) {
136
- const subA = tA[tA.length === 1 ? 0 : i];
137
- const subB = tB[tB.length === 1 ? 0 : i];
138
- result.push(TensorMath.div(subA, subB));
139
- }
140
- return result;
141
- }
142
- else if (Array.isArray(tA) && typeof tB === "number") {
143
- return tA.map(subA => TensorMath.div(subA, tB));
144
- }
145
- else if (typeof tA === "number" && Array.isArray(tB)) {
146
- return tB.map(subB => TensorMath.div(tA, subB));
147
- }
148
- throw new Error("Inputs are not tensors");
110
+ [tA, tB] = TensorMath.padShape(tA, tB);
111
+ const outLen = Math.max(tA.length, tB.length);
112
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
113
+ throw new Error("Inputs are incompatible tensors");
114
+ }
115
+ const result = [];
116
+ for (let i = 0; i < outLen; i++) {
117
+ const subA = tA[tA.length === 1 ? 0 : i];
118
+ const subB = tB[tB.length === 1 ? 0 : i];
119
+ result.push(TensorMath.div(subA, subB));
120
+ }
121
+ return result;
149
122
  }
150
123
  static gt(tA, tB) {
151
124
  if (typeof tA === "number" && typeof tB === "number") {
152
125
  return tA > tB ? 1 : 0;
153
126
  }
154
- else if (Array.isArray(tA) && Array.isArray(tB)) {
155
- const outLen = Math.max(tA.length, tB.length);
156
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
157
- throw new Error("Inputs are incompatible tensors");
158
- }
159
- const result = [];
160
- for (let i = 0; i < outLen; i++) {
161
- const subA = tA[tA.length === 1 ? 0 : i];
162
- const subB = tB[tB.length === 1 ? 0 : i];
163
- result.push(TensorMath.gt(subA, subB));
164
- }
165
- return result;
166
- }
167
- else if (Array.isArray(tA) && typeof tB === "number") {
168
- return tA.map(subA => TensorMath.gt(subA, tB));
169
- }
170
- else if (typeof tA === "number" && Array.isArray(tB)) {
171
- return tB.map(subB => TensorMath.gt(tA, subB));
172
- }
173
- throw new Error("Inputs are not tensors");
127
+ [tA, tB] = TensorMath.padShape(tA, tB);
128
+ const outLen = Math.max(tA.length, tB.length);
129
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
130
+ throw new Error("Inputs are incompatible tensors");
131
+ }
132
+ const result = [];
133
+ for (let i = 0; i < outLen; i++) {
134
+ const subA = tA[tA.length === 1 ? 0 : i];
135
+ const subB = tB[tB.length === 1 ? 0 : i];
136
+ result.push(TensorMath.gt(subA, subB));
137
+ }
138
+ return result;
174
139
  }
175
140
  static lt(tA, tB) {
176
141
  if (typeof tA === "number" && typeof tB === "number") {
177
142
  return tA < tB ? 1 : 0;
178
143
  }
179
- else if (Array.isArray(tA) && Array.isArray(tB)) {
180
- const outLen = Math.max(tA.length, tB.length);
181
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
182
- throw new Error("Inputs are incompatible tensors");
183
- }
184
- const result = [];
185
- for (let i = 0; i < outLen; i++) {
186
- const subA = tA[tA.length === 1 ? 0 : i];
187
- const subB = tB[tB.length === 1 ? 0 : i];
188
- result.push(TensorMath.lt(subA, subB));
189
- }
190
- return result;
191
- }
192
- else if (Array.isArray(tA) && typeof tB === "number") {
193
- return tA.map(subA => TensorMath.lt(subA, tB));
194
- }
195
- else if (typeof tA === "number" && Array.isArray(tB)) {
196
- return tB.map(subB => TensorMath.lt(tA, subB));
197
- }
198
- throw new Error("Inputs are not tensors");
144
+ [tA, tB] = TensorMath.padShape(tA, tB);
145
+ const outLen = Math.max(tA.length, tB.length);
146
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
147
+ throw new Error("Inputs are incompatible tensors");
148
+ }
149
+ const result = [];
150
+ for (let i = 0; i < outLen; i++) {
151
+ const subA = tA[tA.length === 1 ? 0 : i];
152
+ const subB = tB[tB.length === 1 ? 0 : i];
153
+ result.push(TensorMath.lt(subA, subB));
154
+ }
155
+ return result;
199
156
  }
200
157
  static ge(tA, tB) {
201
158
  if (typeof tA === "number" && typeof tB === "number") {
202
159
  return tA >= tB ? 1 : 0;
203
160
  }
204
- else if (Array.isArray(tA) && Array.isArray(tB)) {
205
- const outLen = Math.max(tA.length, tB.length);
206
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
207
- throw new Error("Inputs are incompatible tensors");
208
- }
209
- const result = [];
210
- for (let i = 0; i < outLen; i++) {
211
- const subA = tA[tA.length === 1 ? 0 : i];
212
- const subB = tB[tB.length === 1 ? 0 : i];
213
- result.push(TensorMath.ge(subA, subB));
214
- }
215
- return result;
216
- }
217
- else if (Array.isArray(tA) && typeof tB === "number") {
218
- return tA.map(subA => TensorMath.ge(subA, tB));
219
- }
220
- else if (typeof tA === "number" && Array.isArray(tB)) {
221
- return tB.map(subB => TensorMath.ge(tA, subB));
222
- }
223
- throw new Error("Inputs are not tensors");
161
+ [tA, tB] = TensorMath.padShape(tA, tB);
162
+ const outLen = Math.max(tA.length, tB.length);
163
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
164
+ throw new Error("Inputs are incompatible tensors");
165
+ }
166
+ const result = [];
167
+ for (let i = 0; i < outLen; i++) {
168
+ const subA = tA[tA.length === 1 ? 0 : i];
169
+ const subB = tB[tB.length === 1 ? 0 : i];
170
+ result.push(TensorMath.ge(subA, subB));
171
+ }
172
+ return result;
224
173
  }
225
174
  static le(tA, tB) {
226
175
  if (typeof tA === "number" && typeof tB === "number") {
227
176
  return tA <= tB ? 1 : 0;
228
177
  }
229
- else if (Array.isArray(tA) && Array.isArray(tB)) {
230
- const outLen = Math.max(tA.length, tB.length);
231
- if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
232
- throw new Error("Inputs are incompatible tensors");
233
- }
234
- const result = [];
235
- for (let i = 0; i < outLen; i++) {
236
- const subA = tA[tA.length === 1 ? 0 : i];
237
- const subB = tB[tB.length === 1 ? 0 : i];
238
- result.push(TensorMath.le(subA, subB));
239
- }
240
- return result;
241
- }
242
- else if (Array.isArray(tA) && typeof tB === "number") {
243
- return tA.map(subA => TensorMath.le(subA, tB));
244
- }
245
- else if (typeof tA === "number" && Array.isArray(tB)) {
246
- return tB.map(subB => TensorMath.le(tA, subB));
247
- }
248
- throw new Error("Inputs are not tensors");
178
+ [tA, tB] = TensorMath.padShape(tA, tB);
179
+ const outLen = Math.max(tA.length, tB.length);
180
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
181
+ throw new Error("Inputs are incompatible tensors");
182
+ }
183
+ const result = [];
184
+ for (let i = 0; i < outLen; i++) {
185
+ const subA = tA[tA.length === 1 ? 0 : i];
186
+ const subB = tB[tB.length === 1 ? 0 : i];
187
+ result.push(TensorMath.le(subA, subB));
188
+ }
189
+ return result;
190
+ }
191
+ static eq(tA, tB) {
192
+ if (typeof tA === "number" && typeof tB === "number") {
193
+ return tA === tB ? 1 : 0;
194
+ }
195
+ [tA, tB] = TensorMath.padShape(tA, tB);
196
+ const outLen = Math.max(tA.length, tB.length);
197
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
198
+ throw new Error("Inputs are incompatible tensors");
199
+ }
200
+ const result = [];
201
+ for (let i = 0; i < outLen; i++) {
202
+ const subA = tA[tA.length === 1 ? 0 : i];
203
+ const subB = tB[tB.length === 1 ? 0 : i];
204
+ result.push(TensorMath.eq(subA, subB));
205
+ }
206
+ return result;
249
207
  }
250
208
  static neg(tA) {
251
209
  if (typeof tA === "number") {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.1",
3
+ "version": "0.1.3",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {