catniff 0.8.20 → 0.8.21

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/core.js +21 -2
  2. package/package.json +1 -1
package/dist/core.js CHANGED
@@ -767,7 +767,7 @@ class Tensor {
767
767
  }
768
768
  // If dimension out of bound, throw error
769
769
  if (dim >= this.shape.length || dim < 0) {
770
- throw new Error("Dimension does not exist to apply softmax");
770
+ throw new Error("Dimension does not exist to apply unfold");
771
771
  }
772
772
  // Verify size and step
773
773
  if (size <= 0 || step <= 0)
@@ -781,13 +781,32 @@ class Tensor {
781
781
  const newStrides = [...this.strides, this.strides[dim]];
782
782
  newShape[dim] = outSize;
783
783
  newStrides[dim] = this.strides[dim] * step;
784
- return new Tensor(this.value, {
784
+ const out = new Tensor(this.value, {
785
785
  shape: newShape,
786
786
  strides: newStrides,
787
787
  offset: this.offset,
788
788
  dtype: this.dtype,
789
789
  device: this.device
790
790
  });
791
+ if (this.requiresGrad) {
792
+ out.requiresGrad = true;
793
+ out.children.push(this);
794
+ out.gradFn = () => {
795
+ const outGrad = out.grad;
796
+ const grad = Tensor.zerosLike(this);
797
+ for (let i = 0; i < out.numel; i++) {
798
+ const coords = Tensor.indexToCoords(i, newStrides);
799
+ const windowIdx = coords[dim];
800
+ const withinWindow = coords[coords.length - 1];
801
+ coords[dim] = windowIdx * step + withinWindow;
802
+ coords.pop();
803
+ const sourceIdx = Tensor.coordsToIndex(coords, this.strides);
804
+ grad.value[sourceIdx] += outGrad.value[i];
805
+ }
806
+ Tensor.addGrad(this, grad);
807
+ };
808
+ }
809
+ return out;
791
810
  }
792
811
  // Tensor concatentation
793
812
  cat(other, dim = 0) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.20",
3
+ "version": "0.8.21",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {