catniff 0.8.4 → 0.8.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +47 -0
- package/dist/nn.d.ts +15 -0
- package/dist/nn.js +79 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
package/dist/core.js
CHANGED
|
@@ -2396,6 +2396,53 @@ class Tensor {
|
|
|
2396
2396
|
}
|
|
2397
2397
|
return buildNested(this.value, this.shape, this.strides, this.offset);
|
|
2398
2398
|
}
|
|
2399
|
+
// Returns the nicely Pytorch-like formatted string form
|
|
2400
|
+
toString() {
|
|
2401
|
+
const val = this.val();
|
|
2402
|
+
// Format a single number (integers get trailing dot)
|
|
2403
|
+
const formatNum = (n) => {
|
|
2404
|
+
if (Number.isInteger(n) && Math.abs(n) < 1e8) {
|
|
2405
|
+
return n.toFixed(0) + ".";
|
|
2406
|
+
}
|
|
2407
|
+
return n.toString();
|
|
2408
|
+
};
|
|
2409
|
+
// Handle scalar
|
|
2410
|
+
if (typeof val === "number") {
|
|
2411
|
+
return `tensor(${formatNum(val)})`;
|
|
2412
|
+
}
|
|
2413
|
+
// Collect all numbers to find max width for alignment
|
|
2414
|
+
const collectNumbers = (v) => {
|
|
2415
|
+
if (typeof v === "number")
|
|
2416
|
+
return [v];
|
|
2417
|
+
return v.flatMap(collectNumbers);
|
|
2418
|
+
};
|
|
2419
|
+
const allNumbers = collectNumbers(val);
|
|
2420
|
+
const maxWidth = Math.max(...allNumbers.map((n) => formatNum(n).length));
|
|
2421
|
+
const ndim = this.shape.length;
|
|
2422
|
+
const baseIndent = "tensor(".length; // 7
|
|
2423
|
+
const formatNested = (v, depth) => {
|
|
2424
|
+
if (typeof v === "number") {
|
|
2425
|
+
return formatNum(v).padStart(maxWidth);
|
|
2426
|
+
}
|
|
2427
|
+
const arr = v;
|
|
2428
|
+
// Innermost dimension: format as single line [x, y, z]
|
|
2429
|
+
if (arr.length > 0 && typeof arr[0] === "number") {
|
|
2430
|
+
const elements = arr.map((x) => formatNum(x).padStart(maxWidth));
|
|
2431
|
+
return `[${elements.join(", ")}]`;
|
|
2432
|
+
}
|
|
2433
|
+
// Number of blank lines between elements at this depth
|
|
2434
|
+
// Deeper = fewer blank lines (0 between rows, 1 between 2D blocks, etc.)
|
|
2435
|
+
const blankLines = Math.max(0, ndim - depth - 2);
|
|
2436
|
+
const separator = ",\n" + "\n".repeat(blankLines);
|
|
2437
|
+
const innerIndent = " ".repeat(baseIndent + depth + 1);
|
|
2438
|
+
const formatted = arr.map((item, i) => {
|
|
2439
|
+
const str = formatNested(item, depth + 1);
|
|
2440
|
+
return i === 0 ? "[" + str : innerIndent + str;
|
|
2441
|
+
});
|
|
2442
|
+
return formatted.join(separator) + "]";
|
|
2443
|
+
};
|
|
2444
|
+
return "tensor(" + formatNested(val, 0) + ")";
|
|
2445
|
+
}
|
|
2399
2446
|
// Returns a view of the tensor with gradient turned off and detaches from autograd
|
|
2400
2447
|
detach() {
|
|
2401
2448
|
return new Tensor(this.value, {
|
package/dist/nn.d.ts
CHANGED
|
@@ -50,6 +50,20 @@ export declare class LSTMCell {
|
|
|
50
50
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
51
51
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
52
52
|
}
|
|
53
|
+
export declare class BatchNorm {
|
|
54
|
+
weight?: Tensor;
|
|
55
|
+
bias?: Tensor;
|
|
56
|
+
runningMean?: Tensor;
|
|
57
|
+
runningVar?: Tensor;
|
|
58
|
+
eps: number;
|
|
59
|
+
momentum: number;
|
|
60
|
+
numFeatures: number;
|
|
61
|
+
affine: boolean;
|
|
62
|
+
trackRunningStats: boolean;
|
|
63
|
+
numBatchesTracked: number;
|
|
64
|
+
constructor(numFeatures: number, eps?: number, momentum?: number, affine?: boolean, trackRunningStats?: boolean, device?: string, dtype?: dtype);
|
|
65
|
+
forward(input: Tensor): Tensor;
|
|
66
|
+
}
|
|
53
67
|
export declare class LayerNorm {
|
|
54
68
|
weight?: Tensor;
|
|
55
69
|
bias?: Tensor;
|
|
@@ -91,6 +105,7 @@ export declare const nn: {
|
|
|
91
105
|
RNNCell: typeof RNNCell;
|
|
92
106
|
GRUCell: typeof GRUCell;
|
|
93
107
|
LSTMCell: typeof LSTMCell;
|
|
108
|
+
BatchNorm: typeof BatchNorm;
|
|
94
109
|
LayerNorm: typeof LayerNorm;
|
|
95
110
|
RMSNorm: typeof RMSNorm;
|
|
96
111
|
Embedding: typeof Embedding;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
5
5
|
const core_1 = require("./core");
|
|
6
6
|
function linearTransform(input, weight, bias) {
|
|
@@ -149,6 +149,83 @@ class LSTMCell {
|
|
|
149
149
|
}
|
|
150
150
|
}
|
|
151
151
|
exports.LSTMCell = LSTMCell;
|
|
152
|
+
class BatchNorm {
|
|
153
|
+
weight;
|
|
154
|
+
bias;
|
|
155
|
+
runningMean;
|
|
156
|
+
runningVar;
|
|
157
|
+
eps;
|
|
158
|
+
momentum;
|
|
159
|
+
numFeatures;
|
|
160
|
+
affine;
|
|
161
|
+
trackRunningStats;
|
|
162
|
+
numBatchesTracked;
|
|
163
|
+
constructor(numFeatures, eps = 1e-5, momentum = 0.1, affine = true, trackRunningStats = true, device, dtype) {
|
|
164
|
+
this.numFeatures = numFeatures;
|
|
165
|
+
this.eps = eps;
|
|
166
|
+
this.momentum = momentum;
|
|
167
|
+
this.affine = affine;
|
|
168
|
+
this.trackRunningStats = trackRunningStats;
|
|
169
|
+
this.numBatchesTracked = 0;
|
|
170
|
+
if (this.affine) {
|
|
171
|
+
this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype });
|
|
172
|
+
this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype });
|
|
173
|
+
}
|
|
174
|
+
if (this.trackRunningStats) {
|
|
175
|
+
this.runningMean = core_1.Tensor.zeros([numFeatures], { requiresGrad: false, device, dtype });
|
|
176
|
+
this.runningVar = core_1.Tensor.ones([numFeatures], { requiresGrad: false, device, dtype });
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
forward(input) {
|
|
180
|
+
// Input shape: (N, C, ...) where C = numFeatures
|
|
181
|
+
// Normalize over batch dimension and spatial dimensions (if any)
|
|
182
|
+
if (input.shape.length < 2) {
|
|
183
|
+
throw new Error("Input must have at least 2 dimensions (batch, features)");
|
|
184
|
+
}
|
|
185
|
+
if (input.shape[1] !== this.numFeatures) {
|
|
186
|
+
throw new Error(`Expected ${this.numFeatures} features, got ${input.shape[1]}`);
|
|
187
|
+
}
|
|
188
|
+
let mean;
|
|
189
|
+
let variance;
|
|
190
|
+
if (core_1.Tensor.training || !this.trackRunningStats) {
|
|
191
|
+
// Training or trackRunningStats disabled - calculate mean and variance from scratch
|
|
192
|
+
// Calculate mean and variance over batch and spatial dimensions
|
|
193
|
+
// Keep only the channel dimension
|
|
194
|
+
const dims = [0, ...Array.from({ length: input.shape.length - 2 }, (_, i) => i + 2)];
|
|
195
|
+
mean = input.mean(dims, true);
|
|
196
|
+
variance = input.sub(mean).pow(2).mean(dims, true);
|
|
197
|
+
// Update running statistics if enabled and in training mode
|
|
198
|
+
if (this.trackRunningStats && core_1.Tensor.training) {
|
|
199
|
+
const exponentialAverageFactor = this.momentum;
|
|
200
|
+
this.runningMean = this.runningMean
|
|
201
|
+
.mul(1 - exponentialAverageFactor)
|
|
202
|
+
.add(mean.squeeze().mul(exponentialAverageFactor));
|
|
203
|
+
// Use unbiased variance for running estimate
|
|
204
|
+
const n = input.shape.reduce((acc, val, idx) => idx === 1 ? acc : acc * val, 1);
|
|
205
|
+
const unbiasingFactor = n / (n - 1);
|
|
206
|
+
this.runningVar = this.runningVar
|
|
207
|
+
.mul(1 - exponentialAverageFactor)
|
|
208
|
+
.add(variance.squeeze().mul(exponentialAverageFactor * unbiasingFactor));
|
|
209
|
+
this.numBatchesTracked++;
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
else {
|
|
213
|
+
// Inference with trackRunningStats enabled - use running statistics
|
|
214
|
+
mean = this.runningMean.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
215
|
+
variance = this.runningVar.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
216
|
+
}
|
|
217
|
+
// Normalize
|
|
218
|
+
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
219
|
+
// Apply affine transformation
|
|
220
|
+
if (this.affine) {
|
|
221
|
+
const weightReshaped = this.weight.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
222
|
+
const biasReshaped = this.bias.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]);
|
|
223
|
+
normalized = normalized.mul(weightReshaped).add(biasReshaped);
|
|
224
|
+
}
|
|
225
|
+
return normalized;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
exports.BatchNorm = BatchNorm;
|
|
152
229
|
class LayerNorm {
|
|
153
230
|
weight;
|
|
154
231
|
bias;
|
|
@@ -383,6 +460,7 @@ exports.nn = {
|
|
|
383
460
|
RNNCell,
|
|
384
461
|
GRUCell,
|
|
385
462
|
LSTMCell,
|
|
463
|
+
BatchNorm,
|
|
386
464
|
LayerNorm,
|
|
387
465
|
RMSNorm,
|
|
388
466
|
Embedding,
|