catniff 0.8.5 → 0.8.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/nn.d.ts +34 -0
- package/dist/nn.js +185 -1
- package/package.json +1 -1
package/dist/nn.d.ts
CHANGED
|
@@ -50,6 +50,37 @@ export declare class LSTMCell {
|
|
|
50
50
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
51
51
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
52
52
|
}
|
|
53
|
+
export declare class BatchNorm {
|
|
54
|
+
weight?: Tensor;
|
|
55
|
+
bias?: Tensor;
|
|
56
|
+
runningMean?: Tensor;
|
|
57
|
+
runningVar?: Tensor;
|
|
58
|
+
eps: number;
|
|
59
|
+
momentum: number;
|
|
60
|
+
numFeatures: number;
|
|
61
|
+
affine: boolean;
|
|
62
|
+
trackRunningStats: boolean;
|
|
63
|
+
numBatchesTracked: number;
|
|
64
|
+
constructor(numFeatures: number, eps?: number, momentum?: number, affine?: boolean, trackRunningStats?: boolean, device?: string, dtype?: dtype);
|
|
65
|
+
forward(input: Tensor): Tensor;
|
|
66
|
+
}
|
|
67
|
+
export declare class InstanceNorm {
|
|
68
|
+
weight?: Tensor;
|
|
69
|
+
bias?: Tensor;
|
|
70
|
+
eps: number;
|
|
71
|
+
numFeatures: number;
|
|
72
|
+
constructor(numFeatures: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
73
|
+
forward(input: Tensor): Tensor;
|
|
74
|
+
}
|
|
75
|
+
export declare class GroupNorm {
|
|
76
|
+
weight?: Tensor;
|
|
77
|
+
bias?: Tensor;
|
|
78
|
+
eps: number;
|
|
79
|
+
numGroups: number;
|
|
80
|
+
numChannels: number;
|
|
81
|
+
constructor(numGroups: number, numChannels: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
82
|
+
forward(input: Tensor): Tensor;
|
|
83
|
+
}
|
|
53
84
|
export declare class LayerNorm {
|
|
54
85
|
weight?: Tensor;
|
|
55
86
|
bias?: Tensor;
|
|
@@ -91,6 +122,9 @@ export declare const nn: {
|
|
|
91
122
|
RNNCell: typeof RNNCell;
|
|
92
123
|
GRUCell: typeof GRUCell;
|
|
93
124
|
LSTMCell: typeof LSTMCell;
|
|
125
|
+
BatchNorm: typeof BatchNorm;
|
|
126
|
+
InstanceNorm: typeof InstanceNorm;
|
|
127
|
+
GroupNorm: typeof GroupNorm;
|
|
94
128
|
LayerNorm: typeof LayerNorm;
|
|
95
129
|
RMSNorm: typeof RMSNorm;
|
|
96
130
|
Embedding: typeof Embedding;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
5
5
|
const core_1 = require("./core");
|
|
6
6
|
function linearTransform(input, weight, bias) {
|
|
@@ -149,6 +149,187 @@ class LSTMCell {
|
|
|
149
149
|
}
|
|
150
150
|
}
|
|
151
151
|
exports.LSTMCell = LSTMCell;
|
|
152
|
+
class BatchNorm {
|
|
153
|
+
weight;
|
|
154
|
+
bias;
|
|
155
|
+
runningMean;
|
|
156
|
+
runningVar;
|
|
157
|
+
eps;
|
|
158
|
+
momentum;
|
|
159
|
+
numFeatures;
|
|
160
|
+
affine;
|
|
161
|
+
trackRunningStats;
|
|
162
|
+
numBatchesTracked;
|
|
163
|
+
constructor(numFeatures, eps = 1e-5, momentum = 0.1, affine = true, trackRunningStats = true, device, dtype) {
|
|
164
|
+
this.numFeatures = numFeatures;
|
|
165
|
+
this.eps = eps;
|
|
166
|
+
this.momentum = momentum;
|
|
167
|
+
this.affine = affine;
|
|
168
|
+
this.trackRunningStats = trackRunningStats;
|
|
169
|
+
this.numBatchesTracked = 0;
|
|
170
|
+
if (this.affine) {
|
|
171
|
+
this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype });
|
|
172
|
+
this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype });
|
|
173
|
+
}
|
|
174
|
+
if (this.trackRunningStats) {
|
|
175
|
+
this.runningMean = core_1.Tensor.zeros([numFeatures], { requiresGrad: false, device, dtype });
|
|
176
|
+
this.runningVar = core_1.Tensor.ones([numFeatures], { requiresGrad: false, device, dtype });
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
forward(input) {
|
|
180
|
+
// Input shape: (N, C, ...) where C = numFeatures
|
|
181
|
+
// Normalize over batch dimension and spatial dimensions (if any)
|
|
182
|
+
if (input.shape.length < 2) {
|
|
183
|
+
throw new Error("Input must have at least 2 dimensions (batch, features)");
|
|
184
|
+
}
|
|
185
|
+
if (input.shape[1] !== this.numFeatures) {
|
|
186
|
+
throw new Error(`Expected ${this.numFeatures} features, got ${input.shape[1]}`);
|
|
187
|
+
}
|
|
188
|
+
let mean;
|
|
189
|
+
let variance;
|
|
190
|
+
if (core_1.Tensor.training || !this.trackRunningStats) {
|
|
191
|
+
// Training or trackRunningStats disabled - calculate mean and variance from scratch
|
|
192
|
+
// Calculate mean and variance over batch and spatial dimensions
|
|
193
|
+
// Keep only the channel dimension
|
|
194
|
+
const dims = [0, ...Array.from({ length: input.shape.length - 2 }, (_, i) => i + 2)];
|
|
195
|
+
mean = input.mean(dims, true);
|
|
196
|
+
variance = input.sub(mean).pow(2).mean(dims, true);
|
|
197
|
+
// Update running statistics if enabled and in training mode
|
|
198
|
+
if (this.trackRunningStats && core_1.Tensor.training) {
|
|
199
|
+
const exponentialAverageFactor = this.momentum;
|
|
200
|
+
this.runningMean = this.runningMean
|
|
201
|
+
.mul(1 - exponentialAverageFactor)
|
|
202
|
+
.add(mean.squeeze().mul(exponentialAverageFactor));
|
|
203
|
+
// Use unbiased variance for running estimate
|
|
204
|
+
const n = input.shape.reduce((acc, val, idx) => idx === 1 ? acc : acc * val, 1);
|
|
205
|
+
const unbiasingFactor = n / (n - 1);
|
|
206
|
+
this.runningVar = this.runningVar
|
|
207
|
+
.mul(1 - exponentialAverageFactor)
|
|
208
|
+
.add(variance.squeeze().mul(exponentialAverageFactor * unbiasingFactor));
|
|
209
|
+
this.numBatchesTracked++;
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
else {
|
|
213
|
+
// Inference with trackRunningStats enabled - use running statistics
|
|
214
|
+
mean = this.runningMean.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
215
|
+
variance = this.runningVar.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
216
|
+
}
|
|
217
|
+
// Normalize
|
|
218
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
219
|
+
// Apply affine transformation
|
|
220
|
+
if (this.affine) {
|
|
221
|
+
const weightReshaped = this.weight.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
222
|
+
const biasReshaped = this.bias.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
223
|
+
normalized = normalized.mul(weightReshaped).add(biasReshaped);
|
|
224
|
+
}
|
|
225
|
+
return normalized;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
exports.BatchNorm = BatchNorm;
|
|
229
|
+
class InstanceNorm {
|
|
230
|
+
weight;
|
|
231
|
+
bias;
|
|
232
|
+
eps;
|
|
233
|
+
numFeatures;
|
|
234
|
+
constructor(numFeatures, eps = 1e-5, affine = true, device, dtype) {
|
|
235
|
+
this.numFeatures = numFeatures;
|
|
236
|
+
this.eps = eps;
|
|
237
|
+
if (affine) {
|
|
238
|
+
this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype });
|
|
239
|
+
this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype });
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
forward(input) {
|
|
243
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
244
|
+
if (input.shape.length < 3) {
|
|
245
|
+
throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
|
|
246
|
+
}
|
|
247
|
+
if (input.shape[1] !== this.numFeatures) {
|
|
248
|
+
throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
|
|
249
|
+
}
|
|
250
|
+
// Normalize across spatial dimensions (all dims after channel dim)
|
|
251
|
+
const dims = [];
|
|
252
|
+
for (let i = 2; i < input.shape.length; i++) {
|
|
253
|
+
dims.push(i);
|
|
254
|
+
}
|
|
255
|
+
const mean = input.mean(dims, true);
|
|
256
|
+
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
257
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
258
|
+
if (this.weight) {
|
|
259
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
260
|
+
const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
261
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
262
|
+
normalized = normalized.mul(weightReshaped);
|
|
263
|
+
}
|
|
264
|
+
if (this.bias) {
|
|
265
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
266
|
+
const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
267
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
268
|
+
normalized = normalized.add(biasReshaped);
|
|
269
|
+
}
|
|
270
|
+
return normalized;
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
exports.InstanceNorm = InstanceNorm;
|
|
274
|
+
class GroupNorm {
|
|
275
|
+
weight;
|
|
276
|
+
bias;
|
|
277
|
+
eps;
|
|
278
|
+
numGroups;
|
|
279
|
+
numChannels;
|
|
280
|
+
constructor(numGroups, numChannels, eps = 1e-5, affine = true, device, dtype) {
|
|
281
|
+
if (numChannels % numGroups !== 0) {
|
|
282
|
+
throw new Error(`num_channels (${numChannels}) must be divisible by num_groups (${numGroups})`);
|
|
283
|
+
}
|
|
284
|
+
this.numGroups = numGroups;
|
|
285
|
+
this.numChannels = numChannels;
|
|
286
|
+
this.eps = eps;
|
|
287
|
+
if (affine) {
|
|
288
|
+
this.weight = core_1.Tensor.ones([numChannels], { requiresGrad: true, device, dtype });
|
|
289
|
+
this.bias = core_1.Tensor.zeros([numChannels], { requiresGrad: true, device, dtype });
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
forward(input) {
|
|
293
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
294
|
+
if (input.shape.length < 3) {
|
|
295
|
+
throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
|
|
296
|
+
}
|
|
297
|
+
if (input.shape[1] !== this.numChannels) {
|
|
298
|
+
throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
|
|
299
|
+
}
|
|
300
|
+
const N = input.shape[0];
|
|
301
|
+
const C = input.shape[1];
|
|
302
|
+
const spatialDims = input.shape.slice(2);
|
|
303
|
+
const channelsPerGroup = C / this.numGroups;
|
|
304
|
+
// Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
|
|
305
|
+
const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
|
|
306
|
+
// Normalize across (C//G, ...spatial) dimensions for each group
|
|
307
|
+
// That's dims [2, 3, 4, ...] in the reshaped tensor
|
|
308
|
+
const dims = [];
|
|
309
|
+
for (let i = 2; i < reshapedInput.shape.length; i++) {
|
|
310
|
+
dims.push(i);
|
|
311
|
+
}
|
|
312
|
+
const mean = reshapedInput.mean(dims, true);
|
|
313
|
+
const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
|
|
314
|
+
let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
|
|
315
|
+
// Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
|
|
316
|
+
normalized = normalized.reshape(input.shape);
|
|
317
|
+
if (this.weight) {
|
|
318
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
319
|
+
const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
320
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
321
|
+
normalized = normalized.mul(weightReshaped);
|
|
322
|
+
}
|
|
323
|
+
if (this.bias) {
|
|
324
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
325
|
+
const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
326
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
327
|
+
normalized = normalized.add(biasReshaped);
|
|
328
|
+
}
|
|
329
|
+
return normalized;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
exports.GroupNorm = GroupNorm;
|
|
152
333
|
class LayerNorm {
|
|
153
334
|
weight;
|
|
154
335
|
bias;
|
|
@@ -383,6 +564,9 @@ exports.nn = {
|
|
|
383
564
|
RNNCell,
|
|
384
565
|
GRUCell,
|
|
385
566
|
LSTMCell,
|
|
567
|
+
BatchNorm,
|
|
568
|
+
InstanceNorm,
|
|
569
|
+
GroupNorm,
|
|
386
570
|
LayerNorm,
|
|
387
571
|
RMSNorm,
|
|
388
572
|
Embedding,
|