catniff 0.5.8 → 0.5.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/backend.d.ts +1 -0
- package/dist/core.d.ts +1 -0
- package/dist/core.js +15 -8
- package/package.json +1 -1
package/dist/backend.d.ts
CHANGED
package/dist/core.d.ts
CHANGED
package/dist/core.js
CHANGED
|
@@ -21,13 +21,8 @@ class Tensor {
|
|
|
21
21
|
this.gradFn = options.gradFn || (() => { });
|
|
22
22
|
this.children = options.children || [];
|
|
23
23
|
this.device = options.device || "cpu";
|
|
24
|
-
// Move
|
|
25
|
-
|
|
26
|
-
const backend = Tensor.backends.get(this.device);
|
|
27
|
-
if (backend && backend.transfer) {
|
|
28
|
-
backend.transfer(this);
|
|
29
|
-
}
|
|
30
|
-
}
|
|
24
|
+
// Move to device in-place
|
|
25
|
+
this.to_(this.device);
|
|
31
26
|
}
|
|
32
27
|
// Utility to flatten an nD array to be 1D
|
|
33
28
|
static flatten(tensor) {
|
|
@@ -1712,9 +1707,21 @@ class Tensor {
|
|
|
1712
1707
|
static backends = new Map();
|
|
1713
1708
|
// Op to transfer tensor to another device
|
|
1714
1709
|
to(device) {
|
|
1710
|
+
if (device === "cpu")
|
|
1711
|
+
return this;
|
|
1715
1712
|
const backend = Tensor.backends.get(device);
|
|
1716
1713
|
if (backend && backend.transfer) {
|
|
1717
|
-
backend.transfer(this);
|
|
1714
|
+
return backend.transfer(this);
|
|
1715
|
+
}
|
|
1716
|
+
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|
|
1717
|
+
}
|
|
1718
|
+
// Op to transfer tensor to another device in-place
|
|
1719
|
+
to_(device) {
|
|
1720
|
+
if (device === "cpu")
|
|
1721
|
+
return this;
|
|
1722
|
+
const backend = Tensor.backends.get(this.device);
|
|
1723
|
+
if (backend && backend.create) {
|
|
1724
|
+
backend.create(this);
|
|
1718
1725
|
return this;
|
|
1719
1726
|
}
|
|
1720
1727
|
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|