catniff 0.8.3 → 0.8.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -0
- package/dist/core.js +177 -5
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -201,6 +201,7 @@ export declare class Tensor {
|
|
|
201
201
|
triu(diagonal?: number): Tensor;
|
|
202
202
|
tril(diagonal?: number): Tensor;
|
|
203
203
|
maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
|
|
204
|
+
multinomial(numSamples: number, replacement?: boolean): Tensor;
|
|
204
205
|
static full(shape: number[], num: number, options?: TensorOptions): Tensor;
|
|
205
206
|
static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
|
|
206
207
|
static ones(shape?: number[], options?: TensorOptions): Tensor;
|
|
@@ -223,6 +224,7 @@ export declare class Tensor {
|
|
|
223
224
|
zeroGrad?: boolean;
|
|
224
225
|
}): void;
|
|
225
226
|
val(): TensorValue;
|
|
227
|
+
toString(): string;
|
|
226
228
|
detach(): Tensor;
|
|
227
229
|
clone(): Tensor;
|
|
228
230
|
replace(other: Tensor | TensorValue): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -949,14 +949,14 @@ class Tensor {
|
|
|
949
949
|
// Copy if not contiguous
|
|
950
950
|
const outputSize = this.numel;
|
|
951
951
|
const outputShape = this.shape;
|
|
952
|
-
|
|
952
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
953
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
953
954
|
if (this.isContiguous()) {
|
|
954
|
-
|
|
955
|
-
|
|
955
|
+
// Fast path: direct copy
|
|
956
|
+
outputValue.set(this.value.subarray(this.offset, this.offset + outputSize));
|
|
956
957
|
}
|
|
957
958
|
else {
|
|
958
|
-
|
|
959
|
-
outputStrides = Tensor.getStrides(outputShape);
|
|
959
|
+
// Slow path: coordinate conversion
|
|
960
960
|
for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
|
|
961
961
|
const coords = Tensor.indexToCoords(flatIndex, outputStrides);
|
|
962
962
|
const originalIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
@@ -1912,6 +1912,129 @@ class Tensor {
|
|
|
1912
1912
|
mask = this.handleOther(mask);
|
|
1913
1913
|
return this.mul(mask.logicalNot()).add(mask.mul(value));
|
|
1914
1914
|
}
|
|
1915
|
+
// Multinomial sampling
|
|
1916
|
+
multinomial(numSamples, replacement = false) {
|
|
1917
|
+
// Validate input dimensions (1D or 2D only)
|
|
1918
|
+
if (this.shape.length === 0 || this.shape.length > 2) {
|
|
1919
|
+
throw new Error("multinomial only supports 1D or 2D probability tensors");
|
|
1920
|
+
}
|
|
1921
|
+
const is1D = this.shape.length === 1;
|
|
1922
|
+
const numDist = is1D ? 1 : this.shape[0];
|
|
1923
|
+
const numCategories = is1D ? this.shape[0] : this.shape[1];
|
|
1924
|
+
// Validate numSamples
|
|
1925
|
+
if (numSamples <= 0) {
|
|
1926
|
+
throw new Error("Number of samples must be positive");
|
|
1927
|
+
}
|
|
1928
|
+
if (!replacement && numSamples > numCategories) {
|
|
1929
|
+
throw new Error(`Cannot sample ${numSamples} without replacement from ${numCategories} categories`);
|
|
1930
|
+
}
|
|
1931
|
+
// Make contiguous copy of probabilities
|
|
1932
|
+
const probsSize = this.numel;
|
|
1933
|
+
const probs = new dtype_1.TypedArray[this.dtype](probsSize);
|
|
1934
|
+
if (this.isContiguous()) {
|
|
1935
|
+
// Fast path: direct copy
|
|
1936
|
+
probs.set(this.value.subarray(this.offset, this.offset + probsSize));
|
|
1937
|
+
}
|
|
1938
|
+
else {
|
|
1939
|
+
// Slow path: coordinate conversion
|
|
1940
|
+
const defaultStrides = Tensor.getStrides(this.shape);
|
|
1941
|
+
for (let i = 0; i < probsSize; i++) {
|
|
1942
|
+
const coords = Tensor.indexToCoords(i, defaultStrides);
|
|
1943
|
+
const idx = Tensor.coordsToIndex(coords, this.strides);
|
|
1944
|
+
probs[i] = this.value[idx + this.offset];
|
|
1945
|
+
}
|
|
1946
|
+
}
|
|
1947
|
+
// Output setup
|
|
1948
|
+
const outputShape = is1D ? [numSamples] : [numDist, numSamples];
|
|
1949
|
+
const outputValue = new Int32Array(numDist * numSamples);
|
|
1950
|
+
// Sample from each distribution
|
|
1951
|
+
for (let dist = 0; dist < numDist; dist++) {
|
|
1952
|
+
const offset = dist * numCategories;
|
|
1953
|
+
// Extract this distribution's probabilities
|
|
1954
|
+
const distProbs = probs.slice(offset, offset + numCategories);
|
|
1955
|
+
// Validate and normalize
|
|
1956
|
+
let sum = 0;
|
|
1957
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1958
|
+
if (distProbs[i] < 0) {
|
|
1959
|
+
throw new Error("Probabilities cannot be negative");
|
|
1960
|
+
}
|
|
1961
|
+
sum += distProbs[i];
|
|
1962
|
+
}
|
|
1963
|
+
if (sum <= 0) {
|
|
1964
|
+
throw new Error("Probabilities must sum to a positive value");
|
|
1965
|
+
}
|
|
1966
|
+
// Normalize
|
|
1967
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1968
|
+
distProbs[i] /= sum;
|
|
1969
|
+
}
|
|
1970
|
+
if (replacement) {
|
|
1971
|
+
// With replacement: use CDF for efficient sampling
|
|
1972
|
+
const cdf = new Array(numCategories);
|
|
1973
|
+
let cumSum = 0;
|
|
1974
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1975
|
+
cumSum += distProbs[i];
|
|
1976
|
+
cdf[i] = cumSum;
|
|
1977
|
+
}
|
|
1978
|
+
cdf[numCategories - 1] = 1;
|
|
1979
|
+
for (let s = 0; s < numSamples; s++) {
|
|
1980
|
+
const r = Math.random();
|
|
1981
|
+
// Binary search for efficiency
|
|
1982
|
+
let left = 0;
|
|
1983
|
+
let right = numCategories - 1;
|
|
1984
|
+
while (left < right) {
|
|
1985
|
+
const mid = Math.floor((left + right) / 2);
|
|
1986
|
+
if (r <= cdf[mid]) {
|
|
1987
|
+
right = mid;
|
|
1988
|
+
}
|
|
1989
|
+
else {
|
|
1990
|
+
left = mid + 1;
|
|
1991
|
+
}
|
|
1992
|
+
}
|
|
1993
|
+
outputValue[dist * numSamples + s] = left;
|
|
1994
|
+
}
|
|
1995
|
+
}
|
|
1996
|
+
else {
|
|
1997
|
+
// Without replacement: weighted sampling without replacement
|
|
1998
|
+
const available = Array.from({ length: numCategories }, (_, i) => ({
|
|
1999
|
+
idx: i,
|
|
2000
|
+
prob: distProbs[i]
|
|
2001
|
+
}));
|
|
2002
|
+
for (let s = 0; s < numSamples; s++) {
|
|
2003
|
+
// Compute sum of remaining probabilities
|
|
2004
|
+
let remainingSum = 0;
|
|
2005
|
+
for (const item of available) {
|
|
2006
|
+
remainingSum += item.prob;
|
|
2007
|
+
}
|
|
2008
|
+
// Sample from remaining
|
|
2009
|
+
const r = Math.random() * remainingSum;
|
|
2010
|
+
let cumSum = 0;
|
|
2011
|
+
let selectedIdx = -1;
|
|
2012
|
+
for (let i = 0; i < available.length; i++) {
|
|
2013
|
+
cumSum += available[i].prob;
|
|
2014
|
+
if (r <= cumSum) {
|
|
2015
|
+
selectedIdx = i;
|
|
2016
|
+
break;
|
|
2017
|
+
}
|
|
2018
|
+
}
|
|
2019
|
+
// Handle floating point edge case
|
|
2020
|
+
if (selectedIdx === -1) {
|
|
2021
|
+
selectedIdx = available.length - 1;
|
|
2022
|
+
}
|
|
2023
|
+
// Store result and remove from available
|
|
2024
|
+
outputValue[dist * numSamples + s] = available[selectedIdx].idx;
|
|
2025
|
+
available.splice(selectedIdx, 1);
|
|
2026
|
+
}
|
|
2027
|
+
}
|
|
2028
|
+
}
|
|
2029
|
+
return new Tensor(outputValue, {
|
|
2030
|
+
shape: outputShape,
|
|
2031
|
+
strides: Tensor.getStrides(outputShape),
|
|
2032
|
+
offset: 0,
|
|
2033
|
+
numel: numDist * numSamples,
|
|
2034
|
+
device: this.device,
|
|
2035
|
+
dtype: "int32"
|
|
2036
|
+
});
|
|
2037
|
+
}
|
|
1915
2038
|
// Utility to create a new tensor filled with a number
|
|
1916
2039
|
static full(shape, num, options = {}) {
|
|
1917
2040
|
if (shape.length === 0)
|
|
@@ -2094,6 +2217,7 @@ class Tensor {
|
|
|
2094
2217
|
shape,
|
|
2095
2218
|
offset: 0,
|
|
2096
2219
|
numel: outputSize,
|
|
2220
|
+
dtype: "int32",
|
|
2097
2221
|
...options
|
|
2098
2222
|
});
|
|
2099
2223
|
}
|
|
@@ -2130,6 +2254,7 @@ class Tensor {
|
|
|
2130
2254
|
shape: [n],
|
|
2131
2255
|
offset: 0,
|
|
2132
2256
|
numel: n,
|
|
2257
|
+
dtype: "int32",
|
|
2133
2258
|
...options
|
|
2134
2259
|
});
|
|
2135
2260
|
}
|
|
@@ -2271,6 +2396,53 @@ class Tensor {
|
|
|
2271
2396
|
}
|
|
2272
2397
|
return buildNested(this.value, this.shape, this.strides, this.offset);
|
|
2273
2398
|
}
|
|
2399
|
+
// Returns the nicely Pytorch-like formatted string form
|
|
2400
|
+
toString() {
|
|
2401
|
+
const val = this.val();
|
|
2402
|
+
// Format a single number (integers get trailing dot)
|
|
2403
|
+
const formatNum = (n) => {
|
|
2404
|
+
if (Number.isInteger(n) && Math.abs(n) < 1e8) {
|
|
2405
|
+
return n.toFixed(0) + ".";
|
|
2406
|
+
}
|
|
2407
|
+
return n.toString();
|
|
2408
|
+
};
|
|
2409
|
+
// Handle scalar
|
|
2410
|
+
if (typeof val === "number") {
|
|
2411
|
+
return `tensor(${formatNum(val)})`;
|
|
2412
|
+
}
|
|
2413
|
+
// Collect all numbers to find max width for alignment
|
|
2414
|
+
const collectNumbers = (v) => {
|
|
2415
|
+
if (typeof v === "number")
|
|
2416
|
+
return [v];
|
|
2417
|
+
return v.flatMap(collectNumbers);
|
|
2418
|
+
};
|
|
2419
|
+
const allNumbers = collectNumbers(val);
|
|
2420
|
+
const maxWidth = Math.max(...allNumbers.map((n) => formatNum(n).length));
|
|
2421
|
+
const ndim = this.shape.length;
|
|
2422
|
+
const baseIndent = "tensor(".length; // 7
|
|
2423
|
+
const formatNested = (v, depth) => {
|
|
2424
|
+
if (typeof v === "number") {
|
|
2425
|
+
return formatNum(v).padStart(maxWidth);
|
|
2426
|
+
}
|
|
2427
|
+
const arr = v;
|
|
2428
|
+
// Innermost dimension: format as single line [x, y, z]
|
|
2429
|
+
if (arr.length > 0 && typeof arr[0] === "number") {
|
|
2430
|
+
const elements = arr.map((x) => formatNum(x).padStart(maxWidth));
|
|
2431
|
+
return `[${elements.join(", ")}]`;
|
|
2432
|
+
}
|
|
2433
|
+
// Number of blank lines between elements at this depth
|
|
2434
|
+
// Deeper = fewer blank lines (0 between rows, 1 between 2D blocks, etc.)
|
|
2435
|
+
const blankLines = Math.max(0, ndim - depth - 2);
|
|
2436
|
+
const separator = ",\n" + "\n".repeat(blankLines);
|
|
2437
|
+
const innerIndent = " ".repeat(baseIndent + depth + 1);
|
|
2438
|
+
const formatted = arr.map((item, i) => {
|
|
2439
|
+
const str = formatNested(item, depth + 1);
|
|
2440
|
+
return i === 0 ? "[" + str : innerIndent + str;
|
|
2441
|
+
});
|
|
2442
|
+
return formatted.join(separator) + "]";
|
|
2443
|
+
};
|
|
2444
|
+
return "tensor(" + formatNested(val, 0) + ")";
|
|
2445
|
+
}
|
|
2274
2446
|
// Returns a view of the tensor with gradient turned off and detaches from autograd
|
|
2275
2447
|
detach() {
|
|
2276
2448
|
return new Tensor(this.value, {
|