catniff 0.8.2 → 0.8.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.js +1 -1
- package/dist/nn.d.ts +3 -1
- package/dist/nn.js +26 -1
- package/package.json +1 -1
package/dist/core.js
CHANGED
package/dist/nn.d.ts
CHANGED
|
@@ -70,6 +70,7 @@ export declare class Embedding {
|
|
|
70
70
|
constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
|
|
71
71
|
forward(input: Tensor | TensorValue): Tensor;
|
|
72
72
|
}
|
|
73
|
+
export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
|
|
73
74
|
export declare class MultiheadAttention {
|
|
74
75
|
qProjection: Linear;
|
|
75
76
|
kProjection: Linear;
|
|
@@ -80,7 +81,7 @@ export declare class MultiheadAttention {
|
|
|
80
81
|
headDim: number;
|
|
81
82
|
dropout: number;
|
|
82
83
|
constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
83
|
-
forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
|
|
84
|
+
forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean, isCausal?: boolean): [Tensor, Tensor | undefined];
|
|
84
85
|
}
|
|
85
86
|
export interface StateDict {
|
|
86
87
|
[key: string]: any;
|
|
@@ -93,6 +94,7 @@ export declare const nn: {
|
|
|
93
94
|
LayerNorm: typeof LayerNorm;
|
|
94
95
|
RMSNorm: typeof RMSNorm;
|
|
95
96
|
Embedding: typeof Embedding;
|
|
97
|
+
scaledDotProductAttention: typeof scaledDotProductAttention;
|
|
96
98
|
MultiheadAttention: typeof MultiheadAttention;
|
|
97
99
|
state: {
|
|
98
100
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
|
+
exports.scaledDotProductAttention = scaledDotProductAttention;
|
|
4
5
|
const core_1 = require("./core");
|
|
5
6
|
function linearTransform(input, weight, bias) {
|
|
6
7
|
let output = input.matmul(weight.t());
|
|
@@ -240,6 +241,25 @@ class Embedding {
|
|
|
240
241
|
}
|
|
241
242
|
}
|
|
242
243
|
exports.Embedding = Embedding;
|
|
244
|
+
function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
|
|
245
|
+
const targetLen = query.shape[query.shape.length - 2];
|
|
246
|
+
const sourceLen = key.shape[key.shape.length - 2];
|
|
247
|
+
const dimSize = query.shape[query.shape.length - 1];
|
|
248
|
+
// Attention scores
|
|
249
|
+
let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
|
|
250
|
+
// Set attention mask to causal mask if specified
|
|
251
|
+
if (isCausal) {
|
|
252
|
+
attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
|
|
253
|
+
}
|
|
254
|
+
// Apply attention mask if specified
|
|
255
|
+
if (attnMask) {
|
|
256
|
+
scores = scores.maskedFill(attnMask, -Infinity);
|
|
257
|
+
}
|
|
258
|
+
// Calculate attention weights
|
|
259
|
+
let attnWeights = scores.softmax().dropout(dropout);
|
|
260
|
+
// Apply attention to values
|
|
261
|
+
return attnWeights.matmul(value);
|
|
262
|
+
}
|
|
243
263
|
class MultiheadAttention {
|
|
244
264
|
qProjection;
|
|
245
265
|
kProjection;
|
|
@@ -259,7 +279,7 @@ class MultiheadAttention {
|
|
|
259
279
|
this.headDim = Math.floor(embedDim / numHeads);
|
|
260
280
|
this.dropout = dropout;
|
|
261
281
|
}
|
|
262
|
-
forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true) {
|
|
282
|
+
forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true, isCausal = false) {
|
|
263
283
|
// Batch-first
|
|
264
284
|
const [batchSize, targetLen, embedDim] = query.shape;
|
|
265
285
|
const sourceLen = key.shape[1];
|
|
@@ -272,6 +292,10 @@ class MultiheadAttention {
|
|
|
272
292
|
V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
|
|
273
293
|
// Attention scores
|
|
274
294
|
let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim));
|
|
295
|
+
// Set attention mask to causal mask if specified
|
|
296
|
+
if (isCausal) {
|
|
297
|
+
attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: this.qProjection.weight.device }).triu(1);
|
|
298
|
+
}
|
|
275
299
|
// Apply attention mask if specified
|
|
276
300
|
if (attnMask) {
|
|
277
301
|
scores = scores.maskedFill(attnMask, -Infinity);
|
|
@@ -362,6 +386,7 @@ exports.nn = {
|
|
|
362
386
|
LayerNorm,
|
|
363
387
|
RMSNorm,
|
|
364
388
|
Embedding,
|
|
389
|
+
scaledDotProductAttention,
|
|
365
390
|
MultiheadAttention,
|
|
366
391
|
state
|
|
367
392
|
};
|