catniff 0.8.3 → 0.8.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -201,6 +201,7 @@ export declare class Tensor {
201
201
  triu(diagonal?: number): Tensor;
202
202
  tril(diagonal?: number): Tensor;
203
203
  maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
204
+ multinomial(numSamples: number, replacement?: boolean): Tensor;
204
205
  static full(shape: number[], num: number, options?: TensorOptions): Tensor;
205
206
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
206
207
  static ones(shape?: number[], options?: TensorOptions): Tensor;
@@ -223,6 +224,7 @@ export declare class Tensor {
223
224
  zeroGrad?: boolean;
224
225
  }): void;
225
226
  val(): TensorValue;
227
+ toString(): string;
226
228
  detach(): Tensor;
227
229
  clone(): Tensor;
228
230
  replace(other: Tensor | TensorValue): Tensor;
package/dist/core.js CHANGED
@@ -949,14 +949,14 @@ class Tensor {
949
949
  // Copy if not contiguous
950
950
  const outputSize = this.numel;
951
951
  const outputShape = this.shape;
952
- let outputValue, outputStrides;
952
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
953
+ const outputStrides = Tensor.getStrides(outputShape);
953
954
  if (this.isContiguous()) {
954
- outputValue = [...this.value];
955
- outputStrides = this.strides;
955
+ // Fast path: direct copy
956
+ outputValue.set(this.value.subarray(this.offset, this.offset + outputSize));
956
957
  }
957
958
  else {
958
- outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
959
- outputStrides = Tensor.getStrides(outputShape);
959
+ // Slow path: coordinate conversion
960
960
  for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
961
961
  const coords = Tensor.indexToCoords(flatIndex, outputStrides);
962
962
  const originalIndex = Tensor.coordsToIndex(coords, this.strides);
@@ -1912,6 +1912,129 @@ class Tensor {
1912
1912
  mask = this.handleOther(mask);
1913
1913
  return this.mul(mask.logicalNot()).add(mask.mul(value));
1914
1914
  }
1915
+ // Multinomial sampling
1916
+ multinomial(numSamples, replacement = false) {
1917
+ // Validate input dimensions (1D or 2D only)
1918
+ if (this.shape.length === 0 || this.shape.length > 2) {
1919
+ throw new Error("multinomial only supports 1D or 2D probability tensors");
1920
+ }
1921
+ const is1D = this.shape.length === 1;
1922
+ const numDist = is1D ? 1 : this.shape[0];
1923
+ const numCategories = is1D ? this.shape[0] : this.shape[1];
1924
+ // Validate numSamples
1925
+ if (numSamples <= 0) {
1926
+ throw new Error("Number of samples must be positive");
1927
+ }
1928
+ if (!replacement && numSamples > numCategories) {
1929
+ throw new Error(`Cannot sample ${numSamples} without replacement from ${numCategories} categories`);
1930
+ }
1931
+ // Make contiguous copy of probabilities
1932
+ const probsSize = this.numel;
1933
+ const probs = new dtype_1.TypedArray[this.dtype](probsSize);
1934
+ if (this.isContiguous()) {
1935
+ // Fast path: direct copy
1936
+ probs.set(this.value.subarray(this.offset, this.offset + probsSize));
1937
+ }
1938
+ else {
1939
+ // Slow path: coordinate conversion
1940
+ const defaultStrides = Tensor.getStrides(this.shape);
1941
+ for (let i = 0; i < probsSize; i++) {
1942
+ const coords = Tensor.indexToCoords(i, defaultStrides);
1943
+ const idx = Tensor.coordsToIndex(coords, this.strides);
1944
+ probs[i] = this.value[idx + this.offset];
1945
+ }
1946
+ }
1947
+ // Output setup
1948
+ const outputShape = is1D ? [numSamples] : [numDist, numSamples];
1949
+ const outputValue = new Int32Array(numDist * numSamples);
1950
+ // Sample from each distribution
1951
+ for (let dist = 0; dist < numDist; dist++) {
1952
+ const offset = dist * numCategories;
1953
+ // Extract this distribution's probabilities
1954
+ const distProbs = probs.slice(offset, offset + numCategories);
1955
+ // Validate and normalize
1956
+ let sum = 0;
1957
+ for (let i = 0; i < numCategories; i++) {
1958
+ if (distProbs[i] < 0) {
1959
+ throw new Error("Probabilities cannot be negative");
1960
+ }
1961
+ sum += distProbs[i];
1962
+ }
1963
+ if (sum <= 0) {
1964
+ throw new Error("Probabilities must sum to a positive value");
1965
+ }
1966
+ // Normalize
1967
+ for (let i = 0; i < numCategories; i++) {
1968
+ distProbs[i] /= sum;
1969
+ }
1970
+ if (replacement) {
1971
+ // With replacement: use CDF for efficient sampling
1972
+ const cdf = new Array(numCategories);
1973
+ let cumSum = 0;
1974
+ for (let i = 0; i < numCategories; i++) {
1975
+ cumSum += distProbs[i];
1976
+ cdf[i] = cumSum;
1977
+ }
1978
+ cdf[numCategories - 1] = 1;
1979
+ for (let s = 0; s < numSamples; s++) {
1980
+ const r = Math.random();
1981
+ // Binary search for efficiency
1982
+ let left = 0;
1983
+ let right = numCategories - 1;
1984
+ while (left < right) {
1985
+ const mid = Math.floor((left + right) / 2);
1986
+ if (r <= cdf[mid]) {
1987
+ right = mid;
1988
+ }
1989
+ else {
1990
+ left = mid + 1;
1991
+ }
1992
+ }
1993
+ outputValue[dist * numSamples + s] = left;
1994
+ }
1995
+ }
1996
+ else {
1997
+ // Without replacement: weighted sampling without replacement
1998
+ const available = Array.from({ length: numCategories }, (_, i) => ({
1999
+ idx: i,
2000
+ prob: distProbs[i]
2001
+ }));
2002
+ for (let s = 0; s < numSamples; s++) {
2003
+ // Compute sum of remaining probabilities
2004
+ let remainingSum = 0;
2005
+ for (const item of available) {
2006
+ remainingSum += item.prob;
2007
+ }
2008
+ // Sample from remaining
2009
+ const r = Math.random() * remainingSum;
2010
+ let cumSum = 0;
2011
+ let selectedIdx = -1;
2012
+ for (let i = 0; i < available.length; i++) {
2013
+ cumSum += available[i].prob;
2014
+ if (r <= cumSum) {
2015
+ selectedIdx = i;
2016
+ break;
2017
+ }
2018
+ }
2019
+ // Handle floating point edge case
2020
+ if (selectedIdx === -1) {
2021
+ selectedIdx = available.length - 1;
2022
+ }
2023
+ // Store result and remove from available
2024
+ outputValue[dist * numSamples + s] = available[selectedIdx].idx;
2025
+ available.splice(selectedIdx, 1);
2026
+ }
2027
+ }
2028
+ }
2029
+ return new Tensor(outputValue, {
2030
+ shape: outputShape,
2031
+ strides: Tensor.getStrides(outputShape),
2032
+ offset: 0,
2033
+ numel: numDist * numSamples,
2034
+ device: this.device,
2035
+ dtype: "int32"
2036
+ });
2037
+ }
1915
2038
  // Utility to create a new tensor filled with a number
1916
2039
  static full(shape, num, options = {}) {
1917
2040
  if (shape.length === 0)
@@ -2094,6 +2217,7 @@ class Tensor {
2094
2217
  shape,
2095
2218
  offset: 0,
2096
2219
  numel: outputSize,
2220
+ dtype: "int32",
2097
2221
  ...options
2098
2222
  });
2099
2223
  }
@@ -2130,6 +2254,7 @@ class Tensor {
2130
2254
  shape: [n],
2131
2255
  offset: 0,
2132
2256
  numel: n,
2257
+ dtype: "int32",
2133
2258
  ...options
2134
2259
  });
2135
2260
  }
@@ -2271,6 +2396,53 @@ class Tensor {
2271
2396
  }
2272
2397
  return buildNested(this.value, this.shape, this.strides, this.offset);
2273
2398
  }
2399
+ // Returns the nicely Pytorch-like formatted string form
2400
+ toString() {
2401
+ const val = this.val();
2402
+ // Format a single number (integers get trailing dot)
2403
+ const formatNum = (n) => {
2404
+ if (Number.isInteger(n) && Math.abs(n) < 1e8) {
2405
+ return n.toFixed(0) + ".";
2406
+ }
2407
+ return n.toString();
2408
+ };
2409
+ // Handle scalar
2410
+ if (typeof val === "number") {
2411
+ return `tensor(${formatNum(val)})`;
2412
+ }
2413
+ // Collect all numbers to find max width for alignment
2414
+ const collectNumbers = (v) => {
2415
+ if (typeof v === "number")
2416
+ return [v];
2417
+ return v.flatMap(collectNumbers);
2418
+ };
2419
+ const allNumbers = collectNumbers(val);
2420
+ const maxWidth = Math.max(...allNumbers.map((n) => formatNum(n).length));
2421
+ const ndim = this.shape.length;
2422
+ const baseIndent = "tensor(".length; // 7
2423
+ const formatNested = (v, depth) => {
2424
+ if (typeof v === "number") {
2425
+ return formatNum(v).padStart(maxWidth);
2426
+ }
2427
+ const arr = v;
2428
+ // Innermost dimension: format as single line [x, y, z]
2429
+ if (arr.length > 0 && typeof arr[0] === "number") {
2430
+ const elements = arr.map((x) => formatNum(x).padStart(maxWidth));
2431
+ return `[${elements.join(", ")}]`;
2432
+ }
2433
+ // Number of blank lines between elements at this depth
2434
+ // Deeper = fewer blank lines (0 between rows, 1 between 2D blocks, etc.)
2435
+ const blankLines = Math.max(0, ndim - depth - 2);
2436
+ const separator = ",\n" + "\n".repeat(blankLines);
2437
+ const innerIndent = " ".repeat(baseIndent + depth + 1);
2438
+ const formatted = arr.map((item, i) => {
2439
+ const str = formatNested(item, depth + 1);
2440
+ return i === 0 ? "[" + str : innerIndent + str;
2441
+ });
2442
+ return formatted.join(separator) + "]";
2443
+ };
2444
+ return "tensor(" + formatNested(val, 0) + ")";
2445
+ }
2274
2446
  // Returns a view of the tensor with gradient turned off and detaches from autograd
2275
2447
  detach() {
2276
2448
  return new Tensor(this.value, {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.3",
3
+ "version": "0.8.5",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {