catniff 0.8.6 → 0.8.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/nn.d.ts +19 -0
- package/dist/nn.js +107 -1
- package/package.json +1 -1
package/dist/nn.d.ts
CHANGED
|
@@ -64,6 +64,23 @@ export declare class BatchNorm {
|
|
|
64
64
|
constructor(numFeatures: number, eps?: number, momentum?: number, affine?: boolean, trackRunningStats?: boolean, device?: string, dtype?: dtype);
|
|
65
65
|
forward(input: Tensor): Tensor;
|
|
66
66
|
}
|
|
67
|
+
export declare class InstanceNorm {
|
|
68
|
+
weight?: Tensor;
|
|
69
|
+
bias?: Tensor;
|
|
70
|
+
eps: number;
|
|
71
|
+
numFeatures: number;
|
|
72
|
+
constructor(numFeatures: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
73
|
+
forward(input: Tensor): Tensor;
|
|
74
|
+
}
|
|
75
|
+
export declare class GroupNorm {
|
|
76
|
+
weight?: Tensor;
|
|
77
|
+
bias?: Tensor;
|
|
78
|
+
eps: number;
|
|
79
|
+
numGroups: number;
|
|
80
|
+
numChannels: number;
|
|
81
|
+
constructor(numGroups: number, numChannels: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
82
|
+
forward(input: Tensor): Tensor;
|
|
83
|
+
}
|
|
67
84
|
export declare class LayerNorm {
|
|
68
85
|
weight?: Tensor;
|
|
69
86
|
bias?: Tensor;
|
|
@@ -106,6 +123,8 @@ export declare const nn: {
|
|
|
106
123
|
GRUCell: typeof GRUCell;
|
|
107
124
|
LSTMCell: typeof LSTMCell;
|
|
108
125
|
BatchNorm: typeof BatchNorm;
|
|
126
|
+
InstanceNorm: typeof InstanceNorm;
|
|
127
|
+
GroupNorm: typeof GroupNorm;
|
|
109
128
|
LayerNorm: typeof LayerNorm;
|
|
110
129
|
RMSNorm: typeof RMSNorm;
|
|
111
130
|
Embedding: typeof Embedding;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
5
5
|
const core_1 = require("./core");
|
|
6
6
|
function linearTransform(input, weight, bias) {
|
|
@@ -226,6 +226,110 @@ class BatchNorm {
|
|
|
226
226
|
}
|
|
227
227
|
}
|
|
228
228
|
exports.BatchNorm = BatchNorm;
|
|
229
|
+
class InstanceNorm {
|
|
230
|
+
weight;
|
|
231
|
+
bias;
|
|
232
|
+
eps;
|
|
233
|
+
numFeatures;
|
|
234
|
+
constructor(numFeatures, eps = 1e-5, affine = true, device, dtype) {
|
|
235
|
+
this.numFeatures = numFeatures;
|
|
236
|
+
this.eps = eps;
|
|
237
|
+
if (affine) {
|
|
238
|
+
this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype });
|
|
239
|
+
this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype });
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
forward(input) {
|
|
243
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
244
|
+
if (input.shape.length < 3) {
|
|
245
|
+
throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
|
|
246
|
+
}
|
|
247
|
+
if (input.shape[1] !== this.numFeatures) {
|
|
248
|
+
throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
|
|
249
|
+
}
|
|
250
|
+
// Normalize across spatial dimensions (all dims after channel dim)
|
|
251
|
+
const dims = [];
|
|
252
|
+
for (let i = 2; i < input.shape.length; i++) {
|
|
253
|
+
dims.push(i);
|
|
254
|
+
}
|
|
255
|
+
const mean = input.mean(dims, true);
|
|
256
|
+
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
257
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
258
|
+
if (this.weight) {
|
|
259
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
260
|
+
const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
261
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
262
|
+
normalized = normalized.mul(weightReshaped);
|
|
263
|
+
}
|
|
264
|
+
if (this.bias) {
|
|
265
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
266
|
+
const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
267
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
268
|
+
normalized = normalized.add(biasReshaped);
|
|
269
|
+
}
|
|
270
|
+
return normalized;
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
exports.InstanceNorm = InstanceNorm;
|
|
274
|
+
class GroupNorm {
|
|
275
|
+
weight;
|
|
276
|
+
bias;
|
|
277
|
+
eps;
|
|
278
|
+
numGroups;
|
|
279
|
+
numChannels;
|
|
280
|
+
constructor(numGroups, numChannels, eps = 1e-5, affine = true, device, dtype) {
|
|
281
|
+
if (numChannels % numGroups !== 0) {
|
|
282
|
+
throw new Error(`num_channels (${numChannels}) must be divisible by num_groups (${numGroups})`);
|
|
283
|
+
}
|
|
284
|
+
this.numGroups = numGroups;
|
|
285
|
+
this.numChannels = numChannels;
|
|
286
|
+
this.eps = eps;
|
|
287
|
+
if (affine) {
|
|
288
|
+
this.weight = core_1.Tensor.ones([numChannels], { requiresGrad: true, device, dtype });
|
|
289
|
+
this.bias = core_1.Tensor.zeros([numChannels], { requiresGrad: true, device, dtype });
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
forward(input) {
|
|
293
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
294
|
+
if (input.shape.length < 3) {
|
|
295
|
+
throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
|
|
296
|
+
}
|
|
297
|
+
if (input.shape[1] !== this.numChannels) {
|
|
298
|
+
throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
|
|
299
|
+
}
|
|
300
|
+
const N = input.shape[0];
|
|
301
|
+
const C = input.shape[1];
|
|
302
|
+
const spatialDims = input.shape.slice(2);
|
|
303
|
+
const channelsPerGroup = C / this.numGroups;
|
|
304
|
+
// Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
|
|
305
|
+
const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
|
|
306
|
+
// Normalize across (C//G, ...spatial) dimensions for each group
|
|
307
|
+
// That's dims [2, 3, 4, ...] in the reshaped tensor
|
|
308
|
+
const dims = [];
|
|
309
|
+
for (let i = 2; i < reshapedInput.shape.length; i++) {
|
|
310
|
+
dims.push(i);
|
|
311
|
+
}
|
|
312
|
+
const mean = reshapedInput.mean(dims, true);
|
|
313
|
+
const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
|
|
314
|
+
let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
|
|
315
|
+
// Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
|
|
316
|
+
normalized = normalized.reshape(input.shape);
|
|
317
|
+
if (this.weight) {
|
|
318
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
319
|
+
const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
320
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
321
|
+
normalized = normalized.mul(weightReshaped);
|
|
322
|
+
}
|
|
323
|
+
if (this.bias) {
|
|
324
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
325
|
+
const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
326
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
327
|
+
normalized = normalized.add(biasReshaped);
|
|
328
|
+
}
|
|
329
|
+
return normalized;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
exports.GroupNorm = GroupNorm;
|
|
229
333
|
class LayerNorm {
|
|
230
334
|
weight;
|
|
231
335
|
bias;
|
|
@@ -461,6 +565,8 @@ exports.nn = {
|
|
|
461
565
|
GRUCell,
|
|
462
566
|
LSTMCell,
|
|
463
567
|
BatchNorm,
|
|
568
|
+
InstanceNorm,
|
|
569
|
+
GroupNorm,
|
|
464
570
|
LayerNorm,
|
|
465
571
|
RMSNorm,
|
|
466
572
|
Embedding,
|