catniff 0.7.2 → 0.7.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -1
- package/dist/core.js +50 -6
- package/package.json +2 -2
package/dist/core.d.ts
CHANGED
|
@@ -61,6 +61,7 @@ export declare class Tensor {
|
|
|
61
61
|
index(indices: Tensor | TensorValue): Tensor;
|
|
62
62
|
slice(ranges: number[][]): Tensor;
|
|
63
63
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
64
|
+
expand(newShape: number[]): Tensor;
|
|
64
65
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
65
66
|
squeeze(dims?: number[] | number): Tensor;
|
|
66
67
|
unsqueeze(dim: number): Tensor;
|
|
@@ -219,7 +220,7 @@ export declare class Tensor {
|
|
|
219
220
|
val(): TensorValue;
|
|
220
221
|
detach(): Tensor;
|
|
221
222
|
clone(): Tensor;
|
|
222
|
-
replace(other: Tensor | TensorValue
|
|
223
|
+
replace(other: Tensor | TensorValue): Tensor;
|
|
223
224
|
static backends: Map<string, Backend>;
|
|
224
225
|
to(device: string): Tensor;
|
|
225
226
|
to_(device: string): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -647,6 +647,46 @@ class Tensor {
|
|
|
647
647
|
}
|
|
648
648
|
return results;
|
|
649
649
|
}
|
|
650
|
+
// Tensor expansion
|
|
651
|
+
expand(newShape) {
|
|
652
|
+
// Handle scalars
|
|
653
|
+
let self = this;
|
|
654
|
+
if (typeof this.value === "number") {
|
|
655
|
+
self = self.unsqueeze(0);
|
|
656
|
+
}
|
|
657
|
+
// Pad shapes to same length
|
|
658
|
+
const ndim = Math.max(self.shape.length, newShape.length);
|
|
659
|
+
const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
|
|
660
|
+
const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
|
|
661
|
+
const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
|
|
662
|
+
const newStrides = new Array(ndim);
|
|
663
|
+
for (let i = 0; i < ndim; i++) {
|
|
664
|
+
if (oldShape[i] === targetShape[i]) {
|
|
665
|
+
newStrides[i] = oldStrides[i];
|
|
666
|
+
}
|
|
667
|
+
else if (oldShape[i] === 1) {
|
|
668
|
+
newStrides[i] = 0;
|
|
669
|
+
}
|
|
670
|
+
else {
|
|
671
|
+
throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
const out = new Tensor(self.value, {
|
|
675
|
+
shape: targetShape,
|
|
676
|
+
strides: newStrides,
|
|
677
|
+
offset: self.offset,
|
|
678
|
+
numel: Tensor.shapeToSize(targetShape),
|
|
679
|
+
device: self.device
|
|
680
|
+
});
|
|
681
|
+
if (self.requiresGrad) {
|
|
682
|
+
out.requiresGrad = true;
|
|
683
|
+
out.children.push(self);
|
|
684
|
+
out.gradFn = () => {
|
|
685
|
+
Tensor.addGrad(self, out.grad);
|
|
686
|
+
};
|
|
687
|
+
}
|
|
688
|
+
return out;
|
|
689
|
+
}
|
|
650
690
|
// Tensor concatentation
|
|
651
691
|
cat(other, dim = 0) {
|
|
652
692
|
other = this.handleOther(other);
|
|
@@ -1954,17 +1994,21 @@ class Tensor {
|
|
|
1954
1994
|
return out;
|
|
1955
1995
|
}
|
|
1956
1996
|
// Returns this tensor with value replaced with the value of another tensor
|
|
1957
|
-
replace(other
|
|
1997
|
+
replace(other) {
|
|
1958
1998
|
other = this.handleOther(other);
|
|
1959
1999
|
// Verify shape
|
|
1960
|
-
if (
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
2000
|
+
if (this.shape.length !== other.shape.length) {
|
|
2001
|
+
throw new Error("Shape mismatch when trying to do tensor value replacement");
|
|
2002
|
+
}
|
|
2003
|
+
for (let index = 0; index < this.shape.length; index++) {
|
|
2004
|
+
if (this.shape[index] !== other.shape[index]) {
|
|
2005
|
+
throw new Error("Shape mismatch when trying to do tensor value replacement");
|
|
1965
2006
|
}
|
|
1966
2007
|
}
|
|
2008
|
+
// Reassign values
|
|
1967
2009
|
this.value = other.value;
|
|
2010
|
+
this.strides = other.strides;
|
|
2011
|
+
this.offset = other.offset;
|
|
1968
2012
|
return this;
|
|
1969
2013
|
}
|
|
1970
2014
|
// Holds all available backends
|
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "catniff",
|
|
3
|
-
"version": "0.7.
|
|
4
|
-
"description": "
|
|
3
|
+
"version": "0.7.4",
|
|
4
|
+
"description": "Torch-like deep learning framework for Javascript",
|
|
5
5
|
"main": "index.js",
|
|
6
6
|
"scripts": {
|
|
7
7
|
"test": "echo \"Error: no test specified\" && exit 1"
|