@dniskav/neuron 0.2.3 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -3
- package/dist/index.d.mts +281 -22
- package/dist/index.d.ts +281 -22
- package/dist/index.js +1507 -104
- package/dist/index.mjs +1490 -103
- package/package.json +7 -3
package/dist/index.d.ts
CHANGED
|
@@ -25,6 +25,7 @@ interface Optimizer {
|
|
|
25
25
|
step(weight: number, gradient: number, lr: number): number;
|
|
26
26
|
}
|
|
27
27
|
type OptimizerFactory = () => Optimizer;
|
|
28
|
+
declare const defaultOptimizer: OptimizerFactory;
|
|
28
29
|
declare class SGD implements Optimizer {
|
|
29
30
|
step(weight: number, gradient: number, lr: number): number;
|
|
30
31
|
}
|
|
@@ -34,6 +35,13 @@ declare class Momentum implements Optimizer {
|
|
|
34
35
|
constructor(beta?: number);
|
|
35
36
|
step(weight: number, gradient: number, lr: number): number;
|
|
36
37
|
}
|
|
38
|
+
declare class ClipOptimizer implements Optimizer {
|
|
39
|
+
readonly inner: Optimizer;
|
|
40
|
+
readonly clipValue: number;
|
|
41
|
+
constructor(inner: Optimizer, clipValue: number);
|
|
42
|
+
step(weight: number, gradient: number, lr: number): number;
|
|
43
|
+
}
|
|
44
|
+
declare function ClippedOptimizerFactory(innerFactory: OptimizerFactory, clipValue: number): OptimizerFactory;
|
|
37
45
|
declare class Adam implements Optimizer {
|
|
38
46
|
readonly beta1: number;
|
|
39
47
|
readonly beta2: number;
|
|
@@ -66,24 +74,35 @@ declare class Network {
|
|
|
66
74
|
hiddenLayer: Layer;
|
|
67
75
|
outputLayer: Layer;
|
|
68
76
|
constructor(nInputs: number, nHidden: number, nOutputs: number);
|
|
69
|
-
predict(inputs: number[]): number;
|
|
77
|
+
predict(inputs: number[]): number[];
|
|
70
78
|
train(inputs: number[], target: number, lr: number): number;
|
|
79
|
+
getWeights(): number[];
|
|
80
|
+
setWeights(weights: number[]): void;
|
|
71
81
|
}
|
|
72
82
|
|
|
73
83
|
interface NetworkNOptions {
|
|
74
84
|
activations?: Activation[];
|
|
75
85
|
optimizer?: OptimizerFactory;
|
|
86
|
+
residual?: boolean | ((layerIndex: number) => boolean);
|
|
87
|
+
dropoutRate?: number;
|
|
76
88
|
}
|
|
77
89
|
declare class NetworkN {
|
|
78
90
|
readonly structure: number[];
|
|
79
91
|
layers: Layer[];
|
|
92
|
+
private _dropouts;
|
|
93
|
+
private _residual;
|
|
80
94
|
constructor(structure: number[], options?: NetworkNOptions);
|
|
81
|
-
predict(inputs: number[]): number[];
|
|
95
|
+
predict(inputs: number[], training?: boolean): number[];
|
|
82
96
|
train(inputs: number[], targets: number[], lr: number): number;
|
|
83
97
|
trainWithDeltas(inputs: number[], outputDeltas: number[], lr: number): void;
|
|
98
|
+
getWeights(): number[];
|
|
99
|
+
setWeights(weights: number[]): void;
|
|
100
|
+
private _shouldResidual;
|
|
101
|
+
private _forwardAll;
|
|
102
|
+
private _backpropLayers;
|
|
84
103
|
}
|
|
85
104
|
|
|
86
|
-
declare class Gate {
|
|
105
|
+
declare class Gate$1 {
|
|
87
106
|
W: number[][];
|
|
88
107
|
b: number[];
|
|
89
108
|
constructor(inputSize: number, hSize: number, initBias?: number);
|
|
@@ -94,12 +113,13 @@ declare class LSTMLayer {
|
|
|
94
113
|
readonly hSize: number;
|
|
95
114
|
h: number[];
|
|
96
115
|
c: number[];
|
|
97
|
-
forgetGate: Gate;
|
|
98
|
-
inputGate: Gate;
|
|
99
|
-
cellGate: Gate;
|
|
100
|
-
outputGate: Gate;
|
|
116
|
+
forgetGate: Gate$1;
|
|
117
|
+
inputGate: Gate$1;
|
|
118
|
+
cellGate: Gate$1;
|
|
119
|
+
outputGate: Gate$1;
|
|
120
|
+
private _optimizers;
|
|
101
121
|
private _traj;
|
|
102
|
-
constructor(inputSize: number, hiddenSize: number);
|
|
122
|
+
constructor(inputSize: number, hiddenSize: number, optimizerFactory?: OptimizerFactory);
|
|
103
123
|
reset(): void;
|
|
104
124
|
predict(inputs: number[]): number[];
|
|
105
125
|
backprop(dh_seq: number[][], lr: number): void;
|
|
@@ -122,6 +142,8 @@ declare class LSTMLayer {
|
|
|
122
142
|
};
|
|
123
143
|
};
|
|
124
144
|
setWeights(data: ReturnType<LSTMLayer["getWeights"]>): void;
|
|
145
|
+
getWeightsFlat(): number[];
|
|
146
|
+
setWeightsFlat(weights: number[]): void;
|
|
125
147
|
}
|
|
126
148
|
|
|
127
149
|
interface NetworkLSTMOptions {
|
|
@@ -163,6 +185,8 @@ declare class NetworkLSTM {
|
|
|
163
185
|
}[][];
|
|
164
186
|
};
|
|
165
187
|
setWeights(data: ReturnType<NetworkLSTM["getWeights"]>): void;
|
|
188
|
+
getWeightsFlat(): number[];
|
|
189
|
+
setWeightsFlat(weights: number[]): void;
|
|
166
190
|
}
|
|
167
191
|
|
|
168
192
|
declare function matMul(A: number[][], B: number[][]): number[][];
|
|
@@ -174,38 +198,56 @@ declare class WeightMatrix {
|
|
|
174
198
|
private opts;
|
|
175
199
|
constructor(rows: number, cols: number);
|
|
176
200
|
update(dW: number[][], lr: number, clipValue?: number): void;
|
|
201
|
+
getWeights(): number[];
|
|
202
|
+
setWeights(weights: number[]): void;
|
|
203
|
+
}
|
|
204
|
+
declare class BiasVector {
|
|
205
|
+
values: number[];
|
|
206
|
+
private opts;
|
|
207
|
+
constructor(size: number);
|
|
208
|
+
update(grad: number[], lr: number): void;
|
|
209
|
+
getWeights(): number[];
|
|
210
|
+
setWeights(weights: number[]): void;
|
|
177
211
|
}
|
|
178
212
|
declare class EmbeddingMatrix {
|
|
179
213
|
W: number[][];
|
|
180
214
|
constructor(vocabSize: number, d_model: number);
|
|
181
215
|
get(idx: number): number[];
|
|
182
216
|
update(idx: number, grad: number[], lr: number): void;
|
|
217
|
+
getWeights(): number[];
|
|
218
|
+
setWeights(weights: number[]): void;
|
|
183
219
|
}
|
|
184
220
|
|
|
185
221
|
declare class AttentionHead {
|
|
186
222
|
readonly d_k: number;
|
|
187
223
|
readonly d_v: number;
|
|
224
|
+
readonly causal: boolean;
|
|
188
225
|
Wq: WeightMatrix;
|
|
189
226
|
Wk: WeightMatrix;
|
|
190
227
|
Wv: WeightMatrix;
|
|
191
228
|
private cache;
|
|
192
|
-
constructor(d_model: number, d_k: number, d_v: number);
|
|
229
|
+
constructor(d_model: number, d_k: number, d_v: number, causal?: boolean);
|
|
193
230
|
predict(X: number[][]): number[][];
|
|
194
231
|
backward(dOut: number[][], lr: number): number[][];
|
|
195
232
|
getAttentionWeights(): number[][] | null;
|
|
233
|
+
getWeights(): number[];
|
|
234
|
+
setWeights(weights: number[]): void;
|
|
196
235
|
}
|
|
197
236
|
|
|
198
237
|
declare class MultiHeadAttention {
|
|
199
238
|
readonly nHeads: number;
|
|
200
239
|
readonly d_model: number;
|
|
201
240
|
readonly d_k: number;
|
|
241
|
+
readonly causal: boolean;
|
|
202
242
|
heads: AttentionHead[];
|
|
203
243
|
Wo: WeightMatrix;
|
|
204
244
|
private _concat;
|
|
205
|
-
constructor(d_model: number, nHeads: number);
|
|
245
|
+
constructor(d_model: number, nHeads: number, causal?: boolean);
|
|
206
246
|
predict(X: number[][]): number[][];
|
|
207
247
|
backward(dOut: number[][], lr: number): number[][];
|
|
208
248
|
getAttentionWeights(): (number[][] | null)[];
|
|
249
|
+
getWeights(): number[];
|
|
250
|
+
setWeights(weights: number[]): void;
|
|
209
251
|
}
|
|
210
252
|
|
|
211
253
|
declare class LayerNorm {
|
|
@@ -217,12 +259,15 @@ declare class LayerNorm {
|
|
|
217
259
|
resetCache(seqLen: number): void;
|
|
218
260
|
predictOne(x: number[], pos: number): number[];
|
|
219
261
|
backwardOne(dOut: number[], pos: number, lr: number): number[];
|
|
262
|
+
getWeights(): number[];
|
|
263
|
+
setWeights(weights: number[]): void;
|
|
220
264
|
}
|
|
221
265
|
|
|
222
266
|
interface TransformerBlockOptions {
|
|
223
267
|
d_model: number;
|
|
224
268
|
nHeads: number;
|
|
225
269
|
d_ff: number;
|
|
270
|
+
causal?: boolean;
|
|
226
271
|
}
|
|
227
272
|
declare class TransformerBlock {
|
|
228
273
|
readonly d_model: number;
|
|
@@ -232,20 +277,20 @@ declare class TransformerBlock {
|
|
|
232
277
|
norm2: LayerNorm;
|
|
233
278
|
ff1: WeightMatrix;
|
|
234
279
|
ff2: WeightMatrix;
|
|
235
|
-
b1:
|
|
236
|
-
b2:
|
|
237
|
-
private b1Opts;
|
|
238
|
-
private b2Opts;
|
|
280
|
+
b1: BiasVector;
|
|
281
|
+
b2: BiasVector;
|
|
239
282
|
private _X;
|
|
240
283
|
private _attnOut;
|
|
241
284
|
private _h1;
|
|
242
285
|
private _ff1Pre;
|
|
243
286
|
private _ff1Out;
|
|
244
287
|
private _ff2Out;
|
|
245
|
-
constructor({ d_model, nHeads, d_ff }: TransformerBlockOptions);
|
|
288
|
+
constructor({ d_model, nHeads, d_ff, causal }: TransformerBlockOptions);
|
|
246
289
|
predict(X: number[][]): number[][];
|
|
247
290
|
backward(dOut: number[][], lr: number): number[][];
|
|
248
291
|
getAttentionWeights(): (number[][] | null)[];
|
|
292
|
+
getWeights(): number[];
|
|
293
|
+
setWeights(weights: number[]): void;
|
|
249
294
|
}
|
|
250
295
|
|
|
251
296
|
interface NetworkTransformerOptions {
|
|
@@ -265,12 +310,13 @@ declare class NetworkTransformer {
|
|
|
265
310
|
posEmb: EmbeddingMatrix;
|
|
266
311
|
blocks: TransformerBlock[];
|
|
267
312
|
outputProj: WeightMatrix;
|
|
268
|
-
outputBias:
|
|
269
|
-
private outBiasOpts;
|
|
313
|
+
outputBias: BiasVector;
|
|
270
314
|
constructor(seqLen: number, options?: NetworkTransformerOptions);
|
|
271
315
|
predict(tokens: number[]): number[];
|
|
272
316
|
train(tokens: number[], targets: number[], lr: number, mask?: boolean[]): number;
|
|
273
317
|
getAttentionWeights(): (number[][] | null)[][];
|
|
318
|
+
getWeights(): number[];
|
|
319
|
+
setWeights(weights: number[]): void;
|
|
274
320
|
private _forward;
|
|
275
321
|
}
|
|
276
322
|
|
|
@@ -280,6 +326,7 @@ interface NetworkTransformerRLOptions {
|
|
|
280
326
|
d_ff?: number;
|
|
281
327
|
nBlocks?: number;
|
|
282
328
|
nActions?: number;
|
|
329
|
+
pooling?: 'avg' | 'max' | 'last' | 'weighted';
|
|
283
330
|
}
|
|
284
331
|
declare class NetworkTransformerRL {
|
|
285
332
|
readonly seqLen: number;
|
|
@@ -289,14 +336,17 @@ declare class NetworkTransformerRL {
|
|
|
289
336
|
inputProj: WeightMatrix;
|
|
290
337
|
blocks: TransformerBlock[];
|
|
291
338
|
outputProj: WeightMatrix;
|
|
292
|
-
outputBias:
|
|
293
|
-
private outBiasOpts;
|
|
339
|
+
outputBias: BiasVector;
|
|
294
340
|
private _projected;
|
|
341
|
+
private _pooling;
|
|
342
|
+
private _argmax;
|
|
295
343
|
constructor(seqLen: number, inputDim: number, options?: NetworkTransformerRLOptions);
|
|
296
344
|
predict(sequence: number[][]): number[];
|
|
297
345
|
train(sequence: number[][], target: number[], lr: number): number;
|
|
298
346
|
getAttentionWeights(): (number[][] | null)[][];
|
|
299
|
-
|
|
347
|
+
getWeightsFlat(): number[];
|
|
348
|
+
setWeightsFlat(weights: number[]): void;
|
|
349
|
+
getWeightsStructured(): {
|
|
300
350
|
inputProj: number[][];
|
|
301
351
|
blocks: {
|
|
302
352
|
attn: {
|
|
@@ -323,9 +373,18 @@ declare class NetworkTransformerRL {
|
|
|
323
373
|
outputProj: number[][];
|
|
324
374
|
outputBias: number[];
|
|
325
375
|
};
|
|
326
|
-
|
|
376
|
+
setWeightsStructured(data: ReturnType<NetworkTransformerRL['getWeightsStructured']>): void;
|
|
377
|
+
getWeights(): number[];
|
|
378
|
+
setWeights(weights: number[]): void;
|
|
327
379
|
private _forward;
|
|
328
380
|
private _pool;
|
|
381
|
+
private _poolAvg;
|
|
382
|
+
private _poolMax;
|
|
383
|
+
private _poolLast;
|
|
384
|
+
private _poolWeighted;
|
|
385
|
+
/** Returns the current pooling type for inspection. */
|
|
386
|
+
getPoolingType(): string;
|
|
387
|
+
private _distributePoolGradient;
|
|
329
388
|
}
|
|
330
389
|
|
|
331
390
|
declare function mse(predicted: number[], actual: number[]): number;
|
|
@@ -334,4 +393,204 @@ declare function mseDelta(predicted: number, actual: number): number;
|
|
|
334
393
|
declare function crossEntropyDelta(predicted: number, actual: number): number;
|
|
335
394
|
declare function crossEntropyDeltaRaw(predicted: number, actual: number): number;
|
|
336
395
|
|
|
337
|
-
|
|
396
|
+
declare class Dropout {
|
|
397
|
+
readonly rate: number;
|
|
398
|
+
private _mask;
|
|
399
|
+
constructor(rate: number);
|
|
400
|
+
forward(x: number[], training?: boolean): number[];
|
|
401
|
+
backward(dOut: number[]): number[];
|
|
402
|
+
resetMask(): void;
|
|
403
|
+
getWeights(): number[];
|
|
404
|
+
setWeights(_weights: number[]): void;
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
declare class Gate {
|
|
408
|
+
W: number[][];
|
|
409
|
+
b: number[];
|
|
410
|
+
constructor(inputSize: number, hSize: number, initBias?: number);
|
|
411
|
+
linear(combined: number[]): number[];
|
|
412
|
+
}
|
|
413
|
+
declare class GRULayer {
|
|
414
|
+
readonly inputSize: number;
|
|
415
|
+
readonly hSize: number;
|
|
416
|
+
h: number[];
|
|
417
|
+
resetGate: Gate;
|
|
418
|
+
updateGate: Gate;
|
|
419
|
+
newGate: Gate;
|
|
420
|
+
private _optimizers;
|
|
421
|
+
private _traj;
|
|
422
|
+
constructor(inputSize: number, hiddenSize: number, optimizerFactory?: OptimizerFactory);
|
|
423
|
+
reset(): void;
|
|
424
|
+
predict(inputs: number[]): number[];
|
|
425
|
+
backprop(dh_seq: number[][], lr: number): void;
|
|
426
|
+
getWeightsFlat(): number[];
|
|
427
|
+
setWeightsFlat(weights: number[]): void;
|
|
428
|
+
getWeights(): {
|
|
429
|
+
resetGate: {
|
|
430
|
+
W: number[][];
|
|
431
|
+
b: number[];
|
|
432
|
+
};
|
|
433
|
+
updateGate: {
|
|
434
|
+
W: number[][];
|
|
435
|
+
b: number[];
|
|
436
|
+
};
|
|
437
|
+
newGate: {
|
|
438
|
+
W: number[][];
|
|
439
|
+
b: number[];
|
|
440
|
+
};
|
|
441
|
+
};
|
|
442
|
+
setWeights(data: ReturnType<GRULayer["getWeights"]>): void;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
declare class BatchNorm {
|
|
446
|
+
readonly dim: number;
|
|
447
|
+
readonly momentum: number;
|
|
448
|
+
gamma: number[];
|
|
449
|
+
beta: number[];
|
|
450
|
+
runningMean: number[];
|
|
451
|
+
runningVar: number[];
|
|
452
|
+
private _xNorm;
|
|
453
|
+
private _std;
|
|
454
|
+
constructor(dim: number, momentum?: number);
|
|
455
|
+
forward(x: number[]): number[];
|
|
456
|
+
backward(dOut: number[]): number[];
|
|
457
|
+
trainParams(dOut: number[], lr: number): void;
|
|
458
|
+
getWeights(): number[];
|
|
459
|
+
setWeights(weights: number[]): void;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
declare class Conv1D {
|
|
463
|
+
readonly inputLength: number;
|
|
464
|
+
readonly kernelSize: number;
|
|
465
|
+
readonly filters: number;
|
|
466
|
+
readonly stride: number;
|
|
467
|
+
readonly padding: 'valid' | 'same';
|
|
468
|
+
readonly inputChannels: number;
|
|
469
|
+
kernels: number[][][];
|
|
470
|
+
biases: number[];
|
|
471
|
+
private _kOpts;
|
|
472
|
+
private _bOpts;
|
|
473
|
+
private _input;
|
|
474
|
+
private _paddedInput;
|
|
475
|
+
constructor(inputLength: number, kernelSize: number, filters: number, stride?: number, padding?: 'valid' | 'same', optimizerFactory?: OptimizerFactory, inputChannels?: number);
|
|
476
|
+
forward(input: number[] | number[][]): number[][];
|
|
477
|
+
backward(dOut: number[][], lr?: number): number[][];
|
|
478
|
+
getOutputLength(): number;
|
|
479
|
+
getWeights(): number[];
|
|
480
|
+
setWeights(weights: number[]): void;
|
|
481
|
+
private _normalizeInput;
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
interface DataPair {
|
|
485
|
+
inputs: number[][];
|
|
486
|
+
targets: number[][];
|
|
487
|
+
}
|
|
488
|
+
declare class DataLoader {
|
|
489
|
+
readonly data: DataPair;
|
|
490
|
+
readonly batchSize: number;
|
|
491
|
+
private _indices;
|
|
492
|
+
private _trainIndices;
|
|
493
|
+
private _valIndices;
|
|
494
|
+
private _pos;
|
|
495
|
+
private _validationSplit;
|
|
496
|
+
constructor(data: DataPair, batchSize?: number, validationSplit?: number);
|
|
497
|
+
shuffle(): void;
|
|
498
|
+
hasNext(): boolean;
|
|
499
|
+
next(): DataPair;
|
|
500
|
+
reset(): void;
|
|
501
|
+
get length(): number;
|
|
502
|
+
getValidationData(): DataPair;
|
|
503
|
+
get validationLength(): number;
|
|
504
|
+
static sequences(data: number[][], seqLen: number, validationSplit?: number): DataLoader;
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
interface TrainMetrics {
|
|
508
|
+
accuracy: number;
|
|
509
|
+
precision: number;
|
|
510
|
+
recall: number;
|
|
511
|
+
f1: number;
|
|
512
|
+
}
|
|
513
|
+
interface TrainerOptions {
|
|
514
|
+
epochs?: number;
|
|
515
|
+
lr?: number;
|
|
516
|
+
lrDecay?: number;
|
|
517
|
+
verbose?: boolean;
|
|
518
|
+
weightDecay?: number;
|
|
519
|
+
earlyStopping?: {
|
|
520
|
+
patience: number;
|
|
521
|
+
minDelta: number;
|
|
522
|
+
};
|
|
523
|
+
computeMetrics?: boolean;
|
|
524
|
+
clipValue?: number;
|
|
525
|
+
}
|
|
526
|
+
interface TrainDataset {
|
|
527
|
+
inputs: number[][];
|
|
528
|
+
targets: number[][];
|
|
529
|
+
}
|
|
530
|
+
interface TrainableNetwork {
|
|
531
|
+
train(inputs: number[], targets: number[], lr: number): number;
|
|
532
|
+
}
|
|
533
|
+
/** Extended interface for networks that support weight access and prediction.
|
|
534
|
+
* Required for weightDecay, earlyStopping, and computeMetrics features. */
|
|
535
|
+
interface TrainableNetworkWithWeights extends TrainableNetwork {
|
|
536
|
+
predict(inputs: number[]): number[];
|
|
537
|
+
getWeights(): number[];
|
|
538
|
+
setWeights(weights: number[]): void;
|
|
539
|
+
}
|
|
540
|
+
declare class Trainer {
|
|
541
|
+
readonly network: TrainableNetwork;
|
|
542
|
+
readonly epochs: number;
|
|
543
|
+
readonly lrInitial: number;
|
|
544
|
+
readonly lrDecay: number;
|
|
545
|
+
readonly verbose: boolean;
|
|
546
|
+
readonly weightDecay: number;
|
|
547
|
+
readonly clipValue: number;
|
|
548
|
+
private _history;
|
|
549
|
+
private _validationData?;
|
|
550
|
+
private _earlyStopping?;
|
|
551
|
+
private _bestLoss;
|
|
552
|
+
private _patienceCounter;
|
|
553
|
+
private _stopReason;
|
|
554
|
+
private _computeMetrics;
|
|
555
|
+
private _metrics;
|
|
556
|
+
constructor(network: TrainableNetwork, options?: TrainerOptions);
|
|
557
|
+
setValidationData(dataset: DataPair): void;
|
|
558
|
+
getBestLoss(): number;
|
|
559
|
+
getStopReason(): string;
|
|
560
|
+
getMetrics(): TrainMetrics[];
|
|
561
|
+
train(dataset: TrainDataset): number[];
|
|
562
|
+
getHistory(): number[];
|
|
563
|
+
/** Type guard: does this network support getWeights/setWeights/predict? */
|
|
564
|
+
private _hasWeights;
|
|
565
|
+
/** Mean squared error on a dataset (used for validation loss). */
|
|
566
|
+
private _computeLoss;
|
|
567
|
+
/** Heuristic: are targets classification-style (one-hot or single-class)? */
|
|
568
|
+
private _isClassification;
|
|
569
|
+
/** Compute classification metrics from predictions vs targets. */
|
|
570
|
+
private _computeMetricsArray;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
declare class LRScheduler {
|
|
574
|
+
stepDecay(lr: number, epoch: number, dropRate: number, epochsDrop: number): number;
|
|
575
|
+
exponentialDecay(lr: number, epoch: number, decayRate: number): number;
|
|
576
|
+
plateauDecay(lr: number, currentLoss: number, history: number[], patience: number, factor: number): number;
|
|
577
|
+
cosineAnnealing(lr: number, epoch: number, maxEpochs: number, minLr?: number): number;
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
interface Serializable {
|
|
581
|
+
getWeights(): number[];
|
|
582
|
+
setWeights(weights: number[]): void;
|
|
583
|
+
}
|
|
584
|
+
declare class ModelSaver {
|
|
585
|
+
static toJSON(model: Serializable): string;
|
|
586
|
+
static fromJSON(model: Serializable, json: string): void;
|
|
587
|
+
static saveToFile(model: Serializable, path: string, writeFn: (path: string, data: string) => void): void;
|
|
588
|
+
static loadFromFile(model: Serializable, path: string, readFn: (path: string) => string): void;
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
declare function validateArray(arr: unknown, expectedLength: number, methodName: string): asserts arr is number[];
|
|
592
|
+
declare function validateArrayMinLength(arr: unknown, minLength: number, methodName: string): asserts arr is number[];
|
|
593
|
+
declare function validate2DArray(arr: unknown, expectedRows: number, expectedCols: number, methodName: string): asserts arr is number[][];
|
|
594
|
+
declare function validateNumber(value: unknown, methodName: string): asserts value is number;
|
|
595
|
+
|
|
596
|
+
export { type Activation, Adam, AttentionHead, BatchNorm, BiasVector, ClipOptimizer, ClippedOptimizerFactory, Conv1D, DataLoader, type DataPair, Dropout, EmbeddingMatrix, GRULayer, LRScheduler, LSTMLayer, Layer, LayerNorm, ModelSaver, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, NetworkTransformerRL, type NetworkTransformerRLOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, type Serializable, type TrainDataset, type TrainMetrics, type TrainableNetwork, type TrainableNetworkWithWeights, Trainer, type TrainerOptions, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, defaultOptimizer, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose, validate2DArray, validateArray, validateArrayMinLength, validateNumber };
|