catniff 0.8.14 → 0.8.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -143,9 +143,8 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
143
143
 
144
144
  * More general tensor ops.
145
145
  * More general neural net APIs.
146
- * GPU acceleration.
147
- * Comprehensive caching.
148
- * Bug fixes.
146
+ * GPU acceleration, possibly through WebGPU, Libtorch bindings, or CUDA.
147
+ * Proper optimization.
149
148
  * More detailed documentation.
150
149
  * Code refactoring.
151
150
  * Proper tests.
package/dist/core.js CHANGED
@@ -2504,7 +2504,7 @@ class Tensor {
2504
2504
  visited.add(node);
2505
2505
  // Reset grad to zeros if specified
2506
2506
  if (zeroGrad) {
2507
- node.grad = Tensor.zerosLike(node);
2507
+ delete node.grad;
2508
2508
  }
2509
2509
  for (let child of node.children)
2510
2510
  build(child);
@@ -5,9 +5,33 @@ export declare class StepLR {
5
5
  gamma: number;
6
6
  lastEpoch: number;
7
7
  baseLR: number;
8
+ baseGroupLRs: number[];
8
9
  constructor(optimizer: OptimizerWithLR, stepSize: number, gamma?: number, lastEpoch?: number);
9
- step(epoch?: number): void;
10
+ step(): void;
11
+ }
12
+ export declare class LinearLR {
13
+ optimizer: OptimizerWithLR;
14
+ startFactor: number;
15
+ endFactor: number;
16
+ totalIters: number;
17
+ lastEpoch: number;
18
+ baseLR: number;
19
+ baseGroupLRs: number[];
20
+ constructor(optimizer: OptimizerWithLR, startFactor?: number, endFactor?: number, totalIters?: number, lastEpoch?: number);
21
+ step(): void;
22
+ }
23
+ export declare class CosineAnnealingLR {
24
+ optimizer: OptimizerWithLR;
25
+ TMax: number;
26
+ etaMin: number;
27
+ lastEpoch: number;
28
+ baseLR: number;
29
+ baseGroupLRs: number[];
30
+ constructor(optimizer: OptimizerWithLR, TMax: number, etaMin?: number, lastEpoch?: number);
31
+ step(): void;
10
32
  }
11
33
  export declare const LRScheduler: {
12
34
  StepLR: typeof StepLR;
35
+ LinearLR: typeof LinearLR;
36
+ CosineAnnealingLR: typeof CosineAnnealingLR;
13
37
  };
@@ -1,31 +1,91 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.LRScheduler = exports.StepLR = void 0;
3
+ exports.LRScheduler = exports.CosineAnnealingLR = exports.LinearLR = exports.StepLR = void 0;
4
4
  class StepLR {
5
5
  optimizer;
6
6
  stepSize;
7
7
  gamma;
8
8
  lastEpoch;
9
9
  baseLR;
10
+ baseGroupLRs;
10
11
  constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
11
12
  this.optimizer = optimizer;
12
13
  this.stepSize = stepSize;
13
14
  this.gamma = gamma;
14
15
  this.lastEpoch = lastEpoch;
15
- this.baseLR = this.optimizer.lr;
16
+ this.baseLR = optimizer.lr;
17
+ this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
16
18
  }
17
- step(epoch) {
18
- if (typeof epoch === "undefined") {
19
- this.lastEpoch++;
20
- epoch = this.lastEpoch;
19
+ step() {
20
+ this.lastEpoch++;
21
+ // Update LR of each group
22
+ for (let index = 0; index < this.baseGroupLRs.length; index++) {
23
+ this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
21
24
  }
22
- else {
23
- this.lastEpoch = epoch;
24
- }
25
- this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
25
+ // Update default LR
26
+ this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
26
27
  }
27
28
  }
28
29
  exports.StepLR = StepLR;
30
+ class LinearLR {
31
+ optimizer;
32
+ startFactor;
33
+ endFactor;
34
+ totalIters;
35
+ lastEpoch;
36
+ baseLR;
37
+ baseGroupLRs;
38
+ constructor(optimizer, startFactor = 0.3333333333333333, endFactor = 1, totalIters = 5, lastEpoch = -1) {
39
+ this.optimizer = optimizer;
40
+ this.startFactor = startFactor;
41
+ this.endFactor = endFactor;
42
+ this.totalIters = totalIters;
43
+ this.lastEpoch = lastEpoch;
44
+ this.baseLR = optimizer.lr;
45
+ this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
46
+ }
47
+ step() {
48
+ this.lastEpoch++;
49
+ // Clamp under total allowed iterations
50
+ const t = Math.min(this.lastEpoch, this.totalIters);
51
+ // Precalculate factor
52
+ const factor = this.startFactor + (t / this.totalIters) * (this.endFactor - this.startFactor);
53
+ // Update LR of each group
54
+ for (let index = 0; index < this.baseGroupLRs.length; index++) {
55
+ this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * factor;
56
+ }
57
+ // Update default LR
58
+ this.optimizer.lr = this.baseLR * factor;
59
+ }
60
+ }
61
+ exports.LinearLR = LinearLR;
62
+ class CosineAnnealingLR {
63
+ optimizer;
64
+ TMax;
65
+ etaMin;
66
+ lastEpoch;
67
+ baseLR;
68
+ baseGroupLRs;
69
+ constructor(optimizer, TMax, etaMin = 0, lastEpoch = -1) {
70
+ this.optimizer = optimizer;
71
+ this.TMax = TMax;
72
+ this.etaMin = etaMin;
73
+ this.lastEpoch = lastEpoch;
74
+ this.baseLR = optimizer.lr;
75
+ this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
76
+ }
77
+ step() {
78
+ this.lastEpoch++;
79
+ const cosine = (1 + Math.cos((this.lastEpoch * Math.PI) / this.TMax)) / 2;
80
+ for (let index = 0; index < this.baseGroupLRs.length; index++) {
81
+ this.optimizer.paramGroups[index].lr = this.etaMin + (this.baseGroupLRs[index] - this.etaMin) * cosine;
82
+ }
83
+ this.optimizer.lr = this.etaMin + (this.baseLR - this.etaMin) * cosine;
84
+ }
85
+ }
86
+ exports.CosineAnnealingLR = CosineAnnealingLR;
29
87
  exports.LRScheduler = {
30
- StepLR
88
+ StepLR,
89
+ LinearLR,
90
+ CosineAnnealingLR
31
91
  };
package/dist/optim.d.ts CHANGED
@@ -1,11 +1,15 @@
1
1
  import { Tensor } from "./core";
2
- export interface BaseOptimizerOptions {
3
- lr?: number;
2
+ export interface BaseParamGroup {
3
+ params: Tensor[];
4
+ [key: string]: any;
4
5
  }
5
6
  export declare abstract class BaseOptimizer {
6
- params: Tensor[];
7
- constructor(params: Tensor[], options?: BaseOptimizerOptions);
8
- zeroGrad(): void;
7
+ paramGroups: BaseParamGroup[];
8
+ constructor(params: Tensor[] | BaseParamGroup[]);
9
+ zeroGrad(del?: boolean): void;
10
+ }
11
+ export interface OptimizerWithLR extends BaseOptimizer {
12
+ lr: number;
9
13
  }
10
14
  export interface SGDOptions {
11
15
  lr?: number;
@@ -14,17 +18,18 @@ export interface SGDOptions {
14
18
  weightDecay?: number;
15
19
  nesterov?: boolean;
16
20
  }
17
- export interface OptimizerWithLR extends BaseOptimizer {
18
- lr: number;
21
+ export interface SGDParamGroup extends SGDOptions {
22
+ params: Tensor[];
19
23
  }
20
24
  export declare class SGD extends BaseOptimizer {
25
+ paramGroups: SGDParamGroup[];
21
26
  lr: number;
22
- momentumBuffers: Map<Tensor, Tensor>;
23
27
  momentum: number;
24
28
  dampening: number;
25
29
  weightDecay: number;
26
30
  nesterov: boolean;
27
- constructor(params: Tensor[], options?: SGDOptions);
31
+ momentumBuffers: Map<Tensor, Tensor>;
32
+ constructor(params: Tensor[] | SGDParamGroup[], options?: SGDOptions);
28
33
  step(): void;
29
34
  }
30
35
  export interface AdamOptions {
@@ -33,15 +38,19 @@ export interface AdamOptions {
33
38
  eps?: number;
34
39
  weightDecay?: number;
35
40
  }
41
+ export interface AdamParamGroup extends AdamOptions {
42
+ params: Tensor[];
43
+ }
36
44
  export declare class Adam extends BaseOptimizer {
45
+ paramGroups: AdamParamGroup[];
37
46
  lr: number;
38
- momentumBuffers: Map<Tensor, Tensor>;
39
- velocityBuffers: Map<Tensor, Tensor>;
40
- stepCount: number;
41
47
  betas: [number, number];
42
48
  eps: number;
43
49
  weightDecay: number;
44
- constructor(params: Tensor[], options?: AdamOptions);
50
+ momentumBuffers: Map<Tensor, Tensor>;
51
+ velocityBuffers: Map<Tensor, Tensor>;
52
+ stepCounts: Map<Tensor, number>;
53
+ constructor(params: Tensor[] | AdamParamGroup[], options?: AdamOptions);
45
54
  step(): void;
46
55
  }
47
56
  export interface AdamWOptions {
@@ -50,15 +59,19 @@ export interface AdamWOptions {
50
59
  eps?: number;
51
60
  weightDecay?: number;
52
61
  }
62
+ export interface AdamWParamGroup extends AdamWOptions {
63
+ params: Tensor[];
64
+ }
53
65
  export declare class AdamW extends BaseOptimizer {
66
+ paramGroups: AdamWParamGroup[];
54
67
  lr: number;
55
- momentumBuffers: Map<Tensor, Tensor>;
56
- velocityBuffers: Map<Tensor, Tensor>;
57
- stepCount: number;
58
68
  betas: [number, number];
59
69
  eps: number;
60
70
  weightDecay: number;
61
- constructor(params: Tensor[], options?: AdamWOptions);
71
+ momentumBuffers: Map<Tensor, Tensor>;
72
+ velocityBuffers: Map<Tensor, Tensor>;
73
+ stepCounts: Map<Tensor, number>;
74
+ constructor(params: Tensor[] | AdamWParamGroup[], options?: AdamWOptions);
62
75
  step(): void;
63
76
  }
64
77
  export declare const Optim: {
package/dist/optim.js CHANGED
@@ -3,27 +3,39 @@ Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
4
4
  const core_1 = require("./core");
5
5
  class BaseOptimizer {
6
- params;
7
- constructor(params, options) {
8
- this.params = params;
6
+ paramGroups;
7
+ constructor(params) {
8
+ if (params[0] instanceof core_1.Tensor) {
9
+ this.paramGroups = [{ params: params }];
10
+ }
11
+ else {
12
+ this.paramGroups = params;
13
+ }
9
14
  }
10
- zeroGrad() {
11
- for (let index = 0; index < this.params.length; index++) {
12
- const param = this.params[index];
13
- param.grad = core_1.Tensor.zerosLike(param);
15
+ zeroGrad(del = true) {
16
+ for (let index = 0; index < this.paramGroups.length; index++) {
17
+ const paramGroup = this.paramGroups[index];
18
+ for (const param of paramGroup.params) {
19
+ if (del) {
20
+ delete param.grad;
21
+ }
22
+ else {
23
+ param.grad = core_1.Tensor.zerosLike(param);
24
+ }
25
+ }
14
26
  }
15
27
  }
16
28
  }
17
29
  exports.BaseOptimizer = BaseOptimizer;
18
30
  class SGD extends BaseOptimizer {
19
31
  lr;
20
- momentumBuffers = new Map();
21
32
  momentum;
22
33
  dampening;
23
34
  weightDecay;
24
35
  nesterov;
36
+ momentumBuffers = new Map();
25
37
  constructor(params, options) {
26
- super(params, options);
38
+ super(params);
27
39
  this.lr = options?.lr ?? 0.001;
28
40
  this.momentum = options?.momentum ?? 0;
29
41
  this.dampening = options?.dampening ?? 0;
@@ -31,159 +43,180 @@ class SGD extends BaseOptimizer {
31
43
  this.nesterov = options?.nesterov ?? false;
32
44
  }
33
45
  step() {
34
- for (const param of this.params) {
35
- if (!param.grad || !param.requiresGrad)
36
- continue;
37
- let grad = param.grad.detach(), detachedParam = param.detach();
38
- // Apply weight decay (L2 regularization)
39
- if (this.weightDecay !== 0) {
40
- grad = grad.add(detachedParam.mul(this.weightDecay));
41
- }
42
- // Apply momentum
43
- if (this.momentum !== 0) {
44
- let buf = this.momentumBuffers.get(param);
45
- if (!buf) {
46
- // First time: initialize momentum buffer with current gradient
47
- buf = grad.clone();
48
- this.momentumBuffers.set(param, buf);
49
- }
50
- else {
51
- // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
52
- buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
53
- this.momentumBuffers.set(param, buf);
46
+ for (const paramGroup of this.paramGroups) {
47
+ const lr = paramGroup.lr ?? this.lr;
48
+ const momentum = paramGroup.momentum ?? this.momentum;
49
+ const dampening = paramGroup.dampening ?? this.dampening;
50
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
51
+ const nesterov = paramGroup.nesterov ?? this.nesterov;
52
+ for (const param of paramGroup.params) {
53
+ if (!param.grad || !param.requiresGrad)
54
+ continue;
55
+ let grad = param.grad.detach(), detachedParam = param.detach();
56
+ // Apply weight decay (L2 regularization)
57
+ if (weightDecay !== 0) {
58
+ grad = grad.add(detachedParam.mul(weightDecay));
54
59
  }
55
- if (this.nesterov) {
56
- // Nesterov momentum: grad = grad + momentum * buf
57
- grad = grad.add(buf.mul(this.momentum));
58
- }
59
- else {
60
- // Standard momentum: use momentum buffer as gradient
61
- grad = buf;
60
+ // Apply momentum
61
+ if (momentum !== 0) {
62
+ let buf = this.momentumBuffers.get(param);
63
+ if (!buf) {
64
+ // First time: initialize momentum buffer with current gradient
65
+ buf = grad.clone();
66
+ this.momentumBuffers.set(param, buf);
67
+ }
68
+ else {
69
+ // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
70
+ buf = buf.mul(momentum).add(grad.mul(1 - dampening));
71
+ this.momentumBuffers.set(param, buf);
72
+ }
73
+ if (nesterov) {
74
+ // Nesterov momentum: grad = grad + momentum * buf
75
+ grad = grad.add(buf.mul(momentum));
76
+ }
77
+ else {
78
+ // Standard momentum: use momentum buffer as gradient
79
+ grad = buf;
80
+ }
62
81
  }
82
+ // Update parameter: param = param - lr * grad
83
+ const newParam = detachedParam.sub(grad.mul(lr));
84
+ param.replace(newParam);
63
85
  }
64
- // Update parameter: param = param - lr * grad
65
- const newParam = detachedParam.sub(grad.mul(this.lr));
66
- param.replace(newParam);
67
86
  }
68
87
  }
69
88
  }
70
89
  exports.SGD = SGD;
71
90
  class Adam extends BaseOptimizer {
72
91
  lr;
73
- momentumBuffers = new Map(); // First moment (m_t)
74
- velocityBuffers = new Map(); // Second moment (v_t)
75
- stepCount = 0;
76
92
  betas;
77
93
  eps;
78
94
  weightDecay;
95
+ momentumBuffers = new Map(); // First moment (m_t)
96
+ velocityBuffers = new Map(); // Second moment (v_t)
97
+ stepCounts = new Map();
79
98
  constructor(params, options) {
80
- super(params, options);
99
+ super(params);
81
100
  this.lr = options?.lr ?? 0.001;
82
101
  this.betas = options?.betas ?? [0.9, 0.999];
83
102
  this.eps = options?.eps ?? 1e-8;
84
103
  this.weightDecay = options?.weightDecay ?? 0;
85
104
  }
86
105
  step() {
87
- this.stepCount++;
88
- const beta1 = this.betas[0];
89
- const beta2 = this.betas[1];
90
- // Bias correction factors
91
- const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
92
- const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
93
- for (const param of this.params) {
94
- if (!param.grad || !param.requiresGrad)
95
- continue;
96
- let grad = param.grad.detach(), detachedParam = param.detach();
97
- // Apply weight decay (L2 regularization)
98
- if (this.weightDecay !== 0) {
99
- grad = grad.add(detachedParam.mul(this.weightDecay));
100
- }
101
- // Get or initialize first moment buffer (momentum)
102
- let momentumBuffer = this.momentumBuffers.get(param);
103
- if (!momentumBuffer) {
104
- momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
106
+ for (const paramGroup of this.paramGroups) {
107
+ const lr = paramGroup.lr ?? this.lr;
108
+ const betas = paramGroup.betas ?? this.betas;
109
+ const eps = paramGroup.eps ?? this.eps;
110
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
111
+ for (const param of paramGroup.params) {
112
+ if (!param.grad || !param.requiresGrad)
113
+ continue;
114
+ // Get current step for param, initialize if has not step before
115
+ const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
116
+ this.stepCounts.set(param, stepCount);
117
+ // Bias correction factors
118
+ const [beta1, beta2] = betas;
119
+ const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
120
+ const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
121
+ let grad = param.grad.detach(), detachedParam = param.detach();
122
+ // Apply weight decay (L2 regularization)
123
+ if (weightDecay !== 0) {
124
+ grad = grad.add(detachedParam.mul(weightDecay));
125
+ }
126
+ // Get or initialize first moment buffer (momentum)
127
+ let momentumBuffer = this.momentumBuffers.get(param);
128
+ if (!momentumBuffer) {
129
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
130
+ this.momentumBuffers.set(param, momentumBuffer);
131
+ }
132
+ // Get or initialize second moment buffer (velocity)
133
+ let velocityBuffer = this.velocityBuffers.get(param);
134
+ if (!velocityBuffer) {
135
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
136
+ this.velocityBuffers.set(param, velocityBuffer);
137
+ }
138
+ // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
139
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
105
140
  this.momentumBuffers.set(param, momentumBuffer);
106
- }
107
- // Get or initialize second moment buffer (velocity)
108
- let velocityBuffer = this.velocityBuffers.get(param);
109
- if (!velocityBuffer) {
110
- velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
141
+ // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
142
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
111
143
  this.velocityBuffers.set(param, velocityBuffer);
144
+ // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
145
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
146
+ // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
147
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
148
+ // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
149
+ const denom = correctedVelocity.sqrt().add(eps);
150
+ const stepSize = correctedMomentum.div(denom).mul(lr);
151
+ const newParam = detachedParam.sub(stepSize);
152
+ param.replace(newParam);
112
153
  }
113
- // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
114
- momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
115
- this.momentumBuffers.set(param, momentumBuffer);
116
- // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
117
- velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
118
- this.velocityBuffers.set(param, velocityBuffer);
119
- // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
120
- const correctedMomentum = momentumBuffer.div(biasCorrection1);
121
- // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
122
- const correctedVelocity = velocityBuffer.div(biasCorrection2);
123
- // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
124
- const denom = correctedVelocity.sqrt().add(this.eps);
125
- const stepSize = correctedMomentum.div(denom).mul(this.lr);
126
- const newParam = detachedParam.sub(stepSize);
127
- param.replace(newParam);
128
154
  }
129
155
  }
130
156
  }
131
157
  exports.Adam = Adam;
132
158
  class AdamW extends BaseOptimizer {
133
159
  lr;
134
- momentumBuffers = new Map(); // First moment (m_t)
135
- velocityBuffers = new Map(); // Second moment (v_t)
136
- stepCount = 0;
137
160
  betas;
138
161
  eps;
139
162
  weightDecay;
163
+ momentumBuffers = new Map(); // First moment (m_t)
164
+ velocityBuffers = new Map(); // Second moment (v_t)
165
+ stepCounts = new Map();
140
166
  constructor(params, options) {
141
- super(params, options);
167
+ super(params);
142
168
  this.lr = options?.lr ?? 0.001;
143
169
  this.betas = options?.betas ?? [0.9, 0.999];
144
170
  this.eps = options?.eps ?? 1e-8;
145
171
  this.weightDecay = options?.weightDecay ?? 0.01;
146
172
  }
147
173
  step() {
148
- this.stepCount++;
149
- const beta1 = this.betas[0];
150
- const beta2 = this.betas[1];
151
- // Bias correction factors
152
- const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
153
- const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
154
- for (const param of this.params) {
155
- if (!param.grad || !param.requiresGrad)
156
- continue;
157
- let grad = param.grad.detach(), detachedParam = param.detach();
158
- // Apply weight decay (L2 regularization)
159
- detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
160
- // Get or initialize first moment buffer (momentum)
161
- let momentumBuffer = this.momentumBuffers.get(param);
162
- if (!momentumBuffer) {
163
- momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
174
+ for (const paramGroup of this.paramGroups) {
175
+ const lr = paramGroup.lr ?? this.lr;
176
+ const betas = paramGroup.betas ?? this.betas;
177
+ const eps = paramGroup.eps ?? this.eps;
178
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
179
+ for (const param of paramGroup.params) {
180
+ if (!param.grad || !param.requiresGrad)
181
+ continue;
182
+ // Get current step for param, initialize if has not step before
183
+ const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
184
+ this.stepCounts.set(param, stepCount);
185
+ // Bias correction factors
186
+ const [beta1, beta2] = betas;
187
+ const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
188
+ const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
189
+ let grad = param.grad.detach(), detachedParam = param.detach();
190
+ // Apply weight decay (L2 regularization)
191
+ detachedParam = detachedParam.sub(detachedParam.mul(weightDecay).mul(lr));
192
+ // Get or initialize first moment buffer (momentum)
193
+ let momentumBuffer = this.momentumBuffers.get(param);
194
+ if (!momentumBuffer) {
195
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
196
+ this.momentumBuffers.set(param, momentumBuffer);
197
+ }
198
+ // Get or initialize second moment buffer (velocity)
199
+ let velocityBuffer = this.velocityBuffers.get(param);
200
+ if (!velocityBuffer) {
201
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
202
+ this.velocityBuffers.set(param, velocityBuffer);
203
+ }
204
+ // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
205
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
164
206
  this.momentumBuffers.set(param, momentumBuffer);
165
- }
166
- // Get or initialize second moment buffer (velocity)
167
- let velocityBuffer = this.velocityBuffers.get(param);
168
- if (!velocityBuffer) {
169
- velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
207
+ // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
208
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
170
209
  this.velocityBuffers.set(param, velocityBuffer);
210
+ // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
211
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
212
+ // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
213
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
214
+ // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
215
+ const denom = correctedVelocity.sqrt().add(eps);
216
+ const stepSize = correctedMomentum.div(denom).mul(lr);
217
+ const newParam = detachedParam.sub(stepSize);
218
+ param.replace(newParam);
171
219
  }
172
- // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
173
- momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
174
- this.momentumBuffers.set(param, momentumBuffer);
175
- // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
176
- velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
177
- this.velocityBuffers.set(param, velocityBuffer);
178
- // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
179
- const correctedMomentum = momentumBuffer.div(biasCorrection1);
180
- // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
181
- const correctedVelocity = velocityBuffer.div(biasCorrection2);
182
- // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
183
- const denom = correctedVelocity.sqrt().add(this.eps);
184
- const stepSize = correctedMomentum.div(denom).mul(this.lr);
185
- const newParam = detachedParam.sub(stepSize);
186
- param.replace(newParam);
187
220
  }
188
221
  }
189
222
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.14",
3
+ "version": "0.8.16",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {