catniff 0.8.2 → 0.8.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js CHANGED
@@ -636,7 +636,7 @@ class Tensor {
636
636
  let start = range[0] ?? 0;
637
637
  let end = range[1] ?? dimSize;
638
638
  let step = range[2] ?? 1;
639
- // Handle negative indicesoutGrad
639
+ // Handle negative indices
640
640
  if (start < 0)
641
641
  start += dimSize;
642
642
  if (end < 0)
package/dist/nn.d.ts CHANGED
@@ -70,6 +70,7 @@ export declare class Embedding {
70
70
  constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
71
71
  forward(input: Tensor | TensorValue): Tensor;
72
72
  }
73
+ export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
73
74
  export declare class MultiheadAttention {
74
75
  qProjection: Linear;
75
76
  kProjection: Linear;
@@ -80,7 +81,7 @@ export declare class MultiheadAttention {
80
81
  headDim: number;
81
82
  dropout: number;
82
83
  constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string, dtype?: dtype);
83
- forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
84
+ forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean, isCausal?: boolean): [Tensor, Tensor | undefined];
84
85
  }
85
86
  export interface StateDict {
86
87
  [key: string]: any;
@@ -93,6 +94,7 @@ export declare const nn: {
93
94
  LayerNorm: typeof LayerNorm;
94
95
  RMSNorm: typeof RMSNorm;
95
96
  Embedding: typeof Embedding;
97
+ scaledDotProductAttention: typeof scaledDotProductAttention;
96
98
  MultiheadAttention: typeof MultiheadAttention;
97
99
  state: {
98
100
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
package/dist/nn.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
4
+ exports.scaledDotProductAttention = scaledDotProductAttention;
4
5
  const core_1 = require("./core");
5
6
  function linearTransform(input, weight, bias) {
6
7
  let output = input.matmul(weight.t());
@@ -240,6 +241,25 @@ class Embedding {
240
241
  }
241
242
  }
242
243
  exports.Embedding = Embedding;
244
+ function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
245
+ const targetLen = query.shape[query.shape.length - 2];
246
+ const sourceLen = key.shape[key.shape.length - 2];
247
+ const dimSize = query.shape[query.shape.length - 1];
248
+ // Attention scores
249
+ let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
250
+ // Set attention mask to causal mask if specified
251
+ if (isCausal) {
252
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
253
+ }
254
+ // Apply attention mask if specified
255
+ if (attnMask) {
256
+ scores = scores.maskedFill(attnMask, -Infinity);
257
+ }
258
+ // Calculate attention weights
259
+ let attnWeights = scores.softmax().dropout(dropout);
260
+ // Apply attention to values
261
+ return attnWeights.matmul(value);
262
+ }
243
263
  class MultiheadAttention {
244
264
  qProjection;
245
265
  kProjection;
@@ -259,7 +279,7 @@ class MultiheadAttention {
259
279
  this.headDim = Math.floor(embedDim / numHeads);
260
280
  this.dropout = dropout;
261
281
  }
262
- forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true) {
282
+ forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true, isCausal = false) {
263
283
  // Batch-first
264
284
  const [batchSize, targetLen, embedDim] = query.shape;
265
285
  const sourceLen = key.shape[1];
@@ -272,6 +292,10 @@ class MultiheadAttention {
272
292
  V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
273
293
  // Attention scores
274
294
  let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim));
295
+ // Set attention mask to causal mask if specified
296
+ if (isCausal) {
297
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: this.qProjection.weight.device }).triu(1);
298
+ }
275
299
  // Apply attention mask if specified
276
300
  if (attnMask) {
277
301
  scores = scores.maskedFill(attnMask, -Infinity);
@@ -362,6 +386,7 @@ exports.nn = {
362
386
  LayerNorm,
363
387
  RMSNorm,
364
388
  Embedding,
389
+ scaledDotProductAttention,
365
390
  MultiheadAttention,
366
391
  state
367
392
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.2",
3
+ "version": "0.8.3",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {