catniff 0.2.16 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -3
- package/dist/core.d.ts +3 -0
- package/dist/core.js +30 -1
- package/dist/optim.d.ts +23 -0
- package/dist/optim.js +61 -0
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -76,18 +76,18 @@ console.log(X.grad.val(), Y.grad.val());
|
|
|
76
76
|
|
|
77
77
|
Full documentation is available in [`./docs/documentation.md`](./docs/documentation.md).
|
|
78
78
|
|
|
79
|
-
All available APIs are in [`./src
|
|
79
|
+
All available APIs are in [`./src/`](./src/) if you want to dig deeper.
|
|
80
80
|
|
|
81
81
|
## Todos
|
|
82
82
|
|
|
83
83
|
* Bug fixes.
|
|
84
84
|
* More tensor ops.
|
|
85
|
-
* More detailed documentation.
|
|
86
85
|
* GPU acceleration.
|
|
86
|
+
* Option to load more backends.
|
|
87
87
|
* Some general neural net APIs.
|
|
88
|
+
* More detailed documentation.
|
|
88
89
|
* Code refactoring.
|
|
89
90
|
* Proper tests.
|
|
90
|
-
* Option to load more backends.
|
|
91
91
|
|
|
92
92
|
## Copyrights and License
|
|
93
93
|
|
package/dist/core.d.ts
CHANGED
package/dist/core.js
CHANGED
|
@@ -1456,7 +1456,7 @@ class Tensor {
|
|
|
1456
1456
|
}
|
|
1457
1457
|
return buildNested(this.value, this.shape, this.strides);
|
|
1458
1458
|
}
|
|
1459
|
-
// Returns a
|
|
1459
|
+
// Returns a view of the tensor with gradient turned on/off and detaches from autograd
|
|
1460
1460
|
withGrad(requiresGrad) {
|
|
1461
1461
|
return new Tensor(this.value, {
|
|
1462
1462
|
shape: this.shape,
|
|
@@ -1464,5 +1464,34 @@ class Tensor {
|
|
|
1464
1464
|
requiresGrad
|
|
1465
1465
|
});
|
|
1466
1466
|
}
|
|
1467
|
+
// Returns a view of the tensor with gradient turned off and detaches from autograd
|
|
1468
|
+
detach() {
|
|
1469
|
+
return new Tensor(this.value, {
|
|
1470
|
+
shape: this.shape,
|
|
1471
|
+
strides: this.strides,
|
|
1472
|
+
requiresGrad: false
|
|
1473
|
+
});
|
|
1474
|
+
}
|
|
1475
|
+
// Returns a copy of the tensor (with new data allocation) and detaches from autograd
|
|
1476
|
+
clone() {
|
|
1477
|
+
return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
|
|
1478
|
+
shape: this.shape,
|
|
1479
|
+
strides: this.strides,
|
|
1480
|
+
requiresGrad: this.requiresGrad
|
|
1481
|
+
});
|
|
1482
|
+
}
|
|
1483
|
+
// Returns this tensor with value replaced with the value of another tensor
|
|
1484
|
+
replace(other, allowShapeMismatch = false) {
|
|
1485
|
+
// Verify shape
|
|
1486
|
+
if (!allowShapeMismatch) {
|
|
1487
|
+
for (let index = 0; index < this.shape.length; index++) {
|
|
1488
|
+
if (this.shape[index] !== other.shape[index]) {
|
|
1489
|
+
throw new Error("Shape mismatch when trying to do tensor value replacement");
|
|
1490
|
+
}
|
|
1491
|
+
}
|
|
1492
|
+
}
|
|
1493
|
+
this.value = other.value;
|
|
1494
|
+
return this;
|
|
1495
|
+
}
|
|
1467
1496
|
}
|
|
1468
1497
|
exports.Tensor = Tensor;
|
package/dist/optim.d.ts
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import { Tensor } from "./core";
|
|
2
|
+
export interface SGDOptions {
|
|
3
|
+
lr?: number;
|
|
4
|
+
momentum?: number;
|
|
5
|
+
dampening?: number;
|
|
6
|
+
weightDecay?: number;
|
|
7
|
+
nesterov?: boolean;
|
|
8
|
+
}
|
|
9
|
+
declare class SGD {
|
|
10
|
+
params: Tensor[];
|
|
11
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
12
|
+
lr: number;
|
|
13
|
+
momentum: number;
|
|
14
|
+
dampening: number;
|
|
15
|
+
weightDecay: number;
|
|
16
|
+
nesterov: boolean;
|
|
17
|
+
constructor(params: Tensor[], options?: SGDOptions);
|
|
18
|
+
step(): void;
|
|
19
|
+
}
|
|
20
|
+
export declare class Optim {
|
|
21
|
+
static SGD: typeof SGD;
|
|
22
|
+
}
|
|
23
|
+
export {};
|
package/dist/optim.js
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Optim = void 0;
|
|
4
|
+
class SGD {
|
|
5
|
+
params;
|
|
6
|
+
momentumBuffers = new Map();
|
|
7
|
+
lr;
|
|
8
|
+
momentum;
|
|
9
|
+
dampening;
|
|
10
|
+
weightDecay;
|
|
11
|
+
nesterov;
|
|
12
|
+
constructor(params, options) {
|
|
13
|
+
this.params = params;
|
|
14
|
+
this.lr = options?.lr || 0.001;
|
|
15
|
+
this.momentum = options?.momentum || 0;
|
|
16
|
+
this.dampening = options?.dampening || 0;
|
|
17
|
+
this.weightDecay = options?.weightDecay || 0;
|
|
18
|
+
this.nesterov = options?.nesterov || false;
|
|
19
|
+
}
|
|
20
|
+
step() {
|
|
21
|
+
for (const param of this.params) {
|
|
22
|
+
if (typeof param.grad === "undefined") {
|
|
23
|
+
throw new Error("Can not apply SGD on empty grad");
|
|
24
|
+
}
|
|
25
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
26
|
+
// Apply weight decay (L2 regularization)
|
|
27
|
+
if (this.weightDecay !== 0) {
|
|
28
|
+
grad = grad.add(detachedParam.mul(this.weightDecay));
|
|
29
|
+
}
|
|
30
|
+
// Apply momentum
|
|
31
|
+
if (this.momentum !== 0) {
|
|
32
|
+
let buf = this.momentumBuffers.get(param);
|
|
33
|
+
if (!buf) {
|
|
34
|
+
// First time: initialize momentum buffer with current gradient
|
|
35
|
+
buf = grad.clone();
|
|
36
|
+
this.momentumBuffers.set(param, buf);
|
|
37
|
+
}
|
|
38
|
+
else {
|
|
39
|
+
// Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
|
|
40
|
+
buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
|
|
41
|
+
this.momentumBuffers.set(param, buf);
|
|
42
|
+
}
|
|
43
|
+
if (this.nesterov) {
|
|
44
|
+
// Nesterov momentum: grad = grad + momentum * buf
|
|
45
|
+
grad = grad.add(buf.mul(this.momentum));
|
|
46
|
+
}
|
|
47
|
+
else {
|
|
48
|
+
// Standard momentum: use momentum buffer as gradient
|
|
49
|
+
grad = buf;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
// Update parameter: param = param - lr * grad
|
|
53
|
+
const newParam = detachedParam.sub(grad.mul(this.lr));
|
|
54
|
+
param.replace(newParam);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
class Optim {
|
|
59
|
+
static SGD = SGD;
|
|
60
|
+
}
|
|
61
|
+
exports.Optim = Optim;
|
package/index.d.ts
CHANGED
package/index.js
CHANGED