catniff 0.7.2 → 0.7.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -61,6 +61,7 @@ export declare class Tensor {
61
61
  index(indices: Tensor | TensorValue): Tensor;
62
62
  slice(ranges: number[][]): Tensor;
63
63
  chunk(chunks: number, dim?: number): Tensor[];
64
+ expand(newShape: number[]): Tensor;
64
65
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
65
66
  squeeze(dims?: number[] | number): Tensor;
66
67
  unsqueeze(dim: number): Tensor;
@@ -219,7 +220,7 @@ export declare class Tensor {
219
220
  val(): TensorValue;
220
221
  detach(): Tensor;
221
222
  clone(): Tensor;
222
- replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
223
+ replace(other: Tensor | TensorValue): Tensor;
223
224
  static backends: Map<string, Backend>;
224
225
  to(device: string): Tensor;
225
226
  to_(device: string): Tensor;
package/dist/core.js CHANGED
@@ -647,6 +647,46 @@ class Tensor {
647
647
  }
648
648
  return results;
649
649
  }
650
+ // Tensor expansion
651
+ expand(newShape) {
652
+ // Handle scalars
653
+ let self = this;
654
+ if (typeof this.value === "number") {
655
+ self = self.unsqueeze(0);
656
+ }
657
+ // Pad shapes to same length
658
+ const ndim = Math.max(self.shape.length, newShape.length);
659
+ const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
660
+ const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
661
+ const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
662
+ const newStrides = new Array(ndim);
663
+ for (let i = 0; i < ndim; i++) {
664
+ if (oldShape[i] === targetShape[i]) {
665
+ newStrides[i] = oldStrides[i];
666
+ }
667
+ else if (oldShape[i] === 1) {
668
+ newStrides[i] = 0;
669
+ }
670
+ else {
671
+ throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
672
+ }
673
+ }
674
+ const out = new Tensor(self.value, {
675
+ shape: targetShape,
676
+ strides: newStrides,
677
+ offset: self.offset,
678
+ numel: Tensor.shapeToSize(targetShape),
679
+ device: self.device
680
+ });
681
+ if (self.requiresGrad) {
682
+ out.requiresGrad = true;
683
+ out.children.push(self);
684
+ out.gradFn = () => {
685
+ Tensor.addGrad(self, out.grad);
686
+ };
687
+ }
688
+ return out;
689
+ }
650
690
  // Tensor concatentation
651
691
  cat(other, dim = 0) {
652
692
  other = this.handleOther(other);
@@ -1954,17 +1994,21 @@ class Tensor {
1954
1994
  return out;
1955
1995
  }
1956
1996
  // Returns this tensor with value replaced with the value of another tensor
1957
- replace(other, allowShapeMismatch = false) {
1997
+ replace(other) {
1958
1998
  other = this.handleOther(other);
1959
1999
  // Verify shape
1960
- if (!allowShapeMismatch) {
1961
- for (let index = 0; index < this.shape.length; index++) {
1962
- if (this.shape[index] !== other.shape[index]) {
1963
- throw new Error("Shape mismatch when trying to do tensor value replacement");
1964
- }
2000
+ if (this.shape.length !== other.shape.length) {
2001
+ throw new Error("Shape mismatch when trying to do tensor value replacement");
2002
+ }
2003
+ for (let index = 0; index < this.shape.length; index++) {
2004
+ if (this.shape[index] !== other.shape[index]) {
2005
+ throw new Error("Shape mismatch when trying to do tensor value replacement");
1965
2006
  }
1966
2007
  }
2008
+ // Reassign values
1967
2009
  this.value = other.value;
2010
+ this.strides = other.strides;
2011
+ this.offset = other.offset;
1968
2012
  return this;
1969
2013
  }
1970
2014
  // Holds all available backends
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.7.2",
4
- "description": "A small Torch-like deep learning framework for Javascript",
3
+ "version": "0.7.4",
4
+ "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {
7
7
  "test": "echo \"Error: no test specified\" && exit 1"