catniff 0.8.21 → 0.8.23

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -1,5 +1,8 @@
1
1
  import { Backend } from "./backend";
2
2
  import { dtype, MemoryBuffer } from "./dtype";
3
+ export type Callable = ((input: Tensor) => Tensor) | {
4
+ forward: (input: Tensor) => Tensor;
5
+ };
3
6
  export type TensorValue = number | ArrayLike<TensorValue>;
4
7
  export interface TensorOptions {
5
8
  shape?: number[];
@@ -67,6 +70,7 @@ export declare class Tensor {
67
70
  chunk(chunks: number, dim?: number): Tensor[];
68
71
  expand(newShape: number[]): Tensor;
69
72
  unfold(dim: number, size: number, step: number): Tensor;
73
+ pad(pad: number[], mode?: string, value?: number): Tensor;
70
74
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
71
75
  stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
72
76
  squeeze(dims?: number[] | number): Tensor;
@@ -220,6 +224,13 @@ export declare class Tensor {
220
224
  tril(diagonal?: number): Tensor;
221
225
  maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
222
226
  multinomial(numSamples: number, replacement?: boolean): Tensor;
227
+ linear(weight: Tensor | TensorValue, bias?: Tensor | TensorValue): Tensor;
228
+ sequential(callables: Callable[]): Tensor;
229
+ layerNorm(normalizedShape: number[], weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
230
+ rmsNorm(normalizedShape: number[], weight?: Tensor | TensorValue, eps?: number): Tensor;
231
+ instanceNorm(weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
232
+ groupNorm(numGroups: number, weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
233
+ scaledDotProductAttention(key: Tensor | TensorValue, value: Tensor | TensorValue, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
223
234
  static full(shape: number[], num: number, options?: TensorOptions): Tensor;
224
235
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
225
236
  static ones(shape?: number[], options?: TensorOptions): Tensor;
package/dist/core.js CHANGED
@@ -328,8 +328,14 @@ class Tensor {
328
328
  }
329
329
  const reducedGrad = accumGrad.sum(axesToReduce, true);
330
330
  const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
331
+ // Enforce 0-offset contiguous grads and correct dtype
331
332
  if (typeof tensor.grad === "undefined") {
332
- tensor.grad = squeezedGrad;
333
+ let grad = squeezedGrad;
334
+ // Handle potentially contiguous tensors with non zero offset
335
+ if (grad.offset !== 0) {
336
+ grad = grad.clone();
337
+ }
338
+ tensor.grad = grad.contiguous().cast(tensor.dtype);
333
339
  }
334
340
  else {
335
341
  tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
@@ -808,6 +814,70 @@ class Tensor {
808
814
  }
809
815
  return out;
810
816
  }
817
+ // Tensor padding
818
+ pad(pad, mode = "constant", value = 0) {
819
+ const original = this.clone().contiguous(); // This is needed for index padding to work
820
+ const outputShape = [...original.shape];
821
+ const paddingPerDim = [];
822
+ for (let i = 0; i < original.shape.length; i++) {
823
+ const left = pad[(original.shape.length - 1 - i) * 2] || 0;
824
+ const right = pad[(original.shape.length - 1 - i) * 2 + 1] || 0;
825
+ paddingPerDim[i] = { left, right };
826
+ outputShape[i] += left + right;
827
+ }
828
+ const outputSize = Tensor.shapeToSize(outputShape);
829
+ if (mode === "constant") {
830
+ const outputValue = new dtype_1.TypedArray[original.dtype](outputSize).fill(value);
831
+ const outputStrides = Tensor.getStrides(outputShape);
832
+ for (let index = 0; index < original.numel; index++) {
833
+ const coords = Tensor.indexToCoords(index, original.strides);
834
+ let paddedIndex = 0;
835
+ // Pad each coord
836
+ for (let j = 0; j < original.shape.length; j++) {
837
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
838
+ paddedIndex += shiftedCoord * outputStrides[j];
839
+ }
840
+ outputValue[paddedIndex] = original.value[index];
841
+ }
842
+ const out = new Tensor(outputValue, {
843
+ shape: outputShape,
844
+ strides: outputStrides,
845
+ offset: 0,
846
+ dtype: original.dtype,
847
+ device: original.device
848
+ });
849
+ if (original.requiresGrad) {
850
+ out.requiresGrad = true;
851
+ out.children.push(original);
852
+ out.gradFn = () => {
853
+ const outGrad = out.grad;
854
+ const gradValue = new dtype_1.TypedArray[original.dtype](original.numel);
855
+ const gradStrides = Tensor.getStrides(original.shape);
856
+ for (let index = 0; index < gradValue.length; index++) {
857
+ const coords = Tensor.indexToCoords(index, gradStrides);
858
+ let paddedIndex = 0;
859
+ // Pad each coord
860
+ for (let j = 0; j < original.shape.length; j++) {
861
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
862
+ paddedIndex += shiftedCoord * outputStrides[j];
863
+ }
864
+ gradValue[index] = outGrad.value[paddedIndex];
865
+ }
866
+ Tensor.addGrad(original, new Tensor(gradValue, {
867
+ shape: original.shape,
868
+ strides: gradStrides,
869
+ offset: 0,
870
+ dtype: original.dtype,
871
+ device: original.device
872
+ }));
873
+ };
874
+ }
875
+ return out;
876
+ }
877
+ else {
878
+ throw new Error(`Padding mode not supported: "${mode}"`);
879
+ }
880
+ }
811
881
  // Tensor concatentation
812
882
  cat(other, dim = 0) {
813
883
  other = this.handleOther(other);
@@ -2302,6 +2372,174 @@ class Tensor {
2302
2372
  dtype: "int32"
2303
2373
  });
2304
2374
  }
2375
+ // Functional linear projection
2376
+ linear(weight, bias) {
2377
+ weight = this.handleOther(weight);
2378
+ let output = this.matmul(weight.transpose(-1, -2));
2379
+ if (bias) {
2380
+ bias = this.handleOther(bias);
2381
+ output = output.add(bias);
2382
+ }
2383
+ return output;
2384
+ }
2385
+ // Functional sequential chaining
2386
+ sequential(callables) {
2387
+ let res = this;
2388
+ for (let index = 0; index < callables.length; index++) {
2389
+ const callable = callables[index];
2390
+ if (typeof callable === "function") {
2391
+ res = callable(res);
2392
+ }
2393
+ else if (typeof callable === "object" && typeof callable.forward === "function") {
2394
+ res = callable.forward(res);
2395
+ }
2396
+ }
2397
+ return res;
2398
+ }
2399
+ // Functional layer norm
2400
+ layerNorm(normalizedShape, weight, bias, eps = 1e-05) {
2401
+ // Normalize over the specified dimensions
2402
+ const normalizedDims = normalizedShape.length;
2403
+ const startDim = this.shape.length - normalizedDims;
2404
+ if (startDim < 0) {
2405
+ throw new Error("Input does not have enough dims to normalize");
2406
+ }
2407
+ const dims = [];
2408
+ for (let i = 0; i < normalizedDims; i++) {
2409
+ if (this.shape[startDim + i] !== normalizedShape[i]) {
2410
+ throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
2411
+ }
2412
+ dims.push(startDim + i);
2413
+ }
2414
+ const mean = this.mean(dims, true);
2415
+ const centered = this.sub(mean);
2416
+ const variance = centered.pow(2).mean(dims, true);
2417
+ let normalized = centered.div(variance.add(eps).sqrt());
2418
+ if (weight) {
2419
+ normalized = normalized.mul(weight);
2420
+ }
2421
+ if (bias) {
2422
+ normalized = normalized.add(bias);
2423
+ }
2424
+ return normalized;
2425
+ }
2426
+ // Functional RMS norm
2427
+ rmsNorm(normalizedShape, weight, eps = 1e-5) {
2428
+ // Normalize over the specified dimensions
2429
+ const normalizedDims = normalizedShape.length;
2430
+ const startDim = this.shape.length - normalizedDims;
2431
+ if (startDim < 0) {
2432
+ throw new Error("Input does not have enough dims to normalize");
2433
+ }
2434
+ const dims = [];
2435
+ for (let i = 0; i < normalizedDims; i++) {
2436
+ if (this.shape[startDim + i] !== normalizedShape[i]) {
2437
+ throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
2438
+ }
2439
+ dims.push(startDim + i);
2440
+ }
2441
+ let rms = this.square().mean(dims, true).add(eps).sqrt();
2442
+ let normalized = this.div(rms);
2443
+ if (weight) {
2444
+ normalized = normalized.mul(weight);
2445
+ }
2446
+ return normalized;
2447
+ }
2448
+ // Functional instance norm
2449
+ instanceNorm(weight, bias, eps = 1e-5) {
2450
+ // Input should be at least 3D: [N, C, ...spatial dims]
2451
+ if (this.shape.length < 3) {
2452
+ throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
2453
+ }
2454
+ // Normalize across spatial dimensions (all dims after channel dim)
2455
+ const dims = [];
2456
+ for (let i = 2; i < this.shape.length; i++) {
2457
+ dims.push(i);
2458
+ }
2459
+ const mean = this.mean(dims, true);
2460
+ const centered = this.sub(mean);
2461
+ const variance = centered.pow(2).mean(dims, true);
2462
+ let normalized = centered.div(variance.add(eps).sqrt());
2463
+ const numFeatures = this.shape[1];
2464
+ if (weight) {
2465
+ // Reshape weight to [1, C, 1, 1, ...] for broadcasting
2466
+ weight = this.handleOther(weight);
2467
+ const weightShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
2468
+ const weightReshaped = weight.reshape(weightShape);
2469
+ normalized = normalized.mul(weightReshaped);
2470
+ }
2471
+ if (bias) {
2472
+ // Reshape bias to [1, C, 1, 1, ...] for broadcasting
2473
+ bias = this.handleOther(bias);
2474
+ const biasShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
2475
+ const biasReshaped = bias.reshape(biasShape);
2476
+ normalized = normalized.add(biasReshaped);
2477
+ }
2478
+ return normalized;
2479
+ }
2480
+ // Functional group norm
2481
+ groupNorm(numGroups, weight, bias, eps = 1e-5) {
2482
+ // Input should be at least 3D: [N, C, ...spatial dims]
2483
+ if (this.shape.length < 3) {
2484
+ throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
2485
+ }
2486
+ const N = this.shape[0];
2487
+ const C = this.shape[1];
2488
+ const spatialDims = this.shape.slice(2);
2489
+ const channelsPerGroup = C / numGroups;
2490
+ // Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
2491
+ const reshapedInput = this.reshape([N, numGroups, channelsPerGroup, ...spatialDims]);
2492
+ // Normalize across (C//G, ...spatial) dimensions for each group
2493
+ // That's dims [2, 3, 4, ...] in the reshaped tensor
2494
+ const dims = [];
2495
+ for (let i = 2; i < reshapedInput.shape.length; i++) {
2496
+ dims.push(i);
2497
+ }
2498
+ const mean = reshapedInput.mean(dims, true);
2499
+ const centered = reshapedInput.sub(mean);
2500
+ const variance = centered.pow(2).mean(dims, true);
2501
+ let normalized = centered.div(variance.add(eps).sqrt());
2502
+ // Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
2503
+ normalized = normalized.reshape(this.shape);
2504
+ const numChannels = this.shape[1];
2505
+ if (weight) {
2506
+ // Reshape weight to [1, C, 1, 1, ...] for broadcasting
2507
+ weight = this.handleOther(weight);
2508
+ const weightShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
2509
+ const weightReshaped = weight.reshape(weightShape);
2510
+ normalized = normalized.mul(weightReshaped);
2511
+ }
2512
+ if (bias) {
2513
+ // Reshape bias to [1, C, 1, 1, ...] for broadcasting
2514
+ bias = this.handleOther(bias);
2515
+ const biasShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
2516
+ const biasReshaped = bias.reshape(biasShape);
2517
+ normalized = normalized.add(biasReshaped);
2518
+ }
2519
+ return normalized;
2520
+ }
2521
+ // Functional scaled dot product attention
2522
+ scaledDotProductAttention(key, value, attnMask, dropout = 0, isCausal = false, scale) {
2523
+ key = this.handleOther(key);
2524
+ value = this.handleOther(value);
2525
+ const targetLen = this.shape[this.shape.length - 2];
2526
+ const sourceLen = key.shape[key.shape.length - 2];
2527
+ const dimSize = this.shape[this.shape.length - 1];
2528
+ // Attention scores
2529
+ let scores = this.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
2530
+ // Set attention mask to causal mask if specified
2531
+ if (isCausal) {
2532
+ attnMask = Tensor.ones([targetLen, sourceLen], { device: this.device }).triu(1);
2533
+ }
2534
+ // Apply attention mask if specified
2535
+ if (attnMask) {
2536
+ scores = scores.maskedFill(attnMask, -Infinity);
2537
+ }
2538
+ // Calculate attention weights
2539
+ let attnWeights = scores.softmax().dropout(dropout);
2540
+ // Apply attention to values
2541
+ return attnWeights.matmul(value);
2542
+ }
2305
2543
  // Utility to create a new tensor filled with a number
2306
2544
  static full(shape, num, options = {}) {
2307
2545
  if (shape.length === 0)
package/dist/nn.d.ts CHANGED
@@ -1,10 +1,15 @@
1
- import { Tensor, TensorValue } from "./core";
1
+ import { Callable, Tensor } from "./core";
2
2
  import { dtype } from "./dtype";
3
3
  export declare class Linear {
4
4
  weight: Tensor;
5
5
  bias?: Tensor;
6
6
  constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string, dtype?: dtype);
7
- forward(input: Tensor | TensorValue): Tensor;
7
+ forward(input: Tensor): Tensor;
8
+ }
9
+ export declare class Sequential {
10
+ callables: Callable[];
11
+ constructor(callables: Callable[]);
12
+ forward(input: Tensor): Tensor;
8
13
  }
9
14
  export declare class RNNCell {
10
15
  weightIH: Tensor;
@@ -12,7 +17,7 @@ export declare class RNNCell {
12
17
  biasIH?: Tensor;
13
18
  biasHH?: Tensor;
14
19
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
15
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
20
+ forward(input: Tensor, hidden: Tensor): Tensor;
16
21
  }
17
22
  export declare class GRUCell {
18
23
  weightIR: Tensor;
@@ -28,7 +33,7 @@ export declare class GRUCell {
28
33
  biasHZ?: Tensor;
29
34
  biasHN?: Tensor;
30
35
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
31
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
36
+ forward(input: Tensor, hidden: Tensor): Tensor;
32
37
  }
33
38
  export declare class LSTMCell {
34
39
  weightII: Tensor;
@@ -48,7 +53,7 @@ export declare class LSTMCell {
48
53
  biasHG?: Tensor;
49
54
  biasHO?: Tensor;
50
55
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
51
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
56
+ forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
52
57
  }
53
58
  export declare class BatchNorm {
54
59
  weight?: Tensor;
@@ -99,9 +104,8 @@ export declare class RMSNorm {
99
104
  export declare class Embedding {
100
105
  weight: Tensor;
101
106
  constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
102
- forward(input: Tensor | TensorValue): Tensor;
107
+ forward(input: Tensor): Tensor;
103
108
  }
104
- export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
105
109
  export declare class MultiheadAttention {
106
110
  qProjection: Linear;
107
111
  kProjection: Linear;
@@ -119,6 +123,7 @@ export interface StateDict {
119
123
  }
120
124
  export declare const nn: {
121
125
  Linear: typeof Linear;
126
+ Sequential: typeof Sequential;
122
127
  RNNCell: typeof RNNCell;
123
128
  GRUCell: typeof GRUCell;
124
129
  LSTMCell: typeof LSTMCell;
@@ -128,7 +133,6 @@ export declare const nn: {
128
133
  LayerNorm: typeof LayerNorm;
129
134
  RMSNorm: typeof RMSNorm;
130
135
  Embedding: typeof Embedding;
131
- scaledDotProductAttention: typeof scaledDotProductAttention;
132
136
  MultiheadAttention: typeof MultiheadAttention;
133
137
  state: {
134
138
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
package/dist/nn.js CHANGED
@@ -1,15 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
4
- exports.scaledDotProductAttention = scaledDotProductAttention;
3
+ exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
5
4
  const core_1 = require("./core");
6
- function linearTransform(input, weight, bias) {
7
- let output = input.matmul(weight.t());
8
- if (bias) {
9
- output = output.add(bias);
10
- }
11
- return output;
12
- }
13
5
  class Linear {
14
6
  weight;
15
7
  bias;
@@ -21,20 +13,22 @@ class Linear {
21
13
  }
22
14
  }
23
15
  forward(input) {
24
- input = this.weight.handleOther(input);
25
- return linearTransform(input, this.weight, this.bias);
16
+ return input.linear(this.weight, this.bias);
26
17
  }
27
18
  }
28
19
  exports.Linear = Linear;
29
- function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
30
- let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
31
- if (inputBias) {
32
- output = output.add(inputBias);
20
+ class Sequential {
21
+ callables;
22
+ constructor(callables) {
23
+ this.callables = callables;
33
24
  }
34
- if (hiddenBias) {
35
- output = output.add(hiddenBias);
25
+ forward(input) {
26
+ return input.sequential(this.callables);
36
27
  }
37
- return output;
28
+ }
29
+ exports.Sequential = Sequential;
30
+ function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
31
+ return input.linear(inputWeight, inputBias).add(hidden.linear(hiddenWeight, hiddenBias));
38
32
  }
39
33
  class RNNCell {
40
34
  weightIH;
@@ -51,8 +45,6 @@ class RNNCell {
51
45
  }
52
46
  }
53
47
  forward(input, hidden) {
54
- input = this.weightIH.handleOther(input);
55
- hidden = this.weightHH.handleOther(hidden);
56
48
  return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
57
49
  }
58
50
  }
@@ -88,11 +80,9 @@ class GRUCell {
88
80
  }
89
81
  }
90
82
  forward(input, hidden) {
91
- input = this.weightIN.handleOther(input);
92
- hidden = this.weightHN.handleOther(hidden);
93
83
  const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
94
84
  const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
95
- const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
85
+ const n = input.linear(this.weightIN, this.biasIN).add(r.mul(hidden.linear(this.weightHN, this.biasHN))).tanh();
96
86
  return (z.neg().add(1).mul(n).add(z.mul(hidden)));
97
87
  }
98
88
  }
@@ -136,9 +126,6 @@ class LSTMCell {
136
126
  }
137
127
  }
138
128
  forward(input, hidden, cell) {
139
- input = this.weightII.handleOther(input);
140
- hidden = this.weightHI.handleOther(hidden);
141
- cell = this.weightHI.handleOther(cell);
142
129
  const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
143
130
  const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
144
131
  const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
@@ -240,34 +227,10 @@ class InstanceNorm {
240
227
  }
241
228
  }
242
229
  forward(input) {
243
- // Input should be at least 3D: [N, C, ...spatial dims]
244
- if (input.shape.length < 3) {
245
- throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
246
- }
247
230
  if (input.shape[1] !== this.numFeatures) {
248
231
  throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
249
232
  }
250
- // Normalize across spatial dimensions (all dims after channel dim)
251
- const dims = [];
252
- for (let i = 2; i < input.shape.length; i++) {
253
- dims.push(i);
254
- }
255
- const mean = input.mean(dims, true);
256
- const variance = input.sub(mean).pow(2).mean(dims, true);
257
- let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
258
- if (this.weight) {
259
- // Reshape weight to [1, C, 1, 1, ...] for broadcasting
260
- const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
261
- const weightReshaped = this.weight.reshape(weightShape);
262
- normalized = normalized.mul(weightReshaped);
263
- }
264
- if (this.bias) {
265
- // Reshape bias to [1, C, 1, 1, ...] for broadcasting
266
- const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
267
- const biasReshaped = this.bias.reshape(biasShape);
268
- normalized = normalized.add(biasReshaped);
269
- }
270
- return normalized;
233
+ return input.instanceNorm(this.weight, this.bias, this.eps);
271
234
  }
272
235
  }
273
236
  exports.InstanceNorm = InstanceNorm;
@@ -290,43 +253,10 @@ class GroupNorm {
290
253
  }
291
254
  }
292
255
  forward(input) {
293
- // Input should be at least 3D: [N, C, ...spatial dims]
294
- if (input.shape.length < 3) {
295
- throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
296
- }
297
256
  if (input.shape[1] !== this.numChannels) {
298
257
  throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
299
258
  }
300
- const N = input.shape[0];
301
- const C = input.shape[1];
302
- const spatialDims = input.shape.slice(2);
303
- const channelsPerGroup = C / this.numGroups;
304
- // Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
305
- const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
306
- // Normalize across (C//G, ...spatial) dimensions for each group
307
- // That's dims [2, 3, 4, ...] in the reshaped tensor
308
- const dims = [];
309
- for (let i = 2; i < reshapedInput.shape.length; i++) {
310
- dims.push(i);
311
- }
312
- const mean = reshapedInput.mean(dims, true);
313
- const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
314
- let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
315
- // Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
316
- normalized = normalized.reshape(input.shape);
317
- if (this.weight) {
318
- // Reshape weight to [1, C, 1, 1, ...] for broadcasting
319
- const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
320
- const weightReshaped = this.weight.reshape(weightShape);
321
- normalized = normalized.mul(weightReshaped);
322
- }
323
- if (this.bias) {
324
- // Reshape bias to [1, C, 1, 1, ...] for broadcasting
325
- const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
326
- const biasReshaped = this.bias.reshape(biasShape);
327
- normalized = normalized.add(biasReshaped);
328
- }
329
- return normalized;
259
+ return input.groupNorm(this.numGroups, this.weight, this.bias, this.eps);
330
260
  }
331
261
  }
332
262
  exports.GroupNorm = GroupNorm;
@@ -349,29 +279,7 @@ class LayerNorm {
349
279
  }
350
280
  }
351
281
  forward(input) {
352
- // Normalize over the specified dimensions
353
- const normalizedDims = this.normalizedShape.length;
354
- const startDim = input.shape.length - normalizedDims;
355
- if (startDim < 0) {
356
- throw new Error("Input does not have enough dims to normalize");
357
- }
358
- const dims = [];
359
- for (let i = 0; i < normalizedDims; i++) {
360
- if (input.shape[startDim + i] !== this.normalizedShape[i]) {
361
- throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
362
- }
363
- dims.push(startDim + i);
364
- }
365
- const mean = input.mean(dims, true);
366
- const variance = input.sub(mean).pow(2).mean(dims, true);
367
- let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
368
- if (this.weight) {
369
- normalized = normalized.mul(this.weight);
370
- }
371
- if (this.bias) {
372
- normalized = normalized.add(this.bias);
373
- }
374
- return normalized;
282
+ return input.layerNorm(this.normalizedShape, this.weight, this.bias, this.eps);
375
283
  }
376
284
  }
377
285
  exports.LayerNorm = LayerNorm;
@@ -390,25 +298,7 @@ class RMSNorm {
390
298
  }
391
299
  }
392
300
  forward(input) {
393
- // Normalize over the specified dimensions
394
- const normalizedDims = this.normalizedShape.length;
395
- const startDim = input.shape.length - normalizedDims;
396
- if (startDim < 0) {
397
- throw new Error("Input does not have enough dims to normalize");
398
- }
399
- const dims = [];
400
- for (let i = 0; i < normalizedDims; i++) {
401
- if (input.shape[startDim + i] !== this.normalizedShape[i]) {
402
- throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
403
- }
404
- dims.push(startDim + i);
405
- }
406
- let rms = input.square().mean(dims, true).add(this.eps).sqrt();
407
- let normalized = input.div(rms);
408
- if (this.weight) {
409
- normalized = normalized.mul(this.weight);
410
- }
411
- return normalized;
301
+ return input.rmsNorm(this.normalizedShape, this.weight, this.eps);
412
302
  }
413
303
  }
414
304
  exports.RMSNorm = RMSNorm;
@@ -422,25 +312,6 @@ class Embedding {
422
312
  }
423
313
  }
424
314
  exports.Embedding = Embedding;
425
- function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
426
- const targetLen = query.shape[query.shape.length - 2];
427
- const sourceLen = key.shape[key.shape.length - 2];
428
- const dimSize = query.shape[query.shape.length - 1];
429
- // Attention scores
430
- let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
431
- // Set attention mask to causal mask if specified
432
- if (isCausal) {
433
- attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
434
- }
435
- // Apply attention mask if specified
436
- if (attnMask) {
437
- scores = scores.maskedFill(attnMask, -Infinity);
438
- }
439
- // Calculate attention weights
440
- let attnWeights = scores.softmax().dropout(dropout);
441
- // Apply attention to values
442
- return attnWeights.matmul(value);
443
- }
444
315
  class MultiheadAttention {
445
316
  qProjection;
446
317
  kProjection;
@@ -561,6 +432,7 @@ const state = {
561
432
  };
562
433
  exports.nn = {
563
434
  Linear,
435
+ Sequential,
564
436
  RNNCell,
565
437
  GRUCell,
566
438
  LSTMCell,
@@ -570,7 +442,6 @@ exports.nn = {
570
442
  LayerNorm,
571
443
  RMSNorm,
572
444
  Embedding,
573
- scaledDotProductAttention,
574
445
  MultiheadAttention,
575
446
  state
576
447
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.21",
3
+ "version": "0.8.23",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {