catniff 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,6 +1,6 @@
1
- ## Catniff
1
+ # Catniff
2
2
 
3
- Catniff is a small, experimental autograd engine inspired by [micrograd](https://github.com/karpathy/micrograd). The name is a play on "catnip" and "differentiation".
3
+ Catniff is a small and experimental tensor library and autograd engine inspired by [micrograd](https://github.com/karpathy/micrograd). The name is a play on "catnip" and "differentiation".
4
4
 
5
5
  ## Setup
6
6
 
@@ -16,20 +16,59 @@ Here is a little demo of a quadratic function:
16
16
  const { Node } = require("catniff");
17
17
 
18
18
  const x = new Node(2);
19
- const L = x.pow(2).add(x)
19
+ const L = x.pow(2).add(x); // x^2 + x
20
20
 
21
21
  L.backward();
22
- console.log(x.grad, L.grad);
22
+ console.log(x.grad); // 5
23
23
  ```
24
24
 
25
- All available APIs are in `./src/core.ts`.
25
+ ## Tensors
26
+
27
+ Tensors in Catniff are either numbers (scalars/0-D tensors) or multidimensional number arrays (n-D tensors).
28
+
29
+ There is a built-in `TensorMath` class to help with Tensor arithmetic, for example:
30
+ ```js
31
+ const { TensorMath } = require("catniff");
32
+
33
+ const A = [ 1, 2, 3 ];
34
+ const B = 3
35
+ console.log(TensorMath.add(A, B));
36
+ ```
37
+
38
+ All available APIs are in `./src/tensor.ts`.
39
+
40
+ ## Autograd
41
+
42
+ To compute the gradient of our mathematical expression, we use the `Node` class to dynamically build our DAG:
43
+ ```js
44
+ const { Node } = require("../index");
45
+
46
+ const X = new Node([
47
+ [ 0.5, -1.0 ],
48
+ [ 2.0, 0.0 ]
49
+ ]);
50
+
51
+ const Y = new Node([
52
+ [ 1.0, -2.0 ],
53
+ [ 0.5, 1.5 ]
54
+ ]);
55
+
56
+ const D = X.sub(Y);
57
+ const E = D.exp();
58
+ const F = E.add(1);
59
+ const G = F.log();
60
+
61
+ G.backward();
62
+
63
+ console.log(X.grad, Y.grad);
64
+ ```
65
+
66
+ All available APIs are in `./src/autograd.ts`.
26
67
 
27
68
  ## Todos
28
69
 
29
70
  I'm mostly just learning and playing with this currently, so there are no concrete plans yet, but here are what I currently have in mind:
30
71
 
31
- * A built-in Tensor maths lib.
32
- * Support for Tensors in the autograd engine.
33
72
  * GPU acceleration.
34
73
  * Some general neural net APIs.
35
74
 
@@ -1,24 +1,28 @@
1
+ import { Tensor } from "./tensor";
1
2
  export declare enum OP {
2
3
  NONE = 0,
3
4
  ADD = 1,
4
- MUL = 2,
5
- POW = 3,
6
- DIV = 4,
7
- NEG = 5,
8
- EXP = 6,
9
- LOG = 7,
10
- RELU = 8,
11
- SIGMOID = 9,
12
- TANH = 10
5
+ SUB = 2,
6
+ MUL = 3,
7
+ POW = 4,
8
+ DIV = 5,
9
+ NEG = 6,
10
+ EXP = 7,
11
+ LOG = 8,
12
+ RELU = 9,
13
+ SIGMOID = 10,
14
+ TANH = 11
13
15
  }
14
16
  export declare class Node {
15
- value: number;
16
- grad: number;
17
+ value: Tensor;
18
+ shape: number[];
19
+ grad: Tensor;
17
20
  children: Node[];
18
- feedBackward: Function;
19
21
  op: OP;
20
- constructor(value: number, children?: Node[], op?: OP);
22
+ feedBackward: Function;
23
+ constructor(value: Tensor, children?: Node[], op?: OP);
21
24
  add(other: Node | number): Node;
25
+ sub(other: Node | number): Node;
22
26
  mul(other: Node | number): Node;
23
27
  pow(other: Node | number): Node;
24
28
  div(other: Node | number): Node;
@@ -0,0 +1,171 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Node = exports.OP = void 0;
4
+ const tensor_1 = require("./tensor");
5
+ const { add, sub, mul, pow, div, neg, exp, log, relu, sigmoid, tanh, ge } = tensor_1.TensorMath;
6
+ var OP;
7
+ (function (OP) {
8
+ OP[OP["NONE"] = 0] = "NONE";
9
+ OP[OP["ADD"] = 1] = "ADD";
10
+ OP[OP["SUB"] = 2] = "SUB";
11
+ OP[OP["MUL"] = 3] = "MUL";
12
+ OP[OP["POW"] = 4] = "POW";
13
+ OP[OP["DIV"] = 5] = "DIV";
14
+ OP[OP["NEG"] = 6] = "NEG";
15
+ OP[OP["EXP"] = 7] = "EXP";
16
+ OP[OP["LOG"] = 8] = "LOG";
17
+ OP[OP["RELU"] = 9] = "RELU";
18
+ OP[OP["SIGMOID"] = 10] = "SIGMOID";
19
+ OP[OP["TANH"] = 11] = "TANH";
20
+ })(OP || (exports.OP = OP = {}));
21
+ class Node {
22
+ value;
23
+ shape;
24
+ grad;
25
+ children;
26
+ op;
27
+ feedBackward;
28
+ constructor(value, children = [], op = OP.NONE) {
29
+ this.value = value;
30
+ this.shape = tensor_1.TensorMath.getShape(value);
31
+ this.grad = tensor_1.TensorMath.create(0, this.shape);
32
+ this.children = children;
33
+ this.op = op;
34
+ this.feedBackward = () => { };
35
+ }
36
+ add(other) {
37
+ other = this.forceNode(other);
38
+ const out = new Node(add(this.value, other.value), [this, other], OP.ADD);
39
+ out.feedBackward = () => {
40
+ // x + y d/dx = 1, note that we apply the chain rule continuously so out.grad is multiplied into our derivative
41
+ this.grad = add(this.grad, out.grad);
42
+ // x + y d/dy = 1
43
+ other.grad = add(other.grad, out.grad);
44
+ };
45
+ return out;
46
+ }
47
+ sub(other) {
48
+ other = this.forceNode(other);
49
+ const out = new Node(sub(this.value, other.value), [this, other], OP.SUB);
50
+ out.feedBackward = () => {
51
+ // x - y d/dx = 1
52
+ this.grad = add(this.grad, out.grad);
53
+ // x - y d/dy = -1
54
+ other.grad = add(other.grad, neg(out.grad));
55
+ };
56
+ return out;
57
+ }
58
+ mul(other) {
59
+ other = this.forceNode(other);
60
+ const out = new Node(mul(this.value, other.value), [this, other], OP.MUL);
61
+ out.feedBackward = () => {
62
+ // x * y d/dx = y
63
+ this.grad = add(this.grad, mul(out.grad, other.value));
64
+ // x + y d/dy = x
65
+ other.grad = add(other.grad, mul(out.grad, this.value));
66
+ };
67
+ return out;
68
+ }
69
+ pow(other) {
70
+ if (other instanceof Node) {
71
+ const out = new Node(pow(this.value, other.value), [this, other], OP.POW);
72
+ out.feedBackward = () => {
73
+ // x^a d/dx = a*x^(a-1)
74
+ this.grad = add(this.grad, mul(out.grad, mul(other.value, pow(this.value, sub(other.value, 1)))));
75
+ // x^a d/da = x^a*lnx
76
+ other.grad = add(other.grad, mul(out.grad, mul(pow(this.value, other.value), log(this.value))));
77
+ };
78
+ return out;
79
+ }
80
+ const out = new Node(pow(this.value, other), [this], OP.POW);
81
+ out.feedBackward = () => {
82
+ this.grad = add(this.grad, mul(out.grad, mul(other, pow(this.value, sub(other, 1)))));
83
+ };
84
+ return out;
85
+ }
86
+ div(other) {
87
+ other = this.forceNode(other);
88
+ const out = new Node(div(this.value, other.value), [this, other], OP.DIV);
89
+ out.feedBackward = () => {
90
+ // x/y d/dx = 1/y
91
+ this.grad = add(this.grad, div(out.grad, other.value));
92
+ // x/y d/dy = -x/y^2
93
+ other.grad = add(other.grad, mul(out.grad, div(neg(this.value), pow(other.value, 2))));
94
+ };
95
+ return out;
96
+ }
97
+ neg() {
98
+ const out = new Node(neg(this.value), [this], OP.NEG);
99
+ out.feedBackward = () => {
100
+ // -x d/dx = -1
101
+ this.grad = add(this.grad, neg(out.grad));
102
+ };
103
+ return out;
104
+ }
105
+ exp() {
106
+ const expResult = exp(this.value);
107
+ const out = new Node(expResult, [this], OP.EXP);
108
+ out.feedBackward = () => {
109
+ // e^x d/dx = e^x
110
+ this.grad = add(this.grad, mul(out.grad, expResult));
111
+ };
112
+ return out;
113
+ }
114
+ log() {
115
+ const out = new Node(log(this.value), [this], OP.LOG);
116
+ out.feedBackward = () => {
117
+ // lnx d/dx = 1/x
118
+ this.grad = add(this.grad, div(out.grad, this.value));
119
+ };
120
+ return out;
121
+ }
122
+ relu() {
123
+ const out = new Node(relu(this.value), [this], OP.RELU);
124
+ out.feedBackward = () => {
125
+ this.grad = add(this.grad, mul(out.grad, ge(this.value, 0)));
126
+ };
127
+ return out;
128
+ }
129
+ sigmoid() {
130
+ const sigmoidResult = sigmoid(this.value);
131
+ const out = new Node(sigmoidResult, [this], OP.SIGMOID);
132
+ out.feedBackward = () => {
133
+ this.grad = add(this.grad, mul(mul(out.grad, sigmoidResult), sub(1, sigmoidResult)));
134
+ };
135
+ return out;
136
+ }
137
+ tanh() {
138
+ const tanhResult = tanh(this.value);
139
+ const out = new Node(tanhResult, [this], OP.TANH);
140
+ out.feedBackward = () => {
141
+ this.grad = add(this.grad, mul(out.grad, sub(1, mul(tanhResult, tanhResult))));
142
+ };
143
+ return out;
144
+ }
145
+ backward() {
146
+ // Build topological order
147
+ const topo = [];
148
+ const visited = new Set();
149
+ function build(node) {
150
+ if (!visited.has(node)) {
151
+ visited.add(node);
152
+ node.grad = tensor_1.TensorMath.create(0, node.shape);
153
+ for (let child of node.children)
154
+ build(child);
155
+ topo.push(node);
156
+ }
157
+ }
158
+ build(this);
159
+ // Feed backward to calculate gradient
160
+ this.grad = tensor_1.TensorMath.create(1, this.shape); // Derivative of itself with respect to itself
161
+ for (let index = topo.length - 1; index > -1; index--) {
162
+ topo[index].feedBackward();
163
+ }
164
+ }
165
+ forceNode(value) {
166
+ if (value instanceof Node)
167
+ return value;
168
+ return new Node(value);
169
+ }
170
+ }
171
+ exports.Node = Node;
@@ -0,0 +1,21 @@
1
+ export type Tensor = number | Tensor[];
2
+ export declare class TensorMath {
3
+ static create(num: number, shape: number[]): Tensor;
4
+ static getShape(tA: Tensor): number[];
5
+ static add(tA: Tensor, tB: Tensor): Tensor;
6
+ static sub(tA: Tensor, tB: Tensor): Tensor;
7
+ static mul(tA: Tensor, tB: Tensor): Tensor;
8
+ static pow(tA: Tensor, tB: Tensor): Tensor;
9
+ static div(tA: Tensor, tB: Tensor): Tensor;
10
+ static gt(tA: Tensor, tB: Tensor): Tensor;
11
+ static lt(tA: Tensor, tB: Tensor): Tensor;
12
+ static ge(tA: Tensor, tB: Tensor): Tensor;
13
+ static le(tA: Tensor, tB: Tensor): Tensor;
14
+ static eq(tA: Tensor, tB: Tensor): Tensor;
15
+ static neg(tA: Tensor): Tensor;
16
+ static exp(tA: Tensor): Tensor;
17
+ static log(tA: Tensor): Tensor;
18
+ static relu(tA: Tensor): Tensor;
19
+ static sigmoid(tA: Tensor): Tensor;
20
+ static tanh(tA: Tensor): Tensor;
21
+ }
package/dist/tensor.js ADDED
@@ -0,0 +1,324 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TensorMath = void 0;
4
+ class TensorMath {
5
+ static create(num, shape) {
6
+ if (shape.length === 0) {
7
+ return num;
8
+ }
9
+ const [dim, ...rest] = shape;
10
+ const out = [];
11
+ for (let i = 0; i < dim; i++) {
12
+ out.push(TensorMath.create(num, rest));
13
+ }
14
+ return out;
15
+ }
16
+ static getShape(tA) {
17
+ const shape = [];
18
+ let subA = tA;
19
+ while (Array.isArray(subA)) {
20
+ shape.push(subA.length);
21
+ subA = subA[0];
22
+ }
23
+ return shape;
24
+ }
25
+ static add(tA, tB) {
26
+ if (typeof tA === "number" && typeof tB === "number") {
27
+ return tA + tB;
28
+ }
29
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
30
+ const outLen = Math.max(tA.length, tB.length);
31
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
32
+ throw new Error("Inputs are incompatible tensors");
33
+ }
34
+ const result = [];
35
+ for (let i = 0; i < outLen; i++) {
36
+ const subA = tA[tA.length === 1 ? 0 : i];
37
+ const subB = tB[tB.length === 1 ? 0 : i];
38
+ result.push(TensorMath.add(subA, subB));
39
+ }
40
+ return result;
41
+ }
42
+ else if (Array.isArray(tA) && typeof tB === "number") {
43
+ return tA.map(subA => TensorMath.add(subA, tB));
44
+ }
45
+ else if (typeof tA === "number" && Array.isArray(tB)) {
46
+ return tB.map(subB => TensorMath.add(tA, subB));
47
+ }
48
+ throw new Error("Inputs are not tensors");
49
+ }
50
+ static sub(tA, tB) {
51
+ if (typeof tA === "number" && typeof tB === "number") {
52
+ return tA - tB;
53
+ }
54
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
55
+ const outLen = Math.max(tA.length, tB.length);
56
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
57
+ throw new Error("Inputs are incompatible tensors");
58
+ }
59
+ const result = [];
60
+ for (let i = 0; i < outLen; i++) {
61
+ const subA = tA[tA.length === 1 ? 0 : i];
62
+ const subB = tB[tB.length === 1 ? 0 : i];
63
+ result.push(TensorMath.sub(subA, subB));
64
+ }
65
+ return result;
66
+ }
67
+ else if (Array.isArray(tA) && typeof tB === "number") {
68
+ return tA.map(subA => TensorMath.sub(subA, tB));
69
+ }
70
+ else if (typeof tA === "number" && Array.isArray(tB)) {
71
+ return tB.map(subB => TensorMath.sub(tA, subB));
72
+ }
73
+ throw new Error("Inputs are not tensors");
74
+ }
75
+ static mul(tA, tB) {
76
+ if (typeof tA === "number" && typeof tB === "number") {
77
+ return tA * tB;
78
+ }
79
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
80
+ const outLen = Math.max(tA.length, tB.length);
81
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
82
+ throw new Error("Inputs are incompatible tensors");
83
+ }
84
+ const result = [];
85
+ for (let i = 0; i < outLen; i++) {
86
+ const subA = tA[tA.length === 1 ? 0 : i];
87
+ const subB = tB[tB.length === 1 ? 0 : i];
88
+ result.push(TensorMath.mul(subA, subB));
89
+ }
90
+ return result;
91
+ }
92
+ else if (Array.isArray(tA) && typeof tB === "number") {
93
+ return tA.map(subA => TensorMath.mul(subA, tB));
94
+ }
95
+ else if (typeof tA === "number" && Array.isArray(tB)) {
96
+ return tB.map(subB => TensorMath.mul(tA, subB));
97
+ }
98
+ throw new Error("Inputs are not tensors");
99
+ }
100
+ static pow(tA, tB) {
101
+ if (typeof tA === "number" && typeof tB === "number") {
102
+ return tA ** tB;
103
+ }
104
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
105
+ const outLen = Math.max(tA.length, tB.length);
106
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
107
+ throw new Error("Inputs are incompatible tensors");
108
+ }
109
+ const result = [];
110
+ for (let i = 0; i < outLen; i++) {
111
+ const subA = tA[tA.length === 1 ? 0 : i];
112
+ const subB = tB[tB.length === 1 ? 0 : i];
113
+ result.push(TensorMath.pow(subA, subB));
114
+ }
115
+ return result;
116
+ }
117
+ else if (Array.isArray(tA) && typeof tB === "number") {
118
+ return tA.map(subA => TensorMath.pow(subA, tB));
119
+ }
120
+ else if (typeof tA === "number" && Array.isArray(tB)) {
121
+ return tB.map(subB => TensorMath.pow(tA, subB));
122
+ }
123
+ throw new Error("Inputs are not tensors");
124
+ }
125
+ static div(tA, tB) {
126
+ if (typeof tA === "number" && typeof tB === "number") {
127
+ return tA / tB;
128
+ }
129
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
130
+ const outLen = Math.max(tA.length, tB.length);
131
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
132
+ throw new Error("Inputs are incompatible tensors");
133
+ }
134
+ const result = [];
135
+ for (let i = 0; i < outLen; i++) {
136
+ const subA = tA[tA.length === 1 ? 0 : i];
137
+ const subB = tB[tB.length === 1 ? 0 : i];
138
+ result.push(TensorMath.div(subA, subB));
139
+ }
140
+ return result;
141
+ }
142
+ else if (Array.isArray(tA) && typeof tB === "number") {
143
+ return tA.map(subA => TensorMath.div(subA, tB));
144
+ }
145
+ else if (typeof tA === "number" && Array.isArray(tB)) {
146
+ return tB.map(subB => TensorMath.div(tA, subB));
147
+ }
148
+ throw new Error("Inputs are not tensors");
149
+ }
150
+ static gt(tA, tB) {
151
+ if (typeof tA === "number" && typeof tB === "number") {
152
+ return tA > tB ? 1 : 0;
153
+ }
154
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
155
+ const outLen = Math.max(tA.length, tB.length);
156
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
157
+ throw new Error("Inputs are incompatible tensors");
158
+ }
159
+ const result = [];
160
+ for (let i = 0; i < outLen; i++) {
161
+ const subA = tA[tA.length === 1 ? 0 : i];
162
+ const subB = tB[tB.length === 1 ? 0 : i];
163
+ result.push(TensorMath.gt(subA, subB));
164
+ }
165
+ return result;
166
+ }
167
+ else if (Array.isArray(tA) && typeof tB === "number") {
168
+ return tA.map(subA => TensorMath.gt(subA, tB));
169
+ }
170
+ else if (typeof tA === "number" && Array.isArray(tB)) {
171
+ return tB.map(subB => TensorMath.gt(tA, subB));
172
+ }
173
+ throw new Error("Inputs are not tensors");
174
+ }
175
+ static lt(tA, tB) {
176
+ if (typeof tA === "number" && typeof tB === "number") {
177
+ return tA < tB ? 1 : 0;
178
+ }
179
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
180
+ const outLen = Math.max(tA.length, tB.length);
181
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
182
+ throw new Error("Inputs are incompatible tensors");
183
+ }
184
+ const result = [];
185
+ for (let i = 0; i < outLen; i++) {
186
+ const subA = tA[tA.length === 1 ? 0 : i];
187
+ const subB = tB[tB.length === 1 ? 0 : i];
188
+ result.push(TensorMath.lt(subA, subB));
189
+ }
190
+ return result;
191
+ }
192
+ else if (Array.isArray(tA) && typeof tB === "number") {
193
+ return tA.map(subA => TensorMath.lt(subA, tB));
194
+ }
195
+ else if (typeof tA === "number" && Array.isArray(tB)) {
196
+ return tB.map(subB => TensorMath.lt(tA, subB));
197
+ }
198
+ throw new Error("Inputs are not tensors");
199
+ }
200
+ static ge(tA, tB) {
201
+ if (typeof tA === "number" && typeof tB === "number") {
202
+ return tA >= tB ? 1 : 0;
203
+ }
204
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
205
+ const outLen = Math.max(tA.length, tB.length);
206
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
207
+ throw new Error("Inputs are incompatible tensors");
208
+ }
209
+ const result = [];
210
+ for (let i = 0; i < outLen; i++) {
211
+ const subA = tA[tA.length === 1 ? 0 : i];
212
+ const subB = tB[tB.length === 1 ? 0 : i];
213
+ result.push(TensorMath.ge(subA, subB));
214
+ }
215
+ return result;
216
+ }
217
+ else if (Array.isArray(tA) && typeof tB === "number") {
218
+ return tA.map(subA => TensorMath.ge(subA, tB));
219
+ }
220
+ else if (typeof tA === "number" && Array.isArray(tB)) {
221
+ return tB.map(subB => TensorMath.ge(tA, subB));
222
+ }
223
+ throw new Error("Inputs are not tensors");
224
+ }
225
+ static le(tA, tB) {
226
+ if (typeof tA === "number" && typeof tB === "number") {
227
+ return tA <= tB ? 1 : 0;
228
+ }
229
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
230
+ const outLen = Math.max(tA.length, tB.length);
231
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
232
+ throw new Error("Inputs are incompatible tensors");
233
+ }
234
+ const result = [];
235
+ for (let i = 0; i < outLen; i++) {
236
+ const subA = tA[tA.length === 1 ? 0 : i];
237
+ const subB = tB[tB.length === 1 ? 0 : i];
238
+ result.push(TensorMath.le(subA, subB));
239
+ }
240
+ return result;
241
+ }
242
+ else if (Array.isArray(tA) && typeof tB === "number") {
243
+ return tA.map(subA => TensorMath.le(subA, tB));
244
+ }
245
+ else if (typeof tA === "number" && Array.isArray(tB)) {
246
+ return tB.map(subB => TensorMath.le(tA, subB));
247
+ }
248
+ throw new Error("Inputs are not tensors");
249
+ }
250
+ static eq(tA, tB) {
251
+ if (typeof tA === "number" && typeof tB === "number") {
252
+ return tA === tB ? 1 : 0;
253
+ }
254
+ else if (Array.isArray(tA) && Array.isArray(tB)) {
255
+ const outLen = Math.max(tA.length, tB.length);
256
+ if (tA.length !== tB.length && tA.length !== 1 && tB.length !== 1) {
257
+ throw new Error("Inputs are incompatible tensors");
258
+ }
259
+ const result = [];
260
+ for (let i = 0; i < outLen; i++) {
261
+ const subA = tA[tA.length === 1 ? 0 : i];
262
+ const subB = tB[tB.length === 1 ? 0 : i];
263
+ result.push(TensorMath.eq(subA, subB));
264
+ }
265
+ return result;
266
+ }
267
+ else if (Array.isArray(tA) && typeof tB === "number") {
268
+ return tA.map(subA => TensorMath.eq(subA, tB));
269
+ }
270
+ else if (typeof tA === "number" && Array.isArray(tB)) {
271
+ return tB.map(subB => TensorMath.eq(tA, subB));
272
+ }
273
+ throw new Error("Inputs are not tensors");
274
+ }
275
+ static neg(tA) {
276
+ if (typeof tA === "number") {
277
+ return -tA;
278
+ }
279
+ else {
280
+ return tA.map(subA => TensorMath.neg(subA));
281
+ }
282
+ }
283
+ static exp(tA) {
284
+ if (typeof tA === "number") {
285
+ return Math.exp(tA);
286
+ }
287
+ else {
288
+ return tA.map(subA => TensorMath.exp(subA));
289
+ }
290
+ }
291
+ static log(tA) {
292
+ if (typeof tA === "number") {
293
+ return Math.log(tA);
294
+ }
295
+ else {
296
+ return tA.map(subA => TensorMath.log(subA));
297
+ }
298
+ }
299
+ static relu(tA) {
300
+ if (typeof tA === "number") {
301
+ return Math.max(tA, 0);
302
+ }
303
+ else {
304
+ return tA.map(subA => TensorMath.relu(subA));
305
+ }
306
+ }
307
+ static sigmoid(tA) {
308
+ if (typeof tA === "number") {
309
+ return 1 / (1 + Math.exp(-tA));
310
+ }
311
+ else {
312
+ return tA.map(subA => TensorMath.sigmoid(subA));
313
+ }
314
+ }
315
+ static tanh(tA) {
316
+ if (typeof tA === "number") {
317
+ return Math.tanh(tA);
318
+ }
319
+ else {
320
+ return tA.map(subA => TensorMath.tanh(subA));
321
+ }
322
+ }
323
+ }
324
+ exports.TensorMath = TensorMath;
package/index.js CHANGED
@@ -1 +1,4 @@
1
- module.exports = require("./dist/core");
1
+ module.exports = {
2
+ ...require("./dist/autograd"),
3
+ ...require("./dist/tensor")
4
+ };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.1.0",
3
+ "version": "0.1.2",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {
@@ -18,7 +18,8 @@
18
18
  "dl",
19
19
  "ai",
20
20
  "maths",
21
- "gradient"
21
+ "gradient",
22
+ "tensors"
22
23
  ],
23
24
  "author": "nguyenphuminh",
24
25
  "license": "GPL-3.0",
@@ -28,5 +29,12 @@
28
29
  "homepage": "https://github.com/nguyenphuminh/catniff#readme",
29
30
  "devDependencies": {
30
31
  "typescript": "^5.8.3"
31
- }
32
+ },
33
+ "files": [
34
+ "dist/",
35
+ "index.d.ts",
36
+ "index.js",
37
+ "LICENSE",
38
+ "README.md"
39
+ ]
32
40
  }
package/dist/core.js DELETED
@@ -1,145 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Node = exports.OP = void 0;
4
- var OP;
5
- (function (OP) {
6
- OP[OP["NONE"] = 0] = "NONE";
7
- OP[OP["ADD"] = 1] = "ADD";
8
- OP[OP["MUL"] = 2] = "MUL";
9
- OP[OP["POW"] = 3] = "POW";
10
- OP[OP["DIV"] = 4] = "DIV";
11
- OP[OP["NEG"] = 5] = "NEG";
12
- OP[OP["EXP"] = 6] = "EXP";
13
- OP[OP["LOG"] = 7] = "LOG";
14
- OP[OP["RELU"] = 8] = "RELU";
15
- OP[OP["SIGMOID"] = 9] = "SIGMOID";
16
- OP[OP["TANH"] = 10] = "TANH";
17
- })(OP || (exports.OP = OP = {}));
18
- class Node {
19
- // Only scalars are supported for now
20
- value;
21
- grad;
22
- children;
23
- feedBackward;
24
- op;
25
- constructor(value, children = [], op = OP.NONE) {
26
- this.value = value;
27
- this.grad = 0;
28
- this.children = children;
29
- this.op = op;
30
- this.feedBackward = () => { };
31
- }
32
- add(other) {
33
- other = this.forceNode(other);
34
- const out = new Node(this.value + other.value, [this, other], OP.ADD);
35
- out.feedBackward = () => {
36
- this.grad += out.grad;
37
- other.grad += out.grad;
38
- };
39
- return out;
40
- }
41
- mul(other) {
42
- other = this.forceNode(other);
43
- const out = new Node(this.value * other.value, [this, other], OP.MUL);
44
- out.feedBackward = () => {
45
- this.grad += out.grad * other.value;
46
- other.grad += out.grad * this.value;
47
- };
48
- return out;
49
- }
50
- pow(other) {
51
- if (other instanceof Node) {
52
- const out = new Node(this.value ** other.value, [this, other], OP.POW);
53
- out.feedBackward = () => {
54
- this.grad += out.grad * other.value * this.value ** (other.value - 1);
55
- other.grad += out.grad * this.value ** other.value * Math.log(this.value);
56
- };
57
- return out;
58
- }
59
- const out = new Node(this.value ** other, [this], OP.POW);
60
- out.feedBackward = () => {
61
- this.grad += out.grad * other * this.value ** (other - 1);
62
- };
63
- return out;
64
- }
65
- div(other) {
66
- other = this.forceNode(other);
67
- const out = new Node(this.value / other.value, [this, other], OP.DIV);
68
- out.feedBackward = () => {
69
- this.grad += out.grad * (1 / other.value);
70
- other.grad += out.grad * -this.value / other.value ** 2;
71
- };
72
- return out;
73
- }
74
- neg() {
75
- const out = new Node(-this.value, [this], OP.NEG);
76
- out.feedBackward = () => {
77
- this.grad += -out.grad;
78
- };
79
- return out;
80
- }
81
- exp() {
82
- const exp = Math.exp(this.value);
83
- const out = new Node(exp, [this], OP.EXP);
84
- out.feedBackward = () => {
85
- this.grad += exp;
86
- };
87
- return out;
88
- }
89
- log() {
90
- const out = new Node(Math.log(this.value), [this], OP.LOG);
91
- out.feedBackward = () => {
92
- this.grad += 1 / this.value;
93
- };
94
- return out;
95
- }
96
- relu() {
97
- const out = new Node(Math.max(this.value, 0), [this], OP.RELU);
98
- out.feedBackward = () => {
99
- this.grad += out.grad * (this.value < 0 ? 0 : 1);
100
- };
101
- return out;
102
- }
103
- sigmoid() {
104
- const sigmoid = 1 / (1 + Math.exp(-this.value));
105
- const out = new Node(sigmoid, [this], OP.SIGMOID);
106
- out.feedBackward = () => {
107
- this.grad += out.grad * sigmoid * (1 - sigmoid);
108
- };
109
- return out;
110
- }
111
- tanh() {
112
- const tanh = Math.tanh(this.value);
113
- const out = new Node(tanh, [this], OP.TANH);
114
- out.feedBackward = () => {
115
- this.grad += out.grad * (1 - tanh * tanh);
116
- };
117
- return out;
118
- }
119
- backward() {
120
- // Build topological order
121
- const topo = [];
122
- const visited = new Set();
123
- function build(node) {
124
- if (!visited.has(node)) {
125
- visited.add(node);
126
- node.grad = 0;
127
- for (let child of node.children)
128
- build(child);
129
- topo.push(node);
130
- }
131
- }
132
- build(this);
133
- // Feed backward to calculate gradient
134
- this.grad = 1; // Derivative of itself with respect to itself
135
- for (let index = topo.length - 1; index > -1; index--) {
136
- topo[index].feedBackward();
137
- }
138
- }
139
- forceNode(value) {
140
- if (value instanceof Node)
141
- return value;
142
- return new Node(value);
143
- }
144
- }
145
- exports.Node = Node;
@@ -1,7 +0,0 @@
1
- const { Node } = require("../index");
2
-
3
- const x = new Node(2);
4
- const L = x.pow(2).add(x)
5
-
6
- L.backward();
7
- console.log(x.grad, L.grad);
package/src/core.ts DELETED
@@ -1,180 +0,0 @@
1
- export enum OP {
2
- NONE,
3
- ADD,
4
- MUL,
5
- POW,
6
- DIV,
7
- NEG,
8
- EXP,
9
- LOG,
10
- RELU,
11
- SIGMOID,
12
- TANH
13
- }
14
-
15
- export class Node {
16
- // Only scalars are supported for now
17
- public value: number;
18
- public grad: number;
19
- public children: Node[];
20
- public feedBackward: Function;
21
- public op: OP;
22
-
23
- constructor(value: number, children: Node[] = [], op: OP = OP.NONE) {
24
- this.value = value;
25
- this.grad = 0;
26
- this.children = children;
27
- this.op = op;
28
- this.feedBackward = () => {};
29
- }
30
-
31
- add(other: Node | number): Node {
32
- other = this.forceNode(other);
33
- const out = new Node(this.value + other.value, [this, other], OP.ADD);
34
-
35
- out.feedBackward = () => {
36
- this.grad += out.grad;
37
- other.grad += out.grad;
38
- }
39
-
40
- return out;
41
- }
42
-
43
- mul(other: Node | number): Node {
44
- other = this.forceNode(other);
45
- const out = new Node(this.value * other.value, [this, other], OP.MUL);
46
-
47
- out.feedBackward = () => {
48
- this.grad += out.grad * other.value;
49
- other.grad += out.grad * this.value;
50
- }
51
-
52
- return out;
53
- }
54
-
55
- pow(other: Node | number): Node {
56
- if (other instanceof Node) {
57
- const out = new Node(this.value ** other.value, [this, other], OP.POW);
58
-
59
- out.feedBackward = () => {
60
- this.grad += out.grad * other.value * this.value ** (other.value - 1);
61
- other.grad += out.grad * this.value ** other.value * Math.log(this.value);
62
- }
63
-
64
- return out;
65
- }
66
-
67
- const out = new Node(this.value ** other, [this], OP.POW);
68
-
69
- out.feedBackward = () => {
70
- this.grad += out.grad * other * this.value ** (other - 1);
71
- }
72
-
73
- return out;
74
- }
75
-
76
- div(other: Node | number): Node {
77
- other = this.forceNode(other);
78
- const out = new Node(this.value / other.value, [this, other], OP.DIV);
79
-
80
- out.feedBackward = () => {
81
- this.grad += out.grad / other.value;
82
- other.grad += out.grad * -this.value / other.value**2;
83
- }
84
-
85
- return out;
86
- }
87
-
88
- neg(): Node {
89
- const out = new Node(-this.value, [this], OP.NEG);
90
-
91
- out.feedBackward = () => {
92
- this.grad += -out.grad;
93
- }
94
-
95
- return out;
96
- }
97
-
98
- exp(): Node {
99
- const exp = Math.exp(this.value);
100
- const out = new Node(exp, [this], OP.EXP);
101
-
102
- out.feedBackward = () => {
103
- this.grad += out.grad * exp;
104
- }
105
-
106
- return out;
107
- }
108
-
109
- log(): Node {
110
- const out = new Node(Math.log(this.value), [this], OP.LOG);
111
-
112
- out.feedBackward = () => {
113
- this.grad += out.grad / this.value;
114
- }
115
-
116
- return out;
117
- }
118
-
119
- relu(): Node {
120
- const out = new Node(Math.max(this.value, 0), [this], OP.RELU);
121
- out.feedBackward = () => {
122
- this.grad += out.grad * (this.value < 0 ? 0 : 1);
123
- };
124
-
125
- return out;
126
- }
127
-
128
- sigmoid(): Node {
129
- const sigmoid = 1 / (1 + Math.exp(-this.value));
130
-
131
- const out = new Node(sigmoid, [this], OP.SIGMOID);
132
- out.feedBackward = () => {
133
- this.grad += out.grad * sigmoid * (1 - sigmoid);
134
- };
135
-
136
- return out;
137
- }
138
-
139
- tanh(): Node {
140
- const tanh = Math.tanh(this.value);
141
-
142
- const out = new Node(tanh, [this], OP.TANH);
143
- out.feedBackward = () => {
144
- this.grad += out.grad * (1 - tanh * tanh);
145
- };
146
-
147
- return out;
148
- }
149
-
150
- backward() {
151
- // Build topological order
152
- const topo: Node[] = [];
153
- const visited: Set<Node> = new Set();
154
-
155
- function build(node: Node) {
156
- if (!visited.has(node)) {
157
- visited.add(node);
158
- node.grad = 0;
159
-
160
- for (let child of node.children) build(child);
161
- topo.push(node);
162
- }
163
- }
164
-
165
- build(this);
166
-
167
- // Feed backward to calculate gradient
168
- this.grad = 1; // Derivative of itself with respect to itself
169
-
170
- for (let index = topo.length - 1; index > -1; index--) {
171
- topo[index].feedBackward();
172
- }
173
- }
174
-
175
- forceNode(value: Node | number): Node {
176
- if (value instanceof Node) return value;
177
-
178
- return new Node(value);
179
- }
180
- }
package/tsconfig.json DELETED
@@ -1,108 +0,0 @@
1
- {
2
- "include": [
3
- "src/**/*"
4
- ],
5
- "compilerOptions": {
6
- /* Visit https://aka.ms/tsconfig to read more about this file */
7
- /* Projects */
8
- // "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
9
- // "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
10
- // "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
11
- // "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
12
- // "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
13
- // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
14
- /* Language and Environment */
15
- "target": "esnext", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
16
- // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
17
- // "jsx": "preserve", /* Specify what JSX code is generated. */
18
- // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
19
- // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
20
- // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
21
- // "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
22
- // "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
23
- // "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
24
- // "noLib": true, /* Disable including any library files, including the default lib.d.ts. */
25
- // "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
26
- // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
27
- /* Modules */
28
- "module": "commonjs", /* Specify what module code is generated. */
29
- "rootDir": "./src/", /* Specify the root folder within your source files. */
30
- "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
31
- "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
32
- //"paths": {
33
- // "*": [
34
- // "node_modules/*"
35
- // ]
36
- //}, /* Specify a set of entries that re-map imports to additional lookup locations. */
37
- // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
38
- // "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
39
- // "types": [], /* Specify type package names to be included without being referenced in a source file. */
40
- // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
41
- // "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
42
- // "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
43
- // "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
44
- // "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
45
- // "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
46
- // "resolveJsonModule": true, /* Enable importing .json files. */
47
- // "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
48
- // "noResolve": true, /* Disallow 'import's, 'require's or '<reference>'s from expanding the number of files TypeScript should add to a project. */
49
- /* JavaScript Support */
50
- "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
51
- // "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
52
- // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
53
- /* Emit */
54
- "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
55
- // "declarationMap": true, /* Create sourcemaps for d.ts files. */
56
- // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
57
- // "sourceMap": true, /* Create source map files for emitted JavaScript files. */
58
- // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
59
- // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
60
- "outDir": "./dist", /* Specify an output folder for all emitted files. */
61
- // "removeComments": true, /* Disable emitting comments. */
62
- // "noEmit": true, /* Disable emitting files from a compilation. */
63
- // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
64
- // "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */
65
- // "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
66
- // "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
67
- // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
68
- // "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
69
- // "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
70
- // "newLine": "crlf", /* Set the newline character for emitting files. */
71
- // "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
72
- // "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
73
- // "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
74
- // "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
75
- // "declarationDir": "./", /* Specify the output directory for generated declaration files. */
76
- // "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */
77
- /* Interop Constraints */
78
- // "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
79
- // "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
80
- // "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
81
- "esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
82
- // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
83
- "forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */
84
- /* Type Checking */
85
- "strict": true, /* Enable all strict type-checking options. */
86
- // "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
87
- // "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
88
- // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
89
- // "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
90
- // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
91
- // "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
92
- // "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
93
- // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
94
- // "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
95
- // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
96
- // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
97
- // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
98
- // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
99
- // "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
100
- // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
101
- // "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
102
- // "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
103
- // "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
104
- /* Completeness */
105
- // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
106
- "skipLibCheck": true /* Skip type checking all .d.ts files. */
107
- }
108
- }