catniff 0.5.6 → 0.5.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/dist/nn.d.ts +11 -1
  2. package/dist/nn.js +50 -4
  3. package/package.json +1 -1
package/dist/nn.d.ts CHANGED
@@ -29,7 +29,7 @@ declare class GRUCell {
29
29
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
30
30
  forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
31
31
  }
32
- export declare class LSTMCell {
32
+ declare class LSTMCell {
33
33
  weightII: Tensor;
34
34
  weightIF: Tensor;
35
35
  weightIG: Tensor;
@@ -49,6 +49,14 @@ export declare class LSTMCell {
49
49
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
50
50
  forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
51
51
  }
52
+ declare class LayerNorm {
53
+ weight?: Tensor;
54
+ bias?: Tensor;
55
+ eps: number;
56
+ normalizedShape: number[];
57
+ constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
58
+ forward(input: Tensor | TensorValue): Tensor;
59
+ }
52
60
  interface StateDict {
53
61
  [key: string]: any;
54
62
  }
@@ -56,6 +64,8 @@ export declare const nn: {
56
64
  Linear: typeof Linear;
57
65
  RNNCell: typeof RNNCell;
58
66
  GRUCell: typeof GRUCell;
67
+ LSTMCell: typeof LSTMCell;
68
+ LayerNorm: typeof LayerNorm;
59
69
  state: {
60
70
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
61
71
  getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
package/dist/nn.js CHANGED
@@ -1,6 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = exports.LSTMCell = void 0;
3
+ exports.nn = void 0;
4
4
  const core_1 = require("./core");
5
5
  function linearTransform(input, weight, bias) {
6
6
  let output = input.matmul(weight.t());
@@ -144,7 +144,51 @@ class LSTMCell {
144
144
  return [h, c];
145
145
  }
146
146
  }
147
- exports.LSTMCell = LSTMCell;
147
+ class LayerNorm {
148
+ weight;
149
+ bias;
150
+ eps;
151
+ normalizedShape;
152
+ constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device) {
153
+ this.eps = eps;
154
+ this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
155
+ if (this.normalizedShape.length === 0) {
156
+ throw new Error("Normalized shape cannot be empty");
157
+ }
158
+ if (elementwiseAffine) {
159
+ this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
160
+ if (bias) {
161
+ this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device });
162
+ }
163
+ }
164
+ }
165
+ forward(input) {
166
+ input = core_1.Tensor.forceTensor(input);
167
+ // Normalize over the specified dimensions
168
+ const normalizedDims = this.normalizedShape.length;
169
+ const startDim = input.shape.length - normalizedDims;
170
+ if (startDim < 0) {
171
+ throw new Error("Input does not have enough dims to normalize");
172
+ }
173
+ const dims = [];
174
+ for (let i = 0; i < normalizedDims; i++) {
175
+ if (input.shape[startDim + i] !== this.normalizedShape[i]) {
176
+ throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
177
+ }
178
+ dims.push(startDim + i);
179
+ }
180
+ const mean = input.mean(dims, true);
181
+ const variance = input.sub(mean).pow(2).mean(dims, true);
182
+ let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
183
+ if (this.weight) {
184
+ normalized = normalized.mul(this.weight);
185
+ }
186
+ if (this.bias) {
187
+ normalized = normalized.add(this.bias);
188
+ }
189
+ return normalized;
190
+ }
191
+ }
148
192
  const state = {
149
193
  getParameters(model, visited = new WeakSet()) {
150
194
  if (visited.has(model))
@@ -178,7 +222,7 @@ const state = {
178
222
  stateDict[fullKey] = value.val();
179
223
  }
180
224
  else if (typeof value === "object" && value !== null) {
181
- Object.assign(stateDict, this.getStateDict(value, fullKey, visited));
225
+ Object.assign(stateDict, state.getStateDict(value, fullKey, visited));
182
226
  }
183
227
  }
184
228
  return stateDict;
@@ -196,7 +240,7 @@ const state = {
196
240
  value.replace(new core_1.Tensor(stateDict[fullKey], { device: value.device }));
197
241
  }
198
242
  else if (typeof value === "object" && value !== null) {
199
- this.loadStateDict(value, stateDict, fullKey, visited);
243
+ state.loadStateDict(value, stateDict, fullKey, visited);
200
244
  }
201
245
  }
202
246
  }
@@ -205,5 +249,7 @@ exports.nn = {
205
249
  Linear,
206
250
  RNNCell,
207
251
  GRUCell,
252
+ LSTMCell,
253
+ LayerNorm,
208
254
  state
209
255
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.6",
3
+ "version": "0.5.7",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {