catniff 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/optim.d.ts +19 -0
- package/dist/optim.js +65 -1
- package/package.json +1 -1
package/dist/optim.d.ts
CHANGED
|
@@ -17,7 +17,26 @@ declare class SGD {
|
|
|
17
17
|
constructor(params: Tensor[], options?: SGDOptions);
|
|
18
18
|
step(): void;
|
|
19
19
|
}
|
|
20
|
+
export interface AdamOptions {
|
|
21
|
+
lr?: number;
|
|
22
|
+
betas?: [number, number];
|
|
23
|
+
eps?: number;
|
|
24
|
+
weightDecay?: number;
|
|
25
|
+
}
|
|
26
|
+
declare class Adam {
|
|
27
|
+
params: Tensor[];
|
|
28
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
29
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
30
|
+
stepCount: number;
|
|
31
|
+
lr: number;
|
|
32
|
+
betas: [number, number];
|
|
33
|
+
eps: number;
|
|
34
|
+
weightDecay: number;
|
|
35
|
+
constructor(params: Tensor[], options?: AdamOptions);
|
|
36
|
+
step(): void;
|
|
37
|
+
}
|
|
20
38
|
export declare class Optim {
|
|
21
39
|
static SGD: typeof SGD;
|
|
40
|
+
static Adam: typeof Adam;
|
|
22
41
|
}
|
|
23
42
|
export {};
|
package/dist/optim.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.Optim = void 0;
|
|
4
|
+
const core_1 = require("./core");
|
|
4
5
|
class SGD {
|
|
5
6
|
params;
|
|
6
7
|
momentumBuffers = new Map();
|
|
@@ -19,7 +20,7 @@ class SGD {
|
|
|
19
20
|
}
|
|
20
21
|
step() {
|
|
21
22
|
for (const param of this.params) {
|
|
22
|
-
if (
|
|
23
|
+
if (!param.grad) {
|
|
23
24
|
throw new Error("Can not apply SGD on empty grad");
|
|
24
25
|
}
|
|
25
26
|
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
@@ -55,7 +56,70 @@ class SGD {
|
|
|
55
56
|
}
|
|
56
57
|
}
|
|
57
58
|
}
|
|
59
|
+
class Adam {
|
|
60
|
+
params;
|
|
61
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
62
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
63
|
+
stepCount = 0;
|
|
64
|
+
lr;
|
|
65
|
+
betas;
|
|
66
|
+
eps;
|
|
67
|
+
weightDecay;
|
|
68
|
+
constructor(params, options) {
|
|
69
|
+
this.params = params;
|
|
70
|
+
this.lr = options?.lr || 0.001;
|
|
71
|
+
this.betas = options?.betas || [0.9, 0.999];
|
|
72
|
+
this.eps = options?.eps || 1e-8;
|
|
73
|
+
this.weightDecay = options?.weightDecay || 0;
|
|
74
|
+
}
|
|
75
|
+
step() {
|
|
76
|
+
this.stepCount++;
|
|
77
|
+
const beta1 = this.betas[0];
|
|
78
|
+
const beta2 = this.betas[1];
|
|
79
|
+
// Bias correction factors
|
|
80
|
+
const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
|
|
81
|
+
const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
|
|
82
|
+
for (const param of this.params) {
|
|
83
|
+
if (!param.grad) {
|
|
84
|
+
throw new Error("Can not apply Adam on empty grad");
|
|
85
|
+
}
|
|
86
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
87
|
+
// Apply weight decay (L2 regularization)
|
|
88
|
+
if (this.weightDecay !== 0) {
|
|
89
|
+
grad = grad.add(detachedParam.mul(this.weightDecay));
|
|
90
|
+
}
|
|
91
|
+
// Get or initialize first moment buffer (momentum)
|
|
92
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
93
|
+
if (!momentumBuffer) {
|
|
94
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
95
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
96
|
+
}
|
|
97
|
+
// Get or initialize second moment buffer (velocity)
|
|
98
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
99
|
+
if (!velocityBuffer) {
|
|
100
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
101
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
102
|
+
}
|
|
103
|
+
// Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
|
|
104
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
105
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
106
|
+
// Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
|
|
107
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
108
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
109
|
+
// Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
|
|
110
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
111
|
+
// Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
|
|
112
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
113
|
+
// Update parameters: θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε)
|
|
114
|
+
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
115
|
+
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
116
|
+
const newParam = detachedParam.sub(stepSize);
|
|
117
|
+
param.replace(newParam);
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
58
121
|
class Optim {
|
|
59
122
|
static SGD = SGD;
|
|
123
|
+
static Adam = Adam;
|
|
60
124
|
}
|
|
61
125
|
exports.Optim = Optim;
|