catniff 0.6.4 → 0.6.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -76,9 +76,9 @@ optim.step();
76
76
  console.log("Updated weight:", w.data); // Should move toward 3.0
77
77
  ```
78
78
 
79
- ## Neural networks
79
+ ## Neural networks & Deep learning
80
80
 
81
- There are built-in neural network constructs in Catniff as well:
81
+ There are built-in neural network constructs in Catniff as well, from simple prebuilt nn layers:
82
82
  ```js
83
83
  const { Tensor, nn } = require("catniff");
84
84
 
@@ -102,6 +102,24 @@ gruCell.forward(b, c);
102
102
  lstmCell.forward(b, c, c);
103
103
  ```
104
104
 
105
+ to more advanced constructs like normalization, embedding, and attention:
106
+ ```js
107
+ // 1. Embedding: tokens -> vectors
108
+ const embedding = new nn.Embedding(100, 64);
109
+ const tokens = new Tensor([[1, 5, 23], [8, 2, 15]]);
110
+ const embedded = embedding.forward(tokens);
111
+
112
+ // 2. Self-Attention
113
+ const attention = new nn.MultiheadAttention(64, 8, 0.1);
114
+ const [output, weights] = attention.forward(embedded, embedded, embedded);
115
+
116
+ // 3. Layer Normalization
117
+ const layerNorm = new nn.LayerNorm(64);
118
+ const normalized = layerNorm.forward(output);
119
+
120
+ console.log(normalized.val());
121
+ ```
122
+
105
123
  And it can still do much more, check out the docs and examples below for more information.
106
124
 
107
125
  ## Documentation
package/dist/core.d.ts CHANGED
@@ -184,6 +184,7 @@ export declare class Tensor {
184
184
  dropout(rate: number): Tensor;
185
185
  triu(diagonal?: number): Tensor;
186
186
  tril(diagonal?: number): Tensor;
187
+ maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
187
188
  static full(shape: readonly number[], num: number, options?: TensorOptions): Tensor;
188
189
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
189
190
  static ones(shape?: readonly number[], options?: TensorOptions): Tensor;
package/dist/core.js CHANGED
@@ -1465,6 +1465,11 @@ class Tensor {
1465
1465
  });
1466
1466
  return this.mul(mask);
1467
1467
  }
1468
+ // Fill specific positions of this tensor with a value through a mask
1469
+ maskedFill(mask, value) {
1470
+ mask = this.handleOther(mask);
1471
+ return this.mul(mask.logicalNot()).add(mask.mul(value));
1472
+ }
1468
1473
  // Utility to create a new tensor filled with a number
1469
1474
  static full(shape, num, options = {}) {
1470
1475
  if (shape.length === 0)
package/dist/nn.d.ts CHANGED
@@ -62,6 +62,18 @@ declare class Embedding {
62
62
  constructor(numEmbeddings: number, embeddingDim: number, device: string);
63
63
  forward(input: Tensor | TensorValue): Tensor;
64
64
  }
65
+ declare class MultiheadAttention {
66
+ qProjection: Linear;
67
+ kProjection: Linear;
68
+ vProjection: Linear;
69
+ oProjection: Linear;
70
+ embedDim: number;
71
+ numHeads: number;
72
+ headDim: number;
73
+ dropout: number;
74
+ constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string);
75
+ forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
76
+ }
65
77
  export interface StateDict {
66
78
  [key: string]: any;
67
79
  }
@@ -72,6 +84,7 @@ export declare const nn: {
72
84
  LSTMCell: typeof LSTMCell;
73
85
  LayerNorm: typeof LayerNorm;
74
86
  Embedding: typeof Embedding;
87
+ MultiheadAttention: typeof MultiheadAttention;
75
88
  state: {
76
89
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
77
90
  moveParameters(model: any, device: string): void;
package/dist/nn.js CHANGED
@@ -197,6 +197,57 @@ class Embedding {
197
197
  return this.weight.index(input);
198
198
  }
199
199
  }
200
+ class MultiheadAttention {
201
+ qProjection;
202
+ kProjection;
203
+ vProjection;
204
+ oProjection;
205
+ embedDim;
206
+ numHeads;
207
+ headDim;
208
+ dropout;
209
+ constructor(embedDim, numHeads, dropout = 0, bias = true, device) {
210
+ this.qProjection = new exports.nn.Linear(embedDim, embedDim, bias, device);
211
+ this.kProjection = new exports.nn.Linear(embedDim, embedDim, bias, device);
212
+ this.vProjection = new exports.nn.Linear(embedDim, embedDim, bias, device);
213
+ this.oProjection = new exports.nn.Linear(embedDim, embedDim, bias, device);
214
+ this.embedDim = embedDim;
215
+ this.numHeads = numHeads;
216
+ this.headDim = Math.floor(embedDim / numHeads);
217
+ this.dropout = dropout;
218
+ }
219
+ forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true) {
220
+ // Batch-first
221
+ const [batchSize, targetLen, embedDim] = query.shape;
222
+ const sourceLen = key.shape[1];
223
+ let Q = this.qProjection.forward(query); // (batchSize, targetLen, embedDim)
224
+ let K = this.kProjection.forward(key); // (batchSize, sourceLen, embedDim)
225
+ let V = this.vProjection.forward(value); // (batchSize, sourceLen, embedDim)
226
+ // (batchSize, numHeads, targetLen/sourceLen, headDim)
227
+ Q = Q.reshape([batchSize, targetLen, this.numHeads, this.headDim]).transpose(1, 2);
228
+ K = K.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
229
+ V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
230
+ // Attention scores
231
+ let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim));
232
+ // Apply attention mask if specified
233
+ if (attnMask) {
234
+ scores = scores.maskedFill(attnMask, -Infinity);
235
+ }
236
+ // Calculate attention weights
237
+ let attnWeights = scores.softmax().dropout(this.dropout);
238
+ // Apply attention to values
239
+ let attnOutput = attnWeights.matmul(V); // (batchSize, numHeads, targetLen, headDim)
240
+ // (batchSize, targetLen, embedDim)
241
+ attnOutput = attnOutput.transpose(1, 2).reshape([batchSize, targetLen, embedDim]);
242
+ // Output
243
+ const output = this.oProjection.forward(attnOutput);
244
+ // Average weights if needed
245
+ if (averageAttnWeights) {
246
+ attnWeights = attnWeights.mean(1);
247
+ }
248
+ return [output, needWeights ? attnWeights : undefined];
249
+ }
250
+ }
200
251
  const state = {
201
252
  getParameters(model, visited = new WeakSet()) {
202
253
  if (visited.has(model))
@@ -266,5 +317,6 @@ exports.nn = {
266
317
  LSTMCell,
267
318
  LayerNorm,
268
319
  Embedding,
320
+ MultiheadAttention,
269
321
  state
270
322
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.6.4",
3
+ "version": "0.6.5",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {