catniff 0.6.14 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/lrscheduler.d.ts +13 -0
- package/dist/lrscheduler.js +31 -0
- package/dist/nn.d.ts +15 -8
- package/dist/nn.js +46 -1
- package/dist/optim.d.ts +9 -9
- package/dist/optim.js +11 -11
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { BaseOptimizer } from "./optim";
|
|
2
|
+
export declare class StepLR {
|
|
3
|
+
optimizer: BaseOptimizer;
|
|
4
|
+
stepSize: number;
|
|
5
|
+
gamma: number;
|
|
6
|
+
lastEpoch: number;
|
|
7
|
+
baseLR: number;
|
|
8
|
+
constructor(optimizer: BaseOptimizer, stepSize: number, gamma?: number, lastEpoch?: number);
|
|
9
|
+
step(epoch?: number): void;
|
|
10
|
+
}
|
|
11
|
+
export declare const LRScheduler: {
|
|
12
|
+
StepLR: typeof StepLR;
|
|
13
|
+
};
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.LRScheduler = exports.StepLR = void 0;
|
|
4
|
+
class StepLR {
|
|
5
|
+
optimizer;
|
|
6
|
+
stepSize;
|
|
7
|
+
gamma;
|
|
8
|
+
lastEpoch;
|
|
9
|
+
baseLR;
|
|
10
|
+
constructor(optimizer, stepSize, gamma = 0.1, lastEpoch = -1) {
|
|
11
|
+
this.optimizer = optimizer;
|
|
12
|
+
this.stepSize = stepSize;
|
|
13
|
+
this.gamma = gamma;
|
|
14
|
+
this.lastEpoch = lastEpoch;
|
|
15
|
+
this.baseLR = this.optimizer.lr;
|
|
16
|
+
}
|
|
17
|
+
step(epoch) {
|
|
18
|
+
if (typeof epoch === "undefined") {
|
|
19
|
+
this.lastEpoch++;
|
|
20
|
+
epoch = this.lastEpoch;
|
|
21
|
+
}
|
|
22
|
+
else {
|
|
23
|
+
this.lastEpoch = epoch;
|
|
24
|
+
}
|
|
25
|
+
this.optimizer.lr = this.baseLR * this.gamma ** Math.floor(epoch / this.stepSize);
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
exports.StepLR = StepLR;
|
|
29
|
+
exports.LRScheduler = {
|
|
30
|
+
StepLR
|
|
31
|
+
};
|
package/dist/nn.d.ts
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import { Tensor, TensorValue } from "./core";
|
|
2
|
-
declare class Linear {
|
|
2
|
+
export declare class Linear {
|
|
3
3
|
weight: Tensor;
|
|
4
4
|
bias?: Tensor;
|
|
5
5
|
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
|
|
6
6
|
forward(input: Tensor | TensorValue): Tensor;
|
|
7
7
|
}
|
|
8
|
-
declare class RNNCell {
|
|
8
|
+
export declare class RNNCell {
|
|
9
9
|
weightIH: Tensor;
|
|
10
10
|
weightHH: Tensor;
|
|
11
11
|
biasIH?: Tensor;
|
|
@@ -13,7 +13,7 @@ declare class RNNCell {
|
|
|
13
13
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
14
14
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
15
15
|
}
|
|
16
|
-
declare class GRUCell {
|
|
16
|
+
export declare class GRUCell {
|
|
17
17
|
weightIR: Tensor;
|
|
18
18
|
weightIZ: Tensor;
|
|
19
19
|
weightIN: Tensor;
|
|
@@ -29,7 +29,7 @@ declare class GRUCell {
|
|
|
29
29
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
30
30
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
31
31
|
}
|
|
32
|
-
declare class LSTMCell {
|
|
32
|
+
export declare class LSTMCell {
|
|
33
33
|
weightII: Tensor;
|
|
34
34
|
weightIF: Tensor;
|
|
35
35
|
weightIG: Tensor;
|
|
@@ -49,7 +49,7 @@ declare class LSTMCell {
|
|
|
49
49
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
50
50
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
51
51
|
}
|
|
52
|
-
declare class LayerNorm {
|
|
52
|
+
export declare class LayerNorm {
|
|
53
53
|
weight?: Tensor;
|
|
54
54
|
bias?: Tensor;
|
|
55
55
|
eps: number;
|
|
@@ -57,12 +57,19 @@ declare class LayerNorm {
|
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
58
|
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
|
-
declare class
|
|
60
|
+
export declare class RMSNorm {
|
|
61
|
+
weight?: Tensor;
|
|
62
|
+
eps: number;
|
|
63
|
+
normalizedShape: number[];
|
|
64
|
+
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, device?: string);
|
|
65
|
+
forward(input: Tensor): Tensor;
|
|
66
|
+
}
|
|
67
|
+
export declare class Embedding {
|
|
61
68
|
weight: Tensor;
|
|
62
69
|
constructor(numEmbeddings: number, embeddingDim: number, device: string);
|
|
63
70
|
forward(input: Tensor | TensorValue): Tensor;
|
|
64
71
|
}
|
|
65
|
-
declare class MultiheadAttention {
|
|
72
|
+
export declare class MultiheadAttention {
|
|
66
73
|
qProjection: Linear;
|
|
67
74
|
kProjection: Linear;
|
|
68
75
|
vProjection: Linear;
|
|
@@ -83,6 +90,7 @@ export declare const nn: {
|
|
|
83
90
|
GRUCell: typeof GRUCell;
|
|
84
91
|
LSTMCell: typeof LSTMCell;
|
|
85
92
|
LayerNorm: typeof LayerNorm;
|
|
93
|
+
RMSNorm: typeof RMSNorm;
|
|
86
94
|
Embedding: typeof Embedding;
|
|
87
95
|
MultiheadAttention: typeof MultiheadAttention;
|
|
88
96
|
state: {
|
|
@@ -92,4 +100,3 @@ export declare const nn: {
|
|
|
92
100
|
loadStateDict(model: any, stateDict: StateDict, prefix?: string, visited?: WeakSet<object>): void;
|
|
93
101
|
};
|
|
94
102
|
};
|
|
95
|
-
export {};
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
function linearTransform(input, weight, bias) {
|
|
6
6
|
let output = input.matmul(weight.t());
|
|
@@ -24,6 +24,7 @@ class Linear {
|
|
|
24
24
|
return linearTransform(input, this.weight, this.bias);
|
|
25
25
|
}
|
|
26
26
|
}
|
|
27
|
+
exports.Linear = Linear;
|
|
27
28
|
function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) {
|
|
28
29
|
let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t()));
|
|
29
30
|
if (inputBias) {
|
|
@@ -54,6 +55,7 @@ class RNNCell {
|
|
|
54
55
|
return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh();
|
|
55
56
|
}
|
|
56
57
|
}
|
|
58
|
+
exports.RNNCell = RNNCell;
|
|
57
59
|
class GRUCell {
|
|
58
60
|
weightIR;
|
|
59
61
|
weightIZ;
|
|
@@ -93,6 +95,7 @@ class GRUCell {
|
|
|
93
95
|
return (z.neg().add(1).mul(n).add(z.mul(hidden)));
|
|
94
96
|
}
|
|
95
97
|
}
|
|
98
|
+
exports.GRUCell = GRUCell;
|
|
96
99
|
class LSTMCell {
|
|
97
100
|
weightII;
|
|
98
101
|
weightIF;
|
|
@@ -144,6 +147,7 @@ class LSTMCell {
|
|
|
144
147
|
return [h, c];
|
|
145
148
|
}
|
|
146
149
|
}
|
|
150
|
+
exports.LSTMCell = LSTMCell;
|
|
147
151
|
class LayerNorm {
|
|
148
152
|
weight;
|
|
149
153
|
bias;
|
|
@@ -188,6 +192,44 @@ class LayerNorm {
|
|
|
188
192
|
return normalized;
|
|
189
193
|
}
|
|
190
194
|
}
|
|
195
|
+
exports.LayerNorm = LayerNorm;
|
|
196
|
+
class RMSNorm {
|
|
197
|
+
weight;
|
|
198
|
+
eps;
|
|
199
|
+
normalizedShape;
|
|
200
|
+
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, device) {
|
|
201
|
+
this.eps = eps;
|
|
202
|
+
this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
|
|
203
|
+
if (this.normalizedShape.length === 0) {
|
|
204
|
+
throw new Error("Normalized shape cannot be empty");
|
|
205
|
+
}
|
|
206
|
+
if (elementwiseAffine) {
|
|
207
|
+
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
forward(input) {
|
|
211
|
+
// Normalize over the specified dimensions
|
|
212
|
+
const normalizedDims = this.normalizedShape.length;
|
|
213
|
+
const startDim = input.shape.length - normalizedDims;
|
|
214
|
+
if (startDim < 0) {
|
|
215
|
+
throw new Error("Input does not have enough dims to normalize");
|
|
216
|
+
}
|
|
217
|
+
const dims = [];
|
|
218
|
+
for (let i = 0; i < normalizedDims; i++) {
|
|
219
|
+
if (input.shape[startDim + i] !== this.normalizedShape[i]) {
|
|
220
|
+
throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`);
|
|
221
|
+
}
|
|
222
|
+
dims.push(startDim + i);
|
|
223
|
+
}
|
|
224
|
+
let rms = input.square().mean(dims, true).add(this.eps).sqrt();
|
|
225
|
+
let normalized = input.div(rms);
|
|
226
|
+
if (this.weight) {
|
|
227
|
+
normalized = normalized.mul(this.weight);
|
|
228
|
+
}
|
|
229
|
+
return normalized;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
exports.RMSNorm = RMSNorm;
|
|
191
233
|
class Embedding {
|
|
192
234
|
weight;
|
|
193
235
|
constructor(numEmbeddings, embeddingDim, device) {
|
|
@@ -197,6 +239,7 @@ class Embedding {
|
|
|
197
239
|
return this.weight.index(input);
|
|
198
240
|
}
|
|
199
241
|
}
|
|
242
|
+
exports.Embedding = Embedding;
|
|
200
243
|
class MultiheadAttention {
|
|
201
244
|
qProjection;
|
|
202
245
|
kProjection;
|
|
@@ -248,6 +291,7 @@ class MultiheadAttention {
|
|
|
248
291
|
return [output, needWeights ? attnWeights : undefined];
|
|
249
292
|
}
|
|
250
293
|
}
|
|
294
|
+
exports.MultiheadAttention = MultiheadAttention;
|
|
251
295
|
const state = {
|
|
252
296
|
getParameters(model, visited = new WeakSet()) {
|
|
253
297
|
if (visited.has(model))
|
|
@@ -316,6 +360,7 @@ exports.nn = {
|
|
|
316
360
|
GRUCell,
|
|
317
361
|
LSTMCell,
|
|
318
362
|
LayerNorm,
|
|
363
|
+
RMSNorm,
|
|
319
364
|
Embedding,
|
|
320
365
|
MultiheadAttention,
|
|
321
366
|
state
|
package/dist/optim.d.ts
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import { Tensor } from "./core";
|
|
2
|
-
|
|
2
|
+
export interface BaseOptimizerOptions {
|
|
3
|
+
lr?: number;
|
|
4
|
+
}
|
|
5
|
+
export declare abstract class BaseOptimizer {
|
|
3
6
|
params: Tensor[];
|
|
4
|
-
|
|
7
|
+
lr: number;
|
|
8
|
+
constructor(params: Tensor[], options?: BaseOptimizerOptions);
|
|
5
9
|
zeroGrad(): void;
|
|
6
10
|
}
|
|
7
11
|
export interface SGDOptions {
|
|
@@ -11,9 +15,8 @@ export interface SGDOptions {
|
|
|
11
15
|
weightDecay?: number;
|
|
12
16
|
nesterov?: boolean;
|
|
13
17
|
}
|
|
14
|
-
declare class SGD extends BaseOptimizer {
|
|
18
|
+
export declare class SGD extends BaseOptimizer {
|
|
15
19
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
16
|
-
lr: number;
|
|
17
20
|
momentum: number;
|
|
18
21
|
dampening: number;
|
|
19
22
|
weightDecay: number;
|
|
@@ -27,11 +30,10 @@ export interface AdamOptions {
|
|
|
27
30
|
eps?: number;
|
|
28
31
|
weightDecay?: number;
|
|
29
32
|
}
|
|
30
|
-
declare class Adam extends BaseOptimizer {
|
|
33
|
+
export declare class Adam extends BaseOptimizer {
|
|
31
34
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
32
35
|
velocityBuffers: Map<Tensor, Tensor>;
|
|
33
36
|
stepCount: number;
|
|
34
|
-
lr: number;
|
|
35
37
|
betas: [number, number];
|
|
36
38
|
eps: number;
|
|
37
39
|
weightDecay: number;
|
|
@@ -44,11 +46,10 @@ export interface AdamWOptions {
|
|
|
44
46
|
eps?: number;
|
|
45
47
|
weightDecay?: number;
|
|
46
48
|
}
|
|
47
|
-
declare class AdamW extends BaseOptimizer {
|
|
49
|
+
export declare class AdamW extends BaseOptimizer {
|
|
48
50
|
momentumBuffers: Map<Tensor, Tensor>;
|
|
49
51
|
velocityBuffers: Map<Tensor, Tensor>;
|
|
50
52
|
stepCount: number;
|
|
51
|
-
lr: number;
|
|
52
53
|
betas: [number, number];
|
|
53
54
|
eps: number;
|
|
54
55
|
weightDecay: number;
|
|
@@ -61,4 +62,3 @@ export declare const Optim: {
|
|
|
61
62
|
Adam: typeof Adam;
|
|
62
63
|
AdamW: typeof AdamW;
|
|
63
64
|
};
|
|
64
|
-
export {};
|
package/dist/optim.js
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Optim = void 0;
|
|
3
|
+
exports.Optim = exports.AdamW = exports.Adam = exports.SGD = exports.BaseOptimizer = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class BaseOptimizer {
|
|
6
6
|
params;
|
|
7
|
-
|
|
7
|
+
lr;
|
|
8
|
+
constructor(params, options) {
|
|
8
9
|
this.params = params;
|
|
10
|
+
this.lr = options?.lr || 0.001;
|
|
9
11
|
}
|
|
10
12
|
zeroGrad() {
|
|
11
13
|
for (let index = 0; index < this.params.length; index++) {
|
|
@@ -14,16 +16,15 @@ class BaseOptimizer {
|
|
|
14
16
|
}
|
|
15
17
|
}
|
|
16
18
|
}
|
|
19
|
+
exports.BaseOptimizer = BaseOptimizer;
|
|
17
20
|
class SGD extends BaseOptimizer {
|
|
18
21
|
momentumBuffers = new Map();
|
|
19
|
-
lr;
|
|
20
22
|
momentum;
|
|
21
23
|
dampening;
|
|
22
24
|
weightDecay;
|
|
23
25
|
nesterov;
|
|
24
26
|
constructor(params, options) {
|
|
25
|
-
super(params);
|
|
26
|
-
this.lr = options?.lr || 0.001;
|
|
27
|
+
super(params, options);
|
|
27
28
|
this.momentum = options?.momentum || 0;
|
|
28
29
|
this.dampening = options?.dampening || 0;
|
|
29
30
|
this.weightDecay = options?.weightDecay || 0;
|
|
@@ -66,17 +67,16 @@ class SGD extends BaseOptimizer {
|
|
|
66
67
|
}
|
|
67
68
|
}
|
|
68
69
|
}
|
|
70
|
+
exports.SGD = SGD;
|
|
69
71
|
class Adam extends BaseOptimizer {
|
|
70
72
|
momentumBuffers = new Map(); // First moment (m_t)
|
|
71
73
|
velocityBuffers = new Map(); // Second moment (v_t)
|
|
72
74
|
stepCount = 0;
|
|
73
|
-
lr;
|
|
74
75
|
betas;
|
|
75
76
|
eps;
|
|
76
77
|
weightDecay;
|
|
77
78
|
constructor(params, options) {
|
|
78
|
-
super(params);
|
|
79
|
-
this.lr = options?.lr || 0.001;
|
|
79
|
+
super(params, options);
|
|
80
80
|
this.betas = options?.betas || [0.9, 0.999];
|
|
81
81
|
this.eps = options?.eps || 1e-8;
|
|
82
82
|
this.weightDecay = options?.weightDecay || 0;
|
|
@@ -126,17 +126,16 @@ class Adam extends BaseOptimizer {
|
|
|
126
126
|
}
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
+
exports.Adam = Adam;
|
|
129
130
|
class AdamW extends BaseOptimizer {
|
|
130
131
|
momentumBuffers = new Map(); // First moment (m_t)
|
|
131
132
|
velocityBuffers = new Map(); // Second moment (v_t)
|
|
132
133
|
stepCount = 0;
|
|
133
|
-
lr;
|
|
134
134
|
betas;
|
|
135
135
|
eps;
|
|
136
136
|
weightDecay;
|
|
137
137
|
constructor(params, options) {
|
|
138
|
-
super(params);
|
|
139
|
-
this.lr = options?.lr || 0.001;
|
|
138
|
+
super(params, options);
|
|
140
139
|
this.betas = options?.betas || [0.9, 0.999];
|
|
141
140
|
this.eps = options?.eps || 1e-8;
|
|
142
141
|
this.weightDecay = options?.weightDecay || 0.01;
|
|
@@ -184,6 +183,7 @@ class AdamW extends BaseOptimizer {
|
|
|
184
183
|
}
|
|
185
184
|
}
|
|
186
185
|
}
|
|
186
|
+
exports.AdamW = AdamW;
|
|
187
187
|
exports.Optim = {
|
|
188
188
|
BaseOptimizer,
|
|
189
189
|
SGD,
|
package/index.d.ts
CHANGED
package/index.js
CHANGED