catniff 0.7.2 → 0.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -61,6 +61,7 @@ export declare class Tensor {
61
61
  index(indices: Tensor | TensorValue): Tensor;
62
62
  slice(ranges: number[][]): Tensor;
63
63
  chunk(chunks: number, dim?: number): Tensor[];
64
+ expand(newShape: number[]): Tensor;
64
65
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
65
66
  squeeze(dims?: number[] | number): Tensor;
66
67
  unsqueeze(dim: number): Tensor;
package/dist/core.js CHANGED
@@ -647,6 +647,46 @@ class Tensor {
647
647
  }
648
648
  return results;
649
649
  }
650
+ // Tensor expansion
651
+ expand(newShape) {
652
+ // Handle scalars
653
+ let self = this;
654
+ if (typeof this.value === "number") {
655
+ self = self.unsqueeze(0);
656
+ }
657
+ // Pad shapes to same length
658
+ const ndim = Math.max(self.shape.length, newShape.length);
659
+ const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
660
+ const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
661
+ const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
662
+ const newStrides = new Array(ndim);
663
+ for (let i = 0; i < ndim; i++) {
664
+ if (oldShape[i] === targetShape[i]) {
665
+ newStrides[i] = oldStrides[i];
666
+ }
667
+ else if (oldShape[i] === 1) {
668
+ newStrides[i] = 0;
669
+ }
670
+ else {
671
+ throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
672
+ }
673
+ }
674
+ const out = new Tensor(self.value, {
675
+ shape: targetShape,
676
+ strides: newStrides,
677
+ offset: self.offset,
678
+ numel: Tensor.shapeToSize(targetShape),
679
+ device: self.device
680
+ });
681
+ if (self.requiresGrad) {
682
+ out.requiresGrad = true;
683
+ out.children.push(self);
684
+ out.gradFn = () => {
685
+ Tensor.addGrad(self, out.grad);
686
+ };
687
+ }
688
+ return out;
689
+ }
650
690
  // Tensor concatentation
651
691
  cat(other, dim = 0) {
652
692
  other = this.handleOther(other);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.7.2",
3
+ "version": "0.7.3",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {