catniff 0.8.10 → 0.8.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +115 -31
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -206,6 +206,7 @@ export declare class Tensor {
|
|
|
206
206
|
bmm(other: TensorValue | Tensor): Tensor;
|
|
207
207
|
mv(other: TensorValue | Tensor): Tensor;
|
|
208
208
|
matmul(other: TensorValue | Tensor): Tensor;
|
|
209
|
+
tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
|
|
209
210
|
dropout(rate: number): Tensor;
|
|
210
211
|
triu(diagonal?: number): Tensor;
|
|
211
212
|
tril(diagonal?: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -1688,23 +1688,22 @@ class Tensor {
|
|
|
1688
1688
|
if (matACols !== matBRows)
|
|
1689
1689
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1690
1690
|
const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1691
|
-
const
|
|
1692
|
-
const matCStrides = Tensor.getStrides(matCShape);
|
|
1693
|
-
const matCSize = Tensor.shapeToSize(matCShape);
|
|
1691
|
+
const matCSize = matARows * matBCols;
|
|
1694
1692
|
const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
|
|
1695
1693
|
for (let i = 0; i < matARows; i++) {
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1694
|
+
const aRowOffset = i * matAStrides[0] + this.offset;
|
|
1695
|
+
const cRowOffset = i * matBCols;
|
|
1696
|
+
for (let k = 0; k < matACols; k++) {
|
|
1697
|
+
const aVal = matA[aRowOffset + k * matAStrides[1]];
|
|
1698
|
+
const bRowOffset = k * matBStrides[0] + other.offset;
|
|
1699
|
+
for (let j = 0; j < matBCols; j++) {
|
|
1700
|
+
matC[cRowOffset + j] += aVal * matB[bRowOffset + j * matBStrides[1]];
|
|
1702
1701
|
}
|
|
1703
1702
|
}
|
|
1704
1703
|
}
|
|
1705
1704
|
const out = new Tensor(matC, {
|
|
1706
|
-
shape:
|
|
1707
|
-
strides:
|
|
1705
|
+
shape: [matARows, matBCols],
|
|
1706
|
+
strides: [matBCols, 1],
|
|
1708
1707
|
offset: 0,
|
|
1709
1708
|
numel: matCSize,
|
|
1710
1709
|
device: this.device,
|
|
@@ -1751,25 +1750,28 @@ class Tensor {
|
|
|
1751
1750
|
if (batchACols !== batchBRows)
|
|
1752
1751
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1753
1752
|
const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1754
|
-
const
|
|
1755
|
-
const
|
|
1756
|
-
const batchCSize = Tensor.shapeToSize(batchCShape);
|
|
1753
|
+
const matrixSize = batchARows * batchBCols;
|
|
1754
|
+
const batchCSize = batchSize * matrixSize;
|
|
1757
1755
|
const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
|
|
1758
1756
|
for (let q = 0; q < batchSize; q++) {
|
|
1757
|
+
const aQOffset = q * batchAStrides[0] + this.offset;
|
|
1758
|
+
const bQOffset = q * batchBStrides[0] + other.offset;
|
|
1759
|
+
const cQOffset = q * matrixSize;
|
|
1759
1760
|
for (let i = 0; i < batchARows; i++) {
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1761
|
+
const aRowOffset = aQOffset + i * batchAStrides[1];
|
|
1762
|
+
const cRowOffset = cQOffset + i * batchBCols;
|
|
1763
|
+
for (let k = 0; k < batchACols; k++) {
|
|
1764
|
+
const aVal = batchA[aRowOffset + k * batchAStrides[2]];
|
|
1765
|
+
const bRowOffset = bQOffset + k * batchBStrides[1];
|
|
1766
|
+
for (let j = 0; j < batchBCols; j++) {
|
|
1767
|
+
batchC[cRowOffset + j] += aVal * batchB[bRowOffset + j * batchBStrides[2]];
|
|
1766
1768
|
}
|
|
1767
1769
|
}
|
|
1768
1770
|
}
|
|
1769
1771
|
}
|
|
1770
1772
|
const out = new Tensor(batchC, {
|
|
1771
|
-
shape:
|
|
1772
|
-
strides:
|
|
1773
|
+
shape: [batchSize, batchARows, batchBCols],
|
|
1774
|
+
strides: [matrixSize, batchBCols, 1],
|
|
1773
1775
|
offset: 0,
|
|
1774
1776
|
numel: batchCSize,
|
|
1775
1777
|
device: this.device,
|
|
@@ -1858,18 +1860,26 @@ class Tensor {
|
|
|
1858
1860
|
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
|
|
1859
1861
|
const outputOffsetStrides = outputStrides.slice(0, -2);
|
|
1860
1862
|
// Loop through outer dims and do matmul on two outer-most dims
|
|
1863
|
+
const outputRowStride = outputStrides[lastDim - 1];
|
|
1864
|
+
const outputColStride = outputStrides[lastDim];
|
|
1865
|
+
const selfRowStride = selfStrides[lastDim - 1];
|
|
1866
|
+
const selfColStride = selfStrides[lastDim];
|
|
1867
|
+
const otherRowStride = otherStrides[lastDim - 1];
|
|
1868
|
+
const otherColStride = otherStrides[lastDim];
|
|
1861
1869
|
for (let index = 0; index < offsetSize; index++) {
|
|
1862
1870
|
const coords = Tensor.indexToCoords(index, offsetStrides);
|
|
1863
|
-
const
|
|
1864
|
-
const
|
|
1865
|
-
const
|
|
1871
|
+
const outBatchOffset = Tensor.coordsToIndex(coords, outputOffsetStrides);
|
|
1872
|
+
const selfBatchOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides) + this.offset;
|
|
1873
|
+
const otherBatchOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides) + other.offset;
|
|
1866
1874
|
for (let i = 0; i < batchARows; i++) {
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1875
|
+
const selfRowOffset = selfBatchOffset + i * selfRowStride;
|
|
1876
|
+
const outRowOffset = outBatchOffset + i * outputRowStride;
|
|
1877
|
+
for (let k = 0; k < batchACols; k++) {
|
|
1878
|
+
const aVal = batchA[selfRowOffset + k * selfColStride];
|
|
1879
|
+
const otherRowOffset = otherBatchOffset + k * otherRowStride;
|
|
1880
|
+
for (let j = 0; j < batchBCols; j++) {
|
|
1881
|
+
outputValue[outRowOffset + j * outputColStride] +=
|
|
1882
|
+
aVal * batchB[otherRowOffset + j * otherColStride];
|
|
1873
1883
|
}
|
|
1874
1884
|
}
|
|
1875
1885
|
}
|
|
@@ -1906,6 +1916,80 @@ class Tensor {
|
|
|
1906
1916
|
}
|
|
1907
1917
|
throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
|
|
1908
1918
|
}
|
|
1919
|
+
// General tensor dot product
|
|
1920
|
+
tensordot(other, axes = 2) {
|
|
1921
|
+
other = this.handleOther(other);
|
|
1922
|
+
let axesA, axesB;
|
|
1923
|
+
// If axes is a number
|
|
1924
|
+
if (typeof axes === "number") {
|
|
1925
|
+
axesA = new Array(axes);
|
|
1926
|
+
axesB = new Array(axes);
|
|
1927
|
+
for (let i = 0; i < axes; i++) {
|
|
1928
|
+
axesA[i] = this.shape.length - axes + i;
|
|
1929
|
+
axesB[i] = i;
|
|
1930
|
+
}
|
|
1931
|
+
}
|
|
1932
|
+
// If axes is a pair of numbers or a pair of axes
|
|
1933
|
+
else {
|
|
1934
|
+
// axes is [axesA, axesB]
|
|
1935
|
+
[axesA, axesB] = axes;
|
|
1936
|
+
// Convert single numbers to arrays
|
|
1937
|
+
if (typeof axesA === "number")
|
|
1938
|
+
axesA = [axesA];
|
|
1939
|
+
if (typeof axesB === "number")
|
|
1940
|
+
axesB = [axesB];
|
|
1941
|
+
}
|
|
1942
|
+
// Normalize axes
|
|
1943
|
+
axesA = Tensor.normalizeDims(axesA, this.shape.length);
|
|
1944
|
+
axesB = Tensor.normalizeDims(axesB, other.shape.length);
|
|
1945
|
+
// Validate axes
|
|
1946
|
+
if (axesA.length !== axesB.length) {
|
|
1947
|
+
throw new Error("Number of axes to contract must be the same for both tensors");
|
|
1948
|
+
}
|
|
1949
|
+
// Validate dimensions match
|
|
1950
|
+
for (let i = 0; i < axesA.length; i++) {
|
|
1951
|
+
if (this.shape[axesA[i]] !== other.shape[axesB[i]]) {
|
|
1952
|
+
throw new Error(`Dimension mismatch: a.shape[${axesA[i]}]=${this.shape[axesA[i]]} ` +
|
|
1953
|
+
`!= b.shape[${axesB[i]}]=${other.shape[axesB[i]]}`);
|
|
1954
|
+
}
|
|
1955
|
+
}
|
|
1956
|
+
// Identify free (non-contracted) axes
|
|
1957
|
+
const freeA = [];
|
|
1958
|
+
for (let i = 0; i < this.shape.length; i++) {
|
|
1959
|
+
if (!axesA.includes(i)) {
|
|
1960
|
+
freeA.push(i);
|
|
1961
|
+
}
|
|
1962
|
+
}
|
|
1963
|
+
const freeB = [];
|
|
1964
|
+
for (let i = 0; i < other.shape.length; i++) {
|
|
1965
|
+
if (!axesB.includes(i)) {
|
|
1966
|
+
freeB.push(i);
|
|
1967
|
+
}
|
|
1968
|
+
}
|
|
1969
|
+
// Permute a to move contracted axes to the end: [free_a, contract_a]
|
|
1970
|
+
const aPermuted = this.permute([...freeA, ...axesA]);
|
|
1971
|
+
// Permute b to move contracted axes to the beginning: [contract_b, free_b]
|
|
1972
|
+
const bPermuted = other.permute([...axesB, ...freeB]);
|
|
1973
|
+
// Reshape a to 2D matrix
|
|
1974
|
+
// Shape: [product of free_a dims, product of contract_a dims]
|
|
1975
|
+
const freeASize = freeA.reduce((prod, axis) => prod * this.shape[axis], 1);
|
|
1976
|
+
const contractASize = axesA.reduce((prod, axis) => prod * this.shape[axis], 1);
|
|
1977
|
+
const aReshaped = aPermuted.reshape([freeASize, contractASize]);
|
|
1978
|
+
// Reshape b to 2D matrix
|
|
1979
|
+
// Shape: [product of contract_b dims, product of free_b dims]
|
|
1980
|
+
const contractBSize = axesB.reduce((prod, axis) => prod * other.shape[axis], 1);
|
|
1981
|
+
const freeBSize = freeB.reduce((prod, axis) => prod * other.shape[axis], 1);
|
|
1982
|
+
const bReshaped = bPermuted.reshape([contractBSize, freeBSize]);
|
|
1983
|
+
// Use normal matmul, result shape: [freeASize, freeBSize]
|
|
1984
|
+
const result2D = aReshaped.matmul(bReshaped);
|
|
1985
|
+
// Reshape result back to proper n-dimensional form
|
|
1986
|
+
// Final shape: [free_a_dims..., free_b_dims...]
|
|
1987
|
+
const finalShape = [
|
|
1988
|
+
...freeA.map(i => this.shape[i]),
|
|
1989
|
+
...freeB.map(i => other.shape[i])
|
|
1990
|
+
];
|
|
1991
|
+
return result2D.reshape(finalShape);
|
|
1992
|
+
}
|
|
1909
1993
|
// Dropout
|
|
1910
1994
|
dropout(rate) {
|
|
1911
1995
|
if (!Tensor.training || rate === 0)
|