catniff 0.4.2 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -9,22 +9,6 @@ Install through npm:
9
9
  npm install catniff
10
10
  ```
11
11
 
12
- ## Example
13
-
14
- Here is a little demo of a quadratic function:
15
- ```js
16
- const { Tensor } = require("catniff");
17
-
18
- const x = new Tensor(2, { requiresGrad: true });
19
- const L = x.pow(2).add(x); // x^2 + x
20
-
21
- L.backward();
22
-
23
- console.log(x.grad.val()); // 5
24
- ```
25
-
26
- View all examples in [`./examples`](./examples).
27
-
28
12
  ## Tensors
29
13
 
30
14
  Tensors in Catniff can be created by passing in a number or an nD array, and there are built-in methods that can be used to perform tensor arithmetic:
@@ -92,7 +76,7 @@ optim.step();
92
76
  console.log("Updated weight:", w.data); // Should move toward 3.0
93
77
  ```
94
78
 
95
- And it can still do much more, check out the docs mentioned below for more information.
79
+ And it can still do much more, check out the docs and examples below for more information.
96
80
 
97
81
  ## Documentation
98
82
 
@@ -100,12 +84,19 @@ Full documentation is available in [`./docs/documentation.md`](./docs/documentat
100
84
 
101
85
  All available APIs are in [`./src/`](./src/) if you want to dig deeper.
102
86
 
87
+ ## Examples
88
+
89
+ * [Simple neural net for XOR calculation](./examples/xornet.js).
90
+ * [Tensors](./examples/tensors.js).
91
+ * [Optimizer](./examples/optim.js).
92
+ * [Simple quadratic equation](./examples/quadratic.js).
93
+
103
94
  ## Todos
104
95
 
105
96
  * Bug fixes.
106
97
  * More tensor ops.
107
98
  * GPU acceleration.
108
- * Some general neural net APIs.
99
+ * More general neural net APIs.
109
100
  * More detailed documentation.
110
101
  * Code refactoring.
111
102
  * Proper tests.
package/dist/core.js CHANGED
@@ -20,6 +20,13 @@ class Tensor {
20
20
  this.gradFn = options.gradFn || (() => { });
21
21
  this.children = options.children || [];
22
22
  this.device = options.device || "cpu";
23
+ // Move tensor to device
24
+ if (this.device !== "cpu") {
25
+ const backend = Tensor.backends.get(this.device);
26
+ if (backend && backend.transfer) {
27
+ backend.transfer(this);
28
+ }
29
+ }
23
30
  }
24
31
  // Utility to flatten an nD array to be 1D
25
32
  static flatten(tensor) {
@@ -1527,6 +1534,7 @@ class Tensor {
1527
1534
  return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1528
1535
  shape: this.shape,
1529
1536
  strides: this.strides,
1537
+ device: this.device,
1530
1538
  requiresGrad: this.requiresGrad
1531
1539
  });
1532
1540
  }
@@ -1549,7 +1557,8 @@ class Tensor {
1549
1557
  to(device) {
1550
1558
  const backend = Tensor.backends.get(device);
1551
1559
  if (backend && backend.transfer) {
1552
- return backend.transfer(this);
1560
+ backend.transfer(this);
1561
+ return this;
1553
1562
  }
1554
1563
  throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1555
1564
  }
package/dist/nn.d.ts ADDED
@@ -0,0 +1,23 @@
1
+ import { Tensor, TensorValue } from "./core";
2
+ declare class Linear {
3
+ weight: Tensor;
4
+ bias?: Tensor;
5
+ constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
6
+ forward(input: Tensor | TensorValue): Tensor;
7
+ }
8
+ declare class RNNCell {
9
+ weightIH: Tensor;
10
+ weightHH: Tensor;
11
+ biasIH?: Tensor;
12
+ biasHH?: Tensor;
13
+ constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
14
+ forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
15
+ }
16
+ export declare const nn: {
17
+ Linear: typeof Linear;
18
+ RNNCell: typeof RNNCell;
19
+ state: {
20
+ getParameters(model: any, visited?: WeakSet<object>): Tensor[];
21
+ };
22
+ };
23
+ export {};
package/dist/nn.js ADDED
@@ -0,0 +1,74 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.nn = void 0;
4
+ const core_1 = require("./core");
5
+ class Linear {
6
+ weight;
7
+ bias;
8
+ constructor(inFeatures, outFeatures, bias = true, device) {
9
+ const bound = 1 / Math.sqrt(inFeatures);
10
+ this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
11
+ if (bias) {
12
+ this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
13
+ }
14
+ }
15
+ forward(input) {
16
+ input = core_1.Tensor.forceTensor(input);
17
+ let output = input.matmul(this.weight.t());
18
+ if (this.bias) {
19
+ output = output.add(this.bias);
20
+ }
21
+ return output;
22
+ }
23
+ }
24
+ class RNNCell {
25
+ weightIH;
26
+ weightHH;
27
+ biasIH;
28
+ biasHH;
29
+ constructor(inputSize, hiddenSize, bias = true, device) {
30
+ const bound = 1 / Math.sqrt(hiddenSize);
31
+ this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
32
+ this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
33
+ if (bias) {
34
+ this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
35
+ this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
36
+ }
37
+ }
38
+ forward(input, hidden) {
39
+ input = core_1.Tensor.forceTensor(input);
40
+ hidden = core_1.Tensor.forceTensor(hidden);
41
+ let output = input.matmul(this.weightIH.t())
42
+ .add(hidden.matmul(this.weightHH.t()));
43
+ if (this.biasIH && this.biasHH) {
44
+ output = output.add(this.biasIH).add(this.biasHH);
45
+ }
46
+ return output.tanh();
47
+ }
48
+ }
49
+ const state = {
50
+ getParameters(model, visited = new WeakSet()) {
51
+ if (visited.has(model)) {
52
+ return [];
53
+ }
54
+ visited.add(model);
55
+ const parameters = [];
56
+ for (const key in model) {
57
+ if (!model.hasOwnProperty(key))
58
+ continue;
59
+ const value = model[key];
60
+ if (value instanceof core_1.Tensor) {
61
+ parameters.push(value);
62
+ }
63
+ else if (typeof value === "object" && value !== null) {
64
+ parameters.push(...state.getParameters(value, visited));
65
+ }
66
+ }
67
+ return parameters;
68
+ }
69
+ };
70
+ exports.nn = {
71
+ Linear,
72
+ RNNCell,
73
+ state
74
+ };
package/dist/optim.js CHANGED
@@ -20,9 +20,8 @@ class SGD {
20
20
  }
21
21
  step() {
22
22
  for (const param of this.params) {
23
- if (!param.grad) {
24
- throw new Error("Can not apply SGD on empty grad");
25
- }
23
+ if (!param.grad || !param.requiresGrad)
24
+ continue;
26
25
  let grad = param.grad.detach(), detachedParam = param.detach();
27
26
  // Apply weight decay (L2 regularization)
28
27
  if (this.weightDecay !== 0) {
@@ -80,9 +79,8 @@ class Adam {
80
79
  const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
81
80
  const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
82
81
  for (const param of this.params) {
83
- if (!param.grad) {
84
- throw new Error("Can not apply Adam on empty grad");
85
- }
82
+ if (!param.grad || !param.requiresGrad)
83
+ continue;
86
84
  let grad = param.grad.detach(), detachedParam = param.detach();
87
85
  // Apply weight decay (L2 regularization)
88
86
  if (this.weightDecay !== 0) {
package/index.d.ts CHANGED
@@ -1,2 +1,3 @@
1
1
  export * from "./dist/core";
2
2
  export * from "./dist/optim";
3
+ export * from "./dist/nn";
package/index.js CHANGED
@@ -1,4 +1,5 @@
1
1
  module.exports = {
2
2
  ...require("./dist/core"),
3
- ...require("./dist/optim")
3
+ ...require("./dist/optim"),
4
+ ...require("./dist/nn")
4
5
  };
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.4.2",
4
- "description": "A small Torch-like deep learning framework for Javascript with tensor and autograd support",
3
+ "version": "0.5.1",
4
+ "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {
7
7
  "test": "echo \"Error: no test specified\" && exit 1"