catniff 0.5.8 → 0.5.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/backend.d.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import { Tensor } from "./core";
2
2
  export interface Backend {
3
+ create(tensor: Tensor): void;
3
4
  transfer(tensor: Tensor): Tensor;
4
5
  }
package/dist/core.d.ts CHANGED
@@ -177,4 +177,5 @@ export declare class Tensor {
177
177
  replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
178
178
  static backends: Map<string, Backend>;
179
179
  to(device: string): Tensor;
180
+ to_(device: string): Tensor;
180
181
  }
package/dist/core.js CHANGED
@@ -21,13 +21,8 @@ class Tensor {
21
21
  this.gradFn = options.gradFn || (() => { });
22
22
  this.children = options.children || [];
23
23
  this.device = options.device || "cpu";
24
- // Move tensor to device
25
- if (this.device !== "cpu") {
26
- const backend = Tensor.backends.get(this.device);
27
- if (backend && backend.transfer) {
28
- backend.transfer(this);
29
- }
30
- }
24
+ // Move to device in-place
25
+ this.to_(this.device);
31
26
  }
32
27
  // Utility to flatten an nD array to be 1D
33
28
  static flatten(tensor) {
@@ -1712,9 +1707,21 @@ class Tensor {
1712
1707
  static backends = new Map();
1713
1708
  // Op to transfer tensor to another device
1714
1709
  to(device) {
1710
+ if (device === "cpu")
1711
+ return this;
1715
1712
  const backend = Tensor.backends.get(device);
1716
1713
  if (backend && backend.transfer) {
1717
- backend.transfer(this);
1714
+ return backend.transfer(this);
1715
+ }
1716
+ throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1717
+ }
1718
+ // Op to transfer tensor to another device in-place
1719
+ to_(device) {
1720
+ if (device === "cpu")
1721
+ return this;
1722
+ const backend = Tensor.backends.get(this.device);
1723
+ if (backend && backend.create) {
1724
+ backend.create(this);
1718
1725
  return this;
1719
1726
  }
1720
1727
  throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.8",
3
+ "version": "0.5.9",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {