catniff 0.8.2 → 0.8.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -201,6 +201,7 @@ export declare class Tensor {
201
201
  triu(diagonal?: number): Tensor;
202
202
  tril(diagonal?: number): Tensor;
203
203
  maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
204
+ multinomial(numSamples: number, replacement?: boolean): Tensor;
204
205
  static full(shape: number[], num: number, options?: TensorOptions): Tensor;
205
206
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
206
207
  static ones(shape?: number[], options?: TensorOptions): Tensor;
package/dist/core.js CHANGED
@@ -636,7 +636,7 @@ class Tensor {
636
636
  let start = range[0] ?? 0;
637
637
  let end = range[1] ?? dimSize;
638
638
  let step = range[2] ?? 1;
639
- // Handle negative indicesoutGrad
639
+ // Handle negative indices
640
640
  if (start < 0)
641
641
  start += dimSize;
642
642
  if (end < 0)
@@ -949,14 +949,14 @@ class Tensor {
949
949
  // Copy if not contiguous
950
950
  const outputSize = this.numel;
951
951
  const outputShape = this.shape;
952
- let outputValue, outputStrides;
952
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
953
+ const outputStrides = Tensor.getStrides(outputShape);
953
954
  if (this.isContiguous()) {
954
- outputValue = [...this.value];
955
- outputStrides = this.strides;
955
+ // Fast path: direct copy
956
+ outputValue.set(this.value.subarray(this.offset, this.offset + outputSize));
956
957
  }
957
958
  else {
958
- outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
959
- outputStrides = Tensor.getStrides(outputShape);
959
+ // Slow path: coordinate conversion
960
960
  for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
961
961
  const coords = Tensor.indexToCoords(flatIndex, outputStrides);
962
962
  const originalIndex = Tensor.coordsToIndex(coords, this.strides);
@@ -1912,6 +1912,129 @@ class Tensor {
1912
1912
  mask = this.handleOther(mask);
1913
1913
  return this.mul(mask.logicalNot()).add(mask.mul(value));
1914
1914
  }
1915
+ // Multinomial sampling
1916
+ multinomial(numSamples, replacement = false) {
1917
+ // Validate input dimensions (1D or 2D only)
1918
+ if (this.shape.length === 0 || this.shape.length > 2) {
1919
+ throw new Error("multinomial only supports 1D or 2D probability tensors");
1920
+ }
1921
+ const is1D = this.shape.length === 1;
1922
+ const numDist = is1D ? 1 : this.shape[0];
1923
+ const numCategories = is1D ? this.shape[0] : this.shape[1];
1924
+ // Validate numSamples
1925
+ if (numSamples <= 0) {
1926
+ throw new Error("Number of samples must be positive");
1927
+ }
1928
+ if (!replacement && numSamples > numCategories) {
1929
+ throw new Error(`Cannot sample ${numSamples} without replacement from ${numCategories} categories`);
1930
+ }
1931
+ // Make contiguous copy of probabilities
1932
+ const probsSize = this.numel;
1933
+ const probs = new dtype_1.TypedArray[this.dtype](probsSize);
1934
+ if (this.isContiguous()) {
1935
+ // Fast path: direct copy
1936
+ probs.set(this.value.subarray(this.offset, this.offset + probsSize));
1937
+ }
1938
+ else {
1939
+ // Slow path: coordinate conversion
1940
+ const defaultStrides = Tensor.getStrides(this.shape);
1941
+ for (let i = 0; i < probsSize; i++) {
1942
+ const coords = Tensor.indexToCoords(i, defaultStrides);
1943
+ const idx = Tensor.coordsToIndex(coords, this.strides);
1944
+ probs[i] = this.value[idx + this.offset];
1945
+ }
1946
+ }
1947
+ // Output setup
1948
+ const outputShape = is1D ? [numSamples] : [numDist, numSamples];
1949
+ const outputValue = new Int32Array(numDist * numSamples);
1950
+ // Sample from each distribution
1951
+ for (let dist = 0; dist < numDist; dist++) {
1952
+ const offset = dist * numCategories;
1953
+ // Extract this distribution's probabilities
1954
+ const distProbs = probs.slice(offset, offset + numCategories);
1955
+ // Validate and normalize
1956
+ let sum = 0;
1957
+ for (let i = 0; i < numCategories; i++) {
1958
+ if (distProbs[i] < 0) {
1959
+ throw new Error("Probabilities cannot be negative");
1960
+ }
1961
+ sum += distProbs[i];
1962
+ }
1963
+ if (sum <= 0) {
1964
+ throw new Error("Probabilities must sum to a positive value");
1965
+ }
1966
+ // Normalize
1967
+ for (let i = 0; i < numCategories; i++) {
1968
+ distProbs[i] /= sum;
1969
+ }
1970
+ if (replacement) {
1971
+ // With replacement: use CDF for efficient sampling
1972
+ const cdf = new Array(numCategories);
1973
+ let cumSum = 0;
1974
+ for (let i = 0; i < numCategories; i++) {
1975
+ cumSum += distProbs[i];
1976
+ cdf[i] = cumSum;
1977
+ }
1978
+ cdf[numCategories - 1] = 1;
1979
+ for (let s = 0; s < numSamples; s++) {
1980
+ const r = Math.random();
1981
+ // Binary search for efficiency
1982
+ let left = 0;
1983
+ let right = numCategories - 1;
1984
+ while (left < right) {
1985
+ const mid = Math.floor((left + right) / 2);
1986
+ if (r <= cdf[mid]) {
1987
+ right = mid;
1988
+ }
1989
+ else {
1990
+ left = mid + 1;
1991
+ }
1992
+ }
1993
+ outputValue[dist * numSamples + s] = left;
1994
+ }
1995
+ }
1996
+ else {
1997
+ // Without replacement: weighted sampling without replacement
1998
+ const available = Array.from({ length: numCategories }, (_, i) => ({
1999
+ idx: i,
2000
+ prob: distProbs[i]
2001
+ }));
2002
+ for (let s = 0; s < numSamples; s++) {
2003
+ // Compute sum of remaining probabilities
2004
+ let remainingSum = 0;
2005
+ for (const item of available) {
2006
+ remainingSum += item.prob;
2007
+ }
2008
+ // Sample from remaining
2009
+ const r = Math.random() * remainingSum;
2010
+ let cumSum = 0;
2011
+ let selectedIdx = -1;
2012
+ for (let i = 0; i < available.length; i++) {
2013
+ cumSum += available[i].prob;
2014
+ if (r <= cumSum) {
2015
+ selectedIdx = i;
2016
+ break;
2017
+ }
2018
+ }
2019
+ // Handle floating point edge case
2020
+ if (selectedIdx === -1) {
2021
+ selectedIdx = available.length - 1;
2022
+ }
2023
+ // Store result and remove from available
2024
+ outputValue[dist * numSamples + s] = available[selectedIdx].idx;
2025
+ available.splice(selectedIdx, 1);
2026
+ }
2027
+ }
2028
+ }
2029
+ return new Tensor(outputValue, {
2030
+ shape: outputShape,
2031
+ strides: Tensor.getStrides(outputShape),
2032
+ offset: 0,
2033
+ numel: numDist * numSamples,
2034
+ device: this.device,
2035
+ dtype: "int32"
2036
+ });
2037
+ }
1915
2038
  // Utility to create a new tensor filled with a number
1916
2039
  static full(shape, num, options = {}) {
1917
2040
  if (shape.length === 0)
@@ -2094,6 +2217,7 @@ class Tensor {
2094
2217
  shape,
2095
2218
  offset: 0,
2096
2219
  numel: outputSize,
2220
+ dtype: "int32",
2097
2221
  ...options
2098
2222
  });
2099
2223
  }
@@ -2130,6 +2254,7 @@ class Tensor {
2130
2254
  shape: [n],
2131
2255
  offset: 0,
2132
2256
  numel: n,
2257
+ dtype: "int32",
2133
2258
  ...options
2134
2259
  });
2135
2260
  }
package/dist/nn.d.ts CHANGED
@@ -70,6 +70,7 @@ export declare class Embedding {
70
70
  constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
71
71
  forward(input: Tensor | TensorValue): Tensor;
72
72
  }
73
+ export declare function scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, attnMask?: Tensor, dropout?: number, isCausal?: boolean, scale?: number): Tensor;
73
74
  export declare class MultiheadAttention {
74
75
  qProjection: Linear;
75
76
  kProjection: Linear;
@@ -80,7 +81,7 @@ export declare class MultiheadAttention {
80
81
  headDim: number;
81
82
  dropout: number;
82
83
  constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string, dtype?: dtype);
83
- forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
84
+ forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean, isCausal?: boolean): [Tensor, Tensor | undefined];
84
85
  }
85
86
  export interface StateDict {
86
87
  [key: string]: any;
@@ -93,6 +94,7 @@ export declare const nn: {
93
94
  LayerNorm: typeof LayerNorm;
94
95
  RMSNorm: typeof RMSNorm;
95
96
  Embedding: typeof Embedding;
97
+ scaledDotProductAttention: typeof scaledDotProductAttention;
96
98
  MultiheadAttention: typeof MultiheadAttention;
97
99
  state: {
98
100
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
package/dist/nn.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
4
+ exports.scaledDotProductAttention = scaledDotProductAttention;
4
5
  const core_1 = require("./core");
5
6
  function linearTransform(input, weight, bias) {
6
7
  let output = input.matmul(weight.t());
@@ -240,6 +241,25 @@ class Embedding {
240
241
  }
241
242
  }
242
243
  exports.Embedding = Embedding;
244
+ function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) {
245
+ const targetLen = query.shape[query.shape.length - 2];
246
+ const sourceLen = key.shape[key.shape.length - 2];
247
+ const dimSize = query.shape[query.shape.length - 1];
248
+ // Attention scores
249
+ let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize));
250
+ // Set attention mask to causal mask if specified
251
+ if (isCausal) {
252
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1);
253
+ }
254
+ // Apply attention mask if specified
255
+ if (attnMask) {
256
+ scores = scores.maskedFill(attnMask, -Infinity);
257
+ }
258
+ // Calculate attention weights
259
+ let attnWeights = scores.softmax().dropout(dropout);
260
+ // Apply attention to values
261
+ return attnWeights.matmul(value);
262
+ }
243
263
  class MultiheadAttention {
244
264
  qProjection;
245
265
  kProjection;
@@ -259,7 +279,7 @@ class MultiheadAttention {
259
279
  this.headDim = Math.floor(embedDim / numHeads);
260
280
  this.dropout = dropout;
261
281
  }
262
- forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true) {
282
+ forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true, isCausal = false) {
263
283
  // Batch-first
264
284
  const [batchSize, targetLen, embedDim] = query.shape;
265
285
  const sourceLen = key.shape[1];
@@ -272,6 +292,10 @@ class MultiheadAttention {
272
292
  V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2);
273
293
  // Attention scores
274
294
  let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim));
295
+ // Set attention mask to causal mask if specified
296
+ if (isCausal) {
297
+ attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: this.qProjection.weight.device }).triu(1);
298
+ }
275
299
  // Apply attention mask if specified
276
300
  if (attnMask) {
277
301
  scores = scores.maskedFill(attnMask, -Infinity);
@@ -362,6 +386,7 @@ exports.nn = {
362
386
  LayerNorm,
363
387
  RMSNorm,
364
388
  Embedding,
389
+ scaledDotProductAttention,
365
390
  MultiheadAttention,
366
391
  state
367
392
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.2",
3
+ "version": "0.8.4",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {