catniff 0.8.20 → 0.8.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -1,5 +1,8 @@
1
1
  import { Backend } from "./backend";
2
2
  import { dtype, MemoryBuffer } from "./dtype";
3
+ export type Callable = ((input: Tensor) => Tensor) | {
4
+ forward: (input: Tensor) => Tensor;
5
+ };
3
6
  export type TensorValue = number | ArrayLike<TensorValue>;
4
7
  export interface TensorOptions {
5
8
  shape?: number[];
@@ -220,6 +223,13 @@ export declare class Tensor {
220
223
  tril(diagonal?: number): Tensor;
221
224
  maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
222
225
  multinomial(numSamples: number, replacement?: boolean): Tensor;
226
+ linear(weight: Tensor | TensorValue, bias?: Tensor | TensorValue): Tensor;
227
+ sequential(callables: Callable[]): Tensor;
228
+ layerNorm(normalizedShape: number[], weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
229
+ rmsNorm(normalizedShape: number[], weight?: Tensor | TensorValue, eps?: number): Tensor;
230
+ instanceNorm(weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
231
+ groupNorm(numGroups: number, weight?: Tensor | TensorValue, bias?: Tensor | TensorValue, eps?: number): Tensor;
232
+ scaledDotProductAttention(key: Tensor | TensorValue, value: Tensor | TensorValue, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
223
233
  static full(shape: number[], num: number, options?: TensorOptions): Tensor;
224
234
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
225
235
  static ones(shape?: number[], options?: TensorOptions): Tensor;
package/dist/core.js CHANGED
@@ -767,7 +767,7 @@ class Tensor {
767
767
  }
768
768
  // If dimension out of bound, throw error
769
769
  if (dim >= this.shape.length || dim < 0) {
770
- throw new Error("Dimension does not exist to apply softmax");
770
+ throw new Error("Dimension does not exist to apply unfold");
771
771
  }
772
772
  // Verify size and step
773
773
  if (size <= 0 || step <= 0)
@@ -781,13 +781,32 @@ class Tensor {
781
781
  const newStrides = [...this.strides, this.strides[dim]];
782
782
  newShape[dim] = outSize;
783
783
  newStrides[dim] = this.strides[dim] * step;
784
- return new Tensor(this.value, {
784
+ const out = new Tensor(this.value, {
785
785
  shape: newShape,
786
786
  strides: newStrides,
787
787
  offset: this.offset,
788
788
  dtype: this.dtype,
789
789
  device: this.device
790
790
  });
791
+ if (this.requiresGrad) {
792
+ out.requiresGrad = true;
793
+ out.children.push(this);
794
+ out.gradFn = () => {
795
+ const outGrad = out.grad;
796
+ const grad = Tensor.zerosLike(this);
797
+ for (let i = 0; i < out.numel; i++) {
798
+ const coords = Tensor.indexToCoords(i, newStrides);
799
+ const windowIdx = coords[dim];
800
+ const withinWindow = coords[coords.length - 1];
801
+ coords[dim] = windowIdx * step + withinWindow;
802
+ coords.pop();
803
+ const sourceIdx = Tensor.coordsToIndex(coords, this.strides);
804
+ grad.value[sourceIdx] += outGrad.value[i];
805
+ }
806
+ Tensor.addGrad(this, grad);
807
+ };
808
+ }
809
+ return out;
791
810
  }
792
811
  // Tensor concatentation
793
812
  cat(other, dim = 0) {
@@ -2283,6 +2302,174 @@ class Tensor {
2283
2302
  dtype: "int32"
2284
2303
  });
2285
2304
  }
2305
+ // Functional linear projection
2306
+ linear(weight, bias) {
2307
+ weight = this.handleOther(weight);
2308
+ let output = this.matmul(weight.transpose(-1, -2));
2309
+ if (bias) {
2310
+ bias = this.handleOther(bias);
2311
+ output = output.add(bias);
2312
+ }
2313
+ return output;
2314
+ }
2315
+ // Functional sequential chaining
2316
+ sequential(callables) {
2317
+ let res = this;
2318
+ for (let index = 0; index < callables.length; index++) {
2319
+ const callable = callables[index];
2320
+ if (typeof callable === "function") {
2321
+ res = callable(res);
2322
+ }
2323
+ else if (typeof callable === "object" && typeof callable.forward === "function") {
2324
+ res = callable.forward(res);
2325
+ }
2326
+ }
2327
+ return res;
2328
+ }
2329
+ // Functional layer norm
2330
+ layerNorm(normalizedShape, weight, bias, eps = 1e-05) {
2331
+ // Normalize over the specified dimensions
2332
+ const normalizedDims = normalizedShape.length;
2333
+ const startDim = this.shape.length - normalizedDims;
2334
+ if (startDim < 0) {
2335
+ throw new Error("Input does not have enough dims to normalize");
2336
+ }
2337
+ const dims = [];
2338
+ for (let i = 0; i < normalizedDims; i++) {
2339
+ if (this.shape[startDim + i] !== normalizedShape[i]) {
2340
+ throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
2341
+ }
2342
+ dims.push(startDim + i);
2343
+ }
2344
+ const mean = this.mean(dims, true);
2345
+ const centered = this.sub(mean);
2346
+ const variance = centered.pow(2).mean(dims, true);
2347
+ let normalized = centered.div(variance.add(eps).sqrt());
2348
+ if (weight) {
2349
+ normalized = normalized.mul(weight);
2350
+ }
2351
+ if (bias) {
2352
+ normalized = normalized.add(bias);
2353
+ }
2354
+ return normalized;
2355
+ }
2356
+ // Functional RMS norm
2357
+ rmsNorm(normalizedShape, weight, eps = 1e-5) {
2358
+ // Normalize over the specified dimensions
2359
+ const normalizedDims = normalizedShape.length;
2360
+ const startDim = this.shape.length - normalizedDims;
2361
+ if (startDim < 0) {
2362
+ throw new Error("Input does not have enough dims to normalize");
2363
+ }
2364
+ const dims = [];
2365
+ for (let i = 0; i < normalizedDims; i++) {
2366
+ if (this.shape[startDim + i] !== normalizedShape[i]) {
2367
+ throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${normalizedShape[i]}, got ${this.shape[startDim + i]}`);
2368
+ }
2369
+ dims.push(startDim + i);
2370
+ }
2371
+ let rms = this.square().mean(dims, true).add(eps).sqrt();
2372
+ let normalized = this.div(rms);
2373
+ if (weight) {
2374
+ normalized = normalized.mul(weight);
2375
+ }
2376
+ return normalized;
2377
+ }
2378
+ // Functional instance norm
2379
+ instanceNorm(weight, bias, eps = 1e-5) {
2380
+ // Input should be at least 3D: [N, C, ...spatial dims]
2381
+ if (this.shape.length < 3) {
2382
+ throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
2383
+ }
2384
+ // Normalize across spatial dimensions (all dims after channel dim)
2385
+ const dims = [];
2386
+ for (let i = 2; i < this.shape.length; i++) {
2387
+ dims.push(i);
2388
+ }
2389
+ const mean = this.mean(dims, true);
2390
+ const centered = this.sub(mean);
2391
+ const variance = centered.pow(2).mean(dims, true);
2392
+ let normalized = centered.div(variance.add(eps).sqrt());
2393
+ const numFeatures = this.shape[1];
2394
+ if (weight) {
2395
+ // Reshape weight to [1, C, 1, 1, ...] for broadcasting
2396
+ weight = this.handleOther(weight);
2397
+ const weightShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
2398
+ const weightReshaped = weight.reshape(weightShape);
2399
+ normalized = normalized.mul(weightReshaped);
2400
+ }
2401
+ if (bias) {
2402
+ // Reshape bias to [1, C, 1, 1, ...] for broadcasting
2403
+ bias = this.handleOther(bias);
2404
+ const biasShape = [1, numFeatures, ...Array(this.shape.length - 2).fill(1)];
2405
+ const biasReshaped = bias.reshape(biasShape);
2406
+ normalized = normalized.add(biasReshaped);
2407
+ }
2408
+ return normalized;
2409
+ }
2410
+ // Functional group norm
2411
+ groupNorm(numGroups, weight, bias, eps = 1e-5) {
2412
+ // Input should be at least 3D: [N, C, ...spatial dims]
2413
+ if (this.shape.length < 3) {
2414
+ throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
2415
+ }
2416
+ const N = this.shape[0];
2417
+ const C = this.shape[1];
2418
+ const spatialDims = this.shape.slice(2);
2419
+ const channelsPerGroup = C / numGroups;
2420
+ // Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
2421
+ const reshapedInput = this.reshape([N, numGroups, channelsPerGroup, ...spatialDims]);
2422
+ // Normalize across (C//G, ...spatial) dimensions for each group
2423
+ // That's dims [2, 3, 4, ...] in the reshaped tensor
2424
+ const dims = [];
2425
+ for (let i = 2; i < reshapedInput.shape.length; i++) {
2426
+ dims.push(i);
2427
+ }
2428
+ const mean = reshapedInput.mean(dims, true);
2429
+ const centered = reshapedInput.sub(mean);
2430
+ const variance = centered.pow(2).mean(dims, true);
2431
+ let normalized = centered.div(variance.add(eps).sqrt());
2432
+ // Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
2433
+ normalized = normalized.reshape(this.shape);
2434
+ const numChannels = this.shape[1];
2435
+ if (weight) {
2436
+ // Reshape weight to [1, C, 1, 1, ...] for broadcasting
2437
+ weight = this.handleOther(weight);
2438
+ const weightShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
2439
+ const weightReshaped = weight.reshape(weightShape);
2440
+ normalized = normalized.mul(weightReshaped);
2441
+ }
2442
+ if (bias) {
2443
+ // Reshape bias to [1, C, 1, 1, ...] for broadcasting
2444
+ bias = this.handleOther(bias);
2445
+ const biasShape = [1, numChannels, ...Array(spatialDims.length).fill(1)];
2446
+ const biasReshaped = bias.reshape(biasShape);
2447
+ normalized = normalized.add(biasReshaped);
2448
+ }
2449
+ return normalized;
2450
+ }
2451
+ // Functional scaled dot product attention
2452
+ scaledDotProductAttention(key, value, attnMask, dropout = 0, isCausal = false, scale) {
2453
+ key = this.handleOther(key);
2454
+ value = this.handleOther(value);
2455
+ const targetLen = this.shape[this.shape.length - 2];
2456
+ const sourceLen = key.shape[key.shape.length - 2];
2457
+ const dimSize = this.shape[this.shape.length - 1];
2458
+ // Attention scores
2459
+ let scores = this.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
2460
+ // Set attention mask to causal mask if specified
2461
+ if (isCausal) {
2462
+ attnMask = Tensor.ones([targetLen, sourceLen], { device: this.device }).triu(1);
2463
+ }
2464
+ // Apply attention mask if specified
2465
+ if (attnMask) {
2466
+ scores = scores.maskedFill(attnMask, -Infinity);
2467
+ }
2468
+ // Calculate attention weights
2469
+ let attnWeights = scores.softmax().dropout(dropout);
2470
+ // Apply attention to values
2471
+ return attnWeights.matmul(value);
2472
+ }
2286
2473
  // Utility to create a new tensor filled with a number
2287
2474
  static full(shape, num, options = {}) {
2288
2475
  if (shape.length === 0)
package/dist/nn.d.ts CHANGED
@@ -1,10 +1,15 @@
1
- import { Tensor, TensorValue } from "./core";
1
+ import { Callable, Tensor } from "./core";
2
2
  import { dtype } from "./dtype";
3
3
  export declare class Linear {
4
4
  weight: Tensor;
5
5
  bias?: Tensor;
6
6
  constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string, dtype?: dtype);
7
- forward(input: Tensor | TensorValue): Tensor;
7
+ forward(input: Tensor): Tensor;
8
+ }
9
+ export declare class Sequential {
10
+ callables: Callable[];
11
+ constructor(callables: Callable[]);
12
+ forward(input: Tensor): Tensor;
8
13
  }
9
14
  export declare class RNNCell {
10
15
  weightIH: Tensor;
@@ -12,7 +17,7 @@ export declare class RNNCell {
12
17
  biasIH?: Tensor;
13
18
  biasHH?: Tensor;
14
19
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
15
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
20
+ forward(input: Tensor, hidden: Tensor): Tensor;
16
21
  }
17
22
  export declare class GRUCell {
18
23
  weightIR: Tensor;
@@ -28,7 +33,7 @@ export declare class GRUCell {
28
33
  biasHZ?: Tensor;
29
34
  biasHN?: Tensor;
30
35
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
31
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
36
+ forward(input: Tensor, hidden: Tensor): Tensor;
32
37
  }
33
38
  export declare class LSTMCell {
34
39
  weightII: Tensor;
@@ -48,7 +53,7 @@ export declare class LSTMCell {
48
53
  biasHG?: Tensor;
49
54
  biasHO?: Tensor;
50
55
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
51
- forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
56
+ forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
52
57
  }
53
58
  export declare class BatchNorm {
54
59
  weight?: Tensor;
@@ -99,9 +104,8 @@ export declare class RMSNorm {
99
104
  export declare class Embedding {
100
105
  weight: Tensor;
101
106
  constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
102
- forward(input: Tensor | TensorValue): Tensor;
107
+ forward(input: Tensor): Tensor;
103
108
  }
104
- export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
105
109
  export declare class MultiheadAttention {
106
110
  qProjection: Linear;
107
111
  kProjection: Linear;
@@ -119,6 +123,7 @@ export interface StateDict {
119
123
  }
120
124
  export declare const nn: {
121
125
  Linear: typeof Linear;
126
+ Sequential: typeof Sequential;
122
127
  RNNCell: typeof RNNCell;
123
128
  GRUCell: typeof GRUCell;
124
129
  LSTMCell: typeof LSTMCell;
@@ -128,7 +133,6 @@ export declare const nn: {
128
133
  LayerNorm: typeof LayerNorm;
129
134
  RMSNorm: typeof RMSNorm;
130
135
  Embedding: typeof Embedding;
131
- scaledDotProductAttention: typeof scaledDotProductAttention;
132
136
  MultiheadAttention: typeof MultiheadAttention;
133
137
  state: {
134
138
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
package/dist/nn.js CHANGED
@@ -1,15 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
4
- exports.scaledDotProductAttention = scaledDotProductAttention;
3
+ exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
5
4
  const core_1 = require("./core");
6
- function linearTransform(input, weight, bias) {
7
- let output = input.matmul(weight.t());
8
- if (bias) {
9
- output = output.add(bias);
10
- }
11
- return output;
12
- }
13
5
  class Linear {
14
6
  weight;
15
7
  bias;
@@ -21,20 +13,22 @@ class Linear {
21
13
  }
22
14
  }
23
15
  forward(input) {
24
- input = this.weight.handleOther(input);
25
- return linearTransform(input, this.weight, this.bias);
16
+ return input.linear(this.weight, this.bias);
26
17
  }
27
18
  }
28
19
  exports.Linear = Linear;
29
- function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
30
- let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
31
- if (inputBias) {
32
- output = output.add(inputBias);
20
+ class Sequential {
21
+ callables;
22
+ constructor(callables) {
23
+ this.callables = callables;
33
24
  }
34
- if (hiddenBias) {
35
- output = output.add(hiddenBias);
25
+ forward(input) {
26
+ return input.sequential(this.callables);
36
27
  }
37
- return output;
28
+ }
29
+ exports.Sequential = Sequential;
30
+ function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
31
+ return input.linear(inputWeight, inputBias).add(hidden.linear(hiddenWeight, hiddenBias));
38
32
  }
39
33
  class RNNCell {
40
34
  weightIH;
@@ -51,8 +45,6 @@ class RNNCell {
51
45
  }
52
46
  }
53
47
  forward(input, hidden) {
54
- input = this.weightIH.handleOther(input);
55
- hidden = this.weightHH.handleOther(hidden);
56
48
  return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
57
49
  }
58
50
  }
@@ -88,11 +80,9 @@ class GRUCell {
88
80
  }
89
81
  }
90
82
  forward(input, hidden) {
91
- input = this.weightIN.handleOther(input);
92
- hidden = this.weightHN.handleOther(hidden);
93
83
  const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
94
84
  const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
95
- const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
85
+ const n = input.linear(this.weightIN, this.biasIN).add(r.mul(hidden.linear(this.weightHN, this.biasHN))).tanh();
96
86
  return (z.neg().add(1).mul(n).add(z.mul(hidden)));
97
87
  }
98
88
  }
@@ -136,9 +126,6 @@ class LSTMCell {
136
126
  }
137
127
  }
138
128
  forward(input, hidden, cell) {
139
- input = this.weightII.handleOther(input);
140
- hidden = this.weightHI.handleOther(hidden);
141
- cell = this.weightHI.handleOther(cell);
142
129
  const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
143
130
  const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
144
131
  const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
@@ -240,34 +227,10 @@ class InstanceNorm {
240
227
  }
241
228
  }
242
229
  forward(input) {
243
- // Input should be at least 3D: [N, C, ...spatial dims]
244
- if (input.shape.length < 3) {
245
- throw new Error("InstanceNorm expects at least 3D input [N, C, ...spatial]");
246
- }
247
230
  if (input.shape[1] !== this.numFeatures) {
248
231
  throw new Error(`Expected ${this.numFeatures} channels, got ${input.shape[1]}`);
249
232
  }
250
- // Normalize across spatial dimensions (all dims after channel dim)
251
- const dims = [];
252
- for (let i = 2; i < input.shape.length; i++) {
253
- dims.push(i);
254
- }
255
- const mean = input.mean(dims, true);
256
- const variance = input.sub(mean).pow(2).mean(dims, true);
257
- let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
258
- if (this.weight) {
259
- // Reshape weight to [1, C, 1, 1, ...] for broadcasting
260
- const weightShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
261
- const weightReshaped = this.weight.reshape(weightShape);
262
- normalized = normalized.mul(weightReshaped);
263
- }
264
- if (this.bias) {
265
- // Reshape bias to [1, C, 1, 1, ...] for broadcasting
266
- const biasShape = [1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)];
267
- const biasReshaped = this.bias.reshape(biasShape);
268
- normalized = normalized.add(biasReshaped);
269
- }
270
- return normalized;
233
+ return input.instanceNorm(this.weight, this.bias, this.eps);
271
234
  }
272
235
  }
273
236
  exports.InstanceNorm = InstanceNorm;
@@ -290,43 +253,10 @@ class GroupNorm {
290
253
  }
291
254
  }
292
255
  forward(input) {
293
- // Input should be at least 3D: [N, C, ...spatial dims]
294
- if (input.shape.length < 3) {
295
- throw new Error("GroupNorm expects at least 3D input [N, C, ...spatial]");
296
- }
297
256
  if (input.shape[1] !== this.numChannels) {
298
257
  throw new Error(`Expected ${this.numChannels} channels, got ${input.shape[1]}`);
299
258
  }
300
- const N = input.shape[0];
301
- const C = input.shape[1];
302
- const spatialDims = input.shape.slice(2);
303
- const channelsPerGroup = C / this.numGroups;
304
- // Reshape: [N, C, ...spatial] -> [N, G, C//G, ...spatial]
305
- const reshapedInput = input.reshape([N, this.numGroups, channelsPerGroup, ...spatialDims]);
306
- // Normalize across (C//G, ...spatial) dimensions for each group
307
- // That's dims [2, 3, 4, ...] in the reshaped tensor
308
- const dims = [];
309
- for (let i = 2; i < reshapedInput.shape.length; i++) {
310
- dims.push(i);
311
- }
312
- const mean = reshapedInput.mean(dims, true);
313
- const variance = reshapedInput.sub(mean).pow(2).mean(dims, true);
314
- let normalized = reshapedInput.sub(mean).div(variance.add(this.eps).sqrt());
315
- // Reshape back: [N, G, C//G, ...spatial] -> [N, C, ...spatial]
316
- normalized = normalized.reshape(input.shape);
317
- if (this.weight) {
318
- // Reshape weight to [1, C, 1, 1, ...] for broadcasting
319
- const weightShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
320
- const weightReshaped = this.weight.reshape(weightShape);
321
- normalized = normalized.mul(weightReshaped);
322
- }
323
- if (this.bias) {
324
- // Reshape bias to [1, C, 1, 1, ...] for broadcasting
325
- const biasShape = [1, this.numChannels, ...Array(spatialDims.length).fill(1)];
326
- const biasReshaped = this.bias.reshape(biasShape);
327
- normalized = normalized.add(biasReshaped);
328
- }
329
- return normalized;
259
+ return input.groupNorm(this.numGroups, this.weight, this.bias, this.eps);
330
260
  }
331
261
  }
332
262
  exports.GroupNorm = GroupNorm;
@@ -349,29 +279,7 @@ class LayerNorm {
349
279
  }
350
280
  }
351
281
  forward(input) {
352
- // Normalize over the specified dimensions
353
- const normalizedDims = this.normalizedShape.length;
354
- const startDim = input.shape.length - normalizedDims;
355
- if (startDim < 0) {
356
- throw new Error("Input does not have enough dims to normalize");
357
- }
358
- const dims = [];
359
- for (let i = 0; i < normalizedDims; i++) {
360
- if (input.shape[startDim + i] !== this.normalizedShape[i]) {
361
- throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
362
- }
363
- dims.push(startDim + i);
364
- }
365
- const mean = input.mean(dims, true);
366
- const variance = input.sub(mean).pow(2).mean(dims, true);
367
- let normalized = input.sub(mean).div(variance.add(this.eps).sqrt());
368
- if (this.weight) {
369
- normalized = normalized.mul(this.weight);
370
- }
371
- if (this.bias) {
372
- normalized = normalized.add(this.bias);
373
- }
374
- return normalized;
282
+ return input.layerNorm(this.normalizedShape, this.weight, this.bias, this.eps);
375
283
  }
376
284
  }
377
285
  exports.LayerNorm = LayerNorm;
@@ -390,25 +298,7 @@ class RMSNorm {
390
298
  }
391
299
  }
392
300
  forward(input) {
393
- // Normalize over the specified dimensions
394
- const normalizedDims = this.normalizedShape.length;
395
- const startDim = input.shape.length - normalizedDims;
396
- if (startDim < 0) {
397
- throw new Error("Input does not have enough dims to normalize");
398
- }
399
- const dims = [];
400
- for (let i = 0; i < normalizedDims; i++) {
401
- if (input.shape[startDim + i] !== this.normalizedShape[i]) {
402
- throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
403
- }
404
- dims.push(startDim + i);
405
- }
406
- let rms = input.square().mean(dims, true).add(this.eps).sqrt();
407
- let normalized = input.div(rms);
408
- if (this.weight) {
409
- normalized = normalized.mul(this.weight);
410
- }
411
- return normalized;
301
+ return input.rmsNorm(this.normalizedShape, this.weight, this.eps);
412
302
  }
413
303
  }
414
304
  exports.RMSNorm = RMSNorm;
@@ -422,25 +312,6 @@ class Embedding {
422
312
  }
423
313
  }
424
314
  exports.Embedding = Embedding;
425
- function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
426
- const targetLen = query.shape[query.shape.length - 2];
427
- const sourceLen = key.shape[key.shape.length - 2];
428
- const dimSize = query.shape[query.shape.length - 1];
429
- // Attention scores
430
- let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
431
- // Set attention mask to causal mask if specified
432
- if (isCausal) {
433
- attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
434
- }
435
- // Apply attention mask if specified
436
- if (attnMask) {
437
- scores = scores.maskedFill(attnMask, -Infinity);
438
- }
439
- // Calculate attention weights
440
- let attnWeights = scores.softmax().dropout(dropout);
441
- // Apply attention to values
442
- return attnWeights.matmul(value);
443
- }
444
315
  class MultiheadAttention {
445
316
  qProjection;
446
317
  kProjection;
@@ -561,6 +432,7 @@ const state = {
561
432
  };
562
433
  exports.nn = {
563
434
  Linear,
435
+ Sequential,
564
436
  RNNCell,
565
437
  GRUCell,
566
438
  LSTMCell,
@@ -570,7 +442,6 @@ exports.nn = {
570
442
  LayerNorm,
571
443
  RMSNorm,
572
444
  Embedding,
573
- scaledDotProductAttention,
574
445
  MultiheadAttention,
575
446
  state
576
447
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.20",
3
+ "version": "0.8.22",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {