catniff 0.8.21 → 0.8.22
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +10 -0
- package/dist/core.js +168 -0
- package/dist/nn.d.ts +12 -8
- package/dist/nn.js +18 -147
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import { Backend } from "./backend";
|
|
2
2
|
import { dtype, MemoryBuffer } from "./dtype";
|
|
3
|
+
export type Callable = ((input: Tensor) => Tensor) | {
|
|
4
|
+
forward: (input: Tensor) => Tensor;
|
|
5
|
+
};
|
|
3
6
|
export type TensorValue = number | ArrayLike<TensorValue>;
|
|
4
7
|
export interface TensorOptions {
|
|
5
8
|
shape?: number[];
|
|
@@ -220,6 +223,13 @@ export declare class Tensor {
|
|
|
220
223
|
tril(diagonal?: number): Tensor;
|
|
221
224
|
maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
|
|
222
225
|
multinomial(numSamples: number, replacement?: boolean): Tensor;
|
|
226
|
+
linear(weight: Tensor | TensorValue, bias?: Tensor | TensorValue): Tensor;
|
|
227
|
+
sequential(callables: Callable[]): Tensor;
|
|
228
|
+
layerNorm(normalizedShape: number[], weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
|
|
229
|
+
rmsNorm(normalizedShape: number[], weight?: Tensor | TensorValue, eps?: number): Tensor;
|
|
230
|
+
instanceNorm(weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
|
|
231
|
+
groupNorm(numGroups: number, weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
|
|
232
|
+
scaledDotProductAttention(key: Tensor | TensorValue, value: Tensor | TensorValue, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
|
|
223
233
|
static full(shape: number[], num: number, options?: TensorOptions): Tensor;
|
|
224
234
|
static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
|
|
225
235
|
static ones(shape?: number[], options?: TensorOptions): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -2302,6 +2302,174 @@ class Tensor {
|
|
|
2302
2302
|
dtype: "int32"
|
|
2303
2303
|
});
|
|
2304
2304
|
}
|
|
2305
|
+
// Functional linear projection
|
|
2306
|
+
linear(weight, bias) {
|
|
2307
|
+
weight = this.handleOther(weight);
|
|
2308
|
+
let output = this.matmul(weight.transpose(-1, -2));
|
|
2309
|
+
if (bias) {
|
|
2310
|
+
bias = this.handleOther(bias);
|
|
2311
|
+
output = output.add(bias);
|
|
2312
|
+
}
|
|
2313
|
+
return output;
|
|
2314
|
+
}
|
|
2315
|
+
// Functional sequential chaining
|
|
2316
|
+
sequential(callables) {
|
|
2317
|
+
let res = this;
|
|
2318
|
+
for (let index = 0; index < callables.length; index++) {
|
|
2319
|
+
const callable = callables[index];
|
|
2320
|
+
if (typeof callable === "function") {
|
|
2321
|
+
res = callable(res);
|
|
2322
|
+
}
|
|
2323
|
+
else if (typeof callable === "object" && typeof callable.forward === "function") {
|
|
2324
|
+
res = callable.forward(res);
|
|
2325
|
+
}
|
|
2326
|
+
}
|
|
2327
|
+
return res;
|
|
2328
|
+
}
|
|
2329
|
+
// Functional layer norm
|
|
2330
|
+
layerNorm(normalizedShape, weight, bias, eps = 1e-05) {
|
|
2331
|
+
// Normalize over the specified dimensions
|
|
2332
|
+
const normalizedDims = normalizedShape.length;
|
|
2333
|
+
const startDim = this.shape.length - normalizedDims;
|
|
2334
|
+
if (startDim < 0) {
|
|
2335
|
+
throw new Error("Input does not have enough dims to normalize");
|
|
2336
|
+
}
|
|
2337
|
+
const dims = [];
|
|
2338
|
+
for (let i = 0; i < normalizedDims; i++) {
|
|
2339
|
+
if (this.shape[startDim + i] !== normalizedShape[i]) {
|
|
2340
|
+
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
|
|
2341
|
+
}
|
|
2342
|
+
dims.push(startDim + i);
|
|
2343
|
+
}
|
|
2344
|
+
const mean = this.mean(dims, true);
|
|
2345
|
+
const centered = this.sub(mean);
|
|
2346
|
+
const variance = centered.pow(2).mean(dims, true);
|
|
2347
|
+
let normalized = centered.div(variance.add(eps).sqrt());
|
|
2348
|
+
if (weight) {
|
|
2349
|
+
normalized = normalized.mul(weight);
|
|
2350
|
+
}
|
|
2351
|
+
if (bias) {
|
|
2352
|
+
normalized = normalized.add(bias);
|
|
2353
|
+
}
|
|
2354
|
+
return normalized;
|
|
2355
|
+
}
|
|
2356
|
+
// Functional RMS norm
|
|
2357
|
+
rmsNorm(normalizedShape, weight, eps = 1e-5) {
|
|
2358
|
+
// Normalize over the specified dimensions
|
|
2359
|
+
const normalizedDims = normalizedShape.length;
|
|
2360
|
+
const startDim = this.shape.length - normalizedDims;
|
|
2361
|
+
if (startDim < 0) {
|
|
2362
|
+
throw new Error("Input does not have enough dims to normalize");
|
|
2363
|
+
}
|
|
2364
|
+
const dims = [];
|
|
2365
|
+
for (let i = 0; i < normalizedDims; i++) {
|
|
2366
|
+
if (this.shape[startDim + i] !== normalizedShape[i]) {
|
|
2367
|
+
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
|
|
2368
|
+
}
|
|
2369
|
+
dims.push(startDim + i);
|
|
2370
|
+
}
|
|
2371
|
+
let rms = this.square().mean(dims, true).add(eps).sqrt();
|
|
2372
|
+
let normalized = this.div(rms);
|
|
2373
|
+
if (weight) {
|
|
2374
|
+
normalized = normalized.mul(weight);
|
|
2375
|
+
}
|
|
2376
|
+
return normalized;
|
|
2377
|
+
}
|
|
2378
|
+
// Functional instance norm
|
|
2379
|
+
instanceNorm(weight, bias, eps = 1e-5) {
|
|
2380
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
2381
|
+
if (this.shape.length < 3) {
|
|
2382
|
+
throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
|
|
2383
|
+
}
|
|
2384
|
+
// Normalize across spatial dimensions (all dims after channel dim)
|
|
2385
|
+
const dims = [];
|
|
2386
|
+
for (let i = 2; i < this.shape.length; i++) {
|
|
2387
|
+
dims.push(i);
|
|
2388
|
+
}
|
|
2389
|
+
const mean = this.mean(dims, true);
|
|
2390
|
+
const centered = this.sub(mean);
|
|
2391
|
+
const variance = centered.pow(2).mean(dims, true);
|
|
2392
|
+
let normalized = centered.div(variance.add(eps).sqrt());
|
|
2393
|
+
const numFeatures = this.shape[1];
|
|
2394
|
+
if (weight) {
|
|
2395
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
2396
|
+
weight = this.handleOther(weight);
|
|
2397
|
+
const weightShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
|
|
2398
|
+
const weightReshaped = weight.reshape(weightShape);
|
|
2399
|
+
normalized = normalized.mul(weightReshaped);
|
|
2400
|
+
}
|
|
2401
|
+
if (bias) {
|
|
2402
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
2403
|
+
bias = this.handleOther(bias);
|
|
2404
|
+
const biasShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
|
|
2405
|
+
const biasReshaped = bias.reshape(biasShape);
|
|
2406
|
+
normalized = normalized.add(biasReshaped);
|
|
2407
|
+
}
|
|
2408
|
+
return normalized;
|
|
2409
|
+
}
|
|
2410
|
+
// Functional group norm
|
|
2411
|
+
groupNorm(numGroups, weight, bias, eps = 1e-5) {
|
|
2412
|
+
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
2413
|
+
if (this.shape.length < 3) {
|
|
2414
|
+
throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
|
|
2415
|
+
}
|
|
2416
|
+
const N = this.shape[0];
|
|
2417
|
+
const C = this.shape[1];
|
|
2418
|
+
const spatialDims = this.shape.slice(2);
|
|
2419
|
+
const channelsPerGroup = C / numGroups;
|
|
2420
|
+
// Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
|
|
2421
|
+
const reshapedInput = this.reshape([N, numGroups, channelsPerGroup, ...spatialDims]);
|
|
2422
|
+
// Normalize across (C//G, ...spatial) dimensions for each group
|
|
2423
|
+
// That's dims [2, 3, 4, ...] in the reshaped tensor
|
|
2424
|
+
const dims = [];
|
|
2425
|
+
for (let i = 2; i < reshapedInput.shape.length; i++) {
|
|
2426
|
+
dims.push(i);
|
|
2427
|
+
}
|
|
2428
|
+
const mean = reshapedInput.mean(dims, true);
|
|
2429
|
+
const centered = reshapedInput.sub(mean);
|
|
2430
|
+
const variance = centered.pow(2).mean(dims, true);
|
|
2431
|
+
let normalized = centered.div(variance.add(eps).sqrt());
|
|
2432
|
+
// Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
|
|
2433
|
+
normalized = normalized.reshape(this.shape);
|
|
2434
|
+
const numChannels = this.shape[1];
|
|
2435
|
+
if (weight) {
|
|
2436
|
+
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
2437
|
+
weight = this.handleOther(weight);
|
|
2438
|
+
const weightShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
|
|
2439
|
+
const weightReshaped = weight.reshape(weightShape);
|
|
2440
|
+
normalized = normalized.mul(weightReshaped);
|
|
2441
|
+
}
|
|
2442
|
+
if (bias) {
|
|
2443
|
+
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
2444
|
+
bias = this.handleOther(bias);
|
|
2445
|
+
const biasShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
|
|
2446
|
+
const biasReshaped = bias.reshape(biasShape);
|
|
2447
|
+
normalized = normalized.add(biasReshaped);
|
|
2448
|
+
}
|
|
2449
|
+
return normalized;
|
|
2450
|
+
}
|
|
2451
|
+
// Functional scaled dot product attention
|
|
2452
|
+
scaledDotProductAttention(key, value, attnMask, dropout = 0, isCausal = false, scale) {
|
|
2453
|
+
key = this.handleOther(key);
|
|
2454
|
+
value = this.handleOther(value);
|
|
2455
|
+
const targetLen = this.shape[this.shape.length - 2];
|
|
2456
|
+
const sourceLen = key.shape[key.shape.length - 2];
|
|
2457
|
+
const dimSize = this.shape[this.shape.length - 1];
|
|
2458
|
+
// Attention scores
|
|
2459
|
+
let scores = this.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
|
|
2460
|
+
// Set attention mask to causal mask if specified
|
|
2461
|
+
if (isCausal) {
|
|
2462
|
+
attnMask = Tensor.ones([targetLen, sourceLen], { device: this.device }).triu(1);
|
|
2463
|
+
}
|
|
2464
|
+
// Apply attention mask if specified
|
|
2465
|
+
if (attnMask) {
|
|
2466
|
+
scores = scores.maskedFill(attnMask, -Infinity);
|
|
2467
|
+
}
|
|
2468
|
+
// Calculate attention weights
|
|
2469
|
+
let attnWeights = scores.softmax().dropout(dropout);
|
|
2470
|
+
// Apply attention to values
|
|
2471
|
+
return attnWeights.matmul(value);
|
|
2472
|
+
}
|
|
2305
2473
|
// Utility to create a new tensor filled with a number
|
|
2306
2474
|
static full(shape, num, options = {}) {
|
|
2307
2475
|
if (shape.length === 0)
|
package/dist/nn.d.ts
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { Callable, Tensor } from "./core";
|
|
2
2
|
import { dtype } from "./dtype";
|
|
3
3
|
export declare class Linear {
|
|
4
4
|
weight: Tensor;
|
|
5
5
|
bias?: Tensor;
|
|
6
6
|
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
7
|
-
forward(input: Tensor
|
|
7
|
+
forward(input: Tensor): Tensor;
|
|
8
|
+
}
|
|
9
|
+
export declare class Sequential {
|
|
10
|
+
callables: Callable[];
|
|
11
|
+
constructor(callables: Callable[]);
|
|
12
|
+
forward(input: Tensor): Tensor;
|
|
8
13
|
}
|
|
9
14
|
export declare class RNNCell {
|
|
10
15
|
weightIH: Tensor;
|
|
@@ -12,7 +17,7 @@ export declare class RNNCell {
|
|
|
12
17
|
biasIH?: Tensor;
|
|
13
18
|
biasHH?: Tensor;
|
|
14
19
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
15
|
-
forward(input: Tensor
|
|
20
|
+
forward(input: Tensor, hidden: Tensor): Tensor;
|
|
16
21
|
}
|
|
17
22
|
export declare class GRUCell {
|
|
18
23
|
weightIR: Tensor;
|
|
@@ -28,7 +33,7 @@ export declare class GRUCell {
|
|
|
28
33
|
biasHZ?: Tensor;
|
|
29
34
|
biasHN?: Tensor;
|
|
30
35
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
31
|
-
forward(input: Tensor
|
|
36
|
+
forward(input: Tensor, hidden: Tensor): Tensor;
|
|
32
37
|
}
|
|
33
38
|
export declare class LSTMCell {
|
|
34
39
|
weightII: Tensor;
|
|
@@ -48,7 +53,7 @@ export declare class LSTMCell {
|
|
|
48
53
|
biasHG?: Tensor;
|
|
49
54
|
biasHO?: Tensor;
|
|
50
55
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
51
|
-
forward(input: Tensor
|
|
56
|
+
forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
|
|
52
57
|
}
|
|
53
58
|
export declare class BatchNorm {
|
|
54
59
|
weight?: Tensor;
|
|
@@ -99,9 +104,8 @@ export declare class RMSNorm {
|
|
|
99
104
|
export declare class Embedding {
|
|
100
105
|
weight: Tensor;
|
|
101
106
|
constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
|
|
102
|
-
forward(input: Tensor
|
|
107
|
+
forward(input: Tensor): Tensor;
|
|
103
108
|
}
|
|
104
|
-
export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
|
|
105
109
|
export declare class MultiheadAttention {
|
|
106
110
|
qProjection: Linear;
|
|
107
111
|
kProjection: Linear;
|
|
@@ -119,6 +123,7 @@ export interface StateDict {
|
|
|
119
123
|
}
|
|
120
124
|
export declare const nn: {
|
|
121
125
|
Linear: typeof Linear;
|
|
126
|
+
Sequential: typeof Sequential;
|
|
122
127
|
RNNCell: typeof RNNCell;
|
|
123
128
|
GRUCell: typeof GRUCell;
|
|
124
129
|
LSTMCell: typeof LSTMCell;
|
|
@@ -128,7 +133,6 @@ export declare const nn: {
|
|
|
128
133
|
LayerNorm: typeof LayerNorm;
|
|
129
134
|
RMSNorm: typeof RMSNorm;
|
|
130
135
|
Embedding: typeof Embedding;
|
|
131
|
-
scaledDotProductAttention: typeof scaledDotProductAttention;
|
|
132
136
|
MultiheadAttention: typeof MultiheadAttention;
|
|
133
137
|
state: {
|
|
134
138
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
package/dist/nn.js
CHANGED
|
@@ -1,15 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
|
-
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
5
4
|
const core_1 = require("./core");
|
|
6
|
-
function linearTransform(input, weight, bias) {
|
|
7
|
-
let output = input.matmul(weight.t());
|
|
8
|
-
if (bias) {
|
|
9
|
-
output = output.add(bias);
|
|
10
|
-
}
|
|
11
|
-
return output;
|
|
12
|
-
}
|
|
13
5
|
class Linear {
|
|
14
6
|
weight;
|
|
15
7
|
bias;
|
|
@@ -21,20 +13,22 @@ class Linear {
|
|
|
21
13
|
}
|
|
22
14
|
}
|
|
23
15
|
forward(input) {
|
|
24
|
-
input
|
|
25
|
-
return linearTransform(input, this.weight, this.bias);
|
|
16
|
+
return input.linear(this.weight, this.bias);
|
|
26
17
|
}
|
|
27
18
|
}
|
|
28
19
|
exports.Linear = Linear;
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
20
|
+
class Sequential {
|
|
21
|
+
callables;
|
|
22
|
+
constructor(callables) {
|
|
23
|
+
this.callables = callables;
|
|
33
24
|
}
|
|
34
|
-
|
|
35
|
-
|
|
25
|
+
forward(input) {
|
|
26
|
+
return input.sequential(this.callables);
|
|
36
27
|
}
|
|
37
|
-
|
|
28
|
+
}
|
|
29
|
+
exports.Sequential = Sequential;
|
|
30
|
+
function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
|
|
31
|
+
return input.linear(inputWeight, inputBias).add(hidden.linear(hiddenWeight, hiddenBias));
|
|
38
32
|
}
|
|
39
33
|
class RNNCell {
|
|
40
34
|
weightIH;
|
|
@@ -51,8 +45,6 @@ class RNNCell {
|
|
|
51
45
|
}
|
|
52
46
|
}
|
|
53
47
|
forward(input, hidden) {
|
|
54
|
-
input = this.weightIH.handleOther(input);
|
|
55
|
-
hidden = this.weightHH.handleOther(hidden);
|
|
56
48
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
57
49
|
}
|
|
58
50
|
}
|
|
@@ -88,11 +80,9 @@ class GRUCell {
|
|
|
88
80
|
}
|
|
89
81
|
}
|
|
90
82
|
forward(input, hidden) {
|
|
91
|
-
input = this.weightIN.handleOther(input);
|
|
92
|
-
hidden = this.weightHN.handleOther(hidden);
|
|
93
83
|
const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
|
|
94
84
|
const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
|
|
95
|
-
const n =
|
|
85
|
+
const n = input.linear(this.weightIN, this.biasIN).add(r.mul(hidden.linear(this.weightHN, this.biasHN))).tanh();
|
|
96
86
|
return (z.neg().add(1).mul(n).add(z.mul(hidden)));
|
|
97
87
|
}
|
|
98
88
|
}
|
|
@@ -136,9 +126,6 @@ class LSTMCell {
|
|
|
136
126
|
}
|
|
137
127
|
}
|
|
138
128
|
forward(input, hidden, cell) {
|
|
139
|
-
input = this.weightII.handleOther(input);
|
|
140
|
-
hidden = this.weightHI.handleOther(hidden);
|
|
141
|
-
cell = this.weightHI.handleOther(cell);
|
|
142
129
|
const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
|
|
143
130
|
const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
|
|
144
131
|
const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
|
|
@@ -240,34 +227,10 @@ class InstanceNorm {
|
|
|
240
227
|
}
|
|
241
228
|
}
|
|
242
229
|
forward(input) {
|
|
243
|
-
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
244
|
-
if (input.shape.length < 3) {
|
|
245
|
-
throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
|
|
246
|
-
}
|
|
247
230
|
if (input.shape[1] !== this.numFeatures) {
|
|
248
231
|
throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
|
|
249
232
|
}
|
|
250
|
-
|
|
251
|
-
const dims = [];
|
|
252
|
-
for (let i = 2; i < input.shape.length; i++) {
|
|
253
|
-
dims.push(i);
|
|
254
|
-
}
|
|
255
|
-
const mean = input.mean(dims, true);
|
|
256
|
-
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
257
|
-
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
258
|
-
if (this.weight) {
|
|
259
|
-
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
260
|
-
const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
261
|
-
const weightReshaped = this.weight.reshape(weightShape);
|
|
262
|
-
normalized = normalized.mul(weightReshaped);
|
|
263
|
-
}
|
|
264
|
-
if (this.bias) {
|
|
265
|
-
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
266
|
-
const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
|
|
267
|
-
const biasReshaped = this.bias.reshape(biasShape);
|
|
268
|
-
normalized = normalized.add(biasReshaped);
|
|
269
|
-
}
|
|
270
|
-
return normalized;
|
|
233
|
+
return input.instanceNorm(this.weight, this.bias, this.eps);
|
|
271
234
|
}
|
|
272
235
|
}
|
|
273
236
|
exports.InstanceNorm = InstanceNorm;
|
|
@@ -290,43 +253,10 @@ class GroupNorm {
|
|
|
290
253
|
}
|
|
291
254
|
}
|
|
292
255
|
forward(input) {
|
|
293
|
-
// Input should be at least 3D: [N, C, ...spatial dims]
|
|
294
|
-
if (input.shape.length < 3) {
|
|
295
|
-
throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
|
|
296
|
-
}
|
|
297
256
|
if (input.shape[1] !== this.numChannels) {
|
|
298
257
|
throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
|
|
299
258
|
}
|
|
300
|
-
|
|
301
|
-
const C = input.shape[1];
|
|
302
|
-
const spatialDims = input.shape.slice(2);
|
|
303
|
-
const channelsPerGroup = C / this.numGroups;
|
|
304
|
-
// Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
|
|
305
|
-
const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
|
|
306
|
-
// Normalize across (C//G, ...spatial) dimensions for each group
|
|
307
|
-
// That's dims [2, 3, 4, ...] in the reshaped tensor
|
|
308
|
-
const dims = [];
|
|
309
|
-
for (let i = 2; i < reshapedInput.shape.length; i++) {
|
|
310
|
-
dims.push(i);
|
|
311
|
-
}
|
|
312
|
-
const mean = reshapedInput.mean(dims, true);
|
|
313
|
-
const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
|
|
314
|
-
let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
|
|
315
|
-
// Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
|
|
316
|
-
normalized = normalized.reshape(input.shape);
|
|
317
|
-
if (this.weight) {
|
|
318
|
-
// Reshape weight to [1, C, 1, 1, ...] for broadcasting
|
|
319
|
-
const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
320
|
-
const weightReshaped = this.weight.reshape(weightShape);
|
|
321
|
-
normalized = normalized.mul(weightReshaped);
|
|
322
|
-
}
|
|
323
|
-
if (this.bias) {
|
|
324
|
-
// Reshape bias to [1, C, 1, 1, ...] for broadcasting
|
|
325
|
-
const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
|
|
326
|
-
const biasReshaped = this.bias.reshape(biasShape);
|
|
327
|
-
normalized = normalized.add(biasReshaped);
|
|
328
|
-
}
|
|
329
|
-
return normalized;
|
|
259
|
+
return input.groupNorm(this.numGroups, this.weight, this.bias, this.eps);
|
|
330
260
|
}
|
|
331
261
|
}
|
|
332
262
|
exports.GroupNorm = GroupNorm;
|
|
@@ -349,29 +279,7 @@ class LayerNorm {
|
|
|
349
279
|
}
|
|
350
280
|
}
|
|
351
281
|
forward(input) {
|
|
352
|
-
|
|
353
|
-
const normalizedDims = this.normalizedShape.length;
|
|
354
|
-
const startDim = input.shape.length - normalizedDims;
|
|
355
|
-
if (startDim < 0) {
|
|
356
|
-
throw new Error("Input does not have enough dims to normalize");
|
|
357
|
-
}
|
|
358
|
-
const dims = [];
|
|
359
|
-
for (let i = 0; i < normalizedDims; i++) {
|
|
360
|
-
if (input.shape[startDim + i] !== this.normalizedShape[i]) {
|
|
361
|
-
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
|
|
362
|
-
}
|
|
363
|
-
dims.push(startDim + i);
|
|
364
|
-
}
|
|
365
|
-
const mean = input.mean(dims, true);
|
|
366
|
-
const variance = input.sub(mean).pow(2).mean(dims, true);
|
|
367
|
-
let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
|
|
368
|
-
if (this.weight) {
|
|
369
|
-
normalized = normalized.mul(this.weight);
|
|
370
|
-
}
|
|
371
|
-
if (this.bias) {
|
|
372
|
-
normalized = normalized.add(this.bias);
|
|
373
|
-
}
|
|
374
|
-
return normalized;
|
|
282
|
+
return input.layerNorm(this.normalizedShape, this.weight, this.bias, this.eps);
|
|
375
283
|
}
|
|
376
284
|
}
|
|
377
285
|
exports.LayerNorm = LayerNorm;
|
|
@@ -390,25 +298,7 @@ class RMSNorm {
|
|
|
390
298
|
}
|
|
391
299
|
}
|
|
392
300
|
forward(input) {
|
|
393
|
-
|
|
394
|
-
const normalizedDims = this.normalizedShape.length;
|
|
395
|
-
const startDim = input.shape.length - normalizedDims;
|
|
396
|
-
if (startDim < 0) {
|
|
397
|
-
throw new Error("Input does not have enough dims to normalize");
|
|
398
|
-
}
|
|
399
|
-
const dims = [];
|
|
400
|
-
for (let i = 0; i < normalizedDims; i++) {
|
|
401
|
-
if (input.shape[startDim + i] !== this.normalizedShape[i]) {
|
|
402
|
-
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
|
|
403
|
-
}
|
|
404
|
-
dims.push(startDim + i);
|
|
405
|
-
}
|
|
406
|
-
let rms = input.square().mean(dims, true).add(this.eps).sqrt();
|
|
407
|
-
let normalized = input.div(rms);
|
|
408
|
-
if (this.weight) {
|
|
409
|
-
normalized = normalized.mul(this.weight);
|
|
410
|
-
}
|
|
411
|
-
return normalized;
|
|
301
|
+
return input.rmsNorm(this.normalizedShape, this.weight, this.eps);
|
|
412
302
|
}
|
|
413
303
|
}
|
|
414
304
|
exports.RMSNorm = RMSNorm;
|
|
@@ -422,25 +312,6 @@ class Embedding {
|
|
|
422
312
|
}
|
|
423
313
|
}
|
|
424
314
|
exports.Embedding = Embedding;
|
|
425
|
-
function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
|
|
426
|
-
const targetLen = query.shape[query.shape.length - 2];
|
|
427
|
-
const sourceLen = key.shape[key.shape.length - 2];
|
|
428
|
-
const dimSize = query.shape[query.shape.length - 1];
|
|
429
|
-
// Attention scores
|
|
430
|
-
let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
|
|
431
|
-
// Set attention mask to causal mask if specified
|
|
432
|
-
if (isCausal) {
|
|
433
|
-
attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
|
|
434
|
-
}
|
|
435
|
-
// Apply attention mask if specified
|
|
436
|
-
if (attnMask) {
|
|
437
|
-
scores = scores.maskedFill(attnMask, -Infinity);
|
|
438
|
-
}
|
|
439
|
-
// Calculate attention weights
|
|
440
|
-
let attnWeights = scores.softmax().dropout(dropout);
|
|
441
|
-
// Apply attention to values
|
|
442
|
-
return attnWeights.matmul(value);
|
|
443
|
-
}
|
|
444
315
|
class MultiheadAttention {
|
|
445
316
|
qProjection;
|
|
446
317
|
kProjection;
|
|
@@ -561,6 +432,7 @@ const state = {
|
|
|
561
432
|
};
|
|
562
433
|
exports.nn = {
|
|
563
434
|
Linear,
|
|
435
|
+
Sequential,
|
|
564
436
|
RNNCell,
|
|
565
437
|
GRUCell,
|
|
566
438
|
LSTMCell,
|
|
@@ -570,7 +442,6 @@ exports.nn = {
|
|
|
570
442
|
LayerNorm,
|
|
571
443
|
RMSNorm,
|
|
572
444
|
Embedding,
|
|
573
|
-
scaledDotProductAttention,
|
|
574
445
|
MultiheadAttention,
|
|
575
446
|
state
|
|
576
447
|
};
|