catniff 0.8.10 → 0.8.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -206,6 +206,7 @@ export declare class Tensor {
206
206
  bmm(other: TensorValue | Tensor): Tensor;
207
207
  mv(other: TensorValue | Tensor): Tensor;
208
208
  matmul(other: TensorValue | Tensor): Tensor;
209
+ tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
209
210
  dropout(rate: number): Tensor;
210
211
  triu(diagonal?: number): Tensor;
211
212
  tril(diagonal?: number): Tensor;
package/dist/core.js CHANGED
@@ -1688,23 +1688,22 @@ class Tensor {
1688
1688
  if (matACols !== matBRows)
1689
1689
  throw new Error("Invalid matrices shape for multiplication");
1690
1690
  const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1691
- const matCShape = [matARows, matBCols];
1692
- const matCStrides = Tensor.getStrides(matCShape);
1693
- const matCSize = Tensor.shapeToSize(matCShape);
1691
+ const matCSize = matARows * matBCols;
1694
1692
  const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
1695
1693
  for (let i = 0; i < matARows; i++) {
1696
- for (let j = 0; j < matBCols; j++) {
1697
- for (let k = 0; k < matACols; k++) {
1698
- // Tensor values are 1D arrays so we have to get real index using strides
1699
- matC[i * matCStrides[0] + j * matCStrides[1]] +=
1700
- matA[i * matAStrides[0] + k * matAStrides[1] + this.offset] *
1701
- matB[k * matBStrides[0] + j * matBStrides[1] + other.offset];
1694
+ const aRowOffset = i * matAStrides[0] + this.offset;
1695
+ const cRowOffset = i * matBCols;
1696
+ for (let k = 0; k < matACols; k++) {
1697
+ const aVal = matA[aRowOffset + k * matAStrides[1]];
1698
+ const bRowOffset = k * matBStrides[0] + other.offset;
1699
+ for (let j = 0; j < matBCols; j++) {
1700
+ matC[cRowOffset + j] += aVal * matB[bRowOffset + j * matBStrides[1]];
1702
1701
  }
1703
1702
  }
1704
1703
  }
1705
1704
  const out = new Tensor(matC, {
1706
- shape: matCShape,
1707
- strides: matCStrides,
1705
+ shape: [matARows, matBCols],
1706
+ strides: [matBCols, 1],
1708
1707
  offset: 0,
1709
1708
  numel: matCSize,
1710
1709
  device: this.device,
@@ -1751,25 +1750,28 @@ class Tensor {
1751
1750
  if (batchACols !== batchBRows)
1752
1751
  throw new Error("Invalid matrices shape for multiplication");
1753
1752
  const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1754
- const batchCShape = [batchSize, batchARows, batchBCols];
1755
- const batchCStrides = Tensor.getStrides(batchCShape);
1756
- const batchCSize = Tensor.shapeToSize(batchCShape);
1753
+ const matrixSize = batchARows * batchBCols;
1754
+ const batchCSize = batchSize * matrixSize;
1757
1755
  const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
1758
1756
  for (let q = 0; q < batchSize; q++) {
1757
+ const aQOffset = q * batchAStrides[0] + this.offset;
1758
+ const bQOffset = q * batchBStrides[0] + other.offset;
1759
+ const cQOffset = q * matrixSize;
1759
1760
  for (let i = 0; i < batchARows; i++) {
1760
- for (let j = 0; j < batchBCols; j++) {
1761
- for (let k = 0; k < batchACols; k++) {
1762
- // Tensor values are 1D arrays so we have to get real index using strides
1763
- batchC[q * batchCStrides[0] + i * batchCStrides[1] + j * batchCStrides[2]] +=
1764
- batchA[q * batchAStrides[0] + i * batchAStrides[1] + k * batchAStrides[2] + this.offset] *
1765
- batchB[q * batchBStrides[0] + k * batchBStrides[1] + j * batchBStrides[2] + other.offset];
1761
+ const aRowOffset = aQOffset + i * batchAStrides[1];
1762
+ const cRowOffset = cQOffset + i * batchBCols;
1763
+ for (let k = 0; k < batchACols; k++) {
1764
+ const aVal = batchA[aRowOffset + k * batchAStrides[2]];
1765
+ const bRowOffset = bQOffset + k * batchBStrides[1];
1766
+ for (let j = 0; j < batchBCols; j++) {
1767
+ batchC[cRowOffset + j] += aVal * batchB[bRowOffset + j * batchBStrides[2]];
1766
1768
  }
1767
1769
  }
1768
1770
  }
1769
1771
  }
1770
1772
  const out = new Tensor(batchC, {
1771
- shape: batchCShape,
1772
- strides: batchCStrides,
1773
+ shape: [batchSize, batchARows, batchBCols],
1774
+ strides: [matrixSize, batchBCols, 1],
1773
1775
  offset: 0,
1774
1776
  numel: batchCSize,
1775
1777
  device: this.device,
@@ -1858,18 +1860,26 @@ class Tensor {
1858
1860
  const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
1859
1861
  const outputOffsetStrides = outputStrides.slice(0, -2);
1860
1862
  // Loop through outer dims and do matmul on two outer-most dims
1863
+ const outputRowStride = outputStrides[lastDim - 1];
1864
+ const outputColStride = outputStrides[lastDim];
1865
+ const selfRowStride = selfStrides[lastDim - 1];
1866
+ const selfColStride = selfStrides[lastDim];
1867
+ const otherRowStride = otherStrides[lastDim - 1];
1868
+ const otherColStride = otherStrides[lastDim];
1861
1869
  for (let index = 0; index < offsetSize; index++) {
1862
1870
  const coords = Tensor.indexToCoords(index, offsetStrides);
1863
- const offset = Tensor.coordsToIndex(coords, outputOffsetStrides);
1864
- const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
1865
- const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
1871
+ const outBatchOffset = Tensor.coordsToIndex(coords, outputOffsetStrides);
1872
+ const selfBatchOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides) + this.offset;
1873
+ const otherBatchOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides) + other.offset;
1866
1874
  for (let i = 0; i < batchARows; i++) {
1867
- for (let j = 0; j < batchBCols; j++) {
1868
- for (let k = 0; k < batchACols; k++) {
1869
- const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
1870
- const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
1871
- const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
1872
- outputValue[outputIdx] += batchA[selfIdx + this.offset] * batchB[otherIdx + other.offset];
1875
+ const selfRowOffset = selfBatchOffset + i * selfRowStride;
1876
+ const outRowOffset = outBatchOffset + i * outputRowStride;
1877
+ for (let k = 0; k < batchACols; k++) {
1878
+ const aVal = batchA[selfRowOffset + k * selfColStride];
1879
+ const otherRowOffset = otherBatchOffset + k * otherRowStride;
1880
+ for (let j = 0; j < batchBCols; j++) {
1881
+ outputValue[outRowOffset + j * outputColStride] +=
1882
+ aVal * batchB[otherRowOffset + j * otherColStride];
1873
1883
  }
1874
1884
  }
1875
1885
  }
@@ -1906,6 +1916,80 @@ class Tensor {
1906
1916
  }
1907
1917
  throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
1908
1918
  }
1919
+ // General tensor dot product
1920
+ tensordot(other, axes = 2) {
1921
+ other = this.handleOther(other);
1922
+ let axesA, axesB;
1923
+ // If axes is a number
1924
+ if (typeof axes === "number") {
1925
+ axesA = new Array(axes);
1926
+ axesB = new Array(axes);
1927
+ for (let i = 0; i < axes; i++) {
1928
+ axesA[i] = this.shape.length - axes + i;
1929
+ axesB[i] = i;
1930
+ }
1931
+ }
1932
+ // If axes is a pair of numbers or a pair of axes
1933
+ else {
1934
+ // axes is [axesA, axesB]
1935
+ [axesA, axesB] = axes;
1936
+ // Convert single numbers to arrays
1937
+ if (typeof axesA === "number")
1938
+ axesA = [axesA];
1939
+ if (typeof axesB === "number")
1940
+ axesB = [axesB];
1941
+ }
1942
+ // Normalize axes
1943
+ axesA = Tensor.normalizeDims(axesA, this.shape.length);
1944
+ axesB = Tensor.normalizeDims(axesB, other.shape.length);
1945
+ // Validate axes
1946
+ if (axesA.length !== axesB.length) {
1947
+ throw new Error("Number of axes to contract must be the same for both tensors");
1948
+ }
1949
+ // Validate dimensions match
1950
+ for (let i = 0; i < axesA.length; i++) {
1951
+ if (this.shape[axesA[i]] !== other.shape[axesB[i]]) {
1952
+ throw new Error(`Dimension mismatch: a.shape[${axesA[i]}]=${this.shape[axesA[i]]} ` +
1953
+ `!= b.shape[${axesB[i]}]=${other.shape[axesB[i]]}`);
1954
+ }
1955
+ }
1956
+ // Identify free (non-contracted) axes
1957
+ const freeA = [];
1958
+ for (let i = 0; i < this.shape.length; i++) {
1959
+ if (!axesA.includes(i)) {
1960
+ freeA.push(i);
1961
+ }
1962
+ }
1963
+ const freeB = [];
1964
+ for (let i = 0; i < other.shape.length; i++) {
1965
+ if (!axesB.includes(i)) {
1966
+ freeB.push(i);
1967
+ }
1968
+ }
1969
+ // Permute a to move contracted axes to the end: [free_a, contract_a]
1970
+ const aPermuted = this.permute([...freeA, ...axesA]);
1971
+ // Permute b to move contracted axes to the beginning: [contract_b, free_b]
1972
+ const bPermuted = other.permute([...axesB, ...freeB]);
1973
+ // Reshape a to 2D matrix
1974
+ // Shape: [product of free_a dims, product of contract_a dims]
1975
+ const freeASize = freeA.reduce((prod, axis) => prod * this.shape[axis], 1);
1976
+ const contractASize = axesA.reduce((prod, axis) => prod * this.shape[axis], 1);
1977
+ const aReshaped = aPermuted.reshape([freeASize, contractASize]);
1978
+ // Reshape b to 2D matrix
1979
+ // Shape: [product of contract_b dims, product of free_b dims]
1980
+ const contractBSize = axesB.reduce((prod, axis) => prod * other.shape[axis], 1);
1981
+ const freeBSize = freeB.reduce((prod, axis) => prod * other.shape[axis], 1);
1982
+ const bReshaped = bPermuted.reshape([contractBSize, freeBSize]);
1983
+ // Use normal matmul, result shape: [freeASize, freeBSize]
1984
+ const result2D = aReshaped.matmul(bReshaped);
1985
+ // Reshape result back to proper n-dimensional form
1986
+ // Final shape: [free_a_dims..., free_b_dims...]
1987
+ const finalShape = [
1988
+ ...freeA.map(i => this.shape[i]),
1989
+ ...freeB.map(i => other.shape[i])
1990
+ ];
1991
+ return result2D.reshape(finalShape);
1992
+ }
1909
1993
  // Dropout
1910
1994
  dropout(rate) {
1911
1995
  if (!Tensor.training || rate === 0)
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.10",
3
+ "version": "0.8.12",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {