catniff 0.2.16 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -76,18 +76,18 @@ console.log(X.grad.val(), Y.grad.val());
76
76
 
77
77
  Full documentation is available in [`./docs/documentation.md`](./docs/documentation.md).
78
78
 
79
- All available APIs are in [`./src/core.ts`](./src/core.ts) if you want to dig deeper.
79
+ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
80
80
 
81
81
  ## Todos
82
82
 
83
83
  * Bug fixes.
84
84
  * More tensor ops.
85
- * More detailed documentation.
86
85
  * GPU acceleration.
86
+ * Option to load more backends.
87
87
  * Some general neural net APIs.
88
+ * More detailed documentation.
88
89
  * Code refactoring.
89
90
  * Proper tests.
90
- * Option to load more backends.
91
91
 
92
92
  ## Copyrights and License
93
93
 
package/dist/core.d.ts CHANGED
@@ -158,4 +158,7 @@ export declare class Tensor {
158
158
  backward(): void;
159
159
  val(): TensorValue;
160
160
  withGrad(requiresGrad: boolean): Tensor;
161
+ detach(): Tensor;
162
+ clone(): Tensor;
163
+ replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
161
164
  }
package/dist/core.js CHANGED
@@ -1456,7 +1456,7 @@ class Tensor {
1456
1456
  }
1457
1457
  return buildNested(this.value, this.shape, this.strides);
1458
1458
  }
1459
- // Returns a copy of the tensor with gradient turned on/off and detaches from autograd
1459
+ // Returns a view of the tensor with gradient turned on/off and detaches from autograd
1460
1460
  withGrad(requiresGrad) {
1461
1461
  return new Tensor(this.value, {
1462
1462
  shape: this.shape,
@@ -1464,5 +1464,34 @@ class Tensor {
1464
1464
  requiresGrad
1465
1465
  });
1466
1466
  }
1467
+ // Returns a view of the tensor with gradient turned off and detaches from autograd
1468
+ detach() {
1469
+ return new Tensor(this.value, {
1470
+ shape: this.shape,
1471
+ strides: this.strides,
1472
+ requiresGrad: false
1473
+ });
1474
+ }
1475
+ // Returns a copy of the tensor (with new data allocation) and detaches from autograd
1476
+ clone() {
1477
+ return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1478
+ shape: this.shape,
1479
+ strides: this.strides,
1480
+ requiresGrad: this.requiresGrad
1481
+ });
1482
+ }
1483
+ // Returns this tensor with value replaced with the value of another tensor
1484
+ replace(other, allowShapeMismatch = false) {
1485
+ // Verify shape
1486
+ if (!allowShapeMismatch) {
1487
+ for (let index = 0; index < this.shape.length; index++) {
1488
+ if (this.shape[index] !== other.shape[index]) {
1489
+ throw new Error("Shape mismatch when trying to do tensor value replacement");
1490
+ }
1491
+ }
1492
+ }
1493
+ this.value = other.value;
1494
+ return this;
1495
+ }
1467
1496
  }
1468
1497
  exports.Tensor = Tensor;
@@ -0,0 +1,23 @@
1
+ import { Tensor } from "./core";
2
+ export interface SGDOptions {
3
+ lr?: number;
4
+ momentum?: number;
5
+ dampening?: number;
6
+ weightDecay?: number;
7
+ nesterov?: boolean;
8
+ }
9
+ declare class SGD {
10
+ params: Tensor[];
11
+ momentumBuffers: Map<Tensor, Tensor>;
12
+ lr: number;
13
+ momentum: number;
14
+ dampening: number;
15
+ weightDecay: number;
16
+ nesterov: boolean;
17
+ constructor(params: Tensor[], options?: SGDOptions);
18
+ step(): void;
19
+ }
20
+ export declare class Optim {
21
+ static SGD: typeof SGD;
22
+ }
23
+ export {};
package/dist/optim.js ADDED
@@ -0,0 +1,61 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Optim = void 0;
4
+ class SGD {
5
+ params;
6
+ momentumBuffers = new Map();
7
+ lr;
8
+ momentum;
9
+ dampening;
10
+ weightDecay;
11
+ nesterov;
12
+ constructor(params, options) {
13
+ this.params = params;
14
+ this.lr = options?.lr || 0.001;
15
+ this.momentum = options?.momentum || 0;
16
+ this.dampening = options?.dampening || 0;
17
+ this.weightDecay = options?.weightDecay || 0;
18
+ this.nesterov = options?.nesterov || false;
19
+ }
20
+ step() {
21
+ for (const param of this.params) {
22
+ if (typeof param.grad === "undefined") {
23
+ throw new Error("Can not apply SGD on empty grad");
24
+ }
25
+ let grad = param.grad.detach(), detachedParam = param.detach();
26
+ // Apply weight decay (L2 regularization)
27
+ if (this.weightDecay !== 0) {
28
+ grad = grad.add(detachedParam.mul(this.weightDecay));
29
+ }
30
+ // Apply momentum
31
+ if (this.momentum !== 0) {
32
+ let buf = this.momentumBuffers.get(param);
33
+ if (!buf) {
34
+ // First time: initialize momentum buffer with current gradient
35
+ buf = grad.clone();
36
+ this.momentumBuffers.set(param, buf);
37
+ }
38
+ else {
39
+ // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
40
+ buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
41
+ this.momentumBuffers.set(param, buf);
42
+ }
43
+ if (this.nesterov) {
44
+ // Nesterov momentum: grad = grad + momentum * buf
45
+ grad = grad.add(buf.mul(this.momentum));
46
+ }
47
+ else {
48
+ // Standard momentum: use momentum buffer as gradient
49
+ grad = buf;
50
+ }
51
+ }
52
+ // Update parameter: param = param - lr * grad
53
+ const newParam = detachedParam.sub(grad.mul(this.lr));
54
+ param.replace(newParam);
55
+ }
56
+ }
57
+ }
58
+ class Optim {
59
+ static SGD = SGD;
60
+ }
61
+ exports.Optim = Optim;
package/index.d.ts CHANGED
@@ -1 +1,2 @@
1
1
  export * from "./dist/core";
2
+ export * from "./dist/optim";
package/index.js CHANGED
@@ -1,3 +1,4 @@
1
1
  module.exports = {
2
- ...require("./dist/core")
2
+ ...require("./dist/core"),
3
+ ...require("./dist/optim")
3
4
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.2.16",
3
+ "version": "0.3.0",
4
4
  "description": "A small Torch-like deep learning framework for Javascript with tensor and autograd support",
5
5
  "main": "index.js",
6
6
  "scripts": {