catniff 0.8.1 → 0.8.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -69,6 +69,8 @@ export declare class Tensor {
69
69
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
70
70
  squeeze(dims?: number[] | number): Tensor;
71
71
  unsqueeze(dim: number): Tensor;
72
+ sort(dim?: number, descending?: boolean): Tensor;
73
+ topk(k: number, dim?: number, largest?: boolean): Tensor;
72
74
  static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
73
75
  identity: number;
74
76
  operation: (accumulator: number, value: number) => number;
package/dist/core.js CHANGED
@@ -208,13 +208,16 @@ class Tensor {
208
208
  const outputStrides = Tensor.getStrides(outputShape);
209
209
  const outputSize = Tensor.shapeToSize(outputShape);
210
210
  const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
211
+ // Check fast path conditions of two tensors
212
+ const aFastPath = tA.isContiguous() && tA.numel === outputSize;
213
+ const bFastPath = tB.isContiguous() && tB.numel === outputSize;
211
214
  for (let i = 0; i < outputSize; i++) {
212
215
  // Get coordinates from 1D index
213
- const coordsOutput = Tensor.indexToCoords(i, outputStrides);
216
+ const coordsOutput = aFastPath && bFastPath ? [] : Tensor.indexToCoords(i, outputStrides);
214
217
  // Convert the coordinates to 1D index of flattened A with respect to A's shape
215
- const indexA = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
218
+ const indexA = aFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
216
219
  // Convert the coordinates to 1D index of flattened B with respect to B's shape
217
- const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
220
+ const indexB = bFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
218
221
  // Calculate with op
219
222
  outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
220
223
  }
@@ -934,6 +937,114 @@ class Tensor {
934
937
  }
935
938
  return out;
936
939
  }
940
+ // Tensor sort
941
+ sort(dim = -1, descending = false) {
942
+ if (dim < 0) {
943
+ dim += this.shape.length;
944
+ }
945
+ // If dimension out of bound, throw error
946
+ if (dim >= this.shape.length || dim < 0) {
947
+ throw new Error("Dimension do not exist to sort");
948
+ }
949
+ // Copy if not contiguous
950
+ const outputSize = this.numel;
951
+ const outputShape = this.shape;
952
+ let outputValue, outputStrides;
953
+ if (this.isContiguous()) {
954
+ outputValue = [...this.value];
955
+ outputStrides = this.strides;
956
+ }
957
+ else {
958
+ outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
959
+ outputStrides = Tensor.getStrides(outputShape);
960
+ for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
961
+ const coords = Tensor.indexToCoords(flatIndex, outputStrides);
962
+ const originalIndex = Tensor.coordsToIndex(coords, this.strides);
963
+ outputValue[flatIndex] = this.value[originalIndex + this.offset];
964
+ }
965
+ }
966
+ // Calculate dimensions for gather-scatter
967
+ const dimSize = outputShape[dim];
968
+ const outerSize = outputShape.slice(0, dim).reduce((a, b) => a * b, 1);
969
+ const innerSize = outputShape.slice(dim + 1).reduce((a, b) => a * b, 1);
970
+ // Store permutation indices for gradient
971
+ const permutation = new Array(outputSize);
972
+ // Sort each group independently
973
+ for (let outer = 0; outer < outerSize; outer++) {
974
+ for (let inner = 0; inner < innerSize; inner++) {
975
+ const group = [];
976
+ for (let i = 0; i < dimSize; i++) {
977
+ const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
978
+ group.push({
979
+ value: outputValue[flatIdx],
980
+ dimIdx: i
981
+ });
982
+ }
983
+ // Sort this group by value
984
+ group.sort((a, b) => descending ? b.value - a.value : a.value - b.value);
985
+ // Scatter: write back sorted values and record permutation
986
+ for (let i = 0; i < dimSize; i++) {
987
+ const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
988
+ outputValue[flatIdx] = group[i].value;
989
+ // Record where this element came from (for gradient)
990
+ const originalFlatIdx = outer * (dimSize * innerSize) + group[i].dimIdx * innerSize + inner;
991
+ permutation[flatIdx] = originalFlatIdx;
992
+ }
993
+ }
994
+ }
995
+ const out = new Tensor(outputValue, {
996
+ shape: outputShape,
997
+ strides: outputStrides,
998
+ offset: 0,
999
+ numel: outputSize,
1000
+ device: this.device,
1001
+ dtype: this.dtype
1002
+ });
1003
+ // Gradient setup
1004
+ if (this.requiresGrad) {
1005
+ out.requiresGrad = true;
1006
+ out.children.push(this);
1007
+ out.gradFn = () => {
1008
+ const outGrad = out.grad;
1009
+ // Scatter output gradients back to original positions
1010
+ const inputGradValue = new dtype_1.TypedArray[this.dtype](outputSize);
1011
+ for (let sortedIdx = 0; sortedIdx < outputSize; sortedIdx++) {
1012
+ const originalIdx = permutation[sortedIdx];
1013
+ inputGradValue[originalIdx] = outGrad.value[sortedIdx];
1014
+ }
1015
+ const inputGrad = new Tensor(inputGradValue, {
1016
+ shape: outputShape,
1017
+ strides: outputStrides,
1018
+ offset: 0,
1019
+ numel: outputSize,
1020
+ device: this.device,
1021
+ dtype: this.dtype
1022
+ });
1023
+ Tensor.addGrad(this, inputGrad);
1024
+ };
1025
+ }
1026
+ return out;
1027
+ }
1028
+ // Top-k sampling
1029
+ topk(k, dim = -1, largest = true) {
1030
+ if (dim < 0) {
1031
+ dim += this.shape.length;
1032
+ }
1033
+ // If dimension out of bound, throw error
1034
+ if (dim >= this.shape.length || dim < 0) {
1035
+ throw new Error("Dimension do not exist to get topk");
1036
+ }
1037
+ const dimRanges = new Array(this.shape.length);
1038
+ for (let index = 0; index < dimRanges.length; index++) {
1039
+ if (index === dim) {
1040
+ dimRanges[index] = [0, k];
1041
+ }
1042
+ else {
1043
+ dimRanges[index] = [];
1044
+ }
1045
+ }
1046
+ return this.sort(dim, largest).slice(dimRanges);
1047
+ }
937
1048
  // Generic reduction operation handler
938
1049
  static reduce(tensor, dims, keepDims, config) {
939
1050
  if (tensor.shape.length === 0)
package/dist/nn.d.ts CHANGED
@@ -70,6 +70,7 @@ export declare class Embedding {
70
70
  constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
71
71
  forward(input: Tensor | TensorValue): Tensor;
72
72
  }
73
+ export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
73
74
  export declare class MultiheadAttention {
74
75
  qProjection: Linear;
75
76
  kProjection: Linear;
@@ -80,7 +81,7 @@ export declare class MultiheadAttention {
80
81
  headDim: number;
81
82
  dropout: number;
82
83
  constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string, dtype?: dtype);
83
- forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
84
+ forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean, isCausal?: boolean): [Tensor, Tensor | undefined];
84
85
  }
85
86
  export interface StateDict {
86
87
  [key: string]: any;
@@ -93,6 +94,7 @@ export declare const nn: {
93
94
  LayerNorm: typeof LayerNorm;
94
95
  RMSNorm: typeof RMSNorm;
95
96
  Embedding: typeof Embedding;
97
+ scaledDotProductAttention: typeof scaledDotProductAttention;
96
98
  MultiheadAttention: typeof MultiheadAttention;
97
99
  state: {
98
100
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
package/dist/nn.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
4
+ exports.scaledDotProductAttention = scaledDotProductAttention;
4
5
  const core_1 = require("./core");
5
6
  function linearTransform(input, weight, bias) {
6
7
  let output = input.matmul(weight.t());
@@ -240,6 +241,25 @@ class Embedding {
240
241
  }
241
242
  }
242
243
  exports.Embedding = Embedding;
244
+ function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
245
+ const targetLen = query.shape[query.shape.length - 2];
246
+ const sourceLen = key.shape[key.shape.length - 2];
247
+ const dimSize = query.shape[query.shape.length - 1];
248
+ // Attention scores
249
+ let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
250
+ // Set attention mask to causal mask if specified
251
+ if (isCausal) {
252
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
253
+ }
254
+ // Apply attention mask if specified
255
+ if (attnMask) {
256
+ scores = scores.maskedFill(attnMask, -Infinity);
257
+ }
258
+ // Calculate attention weights
259
+ let attnWeights = scores.softmax().dropout(dropout);
260
+ // Apply attention to values
261
+ return attnWeights.matmul(value);
262
+ }
243
263
  class MultiheadAttention {
244
264
  qProjection;
245
265
  kProjection;
@@ -259,7 +279,7 @@ class MultiheadAttention {
259
279
  this.headDim = Math.floor(embedDim / numHeads);
260
280
  this.dropout = dropout;
261
281
  }
262
- forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true) {
282
+ forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true, isCausal = false) {
263
283
  // Batch-first
264
284
  const [batchSize, targetLen, embedDim] = query.shape;
265
285
  const sourceLen = key.shape[1];
@@ -272,6 +292,10 @@ class MultiheadAttention {
272
292
  V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
273
293
  // Attention scores
274
294
  let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim));
295
+ // Set attention mask to causal mask if specified
296
+ if (isCausal) {
297
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: this.qProjection.weight.device }).triu(1);
298
+ }
275
299
  // Apply attention mask if specified
276
300
  if (attnMask) {
277
301
  scores = scores.maskedFill(attnMask, -Infinity);
@@ -362,6 +386,7 @@ exports.nn = {
362
386
  LayerNorm,
363
387
  RMSNorm,
364
388
  Embedding,
389
+ scaledDotProductAttention,
365
390
  MultiheadAttention,
366
391
  state
367
392
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.1",
3
+ "version": "0.8.3",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {