catniff 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -76,6 +76,32 @@ optim.step();
76
76
  console.log("Updated weight:", w.data); // Should move toward 3.0
77
77
  ```
78
78
 
79
+ ## Neural networks
80
+
81
+ There are built-in neural network constructs in Catniff as well:
82
+ ```js
83
+ const { Tensor, nn } = require("catniff");
84
+
85
+ // Linear layer with input size of 20 and output size of 10
86
+ const linear = nn.Linear(20, 10);
87
+ // RNN cell with input size of 32 and hidden size of 64
88
+ const rnnCell = nn.RNNCell(32, 64);
89
+ // Same thing but using GRU
90
+ const gruCell = nn.GRUCell(32, 64);
91
+ // Same thing but using LSTM
92
+ const lstmCell = nn.LSTMCell(32, 64);
93
+
94
+ // Forward passes
95
+ const a = Tensor.randn([20]);
96
+ const b = Tensor.randn([32]);
97
+ const c = Tensor.randn([64]);
98
+
99
+ linear.forward(a);
100
+ rnnCell.forward(b, c);
101
+ gruCell.forward(b, c);
102
+ lstmCell.forward(b, c, c);
103
+ ```
104
+
79
105
  And it can still do much more, check out the docs and examples below for more information.
80
106
 
81
107
  ## Documentation
package/dist/core.js CHANGED
@@ -20,6 +20,13 @@ class Tensor {
20
20
  this.gradFn = options.gradFn || (() => { });
21
21
  this.children = options.children || [];
22
22
  this.device = options.device || "cpu";
23
+ // Move tensor to device
24
+ if (this.device !== "cpu") {
25
+ const backend = Tensor.backends.get(this.device);
26
+ if (backend && backend.transfer) {
27
+ backend.transfer(this);
28
+ }
29
+ }
23
30
  }
24
31
  // Utility to flatten an nD array to be 1D
25
32
  static flatten(tensor) {
@@ -1549,7 +1556,8 @@ class Tensor {
1549
1556
  to(device) {
1550
1557
  const backend = Tensor.backends.get(device);
1551
1558
  if (backend && backend.transfer) {
1552
- return backend.transfer(this);
1559
+ backend.transfer(this);
1560
+ return this;
1553
1561
  }
1554
1562
  throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1555
1563
  }
package/dist/nn.d.ts CHANGED
@@ -2,13 +2,64 @@ import { Tensor, TensorValue } from "./core";
2
2
  declare class Linear {
3
3
  weight: Tensor;
4
4
  bias?: Tensor;
5
- constructor(inFeatures: number, outFeatures: number, bias?: boolean, customInit?: (shape: number[]) => Tensor);
5
+ constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
6
6
  forward(input: Tensor | TensorValue): Tensor;
7
7
  }
8
+ declare class RNNCell {
9
+ weightIH: Tensor;
10
+ weightHH: Tensor;
11
+ biasIH?: Tensor;
12
+ biasHH?: Tensor;
13
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
14
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
15
+ }
16
+ declare class GRUCell {
17
+ weightIR: Tensor;
18
+ weightIZ: Tensor;
19
+ weightIN: Tensor;
20
+ weightHR: Tensor;
21
+ weightHZ: Tensor;
22
+ weightHN: Tensor;
23
+ biasIR?: Tensor;
24
+ biasIZ?: Tensor;
25
+ biasIN?: Tensor;
26
+ biasHR?: Tensor;
27
+ biasHZ?: Tensor;
28
+ biasHN?: Tensor;
29
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
30
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
31
+ }
32
+ export declare class LSTMCell {
33
+ weightII: Tensor;
34
+ weightIF: Tensor;
35
+ weightIG: Tensor;
36
+ weightIO: Tensor;
37
+ weightHI: Tensor;
38
+ weightHF: Tensor;
39
+ weightHG: Tensor;
40
+ weightHO: Tensor;
41
+ biasII?: Tensor;
42
+ biasIF?: Tensor;
43
+ biasIG?: Tensor;
44
+ biasIO?: Tensor;
45
+ biasHI?: Tensor;
46
+ biasHF?: Tensor;
47
+ biasHG?: Tensor;
48
+ biasHO?: Tensor;
49
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
50
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
51
+ }
52
+ interface StateDict {
53
+ [key: string]: any;
54
+ }
8
55
  export declare const nn: {
9
56
  Linear: typeof Linear;
57
+ RNNCell: typeof RNNCell;
58
+ GRUCell: typeof GRUCell;
10
59
  state: {
11
60
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
61
+ getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
62
+ loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
12
63
  };
13
64
  };
14
65
  export {};
package/dist/nn.js CHANGED
@@ -1,37 +1,154 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = void 0;
3
+ exports.nn = exports.LSTMCell = void 0;
4
4
  const core_1 = require("./core");
5
+ function linearTransform(input, weight, bias) {
6
+ let output = input.matmul(weight.t());
7
+ if (bias) {
8
+ output = output.add(bias);
9
+ }
10
+ return output;
11
+ }
5
12
  class Linear {
6
13
  weight;
7
14
  bias;
8
- constructor(inFeatures, outFeatures, bias = true, customInit) {
9
- let initFunc = (shape) => {
10
- const bound = 1 / Math.sqrt(inFeatures);
11
- return core_1.Tensor.uniform(shape, -bound, bound, { requiresGrad: true });
12
- };
13
- if (customInit) {
14
- initFunc = customInit;
15
- }
16
- this.weight = initFunc([outFeatures, inFeatures]);
15
+ constructor(inFeatures, outFeatures, bias = true, device) {
16
+ const bound = 1 / Math.sqrt(inFeatures);
17
+ this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
17
18
  if (bias) {
18
- this.bias = initFunc([outFeatures]);
19
+ this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
19
20
  }
20
21
  }
21
22
  forward(input) {
22
23
  input = core_1.Tensor.forceTensor(input);
23
- let output = input.matmul(this.weight.t());
24
- if (this.bias) {
25
- output = output.add(this.bias);
24
+ return linearTransform(input, this.weight, this.bias);
25
+ }
26
+ }
27
+ function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
28
+ let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
29
+ if (inputBias) {
30
+ output = output.add(inputBias);
31
+ }
32
+ if (hiddenBias) {
33
+ output = output.add(hiddenBias);
34
+ }
35
+ return output;
36
+ }
37
+ class RNNCell {
38
+ weightIH;
39
+ weightHH;
40
+ biasIH;
41
+ biasHH;
42
+ constructor(inputSize, hiddenSize, bias = true, device) {
43
+ const bound = 1 / Math.sqrt(hiddenSize);
44
+ this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
45
+ this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
46
+ if (bias) {
47
+ this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
48
+ this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
49
+ }
50
+ }
51
+ forward(input, hidden) {
52
+ input = core_1.Tensor.forceTensor(input);
53
+ hidden = core_1.Tensor.forceTensor(hidden);
54
+ return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
55
+ }
56
+ }
57
+ class GRUCell {
58
+ weightIR;
59
+ weightIZ;
60
+ weightIN;
61
+ weightHR;
62
+ weightHZ;
63
+ weightHN;
64
+ biasIR;
65
+ biasIZ;
66
+ biasIN;
67
+ biasHR;
68
+ biasHZ;
69
+ biasHN;
70
+ constructor(inputSize, hiddenSize, bias = true, device) {
71
+ const bound = 1 / Math.sqrt(hiddenSize);
72
+ this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
73
+ this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
74
+ this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
75
+ this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
76
+ this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
77
+ this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
78
+ if (bias) {
79
+ this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
80
+ this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
81
+ this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
82
+ this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
83
+ this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
84
+ this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
26
85
  }
27
- return output;
86
+ }
87
+ forward(input, hidden) {
88
+ input = core_1.Tensor.forceTensor(input);
89
+ hidden = core_1.Tensor.forceTensor(hidden);
90
+ const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
91
+ const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
92
+ const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
93
+ return (z.neg().add(1).mul(n).add(z.mul(hidden)));
28
94
  }
29
95
  }
96
+ class LSTMCell {
97
+ weightII;
98
+ weightIF;
99
+ weightIG;
100
+ weightIO;
101
+ weightHI;
102
+ weightHF;
103
+ weightHG;
104
+ weightHO;
105
+ biasII;
106
+ biasIF;
107
+ biasIG;
108
+ biasIO;
109
+ biasHI;
110
+ biasHF;
111
+ biasHG;
112
+ biasHO;
113
+ constructor(inputSize, hiddenSize, bias = true, device) {
114
+ const bound = 1 / Math.sqrt(hiddenSize);
115
+ this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
116
+ this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
117
+ this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
118
+ this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
119
+ this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
120
+ this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
121
+ this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
122
+ this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
123
+ if (bias) {
124
+ this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
125
+ this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
126
+ this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
127
+ this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
128
+ this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
129
+ this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
130
+ this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
131
+ this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
132
+ }
133
+ }
134
+ forward(input, hidden, cell) {
135
+ input = core_1.Tensor.forceTensor(input);
136
+ hidden = core_1.Tensor.forceTensor(hidden);
137
+ cell = core_1.Tensor.forceTensor(cell);
138
+ const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
139
+ const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
140
+ const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
141
+ const o = rnnTransform(input, hidden, this.weightIO, this.weightHO, this.biasIO, this.biasHO).sigmoid();
142
+ const c = f.mul(cell).add(i.mul(g));
143
+ const h = o.mul(c.tanh());
144
+ return [h, c];
145
+ }
146
+ }
147
+ exports.LSTMCell = LSTMCell;
30
148
  const state = {
31
149
  getParameters(model, visited = new WeakSet()) {
32
- if (visited.has(model)) {
150
+ if (visited.has(model))
33
151
  return [];
34
- }
35
152
  visited.add(model);
36
153
  const parameters = [];
37
154
  for (const key in model) {
@@ -46,9 +163,47 @@ const state = {
46
163
  }
47
164
  }
48
165
  return parameters;
166
+ },
167
+ getStateDict(model, prefix = "", visited = new WeakSet()) {
168
+ if (visited.has(model))
169
+ return {};
170
+ visited.add(model);
171
+ const stateDict = {};
172
+ for (const key in model) {
173
+ if (!model.hasOwnProperty(key))
174
+ continue;
175
+ const value = model[key];
176
+ const fullKey = prefix ? `${prefix}.${key}` : key;
177
+ if (value instanceof core_1.Tensor) {
178
+ stateDict[fullKey] = value.val();
179
+ }
180
+ else if (typeof value === "object" && value !== null) {
181
+ Object.assign(stateDict, this.getStateDict(value, fullKey, visited));
182
+ }
183
+ }
184
+ return stateDict;
185
+ },
186
+ loadStateDict(model, stateDict, prefix = "", visited = new WeakSet()) {
187
+ if (visited.has(model))
188
+ return;
189
+ visited.add(model);
190
+ for (const key in model) {
191
+ if (!model.hasOwnProperty(key))
192
+ continue;
193
+ const value = model[key];
194
+ const fullKey = prefix ? `${prefix}.${key}` : key;
195
+ if (value instanceof core_1.Tensor && stateDict[fullKey]) {
196
+ value.replace(new core_1.Tensor(stateDict[fullKey]));
197
+ }
198
+ else if (typeof value === "object" && value !== null) {
199
+ this.loadStateDict(value, stateDict, fullKey, visited);
200
+ }
201
+ }
49
202
  }
50
203
  };
51
204
  exports.nn = {
52
205
  Linear,
206
+ RNNCell,
207
+ GRUCell,
53
208
  state
54
209
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.0",
3
+ "version": "0.5.2",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {