catniff 0.5.9 → 0.5.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -2
- package/dist/core.js +16 -11
- package/dist/nn.d.ts +2 -1
- package/dist/nn.js +14 -9
- package/dist/optim.d.ts +18 -0
- package/dist/optim.js +59 -0
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -38,7 +38,7 @@ export declare class Tensor {
|
|
|
38
38
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
39
39
|
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
40
40
|
elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
41
|
-
|
|
41
|
+
handleOther(other: Tensor | TensorValue): Tensor;
|
|
42
42
|
static addGrad(tensor: Tensor, accumGrad: Tensor): void;
|
|
43
43
|
isContiguous(): boolean;
|
|
44
44
|
contiguous(): Tensor;
|
|
@@ -174,7 +174,7 @@ export declare class Tensor {
|
|
|
174
174
|
withGrad(requiresGrad: boolean): Tensor;
|
|
175
175
|
detach(): Tensor;
|
|
176
176
|
clone(): Tensor;
|
|
177
|
-
replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
|
|
177
|
+
replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
|
|
178
178
|
static backends: Map<string, Backend>;
|
|
179
179
|
to(device: string): Tensor;
|
|
180
180
|
to_(device: string): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -183,7 +183,7 @@ class Tensor {
|
|
|
183
183
|
}
|
|
184
184
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
185
185
|
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
186
|
-
other =
|
|
186
|
+
other = this.handleOther(other);
|
|
187
187
|
const out = Tensor.elementWiseAB(this, other, op);
|
|
188
188
|
if (this.requiresGrad) {
|
|
189
189
|
out.requiresGrad = true;
|
|
@@ -225,11 +225,15 @@ class Tensor {
|
|
|
225
225
|
}
|
|
226
226
|
return out;
|
|
227
227
|
}
|
|
228
|
-
// Utility to
|
|
229
|
-
|
|
230
|
-
if (
|
|
231
|
-
|
|
232
|
-
|
|
228
|
+
// Utility to handle other tensor if an op needs a second operand
|
|
229
|
+
handleOther(other) {
|
|
230
|
+
if (other instanceof Tensor) {
|
|
231
|
+
if (this.device !== other.device) {
|
|
232
|
+
throw new Error("Can not operate on tensors that are not on the same device");
|
|
233
|
+
}
|
|
234
|
+
return other;
|
|
235
|
+
}
|
|
236
|
+
return new Tensor(other, { device: this.device });
|
|
233
237
|
}
|
|
234
238
|
// Utility to add to gradient of tensor
|
|
235
239
|
static addGrad(tensor, accumGrad) {
|
|
@@ -1199,7 +1203,7 @@ class Tensor {
|
|
|
1199
1203
|
}
|
|
1200
1204
|
// 1D tensor dot product
|
|
1201
1205
|
dot(other) {
|
|
1202
|
-
other =
|
|
1206
|
+
other = this.handleOther(other);
|
|
1203
1207
|
// Verify 1D shape
|
|
1204
1208
|
if (this.shape.length !== 1 || other.shape.length !== 1) {
|
|
1205
1209
|
throw new Error("Inputs are not 1D tensors");
|
|
@@ -1237,7 +1241,7 @@ class Tensor {
|
|
|
1237
1241
|
}
|
|
1238
1242
|
// Matrix multiplication
|
|
1239
1243
|
mm(other) {
|
|
1240
|
-
other =
|
|
1244
|
+
other = this.handleOther(other);
|
|
1241
1245
|
// Verify 2D shape
|
|
1242
1246
|
if (this.shape.length !== 2 || other.shape.length !== 2) {
|
|
1243
1247
|
throw new Error("Inputs are not matrices");
|
|
@@ -1292,7 +1296,7 @@ class Tensor {
|
|
|
1292
1296
|
}
|
|
1293
1297
|
// Batched 3D tensor matmul
|
|
1294
1298
|
bmm(other) {
|
|
1295
|
-
other =
|
|
1299
|
+
other = this.handleOther(other);
|
|
1296
1300
|
// Verify 3D shape
|
|
1297
1301
|
if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
|
|
1298
1302
|
throw new Error("Inputs are not 3D tensors with the same first dim size");
|
|
@@ -1350,7 +1354,7 @@ class Tensor {
|
|
|
1350
1354
|
}
|
|
1351
1355
|
// Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
|
|
1352
1356
|
mv(other) {
|
|
1353
|
-
other =
|
|
1357
|
+
other = this.handleOther(other);
|
|
1354
1358
|
// Verify 2D shape
|
|
1355
1359
|
if (this.shape.length !== 2 || other.shape.length !== 1) {
|
|
1356
1360
|
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
@@ -1359,7 +1363,7 @@ class Tensor {
|
|
|
1359
1363
|
}
|
|
1360
1364
|
// General matrix multiplication with different shapes
|
|
1361
1365
|
matmul(other) {
|
|
1362
|
-
other =
|
|
1366
|
+
other = this.handleOther(other);
|
|
1363
1367
|
const isThis1D = this.shape.length === 1;
|
|
1364
1368
|
const isOther1D = other.shape.length === 1;
|
|
1365
1369
|
if (isThis1D && isOther1D) {
|
|
@@ -1692,6 +1696,7 @@ class Tensor {
|
|
|
1692
1696
|
}
|
|
1693
1697
|
// Returns this tensor with value replaced with the value of another tensor
|
|
1694
1698
|
replace(other, allowShapeMismatch = false) {
|
|
1699
|
+
other = this.handleOther(other);
|
|
1695
1700
|
// Verify shape
|
|
1696
1701
|
if (!allowShapeMismatch) {
|
|
1697
1702
|
for (let index = 0; index < this.shape.length; index++) {
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,7 +55,7 @@ declare class LayerNorm {
|
|
|
55
55
|
eps: number;
|
|
56
56
|
normalizedShape: number[];
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
|
-
forward(input: Tensor
|
|
58
|
+
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
60
|
export interface StateDict {
|
|
61
61
|
[key: string]: any;
|
|
@@ -68,6 +68,7 @@ export declare const nn: {
|
|
|
68
68
|
LayerNorm: typeof LayerNorm;
|
|
69
69
|
state: {
|
|
70
70
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
71
|
+
moveParameters(model: any, device: string): void;
|
|
71
72
|
getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
|
|
72
73
|
loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
|
|
73
74
|
};
|
package/dist/nn.js
CHANGED
|
@@ -20,7 +20,7 @@ class Linear {
|
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
forward(input) {
|
|
23
|
-
input =
|
|
23
|
+
input = this.weight.handleOther(input);
|
|
24
24
|
return linearTransform(input, this.weight, this.bias);
|
|
25
25
|
}
|
|
26
26
|
}
|
|
@@ -49,8 +49,8 @@ class RNNCell {
|
|
|
49
49
|
}
|
|
50
50
|
}
|
|
51
51
|
forward(input, hidden) {
|
|
52
|
-
input =
|
|
53
|
-
hidden =
|
|
52
|
+
input = this.weightIH.handleOther(input);
|
|
53
|
+
hidden = this.weightHH.handleOther(hidden);
|
|
54
54
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
55
55
|
}
|
|
56
56
|
}
|
|
@@ -85,8 +85,8 @@ class GRUCell {
|
|
|
85
85
|
}
|
|
86
86
|
}
|
|
87
87
|
forward(input, hidden) {
|
|
88
|
-
input =
|
|
89
|
-
hidden =
|
|
88
|
+
input = this.weightIN.handleOther(input);
|
|
89
|
+
hidden = this.weightHN.handleOther(hidden);
|
|
90
90
|
const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
|
|
91
91
|
const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
|
|
92
92
|
const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
|
|
@@ -132,9 +132,9 @@ class LSTMCell {
|
|
|
132
132
|
}
|
|
133
133
|
}
|
|
134
134
|
forward(input, hidden, cell) {
|
|
135
|
-
input =
|
|
136
|
-
hidden =
|
|
137
|
-
cell =
|
|
135
|
+
input = this.weightII.handleOther(input);
|
|
136
|
+
hidden = this.weightHI.handleOther(hidden);
|
|
137
|
+
cell = this.weightHI.handleOther(cell);
|
|
138
138
|
const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
|
|
139
139
|
const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
|
|
140
140
|
const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
|
|
@@ -163,7 +163,6 @@ class LayerNorm {
|
|
|
163
163
|
}
|
|
164
164
|
}
|
|
165
165
|
forward(input) {
|
|
166
|
-
input = core_1.Tensor.forceTensor(input);
|
|
167
166
|
// Normalize over the specified dimensions
|
|
168
167
|
const normalizedDims = this.normalizedShape.length;
|
|
169
168
|
const startDim = input.shape.length - normalizedDims;
|
|
@@ -208,6 +207,12 @@ const state = {
|
|
|
208
207
|
}
|
|
209
208
|
return parameters;
|
|
210
209
|
},
|
|
210
|
+
moveParameters(model, device) {
|
|
211
|
+
const params = state.getParameters(model);
|
|
212
|
+
for (const param of params) {
|
|
213
|
+
param.to_(device);
|
|
214
|
+
}
|
|
215
|
+
},
|
|
211
216
|
getStateDict(model, prefix = "", visited = new WeakSet()) {
|
|
212
217
|
if (visited.has(model))
|
|
213
218
|
return {};
|
package/dist/optim.d.ts
CHANGED
|
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
|
|
|
38
38
|
constructor(params: Tensor[], options?: AdamOptions);
|
|
39
39
|
step(): void;
|
|
40
40
|
}
|
|
41
|
+
export interface AdamWOptions {
|
|
42
|
+
lr?: number;
|
|
43
|
+
betas?: [number, number];
|
|
44
|
+
eps?: number;
|
|
45
|
+
weightDecay?: number;
|
|
46
|
+
}
|
|
47
|
+
declare class AdamW extends BaseOptimizer {
|
|
48
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
49
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
50
|
+
stepCount: number;
|
|
51
|
+
lr: number;
|
|
52
|
+
betas: [number, number];
|
|
53
|
+
eps: number;
|
|
54
|
+
weightDecay: number;
|
|
55
|
+
constructor(params: Tensor[], options?: AdamWOptions);
|
|
56
|
+
step(): void;
|
|
57
|
+
}
|
|
41
58
|
export declare class Optim {
|
|
42
59
|
static BaseOptimizer: typeof BaseOptimizer;
|
|
43
60
|
static SGD: typeof SGD;
|
|
44
61
|
static Adam: typeof Adam;
|
|
62
|
+
static AdamW: typeof AdamW;
|
|
45
63
|
}
|
|
46
64
|
export {};
|
package/dist/optim.js
CHANGED
|
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
|
|
|
126
126
|
}
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
+
class AdamW extends BaseOptimizer {
|
|
130
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
131
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
132
|
+
stepCount = 0;
|
|
133
|
+
lr;
|
|
134
|
+
betas;
|
|
135
|
+
eps;
|
|
136
|
+
weightDecay;
|
|
137
|
+
constructor(params, options) {
|
|
138
|
+
super(params);
|
|
139
|
+
this.lr = options?.lr || 0.001;
|
|
140
|
+
this.betas = options?.betas || [0.9, 0.999];
|
|
141
|
+
this.eps = options?.eps || 1e-8;
|
|
142
|
+
this.weightDecay = options?.weightDecay || 0;
|
|
143
|
+
}
|
|
144
|
+
step() {
|
|
145
|
+
this.stepCount++;
|
|
146
|
+
const beta1 = this.betas[0];
|
|
147
|
+
const beta2 = this.betas[1];
|
|
148
|
+
// Bias correction factors
|
|
149
|
+
const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
|
|
150
|
+
const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
|
|
151
|
+
for (const param of this.params) {
|
|
152
|
+
if (!param.grad || !param.requiresGrad)
|
|
153
|
+
continue;
|
|
154
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
155
|
+
// Apply weight decay (L2 regularization)
|
|
156
|
+
detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
|
|
157
|
+
// Get or initialize first moment buffer (momentum)
|
|
158
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
159
|
+
if (!momentumBuffer) {
|
|
160
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
161
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
162
|
+
}
|
|
163
|
+
// Get or initialize second moment buffer (velocity)
|
|
164
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
165
|
+
if (!velocityBuffer) {
|
|
166
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
167
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
168
|
+
}
|
|
169
|
+
// Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
|
|
170
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
171
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
172
|
+
// Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
|
|
173
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
174
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
175
|
+
// Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
|
|
176
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
177
|
+
// Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
|
|
178
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
179
|
+
// Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
|
|
180
|
+
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
181
|
+
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
182
|
+
const newParam = detachedParam.sub(stepSize);
|
|
183
|
+
param.replace(newParam);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
129
187
|
class Optim {
|
|
130
188
|
static BaseOptimizer = BaseOptimizer;
|
|
131
189
|
static SGD = SGD;
|
|
132
190
|
static Adam = Adam;
|
|
191
|
+
static AdamW = AdamW;
|
|
133
192
|
}
|
|
134
193
|
exports.Optim = Optim;
|