catniff 0.8.13 → 0.8.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -143,9 +143,8 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
143
143
 
144
144
  * More general tensor ops.
145
145
  * More general neural net APIs.
146
- * GPU acceleration.
147
- * Comprehensive caching.
148
- * Bug fixes.
146
+ * GPU acceleration, possibly through WebGPU, Libtorch bindings, or CUDA.
147
+ * Proper optimization.
149
148
  * More detailed documentation.
150
149
  * Code refactoring.
151
150
  * Proper tests.
package/dist/core.js CHANGED
@@ -20,21 +20,21 @@ class Tensor {
20
20
  static createGraph = false;
21
21
  constructor(value, options = {}) {
22
22
  // Memory buffer
23
- this.dtype = options.dtype || "float32";
23
+ this.dtype = options.dtype ?? "float32";
24
24
  const flatValue = Tensor.flattenValue(value);
25
25
  const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
26
26
  this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
27
27
  // Tensor metadata
28
- this.shape = options.shape || Tensor.getShape(value);
29
- this.strides = options.strides || Tensor.getStrides(this.shape);
30
- this.offset = options.offset || 0;
31
- this.numel = options.numel || Tensor.shapeToSize(this.shape);
32
- this.device = options.device || "cpu";
28
+ this.shape = options.shape ?? Tensor.getShape(value);
29
+ this.strides = options.strides ?? Tensor.getStrides(this.shape);
30
+ this.offset = options.offset ?? 0;
31
+ this.numel = options.numel ?? Tensor.shapeToSize(this.shape);
32
+ this.device = options.device ?? "cpu";
33
33
  // Autograd data
34
34
  this.grad = options.grad;
35
35
  this.requiresGrad = options.requiresGrad ?? false;
36
- this.gradFn = options.gradFn || (() => { });
37
- this.children = options.children || [];
36
+ this.gradFn = options.gradFn ?? (() => { });
37
+ this.children = options.children ?? [];
38
38
  // Move to device in-place
39
39
  this.to_(this.device);
40
40
  }
@@ -622,14 +622,14 @@ class Tensor {
622
622
  return this;
623
623
  const newShape = [];
624
624
  const newStrides = [];
625
- let newOffset = this.offset || 0;
625
+ let newOffset = this.offset;
626
626
  // Pad ranges to match tensor dimensions
627
627
  const paddedRanges = [...ranges];
628
628
  while (paddedRanges.length < this.shape.length) {
629
629
  paddedRanges.push([]);
630
630
  }
631
631
  for (let i = 0; i < this.shape.length; i++) {
632
- const range = paddedRanges[i] || [];
632
+ const range = paddedRanges[i] ?? [];
633
633
  const dimSize = this.shape[i];
634
634
  const stride = this.strides[i];
635
635
  // Default values
@@ -675,7 +675,7 @@ class Tensor {
675
675
  const originalCoords = new Array(slicedCoords.length);
676
676
  for (let dim = 0; dim < slicedCoords.length; dim++) {
677
677
  const coord = slicedCoords[dim];
678
- const range = paddedRanges[dim] || [];
678
+ const range = paddedRanges[dim] ?? [];
679
679
  const start = range[0] ?? 0;
680
680
  const step = range[2] ?? 1;
681
681
  const normalizedStart = start < 0 ? start + this.shape[dim] : start;
@@ -2504,7 +2504,7 @@ class Tensor {
2504
2504
  visited.add(node);
2505
2505
  // Reset grad to zeros if specified
2506
2506
  if (zeroGrad) {
2507
- node.grad = Tensor.zerosLike(node);
2507
+ delete node.grad;
2508
2508
  }
2509
2509
  for (let child of node.children)
2510
2510
  build(child);
@@ -1,11 +1,12 @@
1
- import { BaseOptimizer } from "./optim";
1
+ import { OptimizerWithLR } from "./optim";
2
2
  export declare class StepLR {
3
- optimizer: BaseOptimizer;
3
+ optimizer: OptimizerWithLR;
4
4
  stepSize: number;
5
5
  gamma: number;
6
6
  lastEpoch: number;
7
7
  baseLR: number;
8
- constructor(optimizer: BaseOptimizer, stepSize: number, gamma?: number, lastEpoch?: number);
8
+ baseGroupLRs: number[];
9
+ constructor(optimizer: OptimizerWithLR, stepSize: number, gamma?: number, lastEpoch?: number);
9
10
  step(epoch?: number): void;
10
11
  }
11
12
  export declare const LRScheduler: {
@@ -7,12 +7,14 @@ class StepLR {
7
7
  gamma;
8
8
  lastEpoch;
9
9
  baseLR;
10
+ baseGroupLRs;
10
11
  constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
11
12
  this.optimizer = optimizer;
12
13
  this.stepSize = stepSize;
13
14
  this.gamma = gamma;
14
15
  this.lastEpoch = lastEpoch;
15
- this.baseLR = this.optimizer.lr;
16
+ this.baseLR = optimizer.lr;
17
+ this.baseGroupLRs = this.optimizer.paramGroups.map(paramGroup => paramGroup.lr ?? this.optimizer.lr);
16
18
  }
17
19
  step(epoch) {
18
20
  if (typeof epoch === "undefined") {
@@ -22,6 +24,11 @@ class StepLR {
22
24
  else {
23
25
  this.lastEpoch = epoch;
24
26
  }
27
+ // Update LR of each group
28
+ for (let index = 0; index < this.baseGroupLRs.length; index++) {
29
+ this.optimizer.paramGroups[index].lr = this.baseGroupLRs[index] * this.gamma ** Math.floor(epoch / this.stepSize);
30
+ }
31
+ // Update default LR
25
32
  this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
26
33
  }
27
34
  }
package/dist/optim.d.ts CHANGED
@@ -1,12 +1,15 @@
1
1
  import { Tensor } from "./core";
2
- export interface BaseOptimizerOptions {
3
- lr?: number;
2
+ export interface BaseParamGroup {
3
+ params: Tensor[];
4
+ [key: string]: any;
4
5
  }
5
6
  export declare abstract class BaseOptimizer {
6
- params: Tensor[];
7
+ paramGroups: BaseParamGroup[];
8
+ constructor(params: Tensor[] | BaseParamGroup[]);
9
+ zeroGrad(del?: boolean): void;
10
+ }
11
+ export interface OptimizerWithLR extends BaseOptimizer {
7
12
  lr: number;
8
- constructor(params: Tensor[], options?: BaseOptimizerOptions);
9
- zeroGrad(): void;
10
13
  }
11
14
  export interface SGDOptions {
12
15
  lr?: number;
@@ -15,13 +18,18 @@ export interface SGDOptions {
15
18
  weightDecay?: number;
16
19
  nesterov?: boolean;
17
20
  }
21
+ export interface SGDParamGroup extends SGDOptions {
22
+ params: Tensor[];
23
+ }
18
24
  export declare class SGD extends BaseOptimizer {
19
- momentumBuffers: Map<Tensor, Tensor>;
25
+ paramGroups: SGDParamGroup[];
26
+ lr: number;
20
27
  momentum: number;
21
28
  dampening: number;
22
29
  weightDecay: number;
23
30
  nesterov: boolean;
24
- constructor(params: Tensor[], options?: SGDOptions);
31
+ momentumBuffers: Map<Tensor, Tensor>;
32
+ constructor(params: Tensor[] | SGDParamGroup[], options?: SGDOptions);
25
33
  step(): void;
26
34
  }
27
35
  export interface AdamOptions {
@@ -30,14 +38,19 @@ export interface AdamOptions {
30
38
  eps?: number;
31
39
  weightDecay?: number;
32
40
  }
41
+ export interface AdamParamGroup extends AdamOptions {
42
+ params: Tensor[];
43
+ }
33
44
  export declare class Adam extends BaseOptimizer {
34
- momentumBuffers: Map<Tensor, Tensor>;
35
- velocityBuffers: Map<Tensor, Tensor>;
36
- stepCount: number;
45
+ paramGroups: AdamParamGroup[];
46
+ lr: number;
37
47
  betas: [number, number];
38
48
  eps: number;
39
49
  weightDecay: number;
40
- constructor(params: Tensor[], options?: AdamOptions);
50
+ momentumBuffers: Map<Tensor, Tensor>;
51
+ velocityBuffers: Map<Tensor, Tensor>;
52
+ stepCounts: Map<Tensor, number>;
53
+ constructor(params: Tensor[] | AdamParamGroup[], options?: AdamOptions);
41
54
  step(): void;
42
55
  }
43
56
  export interface AdamWOptions {
@@ -46,14 +59,19 @@ export interface AdamWOptions {
46
59
  eps?: number;
47
60
  weightDecay?: number;
48
61
  }
62
+ export interface AdamWParamGroup extends AdamWOptions {
63
+ params: Tensor[];
64
+ }
49
65
  export declare class AdamW extends BaseOptimizer {
50
- momentumBuffers: Map<Tensor, Tensor>;
51
- velocityBuffers: Map<Tensor, Tensor>;
52
- stepCount: number;
66
+ paramGroups: AdamWParamGroup[];
67
+ lr: number;
53
68
  betas: [number, number];
54
69
  eps: number;
55
70
  weightDecay: number;
56
- constructor(params: Tensor[], options?: AdamWOptions);
71
+ momentumBuffers: Map<Tensor, Tensor>;
72
+ velocityBuffers: Map<Tensor, Tensor>;
73
+ stepCounts: Map<Tensor, number>;
74
+ constructor(params: Tensor[] | AdamWParamGroup[], options?: AdamWOptions);
57
75
  step(): void;
58
76
  }
59
77
  export declare const Optim: {
package/dist/optim.js CHANGED
@@ -3,183 +3,220 @@ Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
4
4
  const core_1 = require("./core");
5
5
  class BaseOptimizer {
6
- params;
7
- lr;
8
- constructor(params, options) {
9
- this.params = params;
10
- this.lr = options?.lr || 0.001;
6
+ paramGroups;
7
+ constructor(params) {
8
+ if (params[0] instanceof core_1.Tensor) {
9
+ this.paramGroups = [{ params: params }];
10
+ }
11
+ else {
12
+ this.paramGroups = params;
13
+ }
11
14
  }
12
- zeroGrad() {
13
- for (let index = 0; index < this.params.length; index++) {
14
- const param = this.params[index];
15
- param.grad = core_1.Tensor.zerosLike(param);
15
+ zeroGrad(del = true) {
16
+ for (let index = 0; index < this.paramGroups.length; index++) {
17
+ const paramGroup = this.paramGroups[index];
18
+ for (const param of paramGroup.params) {
19
+ if (del) {
20
+ delete param.grad;
21
+ }
22
+ else {
23
+ param.grad = core_1.Tensor.zerosLike(param);
24
+ }
25
+ }
16
26
  }
17
27
  }
18
28
  }
19
29
  exports.BaseOptimizer = BaseOptimizer;
20
30
  class SGD extends BaseOptimizer {
21
- momentumBuffers = new Map();
31
+ lr;
22
32
  momentum;
23
33
  dampening;
24
34
  weightDecay;
25
35
  nesterov;
36
+ momentumBuffers = new Map();
26
37
  constructor(params, options) {
27
- super(params, options);
28
- this.momentum = options?.momentum || 0;
29
- this.dampening = options?.dampening || 0;
30
- this.weightDecay = options?.weightDecay || 0;
31
- this.nesterov = options?.nesterov || false;
38
+ super(params);
39
+ this.lr = options?.lr ?? 0.001;
40
+ this.momentum = options?.momentum ?? 0;
41
+ this.dampening = options?.dampening ?? 0;
42
+ this.weightDecay = options?.weightDecay ?? 0;
43
+ this.nesterov = options?.nesterov ?? false;
32
44
  }
33
45
  step() {
34
- for (const param of this.params) {
35
- if (!param.grad || !param.requiresGrad)
36
- continue;
37
- let grad = param.grad.detach(), detachedParam = param.detach();
38
- // Apply weight decay (L2 regularization)
39
- if (this.weightDecay !== 0) {
40
- grad = grad.add(detachedParam.mul(this.weightDecay));
41
- }
42
- // Apply momentum
43
- if (this.momentum !== 0) {
44
- let buf = this.momentumBuffers.get(param);
45
- if (!buf) {
46
- // First time: initialize momentum buffer with current gradient
47
- buf = grad.clone();
48
- this.momentumBuffers.set(param, buf);
49
- }
50
- else {
51
- // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
52
- buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening));
53
- this.momentumBuffers.set(param, buf);
54
- }
55
- if (this.nesterov) {
56
- // Nesterov momentum: grad = grad + momentum * buf
57
- grad = grad.add(buf.mul(this.momentum));
46
+ for (const paramGroup of this.paramGroups) {
47
+ const lr = paramGroup.lr ?? this.lr;
48
+ const momentum = paramGroup.momentum ?? this.momentum;
49
+ const dampening = paramGroup.dampening ?? this.dampening;
50
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
51
+ const nesterov = paramGroup.nesterov ?? this.nesterov;
52
+ for (const param of paramGroup.params) {
53
+ if (!param.grad || !param.requiresGrad)
54
+ continue;
55
+ let grad = param.grad.detach(), detachedParam = param.detach();
56
+ // Apply weight decay (L2 regularization)
57
+ if (weightDecay !== 0) {
58
+ grad = grad.add(detachedParam.mul(weightDecay));
58
59
  }
59
- else {
60
- // Standard momentum: use momentum buffer as gradient
61
- grad = buf;
60
+ // Apply momentum
61
+ if (momentum !== 0) {
62
+ let buf = this.momentumBuffers.get(param);
63
+ if (!buf) {
64
+ // First time: initialize momentum buffer with current gradient
65
+ buf = grad.clone();
66
+ this.momentumBuffers.set(param, buf);
67
+ }
68
+ else {
69
+ // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad
70
+ buf = buf.mul(momentum).add(grad.mul(1 - dampening));
71
+ this.momentumBuffers.set(param, buf);
72
+ }
73
+ if (nesterov) {
74
+ // Nesterov momentum: grad = grad + momentum * buf
75
+ grad = grad.add(buf.mul(momentum));
76
+ }
77
+ else {
78
+ // Standard momentum: use momentum buffer as gradient
79
+ grad = buf;
80
+ }
62
81
  }
82
+ // Update parameter: param = param - lr * grad
83
+ const newParam = detachedParam.sub(grad.mul(lr));
84
+ param.replace(newParam);
63
85
  }
64
- // Update parameter: param = param - lr * grad
65
- const newParam = detachedParam.sub(grad.mul(this.lr));
66
- param.replace(newParam);
67
86
  }
68
87
  }
69
88
  }
70
89
  exports.SGD = SGD;
71
90
  class Adam extends BaseOptimizer {
72
- momentumBuffers = new Map(); // First moment (m_t)
73
- velocityBuffers = new Map(); // Second moment (v_t)
74
- stepCount = 0;
91
+ lr;
75
92
  betas;
76
93
  eps;
77
94
  weightDecay;
95
+ momentumBuffers = new Map(); // First moment (m_t)
96
+ velocityBuffers = new Map(); // Second moment (v_t)
97
+ stepCounts = new Map();
78
98
  constructor(params, options) {
79
- super(params, options);
80
- this.betas = options?.betas || [0.9, 0.999];
81
- this.eps = options?.eps || 1e-8;
82
- this.weightDecay = options?.weightDecay || 0;
99
+ super(params);
100
+ this.lr = options?.lr ?? 0.001;
101
+ this.betas = options?.betas ?? [0.9, 0.999];
102
+ this.eps = options?.eps ?? 1e-8;
103
+ this.weightDecay = options?.weightDecay ?? 0;
83
104
  }
84
105
  step() {
85
- this.stepCount++;
86
- const beta1 = this.betas[0];
87
- const beta2 = this.betas[1];
88
- // Bias correction factors
89
- const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
90
- const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
91
- for (const param of this.params) {
92
- if (!param.grad || !param.requiresGrad)
93
- continue;
94
- let grad = param.grad.detach(), detachedParam = param.detach();
95
- // Apply weight decay (L2 regularization)
96
- if (this.weightDecay !== 0) {
97
- grad = grad.add(detachedParam.mul(this.weightDecay));
98
- }
99
- // Get or initialize first moment buffer (momentum)
100
- let momentumBuffer = this.momentumBuffers.get(param);
101
- if (!momentumBuffer) {
102
- momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
106
+ for (const paramGroup of this.paramGroups) {
107
+ const lr = paramGroup.lr ?? this.lr;
108
+ const betas = paramGroup.betas ?? this.betas;
109
+ const eps = paramGroup.eps ?? this.eps;
110
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
111
+ for (const param of paramGroup.params) {
112
+ if (!param.grad || !param.requiresGrad)
113
+ continue;
114
+ // Get current step for param, initialize if has not step before
115
+ const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
116
+ this.stepCounts.set(param, stepCount);
117
+ // Bias correction factors
118
+ const [beta1, beta2] = betas;
119
+ const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
120
+ const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
121
+ let grad = param.grad.detach(), detachedParam = param.detach();
122
+ // Apply weight decay (L2 regularization)
123
+ if (weightDecay !== 0) {
124
+ grad = grad.add(detachedParam.mul(weightDecay));
125
+ }
126
+ // Get or initialize first moment buffer (momentum)
127
+ let momentumBuffer = this.momentumBuffers.get(param);
128
+ if (!momentumBuffer) {
129
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
130
+ this.momentumBuffers.set(param, momentumBuffer);
131
+ }
132
+ // Get or initialize second moment buffer (velocity)
133
+ let velocityBuffer = this.velocityBuffers.get(param);
134
+ if (!velocityBuffer) {
135
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
136
+ this.velocityBuffers.set(param, velocityBuffer);
137
+ }
138
+ // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
139
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
103
140
  this.momentumBuffers.set(param, momentumBuffer);
104
- }
105
- // Get or initialize second moment buffer (velocity)
106
- let velocityBuffer = this.velocityBuffers.get(param);
107
- if (!velocityBuffer) {
108
- velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
141
+ // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
142
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
109
143
  this.velocityBuffers.set(param, velocityBuffer);
144
+ // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
145
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
146
+ // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
147
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
148
+ // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
149
+ const denom = correctedVelocity.sqrt().add(eps);
150
+ const stepSize = correctedMomentum.div(denom).mul(lr);
151
+ const newParam = detachedParam.sub(stepSize);
152
+ param.replace(newParam);
110
153
  }
111
- // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
112
- momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
113
- this.momentumBuffers.set(param, momentumBuffer);
114
- // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
115
- velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
116
- this.velocityBuffers.set(param, velocityBuffer);
117
- // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
118
- const correctedMomentum = momentumBuffer.div(biasCorrection1);
119
- // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
120
- const correctedVelocity = velocityBuffer.div(biasCorrection2);
121
- // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
122
- const denom = correctedVelocity.sqrt().add(this.eps);
123
- const stepSize = correctedMomentum.div(denom).mul(this.lr);
124
- const newParam = detachedParam.sub(stepSize);
125
- param.replace(newParam);
126
154
  }
127
155
  }
128
156
  }
129
157
  exports.Adam = Adam;
130
158
  class AdamW extends BaseOptimizer {
131
- momentumBuffers = new Map(); // First moment (m_t)
132
- velocityBuffers = new Map(); // Second moment (v_t)
133
- stepCount = 0;
159
+ lr;
134
160
  betas;
135
161
  eps;
136
162
  weightDecay;
163
+ momentumBuffers = new Map(); // First moment (m_t)
164
+ velocityBuffers = new Map(); // Second moment (v_t)
165
+ stepCounts = new Map();
137
166
  constructor(params, options) {
138
- super(params, options);
139
- this.betas = options?.betas || [0.9, 0.999];
140
- this.eps = options?.eps || 1e-8;
141
- this.weightDecay = options?.weightDecay || 0.01;
167
+ super(params);
168
+ this.lr = options?.lr ?? 0.001;
169
+ this.betas = options?.betas ?? [0.9, 0.999];
170
+ this.eps = options?.eps ?? 1e-8;
171
+ this.weightDecay = options?.weightDecay ?? 0.01;
142
172
  }
143
173
  step() {
144
- this.stepCount++;
145
- const beta1 = this.betas[0];
146
- const beta2 = this.betas[1];
147
- // Bias correction factors
148
- const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
149
- const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
150
- for (const param of this.params) {
151
- if (!param.grad || !param.requiresGrad)
152
- continue;
153
- let grad = param.grad.detach(), detachedParam = param.detach();
154
- // Apply weight decay (L2 regularization)
155
- detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
156
- // Get or initialize first moment buffer (momentum)
157
- let momentumBuffer = this.momentumBuffers.get(param);
158
- if (!momentumBuffer) {
159
- momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
174
+ for (const paramGroup of this.paramGroups) {
175
+ const lr = paramGroup.lr ?? this.lr;
176
+ const betas = paramGroup.betas ?? this.betas;
177
+ const eps = paramGroup.eps ?? this.eps;
178
+ const weightDecay = paramGroup.weightDecay ?? this.weightDecay;
179
+ for (const param of paramGroup.params) {
180
+ if (!param.grad || !param.requiresGrad)
181
+ continue;
182
+ // Get current step for param, initialize if has not step before
183
+ const stepCount = (this.stepCounts.get(param) ?? 0) + 1;
184
+ this.stepCounts.set(param, stepCount);
185
+ // Bias correction factors
186
+ const [beta1, beta2] = betas;
187
+ const biasCorrection1 = 1 - Math.pow(beta1, stepCount);
188
+ const biasCorrection2 = 1 - Math.pow(beta2, stepCount);
189
+ let grad = param.grad.detach(), detachedParam = param.detach();
190
+ // Apply weight decay (L2 regularization)
191
+ detachedParam = detachedParam.sub(detachedParam.mul(weightDecay).mul(lr));
192
+ // Get or initialize first moment buffer (momentum)
193
+ let momentumBuffer = this.momentumBuffers.get(param);
194
+ if (!momentumBuffer) {
195
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
196
+ this.momentumBuffers.set(param, momentumBuffer);
197
+ }
198
+ // Get or initialize second moment buffer (velocity)
199
+ let velocityBuffer = this.velocityBuffers.get(param);
200
+ if (!velocityBuffer) {
201
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
202
+ this.velocityBuffers.set(param, velocityBuffer);
203
+ }
204
+ // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
205
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
160
206
  this.momentumBuffers.set(param, momentumBuffer);
161
- }
162
- // Get or initialize second moment buffer (velocity)
163
- let velocityBuffer = this.velocityBuffers.get(param);
164
- if (!velocityBuffer) {
165
- velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
207
+ // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
208
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
166
209
  this.velocityBuffers.set(param, velocityBuffer);
210
+ // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
211
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
212
+ // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
213
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
214
+ // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
215
+ const denom = correctedVelocity.sqrt().add(eps);
216
+ const stepSize = correctedMomentum.div(denom).mul(lr);
217
+ const newParam = detachedParam.sub(stepSize);
218
+ param.replace(newParam);
167
219
  }
168
- // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
169
- momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
170
- this.momentumBuffers.set(param, momentumBuffer);
171
- // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
172
- velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
173
- this.velocityBuffers.set(param, velocityBuffer);
174
- // Compute bias-corrected first moment: m_hat_t = m_t / (1 - beta1^t)
175
- const correctedMomentum = momentumBuffer.div(biasCorrection1);
176
- // Compute bias-corrected second moment: v_hat_t = v_t / (1 - beta2^t)
177
- const correctedVelocity = velocityBuffer.div(biasCorrection2);
178
- // Update parameters: theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
179
- const denom = correctedVelocity.sqrt().add(this.eps);
180
- const stepSize = correctedMomentum.div(denom).mul(this.lr);
181
- const newParam = detachedParam.sub(stepSize);
182
- param.replace(newParam);
183
220
  }
184
221
  }
185
222
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.13",
3
+ "version": "0.8.15",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {