catniff 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -76,6 +76,32 @@ optim.step();
76
76
  console.log("Updated weight:", w.data); // Should move toward 3.0
77
77
  ```
78
78
 
79
+ ## Neural networks
80
+
81
+ There are built-in neural network constructs in Catniff as well:
82
+ ```js
83
+ const { Tensor, nn } = require("catniff");
84
+
85
+ // Linear layer with input size of 20 and output size of 10
86
+ const linear = nn.Linear(20, 10);
87
+ // RNN cell with input size of 32 and hidden size of 64
88
+ const rnnCell = nn.RNNCell(32, 64);
89
+ // Same thing but using GRU
90
+ const gruCell = nn.GRUCell(32, 64);
91
+ // Same thing but using LSTM
92
+ const lstmCell = nn.LSTMCell(32, 64);
93
+
94
+ // Forward passes
95
+ const a = Tensor.randn([20]);
96
+ const b = Tensor.randn([32]);
97
+ const c = Tensor.randn([64]);
98
+
99
+ linear.forward(a);
100
+ rnnCell.forward(b, c);
101
+ gruCell.forward(b, c);
102
+ lstmCell.forward(b, c, c);
103
+ ```
104
+
79
105
  And it can still do much more, check out the docs and examples below for more information.
80
106
 
81
107
  ## Documentation
package/dist/core.js CHANGED
@@ -1534,7 +1534,6 @@ class Tensor {
1534
1534
  return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1535
1535
  shape: this.shape,
1536
1536
  strides: this.strides,
1537
- device: this.device,
1538
1537
  requiresGrad: this.requiresGrad
1539
1538
  });
1540
1539
  }
package/dist/nn.d.ts CHANGED
@@ -13,11 +13,53 @@ declare class RNNCell {
13
13
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
14
14
  forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
15
15
  }
16
+ declare class GRUCell {
17
+ weightIR: Tensor;
18
+ weightIZ: Tensor;
19
+ weightIN: Tensor;
20
+ weightHR: Tensor;
21
+ weightHZ: Tensor;
22
+ weightHN: Tensor;
23
+ biasIR?: Tensor;
24
+ biasIZ?: Tensor;
25
+ biasIN?: Tensor;
26
+ biasHR?: Tensor;
27
+ biasHZ?: Tensor;
28
+ biasHN?: Tensor;
29
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
30
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
31
+ }
32
+ export declare class LSTMCell {
33
+ weightII: Tensor;
34
+ weightIF: Tensor;
35
+ weightIG: Tensor;
36
+ weightIO: Tensor;
37
+ weightHI: Tensor;
38
+ weightHF: Tensor;
39
+ weightHG: Tensor;
40
+ weightHO: Tensor;
41
+ biasII?: Tensor;
42
+ biasIF?: Tensor;
43
+ biasIG?: Tensor;
44
+ biasIO?: Tensor;
45
+ biasHI?: Tensor;
46
+ biasHF?: Tensor;
47
+ biasHG?: Tensor;
48
+ biasHO?: Tensor;
49
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
50
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
51
+ }
52
+ interface StateDict {
53
+ [key: string]: any;
54
+ }
16
55
  export declare const nn: {
17
56
  Linear: typeof Linear;
18
57
  RNNCell: typeof RNNCell;
58
+ GRUCell: typeof GRUCell;
19
59
  state: {
20
60
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
61
+ getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
62
+ loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
21
63
  };
22
64
  };
23
65
  export {};
package/dist/nn.js CHANGED
@@ -1,7 +1,14 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = void 0;
3
+ exports.nn = exports.LSTMCell = void 0;
4
4
  const core_1 = require("./core");
5
+ function linearTransform(input, weight, bias) {
6
+ let output = input.matmul(weight.t());
7
+ if (bias) {
8
+ output = output.add(bias);
9
+ }
10
+ return output;
11
+ }
5
12
  class Linear {
6
13
  weight;
7
14
  bias;
@@ -14,12 +21,18 @@ class Linear {
14
21
  }
15
22
  forward(input) {
16
23
  input = core_1.Tensor.forceTensor(input);
17
- let output = input.matmul(this.weight.t());
18
- if (this.bias) {
19
- output = output.add(this.bias);
20
- }
21
- return output;
24
+ return linearTransform(input, this.weight, this.bias);
25
+ }
26
+ }
27
+ function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
28
+ let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
29
+ if (inputBias) {
30
+ output = output.add(inputBias);
31
+ }
32
+ if (hiddenBias) {
33
+ output = output.add(hiddenBias);
22
34
  }
35
+ return output;
23
36
  }
24
37
  class RNNCell {
25
38
  weightIH;
@@ -38,19 +51,104 @@ class RNNCell {
38
51
  forward(input, hidden) {
39
52
  input = core_1.Tensor.forceTensor(input);
40
53
  hidden = core_1.Tensor.forceTensor(hidden);
41
- let output = input.matmul(this.weightIH.t())
42
- .add(hidden.matmul(this.weightHH.t()));
43
- if (this.biasIH && this.biasHH) {
44
- output = output.add(this.biasIH).add(this.biasHH);
54
+ return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
55
+ }
56
+ }
57
+ class GRUCell {
58
+ weightIR;
59
+ weightIZ;
60
+ weightIN;
61
+ weightHR;
62
+ weightHZ;
63
+ weightHN;
64
+ biasIR;
65
+ biasIZ;
66
+ biasIN;
67
+ biasHR;
68
+ biasHZ;
69
+ biasHN;
70
+ constructor(inputSize, hiddenSize, bias = true, device) {
71
+ const bound = 1 / Math.sqrt(hiddenSize);
72
+ this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
73
+ this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
74
+ this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
75
+ this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
76
+ this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
77
+ this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
78
+ if (bias) {
79
+ this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
80
+ this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
81
+ this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
82
+ this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
83
+ this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
84
+ this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
85
+ }
86
+ }
87
+ forward(input, hidden) {
88
+ input = core_1.Tensor.forceTensor(input);
89
+ hidden = core_1.Tensor.forceTensor(hidden);
90
+ const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
91
+ const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
92
+ const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
93
+ return (z.neg().add(1).mul(n).add(z.mul(hidden)));
94
+ }
95
+ }
96
+ class LSTMCell {
97
+ weightII;
98
+ weightIF;
99
+ weightIG;
100
+ weightIO;
101
+ weightHI;
102
+ weightHF;
103
+ weightHG;
104
+ weightHO;
105
+ biasII;
106
+ biasIF;
107
+ biasIG;
108
+ biasIO;
109
+ biasHI;
110
+ biasHF;
111
+ biasHG;
112
+ biasHO;
113
+ constructor(inputSize, hiddenSize, bias = true, device) {
114
+ const bound = 1 / Math.sqrt(hiddenSize);
115
+ this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
116
+ this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
117
+ this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
118
+ this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
119
+ this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
120
+ this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
121
+ this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
122
+ this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
123
+ if (bias) {
124
+ this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
125
+ this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
126
+ this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
127
+ this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
128
+ this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
129
+ this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
130
+ this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
131
+ this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
45
132
  }
46
- return output.tanh();
133
+ }
134
+ forward(input, hidden, cell) {
135
+ input = core_1.Tensor.forceTensor(input);
136
+ hidden = core_1.Tensor.forceTensor(hidden);
137
+ cell = core_1.Tensor.forceTensor(cell);
138
+ const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
139
+ const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
140
+ const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
141
+ const o = rnnTransform(input, hidden, this.weightIO, this.weightHO, this.biasIO, this.biasHO).sigmoid();
142
+ const c = f.mul(cell).add(i.mul(g));
143
+ const h = o.mul(c.tanh());
144
+ return [h, c];
47
145
  }
48
146
  }
147
+ exports.LSTMCell = LSTMCell;
49
148
  const state = {
50
149
  getParameters(model, visited = new WeakSet()) {
51
- if (visited.has(model)) {
150
+ if (visited.has(model))
52
151
  return [];
53
- }
54
152
  visited.add(model);
55
153
  const parameters = [];
56
154
  for (const key in model) {
@@ -65,10 +163,47 @@ const state = {
65
163
  }
66
164
  }
67
165
  return parameters;
166
+ },
167
+ getStateDict(model, prefix = "", visited = new WeakSet()) {
168
+ if (visited.has(model))
169
+ return {};
170
+ visited.add(model);
171
+ const stateDict = {};
172
+ for (const key in model) {
173
+ if (!model.hasOwnProperty(key))
174
+ continue;
175
+ const value = model[key];
176
+ const fullKey = prefix ? `${prefix}.${key}` : key;
177
+ if (value instanceof core_1.Tensor) {
178
+ stateDict[fullKey] = value.val();
179
+ }
180
+ else if (typeof value === "object" && value !== null) {
181
+ Object.assign(stateDict, this.getStateDict(value, fullKey, visited));
182
+ }
183
+ }
184
+ return stateDict;
185
+ },
186
+ loadStateDict(model, stateDict, prefix = "", visited = new WeakSet()) {
187
+ if (visited.has(model))
188
+ return;
189
+ visited.add(model);
190
+ for (const key in model) {
191
+ if (!model.hasOwnProperty(key))
192
+ continue;
193
+ const value = model[key];
194
+ const fullKey = prefix ? `${prefix}.${key}` : key;
195
+ if (value instanceof core_1.Tensor && stateDict[fullKey]) {
196
+ value.replace(new core_1.Tensor(stateDict[fullKey]));
197
+ }
198
+ else if (typeof value === "object" && value !== null) {
199
+ this.loadStateDict(value, stateDict, fullKey, visited);
200
+ }
201
+ }
68
202
  }
69
203
  };
70
204
  exports.nn = {
71
205
  Linear,
72
206
  RNNCell,
207
+ GRUCell,
73
208
  state
74
209
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.1",
3
+ "version": "0.5.2",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {