catniff 0.5.8 → 0.5.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/backend.d.ts +1 -0
- package/dist/core.d.ts +3 -2
- package/dist/core.js +31 -19
- package/dist/nn.d.ts +2 -1
- package/dist/nn.js +14 -9
- package/dist/optim.d.ts +18 -0
- package/dist/optim.js +59 -0
- package/package.json +1 -1
package/dist/backend.d.ts
CHANGED
package/dist/core.d.ts
CHANGED
|
@@ -38,7 +38,7 @@ export declare class Tensor {
|
|
|
38
38
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
39
39
|
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
40
40
|
elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
41
|
-
|
|
41
|
+
handleOther(other: Tensor | TensorValue): Tensor;
|
|
42
42
|
static addGrad(tensor: Tensor, accumGrad: Tensor): void;
|
|
43
43
|
isContiguous(): boolean;
|
|
44
44
|
contiguous(): Tensor;
|
|
@@ -174,7 +174,8 @@ export declare class Tensor {
|
|
|
174
174
|
withGrad(requiresGrad: boolean): Tensor;
|
|
175
175
|
detach(): Tensor;
|
|
176
176
|
clone(): Tensor;
|
|
177
|
-
replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
|
|
177
|
+
replace(other: Tensor | TensorValue, allowShapeMismatch?: boolean): Tensor;
|
|
178
178
|
static backends: Map<string, Backend>;
|
|
179
179
|
to(device: string): Tensor;
|
|
180
|
+
to_(device: string): Tensor;
|
|
180
181
|
}
|
package/dist/core.js
CHANGED
|
@@ -21,13 +21,8 @@ class Tensor {
|
|
|
21
21
|
this.gradFn = options.gradFn || (() => { });
|
|
22
22
|
this.children = options.children || [];
|
|
23
23
|
this.device = options.device || "cpu";
|
|
24
|
-
// Move
|
|
25
|
-
|
|
26
|
-
const backend = Tensor.backends.get(this.device);
|
|
27
|
-
if (backend && backend.transfer) {
|
|
28
|
-
backend.transfer(this);
|
|
29
|
-
}
|
|
30
|
-
}
|
|
24
|
+
// Move to device in-place
|
|
25
|
+
this.to_(this.device);
|
|
31
26
|
}
|
|
32
27
|
// Utility to flatten an nD array to be 1D
|
|
33
28
|
static flatten(tensor) {
|
|
@@ -188,7 +183,7 @@ class Tensor {
|
|
|
188
183
|
}
|
|
189
184
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
190
185
|
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
191
|
-
other =
|
|
186
|
+
other = this.handleOther(other);
|
|
192
187
|
const out = Tensor.elementWiseAB(this, other, op);
|
|
193
188
|
if (this.requiresGrad) {
|
|
194
189
|
out.requiresGrad = true;
|
|
@@ -230,11 +225,15 @@ class Tensor {
|
|
|
230
225
|
}
|
|
231
226
|
return out;
|
|
232
227
|
}
|
|
233
|
-
// Utility to
|
|
234
|
-
|
|
235
|
-
if (
|
|
236
|
-
|
|
237
|
-
|
|
228
|
+
// Utility to handle other tensor if an op needs a second operand
|
|
229
|
+
handleOther(other) {
|
|
230
|
+
if (other instanceof Tensor) {
|
|
231
|
+
if (this.device !== other.device) {
|
|
232
|
+
throw new Error("Can not operate on tensors that are not on the same device");
|
|
233
|
+
}
|
|
234
|
+
return other;
|
|
235
|
+
}
|
|
236
|
+
return new Tensor(other, { device: this.device });
|
|
238
237
|
}
|
|
239
238
|
// Utility to add to gradient of tensor
|
|
240
239
|
static addGrad(tensor, accumGrad) {
|
|
@@ -1204,7 +1203,7 @@ class Tensor {
|
|
|
1204
1203
|
}
|
|
1205
1204
|
// 1D tensor dot product
|
|
1206
1205
|
dot(other) {
|
|
1207
|
-
other =
|
|
1206
|
+
other = this.handleOther(other);
|
|
1208
1207
|
// Verify 1D shape
|
|
1209
1208
|
if (this.shape.length !== 1 || other.shape.length !== 1) {
|
|
1210
1209
|
throw new Error("Inputs are not 1D tensors");
|
|
@@ -1242,7 +1241,7 @@ class Tensor {
|
|
|
1242
1241
|
}
|
|
1243
1242
|
// Matrix multiplication
|
|
1244
1243
|
mm(other) {
|
|
1245
|
-
other =
|
|
1244
|
+
other = this.handleOther(other);
|
|
1246
1245
|
// Verify 2D shape
|
|
1247
1246
|
if (this.shape.length !== 2 || other.shape.length !== 2) {
|
|
1248
1247
|
throw new Error("Inputs are not matrices");
|
|
@@ -1297,7 +1296,7 @@ class Tensor {
|
|
|
1297
1296
|
}
|
|
1298
1297
|
// Batched 3D tensor matmul
|
|
1299
1298
|
bmm(other) {
|
|
1300
|
-
other =
|
|
1299
|
+
other = this.handleOther(other);
|
|
1301
1300
|
// Verify 3D shape
|
|
1302
1301
|
if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
|
|
1303
1302
|
throw new Error("Inputs are not 3D tensors with the same first dim size");
|
|
@@ -1355,7 +1354,7 @@ class Tensor {
|
|
|
1355
1354
|
}
|
|
1356
1355
|
// Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
|
|
1357
1356
|
mv(other) {
|
|
1358
|
-
other =
|
|
1357
|
+
other = this.handleOther(other);
|
|
1359
1358
|
// Verify 2D shape
|
|
1360
1359
|
if (this.shape.length !== 2 || other.shape.length !== 1) {
|
|
1361
1360
|
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
@@ -1364,7 +1363,7 @@ class Tensor {
|
|
|
1364
1363
|
}
|
|
1365
1364
|
// General matrix multiplication with different shapes
|
|
1366
1365
|
matmul(other) {
|
|
1367
|
-
other =
|
|
1366
|
+
other = this.handleOther(other);
|
|
1368
1367
|
const isThis1D = this.shape.length === 1;
|
|
1369
1368
|
const isOther1D = other.shape.length === 1;
|
|
1370
1369
|
if (isThis1D && isOther1D) {
|
|
@@ -1697,6 +1696,7 @@ class Tensor {
|
|
|
1697
1696
|
}
|
|
1698
1697
|
// Returns this tensor with value replaced with the value of another tensor
|
|
1699
1698
|
replace(other, allowShapeMismatch = false) {
|
|
1699
|
+
other = this.handleOther(other);
|
|
1700
1700
|
// Verify shape
|
|
1701
1701
|
if (!allowShapeMismatch) {
|
|
1702
1702
|
for (let index = 0; index < this.shape.length; index++) {
|
|
@@ -1712,9 +1712,21 @@ class Tensor {
|
|
|
1712
1712
|
static backends = new Map();
|
|
1713
1713
|
// Op to transfer tensor to another device
|
|
1714
1714
|
to(device) {
|
|
1715
|
+
if (device === "cpu")
|
|
1716
|
+
return this;
|
|
1715
1717
|
const backend = Tensor.backends.get(device);
|
|
1716
1718
|
if (backend && backend.transfer) {
|
|
1717
|
-
backend.transfer(this);
|
|
1719
|
+
return backend.transfer(this);
|
|
1720
|
+
}
|
|
1721
|
+
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|
|
1722
|
+
}
|
|
1723
|
+
// Op to transfer tensor to another device in-place
|
|
1724
|
+
to_(device) {
|
|
1725
|
+
if (device === "cpu")
|
|
1726
|
+
return this;
|
|
1727
|
+
const backend = Tensor.backends.get(this.device);
|
|
1728
|
+
if (backend && backend.create) {
|
|
1729
|
+
backend.create(this);
|
|
1718
1730
|
return this;
|
|
1719
1731
|
}
|
|
1720
1732
|
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,7 +55,7 @@ declare class LayerNorm {
|
|
|
55
55
|
eps: number;
|
|
56
56
|
normalizedShape: number[];
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
|
-
forward(input: Tensor
|
|
58
|
+
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
60
|
export interface StateDict {
|
|
61
61
|
[key: string]: any;
|
|
@@ -68,6 +68,7 @@ export declare const nn: {
|
|
|
68
68
|
LayerNorm: typeof LayerNorm;
|
|
69
69
|
state: {
|
|
70
70
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
71
|
+
moveParameters(model: any, device: string): void;
|
|
71
72
|
getStateDict(model: any, prefix?: string, visited?: WeakSet<object>): StateDict;
|
|
72
73
|
loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
|
|
73
74
|
};
|
package/dist/nn.js
CHANGED
|
@@ -20,7 +20,7 @@ class Linear {
|
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
forward(input) {
|
|
23
|
-
input =
|
|
23
|
+
input = this.weight.handleOther(input);
|
|
24
24
|
return linearTransform(input, this.weight, this.bias);
|
|
25
25
|
}
|
|
26
26
|
}
|
|
@@ -49,8 +49,8 @@ class RNNCell {
|
|
|
49
49
|
}
|
|
50
50
|
}
|
|
51
51
|
forward(input, hidden) {
|
|
52
|
-
input =
|
|
53
|
-
hidden =
|
|
52
|
+
input = this.weightIH.handleOther(input);
|
|
53
|
+
hidden = this.weightHH.handleOther(hidden);
|
|
54
54
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
55
55
|
}
|
|
56
56
|
}
|
|
@@ -85,8 +85,8 @@ class GRUCell {
|
|
|
85
85
|
}
|
|
86
86
|
}
|
|
87
87
|
forward(input, hidden) {
|
|
88
|
-
input =
|
|
89
|
-
hidden =
|
|
88
|
+
input = this.weightIN.handleOther(input);
|
|
89
|
+
hidden = this.weightHN.handleOther(hidden);
|
|
90
90
|
const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid();
|
|
91
91
|
const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid();
|
|
92
92
|
const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh();
|
|
@@ -132,9 +132,9 @@ class LSTMCell {
|
|
|
132
132
|
}
|
|
133
133
|
}
|
|
134
134
|
forward(input, hidden, cell) {
|
|
135
|
-
input =
|
|
136
|
-
hidden =
|
|
137
|
-
cell =
|
|
135
|
+
input = this.weightII.handleOther(input);
|
|
136
|
+
hidden = this.weightHI.handleOther(hidden);
|
|
137
|
+
cell = this.weightHI.handleOther(cell);
|
|
138
138
|
const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid();
|
|
139
139
|
const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid();
|
|
140
140
|
const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh();
|
|
@@ -163,7 +163,6 @@ class LayerNorm {
|
|
|
163
163
|
}
|
|
164
164
|
}
|
|
165
165
|
forward(input) {
|
|
166
|
-
input = core_1.Tensor.forceTensor(input);
|
|
167
166
|
// Normalize over the specified dimensions
|
|
168
167
|
const normalizedDims = this.normalizedShape.length;
|
|
169
168
|
const startDim = input.shape.length - normalizedDims;
|
|
@@ -208,6 +207,12 @@ const state = {
|
|
|
208
207
|
}
|
|
209
208
|
return parameters;
|
|
210
209
|
},
|
|
210
|
+
moveParameters(model, device) {
|
|
211
|
+
const params = state.getParameters(model);
|
|
212
|
+
for (const param of params) {
|
|
213
|
+
param.to_(device);
|
|
214
|
+
}
|
|
215
|
+
},
|
|
211
216
|
getStateDict(model, prefix = "", visited = new WeakSet()) {
|
|
212
217
|
if (visited.has(model))
|
|
213
218
|
return {};
|
package/dist/optim.d.ts
CHANGED
|
@@ -38,9 +38,27 @@ declare class Adam extends BaseOptimizer {
|
|
|
38
38
|
constructor(params: Tensor[], options?: AdamOptions);
|
|
39
39
|
step(): void;
|
|
40
40
|
}
|
|
41
|
+
export interface AdamWOptions {
|
|
42
|
+
lr?: number;
|
|
43
|
+
betas?: [number, number];
|
|
44
|
+
eps?: number;
|
|
45
|
+
weightDecay?: number;
|
|
46
|
+
}
|
|
47
|
+
declare class AdamW extends BaseOptimizer {
|
|
48
|
+
momentumBuffers: Map<Tensor, Tensor>;
|
|
49
|
+
velocityBuffers: Map<Tensor, Tensor>;
|
|
50
|
+
stepCount: number;
|
|
51
|
+
lr: number;
|
|
52
|
+
betas: [number, number];
|
|
53
|
+
eps: number;
|
|
54
|
+
weightDecay: number;
|
|
55
|
+
constructor(params: Tensor[], options?: AdamWOptions);
|
|
56
|
+
step(): void;
|
|
57
|
+
}
|
|
41
58
|
export declare class Optim {
|
|
42
59
|
static BaseOptimizer: typeof BaseOptimizer;
|
|
43
60
|
static SGD: typeof SGD;
|
|
44
61
|
static Adam: typeof Adam;
|
|
62
|
+
static AdamW: typeof AdamW;
|
|
45
63
|
}
|
|
46
64
|
export {};
|
package/dist/optim.js
CHANGED
|
@@ -126,9 +126,68 @@ class Adam extends BaseOptimizer {
|
|
|
126
126
|
}
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
+
class AdamW extends BaseOptimizer {
|
|
130
|
+
momentumBuffers = new Map(); // First moment (m_t)
|
|
131
|
+
velocityBuffers = new Map(); // Second moment (v_t)
|
|
132
|
+
stepCount = 0;
|
|
133
|
+
lr;
|
|
134
|
+
betas;
|
|
135
|
+
eps;
|
|
136
|
+
weightDecay;
|
|
137
|
+
constructor(params, options) {
|
|
138
|
+
super(params);
|
|
139
|
+
this.lr = options?.lr || 0.001;
|
|
140
|
+
this.betas = options?.betas || [0.9, 0.999];
|
|
141
|
+
this.eps = options?.eps || 1e-8;
|
|
142
|
+
this.weightDecay = options?.weightDecay || 0;
|
|
143
|
+
}
|
|
144
|
+
step() {
|
|
145
|
+
this.stepCount++;
|
|
146
|
+
const beta1 = this.betas[0];
|
|
147
|
+
const beta2 = this.betas[1];
|
|
148
|
+
// Bias correction factors
|
|
149
|
+
const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
|
|
150
|
+
const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
|
|
151
|
+
for (const param of this.params) {
|
|
152
|
+
if (!param.grad || !param.requiresGrad)
|
|
153
|
+
continue;
|
|
154
|
+
let grad = param.grad.detach(), detachedParam = param.detach();
|
|
155
|
+
// Apply weight decay (L2 regularization)
|
|
156
|
+
detachedParam = detachedParam.sub(detachedParam.mul(this.weightDecay).mul(this.lr));
|
|
157
|
+
// Get or initialize first moment buffer (momentum)
|
|
158
|
+
let momentumBuffer = this.momentumBuffers.get(param);
|
|
159
|
+
if (!momentumBuffer) {
|
|
160
|
+
momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
161
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
162
|
+
}
|
|
163
|
+
// Get or initialize second moment buffer (velocity)
|
|
164
|
+
let velocityBuffer = this.velocityBuffers.get(param);
|
|
165
|
+
if (!velocityBuffer) {
|
|
166
|
+
velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
|
|
167
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
168
|
+
}
|
|
169
|
+
// Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
|
|
170
|
+
momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
|
|
171
|
+
this.momentumBuffers.set(param, momentumBuffer);
|
|
172
|
+
// Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
|
|
173
|
+
velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
|
|
174
|
+
this.velocityBuffers.set(param, velocityBuffer);
|
|
175
|
+
// Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
|
|
176
|
+
const correctedMomentum = momentumBuffer.div(biasCorrection1);
|
|
177
|
+
// Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
|
|
178
|
+
const correctedVelocity = velocityBuffer.div(biasCorrection2);
|
|
179
|
+
// Update parameters: θ_t = θ_t - α * m̂_t / (√v̂_t + ε)
|
|
180
|
+
const denom = correctedVelocity.sqrt().add(this.eps);
|
|
181
|
+
const stepSize = correctedMomentum.div(denom).mul(this.lr);
|
|
182
|
+
const newParam = detachedParam.sub(stepSize);
|
|
183
|
+
param.replace(newParam);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
129
187
|
class Optim {
|
|
130
188
|
static BaseOptimizer = BaseOptimizer;
|
|
131
189
|
static SGD = SGD;
|
|
132
190
|
static Adam = Adam;
|
|
191
|
+
static AdamW = AdamW;
|
|
133
192
|
}
|
|
134
193
|
exports.Optim = Optim;
|