catniff 0.8.6 → 0.8.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -0
- package/dist/core.js +36 -7
- package/dist/nn.d.ts +19 -0
- package/dist/nn.js +107 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -99,6 +99,8 @@ export declare class Tensor {
|
|
|
99
99
|
std(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
100
100
|
softmax(dim?: number): Tensor;
|
|
101
101
|
softmin(dim?: number): Tensor;
|
|
102
|
+
logsumexp(dim?: number): Tensor;
|
|
103
|
+
logSoftmax(dim?: number): Tensor;
|
|
102
104
|
add(other: TensorValue | Tensor): Tensor;
|
|
103
105
|
sub(other: TensorValue | Tensor): Tensor;
|
|
104
106
|
subtract: (other: TensorValue | Tensor) => Tensor;
|
package/dist/core.js
CHANGED
|
@@ -700,7 +700,7 @@ class Tensor {
|
|
|
700
700
|
}
|
|
701
701
|
// If dimension out of bound, throw error
|
|
702
702
|
if (dim >= this.shape.length || dim < 0) {
|
|
703
|
-
throw new Error("Dimension
|
|
703
|
+
throw new Error("Dimension does not exist to chunk");
|
|
704
704
|
}
|
|
705
705
|
const sliceOpt = new Array(this.shape.length);
|
|
706
706
|
for (let index = 0; index < sliceOpt.length; index++) {
|
|
@@ -944,7 +944,7 @@ class Tensor {
|
|
|
944
944
|
}
|
|
945
945
|
// If dimension out of bound, throw error
|
|
946
946
|
if (dim >= this.shape.length || dim < 0) {
|
|
947
|
-
throw new Error("Dimension
|
|
947
|
+
throw new Error("Dimension does not exist to sort");
|
|
948
948
|
}
|
|
949
949
|
// Copy if not contiguous
|
|
950
950
|
const outputSize = this.numel;
|
|
@@ -1032,7 +1032,7 @@ class Tensor {
|
|
|
1032
1032
|
}
|
|
1033
1033
|
// If dimension out of bound, throw error
|
|
1034
1034
|
if (dim >= this.shape.length || dim < 0) {
|
|
1035
|
-
throw new Error("Dimension
|
|
1035
|
+
throw new Error("Dimension does not exist to get topk");
|
|
1036
1036
|
}
|
|
1037
1037
|
const dimRanges = new Array(this.shape.length);
|
|
1038
1038
|
for (let index = 0; index < dimRanges.length; index++) {
|
|
@@ -1207,7 +1207,7 @@ class Tensor {
|
|
|
1207
1207
|
std(dims, keepDims = false) {
|
|
1208
1208
|
return this.var(dims, keepDims).sqrt();
|
|
1209
1209
|
}
|
|
1210
|
-
// Tensor softmax
|
|
1210
|
+
// Tensor (stable) softmax
|
|
1211
1211
|
softmax(dim = -1) {
|
|
1212
1212
|
if (this.shape.length === 0)
|
|
1213
1213
|
return this;
|
|
@@ -1217,7 +1217,7 @@ class Tensor {
|
|
|
1217
1217
|
}
|
|
1218
1218
|
// If dimension out of bound, throw error
|
|
1219
1219
|
if (dim >= this.shape.length || dim < 0) {
|
|
1220
|
-
throw new Error("Dimension
|
|
1220
|
+
throw new Error("Dimension does not exist to apply softmax");
|
|
1221
1221
|
}
|
|
1222
1222
|
const maxVals = this.max(dim, true);
|
|
1223
1223
|
const shifted = this.sub(maxVals);
|
|
@@ -1225,7 +1225,7 @@ class Tensor {
|
|
|
1225
1225
|
const sumExp = expVals.sum(dim, true);
|
|
1226
1226
|
return expVals.div(sumExp);
|
|
1227
1227
|
}
|
|
1228
|
-
// Tensor softmin
|
|
1228
|
+
// Tensor (stable) softmin
|
|
1229
1229
|
softmin(dim = -1) {
|
|
1230
1230
|
if (this.shape.length === 0)
|
|
1231
1231
|
return this;
|
|
@@ -1235,7 +1235,7 @@ class Tensor {
|
|
|
1235
1235
|
}
|
|
1236
1236
|
// If dimension out of bound, throw error
|
|
1237
1237
|
if (dim >= this.shape.length || dim < 0) {
|
|
1238
|
-
throw new Error("Dimension
|
|
1238
|
+
throw new Error("Dimension does not exist to apply softmin");
|
|
1239
1239
|
}
|
|
1240
1240
|
const maxVals = this.max(dim, true);
|
|
1241
1241
|
const shifted = maxVals.sub(this);
|
|
@@ -1243,6 +1243,35 @@ class Tensor {
|
|
|
1243
1243
|
const sumExp = expVals.sum(dim, true);
|
|
1244
1244
|
return expVals.div(sumExp);
|
|
1245
1245
|
}
|
|
1246
|
+
// Tensor (stable) logsumexp
|
|
1247
|
+
logsumexp(dim = -1) {
|
|
1248
|
+
if (this.shape.length === 0)
|
|
1249
|
+
return this;
|
|
1250
|
+
// Handle negative indexing
|
|
1251
|
+
if (dim < 0) {
|
|
1252
|
+
dim += this.shape.length;
|
|
1253
|
+
}
|
|
1254
|
+
// If dimension out of bound, throw error
|
|
1255
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
1256
|
+
throw new Error("Dimension does not exist to apply logsumexp");
|
|
1257
|
+
}
|
|
1258
|
+
const max = this.max(dim, true);
|
|
1259
|
+
return max.add(this.sub(max).exp().sum(dim, true).log());
|
|
1260
|
+
}
|
|
1261
|
+
// Tensor (stable) logsumexp
|
|
1262
|
+
logSoftmax(dim = -1) {
|
|
1263
|
+
if (this.shape.length === 0)
|
|
1264
|
+
return this;
|
|
1265
|
+
// Handle negative indexing
|
|
1266
|
+
if (dim < 0) {
|
|
1267
|
+
dim += this.shape.length;
|
|
1268
|
+
}
|
|
1269
|
+
// If dimension out of bound, throw error
|
|
1270
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
1271
|
+
throw new Error("Dimension does not exist to apply logsumexp");
|
|
1272
|
+
}
|
|
1273
|
+
return this.sub(this.logsumexp(dim));
|
|
1274
|
+
}
|
|
1246
1275
|
// Tensor element-wise addition
|
|
1247
1276
|
add(other) {
|
|
1248
1277
|
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
|
package/dist/nn.d.ts
CHANGED
|
@@ -64,6 +64,23 @@ export declare class BatchNorm {
|
|
|
64
64
|
constructor(numFeatures: number, eps?: number, momentum?: number, affine?: boolean, trackRunningStats?: boolean, device?: string, dtype?: dtype);
|
|
65
65
|
forward(input: Tensor): Tensor;
|
|
66
66
|
}
|
|
67
|
+
export declare class InstanceNorm {
|
|
68
|
+
weight?: Tensor;
|
|
69
|
+
bias?: Tensor;
|
|
70
|
+
eps: number;
|
|
71
|
+
numFeatures: number;
|
|
72
|
+
constructor(numFeatures: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
73
|
+
forward(input: Tensor): Tensor;
|
|
74
|
+
}
|
|
75
|
+
export declare class GroupNorm {
|
|
76
|
+
weight?: Tensor;
|
|
77
|
+
bias?: Tensor;
|
|
78
|
+
eps: number;
|
|
79
|
+
numGroups: number;
|
|
80
|
+
numChannels: number;
|
|
81
|
+
constructor(numGroups: number, numChannels: number, eps?: number, affine?: boolean, device?: string, dtype?: dtype);
|
|
82
|
+
forward(input: Tensor): Tensor;
|
|
83
|
+
}
|
|
67
84
|
export declare class LayerNorm {
|
|
68
85
|
weight?: Tensor;
|
|
69
86
|
bias?: Tensor;
|
|
@@ -106,6 +123,8 @@ export declare const nn: {
|
|
|
106
123
|
GRUCell: typeof GRUCell;
|
|
107
124
|
LSTMCell: typeof LSTMCell;
|
|
108
125
|
BatchNorm: typeof BatchNorm;
|
|
126
|
+
InstanceNorm: typeof InstanceNorm;
|
|
127
|
+
GroupNorm: typeof GroupNorm;
|
|
109
128
|
LayerNorm: typeof LayerNorm;
|
|
110
129
|
RMSNorm: typeof RMSNorm;
|
|
111
130
|
Embedding: typeof Embedding;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
5
5
|
const core_1 = require("./core");
|
|
6
6
|
function linearTransform(input, weight, bias) {
|
|
@@ -226,6 +226,110 @@ class BatchNorm {
|
|
|
226
226
|
}
|
|
227
227
|
}
|
|
228
228
|
exports.BatchNorm = BatchNorm;
|
|
229
|
+
class InstanceNorm {
|
|
230
|
+
weight;
|
|
231
|
+
bias;
|
|
232
|
+
eps;
|
|
233
|
+
numFeatures;
|
|
234
|
+
constructor(numFeatures, eps = 1e-5, affine = true, device, dtype) {
|
|
235
|
+
this.numFeatures = numFeatures;
|
|
236
|
+
this.eps = eps;
|
|
237
|
+
if (affine) {
|
|
238
|
+
this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype });
|
|
239
|
+
this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype });
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
forward(input) {
|
|
243
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
244
|
+
if (input.shape.length < 3) {
|
|
245
|
+
throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
|
|
246
|
+
}
|
|
247
|
+
if (input.shape[1] !== this.numFeatures) {
|
|
248
|
+
throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
|
|
249
|
+
}
|
|
250
|
+
// Normalize across spatial dimensions (all dims after channel dim)
|
|
251
|
+
const dims = [];
|
|
252
|
+
for (let i = 2; i < input.shape.length; i++) {
|
|
253
|
+
dims.push(i);
|
|
254
|
+
}
|
|
255
|
+
const mean = input.mean(dims, true);
|
|
256
|
+
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
257
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
258
|
+
if (this.weight) {
|
|
259
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
260
|
+
const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
261
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
262
|
+
normalized = normalized.mul(weightReshaped);
|
|
263
|
+
}
|
|
264
|
+
if (this.bias) {
|
|
265
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
266
|
+
const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
267
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
268
|
+
normalized = normalized.add(biasReshaped);
|
|
269
|
+
}
|
|
270
|
+
return normalized;
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
exports.InstanceNorm = InstanceNorm;
|
|
274
|
+
class GroupNorm {
|
|
275
|
+
weight;
|
|
276
|
+
bias;
|
|
277
|
+
eps;
|
|
278
|
+
numGroups;
|
|
279
|
+
numChannels;
|
|
280
|
+
constructor(numGroups, numChannels, eps = 1e-5, affine = true, device, dtype) {
|
|
281
|
+
if (numChannels % numGroups !== 0) {
|
|
282
|
+
throw new Error(`num_channels (${numChannels}) must be divisible by num_groups (${numGroups})`);
|
|
283
|
+
}
|
|
284
|
+
this.numGroups = numGroups;
|
|
285
|
+
this.numChannels = numChannels;
|
|
286
|
+
this.eps = eps;
|
|
287
|
+
if (affine) {
|
|
288
|
+
this.weight = core_1.Tensor.ones([numChannels], { requiresGrad: true, device, dtype });
|
|
289
|
+
this.bias = core_1.Tensor.zeros([numChannels], { requiresGrad: true, device, dtype });
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
forward(input) {
|
|
293
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
294
|
+
if (input.shape.length < 3) {
|
|
295
|
+
throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
|
|
296
|
+
}
|
|
297
|
+
if (input.shape[1] !== this.numChannels) {
|
|
298
|
+
throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
|
|
299
|
+
}
|
|
300
|
+
const N = input.shape[0];
|
|
301
|
+
const C = input.shape[1];
|
|
302
|
+
const spatialDims = input.shape.slice(2);
|
|
303
|
+
const channelsPerGroup = C / this.numGroups;
|
|
304
|
+
// Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
|
|
305
|
+
const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
|
|
306
|
+
// Normalize across (C//G, ...spatial) dimensions for each group
|
|
307
|
+
// That's dims [2, 3, 4, ...] in the reshaped tensor
|
|
308
|
+
const dims = [];
|
|
309
|
+
for (let i = 2; i < reshapedInput.shape.length; i++) {
|
|
310
|
+
dims.push(i);
|
|
311
|
+
}
|
|
312
|
+
const mean = reshapedInput.mean(dims, true);
|
|
313
|
+
const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
|
|
314
|
+
let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
|
|
315
|
+
// Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
|
|
316
|
+
normalized = normalized.reshape(input.shape);
|
|
317
|
+
if (this.weight) {
|
|
318
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
319
|
+
const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
320
|
+
const weightReshaped = this.weight.reshape(weightShape);
|
|
321
|
+
normalized = normalized.mul(weightReshaped);
|
|
322
|
+
}
|
|
323
|
+
if (this.bias) {
|
|
324
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
325
|
+
const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
326
|
+
const biasReshaped = this.bias.reshape(biasShape);
|
|
327
|
+
normalized = normalized.add(biasReshaped);
|
|
328
|
+
}
|
|
329
|
+
return normalized;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
exports.GroupNorm = GroupNorm;
|
|
229
333
|
class LayerNorm {
|
|
230
334
|
weight;
|
|
231
335
|
bias;
|
|
@@ -461,6 +565,8 @@ exports.nn = {
|
|
|
461
565
|
GRUCell,
|
|
462
566
|
LSTMCell,
|
|
463
567
|
BatchNorm,
|
|
568
|
+
InstanceNorm,
|
|
569
|
+
GroupNorm,
|
|
464
570
|
LayerNorm,
|
|
465
571
|
RMSNorm,
|
|
466
572
|
Embedding,
|