catniff 0.5.9 → 0.5.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -38,7 +38,7 @@ export declare class Tensor {
38
38
  static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
39
39
  elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
40
40
  elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
41
- static forceTensor(value: TensorValue | Tensor): Tensor;
41
+ handleOther(other: Tensor | TensorValue): Tensor;
42
42
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
43
43
  isContiguous(): boolean;
44
44
  contiguous(): Tensor;
@@ -50,6 +50,8 @@ export declare class Tensor {
50
50
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
51
51
  max(dims?: number[] | number, keepDims?: boolean): Tensor;
52
52
  min(dims?: number[] | number, keepDims?: boolean): Tensor;
53
+ all(dims?: number[] | number, keepDims?: boolean): Tensor;
54
+ any(dims?: number[] | number, keepDims?: boolean): Tensor;
53
55
  var(dims?: number[] | number, keepDims?: boolean): Tensor;
54
56
  std(dims?: number[] | number, keepDims?: boolean): Tensor;
55
57
  softmax(dims?: number[] | number): Tensor;
@@ -174,7 +176,7 @@ export declare class Tensor {
174
176
  withGrad(requiresGrad: boolean): Tensor;
175
177
  detach(): Tensor;
176
178
  clone(): Tensor;
177
- replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
179
+ replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
178
180
  static backends: Map<string, Backend>;
179
181
  to(device: string): Tensor;
180
182
  to_(device: string): Tensor;
package/dist/core.js CHANGED
@@ -183,7 +183,7 @@ class Tensor {
183
183
  }
184
184
  // Utility to do element-wise operation and build a dag node with another tensor
185
185
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
186
- other = Tensor.forceTensor(other);
186
+ other = this.handleOther(other);
187
187
  const out = Tensor.elementWiseAB(this, other, op);
188
188
  if (this.requiresGrad) {
189
189
  out.requiresGrad = true;
@@ -225,11 +225,15 @@ class Tensor {
225
225
  }
226
226
  return out;
227
227
  }
228
- // Utility to force an input value to be a tensor
229
- static forceTensor(value) {
230
- if (value instanceof Tensor)
231
- return value;
232
- return new Tensor(value);
228
+ // Utility to handle other tensor if an op needs a second operand
229
+ handleOther(other) {
230
+ if (other instanceof Tensor) {
231
+ if (this.device !== other.device) {
232
+ throw new Error("Can not operate on tensors that are not on the same device");
233
+ }
234
+ return other;
235
+ }
236
+ return new Tensor(other, { device: this.device });
233
237
  }
234
238
  // Utility to add to gradient of tensor
235
239
  static addGrad(tensor, accumGrad) {
@@ -428,9 +432,9 @@ class Tensor {
428
432
  }
429
433
  // Calculate new value after sum
430
434
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
431
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
432
435
  // Force 0 on reduced axes to collapse into size-1 dims
433
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
436
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
437
+ outCoords[dims] = 0;
434
438
  // Convert output coordinates to flat index
435
439
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
436
440
  // Add into sum
@@ -479,9 +483,9 @@ class Tensor {
479
483
  const originalSize = Tensor.shapeToSize(this.shape);
480
484
  // Calculate new value after multiplying
481
485
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
482
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
483
486
  // Force 0 on reduced axes to collapse into size-1 dims
484
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
487
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
488
+ outCoords[dims] = 0;
485
489
  // Convert output coordinates to flat index
486
490
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
487
491
  // Multiply into product
@@ -498,9 +502,9 @@ class Tensor {
498
502
  out.gradFn = () => {
499
503
  const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
500
504
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
501
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
502
505
  // Force 0 on reduced axes to collapse into size-1 dims
503
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
506
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
507
+ outCoords[dims] = 0;
504
508
  // Convert output coordinates to flat index
505
509
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
506
510
  // Grad is the product of other elements of the same axis, which is product of all els divided by the current value
@@ -537,9 +541,9 @@ class Tensor {
537
541
  const originalSize = Tensor.shapeToSize(this.shape);
538
542
  // Calculate sums and how many elements contribute to specific positions
539
543
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
540
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
541
544
  // Force 0 on reduced axes to collapse into size-1 dims
542
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
545
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
546
+ outCoords[dims] = 0;
543
547
  // Convert output coordinates to flat index
544
548
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
545
549
  // Calculate sum and contributors to the sum
@@ -562,9 +566,9 @@ class Tensor {
562
566
  const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
563
567
  // Calculate grad by assigning 1 divided by the number of contributors to the position
564
568
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
565
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
566
569
  // Force 0 on reduced axes to collapse into size-1 dims
567
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
570
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
571
+ outCoords[dims] = 0;
568
572
  // Convert output coordinates to flat index
569
573
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
570
574
  // Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
@@ -600,9 +604,9 @@ class Tensor {
600
604
  const originalSize = Tensor.shapeToSize(this.shape);
601
605
  // Calculate maximum values of axes
602
606
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
603
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
604
607
  // Force 0 on reduced axes to collapse into size-1 dims
605
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
608
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
609
+ outCoords[dims] = 0;
606
610
  // Convert output coordinates to flat index
607
611
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
608
612
  // Get max over time
@@ -623,18 +627,18 @@ class Tensor {
623
627
  const shareCounts = new Array(outputSize).fill(0);
624
628
  const originalValue = this.value;
625
629
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
626
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
627
630
  // Force 0 on reduced axes to collapse into size-1 dims
628
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
631
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
632
+ outCoords[dims] = 0;
629
633
  // Convert output coordinates to flat index
630
634
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
631
635
  // We collect how many elements share the same max value first
632
636
  shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
633
637
  }
634
638
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
635
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
636
639
  // Force 0 on reduced axes to collapse into size-1 dims
637
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
640
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
641
+ outCoords[dims] = 0;
638
642
  // Convert output coordinates to flat index
639
643
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
640
644
  // Here we share the grad between the elements that share the same max value
@@ -670,9 +674,9 @@ class Tensor {
670
674
  const originalSize = Tensor.shapeToSize(this.shape);
671
675
  // Calculate minimum values of axes
672
676
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
673
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
674
677
  // Force 0 on reduced axes to collapse into size-1 dims
675
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
678
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
679
+ outCoords[dims] = 0;
676
680
  // Convert output coordinates to flat index
677
681
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
678
682
  // Get min over time
@@ -693,18 +697,18 @@ class Tensor {
693
697
  const shareCounts = new Array(outputSize).fill(0);
694
698
  const originalValue = this.value;
695
699
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
696
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
697
700
  // Force 0 on reduced axes to collapse into size-1 dims
698
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
701
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
702
+ outCoords[dims] = 0;
699
703
  // Convert output coordinates to flat index
700
704
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
701
705
  // We collect how many elements share the same min value first
702
706
  shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
703
707
  }
704
708
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
705
- const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
706
709
  // Force 0 on reduced axes to collapse into size-1 dims
707
- const outCoords = coords.map((val, i) => dims === i ? 0 : val);
710
+ const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
711
+ outCoords[dims] = 0;
708
712
  // Convert output coordinates to flat index
709
713
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
710
714
  // Here we share the grad between the elements that share the same min value
@@ -716,6 +720,14 @@ class Tensor {
716
720
  }
717
721
  return keepDims ? out : out.squeeze(dims);
718
722
  }
723
+ // Tensor all condition reduction
724
+ all(dims, keepDims = false) {
725
+ return this.min(dims, keepDims).ne(0);
726
+ }
727
+ // Tensor any condition reduction
728
+ any(dims, keepDims = false) {
729
+ return this.max(dims, keepDims).ne(0);
730
+ }
719
731
  // Tensor variance reduction
720
732
  var(dims, keepDims = false) {
721
733
  const meanXSquared = this.square().mean(dims, keepDims);
@@ -1199,7 +1211,7 @@ class Tensor {
1199
1211
  }
1200
1212
  // 1D tensor dot product
1201
1213
  dot(other) {
1202
- other = Tensor.forceTensor(other);
1214
+ other = this.handleOther(other);
1203
1215
  // Verify 1D shape
1204
1216
  if (this.shape.length !== 1 || other.shape.length !== 1) {
1205
1217
  throw new Error("Inputs are not 1D tensors");
@@ -1237,7 +1249,7 @@ class Tensor {
1237
1249
  }
1238
1250
  // Matrix multiplication
1239
1251
  mm(other) {
1240
- other = Tensor.forceTensor(other);
1252
+ other = this.handleOther(other);
1241
1253
  // Verify 2D shape
1242
1254
  if (this.shape.length !== 2 || other.shape.length !== 2) {
1243
1255
  throw new Error("Inputs are not matrices");
@@ -1292,7 +1304,7 @@ class Tensor {
1292
1304
  }
1293
1305
  // Batched 3D tensor matmul
1294
1306
  bmm(other) {
1295
- other = Tensor.forceTensor(other);
1307
+ other = this.handleOther(other);
1296
1308
  // Verify 3D shape
1297
1309
  if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
1298
1310
  throw new Error("Inputs are not 3D tensors with the same first dim size");
@@ -1350,7 +1362,7 @@ class Tensor {
1350
1362
  }
1351
1363
  // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
1352
1364
  mv(other) {
1353
- other = Tensor.forceTensor(other);
1365
+ other = this.handleOther(other);
1354
1366
  // Verify 2D shape
1355
1367
  if (this.shape.length !== 2 || other.shape.length !== 1) {
1356
1368
  throw new Error("Input is not a 2D and 1D tensor pair");
@@ -1359,7 +1371,7 @@ class Tensor {
1359
1371
  }
1360
1372
  // General matrix multiplication with different shapes
1361
1373
  matmul(other) {
1362
- other = Tensor.forceTensor(other);
1374
+ other = this.handleOther(other);
1363
1375
  const isThis1D = this.shape.length === 1;
1364
1376
  const isOther1D = other.shape.length === 1;
1365
1377
  if (isThis1D && isOther1D) {
@@ -1692,6 +1704,7 @@ class Tensor {
1692
1704
  }
1693
1705
  // Returns this tensor with value replaced with the value of another tensor
1694
1706
  replace(other, allowShapeMismatch = false) {
1707
+ other = this.handleOther(other);
1695
1708
  // Verify shape
1696
1709
  if (!allowShapeMismatch) {
1697
1710
  for (let index = 0; index < this.shape.length; index++) {
package/dist/nn.d.ts CHANGED
@@ -55,7 +55,7 @@ declare class LayerNorm {
55
55
  eps: number;
56
56
  normalizedShape: number[];
57
57
  constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
58
- forward(input: Tensor | TensorValue): Tensor;
58
+ forward(input: Tensor): Tensor;
59
59
  }
60
60
  export interface StateDict {
61
61
  [key: string]: any;
@@ -68,6 +68,7 @@ export declare const nn: {
68
68
  LayerNorm: typeof LayerNorm;
69
69
  state: {
70
70
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
71
+ moveParameters(model: any, device: string): void;
71
72
  getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
72
73
  loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
73
74
  };
package/dist/nn.js CHANGED
@@ -20,7 +20,7 @@ class Linear {
20
20
  }
21
21
  }
22
22
  forward(input) {
23
- input = core_1.Tensor.forceTensor(input);
23
+ input = this.weight.handleOther(input);
24
24
  return linearTransform(input, this.weight, this.bias);
25
25
  }
26
26
  }
@@ -49,8 +49,8 @@ class RNNCell {
49
49
  }
50
50
  }
51
51
  forward(input, hidden) {
52
- input = core_1.Tensor.forceTensor(input);
53
- hidden = core_1.Tensor.forceTensor(hidden);
52
+ input = this.weightIH.handleOther(input);
53
+ hidden = this.weightHH.handleOther(hidden);
54
54
  return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
55
55
  }
56
56
  }
@@ -85,8 +85,8 @@ class GRUCell {
85
85
  }
86
86
  }
87
87
  forward(input, hidden) {
88
- input = core_1.Tensor.forceTensor(input);
89
- hidden = core_1.Tensor.forceTensor(hidden);
88
+ input = this.weightIN.handleOther(input);
89
+ hidden = this.weightHN.handleOther(hidden);
90
90
  const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
91
91
  const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
92
92
  const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
@@ -132,9 +132,9 @@ class LSTMCell {
132
132
  }
133
133
  }
134
134
  forward(input, hidden, cell) {
135
- input = core_1.Tensor.forceTensor(input);
136
- hidden = core_1.Tensor.forceTensor(hidden);
137
- cell = core_1.Tensor.forceTensor(cell);
135
+ input = this.weightII.handleOther(input);
136
+ hidden = this.weightHI.handleOther(hidden);
137
+ cell = this.weightHI.handleOther(cell);
138
138
  const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
139
139
  const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
140
140
  const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
@@ -163,7 +163,6 @@ class LayerNorm {
163
163
  }
164
164
  }
165
165
  forward(input) {
166
- input = core_1.Tensor.forceTensor(input);
167
166
  // Normalize over the specified dimensions
168
167
  const normalizedDims = this.normalizedShape.length;
169
168
  const startDim = input.shape.length - normalizedDims;
@@ -208,6 +207,12 @@ const state = {
208
207
  }
209
208
  return parameters;
210
209
  },
210
+ moveParameters(model, device) {
211
+ const params = state.getParameters(model);
212
+ for (const param of params) {
213
+ param.to_(device);
214
+ }
215
+ },
211
216
  getStateDict(model, prefix = "", visited = new WeakSet()) {
212
217
  if (visited.has(model))
213
218
  return {};
package/dist/optim.d.ts CHANGED
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
38
38
  constructor(params: Tensor[], options?: AdamOptions);
39
39
  step(): void;
40
40
  }
41
+ export interface AdamWOptions {
42
+ lr?: number;
43
+ betas?: [number, number];
44
+ eps?: number;
45
+ weightDecay?: number;
46
+ }
47
+ declare class AdamW extends BaseOptimizer {
48
+ momentumBuffers: Map<Tensor, Tensor>;
49
+ velocityBuffers: Map<Tensor, Tensor>;
50
+ stepCount: number;
51
+ lr: number;
52
+ betas: [number, number];
53
+ eps: number;
54
+ weightDecay: number;
55
+ constructor(params: Tensor[], options?: AdamWOptions);
56
+ step(): void;
57
+ }
41
58
  export declare class Optim {
42
59
  static BaseOptimizer: typeof BaseOptimizer;
43
60
  static SGD: typeof SGD;
44
61
  static Adam: typeof Adam;
62
+ static AdamW: typeof AdamW;
45
63
  }
46
64
  export {};
package/dist/optim.js CHANGED
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
126
126
  }
127
127
  }
128
128
  }
129
+ class AdamW extends BaseOptimizer {
130
+ momentumBuffers = new Map(); // First moment (m_t)
131
+ velocityBuffers = new Map(); // Second moment (v_t)
132
+ stepCount = 0;
133
+ lr;
134
+ betas;
135
+ eps;
136
+ weightDecay;
137
+ constructor(params, options) {
138
+ super(params);
139
+ this.lr = options?.lr || 0.001;
140
+ this.betas = options?.betas || [0.9, 0.999];
141
+ this.eps = options?.eps || 1e-8;
142
+ this.weightDecay = options?.weightDecay || 0;
143
+ }
144
+ step() {
145
+ this.stepCount++;
146
+ const beta1 = this.betas[0];
147
+ const beta2 = this.betas[1];
148
+ // Bias correction factors
149
+ const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
150
+ const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
151
+ for (const param of this.params) {
152
+ if (!param.grad || !param.requiresGrad)
153
+ continue;
154
+ let grad = param.grad.detach(), detachedParam = param.detach();
155
+ // Apply weight decay (L2 regularization)
156
+ detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
157
+ // Get or initialize first moment buffer (momentum)
158
+ let momentumBuffer = this.momentumBuffers.get(param);
159
+ if (!momentumBuffer) {
160
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
161
+ this.momentumBuffers.set(param, momentumBuffer);
162
+ }
163
+ // Get or initialize second moment buffer (velocity)
164
+ let velocityBuffer = this.velocityBuffers.get(param);
165
+ if (!velocityBuffer) {
166
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
167
+ this.velocityBuffers.set(param, velocityBuffer);
168
+ }
169
+ // Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
170
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
171
+ this.momentumBuffers.set(param, momentumBuffer);
172
+ // Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
173
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
174
+ this.velocityBuffers.set(param, velocityBuffer);
175
+ // Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
176
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
177
+ // Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
178
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
179
+ // Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
180
+ const denom = correctedVelocity.sqrt().add(this.eps);
181
+ const stepSize = correctedMomentum.div(denom).mul(this.lr);
182
+ const newParam = detachedParam.sub(stepSize);
183
+ param.replace(newParam);
184
+ }
185
+ }
186
+ }
129
187
  class Optim {
130
188
  static BaseOptimizer = BaseOptimizer;
131
189
  static SGD = SGD;
132
190
  static Adam = Adam;
191
+ static AdamW = AdamW;
133
192
  }
134
193
  exports.Optim = Optim;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.9",
3
+ "version": "0.5.11",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {