catniff 0.8.3 → 0.8.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +130 -5
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -201,6 +201,7 @@ export declare class Tensor {
|
|
|
201
201
|
triu(diagonal?: number): Tensor;
|
|
202
202
|
tril(diagonal?: number): Tensor;
|
|
203
203
|
maskedFill(mask: Tensor | TensorValue, value: number): Tensor;
|
|
204
|
+
multinomial(numSamples: number, replacement?: boolean): Tensor;
|
|
204
205
|
static full(shape: number[], num: number, options?: TensorOptions): Tensor;
|
|
205
206
|
static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
|
|
206
207
|
static ones(shape?: number[], options?: TensorOptions): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -949,14 +949,14 @@ class Tensor {
|
|
|
949
949
|
// Copy if not contiguous
|
|
950
950
|
const outputSize = this.numel;
|
|
951
951
|
const outputShape = this.shape;
|
|
952
|
-
|
|
952
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
953
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
953
954
|
if (this.isContiguous()) {
|
|
954
|
-
|
|
955
|
-
|
|
955
|
+
// Fast path: direct copy
|
|
956
|
+
outputValue.set(this.value.subarray(this.offset, this.offset + outputSize));
|
|
956
957
|
}
|
|
957
958
|
else {
|
|
958
|
-
|
|
959
|
-
outputStrides = Tensor.getStrides(outputShape);
|
|
959
|
+
// Slow path: coordinate conversion
|
|
960
960
|
for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
|
|
961
961
|
const coords = Tensor.indexToCoords(flatIndex, outputStrides);
|
|
962
962
|
const originalIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
@@ -1912,6 +1912,129 @@ class Tensor {
|
|
|
1912
1912
|
mask = this.handleOther(mask);
|
|
1913
1913
|
return this.mul(mask.logicalNot()).add(mask.mul(value));
|
|
1914
1914
|
}
|
|
1915
|
+
// Multinomial sampling
|
|
1916
|
+
multinomial(numSamples, replacement = false) {
|
|
1917
|
+
// Validate input dimensions (1D or 2D only)
|
|
1918
|
+
if (this.shape.length === 0 || this.shape.length > 2) {
|
|
1919
|
+
throw new Error("multinomial only supports 1D or 2D probability tensors");
|
|
1920
|
+
}
|
|
1921
|
+
const is1D = this.shape.length === 1;
|
|
1922
|
+
const numDist = is1D ? 1 : this.shape[0];
|
|
1923
|
+
const numCategories = is1D ? this.shape[0] : this.shape[1];
|
|
1924
|
+
// Validate numSamples
|
|
1925
|
+
if (numSamples <= 0) {
|
|
1926
|
+
throw new Error("Number of samples must be positive");
|
|
1927
|
+
}
|
|
1928
|
+
if (!replacement && numSamples > numCategories) {
|
|
1929
|
+
throw new Error(`Cannot sample ${numSamples} without replacement from ${numCategories} categories`);
|
|
1930
|
+
}
|
|
1931
|
+
// Make contiguous copy of probabilities
|
|
1932
|
+
const probsSize = this.numel;
|
|
1933
|
+
const probs = new dtype_1.TypedArray[this.dtype](probsSize);
|
|
1934
|
+
if (this.isContiguous()) {
|
|
1935
|
+
// Fast path: direct copy
|
|
1936
|
+
probs.set(this.value.subarray(this.offset, this.offset + probsSize));
|
|
1937
|
+
}
|
|
1938
|
+
else {
|
|
1939
|
+
// Slow path: coordinate conversion
|
|
1940
|
+
const defaultStrides = Tensor.getStrides(this.shape);
|
|
1941
|
+
for (let i = 0; i < probsSize; i++) {
|
|
1942
|
+
const coords = Tensor.indexToCoords(i, defaultStrides);
|
|
1943
|
+
const idx = Tensor.coordsToIndex(coords, this.strides);
|
|
1944
|
+
probs[i] = this.value[idx + this.offset];
|
|
1945
|
+
}
|
|
1946
|
+
}
|
|
1947
|
+
// Output setup
|
|
1948
|
+
const outputShape = is1D ? [numSamples] : [numDist, numSamples];
|
|
1949
|
+
const outputValue = new Int32Array(numDist * numSamples);
|
|
1950
|
+
// Sample from each distribution
|
|
1951
|
+
for (let dist = 0; dist < numDist; dist++) {
|
|
1952
|
+
const offset = dist * numCategories;
|
|
1953
|
+
// Extract this distribution's probabilities
|
|
1954
|
+
const distProbs = probs.slice(offset, offset + numCategories);
|
|
1955
|
+
// Validate and normalize
|
|
1956
|
+
let sum = 0;
|
|
1957
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1958
|
+
if (distProbs[i] < 0) {
|
|
1959
|
+
throw new Error("Probabilities cannot be negative");
|
|
1960
|
+
}
|
|
1961
|
+
sum += distProbs[i];
|
|
1962
|
+
}
|
|
1963
|
+
if (sum <= 0) {
|
|
1964
|
+
throw new Error("Probabilities must sum to a positive value");
|
|
1965
|
+
}
|
|
1966
|
+
// Normalize
|
|
1967
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1968
|
+
distProbs[i] /= sum;
|
|
1969
|
+
}
|
|
1970
|
+
if (replacement) {
|
|
1971
|
+
// With replacement: use CDF for efficient sampling
|
|
1972
|
+
const cdf = new Array(numCategories);
|
|
1973
|
+
let cumSum = 0;
|
|
1974
|
+
for (let i = 0; i < numCategories; i++) {
|
|
1975
|
+
cumSum += distProbs[i];
|
|
1976
|
+
cdf[i] = cumSum;
|
|
1977
|
+
}
|
|
1978
|
+
cdf[numCategories - 1] = 1;
|
|
1979
|
+
for (let s = 0; s < numSamples; s++) {
|
|
1980
|
+
const r = Math.random();
|
|
1981
|
+
// Binary search for efficiency
|
|
1982
|
+
let left = 0;
|
|
1983
|
+
let right = numCategories - 1;
|
|
1984
|
+
while (left < right) {
|
|
1985
|
+
const mid = Math.floor((left + right) / 2);
|
|
1986
|
+
if (r <= cdf[mid]) {
|
|
1987
|
+
right = mid;
|
|
1988
|
+
}
|
|
1989
|
+
else {
|
|
1990
|
+
left = mid + 1;
|
|
1991
|
+
}
|
|
1992
|
+
}
|
|
1993
|
+
outputValue[dist * numSamples + s] = left;
|
|
1994
|
+
}
|
|
1995
|
+
}
|
|
1996
|
+
else {
|
|
1997
|
+
// Without replacement: weighted sampling without replacement
|
|
1998
|
+
const available = Array.from({ length: numCategories }, (_, i) => ({
|
|
1999
|
+
idx: i,
|
|
2000
|
+
prob: distProbs[i]
|
|
2001
|
+
}));
|
|
2002
|
+
for (let s = 0; s < numSamples; s++) {
|
|
2003
|
+
// Compute sum of remaining probabilities
|
|
2004
|
+
let remainingSum = 0;
|
|
2005
|
+
for (const item of available) {
|
|
2006
|
+
remainingSum += item.prob;
|
|
2007
|
+
}
|
|
2008
|
+
// Sample from remaining
|
|
2009
|
+
const r = Math.random() * remainingSum;
|
|
2010
|
+
let cumSum = 0;
|
|
2011
|
+
let selectedIdx = -1;
|
|
2012
|
+
for (let i = 0; i < available.length; i++) {
|
|
2013
|
+
cumSum += available[i].prob;
|
|
2014
|
+
if (r <= cumSum) {
|
|
2015
|
+
selectedIdx = i;
|
|
2016
|
+
break;
|
|
2017
|
+
}
|
|
2018
|
+
}
|
|
2019
|
+
// Handle floating point edge case
|
|
2020
|
+
if (selectedIdx === -1) {
|
|
2021
|
+
selectedIdx = available.length - 1;
|
|
2022
|
+
}
|
|
2023
|
+
// Store result and remove from available
|
|
2024
|
+
outputValue[dist * numSamples + s] = available[selectedIdx].idx;
|
|
2025
|
+
available.splice(selectedIdx, 1);
|
|
2026
|
+
}
|
|
2027
|
+
}
|
|
2028
|
+
}
|
|
2029
|
+
return new Tensor(outputValue, {
|
|
2030
|
+
shape: outputShape,
|
|
2031
|
+
strides: Tensor.getStrides(outputShape),
|
|
2032
|
+
offset: 0,
|
|
2033
|
+
numel: numDist * numSamples,
|
|
2034
|
+
device: this.device,
|
|
2035
|
+
dtype: "int32"
|
|
2036
|
+
});
|
|
2037
|
+
}
|
|
1915
2038
|
// Utility to create a new tensor filled with a number
|
|
1916
2039
|
static full(shape, num, options = {}) {
|
|
1917
2040
|
if (shape.length === 0)
|
|
@@ -2094,6 +2217,7 @@ class Tensor {
|
|
|
2094
2217
|
shape,
|
|
2095
2218
|
offset: 0,
|
|
2096
2219
|
numel: outputSize,
|
|
2220
|
+
dtype: "int32",
|
|
2097
2221
|
...options
|
|
2098
2222
|
});
|
|
2099
2223
|
}
|
|
@@ -2130,6 +2254,7 @@ class Tensor {
|
|
|
2130
2254
|
shape: [n],
|
|
2131
2255
|
offset: 0,
|
|
2132
2256
|
numel: n,
|
|
2257
|
+
dtype: "int32",
|
|
2133
2258
|
...options
|
|
2134
2259
|
});
|
|
2135
2260
|
}
|