catniff 0.8.14 → 0.8.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -3
- package/dist/core.js +1 -1
- package/dist/lrscheduler.d.ts +25 -1
- package/dist/lrscheduler.js +71 -11
- package/dist/optim.d.ts +30 -17
- package/dist/optim.js +154 -121
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -143,9 +143,8 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
|
|
|
143
143
|
|
|
144
144
|
* More general tensor ops.
|
|
145
145
|
* More general neural net APIs.
|
|
146
|
-
* GPU acceleration.
|
|
147
|
-
*
|
|
148
|
-
* Bug fixes.
|
|
146
|
+
* GPU acceleration, possibly through WebGPU, Libtorch bindings, or CUDA.
|
|
147
|
+
* Proper optimization.
|
|
149
148
|
* More detailed documentation.
|
|
150
149
|
* Code refactoring.
|
|
151
150
|
* Proper tests.
|
package/dist/core.js
CHANGED
package/dist/lrscheduler.d.ts
CHANGED
|
@@ -5,9 +5,33 @@ export declare class StepLR {
|
|
|
5
5
|
gamma: number;
|
|
6
6
|
lastEpoch: number;
|
|
7
7
|
baseLR: number;
|
|
8
|
+
baseGroupLRs: number[];
|
|
8
9
|
constructor(optimizer: OptimizerWithLR, stepSize: number, gamma?: number, lastEpoch?: number);
|
|
9
|
-
step(
|
|
10
|
+
step(): void;
|
|
11
|
+
}
|
|
12
|
+
export declare class LinearLR {
|
|
13
|
+
optimizer: OptimizerWithLR;
|
|
14
|
+
startFactor: number;
|
|
15
|
+
endFactor: number;
|
|
16
|
+
totalIters: number;
|
|
17
|
+
lastEpoch: number;
|
|
18
|
+
baseLR: number;
|
|
19
|
+
baseGroupLRs: number[];
|
|
20
|
+
constructor(optimizer: OptimizerWithLR, startFactor?: number, endFactor?: number, totalIters?: number, lastEpoch?: number);
|
|
21
|
+
step(): void;
|
|
22
|
+
}
|
|
23
|
+
export declare class CosineAnnealingLR {
|
|
24
|
+
optimizer: OptimizerWithLR;
|
|
25
|
+
TMax: number;
|
|
26
|
+
etaMin: number;
|
|
27
|
+
lastEpoch: number;
|
|
28
|
+
baseLR: number;
|
|
29
|
+
baseGroupLRs: number[];
|
|
30
|
+
constructor(optimizer: OptimizerWithLR, TMax: number, etaMin?: number, lastEpoch?: number);
|
|
31
|
+
step(): void;
|
|
10
32
|
}
|
|
11
33
|
export declare const LRScheduler: {
|
|
12
34
|
StepLR: typeof StepLR;
|
|
35
|
+
LinearLR: typeof LinearLR;
|
|
36
|
+
CosineAnnealingLR: typeof CosineAnnealingLR;
|
|
13
37
|
};
|
package/dist/lrscheduler.js
CHANGED
|
@@ -1,31 +1,91 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.LRScheduler = exports.StepLR = void 0;
|
|
3
|
+
exports.LRScheduler = exports.CosineAnnealingLR = exports.LinearLR = exports.StepLR = void 0;
|
|
4
4
|
class StepLR {
|
|
5
5
|
optimizer;
|
|
6
6
|
stepSize;
|
|
7
7
|
gamma;
|
|
8
8
|
lastEpoch;
|
|
9
9
|
baseLR;
|
|
10
|
+
baseGroupLRs;
|
|
10
11
|
constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
|
|
11
12
|
this.optimizer = optimizer;
|
|
12
13
|
this.stepSize = stepSize;
|
|
13
14
|
this.gamma = gamma;
|
|
14
15
|
this.lastEpoch = lastEpoch;
|
|
15
|
-
this.baseLR =
|
|
16
|
+
this.baseLR = optimizer.lr;
|
|
17
|
+
this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
|
|
16
18
|
}
|
|
17
|
-
step(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
step() {
|
|
20
|
+
this.lastEpoch++;
|
|
21
|
+
// Update LR of each group
|
|
22
|
+
for (let index = 0; index < this.baseGroupLRs.length; index++) {
|
|
23
|
+
this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
|
|
21
24
|
}
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
}
|
|
25
|
-
this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
|
|
25
|
+
// Update default LR
|
|
26
|
+
this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
|
|
26
27
|
}
|
|
27
28
|
}
|
|
28
29
|
exports.StepLR = StepLR;
|
|
30
|
+
class LinearLR {
|
|
31
|
+
optimizer;
|
|
32
|
+
startFactor;
|
|
33
|
+
endFactor;
|
|
34
|
+
totalIters;
|
|
35
|
+
lastEpoch;
|
|
36
|
+
baseLR;
|
|
37
|
+
baseGroupLRs;
|
|
38
|
+
constructor(optimizer, startFactor = 0.3333333333333333, endFactor = 1, totalIters = 5, lastEpoch = -1) {
|
|
39
|
+
this.optimizer = optimizer;
|
|
40
|
+
this.startFactor = startFactor;
|
|
41
|
+
this.endFactor = endFactor;
|
|
42
|
+
this.totalIters = totalIters;
|
|
43
|
+
this.lastEpoch = lastEpoch;
|
|
44
|
+
this.baseLR = optimizer.lr;
|
|
45
|
+
this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
|
|
46
|
+
}
|
|
47
|
+
step() {
|
|
48
|
+
this.lastEpoch++;
|
|
49
|
+
// Clamp under total allowed iterations
|
|
50
|
+
const t = Math.min(this.lastEpoch, this.totalIters);
|
|
51
|
+
// Precalculate factor
|
|
52
|
+
const factor = this.startFactor + (t / this.totalIters) * (this.endFactor - this.startFactor);
|
|
53
|
+
// Update LR of each group
|
|
54
|
+
for (let index = 0; index < this.baseGroupLRs.length; index++) {
|
|
55
|
+
this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * factor;
|
|
56
|
+
}
|
|
57
|
+
// Update default LR
|
|
58
|
+
this.optimizer.lr = this.baseLR * factor;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
exports.LinearLR = LinearLR;
|
|
62
|
+
class CosineAnnealingLR {
|
|
63
|
+
optimizer;
|
|
64
|
+
TMax;
|
|
65
|
+
etaMin;
|
|
66
|
+
lastEpoch;
|
|
67
|
+
baseLR;
|
|
68
|
+
baseGroupLRs;
|
|
69
|
+
constructor(optimizer, TMax, etaMin = 0, lastEpoch = -1) {
|
|
70
|
+
this.optimizer = optimizer;
|
|
71
|
+
this.TMax = TMax;
|
|
72
|
+
this.etaMin = etaMin;
|
|
73
|
+
this.lastEpoch = lastEpoch;
|
|
74
|
+
this.baseLR = optimizer.lr;
|
|
75
|
+
this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
|
|
76
|
+
}
|
|
77
|
+
step() {
|
|
78
|
+
this.lastEpoch++;
|
|
79
|
+
const cosine = (1 + Math.cos((this.lastEpoch * Math.PI) / this.TMax)) / 2;
|
|
80
|
+
for (let index = 0; index < this.baseGroupLRs.length; index++) {
|
|
81
|
+
this.optimizer.paramGroups[index].lr = this.etaMin + (this.baseGroupLRs[index] - this.etaMin) * cosine;
|
|
82
|
+
}
|
|
83
|
+
this.optimizer.lr = this.etaMin + (this.baseLR - this.etaMin) * cosine;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
exports.CosineAnnealingLR = CosineAnnealingLR;
|
|
29
87
|
exports.LRScheduler = {
|
|
30
|
-
StepLR
|
|
88
|
+
StepLR,
|
|
89
|
+
LinearLR,
|
|
90
|
+
CosineAnnealingLR
|
|
31
91
|
};
|
package/dist/optim.d.ts
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import { Tensor } from "./core";
|
|
2
|
-
export interface
|
|
3
|
-
|
|
2
|
+
export interface BaseParamGroup {
|
|
3
|
+
params: Tensor[];
|
|
4
|
+
[key: string]: any;
|
|
4
5
|
}
|
|
5
6
|
export declare abstract class BaseOptimizer {
|
|
6
|
-
|
|
7
|
-
constructor(params: Tensor[]
|
|
8
|
-
zeroGrad(): void;
|
|
7
|
+
paramGroups: BaseParamGroup[];
|
|
8
|
+
constructor(params: Tensor[] | BaseParamGroup[]);
|
|
9
|
+
zeroGrad(del?: boolean): void;
|
|
10
|
+
}
|
|
11
|
+
export interface OptimizerWithLR extends BaseOptimizer {
|
|
12
|
+
lr: number;
|
|
9
13
|
}
|
|
10
14
|
export interface SGDOptions {
|
|
11
15
|
lr?: number;
|
|
@@ -14,17 +18,18 @@ export interface SGDOptions {
|
|
|
14
18
|
weightDecay?: number;
|
|
15
19
|
nesterov?: boolean;
|
|
16
20
|
}
|
|
17
|
-
export interface
|
|
18
|
-
|
|
21
|
+
export interface SGDParamGroup extends SGDOptions {
|
|
22
|
+
params: Tensor[];
|
|
19
23
|
}
|
|
20
24
|
export declare class SGD extends BaseOptimizer {
|
|
25
|
+
paramGroups: SGDParamGroup[];
|
|
21
26
|
lr: number;
|
|
22
|
-
momentumBuffers: Map<Tensor, Tensor>;
|
|
23
27
|
momentum: number;
|
|
24
28
|
dampening: number;
|
|
25
29
|
weightDecay: number;
|
|
26
30
|
nesterov: boolean;
|
|
27
|
-
|
|
31
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
32
|
+
constructor(params: Tensor[] | SGDParamGroup[], options?: SGDOptions);
|
|
28
33
|
step(): void;
|
|
29
34
|
}
|
|
30
35
|
export interface AdamOptions {
|
|
@@ -33,15 +38,19 @@ export interface AdamOptions {
|
|
|
33
38
|
eps?: number;
|
|
34
39
|
weightDecay?: number;
|
|
35
40
|
}
|
|
41
|
+
export interface AdamParamGroup extends AdamOptions {
|
|
42
|
+
params: Tensor[];
|
|
43
|
+
}
|
|
36
44
|
export declare class Adam extends BaseOptimizer {
|
|
45
|
+
paramGroups: AdamParamGroup[];
|
|
37
46
|
lr: number;
|
|
38
|
-
momentumBuffers: Map<Tensor, Tensor>;
|
|
39
|
-
velocityBuffers: Map<Tensor, Tensor>;
|
|
40
|
-
stepCount: number;
|
|
41
47
|
betas: [number, number];
|
|
42
48
|
eps: number;
|
|
43
49
|
weightDecay: number;
|
|
44
|
-
|
|
50
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
51
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
52
|
+
stepCounts: Map<Tensor, number>;
|
|
53
|
+
constructor(params: Tensor[] | AdamParamGroup[], options?: AdamOptions);
|
|
45
54
|
step(): void;
|
|
46
55
|
}
|
|
47
56
|
export interface AdamWOptions {
|
|
@@ -50,15 +59,19 @@ export interface AdamWOptions {
|
|
|
50
59
|
eps?: number;
|
|
51
60
|
weightDecay?: number;
|
|
52
61
|
}
|
|
62
|
+
export interface AdamWParamGroup extends AdamWOptions {
|
|
63
|
+
params: Tensor[];
|
|
64
|
+
}
|
|
53
65
|
export declare class AdamW extends BaseOptimizer {
|
|
66
|
+
paramGroups: AdamWParamGroup[];
|
|
54
67
|
lr: number;
|
|
55
|
-
momentumBuffers: Map<Tensor, Tensor>;
|
|
56
|
-
velocityBuffers: Map<Tensor, Tensor>;
|
|
57
|
-
stepCount: number;
|
|
58
68
|
betas: [number, number];
|
|
59
69
|
eps: number;
|
|
60
70
|
weightDecay: number;
|
|
61
|
-
|
|
71
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
72
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
73
|
+
stepCounts: Map<Tensor, number>;
|
|
74
|
+
constructor(params: Tensor[] | AdamWParamGroup[], options?: AdamWOptions);
|
|
62
75
|
step(): void;
|
|
63
76
|
}
|
|
64
77
|
export declare const Optim: {
|
package/dist/optim.js
CHANGED
|
@@ -3,27 +3,39 @@ Object.defineProperty(exports, "__esModule", { value: true });
|
|
|
3
3
|
exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class BaseOptimizer {
|
|
6
|
-
|
|
7
|
-
constructor(params
|
|
8
|
-
|
|
6
|
+
paramGroups;
|
|
7
|
+
constructor(params) {
|
|
8
|
+
if (params[0] instanceof core_1.Tensor) {
|
|
9
|
+
this.paramGroups = [{ params: params }];
|
|
10
|
+
}
|
|
11
|
+
else {
|
|
12
|
+
this.paramGroups = params;
|
|
13
|
+
}
|
|
9
14
|
}
|
|
10
|
-
zeroGrad() {
|
|
11
|
-
for (let index = 0; index < this.
|
|
12
|
-
const
|
|
13
|
-
param
|
|
15
|
+
zeroGrad(del = true) {
|
|
16
|
+
for (let index = 0; index < this.paramGroups.length; index++) {
|
|
17
|
+
const paramGroup = this.paramGroups[index];
|
|
18
|
+
for (const param of paramGroup.params) {
|
|
19
|
+
if (del) {
|
|
20
|
+
delete param.grad;
|
|
21
|
+
}
|
|
22
|
+
else {
|
|
23
|
+
param.grad = core_1.Tensor.zerosLike(param);
|
|
24
|
+
}
|
|
25
|
+
}
|
|
14
26
|
}
|
|
15
27
|
}
|
|
16
28
|
}
|
|
17
29
|
exports.BaseOptimizer = BaseOptimizer;
|
|
18
30
|
class SGD extends BaseOptimizer {
|
|
19
31
|
lr;
|
|
20
|
-
momentumBuffers = new Map();
|
|
21
32
|
momentum;
|
|
22
33
|
dampening;
|
|
23
34
|
weightDecay;
|
|
24
35
|
nesterov;
|
|
36
|
+
momentumBuffers = new Map();
|
|
25
37
|
constructor(params, options) {
|
|
26
|
-
super(params
|
|
38
|
+
super(params);
|
|
27
39
|
this.lr = options?.lr ?? 0.001;
|
|
28
40
|
this.momentum = options?.momentum ?? 0;
|
|
29
41
|
this.dampening = options?.dampening ?? 0;
|
|
@@ -31,159 +43,180 @@ class SGD extends BaseOptimizer {
|
|
|
31
43
|
this.nesterov = options?.nesterov ?? false;
|
|
32
44
|
}
|
|
33
45
|
step() {
|
|
34
|
-
for (const
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if (
|
|
46
|
-
|
|
47
|
-
buf = grad.clone();
|
|
48
|
-
this.momentumBuffers.set(param, buf);
|
|
49
|
-
}
|
|
50
|
-
else {
|
|
51
|
-
// Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
|
|
52
|
-
buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
|
|
53
|
-
this.momentumBuffers.set(param, buf);
|
|
46
|
+
for (const paramGroup of this.paramGroups) {
|
|
47
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
48
|
+
const momentum = paramGroup.momentum ?? this.momentum;
|
|
49
|
+
const dampening = paramGroup.dampening ?? this.dampening;
|
|
50
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
51
|
+
const nesterov = paramGroup.nesterov ?? this.nesterov;
|
|
52
|
+
for (const param of paramGroup.params) {
|
|
53
|
+
if (!param.grad || !param.requiresGrad)
|
|
54
|
+
continue;
|
|
55
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
56
|
+
// Apply weight decay (L2 regularization)
|
|
57
|
+
if (weightDecay !== 0) {
|
|
58
|
+
grad = grad.add(detachedParam.mul(weightDecay));
|
|
54
59
|
}
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
// Apply momentum
|
|
61
|
+
if (momentum !== 0) {
|
|
62
|
+
let buf = this.momentumBuffers.get(param);
|
|
63
|
+
if (!buf) {
|
|
64
|
+
// First time: initialize momentum buffer with current gradient
|
|
65
|
+
buf = grad.clone();
|
|
66
|
+
this.momentumBuffers.set(param, buf);
|
|
67
|
+
}
|
|
68
|
+
else {
|
|
69
|
+
// Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
|
|
70
|
+
buf = buf.mul(momentum).add(grad.mul(1 - dampening));
|
|
71
|
+
this.momentumBuffers.set(param, buf);
|
|
72
|
+
}
|
|
73
|
+
if (nesterov) {
|
|
74
|
+
// Nesterov momentum: grad = grad + momentum * buf
|
|
75
|
+
grad = grad.add(buf.mul(momentum));
|
|
76
|
+
}
|
|
77
|
+
else {
|
|
78
|
+
// Standard momentum: use momentum buffer as gradient
|
|
79
|
+
grad = buf;
|
|
80
|
+
}
|
|
62
81
|
}
|
|
82
|
+
// Update parameter: param = param - lr * grad
|
|
83
|
+
const newParam = detachedParam.sub(grad.mul(lr));
|
|
84
|
+
param.replace(newParam);
|
|
63
85
|
}
|
|
64
|
-
// Update parameter: param = param - lr * grad
|
|
65
|
-
const newParam = detachedParam.sub(grad.mul(this.lr));
|
|
66
|
-
param.replace(newParam);
|
|
67
86
|
}
|
|
68
87
|
}
|
|
69
88
|
}
|
|
70
89
|
exports.SGD = SGD;
|
|
71
90
|
class Adam extends BaseOptimizer {
|
|
72
91
|
lr;
|
|
73
|
-
momentumBuffers = new Map(); // First moment (m_t)
|
|
74
|
-
velocityBuffers = new Map(); // Second moment (v_t)
|
|
75
|
-
stepCount = 0;
|
|
76
92
|
betas;
|
|
77
93
|
eps;
|
|
78
94
|
weightDecay;
|
|
95
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
96
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
97
|
+
stepCounts = new Map();
|
|
79
98
|
constructor(params, options) {
|
|
80
|
-
super(params
|
|
99
|
+
super(params);
|
|
81
100
|
this.lr = options?.lr ?? 0.001;
|
|
82
101
|
this.betas = options?.betas ?? [0.9, 0.999];
|
|
83
102
|
this.eps = options?.eps ?? 1e-8;
|
|
84
103
|
this.weightDecay = options?.weightDecay ?? 0;
|
|
85
104
|
}
|
|
86
105
|
step() {
|
|
87
|
-
this.
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
for (const paramGroup of this.paramGroups) {
|
|
107
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
108
|
+
const betas = paramGroup.betas ?? this.betas;
|
|
109
|
+
const eps = paramGroup.eps ?? this.eps;
|
|
110
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
111
|
+
for (const param of paramGroup.params) {
|
|
112
|
+
if (!param.grad || !param.requiresGrad)
|
|
113
|
+
continue;
|
|
114
|
+
// Get current step for param, initialize if has not step before
|
|
115
|
+
const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
|
|
116
|
+
this.stepCounts.set(param, stepCount);
|
|
117
|
+
// Bias correction factors
|
|
118
|
+
const [beta1, beta2] = betas;
|
|
119
|
+
const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
|
|
120
|
+
const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
|
|
121
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
122
|
+
// Apply weight decay (L2 regularization)
|
|
123
|
+
if (weightDecay !== 0) {
|
|
124
|
+
grad = grad.add(detachedParam.mul(weightDecay));
|
|
125
|
+
}
|
|
126
|
+
// Get or initialize first moment buffer (momentum)
|
|
127
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
128
|
+
if (!momentumBuffer) {
|
|
129
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
130
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
131
|
+
}
|
|
132
|
+
// Get or initialize second moment buffer (velocity)
|
|
133
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
134
|
+
if (!velocityBuffer) {
|
|
135
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
136
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
137
|
+
}
|
|
138
|
+
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
139
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
105
140
|
this.momentumBuffers.set(param, momentumBuffer);
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
let velocityBuffer = this.velocityBuffers.get(param);
|
|
109
|
-
if (!velocityBuffer) {
|
|
110
|
-
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
141
|
+
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
142
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
111
143
|
this.velocityBuffers.set(param, velocityBuffer);
|
|
144
|
+
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
145
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
146
|
+
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
147
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
148
|
+
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
149
|
+
const denom = correctedVelocity.sqrt().add(eps);
|
|
150
|
+
const stepSize = correctedMomentum.div(denom).mul(lr);
|
|
151
|
+
const newParam = detachedParam.sub(stepSize);
|
|
152
|
+
param.replace(newParam);
|
|
112
153
|
}
|
|
113
|
-
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
114
|
-
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
115
|
-
this.momentumBuffers.set(param, momentumBuffer);
|
|
116
|
-
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
117
|
-
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
118
|
-
this.velocityBuffers.set(param, velocityBuffer);
|
|
119
|
-
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
120
|
-
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
121
|
-
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
122
|
-
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
123
|
-
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
124
|
-
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
125
|
-
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
126
|
-
const newParam = detachedParam.sub(stepSize);
|
|
127
|
-
param.replace(newParam);
|
|
128
154
|
}
|
|
129
155
|
}
|
|
130
156
|
}
|
|
131
157
|
exports.Adam = Adam;
|
|
132
158
|
class AdamW extends BaseOptimizer {
|
|
133
159
|
lr;
|
|
134
|
-
momentumBuffers = new Map(); // First moment (m_t)
|
|
135
|
-
velocityBuffers = new Map(); // Second moment (v_t)
|
|
136
|
-
stepCount = 0;
|
|
137
160
|
betas;
|
|
138
161
|
eps;
|
|
139
162
|
weightDecay;
|
|
163
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
164
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
165
|
+
stepCounts = new Map();
|
|
140
166
|
constructor(params, options) {
|
|
141
|
-
super(params
|
|
167
|
+
super(params);
|
|
142
168
|
this.lr = options?.lr ?? 0.001;
|
|
143
169
|
this.betas = options?.betas ?? [0.9, 0.999];
|
|
144
170
|
this.eps = options?.eps ?? 1e-8;
|
|
145
171
|
this.weightDecay = options?.weightDecay ?? 0.01;
|
|
146
172
|
}
|
|
147
173
|
step() {
|
|
148
|
-
this.
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
174
|
+
for (const paramGroup of this.paramGroups) {
|
|
175
|
+
const lr = paramGroup.lr ?? this.lr;
|
|
176
|
+
const betas = paramGroup.betas ?? this.betas;
|
|
177
|
+
const eps = paramGroup.eps ?? this.eps;
|
|
178
|
+
const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
|
|
179
|
+
for (const param of paramGroup.params) {
|
|
180
|
+
if (!param.grad || !param.requiresGrad)
|
|
181
|
+
continue;
|
|
182
|
+
// Get current step for param, initialize if has not step before
|
|
183
|
+
const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
|
|
184
|
+
this.stepCounts.set(param, stepCount);
|
|
185
|
+
// Bias correction factors
|
|
186
|
+
const [beta1, beta2] = betas;
|
|
187
|
+
const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
|
|
188
|
+
const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
|
|
189
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
190
|
+
// Apply weight decay (L2 regularization)
|
|
191
|
+
detachedParam = detachedParam.sub(detachedParam.mul(weightDecay).mul(lr));
|
|
192
|
+
// Get or initialize first moment buffer (momentum)
|
|
193
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
194
|
+
if (!momentumBuffer) {
|
|
195
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
196
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
197
|
+
}
|
|
198
|
+
// Get or initialize second moment buffer (velocity)
|
|
199
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
200
|
+
if (!velocityBuffer) {
|
|
201
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
202
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
203
|
+
}
|
|
204
|
+
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
205
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
164
206
|
this.momentumBuffers.set(param, momentumBuffer);
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
let velocityBuffer = this.velocityBuffers.get(param);
|
|
168
|
-
if (!velocityBuffer) {
|
|
169
|
-
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
207
|
+
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
208
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
170
209
|
this.velocityBuffers.set(param, velocityBuffer);
|
|
210
|
+
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
211
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
212
|
+
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
213
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
214
|
+
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
215
|
+
const denom = correctedVelocity.sqrt().add(eps);
|
|
216
|
+
const stepSize = correctedMomentum.div(denom).mul(lr);
|
|
217
|
+
const newParam = detachedParam.sub(stepSize);
|
|
218
|
+
param.replace(newParam);
|
|
171
219
|
}
|
|
172
|
-
// Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
|
|
173
|
-
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
174
|
-
this.momentumBuffers.set(param, momentumBuffer);
|
|
175
|
-
// Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
|
|
176
|
-
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
177
|
-
this.velocityBuffers.set(param, velocityBuffer);
|
|
178
|
-
// Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
|
|
179
|
-
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
180
|
-
// Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
|
|
181
|
-
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
182
|
-
// Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
|
|
183
|
-
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
184
|
-
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
185
|
-
const newParam = detachedParam.sub(stepSize);
|
|
186
|
-
param.replace(newParam);
|
|
187
220
|
}
|
|
188
221
|
}
|
|
189
222
|
}
|