catniff 0.7.0 → 0.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +84 -4
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -61,6 +61,7 @@ export declare class Tensor {
|
|
|
61
61
|
index(indices: Tensor | TensorValue): Tensor;
|
|
62
62
|
slice(ranges: number[][]): Tensor;
|
|
63
63
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
64
|
+
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
64
65
|
squeeze(dims?: number[] | number): Tensor;
|
|
65
66
|
unsqueeze(dim: number): Tensor;
|
|
66
67
|
static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
|
package/dist/core.js
CHANGED
|
@@ -218,7 +218,6 @@ class Tensor {
|
|
|
218
218
|
}
|
|
219
219
|
if (out.requiresGrad) {
|
|
220
220
|
out.gradFn = () => {
|
|
221
|
-
// Disable gradient collecting of gradients themselves
|
|
222
221
|
const outGrad = out.grad;
|
|
223
222
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
224
223
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|
|
@@ -239,7 +238,6 @@ class Tensor {
|
|
|
239
238
|
}
|
|
240
239
|
if (out.requiresGrad) {
|
|
241
240
|
out.gradFn = () => {
|
|
242
|
-
// Disable gradient collecting of gradients themselves
|
|
243
241
|
const outGrad = out.grad;
|
|
244
242
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
245
243
|
if (this.requiresGrad)
|
|
@@ -649,6 +647,90 @@ class Tensor {
|
|
|
649
647
|
}
|
|
650
648
|
return results;
|
|
651
649
|
}
|
|
650
|
+
// Tensor concatentation
|
|
651
|
+
cat(other, dim = 0) {
|
|
652
|
+
other = this.handleOther(other);
|
|
653
|
+
// Handle scalars
|
|
654
|
+
if (typeof this.value === "number" || typeof other.value === "number") {
|
|
655
|
+
throw new Error("Can not concatenate scalars");
|
|
656
|
+
}
|
|
657
|
+
// Handle negative indices
|
|
658
|
+
if (dim < 0) {
|
|
659
|
+
dim += this.shape.length;
|
|
660
|
+
}
|
|
661
|
+
// If dimension out of bound, throw error
|
|
662
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
663
|
+
throw new Error("Dimension does not exist to concatenate");
|
|
664
|
+
}
|
|
665
|
+
// If shape does not match, throw error
|
|
666
|
+
if (this.shape.length !== other.shape.length) {
|
|
667
|
+
throw new Error("Shape does not match to concatenate");
|
|
668
|
+
}
|
|
669
|
+
const outputShape = new Array(this.shape.length);
|
|
670
|
+
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
|
|
671
|
+
if (currentDim === dim) {
|
|
672
|
+
outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim];
|
|
673
|
+
}
|
|
674
|
+
else if (this.shape[currentDim] !== other.shape[currentDim]) {
|
|
675
|
+
throw new Error("Shape does not match to concatenate");
|
|
676
|
+
}
|
|
677
|
+
else {
|
|
678
|
+
outputShape[currentDim] = this.shape[currentDim];
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
682
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
683
|
+
const outputValue = new Array(outputSize);
|
|
684
|
+
for (let outIndex = 0; outIndex < outputSize; outIndex++) {
|
|
685
|
+
const coords = Tensor.indexToCoords(outIndex, outputStrides);
|
|
686
|
+
// Check which tensor this output position comes from
|
|
687
|
+
if (coords[dim] < this.shape[dim]) {
|
|
688
|
+
// Comes from this tensor
|
|
689
|
+
const srcIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
690
|
+
outputValue[outIndex] = this.value[srcIndex + this.offset];
|
|
691
|
+
}
|
|
692
|
+
else {
|
|
693
|
+
// Comes from other tensor - adjust coordinate in concat dimension
|
|
694
|
+
const otherCoords = [...coords];
|
|
695
|
+
otherCoords[dim] -= this.shape[dim];
|
|
696
|
+
const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides);
|
|
697
|
+
outputValue[outIndex] = other.value[srcIndex + other.offset];
|
|
698
|
+
}
|
|
699
|
+
}
|
|
700
|
+
const out = new Tensor(outputValue, {
|
|
701
|
+
shape: outputShape,
|
|
702
|
+
strides: outputStrides,
|
|
703
|
+
numel: outputSize
|
|
704
|
+
});
|
|
705
|
+
if (this.requiresGrad) {
|
|
706
|
+
out.requiresGrad = true;
|
|
707
|
+
out.children.push(this);
|
|
708
|
+
}
|
|
709
|
+
if (other.requiresGrad) {
|
|
710
|
+
out.requiresGrad = true;
|
|
711
|
+
out.children.push(other);
|
|
712
|
+
}
|
|
713
|
+
if (out.requiresGrad) {
|
|
714
|
+
out.gradFn = () => {
|
|
715
|
+
const outGrad = out.grad;
|
|
716
|
+
const thisRanges = new Array(this.shape.length);
|
|
717
|
+
const otherRanges = new Array(other.shape.length);
|
|
718
|
+
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
|
|
719
|
+
if (currentDim === dim) {
|
|
720
|
+
thisRanges[currentDim] = [0, this.shape[currentDim], 1];
|
|
721
|
+
otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1];
|
|
722
|
+
}
|
|
723
|
+
else {
|
|
724
|
+
thisRanges[currentDim] = [];
|
|
725
|
+
otherRanges[currentDim] = [];
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
Tensor.addGrad(this, outGrad.slice(thisRanges));
|
|
729
|
+
Tensor.addGrad(other, outGrad.slice(otherRanges));
|
|
730
|
+
};
|
|
731
|
+
}
|
|
732
|
+
return out;
|
|
733
|
+
}
|
|
652
734
|
// Tensor squeeze
|
|
653
735
|
squeeze(dims) {
|
|
654
736
|
if (typeof this.value === "number")
|
|
@@ -1338,7 +1420,6 @@ class Tensor {
|
|
|
1338
1420
|
}
|
|
1339
1421
|
if (out.requiresGrad) {
|
|
1340
1422
|
out.gradFn = () => {
|
|
1341
|
-
// Disable gradient collecting of gradients themselves
|
|
1342
1423
|
const outGrad = out.grad;
|
|
1343
1424
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
1344
1425
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|
|
@@ -1396,7 +1477,6 @@ class Tensor {
|
|
|
1396
1477
|
}
|
|
1397
1478
|
if (out.requiresGrad) {
|
|
1398
1479
|
out.gradFn = () => {
|
|
1399
|
-
// Disable gradient collecting of gradients themselves
|
|
1400
1480
|
const outGrad = out.grad;
|
|
1401
1481
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
1402
1482
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|