catniff 0.5.6 → 0.5.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/nn.d.ts +11 -1
- package/dist/nn.js +50 -4
- package/package.json +1 -1
package/dist/nn.d.ts
CHANGED
|
@@ -29,7 +29,7 @@ declare class GRUCell {
|
|
|
29
29
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
30
30
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
31
31
|
}
|
|
32
|
-
|
|
32
|
+
declare class LSTMCell {
|
|
33
33
|
weightII: Tensor;
|
|
34
34
|
weightIF: Tensor;
|
|
35
35
|
weightIG: Tensor;
|
|
@@ -49,6 +49,14 @@ export declare class LSTMCell {
|
|
|
49
49
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
50
50
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
51
51
|
}
|
|
52
|
+
declare class LayerNorm {
|
|
53
|
+
weight?: Tensor;
|
|
54
|
+
bias?: Tensor;
|
|
55
|
+
eps: number;
|
|
56
|
+
normalizedShape: number[];
|
|
57
|
+
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
|
+
forward(input: Tensor | TensorValue): Tensor;
|
|
59
|
+
}
|
|
52
60
|
interface StateDict {
|
|
53
61
|
[key: string]: any;
|
|
54
62
|
}
|
|
@@ -56,6 +64,8 @@ export declare const nn: {
|
|
|
56
64
|
Linear: typeof Linear;
|
|
57
65
|
RNNCell: typeof RNNCell;
|
|
58
66
|
GRUCell: typeof GRUCell;
|
|
67
|
+
LSTMCell: typeof LSTMCell;
|
|
68
|
+
LayerNorm: typeof LayerNorm;
|
|
59
69
|
state: {
|
|
60
70
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
61
71
|
getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn =
|
|
3
|
+
exports.nn = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
function linearTransform(input, weight, bias) {
|
|
6
6
|
let output = input.matmul(weight.t());
|
|
@@ -144,7 +144,51 @@ class LSTMCell {
|
|
|
144
144
|
return [h, c];
|
|
145
145
|
}
|
|
146
146
|
}
|
|
147
|
-
|
|
147
|
+
class LayerNorm {
|
|
148
|
+
weight;
|
|
149
|
+
bias;
|
|
150
|
+
eps;
|
|
151
|
+
normalizedShape;
|
|
152
|
+
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device) {
|
|
153
|
+
this.eps = eps;
|
|
154
|
+
this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
|
|
155
|
+
if (this.normalizedShape.length === 0) {
|
|
156
|
+
throw new Error("Normalized shape cannot be empty");
|
|
157
|
+
}
|
|
158
|
+
if (elementwiseAffine) {
|
|
159
|
+
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
|
|
160
|
+
if (bias) {
|
|
161
|
+
this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device });
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
forward(input) {
|
|
166
|
+
input = core_1.Tensor.forceTensor(input);
|
|
167
|
+
// Normalize over the specified dimensions
|
|
168
|
+
const normalizedDims = this.normalizedShape.length;
|
|
169
|
+
const startDim = input.shape.length - normalizedDims;
|
|
170
|
+
if (startDim < 0) {
|
|
171
|
+
throw new Error("Input does not have enough dims to normalize");
|
|
172
|
+
}
|
|
173
|
+
const dims = [];
|
|
174
|
+
for (let i = 0; i < normalizedDims; i++) {
|
|
175
|
+
if (input.shape[startDim + i] !== this.normalizedShape[i]) {
|
|
176
|
+
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
|
|
177
|
+
}
|
|
178
|
+
dims.push(startDim + i);
|
|
179
|
+
}
|
|
180
|
+
const mean = input.mean(dims, true);
|
|
181
|
+
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
182
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
183
|
+
if (this.weight) {
|
|
184
|
+
normalized = normalized.mul(this.weight);
|
|
185
|
+
}
|
|
186
|
+
if (this.bias) {
|
|
187
|
+
normalized = normalized.add(this.bias);
|
|
188
|
+
}
|
|
189
|
+
return normalized;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
148
192
|
const state = {
|
|
149
193
|
getParameters(model, visited = new WeakSet()) {
|
|
150
194
|
if (visited.has(model))
|
|
@@ -178,7 +222,7 @@ const state = {
|
|
|
178
222
|
stateDict[fullKey] = value.val();
|
|
179
223
|
}
|
|
180
224
|
else if (typeof value === "object" && value !== null) {
|
|
181
|
-
Object.assign(stateDict,
|
|
225
|
+
Object.assign(stateDict, state.getStateDict(value, fullKey, visited));
|
|
182
226
|
}
|
|
183
227
|
}
|
|
184
228
|
return stateDict;
|
|
@@ -196,7 +240,7 @@ const state = {
|
|
|
196
240
|
value.replace(new core_1.Tensor(stateDict[fullKey], { device: value.device }));
|
|
197
241
|
}
|
|
198
242
|
else if (typeof value === "object" && value !== null) {
|
|
199
|
-
|
|
243
|
+
state.loadStateDict(value, stateDict, fullKey, visited);
|
|
200
244
|
}
|
|
201
245
|
}
|
|
202
246
|
}
|
|
@@ -205,5 +249,7 @@ exports.nn = {
|
|
|
205
249
|
Linear,
|
|
206
250
|
RNNCell,
|
|
207
251
|
GRUCell,
|
|
252
|
+
LSTMCell,
|
|
253
|
+
LayerNorm,
|
|
208
254
|
state
|
|
209
255
|
};
|