catniff 0.8.13 → 0.8.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -3
- package/dist/core.js +12 -12
- package/dist/lrscheduler.d.ts +4 -3
- package/dist/lrscheduler.js +8 -1
- package/dist/optim.d.ts +33 -15
- package/dist/optim.js +170 -133
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -143,9 +143,8 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
|
|
|
143
143
|
|
|
144
144
|
* More general tensor ops.
|
|
145
145
|
* More general neural net APIs.
|
|
146
|
-
* GPU acceleration.
|
|
147
|
-
*
|
|
148
|
-
* Bug fixes.
|
|
146
|
+
* GPU acceleration, possibly through WebGPU, Libtorch bindings, or CUDA.
|
|
147
|
+
* Proper optimization.
|
|
149
148
|
* More detailed documentation.
|
|
150
149
|
* Code refactoring.
|
|
151
150
|
* Proper tests.
|
package/dist/core.js
CHANGED
|
@@ -20,21 +20,21 @@ class Tensor {
|
|
|
20
20
|
static createGraph = false;
|
|
21
21
|
constructor(value, options = {}) {
|
|
22
22
|
// Memory buffer
|
|
23
|
-
this.dtype = options.dtype
|
|
23
|
+
this.dtype = options.dtype ?? "float32";
|
|
24
24
|
const flatValue = Tensor.flattenValue(value);
|
|
25
25
|
const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
|
|
26
26
|
this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
|
|
27
27
|
// Tensor metadata
|
|
28
|
-
this.shape = options.shape
|
|
29
|
-
this.strides = options.strides
|
|
30
|
-
this.offset = options.offset
|
|
31
|
-
this.numel = options.numel
|
|
32
|
-
this.device = options.device
|
|
28
|
+
this.shape = options.shape ?? Tensor.getShape(value);
|
|
29
|
+
this.strides = options.strides ?? Tensor.getStrides(this.shape);
|
|
30
|
+
this.offset = options.offset ?? 0;
|
|
31
|
+
this.numel = options.numel ?? Tensor.shapeToSize(this.shape);
|
|
32
|
+
this.device = options.device ?? "cpu";
|
|
33
33
|
// Autograd data
|
|
34
34
|
this.grad = options.grad;
|
|
35
35
|
this.requiresGrad = options.requiresGrad ?? false;
|
|
36
|
-
this.gradFn = options.gradFn
|
|
37
|
-
this.children = options.children
|
|
36
|
+
this.gradFn = options.gradFn ?? (() => { });
|
|
37
|
+
this.children = options.children ?? [];
|
|
38
38
|
// Move to device in-place
|
|
39
39
|
this.to_(this.device);
|
|
40
40
|
}
|
|
@@ -622,14 +622,14 @@ class Tensor {
|
|
|
622
622
|
return this;
|
|
623
623
|
const newShape = [];
|
|
624
624
|
const newStrides = [];
|
|
625
|
-
let newOffset = this.offset
|
|
625
|
+
let newOffset = this.offset;
|
|
626
626
|
// Pad ranges to match tensor dimensions
|
|
627
627
|
const paddedRanges = [...ranges];
|
|
628
628
|
while (paddedRanges.length < this.shape.length) {
|
|
629
629
|
paddedRanges.push([]);
|
|
630
630
|
}
|
|
631
631
|
for (let i = 0; i < this.shape.length; i++) {
|
|
632
|
-
const range = paddedRanges[i]
|
|
632
|
+
const range = paddedRanges[i] ?? [];
|
|
633
633
|
const dimSize = this.shape[i];
|
|
634
634
|
const stride = this.strides[i];
|
|
635
635
|
// Default values
|
|
@@ -675,7 +675,7 @@ class Tensor {
|
|
|
675
675
|
const originalCoords = new Array(slicedCoords.length);
|
|
676
676
|
for (let dim = 0; dim < slicedCoords.length; dim++) {
|
|
677
677
|
const coord = slicedCoords[dim];
|
|
678
|
-
const range = paddedRanges[dim]
|
|
678
|
+
const range = paddedRanges[dim] ?? [];
|
|
679
679
|
const start = range[0] ?? 0;
|
|
680
680
|
const step = range[2] ?? 1;
|
|
681
681
|
const normalizedStart = start < 0 ? start + this.shape[dim] : start;
|
|
@@ -2504,7 +2504,7 @@ class Tensor {
|
|
|
2504
2504
|
visited.add(node);
|
|
2505
2505
|
// Reset grad to zeros if specified
|
|
2506
2506
|
if (zeroGrad) {
|
|
2507
|
-
node.grad
|
|
2507
|
+
delete node.grad;
|
|
2508
2508
|
}
|
|
2509
2509
|
for (let child of node.children)
|
|
2510
2510
|
build(child);
|
package/dist/lrscheduler.d.ts
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { OptimizerWithLR } from "./optim";
|
|
2
2
|
export declare class StepLR {
|
|
3
|
-
optimizer:
|
|
3
|
+
optimizer: OptimizerWithLR;
|
|
4
4
|
stepSize: number;
|
|
5
5
|
gamma: number;
|
|
6
6
|
lastEpoch: number;
|
|
7
7
|
baseLR: number;
|
|
8
|
-
|
|
8
|
+
baseGroupLRs: number[];
|
|
9
|
+
constructor(optimizer: OptimizerWithLR, stepSize: number, gamma?: number, lastEpoch?: number);
|
|
9
10
|
step(epoch?: number): void;
|
|
10
11
|
}
|
|
11
12
|
export declare const LRScheduler: {
|
package/dist/lrscheduler.js
CHANGED
|
@@ -7,12 +7,14 @@ class StepLR {
|
|
|
7
7
|
gamma;
|
|
8
8
|
lastEpoch;
|
|
9
9
|
baseLR;
|
|
10
|
+
baseGroupLRs;
|
|
10
11
|
constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
|
|
11
12
|
this.optimizer = optimizer;
|
|
12
13
|
this.stepSize = stepSize;
|
|
13
14
|
this.gamma = gamma;
|
|
14
15
|
this.lastEpoch = lastEpoch;
|
|
15
|
-
this.baseLR =
|
|
16
|
+
this.baseLR = optimizer.lr;
|
|
17
|
+
this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
|
|
16
18
|
}
|
|
17
19
|
step(epoch) {
|
|
18
20
|
if (typeof epoch === "undefined") {
|
|
@@ -22,6 +24,11 @@ class StepLR {
|
|
|
22
24
|
else {
|
|
23
25
|
this.lastEpoch = epoch;
|
|
24
26
|
}
|
|
27
|
+
// Update LR of each group
|
|
28
|
+
for (let index = 0; index < this.baseGroupLRs.length; index++) {
|
|
29
|
+
this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * this.gamma ** Math.floor(epoch / this.stepSize);
|
|
30
|
+
}
|
|
31
|
+
// Update default LR
|
|
25
32
|
this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
|
|
26
33
|
}
|
|
27
34
|
}
|
package/dist/optim.d.ts
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import { Tensor } from "./core";
|
|
2
|
-
export interface
|
|
3
|
-
|
|
2
|
+
export interface BaseParamGroup {
|
|
3
|
+
params: Tensor[];
|
|
4
|
+
[key: string]: any;
|
|
4
5
|
}
|
|
5
6
|
export declare abstract class BaseOptimizer {
|
|
6
|
-
|
|
7
|
+
paramGroups: BaseParamGroup[];
|
|
8
|
+
constructor(params: Tensor[] | BaseParamGroup[]);
|
|
9
|
+
zeroGrad(del?: boolean): void;
|
|
10
|
+
}
|
|
11
|
+
export interface OptimizerWithLR extends BaseOptimizer {
|
|
7
12
|
lr: number;
|
|
8
|
-
constructor(params: Tensor[], options?: BaseOptimizerOptions);
|
|
9
|
-
zeroGrad(): void;
|
|
10
13
|
}
|
|
11
14
|
export interface SGDOptions {
|
|
12
15
|
lr?: number;
|
|
@@ -15,13 +18,18 @@ export interface SGDOptions {
|
|
|
15
18
|
weightDecay?: number;
|
|
16
19
|
nesterov?: boolean;
|
|
17
20
|
}
|
|
21
|
+
export interface SGDParamGroup extends SGDOptions {
|
|
22
|
+
params: Tensor[];
|
|
23
|
+
}
|
|
18
24
|
export declare class SGD extends BaseOptimizer {
|
|
19
|
-
|
|
25
|
+
paramGroups: SGDParamGroup[];
|
|
26
|
+
lr: number;
|
|
20
27
|
momentum: number;
|
|
21
28
|
dampening: number;
|
|
22
29
|
weightDecay: number;
|
|
23
30
|
nesterov: boolean;
|
|
24
|
-
|
|
31
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
32
|
+
constructor(params: Tensor[] | SGDParamGroup[], options?: SGDOptions);
|
|
25
33
|
step(): void;
|
|
26
34
|
}
|
|
27
35
|
export interface AdamOptions {
|
|
@@ -30,14 +38,19 @@ export interface AdamOptions {
|
|
|
30
38
|
eps?: number;
|
|
31
39
|
weightDecay?: number;
|
|
32
40
|
}
|
|
41
|
+
export interface AdamParamGroup extends AdamOptions {
|
|
42
|
+
params: Tensor[];
|
|
43
|
+
}
|
|
33
44
|
export declare class Adam extends BaseOptimizer {
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
stepCount: number;
|
|
45
|
+
paramGroups: AdamParamGroup[];
|
|
46
|
+
lr: number;
|
|
37
47
|
betas: [number, number];
|
|
38
48
|
eps: number;
|
|
39
49
|
weightDecay: number;
|
|
40
|
-
|
|
50
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
51
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
52
|
+
stepCounts: Map<Tensor, number>;
|
|
53
|
+
constructor(params: Tensor[] | AdamParamGroup[], options?: AdamOptions);
|
|
41
54
|
step(): void;
|
|
42
55
|
}
|
|
43
56
|
export interface AdamWOptions {
|
|
@@ -46,14 +59,19 @@ export interface AdamWOptions {
|
|
|
46
59
|
eps?: number;
|
|
47
60
|
weightDecay?: number;
|
|
48
61
|
}
|
|
62
|
+
export interface AdamWParamGroup extends AdamWOptions {
|
|
63
|
+
params: Tensor[];
|
|
64
|
+
}
|
|
49
65
|
export declare class AdamW extends BaseOptimizer {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
stepCount: number;
|
|
66
|
+
paramGroups: AdamWParamGroup[];
|
|
67
|
+
lr: number;
|
|
53
68
|
betas: [number, number];
|
|
54
69
|
eps: number;
|
|
55
70
|
weightDecay: number;
|
|
56
|
-
|
|
71
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
72
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
73
|
+
stepCounts: Map<Tensor, number>;
|
|
74
|
+
constructor(params: Tensor[] | AdamWParamGroup[], options?: AdamWOptions);
|
|
57
75
|
step(): void;
|
|
58
76
|
}
|
|
59
77
|
export declare const Optim: {
|
package/dist/optim.js
CHANGED
|
@@ -3,183 +3,220 @@ Object.defineProperty(exports, "__esModule", { value: true });
|
|
|
3
3
|
exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class BaseOptimizer {
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
6
|
+
paramGroups;
|
|
7
|
+
constructor(params) {
|
|
8
|
+
if (params[0] instanceof core_1.Tensor) {
|
|
9
|
+
this.paramGroups = [{ params: params }];
|
|
10
|
+
}
|
|
11
|
+
else {
|
|
12
|
+
this.paramGroups = params;
|
|
13
|
+
}
|
|
11
14
|
}
|
|
12
|
-
zeroGrad() {
|
|
13
|
-
for (let index = 0; index < this.
|
|
14
|
-
const
|
|
15
|
-
param
|
|
15
|
+
zeroGrad(del = true) {
|
|
16
|
+
for (let index = 0; index < this.paramGroups.length; index++) {
|
|
17
|
+
const paramGroup = this.paramGroups[index];
|
|
18
|
+
for (const param of paramGroup.params) {
|
|
19
|
+
if (del) {
|
|
20
|
+
delete param.grad;
|
|
21
|
+
}
|
|
22
|
+
else {
|
|
23
|
+
param.grad = core_1.Tensor.zerosLike(param);
|
|
24
|
+
}
|
|
25
|
+
}
|
|
16
26
|
}
|
|
17
27
|
}
|
|
18
28
|
}
|
|
19
29
|
exports.BaseOptimizer = BaseOptimizer;
|
|
20
30
|
class SGD extends BaseOptimizer {
|
|
21
|
-
|
|
31
|
+
lr;
|
|
22
32
|
momentum;
|
|
23
33
|
dampening;
|
|
24
34
|
weightDecay;
|
|
25
35
|
nesterov;
|
|
36
|
+
momentumBuffers = new Map();
|
|
26
37
|
constructor(params, options) {
|
|
27
|
-
super(params
|
|
28
|
-
this.
|
|
29
|
-
this.
|
|
30
|
-
this.
|
|
31
|
-
this.
|
|
38
|
+
super(params);
|
|
39
|
+
this.lr = options?.lr ?? 0.001;
|
|
40
|
+
this.momentum = options?.momentum ?? 0;
|
|
41
|
+
this.dampening = options?.dampening ?? 0;
|
|
42
|
+
this.weightDecay = options?.weightDecay ?? 0;
|
|
43
|
+
this.nesterov = options?.nesterov ?? false;
|
|
32
44
|
}
|
|
33
45
|
step() {
|
|
34
|
-
for (const
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if (
|
|
46
|
-
|
|
47
|
-
buf = grad.clone();
|
|
48
|
-
this.momentumBuffers.set(param, buf);
|
|
49
|
-
}
|
|
50
|
-
else {
|
|
51
|
-
// Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
|
|
52
|
-
buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
|
|
53
|
-
this.momentumBuffers.set(param, buf);
|
|
54
|
-
}
|
|
55
|
-
if (this.nesterov) {
|
|
56
|
-
// Nesterov momentum: grad = grad + momentum * buf
|
|
57
|
-
grad = grad.add(buf.mul(this.momentum));
|
|
46
|
+
for (const paramGroup of this.paramGroups) {
|
|
47
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
48
|
+
const momentum = paramGroup.momentum ?? this.momentum;
|
|
49
|
+
const dampening = paramGroup.dampening ?? this.dampening;
|
|
50
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
51
|
+
const nesterov = paramGroup.nesterov ?? this.nesterov;
|
|
52
|
+
for (const param of paramGroup.params) {
|
|
53
|
+
if (!param.grad || !param.requiresGrad)
|
|
54
|
+
continue;
|
|
55
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
56
|
+
// Apply weight decay (L2 regularization)
|
|
57
|
+
if (weightDecay !== 0) {
|
|
58
|
+
grad = grad.add(detachedParam.mul(weightDecay));
|
|
58
59
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
// Apply momentum
|
|
61
|
+
if (momentum !== 0) {
|
|
62
|
+
let buf = this.momentumBuffers.get(param);
|
|
63
|
+
if (!buf) {
|
|
64
|
+
// First time: initialize momentum buffer with current gradient
|
|
65
|
+
buf = grad.clone();
|
|
66
|
+
this.momentumBuffers.set(param, buf);
|
|
67
|
+
}
|
|
68
|
+
else {
|
|
69
|
+
// Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
|
|
70
|
+
buf = buf.mul(momentum).add(grad.mul(1 - dampening));
|
|
71
|
+
this.momentumBuffers.set(param, buf);
|
|
72
|
+
}
|
|
73
|
+
if (nesterov) {
|
|
74
|
+
// Nesterov momentum: grad = grad + momentum * buf
|
|
75
|
+
grad = grad.add(buf.mul(momentum));
|
|
76
|
+
}
|
|
77
|
+
else {
|
|
78
|
+
// Standard momentum: use momentum buffer as gradient
|
|
79
|
+
grad = buf;
|
|
80
|
+
}
|
|
62
81
|
}
|
|
82
|
+
// Update parameter: param = param - lr * grad
|
|
83
|
+
const newParam = detachedParam.sub(grad.mul(lr));
|
|
84
|
+
param.replace(newParam);
|
|
63
85
|
}
|
|
64
|
-
// Update parameter: param = param - lr * grad
|
|
65
|
-
const newParam = detachedParam.sub(grad.mul(this.lr));
|
|
66
|
-
param.replace(newParam);
|
|
67
86
|
}
|
|
68
87
|
}
|
|
69
88
|
}
|
|
70
89
|
exports.SGD = SGD;
|
|
71
90
|
class Adam extends BaseOptimizer {
|
|
72
|
-
|
|
73
|
-
velocityBuffers = new Map(); // Second moment (v_t)
|
|
74
|
-
stepCount = 0;
|
|
91
|
+
lr;
|
|
75
92
|
betas;
|
|
76
93
|
eps;
|
|
77
94
|
weightDecay;
|
|
95
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
96
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
97
|
+
stepCounts = new Map();
|
|
78
98
|
constructor(params, options) {
|
|
79
|
-
super(params
|
|
80
|
-
this.
|
|
81
|
-
this.
|
|
82
|
-
this.
|
|
99
|
+
super(params);
|
|
100
|
+
this.lr = options?.lr ?? 0.001;
|
|
101
|
+
this.betas = options?.betas ?? [0.9, 0.999];
|
|
102
|
+
this.eps = options?.eps ?? 1e-8;
|
|
103
|
+
this.weightDecay = options?.weightDecay ?? 0;
|
|
83
104
|
}
|
|
84
105
|
step() {
|
|
85
|
-
this.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
106
|
+
for (const paramGroup of this.paramGroups) {
|
|
107
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
108
|
+
const betas = paramGroup.betas ?? this.betas;
|
|
109
|
+
const eps = paramGroup.eps ?? this.eps;
|
|
110
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
111
|
+
for (const param of paramGroup.params) {
|
|
112
|
+
if (!param.grad || !param.requiresGrad)
|
|
113
|
+
continue;
|
|
114
|
+
// Get current step for param, initialize if has not step before
|
|
115
|
+
const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
|
|
116
|
+
this.stepCounts.set(param, stepCount);
|
|
117
|
+
// Bias correction factors
|
|
118
|
+
const [beta1, beta2] = betas;
|
|
119
|
+
const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
|
|
120
|
+
const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
|
|
121
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
122
|
+
// Apply weight decay (L2 regularization)
|
|
123
|
+
if (weightDecay !== 0) {
|
|
124
|
+
grad = grad.add(detachedParam.mul(weightDecay));
|
|
125
|
+
}
|
|
126
|
+
// Get or initialize first moment buffer (momentum)
|
|
127
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
128
|
+
if (!momentumBuffer) {
|
|
129
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
130
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
131
|
+
}
|
|
132
|
+
// Get or initialize second moment buffer (velocity)
|
|
133
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
134
|
+
if (!velocityBuffer) {
|
|
135
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
136
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
137
|
+
}
|
|
138
|
+
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
139
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
103
140
|
this.momentumBuffers.set(param, momentumBuffer);
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
let velocityBuffer = this.velocityBuffers.get(param);
|
|
107
|
-
if (!velocityBuffer) {
|
|
108
|
-
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
141
|
+
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
142
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
109
143
|
this.velocityBuffers.set(param, velocityBuffer);
|
|
144
|
+
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
145
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
146
|
+
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
147
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
148
|
+
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
149
|
+
const denom = correctedVelocity.sqrt().add(eps);
|
|
150
|
+
const stepSize = correctedMomentum.div(denom).mul(lr);
|
|
151
|
+
const newParam = detachedParam.sub(stepSize);
|
|
152
|
+
param.replace(newParam);
|
|
110
153
|
}
|
|
111
|
-
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
112
|
-
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
113
|
-
this.momentumBuffers.set(param, momentumBuffer);
|
|
114
|
-
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
115
|
-
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
116
|
-
this.velocityBuffers.set(param, velocityBuffer);
|
|
117
|
-
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
118
|
-
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
119
|
-
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
120
|
-
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
121
|
-
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
122
|
-
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
123
|
-
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
124
|
-
const newParam = detachedParam.sub(stepSize);
|
|
125
|
-
param.replace(newParam);
|
|
126
154
|
}
|
|
127
155
|
}
|
|
128
156
|
}
|
|
129
157
|
exports.Adam = Adam;
|
|
130
158
|
class AdamW extends BaseOptimizer {
|
|
131
|
-
|
|
132
|
-
velocityBuffers = new Map(); // Second moment (v_t)
|
|
133
|
-
stepCount = 0;
|
|
159
|
+
lr;
|
|
134
160
|
betas;
|
|
135
161
|
eps;
|
|
136
162
|
weightDecay;
|
|
163
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
164
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
165
|
+
stepCounts = new Map();
|
|
137
166
|
constructor(params, options) {
|
|
138
|
-
super(params
|
|
139
|
-
this.
|
|
140
|
-
this.
|
|
141
|
-
this.
|
|
167
|
+
super(params);
|
|
168
|
+
this.lr = options?.lr ?? 0.001;
|
|
169
|
+
this.betas = options?.betas ?? [0.9, 0.999];
|
|
170
|
+
this.eps = options?.eps ?? 1e-8;
|
|
171
|
+
this.weightDecay = options?.weightDecay ?? 0.01;
|
|
142
172
|
}
|
|
143
173
|
step() {
|
|
144
|
-
this.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
174
|
+
for (const paramGroup of this.paramGroups) {
|
|
175
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
176
|
+
const betas = paramGroup.betas ?? this.betas;
|
|
177
|
+
const eps = paramGroup.eps ?? this.eps;
|
|
178
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
179
|
+
for (const param of paramGroup.params) {
|
|
180
|
+
if (!param.grad || !param.requiresGrad)
|
|
181
|
+
continue;
|
|
182
|
+
// Get current step for param, initialize if has not step before
|
|
183
|
+
const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
|
|
184
|
+
this.stepCounts.set(param, stepCount);
|
|
185
|
+
// Bias correction factors
|
|
186
|
+
const [beta1, beta2] = betas;
|
|
187
|
+
const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
|
|
188
|
+
const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
|
|
189
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
190
|
+
// Apply weight decay (L2 regularization)
|
|
191
|
+
detachedParam = detachedParam.sub(detachedParam.mul(weightDecay).mul(lr));
|
|
192
|
+
// Get or initialize first moment buffer (momentum)
|
|
193
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
194
|
+
if (!momentumBuffer) {
|
|
195
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
196
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
197
|
+
}
|
|
198
|
+
// Get or initialize second moment buffer (velocity)
|
|
199
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
200
|
+
if (!velocityBuffer) {
|
|
201
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
202
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
203
|
+
}
|
|
204
|
+
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
205
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
160
206
|
this.momentumBuffers.set(param, momentumBuffer);
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
let velocityBuffer = this.velocityBuffers.get(param);
|
|
164
|
-
if (!velocityBuffer) {
|
|
165
|
-
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
207
|
+
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
208
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
166
209
|
this.velocityBuffers.set(param, velocityBuffer);
|
|
210
|
+
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
211
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
212
|
+
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
213
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
214
|
+
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
215
|
+
const denom = correctedVelocity.sqrt().add(eps);
|
|
216
|
+
const stepSize = correctedMomentum.div(denom).mul(lr);
|
|
217
|
+
const newParam = detachedParam.sub(stepSize);
|
|
218
|
+
param.replace(newParam);
|
|
167
219
|
}
|
|
168
|
-
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
169
|
-
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
170
|
-
this.momentumBuffers.set(param, momentumBuffer);
|
|
171
|
-
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
172
|
-
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
173
|
-
this.velocityBuffers.set(param, velocityBuffer);
|
|
174
|
-
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
175
|
-
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
176
|
-
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
177
|
-
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
178
|
-
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
179
|
-
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
180
|
-
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
181
|
-
const newParam = detachedParam.sub(stepSize);
|
|
182
|
-
param.replace(newParam);
|
|
183
220
|
}
|
|
184
221
|
}
|
|
185
222
|
}
|