catniff 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js CHANGED
@@ -20,6 +20,13 @@ class Tensor {
20
20
  this.gradFn = options.gradFn || (() => { });
21
21
  this.children = options.children || [];
22
22
  this.device = options.device || "cpu";
23
+ // Move tensor to device
24
+ if (this.device !== "cpu") {
25
+ const backend = Tensor.backends.get(this.device);
26
+ if (backend && backend.transfer) {
27
+ backend.transfer(this);
28
+ }
29
+ }
23
30
  }
24
31
  // Utility to flatten an nD array to be 1D
25
32
  static flatten(tensor) {
@@ -1527,6 +1534,7 @@ class Tensor {
1527
1534
  return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1528
1535
  shape: this.shape,
1529
1536
  strides: this.strides,
1537
+ device: this.device,
1530
1538
  requiresGrad: this.requiresGrad
1531
1539
  });
1532
1540
  }
@@ -1549,7 +1557,8 @@ class Tensor {
1549
1557
  to(device) {
1550
1558
  const backend = Tensor.backends.get(device);
1551
1559
  if (backend && backend.transfer) {
1552
- return backend.transfer(this);
1560
+ backend.transfer(this);
1561
+ return this;
1553
1562
  }
1554
1563
  throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1555
1564
  }
package/dist/nn.d.ts CHANGED
@@ -2,11 +2,20 @@ import { Tensor, TensorValue } from "./core";
2
2
  declare class Linear {
3
3
  weight: Tensor;
4
4
  bias?: Tensor;
5
- constructor(inFeatures: number, outFeatures: number, bias?: boolean, customInit?: (shape: number[]) => Tensor);
5
+ constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
6
6
  forward(input: Tensor | TensorValue): Tensor;
7
7
  }
8
+ declare class RNNCell {
9
+ weightIH: Tensor;
10
+ weightHH: Tensor;
11
+ biasIH?: Tensor;
12
+ biasHH?: Tensor;
13
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
14
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
15
+ }
8
16
  export declare const nn: {
9
17
  Linear: typeof Linear;
18
+ RNNCell: typeof RNNCell;
10
19
  state: {
11
20
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
12
21
  };
package/dist/nn.js CHANGED
@@ -5,17 +5,11 @@ const core_1 = require("./core");
5
5
  class Linear {
6
6
  weight;
7
7
  bias;
8
- constructor(inFeatures, outFeatures, bias = true, customInit) {
9
- let initFunc = (shape) => {
10
- const bound = 1 / Math.sqrt(inFeatures);
11
- return core_1.Tensor.uniform(shape, -bound, bound, { requiresGrad: true });
12
- };
13
- if (customInit) {
14
- initFunc = customInit;
15
- }
16
- this.weight = initFunc([outFeatures, inFeatures]);
8
+ constructor(inFeatures, outFeatures, bias = true, device) {
9
+ const bound = 1 / Math.sqrt(inFeatures);
10
+ this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
17
11
  if (bias) {
18
- this.bias = initFunc([outFeatures]);
12
+ this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
19
13
  }
20
14
  }
21
15
  forward(input) {
@@ -27,6 +21,31 @@ class Linear {
27
21
  return output;
28
22
  }
29
23
  }
24
+ class RNNCell {
25
+ weightIH;
26
+ weightHH;
27
+ biasIH;
28
+ biasHH;
29
+ constructor(inputSize, hiddenSize, bias = true, device) {
30
+ const bound = 1 / Math.sqrt(hiddenSize);
31
+ this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
32
+ this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
33
+ if (bias) {
34
+ this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
35
+ this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
36
+ }
37
+ }
38
+ forward(input, hidden) {
39
+ input = core_1.Tensor.forceTensor(input);
40
+ hidden = core_1.Tensor.forceTensor(hidden);
41
+ let output = input.matmul(this.weightIH.t())
42
+ .add(hidden.matmul(this.weightHH.t()));
43
+ if (this.biasIH && this.biasHH) {
44
+ output = output.add(this.biasIH).add(this.biasHH);
45
+ }
46
+ return output.tanh();
47
+ }
48
+ }
30
49
  const state = {
31
50
  getParameters(model, visited = new WeakSet()) {
32
51
  if (visited.has(model)) {
@@ -50,5 +69,6 @@ const state = {
50
69
  };
51
70
  exports.nn = {
52
71
  Linear,
72
+ RNNCell,
53
73
  state
54
74
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.0",
3
+ "version": "0.5.1",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {