catniff 0.7.2 → 0.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +40 -0
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -61,6 +61,7 @@ export declare class Tensor {
|
|
|
61
61
|
index(indices: Tensor | TensorValue): Tensor;
|
|
62
62
|
slice(ranges: number[][]): Tensor;
|
|
63
63
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
64
|
+
expand(newShape: number[]): Tensor;
|
|
64
65
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
65
66
|
squeeze(dims?: number[] | number): Tensor;
|
|
66
67
|
unsqueeze(dim: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -647,6 +647,46 @@ class Tensor {
|
|
|
647
647
|
}
|
|
648
648
|
return results;
|
|
649
649
|
}
|
|
650
|
+
// Tensor expansion
|
|
651
|
+
expand(newShape) {
|
|
652
|
+
// Handle scalars
|
|
653
|
+
let self = this;
|
|
654
|
+
if (typeof this.value === "number") {
|
|
655
|
+
self = self.unsqueeze(0);
|
|
656
|
+
}
|
|
657
|
+
// Pad shapes to same length
|
|
658
|
+
const ndim = Math.max(self.shape.length, newShape.length);
|
|
659
|
+
const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
|
|
660
|
+
const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
|
|
661
|
+
const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
|
|
662
|
+
const newStrides = new Array(ndim);
|
|
663
|
+
for (let i = 0; i < ndim; i++) {
|
|
664
|
+
if (oldShape[i] === targetShape[i]) {
|
|
665
|
+
newStrides[i] = oldStrides[i];
|
|
666
|
+
}
|
|
667
|
+
else if (oldShape[i] === 1) {
|
|
668
|
+
newStrides[i] = 0;
|
|
669
|
+
}
|
|
670
|
+
else {
|
|
671
|
+
throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
const out = new Tensor(self.value, {
|
|
675
|
+
shape: targetShape,
|
|
676
|
+
strides: newStrides,
|
|
677
|
+
offset: self.offset,
|
|
678
|
+
numel: Tensor.shapeToSize(targetShape),
|
|
679
|
+
device: self.device
|
|
680
|
+
});
|
|
681
|
+
if (self.requiresGrad) {
|
|
682
|
+
out.requiresGrad = true;
|
|
683
|
+
out.children.push(self);
|
|
684
|
+
out.gradFn = () => {
|
|
685
|
+
Tensor.addGrad(self, out.grad);
|
|
686
|
+
};
|
|
687
|
+
}
|
|
688
|
+
return out;
|
|
689
|
+
}
|
|
650
690
|
// Tensor concatentation
|
|
651
691
|
cat(other, dim = 0) {
|
|
652
692
|
other = this.handleOther(other);
|