catniff 0.5.9 → 0.5.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -38,7 +38,7 @@ export declare class Tensor {
38
38
  static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
39
39
  elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
40
40
  elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
41
- static forceTensor(value: TensorValue | Tensor): Tensor;
41
+ handleOther(other: Tensor | TensorValue): Tensor;
42
42
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
43
43
  isContiguous(): boolean;
44
44
  contiguous(): Tensor;
@@ -174,7 +174,7 @@ export declare class Tensor {
174
174
  withGrad(requiresGrad: boolean): Tensor;
175
175
  detach(): Tensor;
176
176
  clone(): Tensor;
177
- replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
177
+ replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
178
178
  static backends: Map<string, Backend>;
179
179
  to(device: string): Tensor;
180
180
  to_(device: string): Tensor;
package/dist/core.js CHANGED
@@ -183,7 +183,7 @@ class Tensor {
183
183
  }
184
184
  // Utility to do element-wise operation and build a dag node with another tensor
185
185
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
186
- other = Tensor.forceTensor(other);
186
+ other = this.handleOther(other);
187
187
  const out = Tensor.elementWiseAB(this, other, op);
188
188
  if (this.requiresGrad) {
189
189
  out.requiresGrad = true;
@@ -225,11 +225,15 @@ class Tensor {
225
225
  }
226
226
  return out;
227
227
  }
228
- // Utility to force an input value to be a tensor
229
- static forceTensor(value) {
230
- if (value instanceof Tensor)
231
- return value;
232
- return new Tensor(value);
228
+ // Utility to handle other tensor if an op needs a second operand
229
+ handleOther(other) {
230
+ if (other instanceof Tensor) {
231
+ if (this.device !== other.device) {
232
+ throw new Error("Can not operate on tensors that are not on the same device");
233
+ }
234
+ return other;
235
+ }
236
+ return new Tensor(other, { device: this.device });
233
237
  }
234
238
  // Utility to add to gradient of tensor
235
239
  static addGrad(tensor, accumGrad) {
@@ -1199,7 +1203,7 @@ class Tensor {
1199
1203
  }
1200
1204
  // 1D tensor dot product
1201
1205
  dot(other) {
1202
- other = Tensor.forceTensor(other);
1206
+ other = this.handleOther(other);
1203
1207
  // Verify 1D shape
1204
1208
  if (this.shape.length !== 1 || other.shape.length !== 1) {
1205
1209
  throw new Error("Inputs are not 1D tensors");
@@ -1237,7 +1241,7 @@ class Tensor {
1237
1241
  }
1238
1242
  // Matrix multiplication
1239
1243
  mm(other) {
1240
- other = Tensor.forceTensor(other);
1244
+ other = this.handleOther(other);
1241
1245
  // Verify 2D shape
1242
1246
  if (this.shape.length !== 2 || other.shape.length !== 2) {
1243
1247
  throw new Error("Inputs are not matrices");
@@ -1292,7 +1296,7 @@ class Tensor {
1292
1296
  }
1293
1297
  // Batched 3D tensor matmul
1294
1298
  bmm(other) {
1295
- other = Tensor.forceTensor(other);
1299
+ other = this.handleOther(other);
1296
1300
  // Verify 3D shape
1297
1301
  if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
1298
1302
  throw new Error("Inputs are not 3D tensors with the same first dim size");
@@ -1350,7 +1354,7 @@ class Tensor {
1350
1354
  }
1351
1355
  // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
1352
1356
  mv(other) {
1353
- other = Tensor.forceTensor(other);
1357
+ other = this.handleOther(other);
1354
1358
  // Verify 2D shape
1355
1359
  if (this.shape.length !== 2 || other.shape.length !== 1) {
1356
1360
  throw new Error("Input is not a 2D and 1D tensor pair");
@@ -1359,7 +1363,7 @@ class Tensor {
1359
1363
  }
1360
1364
  // General matrix multiplication with different shapes
1361
1365
  matmul(other) {
1362
- other = Tensor.forceTensor(other);
1366
+ other = this.handleOther(other);
1363
1367
  const isThis1D = this.shape.length === 1;
1364
1368
  const isOther1D = other.shape.length === 1;
1365
1369
  if (isThis1D && isOther1D) {
@@ -1692,6 +1696,7 @@ class Tensor {
1692
1696
  }
1693
1697
  // Returns this tensor with value replaced with the value of another tensor
1694
1698
  replace(other, allowShapeMismatch = false) {
1699
+ other = this.handleOther(other);
1695
1700
  // Verify shape
1696
1701
  if (!allowShapeMismatch) {
1697
1702
  for (let index = 0; index < this.shape.length; index++) {
package/dist/nn.d.ts CHANGED
@@ -55,7 +55,7 @@ declare class LayerNorm {
55
55
  eps: number;
56
56
  normalizedShape: number[];
57
57
  constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
58
- forward(input: Tensor | TensorValue): Tensor;
58
+ forward(input: Tensor): Tensor;
59
59
  }
60
60
  export interface StateDict {
61
61
  [key: string]: any;
@@ -68,6 +68,7 @@ export declare const nn: {
68
68
  LayerNorm: typeof LayerNorm;
69
69
  state: {
70
70
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
71
+ moveParameters(model: any, device: string): void;
71
72
  getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
72
73
  loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
73
74
  };
package/dist/nn.js CHANGED
@@ -20,7 +20,7 @@ class Linear {
20
20
  }
21
21
  }
22
22
  forward(input) {
23
- input = core_1.Tensor.forceTensor(input);
23
+ input = this.weight.handleOther(input);
24
24
  return linearTransform(input, this.weight, this.bias);
25
25
  }
26
26
  }
@@ -49,8 +49,8 @@ class RNNCell {
49
49
  }
50
50
  }
51
51
  forward(input, hidden) {
52
- input = core_1.Tensor.forceTensor(input);
53
- hidden = core_1.Tensor.forceTensor(hidden);
52
+ input = this.weightIH.handleOther(input);
53
+ hidden = this.weightHH.handleOther(hidden);
54
54
  return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
55
55
  }
56
56
  }
@@ -85,8 +85,8 @@ class GRUCell {
85
85
  }
86
86
  }
87
87
  forward(input, hidden) {
88
- input = core_1.Tensor.forceTensor(input);
89
- hidden = core_1.Tensor.forceTensor(hidden);
88
+ input = this.weightIN.handleOther(input);
89
+ hidden = this.weightHN.handleOther(hidden);
90
90
  const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
91
91
  const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
92
92
  const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
@@ -132,9 +132,9 @@ class LSTMCell {
132
132
  }
133
133
  }
134
134
  forward(input, hidden, cell) {
135
- input = core_1.Tensor.forceTensor(input);
136
- hidden = core_1.Tensor.forceTensor(hidden);
137
- cell = core_1.Tensor.forceTensor(cell);
135
+ input = this.weightII.handleOther(input);
136
+ hidden = this.weightHI.handleOther(hidden);
137
+ cell = this.weightHI.handleOther(cell);
138
138
  const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
139
139
  const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
140
140
  const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
@@ -163,7 +163,6 @@ class LayerNorm {
163
163
  }
164
164
  }
165
165
  forward(input) {
166
- input = core_1.Tensor.forceTensor(input);
167
166
  // Normalize over the specified dimensions
168
167
  const normalizedDims = this.normalizedShape.length;
169
168
  const startDim = input.shape.length - normalizedDims;
@@ -208,6 +207,12 @@ const state = {
208
207
  }
209
208
  return parameters;
210
209
  },
210
+ moveParameters(model, device) {
211
+ const params = state.getParameters(model);
212
+ for (const param of params) {
213
+ param.to_(device);
214
+ }
215
+ },
211
216
  getStateDict(model, prefix = "", visited = new WeakSet()) {
212
217
  if (visited.has(model))
213
218
  return {};
package/dist/optim.d.ts CHANGED
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
38
38
  constructor(params: Tensor[], options?: AdamOptions);
39
39
  step(): void;
40
40
  }
41
+ export interface AdamWOptions {
42
+ lr?: number;
43
+ betas?: [number, number];
44
+ eps?: number;
45
+ weightDecay?: number;
46
+ }
47
+ declare class AdamW extends BaseOptimizer {
48
+ momentumBuffers: Map<Tensor, Tensor>;
49
+ velocityBuffers: Map<Tensor, Tensor>;
50
+ stepCount: number;
51
+ lr: number;
52
+ betas: [number, number];
53
+ eps: number;
54
+ weightDecay: number;
55
+ constructor(params: Tensor[], options?: AdamWOptions);
56
+ step(): void;
57
+ }
41
58
  export declare class Optim {
42
59
  static BaseOptimizer: typeof BaseOptimizer;
43
60
  static SGD: typeof SGD;
44
61
  static Adam: typeof Adam;
62
+ static AdamW: typeof AdamW;
45
63
  }
46
64
  export {};
package/dist/optim.js CHANGED
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
126
126
  }
127
127
  }
128
128
  }
129
+ class AdamW extends BaseOptimizer {
130
+ momentumBuffers = new Map(); // First moment (m_t)
131
+ velocityBuffers = new Map(); // Second moment (v_t)
132
+ stepCount = 0;
133
+ lr;
134
+ betas;
135
+ eps;
136
+ weightDecay;
137
+ constructor(params, options) {
138
+ super(params);
139
+ this.lr = options?.lr || 0.001;
140
+ this.betas = options?.betas || [0.9, 0.999];
141
+ this.eps = options?.eps || 1e-8;
142
+ this.weightDecay = options?.weightDecay || 0;
143
+ }
144
+ step() {
145
+ this.stepCount++;
146
+ const beta1 = this.betas[0];
147
+ const beta2 = this.betas[1];
148
+ // Bias correction factors
149
+ const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
150
+ const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
151
+ for (const param of this.params) {
152
+ if (!param.grad || !param.requiresGrad)
153
+ continue;
154
+ let grad = param.grad.detach(), detachedParam = param.detach();
155
+ // Apply weight decay (L2 regularization)
156
+ detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
157
+ // Get or initialize first moment buffer (momentum)
158
+ let momentumBuffer = this.momentumBuffers.get(param);
159
+ if (!momentumBuffer) {
160
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
161
+ this.momentumBuffers.set(param, momentumBuffer);
162
+ }
163
+ // Get or initialize second moment buffer (velocity)
164
+ let velocityBuffer = this.velocityBuffers.get(param);
165
+ if (!velocityBuffer) {
166
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
167
+ this.velocityBuffers.set(param, velocityBuffer);
168
+ }
169
+ // Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
170
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
171
+ this.momentumBuffers.set(param, momentumBuffer);
172
+ // Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
173
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
174
+ this.velocityBuffers.set(param, velocityBuffer);
175
+ // Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
176
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
177
+ // Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
178
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
179
+ // Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
180
+ const denom = correctedVelocity.sqrt().add(this.eps);
181
+ const stepSize = correctedMomentum.div(denom).mul(this.lr);
182
+ const newParam = detachedParam.sub(stepSize);
183
+ param.replace(newParam);
184
+ }
185
+ }
186
+ }
129
187
  class Optim {
130
188
  static BaseOptimizer = BaseOptimizer;
131
189
  static SGD = SGD;
132
190
  static Adam = Adam;
191
+ static AdamW = AdamW;
133
192
  }
134
193
  exports.Optim = Optim;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.9",
3
+ "version": "0.5.10",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {