catniff 0.5.8 → 0.5.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/backend.d.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import { Tensor } from "./core";
2
2
  export interface Backend {
3
+ create(tensor: Tensor): void;
3
4
  transfer(tensor: Tensor): Tensor;
4
5
  }
package/dist/core.d.ts CHANGED
@@ -38,7 +38,7 @@ export declare class Tensor {
38
38
  static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
39
39
  elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
40
40
  elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
41
- static forceTensor(value: TensorValue | Tensor): Tensor;
41
+ handleOther(other: Tensor | TensorValue): Tensor;
42
42
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
43
43
  isContiguous(): boolean;
44
44
  contiguous(): Tensor;
@@ -174,7 +174,8 @@ export declare class Tensor {
174
174
  withGrad(requiresGrad: boolean): Tensor;
175
175
  detach(): Tensor;
176
176
  clone(): Tensor;
177
- replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
177
+ replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
178
178
  static backends: Map<string, Backend>;
179
179
  to(device: string): Tensor;
180
+ to_(device: string): Tensor;
180
181
  }
package/dist/core.js CHANGED
@@ -21,13 +21,8 @@ class Tensor {
21
21
  this.gradFn = options.gradFn || (() => { });
22
22
  this.children = options.children || [];
23
23
  this.device = options.device || "cpu";
24
- // Move tensor to device
25
- if (this.device !== "cpu") {
26
- const backend = Tensor.backends.get(this.device);
27
- if (backend && backend.transfer) {
28
- backend.transfer(this);
29
- }
30
- }
24
+ // Move to device in-place
25
+ this.to_(this.device);
31
26
  }
32
27
  // Utility to flatten an nD array to be 1D
33
28
  static flatten(tensor) {
@@ -188,7 +183,7 @@ class Tensor {
188
183
  }
189
184
  // Utility to do element-wise operation and build a dag node with another tensor
190
185
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
191
- other = Tensor.forceTensor(other);
186
+ other = this.handleOther(other);
192
187
  const out = Tensor.elementWiseAB(this, other, op);
193
188
  if (this.requiresGrad) {
194
189
  out.requiresGrad = true;
@@ -230,11 +225,15 @@ class Tensor {
230
225
  }
231
226
  return out;
232
227
  }
233
- // Utility to force an input value to be a tensor
234
- static forceTensor(value) {
235
- if (value instanceof Tensor)
236
- return value;
237
- return new Tensor(value);
228
+ // Utility to handle other tensor if an op needs a second operand
229
+ handleOther(other) {
230
+ if (other instanceof Tensor) {
231
+ if (this.device !== other.device) {
232
+ throw new Error("Can not operate on tensors that are not on the same device");
233
+ }
234
+ return other;
235
+ }
236
+ return new Tensor(other, { device: this.device });
238
237
  }
239
238
  // Utility to add to gradient of tensor
240
239
  static addGrad(tensor, accumGrad) {
@@ -1204,7 +1203,7 @@ class Tensor {
1204
1203
  }
1205
1204
  // 1D tensor dot product
1206
1205
  dot(other) {
1207
- other = Tensor.forceTensor(other);
1206
+ other = this.handleOther(other);
1208
1207
  // Verify 1D shape
1209
1208
  if (this.shape.length !== 1 || other.shape.length !== 1) {
1210
1209
  throw new Error("Inputs are not 1D tensors");
@@ -1242,7 +1241,7 @@ class Tensor {
1242
1241
  }
1243
1242
  // Matrix multiplication
1244
1243
  mm(other) {
1245
- other = Tensor.forceTensor(other);
1244
+ other = this.handleOther(other);
1246
1245
  // Verify 2D shape
1247
1246
  if (this.shape.length !== 2 || other.shape.length !== 2) {
1248
1247
  throw new Error("Inputs are not matrices");
@@ -1297,7 +1296,7 @@ class Tensor {
1297
1296
  }
1298
1297
  // Batched 3D tensor matmul
1299
1298
  bmm(other) {
1300
- other = Tensor.forceTensor(other);
1299
+ other = this.handleOther(other);
1301
1300
  // Verify 3D shape
1302
1301
  if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
1303
1302
  throw new Error("Inputs are not 3D tensors with the same first dim size");
@@ -1355,7 +1354,7 @@ class Tensor {
1355
1354
  }
1356
1355
  // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
1357
1356
  mv(other) {
1358
- other = Tensor.forceTensor(other);
1357
+ other = this.handleOther(other);
1359
1358
  // Verify 2D shape
1360
1359
  if (this.shape.length !== 2 || other.shape.length !== 1) {
1361
1360
  throw new Error("Input is not a 2D and 1D tensor pair");
@@ -1364,7 +1363,7 @@ class Tensor {
1364
1363
  }
1365
1364
  // General matrix multiplication with different shapes
1366
1365
  matmul(other) {
1367
- other = Tensor.forceTensor(other);
1366
+ other = this.handleOther(other);
1368
1367
  const isThis1D = this.shape.length === 1;
1369
1368
  const isOther1D = other.shape.length === 1;
1370
1369
  if (isThis1D && isOther1D) {
@@ -1697,6 +1696,7 @@ class Tensor {
1697
1696
  }
1698
1697
  // Returns this tensor with value replaced with the value of another tensor
1699
1698
  replace(other, allowShapeMismatch = false) {
1699
+ other = this.handleOther(other);
1700
1700
  // Verify shape
1701
1701
  if (!allowShapeMismatch) {
1702
1702
  for (let index = 0; index < this.shape.length; index++) {
@@ -1712,9 +1712,21 @@ class Tensor {
1712
1712
  static backends = new Map();
1713
1713
  // Op to transfer tensor to another device
1714
1714
  to(device) {
1715
+ if (device === "cpu")
1716
+ return this;
1715
1717
  const backend = Tensor.backends.get(device);
1716
1718
  if (backend && backend.transfer) {
1717
- backend.transfer(this);
1719
+ return backend.transfer(this);
1720
+ }
1721
+ throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1722
+ }
1723
+ // Op to transfer tensor to another device in-place
1724
+ to_(device) {
1725
+ if (device === "cpu")
1726
+ return this;
1727
+ const backend = Tensor.backends.get(this.device);
1728
+ if (backend && backend.create) {
1729
+ backend.create(this);
1718
1730
  return this;
1719
1731
  }
1720
1732
  throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
package/dist/nn.d.ts CHANGED
@@ -55,7 +55,7 @@ declare class LayerNorm {
55
55
  eps: number;
56
56
  normalizedShape: number[];
57
57
  constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
58
- forward(input: Tensor | TensorValue): Tensor;
58
+ forward(input: Tensor): Tensor;
59
59
  }
60
60
  export interface StateDict {
61
61
  [key: string]: any;
@@ -68,6 +68,7 @@ export declare const nn: {
68
68
  LayerNorm: typeof LayerNorm;
69
69
  state: {
70
70
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
71
+ moveParameters(model: any, device: string): void;
71
72
  getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
72
73
  loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
73
74
  };
package/dist/nn.js CHANGED
@@ -20,7 +20,7 @@ class Linear {
20
20
  }
21
21
  }
22
22
  forward(input) {
23
- input = core_1.Tensor.forceTensor(input);
23
+ input = this.weight.handleOther(input);
24
24
  return linearTransform(input, this.weight, this.bias);
25
25
  }
26
26
  }
@@ -49,8 +49,8 @@ class RNNCell {
49
49
  }
50
50
  }
51
51
  forward(input, hidden) {
52
- input = core_1.Tensor.forceTensor(input);
53
- hidden = core_1.Tensor.forceTensor(hidden);
52
+ input = this.weightIH.handleOther(input);
53
+ hidden = this.weightHH.handleOther(hidden);
54
54
  return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
55
55
  }
56
56
  }
@@ -85,8 +85,8 @@ class GRUCell {
85
85
  }
86
86
  }
87
87
  forward(input, hidden) {
88
- input = core_1.Tensor.forceTensor(input);
89
- hidden = core_1.Tensor.forceTensor(hidden);
88
+ input = this.weightIN.handleOther(input);
89
+ hidden = this.weightHN.handleOther(hidden);
90
90
  const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
91
91
  const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
92
92
  const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
@@ -132,9 +132,9 @@ class LSTMCell {
132
132
  }
133
133
  }
134
134
  forward(input, hidden, cell) {
135
- input = core_1.Tensor.forceTensor(input);
136
- hidden = core_1.Tensor.forceTensor(hidden);
137
- cell = core_1.Tensor.forceTensor(cell);
135
+ input = this.weightII.handleOther(input);
136
+ hidden = this.weightHI.handleOther(hidden);
137
+ cell = this.weightHI.handleOther(cell);
138
138
  const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
139
139
  const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
140
140
  const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
@@ -163,7 +163,6 @@ class LayerNorm {
163
163
  }
164
164
  }
165
165
  forward(input) {
166
- input = core_1.Tensor.forceTensor(input);
167
166
  // Normalize over the specified dimensions
168
167
  const normalizedDims = this.normalizedShape.length;
169
168
  const startDim = input.shape.length - normalizedDims;
@@ -208,6 +207,12 @@ const state = {
208
207
  }
209
208
  return parameters;
210
209
  },
210
+ moveParameters(model, device) {
211
+ const params = state.getParameters(model);
212
+ for (const param of params) {
213
+ param.to_(device);
214
+ }
215
+ },
211
216
  getStateDict(model, prefix = "", visited = new WeakSet()) {
212
217
  if (visited.has(model))
213
218
  return {};
package/dist/optim.d.ts CHANGED
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
38
38
  constructor(params: Tensor[], options?: AdamOptions);
39
39
  step(): void;
40
40
  }
41
+ export interface AdamWOptions {
42
+ lr?: number;
43
+ betas?: [number, number];
44
+ eps?: number;
45
+ weightDecay?: number;
46
+ }
47
+ declare class AdamW extends BaseOptimizer {
48
+ momentumBuffers: Map<Tensor, Tensor>;
49
+ velocityBuffers: Map<Tensor, Tensor>;
50
+ stepCount: number;
51
+ lr: number;
52
+ betas: [number, number];
53
+ eps: number;
54
+ weightDecay: number;
55
+ constructor(params: Tensor[], options?: AdamWOptions);
56
+ step(): void;
57
+ }
41
58
  export declare class Optim {
42
59
  static BaseOptimizer: typeof BaseOptimizer;
43
60
  static SGD: typeof SGD;
44
61
  static Adam: typeof Adam;
62
+ static AdamW: typeof AdamW;
45
63
  }
46
64
  export {};
package/dist/optim.js CHANGED
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
126
126
  }
127
127
  }
128
128
  }
129
+ class AdamW extends BaseOptimizer {
130
+ momentumBuffers = new Map(); // First moment (m_t)
131
+ velocityBuffers = new Map(); // Second moment (v_t)
132
+ stepCount = 0;
133
+ lr;
134
+ betas;
135
+ eps;
136
+ weightDecay;
137
+ constructor(params, options) {
138
+ super(params);
139
+ this.lr = options?.lr || 0.001;
140
+ this.betas = options?.betas || [0.9, 0.999];
141
+ this.eps = options?.eps || 1e-8;
142
+ this.weightDecay = options?.weightDecay || 0;
143
+ }
144
+ step() {
145
+ this.stepCount++;
146
+ const beta1 = this.betas[0];
147
+ const beta2 = this.betas[1];
148
+ // Bias correction factors
149
+ const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
150
+ const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
151
+ for (const param of this.params) {
152
+ if (!param.grad || !param.requiresGrad)
153
+ continue;
154
+ let grad = param.grad.detach(), detachedParam = param.detach();
155
+ // Apply weight decay (L2 regularization)
156
+ detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
157
+ // Get or initialize first moment buffer (momentum)
158
+ let momentumBuffer = this.momentumBuffers.get(param);
159
+ if (!momentumBuffer) {
160
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
161
+ this.momentumBuffers.set(param, momentumBuffer);
162
+ }
163
+ // Get or initialize second moment buffer (velocity)
164
+ let velocityBuffer = this.velocityBuffers.get(param);
165
+ if (!velocityBuffer) {
166
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
167
+ this.velocityBuffers.set(param, velocityBuffer);
168
+ }
169
+ // Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
170
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
171
+ this.momentumBuffers.set(param, momentumBuffer);
172
+ // Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
173
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
174
+ this.velocityBuffers.set(param, velocityBuffer);
175
+ // Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
176
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
177
+ // Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
178
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
179
+ // Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
180
+ const denom = correctedVelocity.sqrt().add(this.eps);
181
+ const stepSize = correctedMomentum.div(denom).mul(this.lr);
182
+ const newParam = detachedParam.sub(stepSize);
183
+ param.replace(newParam);
184
+ }
185
+ }
186
+ }
129
187
  class Optim {
130
188
  static BaseOptimizer = BaseOptimizer;
131
189
  static SGD = SGD;
132
190
  static Adam = Adam;
191
+ static AdamW = AdamW;
133
192
  }
134
193
  exports.Optim = Optim;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.8",
3
+ "version": "0.5.10",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {