catniff 0.8.20 → 0.8.21
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.js +21 -2
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -767,7 +767,7 @@ class Tensor {
|
|
|
767
767
|
}
|
|
768
768
|
// If dimension out of bound, throw error
|
|
769
769
|
if (dim >= this.shape.length || dim < 0) {
|
|
770
|
-
throw new Error("Dimension does not exist to apply
|
|
770
|
+
throw new Error("Dimension does not exist to apply unfold");
|
|
771
771
|
}
|
|
772
772
|
// Verify size and step
|
|
773
773
|
if (size <= 0 || step <= 0)
|
|
@@ -781,13 +781,32 @@ class Tensor {
|
|
|
781
781
|
const newStrides = [...this.strides, this.strides[dim]];
|
|
782
782
|
newShape[dim] = outSize;
|
|
783
783
|
newStrides[dim] = this.strides[dim] * step;
|
|
784
|
-
|
|
784
|
+
const out = new Tensor(this.value, {
|
|
785
785
|
shape: newShape,
|
|
786
786
|
strides: newStrides,
|
|
787
787
|
offset: this.offset,
|
|
788
788
|
dtype: this.dtype,
|
|
789
789
|
device: this.device
|
|
790
790
|
});
|
|
791
|
+
if (this.requiresGrad) {
|
|
792
|
+
out.requiresGrad = true;
|
|
793
|
+
out.children.push(this);
|
|
794
|
+
out.gradFn = () => {
|
|
795
|
+
const outGrad = out.grad;
|
|
796
|
+
const grad = Tensor.zerosLike(this);
|
|
797
|
+
for (let i = 0; i < out.numel; i++) {
|
|
798
|
+
const coords = Tensor.indexToCoords(i, newStrides);
|
|
799
|
+
const windowIdx = coords[dim];
|
|
800
|
+
const withinWindow = coords[coords.length - 1];
|
|
801
|
+
coords[dim] = windowIdx * step + withinWindow;
|
|
802
|
+
coords.pop();
|
|
803
|
+
const sourceIdx = Tensor.coordsToIndex(coords, this.strides);
|
|
804
|
+
grad.value[sourceIdx] += outGrad.value[i];
|
|
805
|
+
}
|
|
806
|
+
Tensor.addGrad(this, grad);
|
|
807
|
+
};
|
|
808
|
+
}
|
|
809
|
+
return out;
|
|
791
810
|
}
|
|
792
811
|
// Tensor concatentation
|
|
793
812
|
cat(other, dim = 0) {
|