catniff 0.5.9 → 0.5.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +4 -2
- package/dist/core.js +46 -33
- package/dist/nn.d.ts +2 -1
- package/dist/nn.js +14 -9
- package/dist/optim.d.ts +18 -0
- package/dist/optim.js +59 -0
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -38,7 +38,7 @@ export declare class Tensor {
|
|
|
38
38
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
39
39
|
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
40
40
|
elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
41
|
-
|
|
41
|
+
handleOther(other: Tensor | TensorValue): Tensor;
|
|
42
42
|
static addGrad(tensor: Tensor, accumGrad: Tensor): void;
|
|
43
43
|
isContiguous(): boolean;
|
|
44
44
|
contiguous(): Tensor;
|
|
@@ -50,6 +50,8 @@ export declare class Tensor {
|
|
|
50
50
|
mean(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
51
51
|
max(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
52
52
|
min(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
53
|
+
all(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
54
|
+
any(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
53
55
|
var(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
54
56
|
std(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
55
57
|
softmax(dims?: number[] | number): Tensor;
|
|
@@ -174,7 +176,7 @@ export declare class Tensor {
|
|
|
174
176
|
withGrad(requiresGrad: boolean): Tensor;
|
|
175
177
|
detach(): Tensor;
|
|
176
178
|
clone(): Tensor;
|
|
177
|
-
replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
|
|
179
|
+
replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
|
|
178
180
|
static backends: Map<string, Backend>;
|
|
179
181
|
to(device: string): Tensor;
|
|
180
182
|
to_(device: string): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -183,7 +183,7 @@ class Tensor {
|
|
|
183
183
|
}
|
|
184
184
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
185
185
|
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
186
|
-
other =
|
|
186
|
+
other = this.handleOther(other);
|
|
187
187
|
const out = Tensor.elementWiseAB(this, other, op);
|
|
188
188
|
if (this.requiresGrad) {
|
|
189
189
|
out.requiresGrad = true;
|
|
@@ -225,11 +225,15 @@ class Tensor {
|
|
|
225
225
|
}
|
|
226
226
|
return out;
|
|
227
227
|
}
|
|
228
|
-
// Utility to
|
|
229
|
-
|
|
230
|
-
if (
|
|
231
|
-
|
|
232
|
-
|
|
228
|
+
// Utility to handle other tensor if an op needs a second operand
|
|
229
|
+
handleOther(other) {
|
|
230
|
+
if (other instanceof Tensor) {
|
|
231
|
+
if (this.device !== other.device) {
|
|
232
|
+
throw new Error("Can not operate on tensors that are not on the same device");
|
|
233
|
+
}
|
|
234
|
+
return other;
|
|
235
|
+
}
|
|
236
|
+
return new Tensor(other, { device: this.device });
|
|
233
237
|
}
|
|
234
238
|
// Utility to add to gradient of tensor
|
|
235
239
|
static addGrad(tensor, accumGrad) {
|
|
@@ -428,9 +432,9 @@ class Tensor {
|
|
|
428
432
|
}
|
|
429
433
|
// Calculate new value after sum
|
|
430
434
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
431
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
432
435
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
433
|
-
const outCoords =
|
|
436
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
437
|
+
outCoords[dims] = 0;
|
|
434
438
|
// Convert output coordinates to flat index
|
|
435
439
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
436
440
|
// Add into sum
|
|
@@ -479,9 +483,9 @@ class Tensor {
|
|
|
479
483
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
480
484
|
// Calculate new value after multiplying
|
|
481
485
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
482
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
483
486
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
484
|
-
const outCoords =
|
|
487
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
488
|
+
outCoords[dims] = 0;
|
|
485
489
|
// Convert output coordinates to flat index
|
|
486
490
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
487
491
|
// Multiply into product
|
|
@@ -498,9 +502,9 @@ class Tensor {
|
|
|
498
502
|
out.gradFn = () => {
|
|
499
503
|
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
500
504
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
501
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
502
505
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
503
|
-
const outCoords =
|
|
506
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
507
|
+
outCoords[dims] = 0;
|
|
504
508
|
// Convert output coordinates to flat index
|
|
505
509
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
506
510
|
// Grad is the product of other elements of the same axis, which is product of all els divided by the current value
|
|
@@ -537,9 +541,9 @@ class Tensor {
|
|
|
537
541
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
538
542
|
// Calculate sums and how many elements contribute to specific positions
|
|
539
543
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
540
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
541
544
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
542
|
-
const outCoords =
|
|
545
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
546
|
+
outCoords[dims] = 0;
|
|
543
547
|
// Convert output coordinates to flat index
|
|
544
548
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
545
549
|
// Calculate sum and contributors to the sum
|
|
@@ -562,9 +566,9 @@ class Tensor {
|
|
|
562
566
|
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
563
567
|
// Calculate grad by assigning 1 divided by the number of contributors to the position
|
|
564
568
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
565
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
566
569
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
567
|
-
const outCoords =
|
|
570
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
571
|
+
outCoords[dims] = 0;
|
|
568
572
|
// Convert output coordinates to flat index
|
|
569
573
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
570
574
|
// Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
|
|
@@ -600,9 +604,9 @@ class Tensor {
|
|
|
600
604
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
601
605
|
// Calculate maximum values of axes
|
|
602
606
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
603
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
604
607
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
605
|
-
const outCoords =
|
|
608
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
609
|
+
outCoords[dims] = 0;
|
|
606
610
|
// Convert output coordinates to flat index
|
|
607
611
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
608
612
|
// Get max over time
|
|
@@ -623,18 +627,18 @@ class Tensor {
|
|
|
623
627
|
const shareCounts = new Array(outputSize).fill(0);
|
|
624
628
|
const originalValue = this.value;
|
|
625
629
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
626
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
627
630
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
628
|
-
const outCoords =
|
|
631
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
632
|
+
outCoords[dims] = 0;
|
|
629
633
|
// Convert output coordinates to flat index
|
|
630
634
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
631
635
|
// We collect how many elements share the same max value first
|
|
632
636
|
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
633
637
|
}
|
|
634
638
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
635
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
636
639
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
637
|
-
const outCoords =
|
|
640
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
641
|
+
outCoords[dims] = 0;
|
|
638
642
|
// Convert output coordinates to flat index
|
|
639
643
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
640
644
|
// Here we share the grad between the elements that share the same max value
|
|
@@ -670,9 +674,9 @@ class Tensor {
|
|
|
670
674
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
671
675
|
// Calculate minimum values of axes
|
|
672
676
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
673
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
674
677
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
675
|
-
const outCoords =
|
|
678
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
679
|
+
outCoords[dims] = 0;
|
|
676
680
|
// Convert output coordinates to flat index
|
|
677
681
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
678
682
|
// Get min over time
|
|
@@ -693,18 +697,18 @@ class Tensor {
|
|
|
693
697
|
const shareCounts = new Array(outputSize).fill(0);
|
|
694
698
|
const originalValue = this.value;
|
|
695
699
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
696
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
697
700
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
698
|
-
const outCoords =
|
|
701
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
702
|
+
outCoords[dims] = 0;
|
|
699
703
|
// Convert output coordinates to flat index
|
|
700
704
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
701
705
|
// We collect how many elements share the same min value first
|
|
702
706
|
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
703
707
|
}
|
|
704
708
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
705
|
-
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
706
709
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
707
|
-
const outCoords =
|
|
710
|
+
const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
711
|
+
outCoords[dims] = 0;
|
|
708
712
|
// Convert output coordinates to flat index
|
|
709
713
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
710
714
|
// Here we share the grad between the elements that share the same min value
|
|
@@ -716,6 +720,14 @@ class Tensor {
|
|
|
716
720
|
}
|
|
717
721
|
return keepDims ? out : out.squeeze(dims);
|
|
718
722
|
}
|
|
723
|
+
// Tensor all condition reduction
|
|
724
|
+
all(dims, keepDims = false) {
|
|
725
|
+
return this.min(dims, keepDims).ne(0);
|
|
726
|
+
}
|
|
727
|
+
// Tensor any condition reduction
|
|
728
|
+
any(dims, keepDims = false) {
|
|
729
|
+
return this.max(dims, keepDims).ne(0);
|
|
730
|
+
}
|
|
719
731
|
// Tensor variance reduction
|
|
720
732
|
var(dims, keepDims = false) {
|
|
721
733
|
const meanXSquared = this.square().mean(dims, keepDims);
|
|
@@ -1199,7 +1211,7 @@ class Tensor {
|
|
|
1199
1211
|
}
|
|
1200
1212
|
// 1D tensor dot product
|
|
1201
1213
|
dot(other) {
|
|
1202
|
-
other =
|
|
1214
|
+
other = this.handleOther(other);
|
|
1203
1215
|
// Verify 1D shape
|
|
1204
1216
|
if (this.shape.length !== 1 || other.shape.length !== 1) {
|
|
1205
1217
|
throw new Error("Inputs are not 1D tensors");
|
|
@@ -1237,7 +1249,7 @@ class Tensor {
|
|
|
1237
1249
|
}
|
|
1238
1250
|
// Matrix multiplication
|
|
1239
1251
|
mm(other) {
|
|
1240
|
-
other =
|
|
1252
|
+
other = this.handleOther(other);
|
|
1241
1253
|
// Verify 2D shape
|
|
1242
1254
|
if (this.shape.length !== 2 || other.shape.length !== 2) {
|
|
1243
1255
|
throw new Error("Inputs are not matrices");
|
|
@@ -1292,7 +1304,7 @@ class Tensor {
|
|
|
1292
1304
|
}
|
|
1293
1305
|
// Batched 3D tensor matmul
|
|
1294
1306
|
bmm(other) {
|
|
1295
|
-
other =
|
|
1307
|
+
other = this.handleOther(other);
|
|
1296
1308
|
// Verify 3D shape
|
|
1297
1309
|
if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
|
|
1298
1310
|
throw new Error("Inputs are not 3D tensors with the same first dim size");
|
|
@@ -1350,7 +1362,7 @@ class Tensor {
|
|
|
1350
1362
|
}
|
|
1351
1363
|
// Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
|
|
1352
1364
|
mv(other) {
|
|
1353
|
-
other =
|
|
1365
|
+
other = this.handleOther(other);
|
|
1354
1366
|
// Verify 2D shape
|
|
1355
1367
|
if (this.shape.length !== 2 || other.shape.length !== 1) {
|
|
1356
1368
|
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
@@ -1359,7 +1371,7 @@ class Tensor {
|
|
|
1359
1371
|
}
|
|
1360
1372
|
// General matrix multiplication with different shapes
|
|
1361
1373
|
matmul(other) {
|
|
1362
|
-
other =
|
|
1374
|
+
other = this.handleOther(other);
|
|
1363
1375
|
const isThis1D = this.shape.length === 1;
|
|
1364
1376
|
const isOther1D = other.shape.length === 1;
|
|
1365
1377
|
if (isThis1D && isOther1D) {
|
|
@@ -1692,6 +1704,7 @@ class Tensor {
|
|
|
1692
1704
|
}
|
|
1693
1705
|
// Returns this tensor with value replaced with the value of another tensor
|
|
1694
1706
|
replace(other, allowShapeMismatch = false) {
|
|
1707
|
+
other = this.handleOther(other);
|
|
1695
1708
|
// Verify shape
|
|
1696
1709
|
if (!allowShapeMismatch) {
|
|
1697
1710
|
for (let index = 0; index < this.shape.length; index++) {
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,7 +55,7 @@ declare class LayerNorm {
|
|
|
55
55
|
eps: number;
|
|
56
56
|
normalizedShape: number[];
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
|
-
forward(input: Tensor
|
|
58
|
+
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
60
|
export interface StateDict {
|
|
61
61
|
[key: string]: any;
|
|
@@ -68,6 +68,7 @@ export declare const nn: {
|
|
|
68
68
|
LayerNorm: typeof LayerNorm;
|
|
69
69
|
state: {
|
|
70
70
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
71
|
+
moveParameters(model: any, device: string): void;
|
|
71
72
|
getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
|
|
72
73
|
loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
|
|
73
74
|
};
|
package/dist/nn.js
CHANGED
|
@@ -20,7 +20,7 @@ class Linear {
|
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
forward(input) {
|
|
23
|
-
input =
|
|
23
|
+
input = this.weight.handleOther(input);
|
|
24
24
|
return linearTransform(input, this.weight, this.bias);
|
|
25
25
|
}
|
|
26
26
|
}
|
|
@@ -49,8 +49,8 @@ class RNNCell {
|
|
|
49
49
|
}
|
|
50
50
|
}
|
|
51
51
|
forward(input, hidden) {
|
|
52
|
-
input =
|
|
53
|
-
hidden =
|
|
52
|
+
input = this.weightIH.handleOther(input);
|
|
53
|
+
hidden = this.weightHH.handleOther(hidden);
|
|
54
54
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
55
55
|
}
|
|
56
56
|
}
|
|
@@ -85,8 +85,8 @@ class GRUCell {
|
|
|
85
85
|
}
|
|
86
86
|
}
|
|
87
87
|
forward(input, hidden) {
|
|
88
|
-
input =
|
|
89
|
-
hidden =
|
|
88
|
+
input = this.weightIN.handleOther(input);
|
|
89
|
+
hidden = this.weightHN.handleOther(hidden);
|
|
90
90
|
const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
|
|
91
91
|
const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
|
|
92
92
|
const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
|
|
@@ -132,9 +132,9 @@ class LSTMCell {
|
|
|
132
132
|
}
|
|
133
133
|
}
|
|
134
134
|
forward(input, hidden, cell) {
|
|
135
|
-
input =
|
|
136
|
-
hidden =
|
|
137
|
-
cell =
|
|
135
|
+
input = this.weightII.handleOther(input);
|
|
136
|
+
hidden = this.weightHI.handleOther(hidden);
|
|
137
|
+
cell = this.weightHI.handleOther(cell);
|
|
138
138
|
const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
|
|
139
139
|
const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
|
|
140
140
|
const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
|
|
@@ -163,7 +163,6 @@ class LayerNorm {
|
|
|
163
163
|
}
|
|
164
164
|
}
|
|
165
165
|
forward(input) {
|
|
166
|
-
input = core_1.Tensor.forceTensor(input);
|
|
167
166
|
// Normalize over the specified dimensions
|
|
168
167
|
const normalizedDims = this.normalizedShape.length;
|
|
169
168
|
const startDim = input.shape.length - normalizedDims;
|
|
@@ -208,6 +207,12 @@ const state = {
|
|
|
208
207
|
}
|
|
209
208
|
return parameters;
|
|
210
209
|
},
|
|
210
|
+
moveParameters(model, device) {
|
|
211
|
+
const params = state.getParameters(model);
|
|
212
|
+
for (const param of params) {
|
|
213
|
+
param.to_(device);
|
|
214
|
+
}
|
|
215
|
+
},
|
|
211
216
|
getStateDict(model, prefix = "", visited = new WeakSet()) {
|
|
212
217
|
if (visited.has(model))
|
|
213
218
|
return {};
|
package/dist/optim.d.ts
CHANGED
|
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
|
|
|
38
38
|
constructor(params: Tensor[], options?: AdamOptions);
|
|
39
39
|
step(): void;
|
|
40
40
|
}
|
|
41
|
+
export interface AdamWOptions {
|
|
42
|
+
lr?: number;
|
|
43
|
+
betas?: [number, number];
|
|
44
|
+
eps?: number;
|
|
45
|
+
weightDecay?: number;
|
|
46
|
+
}
|
|
47
|
+
declare class AdamW extends BaseOptimizer {
|
|
48
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
49
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
50
|
+
stepCount: number;
|
|
51
|
+
lr: number;
|
|
52
|
+
betas: [number, number];
|
|
53
|
+
eps: number;
|
|
54
|
+
weightDecay: number;
|
|
55
|
+
constructor(params: Tensor[], options?: AdamWOptions);
|
|
56
|
+
step(): void;
|
|
57
|
+
}
|
|
41
58
|
export declare class Optim {
|
|
42
59
|
static BaseOptimizer: typeof BaseOptimizer;
|
|
43
60
|
static SGD: typeof SGD;
|
|
44
61
|
static Adam: typeof Adam;
|
|
62
|
+
static AdamW: typeof AdamW;
|
|
45
63
|
}
|
|
46
64
|
export {};
|
package/dist/optim.js
CHANGED
|
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
|
|
|
126
126
|
}
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
+
class AdamW extends BaseOptimizer {
|
|
130
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
131
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
132
|
+
stepCount = 0;
|
|
133
|
+
lr;
|
|
134
|
+
betas;
|
|
135
|
+
eps;
|
|
136
|
+
weightDecay;
|
|
137
|
+
constructor(params, options) {
|
|
138
|
+
super(params);
|
|
139
|
+
this.lr = options?.lr || 0.001;
|
|
140
|
+
this.betas = options?.betas || [0.9, 0.999];
|
|
141
|
+
this.eps = options?.eps || 1e-8;
|
|
142
|
+
this.weightDecay = options?.weightDecay || 0;
|
|
143
|
+
}
|
|
144
|
+
step() {
|
|
145
|
+
this.stepCount++;
|
|
146
|
+
const beta1 = this.betas[0];
|
|
147
|
+
const beta2 = this.betas[1];
|
|
148
|
+
// Bias correction factors
|
|
149
|
+
const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
|
|
150
|
+
const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
|
|
151
|
+
for (const param of this.params) {
|
|
152
|
+
if (!param.grad || !param.requiresGrad)
|
|
153
|
+
continue;
|
|
154
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
155
|
+
// Apply weight decay (L2 regularization)
|
|
156
|
+
detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
|
|
157
|
+
// Get or initialize first moment buffer (momentum)
|
|
158
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
159
|
+
if (!momentumBuffer) {
|
|
160
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
161
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
162
|
+
}
|
|
163
|
+
// Get or initialize second moment buffer (velocity)
|
|
164
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
165
|
+
if (!velocityBuffer) {
|
|
166
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
167
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
168
|
+
}
|
|
169
|
+
// Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
|
|
170
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
171
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
172
|
+
// Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
|
|
173
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
174
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
175
|
+
// Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
|
|
176
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
177
|
+
// Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
|
|
178
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
179
|
+
// Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
|
|
180
|
+
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
181
|
+
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
182
|
+
const newParam = detachedParam.sub(stepSize);
|
|
183
|
+
param.replace(newParam);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
129
187
|
class Optim {
|
|
130
188
|
static BaseOptimizer = BaseOptimizer;
|
|
131
189
|
static SGD = SGD;
|
|
132
190
|
static Adam = Adam;
|
|
191
|
+
static AdamW = AdamW;
|
|
133
192
|
}
|
|
134
193
|
exports.Optim = Optim;
|