catniff 0.6.15 → 0.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +84 -4
- package/dist/lrscheduler.d.ts +13 -0
- package/dist/lrscheduler.js +31 -0
- package/dist/nn.d.ts +8 -9
- package/dist/nn.js +9 -1
- package/dist/optim.d.ts +9 -9
- package/dist/optim.js +11 -11
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -61,6 +61,7 @@ export declare class Tensor {
|
|
|
61
61
|
index(indices: Tensor | TensorValue): Tensor;
|
|
62
62
|
slice(ranges: number[][]): Tensor;
|
|
63
63
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
64
|
+
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
64
65
|
squeeze(dims?: number[] | number): Tensor;
|
|
65
66
|
unsqueeze(dim: number): Tensor;
|
|
66
67
|
static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
|
package/dist/core.js
CHANGED
|
@@ -218,7 +218,6 @@ class Tensor {
|
|
|
218
218
|
}
|
|
219
219
|
if (out.requiresGrad) {
|
|
220
220
|
out.gradFn = () => {
|
|
221
|
-
// Disable gradient collecting of gradients themselves
|
|
222
221
|
const outGrad = out.grad;
|
|
223
222
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
224
223
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|
|
@@ -239,7 +238,6 @@ class Tensor {
|
|
|
239
238
|
}
|
|
240
239
|
if (out.requiresGrad) {
|
|
241
240
|
out.gradFn = () => {
|
|
242
|
-
// Disable gradient collecting of gradients themselves
|
|
243
241
|
const outGrad = out.grad;
|
|
244
242
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
245
243
|
if (this.requiresGrad)
|
|
@@ -649,6 +647,90 @@ class Tensor {
|
|
|
649
647
|
}
|
|
650
648
|
return results;
|
|
651
649
|
}
|
|
650
|
+
// Tensor concatentation
|
|
651
|
+
cat(other, dim = 0) {
|
|
652
|
+
other = this.handleOther(other);
|
|
653
|
+
// Handle scalars
|
|
654
|
+
if (typeof this.value === "number" || typeof other.value === "number") {
|
|
655
|
+
throw new Error("Can not concatenate scalars");
|
|
656
|
+
}
|
|
657
|
+
// Handle negative indices
|
|
658
|
+
if (dim < 0) {
|
|
659
|
+
dim += this.shape.length;
|
|
660
|
+
}
|
|
661
|
+
// If dimension out of bound, throw error
|
|
662
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
663
|
+
throw new Error("Dimension does not exist to concatenate");
|
|
664
|
+
}
|
|
665
|
+
// If shape does not match, throw error
|
|
666
|
+
if (this.shape.length !== other.shape.length) {
|
|
667
|
+
throw new Error("Shape does not match to concatenate");
|
|
668
|
+
}
|
|
669
|
+
const outputShape = new Array(this.shape.length);
|
|
670
|
+
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
|
|
671
|
+
if (currentDim === dim) {
|
|
672
|
+
outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim];
|
|
673
|
+
}
|
|
674
|
+
else if (this.shape[currentDim] !== other.shape[currentDim]) {
|
|
675
|
+
throw new Error("Shape does not match to concatenate");
|
|
676
|
+
}
|
|
677
|
+
else {
|
|
678
|
+
outputShape[currentDim] = this.shape[currentDim];
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
682
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
683
|
+
const outputValue = new Array(outputSize);
|
|
684
|
+
for (let outIndex = 0; outIndex < outputSize; outIndex++) {
|
|
685
|
+
const coords = Tensor.indexToCoords(outIndex, outputStrides);
|
|
686
|
+
// Check which tensor this output position comes from
|
|
687
|
+
if (coords[dim] < this.shape[dim]) {
|
|
688
|
+
// Comes from this tensor
|
|
689
|
+
const srcIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
690
|
+
outputValue[outIndex] = this.value[srcIndex + this.offset];
|
|
691
|
+
}
|
|
692
|
+
else {
|
|
693
|
+
// Comes from other tensor - adjust coordinate in concat dimension
|
|
694
|
+
const otherCoords = [...coords];
|
|
695
|
+
otherCoords[dim] -= this.shape[dim];
|
|
696
|
+
const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides);
|
|
697
|
+
outputValue[outIndex] = other.value[srcIndex + other.offset];
|
|
698
|
+
}
|
|
699
|
+
}
|
|
700
|
+
const out = new Tensor(outputValue, {
|
|
701
|
+
shape: outputShape,
|
|
702
|
+
strides: outputStrides,
|
|
703
|
+
numel: outputSize
|
|
704
|
+
});
|
|
705
|
+
if (this.requiresGrad) {
|
|
706
|
+
out.requiresGrad = true;
|
|
707
|
+
out.children.push(this);
|
|
708
|
+
}
|
|
709
|
+
if (other.requiresGrad) {
|
|
710
|
+
out.requiresGrad = true;
|
|
711
|
+
out.children.push(other);
|
|
712
|
+
}
|
|
713
|
+
if (out.requiresGrad) {
|
|
714
|
+
out.gradFn = () => {
|
|
715
|
+
const outGrad = out.grad;
|
|
716
|
+
const thisRanges = new Array(this.shape.length);
|
|
717
|
+
const otherRanges = new Array(other.shape.length);
|
|
718
|
+
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
|
|
719
|
+
if (currentDim === dim) {
|
|
720
|
+
thisRanges[currentDim] = [0, this.shape[currentDim], 1];
|
|
721
|
+
otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1];
|
|
722
|
+
}
|
|
723
|
+
else {
|
|
724
|
+
thisRanges[currentDim] = [];
|
|
725
|
+
otherRanges[currentDim] = [];
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
Tensor.addGrad(this, outGrad.slice(thisRanges));
|
|
729
|
+
Tensor.addGrad(other, outGrad.slice(otherRanges));
|
|
730
|
+
};
|
|
731
|
+
}
|
|
732
|
+
return out;
|
|
733
|
+
}
|
|
652
734
|
// Tensor squeeze
|
|
653
735
|
squeeze(dims) {
|
|
654
736
|
if (typeof this.value === "number")
|
|
@@ -1338,7 +1420,6 @@ class Tensor {
|
|
|
1338
1420
|
}
|
|
1339
1421
|
if (out.requiresGrad) {
|
|
1340
1422
|
out.gradFn = () => {
|
|
1341
|
-
// Disable gradient collecting of gradients themselves
|
|
1342
1423
|
const outGrad = out.grad;
|
|
1343
1424
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
1344
1425
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|
|
@@ -1396,7 +1477,6 @@ class Tensor {
|
|
|
1396
1477
|
}
|
|
1397
1478
|
if (out.requiresGrad) {
|
|
1398
1479
|
out.gradFn = () => {
|
|
1399
|
-
// Disable gradient collecting of gradients themselves
|
|
1400
1480
|
const outGrad = out.grad;
|
|
1401
1481
|
const selfWithGrad = Tensor.createGraph ? this : this.detach();
|
|
1402
1482
|
const otherWithGrad = Tensor.createGraph ? other : other.detach();
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { BaseOptimizer } from "./optim";
|
|
2
|
+
export declare class StepLR {
|
|
3
|
+
optimizer: BaseOptimizer;
|
|
4
|
+
stepSize: number;
|
|
5
|
+
gamma: number;
|
|
6
|
+
lastEpoch: number;
|
|
7
|
+
baseLR: number;
|
|
8
|
+
constructor(optimizer: BaseOptimizer, stepSize: number, gamma?: number, lastEpoch?: number);
|
|
9
|
+
step(epoch?: number): void;
|
|
10
|
+
}
|
|
11
|
+
export declare const LRScheduler: {
|
|
12
|
+
StepLR: typeof StepLR;
|
|
13
|
+
};
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.LRScheduler = exports.StepLR = void 0;
|
|
4
|
+
class StepLR {
|
|
5
|
+
optimizer;
|
|
6
|
+
stepSize;
|
|
7
|
+
gamma;
|
|
8
|
+
lastEpoch;
|
|
9
|
+
baseLR;
|
|
10
|
+
constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
|
|
11
|
+
this.optimizer = optimizer;
|
|
12
|
+
this.stepSize = stepSize;
|
|
13
|
+
this.gamma = gamma;
|
|
14
|
+
this.lastEpoch = lastEpoch;
|
|
15
|
+
this.baseLR = this.optimizer.lr;
|
|
16
|
+
}
|
|
17
|
+
step(epoch) {
|
|
18
|
+
if (typeof epoch === "undefined") {
|
|
19
|
+
this.lastEpoch++;
|
|
20
|
+
epoch = this.lastEpoch;
|
|
21
|
+
}
|
|
22
|
+
else {
|
|
23
|
+
this.lastEpoch = epoch;
|
|
24
|
+
}
|
|
25
|
+
this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
exports.StepLR = StepLR;
|
|
29
|
+
exports.LRScheduler = {
|
|
30
|
+
StepLR
|
|
31
|
+
};
|
package/dist/nn.d.ts
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import { Tensor, TensorValue } from "./core";
|
|
2
|
-
declare class Linear {
|
|
2
|
+
export declare class Linear {
|
|
3
3
|
weight: Tensor;
|
|
4
4
|
bias?: Tensor;
|
|
5
5
|
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
|
|
6
6
|
forward(input: Tensor | TensorValue): Tensor;
|
|
7
7
|
}
|
|
8
|
-
declare class RNNCell {
|
|
8
|
+
export declare class RNNCell {
|
|
9
9
|
weightIH: Tensor;
|
|
10
10
|
weightHH: Tensor;
|
|
11
11
|
biasIH?: Tensor;
|
|
@@ -13,7 +13,7 @@ declare class RNNCell {
|
|
|
13
13
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
14
14
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
15
15
|
}
|
|
16
|
-
declare class GRUCell {
|
|
16
|
+
export declare class GRUCell {
|
|
17
17
|
weightIR: Tensor;
|
|
18
18
|
weightIZ: Tensor;
|
|
19
19
|
weightIN: Tensor;
|
|
@@ -29,7 +29,7 @@ declare class GRUCell {
|
|
|
29
29
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
30
30
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
31
31
|
}
|
|
32
|
-
declare class LSTMCell {
|
|
32
|
+
export declare class LSTMCell {
|
|
33
33
|
weightII: Tensor;
|
|
34
34
|
weightIF: Tensor;
|
|
35
35
|
weightIG: Tensor;
|
|
@@ -49,7 +49,7 @@ declare class LSTMCell {
|
|
|
49
49
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
50
50
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
51
51
|
}
|
|
52
|
-
declare class LayerNorm {
|
|
52
|
+
export declare class LayerNorm {
|
|
53
53
|
weight?: Tensor;
|
|
54
54
|
bias?: Tensor;
|
|
55
55
|
eps: number;
|
|
@@ -57,19 +57,19 @@ declare class LayerNorm {
|
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
58
|
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
|
-
declare class RMSNorm {
|
|
60
|
+
export declare class RMSNorm {
|
|
61
61
|
weight?: Tensor;
|
|
62
62
|
eps: number;
|
|
63
63
|
normalizedShape: number[];
|
|
64
64
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, device?: string);
|
|
65
65
|
forward(input: Tensor): Tensor;
|
|
66
66
|
}
|
|
67
|
-
declare class Embedding {
|
|
67
|
+
export declare class Embedding {
|
|
68
68
|
weight: Tensor;
|
|
69
69
|
constructor(numEmbeddings: number, embeddingDim: number, device: string);
|
|
70
70
|
forward(input: Tensor | TensorValue): Tensor;
|
|
71
71
|
}
|
|
72
|
-
declare class MultiheadAttention {
|
|
72
|
+
export declare class MultiheadAttention {
|
|
73
73
|
qProjection: Linear;
|
|
74
74
|
kProjection: Linear;
|
|
75
75
|
vProjection: Linear;
|
|
@@ -100,4 +100,3 @@ export declare const nn: {
|
|
|
100
100
|
loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
|
|
101
101
|
};
|
|
102
102
|
};
|
|
103
|
-
export {};
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
function linearTransform(input, weight, bias) {
|
|
6
6
|
let output = input.matmul(weight.t());
|
|
@@ -24,6 +24,7 @@ class Linear {
|
|
|
24
24
|
return linearTransform(input, this.weight, this.bias);
|
|
25
25
|
}
|
|
26
26
|
}
|
|
27
|
+
exports.Linear = Linear;
|
|
27
28
|
function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
|
|
28
29
|
let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
|
|
29
30
|
if (inputBias) {
|
|
@@ -54,6 +55,7 @@ class RNNCell {
|
|
|
54
55
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
55
56
|
}
|
|
56
57
|
}
|
|
58
|
+
exports.RNNCell = RNNCell;
|
|
57
59
|
class GRUCell {
|
|
58
60
|
weightIR;
|
|
59
61
|
weightIZ;
|
|
@@ -93,6 +95,7 @@ class GRUCell {
|
|
|
93
95
|
return (z.neg().add(1).mul(n).add(z.mul(hidden)));
|
|
94
96
|
}
|
|
95
97
|
}
|
|
98
|
+
exports.GRUCell = GRUCell;
|
|
96
99
|
class LSTMCell {
|
|
97
100
|
weightII;
|
|
98
101
|
weightIF;
|
|
@@ -144,6 +147,7 @@ class LSTMCell {
|
|
|
144
147
|
return [h, c];
|
|
145
148
|
}
|
|
146
149
|
}
|
|
150
|
+
exports.LSTMCell = LSTMCell;
|
|
147
151
|
class LayerNorm {
|
|
148
152
|
weight;
|
|
149
153
|
bias;
|
|
@@ -188,6 +192,7 @@ class LayerNorm {
|
|
|
188
192
|
return normalized;
|
|
189
193
|
}
|
|
190
194
|
}
|
|
195
|
+
exports.LayerNorm = LayerNorm;
|
|
191
196
|
class RMSNorm {
|
|
192
197
|
weight;
|
|
193
198
|
eps;
|
|
@@ -224,6 +229,7 @@ class RMSNorm {
|
|
|
224
229
|
return normalized;
|
|
225
230
|
}
|
|
226
231
|
}
|
|
232
|
+
exports.RMSNorm = RMSNorm;
|
|
227
233
|
class Embedding {
|
|
228
234
|
weight;
|
|
229
235
|
constructor(numEmbeddings, embeddingDim, device) {
|
|
@@ -233,6 +239,7 @@ class Embedding {
|
|
|
233
239
|
return this.weight.index(input);
|
|
234
240
|
}
|
|
235
241
|
}
|
|
242
|
+
exports.Embedding = Embedding;
|
|
236
243
|
class MultiheadAttention {
|
|
237
244
|
qProjection;
|
|
238
245
|
kProjection;
|
|
@@ -284,6 +291,7 @@ class MultiheadAttention {
|
|
|
284
291
|
return [output, needWeights ? attnWeights : undefined];
|
|
285
292
|
}
|
|
286
293
|
}
|
|
294
|
+
exports.MultiheadAttention = MultiheadAttention;
|
|
287
295
|
const state = {
|
|
288
296
|
getParameters(model, visited = new WeakSet()) {
|
|
289
297
|
if (visited.has(model))
|
package/dist/optim.d.ts
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import { Tensor } from "./core";
|
|
2
|
-
|
|
2
|
+
export interface BaseOptimizerOptions {
|
|
3
|
+
lr?: number;
|
|
4
|
+
}
|
|
5
|
+
export declare abstract class BaseOptimizer {
|
|
3
6
|
params: Tensor[];
|
|
4
|
-
|
|
7
|
+
lr: number;
|
|
8
|
+
constructor(params: Tensor[], options?: BaseOptimizerOptions);
|
|
5
9
|
zeroGrad(): void;
|
|
6
10
|
}
|
|
7
11
|
export interface SGDOptions {
|
|
@@ -11,9 +15,8 @@ export interface SGDOptions {
|
|
|
11
15
|
weightDecay?: number;
|
|
12
16
|
nesterov?: boolean;
|
|
13
17
|
}
|
|
14
|
-
declare class SGD extends BaseOptimizer {
|
|
18
|
+
export declare class SGD extends BaseOptimizer {
|
|
15
19
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
16
|
-
lr: number;
|
|
17
20
|
momentum: number;
|
|
18
21
|
dampening: number;
|
|
19
22
|
weightDecay: number;
|
|
@@ -27,11 +30,10 @@ export interface AdamOptions {
|
|
|
27
30
|
eps?: number;
|
|
28
31
|
weightDecay?: number;
|
|
29
32
|
}
|
|
30
|
-
declare class Adam extends BaseOptimizer {
|
|
33
|
+
export declare class Adam extends BaseOptimizer {
|
|
31
34
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
32
35
|
velocityBuffers: Map<Tensor, Tensor>;
|
|
33
36
|
stepCount: number;
|
|
34
|
-
lr: number;
|
|
35
37
|
betas: [number, number];
|
|
36
38
|
eps: number;
|
|
37
39
|
weightDecay: number;
|
|
@@ -44,11 +46,10 @@ export interface AdamWOptions {
|
|
|
44
46
|
eps?: number;
|
|
45
47
|
weightDecay?: number;
|
|
46
48
|
}
|
|
47
|
-
declare class AdamW extends BaseOptimizer {
|
|
49
|
+
export declare class AdamW extends BaseOptimizer {
|
|
48
50
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
49
51
|
velocityBuffers: Map<Tensor, Tensor>;
|
|
50
52
|
stepCount: number;
|
|
51
|
-
lr: number;
|
|
52
53
|
betas: [number, number];
|
|
53
54
|
eps: number;
|
|
54
55
|
weightDecay: number;
|
|
@@ -61,4 +62,3 @@ export declare const Optim: {
|
|
|
61
62
|
Adam: typeof Adam;
|
|
62
63
|
AdamW: typeof AdamW;
|
|
63
64
|
};
|
|
64
|
-
export {};
|
package/dist/optim.js
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Optim = void 0;
|
|
3
|
+
exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class BaseOptimizer {
|
|
6
6
|
params;
|
|
7
|
-
|
|
7
|
+
lr;
|
|
8
|
+
constructor(params, options) {
|
|
8
9
|
this.params = params;
|
|
10
|
+
this.lr = options?.lr || 0.001;
|
|
9
11
|
}
|
|
10
12
|
zeroGrad() {
|
|
11
13
|
for (let index = 0; index < this.params.length; index++) {
|
|
@@ -14,16 +16,15 @@ class BaseOptimizer {
|
|
|
14
16
|
}
|
|
15
17
|
}
|
|
16
18
|
}
|
|
19
|
+
exports.BaseOptimizer = BaseOptimizer;
|
|
17
20
|
class SGD extends BaseOptimizer {
|
|
18
21
|
momentumBuffers = new Map();
|
|
19
|
-
lr;
|
|
20
22
|
momentum;
|
|
21
23
|
dampening;
|
|
22
24
|
weightDecay;
|
|
23
25
|
nesterov;
|
|
24
26
|
constructor(params, options) {
|
|
25
|
-
super(params);
|
|
26
|
-
this.lr = options?.lr || 0.001;
|
|
27
|
+
super(params, options);
|
|
27
28
|
this.momentum = options?.momentum || 0;
|
|
28
29
|
this.dampening = options?.dampening || 0;
|
|
29
30
|
this.weightDecay = options?.weightDecay || 0;
|
|
@@ -66,17 +67,16 @@ class SGD extends BaseOptimizer {
|
|
|
66
67
|
}
|
|
67
68
|
}
|
|
68
69
|
}
|
|
70
|
+
exports.SGD = SGD;
|
|
69
71
|
class Adam extends BaseOptimizer {
|
|
70
72
|
momentumBuffers = new Map(); // First moment (m_t)
|
|
71
73
|
velocityBuffers = new Map(); // Second moment (v_t)
|
|
72
74
|
stepCount = 0;
|
|
73
|
-
lr;
|
|
74
75
|
betas;
|
|
75
76
|
eps;
|
|
76
77
|
weightDecay;
|
|
77
78
|
constructor(params, options) {
|
|
78
|
-
super(params);
|
|
79
|
-
this.lr = options?.lr || 0.001;
|
|
79
|
+
super(params, options);
|
|
80
80
|
this.betas = options?.betas || [0.9, 0.999];
|
|
81
81
|
this.eps = options?.eps || 1e-8;
|
|
82
82
|
this.weightDecay = options?.weightDecay || 0;
|
|
@@ -126,17 +126,16 @@ class Adam extends BaseOptimizer {
|
|
|
126
126
|
}
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
+
exports.Adam = Adam;
|
|
129
130
|
class AdamW extends BaseOptimizer {
|
|
130
131
|
momentumBuffers = new Map(); // First moment (m_t)
|
|
131
132
|
velocityBuffers = new Map(); // Second moment (v_t)
|
|
132
133
|
stepCount = 0;
|
|
133
|
-
lr;
|
|
134
134
|
betas;
|
|
135
135
|
eps;
|
|
136
136
|
weightDecay;
|
|
137
137
|
constructor(params, options) {
|
|
138
|
-
super(params);
|
|
139
|
-
this.lr = options?.lr || 0.001;
|
|
138
|
+
super(params, options);
|
|
140
139
|
this.betas = options?.betas || [0.9, 0.999];
|
|
141
140
|
this.eps = options?.eps || 1e-8;
|
|
142
141
|
this.weightDecay = options?.weightDecay || 0.01;
|
|
@@ -184,6 +183,7 @@ class AdamW extends BaseOptimizer {
|
|
|
184
183
|
}
|
|
185
184
|
}
|
|
186
185
|
}
|
|
186
|
+
exports.AdamW = AdamW;
|
|
187
187
|
exports.Optim = {
|
|
188
188
|
BaseOptimizer,
|
|
189
189
|
SGD,
|
package/index.d.ts
CHANGED
package/index.js
CHANGED