mini-jstorch 1.4.5 → 1.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +22 -46
- package/demo/MakeModel.js +36 -0
- package/demo/fu_fun.js +72 -0
- package/demo/scheduler.js +69 -0
- package/index.js +1 -1
- package/package.json +1 -1
- package/src/jstorch.js +1237 -0
- package/src/MainEngine.js +0 -663
- package/tests/MakeModel.js +0 -38
- package/tests/scheduler.js +0 -23
package/src/jstorch.js
ADDED
|
@@ -0,0 +1,1237 @@
|
|
|
1
|
+
/*!
|
|
2
|
+
* Project: mini-jstorch
|
|
3
|
+
* File: MainEngine.js
|
|
4
|
+
* Author: Rizal-editors
|
|
5
|
+
* License: MIT
|
|
6
|
+
* Copyright (C) 2025 Rizal-editors
|
|
7
|
+
*
|
|
8
|
+
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
9
|
+
* of this software and associated documentation files (the "Software"), to deal
|
|
10
|
+
* in the Software without restriction, including without limitation the rights
|
|
11
|
+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
12
|
+
* copies of the Software, and to permit persons to whom the Software is
|
|
13
|
+
* furnished to do so, subject to the following conditions:
|
|
14
|
+
*
|
|
15
|
+
* The above copyright notice and this permission notice shall be included in all
|
|
16
|
+
* copies or substantial portions of the Software.
|
|
17
|
+
*
|
|
18
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
19
|
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
20
|
+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
21
|
+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
22
|
+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
23
|
+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
24
|
+
* SOFTWARE.
|
|
25
|
+
*/
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
// ---------------------- Utilities ----------------------
|
|
29
|
+
export function zeros(rows, cols) {
|
|
30
|
+
return Array.from({length:rows},()=>Array(cols).fill(0));
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export function ones(rows, cols) {
|
|
34
|
+
return Array.from({length:rows},()=>Array(cols).fill(1));
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
export function randomMatrix(rows, cols, scale=null){
|
|
38
|
+
// Auto-scale based on layer size (Xavier init)
|
|
39
|
+
if (scale === null){
|
|
40
|
+
scale = Math.sqrt(2.0 / (rows + cols));
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return Array.from({length: rows}, () =>
|
|
44
|
+
Array.from({length: cols}, () => (Math.random() * 2 - 1) * scale));
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
export function transpose(matrix){
|
|
48
|
+
return matrix[0].map((_,i)=>matrix.map(row=>row[i]));
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
export function addMatrices(a,b){
|
|
52
|
+
return a.map((row,i)=>
|
|
53
|
+
row.map((v,j)=>v+(b[i] && b[i][j]!==undefined?b[i][j]:0))
|
|
54
|
+
);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
export function dot(a,b){
|
|
58
|
+
const res=zeros(a.length,b[0].length);
|
|
59
|
+
for(let i=0;i<a.length;i++)
|
|
60
|
+
for(let j=0;j<b[0].length;j++)
|
|
61
|
+
for(let k=0;k<a[0].length;k++)
|
|
62
|
+
res[i][j]+=a[i][k]*b[k][j];
|
|
63
|
+
return res;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
export function softmax(x){
|
|
67
|
+
const m=Math.max(...x);
|
|
68
|
+
const exps=x.map(v=>Math.exp(v-m));
|
|
69
|
+
const s=exps.reduce((a,b)=>a+b,0);
|
|
70
|
+
return exps.map(v=>v/s);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
export function crossEntropy(pred,target){
|
|
74
|
+
const eps=1e-12;
|
|
75
|
+
return -target.reduce((sum,t,i)=>sum+t*Math.log(pred[i]+eps),0);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
// ---------------------- USERS FRIENDLY UTILS ----------------
|
|
79
|
+
export function fu_tensor(data, requiresGrad = false) {
|
|
80
|
+
if (!Array.isArray(data) || !Array.isArray(data[0])) {
|
|
81
|
+
throw new Error("fu_tensor: Data must be 2D array");
|
|
82
|
+
}
|
|
83
|
+
const tensor = new Tensor(data);
|
|
84
|
+
tensor.requiresGrad = requiresGrad;
|
|
85
|
+
return tensor;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// fu_add
|
|
89
|
+
export function fu_add(a, b) {
|
|
90
|
+
if (!(a instanceof Tensor) && !(b instanceof Tensor)) {
|
|
91
|
+
throw new Error("fu_add: At least one operand must be Tensor");
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if (!(a instanceof Tensor)) {
|
|
95
|
+
a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
|
|
96
|
+
Array(b.shape()[1]).fill(a)
|
|
97
|
+
));
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
if (!(b instanceof Tensor)) {
|
|
101
|
+
b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
|
|
102
|
+
Array(a.shape()[1]).fill(b)
|
|
103
|
+
));
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
|
|
107
|
+
throw new Error(`fu_add: Shape mismatch ${a.shape()} vs ${b.shape()}`);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
return new Tensor(a.data.map((r, i) => r.map((v, j) => v + b.data[i][j])));
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// fu_mul
|
|
114
|
+
export function fu_mul(a, b) {
|
|
115
|
+
if (!(a instanceof Tensor) && !(b instanceof Tensor)) {
|
|
116
|
+
throw new Error("fu_mul: At least one operand must be Tensor");
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
if (!(a instanceof Tensor)) {
|
|
120
|
+
a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
|
|
121
|
+
Array(b.shape()[1]).fill(a)
|
|
122
|
+
));
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
if (!(b instanceof Tensor)) {
|
|
126
|
+
b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
|
|
127
|
+
Array(a.shape()[1]).fill(b)
|
|
128
|
+
));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
|
|
132
|
+
throw new Error(`fu_mul: Shape mismatch ${a.shape()} vs ${b.shape()}`);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
return new Tensor(a.data.map((r, i) => r.map((v, j) => v * b.data[i][j])));
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// fu_matmul
|
|
139
|
+
export function fu_matmul(a, b) {
|
|
140
|
+
if (!(a instanceof Tensor)) a = fu_tensor(a);
|
|
141
|
+
if (!(b instanceof Tensor)) b = fu_tensor(b);
|
|
142
|
+
|
|
143
|
+
if (a.shape()[1] !== b.shape()[0]) {
|
|
144
|
+
throw new Error(`fu_matmul: Inner dimension mismatch ${a.shape()[1]} vs ${b.shape()[0]}`);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
return new Tensor(dot(a.data, b.data));
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// fu_sum
|
|
151
|
+
export function fu_sum(tensor) {
|
|
152
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
153
|
+
const total = tensor.data.flat().reduce((a, b) => a + b, 0);
|
|
154
|
+
return new Tensor([[total]]);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// fu_mean
|
|
158
|
+
export function fu_mean(tensor) {
|
|
159
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
160
|
+
const totalElements = tensor.shape()[0] * tensor.shape()[1];
|
|
161
|
+
const sum = fu_sum(tensor).data[0][0];
|
|
162
|
+
return new Tensor([[sum / totalElements]]);
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
// fu_relu
|
|
166
|
+
export function fu_relu(tensor) {
|
|
167
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
168
|
+
return new Tensor(tensor.data.map(r => r.map(v => Math.max(0, v))));
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// fu_sigmoid
|
|
172
|
+
export function fu_sigmoid(tensor) {
|
|
173
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
174
|
+
const fn = v => 1 / (1 + Math.exp(-v));
|
|
175
|
+
return new Tensor(tensor.data.map(r => r.map(fn)));
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// fu_tanh
|
|
179
|
+
export function fu_tanh(tensor) {
|
|
180
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
181
|
+
return new Tensor(tensor.data.map(r => r.map(v => Math.tanh(v))));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// fu_softmax
|
|
185
|
+
export function fu_softmax(tensor) {
|
|
186
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
187
|
+
const result = tensor.data.map(row => {
|
|
188
|
+
const maxVal = Math.max(...row);
|
|
189
|
+
const exps = row.map(v => Math.exp(v - maxVal));
|
|
190
|
+
const sumExps = exps.reduce((a, b) => a + b, 0);
|
|
191
|
+
return exps.map(v => v / sumExps);
|
|
192
|
+
});
|
|
193
|
+
return new Tensor(result);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// fu_flatten - Flatten tensor to 1D
|
|
197
|
+
export function fu_flatten(tensor) {
|
|
198
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
199
|
+
return new Tensor([tensor.data.flat()]);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
// fu_reshape
|
|
203
|
+
export function fu_reshape(tensor, rows, cols) {
|
|
204
|
+
if (!(tensor instanceof Tensor)) tensor = fu_tensor(tensor);
|
|
205
|
+
const flat = tensor.data.flat();
|
|
206
|
+
if (flat.length !== rows * cols) {
|
|
207
|
+
throw new Error(`fu_reshape: Size mismatch ${flat.length} vs ${rows * cols}`);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
const result = [];
|
|
211
|
+
for (let i = 0; i < rows; i++) {
|
|
212
|
+
result.push(flat.slice(i * cols, i * cols + cols));
|
|
213
|
+
}
|
|
214
|
+
return new Tensor(result);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// fu_stack
|
|
218
|
+
export function fu_stack(tensors) {
|
|
219
|
+
if (!tensors.every(t => t instanceof Tensor)) {
|
|
220
|
+
throw new Error("fu_stack: All inputs must be Tensors");
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
const firstShape = tensors[0].shape();
|
|
224
|
+
if (!tensors.every(t => t.shape()[0] === firstShape[0] && t.shape()[1] === firstShape[1])) {
|
|
225
|
+
throw new Error("fu_stack: All tensors must have same shape");
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
const stacked = tensors.map(t => t.data);
|
|
229
|
+
return new Tensor(stacked);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// ---------------------- Tensor ----------------------
|
|
233
|
+
export class Tensor {
|
|
234
|
+
constructor(data){ this.data=data; this.grad=zeros(data.length,data[0].length); }
|
|
235
|
+
shape(){ return [this.data.length,this.data[0].length]; }
|
|
236
|
+
add(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v+t.data[i][j])):this.data.map(r=>r.map(v=>v+t)); }
|
|
237
|
+
sub(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v-t.data[i][j])):this.data.map(r=>r.map(v=>v-t)); }
|
|
238
|
+
mul(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v*t.data[i][j])):this.data.map(r=>r.map(v=>v*t)); }
|
|
239
|
+
matmul(t){ if(t instanceof Tensor) return dot(this.data,t.data); else throw new Error("matmul requires Tensor"); }
|
|
240
|
+
transpose(){ return transpose(this.data); }
|
|
241
|
+
flatten(){ return this.data.flat(); }
|
|
242
|
+
static zeros(r,c){ return new Tensor(zeros(r,c)); }
|
|
243
|
+
static ones(r,c){ return new Tensor(ones(r,c)); }
|
|
244
|
+
static random(r,c,scale=0.1){ return new Tensor(randomMatrix(r,c,scale)); }
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
// ---------------------- Layers ----------------------
|
|
248
|
+
export class Linear {
|
|
249
|
+
constructor(inputDim, outputDim){
|
|
250
|
+
this.W = randomMatrix(inputDim, outputDim);
|
|
251
|
+
this.b = Array(outputDim).fill(0);
|
|
252
|
+
this.gradW = zeros(inputDim, outputDim);
|
|
253
|
+
this.gradb = Array(outputDim).fill(0);
|
|
254
|
+
this.x = null;
|
|
255
|
+
this.originalShape = null; // Track input shape
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
forward(x){
|
|
259
|
+
// Handle both [batch, features] and [batch, 1, features]
|
|
260
|
+
this.originalShape = this._getShapeType(x);
|
|
261
|
+
|
|
262
|
+
if (this.originalShape === '3d') {
|
|
263
|
+
// Convert from [batch, 1, features] to [batch, features]
|
|
264
|
+
this.x = x.map(sample => sample[0]);
|
|
265
|
+
} else {
|
|
266
|
+
// Already in [batch, features] format
|
|
267
|
+
this.x = x;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
const out = dot(this.x, this.W);
|
|
271
|
+
return out.map((row, i) => row.map((v, j) => v + this.b[j]));
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
backward(grad){
|
|
275
|
+
// Compute gradients
|
|
276
|
+
for(let i = 0; i < this.W.length; i++) {
|
|
277
|
+
for(let j = 0; j < this.W[0].length; j++) {
|
|
278
|
+
this.gradW[i][j] = this.x.reduce((sum, row, k) => sum + row[i] * grad[k][j], 0);
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
for(let j = 0; j < this.b.length; j++) {
|
|
283
|
+
this.gradb[j] = grad.reduce((sum, row) => sum + row[j], 0);
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
const gradInput = zeros(this.x.length, this.W.length);
|
|
287
|
+
for(let i = 0; i < this.x.length; i++) {
|
|
288
|
+
for(let j = 0; j < this.W.length; j++) {
|
|
289
|
+
for(let k = 0; k < this.W[0].length; k++) {
|
|
290
|
+
gradInput[i][j] += grad[i][k] * this.W[j][k];
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
//Convert back to original shape if needed
|
|
296
|
+
if (this.originalShape === '3d') {
|
|
297
|
+
return gradInput.map(row => [row]); // Back to [batch, 1, features]
|
|
298
|
+
}
|
|
299
|
+
return gradInput;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
_getShapeType(x) {
|
|
303
|
+
if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
|
|
304
|
+
return '3d'; // [batch, 1, features]
|
|
305
|
+
} else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
|
|
306
|
+
return '2d'; // [batch, features]
|
|
307
|
+
} else {
|
|
308
|
+
throw new Error(`Unsupported input shape for Linear layer`);
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
parameters(){
|
|
313
|
+
return [
|
|
314
|
+
{param: this.W, grad: this.gradW},
|
|
315
|
+
{param: [this.b], grad: [this.gradb]}
|
|
316
|
+
];
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
export class Flatten {
|
|
321
|
+
constructor() {
|
|
322
|
+
this.originalShape = null;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
forward(x) {
|
|
326
|
+
// Always convert to [batch, features] format
|
|
327
|
+
this.originalShape = x.map(sample => this._getShape(sample));
|
|
328
|
+
|
|
329
|
+
return x.map(sample => {
|
|
330
|
+
const flat = this._flatten(sample);
|
|
331
|
+
return flat; // Return as 1D array for [batch, features] compatibility
|
|
332
|
+
});
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
backward(grad) {
|
|
336
|
+
// grad is [batch, features], reshape back to original shape
|
|
337
|
+
return grad.map((flatGrad, batchIdx) => {
|
|
338
|
+
const shape = this.originalShape[batchIdx];
|
|
339
|
+
return this._unflatten(flatGrad, shape);
|
|
340
|
+
});
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
_getShape(sample) {
|
|
344
|
+
if (Array.isArray(sample[0]) && Array.isArray(sample[0][0])) {
|
|
345
|
+
return {
|
|
346
|
+
type: '3d',
|
|
347
|
+
dims: [sample.length, sample[0].length, sample[0][0].length]
|
|
348
|
+
};
|
|
349
|
+
} else if (Array.isArray(sample[0])) {
|
|
350
|
+
return {
|
|
351
|
+
type: '2d',
|
|
352
|
+
dims: [sample.length, sample[0].length]
|
|
353
|
+
};
|
|
354
|
+
} else {
|
|
355
|
+
return {
|
|
356
|
+
type: '1d',
|
|
357
|
+
dims: [sample.length]
|
|
358
|
+
};
|
|
359
|
+
}
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
_flatten(sample) {
|
|
363
|
+
if (Array.isArray(sample[0]) && Array.isArray(sample[0][0])) {
|
|
364
|
+
return sample.flat(2); // [channels, height, width] -> flat
|
|
365
|
+
} else if (Array.isArray(sample[0])) {
|
|
366
|
+
return sample.flat(); // [height, width] -> flat
|
|
367
|
+
} else {
|
|
368
|
+
return sample; // already flat
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
_unflatten(flat, shape) {
|
|
373
|
+
if (shape.type === '3d') {
|
|
374
|
+
const [channels, height, width] = shape.dims;
|
|
375
|
+
const result = [];
|
|
376
|
+
let index = 0;
|
|
377
|
+
for (let c = 0; c < channels; c++) {
|
|
378
|
+
const channel = [];
|
|
379
|
+
for (let h = 0; h < height; h++) {
|
|
380
|
+
const row = [];
|
|
381
|
+
for (let w = 0; w < width; w++) {
|
|
382
|
+
row.push(flat[index++]);
|
|
383
|
+
}
|
|
384
|
+
channel.push(row);
|
|
385
|
+
}
|
|
386
|
+
result.push(channel);
|
|
387
|
+
}
|
|
388
|
+
return result;
|
|
389
|
+
} else if (shape.type === '2d') {
|
|
390
|
+
const [height, width] = shape.dims;
|
|
391
|
+
const result = [];
|
|
392
|
+
for (let h = 0; h < height; h++) {
|
|
393
|
+
result.push(flat.slice(h * width, h * width + width));
|
|
394
|
+
}
|
|
395
|
+
return result;
|
|
396
|
+
} else {
|
|
397
|
+
return flat; // 1d
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
parameters() { return []; }
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
// ---------------------- Conv2D ----------------------
|
|
405
|
+
export class Conv2D {
|
|
406
|
+
constructor(inC, outC, kernel, stride=1, padding=0){
|
|
407
|
+
this.inC = inC;
|
|
408
|
+
this.outC = outC;
|
|
409
|
+
this.kernel = kernel;
|
|
410
|
+
this.stride = stride;
|
|
411
|
+
this.padding = padding;
|
|
412
|
+
this.W = Array(outC).fill().map(() =>
|
|
413
|
+
Array(inC).fill().map(() => randomMatrix(kernel, kernel))
|
|
414
|
+
);
|
|
415
|
+
this.gradW = Array(outC).fill().map(() =>
|
|
416
|
+
Array(inC).fill().map(() => zeros(kernel, kernel))
|
|
417
|
+
);
|
|
418
|
+
this.x = null;
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
pad2D(input, pad){
|
|
422
|
+
// Input is single channel [height, width]
|
|
423
|
+
if (!input || !input.length) return input;
|
|
424
|
+
|
|
425
|
+
const rows = input.length + 2 * pad;
|
|
426
|
+
const cols = input[0].length + 2 * pad;
|
|
427
|
+
const out = Array.from({length: rows}, () => Array(cols).fill(0));
|
|
428
|
+
|
|
429
|
+
for(let i = 0; i < input.length; i++) {
|
|
430
|
+
for(let j = 0; j < input[0].length; j++) {
|
|
431
|
+
out[i + pad][j + pad] = input[i][j];
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
return out;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
conv2DSingle(input, kernel) {
|
|
438
|
+
const rows = Math.floor((input.length - kernel.length) / this.stride) + 1;
|
|
439
|
+
const cols = Math.floor((input[0].length - kernel[0].length) / this.stride) + 1;
|
|
440
|
+
const out = zeros(rows, cols);
|
|
441
|
+
|
|
442
|
+
for(let i = 0; i < rows; i++) {
|
|
443
|
+
for(let j = 0; j < cols; j++) {
|
|
444
|
+
let sum = 0;
|
|
445
|
+
for(let ki = 0; ki < kernel.length; ki++) {
|
|
446
|
+
for(let kj = 0; kj < kernel[0].length; kj++) {
|
|
447
|
+
const inputRow = i * this.stride + ki;
|
|
448
|
+
const inputCol = j * this.stride + kj;
|
|
449
|
+
sum += input[inputRow][inputCol] * kernel[ki][kj];
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
out[i][j] = sum;
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
return out;
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
forward(batch) {
|
|
459
|
+
this.x = batch;
|
|
460
|
+
return batch.map(sample => {
|
|
461
|
+
const channelsOut = [];
|
|
462
|
+
for(let oc = 0; oc < this.outC; oc++) {
|
|
463
|
+
let outChan = null;
|
|
464
|
+
for(let ic = 0; ic < this.inC; ic++) {
|
|
465
|
+
let inputChan = sample[ic];
|
|
466
|
+
if(this.padding > 0) {
|
|
467
|
+
inputChan = this.pad2D(inputChan, this.padding);
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
const conv = this.conv2DSingle(inputChan, this.W[oc][ic]);
|
|
471
|
+
|
|
472
|
+
if(outChan === null) {
|
|
473
|
+
outChan = conv;
|
|
474
|
+
} else {
|
|
475
|
+
outChan = addMatrices(outChan, conv);
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
channelsOut.push(outChan);
|
|
479
|
+
}
|
|
480
|
+
return channelsOut;
|
|
481
|
+
});
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
backward(grad) {
|
|
485
|
+
const batchSize = this.x.length;
|
|
486
|
+
const gradW = this.W.map(oc => oc.map(ic => zeros(this.kernel, this.kernel)));
|
|
487
|
+
const gradInput = this.x.map(sample =>
|
|
488
|
+
sample.map(chan => zeros(chan.length, chan[0].length))
|
|
489
|
+
);
|
|
490
|
+
|
|
491
|
+
for (let b = 0; b < batchSize; b++) {
|
|
492
|
+
for (let oc = 0; oc < this.outC; oc++) {
|
|
493
|
+
for (let ic = 0; ic < this.inC; ic++) {
|
|
494
|
+
const outGrad = grad[b][oc];
|
|
495
|
+
|
|
496
|
+
// Compute gradW
|
|
497
|
+
for (let i = 0; i < this.kernel; i++) {
|
|
498
|
+
for (let j = 0; j < this.kernel; j++) {
|
|
499
|
+
let sum = 0;
|
|
500
|
+
for (let y = 0; y < outGrad.length; y++) {
|
|
501
|
+
for (let x = 0; x < outGrad[0].length; x++) {
|
|
502
|
+
const inY = y * this.stride + i;
|
|
503
|
+
const inX = x * this.stride + j;
|
|
504
|
+
if (inY < this.x[b][ic].length && inX < this.x[b][ic][0].length) {
|
|
505
|
+
sum += this.x[b][ic][inY][inX] * outGrad[y][x];
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
gradW[oc][ic][i][j] += sum;
|
|
510
|
+
}
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
// Compute gradInput
|
|
514
|
+
for (let y = 0; y < outGrad.length; y++) {
|
|
515
|
+
for (let x = 0; x < outGrad[0].length; x++) {
|
|
516
|
+
for (let ki = 0; ki < this.kernel; ki++) {
|
|
517
|
+
for (let kj = 0; kj < this.kernel; kj++) {
|
|
518
|
+
const inY = y * this.stride + ki;
|
|
519
|
+
const inX = x * this.stride + kj;
|
|
520
|
+
if (inY < gradInput[b][ic].length && inX < gradInput[b][ic][0].length) {
|
|
521
|
+
gradInput[b][ic][inY][inX] +=
|
|
522
|
+
this.W[oc][ic][ki][kj] * outGrad[y][x];
|
|
523
|
+
}
|
|
524
|
+
}
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
this.gradW = gradW;
|
|
533
|
+
return gradInput;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
parameters() {
|
|
537
|
+
return this.W.flatMap((w, oc) =>
|
|
538
|
+
w.map((wc, ic) => ({
|
|
539
|
+
param: wc,
|
|
540
|
+
grad: this.gradW[oc][ic]
|
|
541
|
+
}))
|
|
542
|
+
);
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
// ---------------------- Sequential ----------------------
|
|
547
|
+
export class Sequential {
|
|
548
|
+
constructor(layers=[]){ this.layers=layers; }
|
|
549
|
+
forward(x){ return this.layers.reduce((acc,l)=>l.forward(acc), x); }
|
|
550
|
+
backward(grad){ return this.layers.reduceRight((g,l)=>l.backward(g), grad); }
|
|
551
|
+
parameters(){ return this.layers.flatMap(l=>l.parameters?l.parameters():[]); }
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
// ---------------------- Activations ----------------------
|
|
555
|
+
export class ReLU{
|
|
556
|
+
constructor(){ this.mask = null; this.originalShape = null; }
|
|
557
|
+
|
|
558
|
+
forward(x){
|
|
559
|
+
this.originalShape = this._getShapeType(x);
|
|
560
|
+
|
|
561
|
+
if (this.originalShape === '3d') {
|
|
562
|
+
// Handle [batch, 1, features]
|
|
563
|
+
this.mask = x.map(sample => sample[0].map(v => v > 0));
|
|
564
|
+
return x.map(sample => [sample[0].map(v => Math.max(0, v))]);
|
|
565
|
+
} else {
|
|
566
|
+
// Handle [batch, features]
|
|
567
|
+
this.mask = x.map(row => row.map(v => v > 0));
|
|
568
|
+
return x.map(row => row.map(v => Math.max(0, v)));
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
backward(grad){
|
|
573
|
+
if (this.originalShape === '3d') {
|
|
574
|
+
return grad.map((sample, i) =>
|
|
575
|
+
[sample[0].map((v, j) => this.mask[i][j] ? v : 0)]
|
|
576
|
+
);
|
|
577
|
+
} else {
|
|
578
|
+
return grad.map((row, i) =>
|
|
579
|
+
row.map((v, j) => this.mask[i][j] ? v : 0)
|
|
580
|
+
);
|
|
581
|
+
}
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
_getShapeType(x) {
|
|
585
|
+
if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
|
|
586
|
+
return '3d';
|
|
587
|
+
} else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
|
|
588
|
+
return '2d';
|
|
589
|
+
} else {
|
|
590
|
+
throw new Error(`Unsupported input shape for ReLU`);
|
|
591
|
+
}
|
|
592
|
+
}
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
export class Sigmoid{ constructor(){ this.out=null; } forward(x){ const fn=v=>1/(1+Math.exp(-v)); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*this.out[i][j]*(1-this.out[i][j]))); } }
|
|
596
|
+
export class Tanh{ constructor(){ this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>Math.tanh(v))); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(1-this.out[i][j]**2))); } }
|
|
597
|
+
export class LeakyReLU{ constructor(alpha=0.01){ this.alpha=alpha; this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>v>0?v:v*this.alpha)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(this.out[i][j]>0?1:this.alpha))); } }
|
|
598
|
+
export class GELU{ constructor(){ this.out=null; } forward(x){ const fn=v=>0.5*v*(1+Math.tanh(Math.sqrt(2/Math.PI)*(v+0.044715*v**3))); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map(v=>v*1)); } }
|
|
599
|
+
|
|
600
|
+
// ---------------------- Dropout ----------------------
|
|
601
|
+
export class Dropout{ constructor(p=0.5){ this.p=p; } forward(x){ return x.map(r=>r.map(v=>v*Math.random()>=this.p?v:0)); } backward(grad){ return grad.map(r=>r.map(v=>v*(1-this.p))); } }
|
|
602
|
+
|
|
603
|
+
// ---------------------- Losses ----------------------
|
|
604
|
+
export class MSELoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((row,i)=>row.reduce((sum,v,j)=>sum+(v-target[i][j])**2,0)/row.length); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((row,i)=>row.map((v,j)=>2*(v-this.target[i][j])/row.length)); } }
|
|
605
|
+
export class CrossEntropyLoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((p,i)=>crossEntropy(softmax(p),target[i])); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((p,i)=>{ const s=softmax(p); return s.map((v,j)=>(v-this.target[i][j])/this.pred.length); }); } }
|
|
606
|
+
|
|
607
|
+
// ---------------------- Optimizers ----------------------
|
|
608
|
+
export class Adam{
|
|
609
|
+
constructor(params, lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-8, max_grad_norm = 1.0){
|
|
610
|
+
// Handle both parameter styles: (params, lr) OR (params, {lr, ...})
|
|
611
|
+
if (typeof lr === 'object') {
|
|
612
|
+
// Options object provided
|
|
613
|
+
const options = lr;
|
|
614
|
+
this.lr = options.lr || 0.001;
|
|
615
|
+
this.beta1 = options.b1 || options.beta1 || 0.9;
|
|
616
|
+
this.beta2 = options.b2 || options.beta2 || 0.999;
|
|
617
|
+
this.eps = options.eps || 1e-8;
|
|
618
|
+
this.max_grad_norm = options.max_grad_norm || 1.0;
|
|
619
|
+
} else {
|
|
620
|
+
// Individual parameters provided
|
|
621
|
+
this.lr = lr;
|
|
622
|
+
this.beta1 = b1;
|
|
623
|
+
this.beta2 = b2;
|
|
624
|
+
this.eps = eps;
|
|
625
|
+
this.max_grad_norm = max_grad_norm;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
this.params = params;
|
|
629
|
+
this.m = params.map(p => zeros(p.param.length, p.param[0].length || 1));
|
|
630
|
+
this.v = params.map(p => zeros(p.param.length, p.param[0].length || 1));
|
|
631
|
+
this.t = 0;
|
|
632
|
+
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
step(){
|
|
636
|
+
this.t++;
|
|
637
|
+
this.params.forEach((p, idx) => {
|
|
638
|
+
// Calculate gradient norm for clipping
|
|
639
|
+
let grad_norm_sq = 0;
|
|
640
|
+
for (let i = 0; i < p.param.length; i++){
|
|
641
|
+
for (let j = 0; j < (p.param[0].length || 1); j++){
|
|
642
|
+
const grad_val = p.grad[i] && p.grad[i][j] !== undefined ? p.grad[i][j] : 0;
|
|
643
|
+
grad_norm_sq += grad_val * grad_val;
|
|
644
|
+
}
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
const grad_norm = Math.sqrt(grad_norm_sq);
|
|
648
|
+
const clip_scale = grad_norm > this.max_grad_norm ? this.max_grad_norm / grad_norm : 1.0;
|
|
649
|
+
|
|
650
|
+
// Update with clipped gradients
|
|
651
|
+
for (let i = 0; i < p.param.length; i++){
|
|
652
|
+
for(let j = 0; j < (p.param[0].length || 1); j++){
|
|
653
|
+
if (p.grad[i] && p.grad[i][j] !== undefined){
|
|
654
|
+
const g = p.grad[i][j] * clip_scale;
|
|
655
|
+
this.m[idx][i][j] = this.beta1 * this.m[idx][i][j] + (1 - this.beta1) * g;
|
|
656
|
+
this.v[idx][i][j] = this.beta2 * this.v[idx][i][j] + (1 - this.beta2) * g * g;
|
|
657
|
+
const mHat = this.m[idx][i][j] / (1 - Math.pow(this.beta1, this.t));
|
|
658
|
+
const vHat = this.v[idx][i][j] / (1 - Math.pow(this.beta2, this.t));
|
|
659
|
+
p.param[i][j] -= this.lr * mHat / (Math.sqrt(vHat) + this.eps);
|
|
660
|
+
}
|
|
661
|
+
}
|
|
662
|
+
}
|
|
663
|
+
});
|
|
664
|
+
}
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
export class SGD{
|
|
668
|
+
constructor(params, lr = 0.01, max_grad_norm = 1.0) {
|
|
669
|
+
this.params = params;
|
|
670
|
+
this.lr = lr;
|
|
671
|
+
this.max_grad_norm = max_grad_norm; // Gradient Clipping
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
step() {
|
|
675
|
+
this.params.forEach(p => {
|
|
676
|
+
// Calculate gradient norm
|
|
677
|
+
let grad_norm_sq = 0;
|
|
678
|
+
let total_params = 0;
|
|
679
|
+
|
|
680
|
+
for (let i = 0; i < p.param.length; i++){
|
|
681
|
+
const row = p.param[i];
|
|
682
|
+
for (let j = 0; j < (row.length || 1); j++) {
|
|
683
|
+
const grad_val = p.grad[i] && p.grad[i][j] !== undefined ? p.grad[i][j] : 0;
|
|
684
|
+
grad_norm_sq += grad_val * grad_val;
|
|
685
|
+
total_params++;
|
|
686
|
+
}
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
const grad_norm = Math.sqrt(grad_norm_sq);
|
|
690
|
+
|
|
691
|
+
// Apply gradient clipping if needed
|
|
692
|
+
const clip_scale = grad_norm > this.max_grad_norm ? this.max_grad_norm / grad_norm : 1.0;
|
|
693
|
+
|
|
694
|
+
// Update parameters with clipped gradients
|
|
695
|
+
for (let i = 0; i < p.param.length; i++){
|
|
696
|
+
const row = p.param[i];
|
|
697
|
+
for (let j = 0; j < (row.length || 1); j++) {
|
|
698
|
+
if (p.grad[i] && p.grad[i][j] !== undefined){
|
|
699
|
+
p.param[i][j] -= this.lr * (p.grad[i][j] * clip_scale);
|
|
700
|
+
}
|
|
701
|
+
}
|
|
702
|
+
}
|
|
703
|
+
});
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
export class LION {
|
|
709
|
+
constructor(params, options = {}) {
|
|
710
|
+
this.params = params;
|
|
711
|
+
|
|
712
|
+
const {
|
|
713
|
+
lr = 0.0001, // Lions typically uses smaller LR
|
|
714
|
+
beta1 = 0.9, // First moment decay
|
|
715
|
+
beta2 = 0.99, // Second moment decay
|
|
716
|
+
weight_decay = 0, // L2 regularization
|
|
717
|
+
eps = 1e-8 // Numerical stability
|
|
718
|
+
} = options;
|
|
719
|
+
|
|
720
|
+
this.lr = lr;
|
|
721
|
+
this.beta1 = beta1;
|
|
722
|
+
this.beta2 = beta2;
|
|
723
|
+
this.weight_decay = weight_decay;
|
|
724
|
+
this.eps = eps;
|
|
725
|
+
|
|
726
|
+
// Initialize momentums
|
|
727
|
+
this.m = params.map(p => zeros(p.param.length, p.param[0].length || 1));
|
|
728
|
+
this.t = 0;
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
step() {
|
|
732
|
+
this.t++;
|
|
733
|
+
|
|
734
|
+
this.params.forEach((p, idx) => {
|
|
735
|
+
for (let i = 0; i < p.param.length; i++) {
|
|
736
|
+
for (let j = 0; j < (p.param[0].length || 1); j++) {
|
|
737
|
+
if (p.grad[i] && p.grad[i][j] !== undefined) {
|
|
738
|
+
const grad = p.grad[i][j];
|
|
739
|
+
|
|
740
|
+
// Update momentum: m_t = β1 * m_{t-1} + (1 - β1) * g_t
|
|
741
|
+
this.m[idx][i][j] = this.beta1 * this.m[idx][i][j] + (1 - this.beta1) * grad;
|
|
742
|
+
|
|
743
|
+
// LIONS update: param = param - η * sign(m_t + β2 * g_t)
|
|
744
|
+
const update_term = this.m[idx][i][j] + this.beta2 * grad;
|
|
745
|
+
|
|
746
|
+
// Get sign with epsilon for stability
|
|
747
|
+
let sign_val;
|
|
748
|
+
if (update_term > this.eps) sign_val = 1;
|
|
749
|
+
else if (update_term < -this.eps) sign_val = -1;
|
|
750
|
+
else sign_val = 0;
|
|
751
|
+
|
|
752
|
+
let update = sign_val * this.lr;
|
|
753
|
+
|
|
754
|
+
// Add weight decay if specified
|
|
755
|
+
if (this.weight_decay > 0) {
|
|
756
|
+
update += this.weight_decay * this.lr * p.param[i][j];
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
p.param[i][j] -= update;
|
|
760
|
+
}
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
});
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
zeroGrad() {
|
|
767
|
+
this.params.forEach(p => {
|
|
768
|
+
if (p.grad) {
|
|
769
|
+
for (let i = 0; i < p.grad.length; i++) {
|
|
770
|
+
for (let j = 0; j < p.grad[i].length; j++) {
|
|
771
|
+
p.grad[i][j] = 0;
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
}
|
|
775
|
+
});
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
// ---------------------- Learning Rate Schedulers ----------------------
|
|
780
|
+
export class StepLR {
|
|
781
|
+
constructor(optimizer, step_size, gamma=1.0) {
|
|
782
|
+
this.optimizer = optimizer;
|
|
783
|
+
this.step_size = step_size;
|
|
784
|
+
this.gamma = gamma;
|
|
785
|
+
this.last_epoch = 0;
|
|
786
|
+
this.base_lr = optimizer.lr;
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
step() {
|
|
790
|
+
this.last_epoch += 1;
|
|
791
|
+
if (this.last_epoch % this.step_size === 0) {
|
|
792
|
+
this.optimizer.lr *= this.gamma;
|
|
793
|
+
}
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
get_lr() {
|
|
797
|
+
return this.optimizer.lr;
|
|
798
|
+
/* Do nothing else */
|
|
799
|
+
}
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
export class LambdaLR {
|
|
803
|
+
constructor(optimizer, lr_lambda) {
|
|
804
|
+
this.optimizer = optimizer;
|
|
805
|
+
this.lr_lambda = lr_lambda;
|
|
806
|
+
this.last_epoch = 0;
|
|
807
|
+
this.base_lr = optimizer.lr;
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
step() {
|
|
811
|
+
this.last_epoch += 1;
|
|
812
|
+
this.optimizer.lr = this.base_lr * this.lr_lambda(this.last_epoch);
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
get_lr() {
|
|
816
|
+
return this.optimizer.lr;
|
|
817
|
+
/* Do nothing else */
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
// ---------------------- ReduceLROnPlateau Scheduler ----------------------
|
|
822
|
+
export class ReduceLROnPlateau {
|
|
823
|
+
constructor(optimizer, options = {}) {
|
|
824
|
+
this.optimizer = optimizer;
|
|
825
|
+
|
|
826
|
+
// Destructure with defaults
|
|
827
|
+
const {
|
|
828
|
+
patience = 10,
|
|
829
|
+
factor = 0.5,
|
|
830
|
+
min_lr = 1e-6,
|
|
831
|
+
threshold = 1e-4,
|
|
832
|
+
cooldown = 0,
|
|
833
|
+
verbose = false
|
|
834
|
+
} = options;
|
|
835
|
+
|
|
836
|
+
this.patience = patience;
|
|
837
|
+
this.factor = factor;
|
|
838
|
+
this.min_lr = min_lr;
|
|
839
|
+
this.threshold = threshold;
|
|
840
|
+
this.cooldown = cooldown;
|
|
841
|
+
this.verbose = verbose;
|
|
842
|
+
|
|
843
|
+
// State tracking
|
|
844
|
+
this.bestLoss = Infinity;
|
|
845
|
+
this.wait = 0;
|
|
846
|
+
this.cooldown_counter = 0;
|
|
847
|
+
this.num_reductions = 0;
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
step(loss) {
|
|
851
|
+
// Handle cooldown
|
|
852
|
+
if (this.cooldown_counter > 0) {
|
|
853
|
+
this.cooldown_counter--;
|
|
854
|
+
return;
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
// Check if this is significant improvement (relative threshold)
|
|
858
|
+
const improvement_needed = this.bestLoss * (1 - this.threshold);
|
|
859
|
+
const is_better = loss < improvement_needed;
|
|
860
|
+
|
|
861
|
+
if (is_better) {
|
|
862
|
+
// Significant improvement - reset
|
|
863
|
+
this.bestLoss = loss;
|
|
864
|
+
this.wait = 0;
|
|
865
|
+
} else {
|
|
866
|
+
// No significant improvement
|
|
867
|
+
this.wait += 1;
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
// Check if we've waited long enough
|
|
871
|
+
if (this.wait >= this.patience) {
|
|
872
|
+
this._reduce_lr();
|
|
873
|
+
this.cooldown_counter = this.cooldown;
|
|
874
|
+
this.wait = 0;
|
|
875
|
+
}
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
_reduce_lr() {
|
|
879
|
+
const old_lr = this.optimizer.lr;
|
|
880
|
+
const new_lr = Math.max(old_lr * this.factor, this.min_lr);
|
|
881
|
+
|
|
882
|
+
if (new_lr < old_lr) {
|
|
883
|
+
this.optimizer.lr = new_lr;
|
|
884
|
+
this.num_reductions++;
|
|
885
|
+
|
|
886
|
+
if (this.verbose) {
|
|
887
|
+
console.log(`ReduceLROnPlateau: reducing LR from ${old_lr} to ${new_lr}`);
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
get_last_lr() {
|
|
893
|
+
return this.optimizer.lr;
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
reset() {
|
|
897
|
+
this.bestLoss = Infinity;
|
|
898
|
+
this.wait = 0;
|
|
899
|
+
this.cooldown_counter = 0;
|
|
900
|
+
this.num_reductions = 0;
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
// ---------------------- ELU Activation ----------------------
|
|
905
|
+
export class ELU {
|
|
906
|
+
constructor(alpha=1.0) {
|
|
907
|
+
this.alpha = alpha;
|
|
908
|
+
this.out = null;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
forward(x) {
|
|
912
|
+
this.out = x.map(row =>
|
|
913
|
+
row.map(v => v > 0 ? v : this.alpha * (Math.exp(v) - 1))
|
|
914
|
+
);
|
|
915
|
+
return this.out;
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
backward(grad) {
|
|
919
|
+
return grad.map((row, i) =>
|
|
920
|
+
row.map((v, j) =>
|
|
921
|
+
v * (this.out[i][j] > 0 ? 1 : this.alpha * Math.exp(this.out[i][j]))
|
|
922
|
+
)
|
|
923
|
+
);
|
|
924
|
+
}
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
// ---------------------- Mish Activation ----------------------
|
|
928
|
+
export class Mish {
|
|
929
|
+
constructor() {
|
|
930
|
+
this.x = null;
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
forward(x) {
|
|
934
|
+
this.x = x;
|
|
935
|
+
return x.map(row =>
|
|
936
|
+
row.map(v => {
|
|
937
|
+
// Mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
|
|
938
|
+
const softplus = Math.log(1 + Math.exp(v));
|
|
939
|
+
return v * Math.tanh(softplus);
|
|
940
|
+
})
|
|
941
|
+
);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
backward(grad) {
|
|
945
|
+
return grad.map((row, i) =>
|
|
946
|
+
row.map((v, j) => {
|
|
947
|
+
const x_val = this.x[i][j];
|
|
948
|
+
|
|
949
|
+
// Gradient of Mish:
|
|
950
|
+
// δ = ω * (4(x+1) + 4e^2x + e^3x + e^x(4x+6)) / (2e^x + e^2x + 2)^2
|
|
951
|
+
// where ω = sech^2(softplus(x))
|
|
952
|
+
|
|
953
|
+
const exp_x = Math.exp(x_val);
|
|
954
|
+
const exp_2x = Math.exp(2 * x_val);
|
|
955
|
+
const exp_3x = Math.exp(3 * x_val);
|
|
956
|
+
const softplus = Math.log(1 + exp_x);
|
|
957
|
+
|
|
958
|
+
const sech_softplus = 1 / Math.cosh(softplus);
|
|
959
|
+
const numerator = 4 * (x_val + 1) + 4 * exp_2x + exp_3x + exp_x * (4 * x_val + 6);
|
|
960
|
+
const denominator = Math.pow(2 * exp_x + exp_2x + 2, 2);
|
|
961
|
+
|
|
962
|
+
const mish_grad = (sech_softplus * sech_softplus) * (numerator / denominator);
|
|
963
|
+
return v * mish_grad;
|
|
964
|
+
})
|
|
965
|
+
);
|
|
966
|
+
}
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
// ---------------------- SiLU Activation ----------------------
|
|
970
|
+
export class SiLU {
|
|
971
|
+
constructor() {
|
|
972
|
+
this.x = null;
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
forward(x) {
|
|
976
|
+
this.x = x;
|
|
977
|
+
return x.map(row =>
|
|
978
|
+
row.map(v => v / (1 + Math.exp(-v))) // x * sigmoid(x)
|
|
979
|
+
);
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
backward(grad) {
|
|
983
|
+
return grad.map((row, i) =>
|
|
984
|
+
row.map((v, j) => {
|
|
985
|
+
const x_val = this.x[i][j];
|
|
986
|
+
const sigmoid = 1 / (1 + Math.exp(-x_val));
|
|
987
|
+
return v * (sigmoid * (1 + x_val * (1 - sigmoid)));
|
|
988
|
+
})
|
|
989
|
+
);
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
// ---------------------- BatchNorm2D ----------------------
|
|
995
|
+
export class BatchNorm2d {
|
|
996
|
+
constructor(numFeatures, eps=1e-5, momentum=0.1, affine=true) {
|
|
997
|
+
this.numFeatures = numFeatures;
|
|
998
|
+
this.eps = eps;
|
|
999
|
+
this.momentum = momentum;
|
|
1000
|
+
this.affine = affine;
|
|
1001
|
+
|
|
1002
|
+
// Parameters
|
|
1003
|
+
if (affine) {
|
|
1004
|
+
this.weight = Array(numFeatures).fill(1);
|
|
1005
|
+
this.bias = Array(numFeatures).fill(0);
|
|
1006
|
+
this.gradWeight = Array(numFeatures).fill(0);
|
|
1007
|
+
this.gradBias = Array(numFeatures).fill(0);
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1010
|
+
// Running statistics
|
|
1011
|
+
this.runningMean = Array(numFeatures).fill(0);
|
|
1012
|
+
this.runningVar = Array(numFeatures).fill(1);
|
|
1013
|
+
|
|
1014
|
+
// Training state
|
|
1015
|
+
this.training = true;
|
|
1016
|
+
this.x = null;
|
|
1017
|
+
this.xCentered = null;
|
|
1018
|
+
this.std = null;
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
forward(x) {
|
|
1022
|
+
// x shape: [batch, channels, height, width]
|
|
1023
|
+
this.x = x;
|
|
1024
|
+
const batchSize = x.length;
|
|
1025
|
+
const channels = x[0].length;
|
|
1026
|
+
|
|
1027
|
+
if (this.training) {
|
|
1028
|
+
// Calculate mean per channel
|
|
1029
|
+
const means = Array(channels).fill(0);
|
|
1030
|
+
for (let b = 0; b < batchSize; b++) {
|
|
1031
|
+
for (let c = 0; c < channels; c++) {
|
|
1032
|
+
const channelData = x[b][c];
|
|
1033
|
+
let sum = 0;
|
|
1034
|
+
for (let i = 0; i < channelData.length; i++) {
|
|
1035
|
+
for (let j = 0; j < channelData[0].length; j++) {
|
|
1036
|
+
sum += channelData[i][j];
|
|
1037
|
+
}
|
|
1038
|
+
}
|
|
1039
|
+
means[c] += sum / (channelData.length * channelData[0].length);
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
means.forEach((_, c) => means[c] /= batchSize);
|
|
1043
|
+
|
|
1044
|
+
// Calculate variance per channel
|
|
1045
|
+
const variances = Array(channels).fill(0);
|
|
1046
|
+
for (let b = 0; b < batchSize; b++) {
|
|
1047
|
+
for (let c = 0; c < channels; c++) {
|
|
1048
|
+
const channelData = x[b][c];
|
|
1049
|
+
let sum = 0;
|
|
1050
|
+
for (let i = 0; i < channelData.length; i++) {
|
|
1051
|
+
for (let j = 0; j < channelData[0].length; j++) {
|
|
1052
|
+
sum += Math.pow(channelData[i][j] - means[c], 2);
|
|
1053
|
+
}
|
|
1054
|
+
}
|
|
1055
|
+
variances[c] += sum / (channelData.length * channelData[0].length);
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
variances.forEach((_, c) => variances[c] /= batchSize);
|
|
1059
|
+
|
|
1060
|
+
// Update running statistics
|
|
1061
|
+
for (let c = 0; c < channels; c++) {
|
|
1062
|
+
this.runningMean[c] = this.momentum * means[c] + (1 - this.momentum) * this.runningMean[c];
|
|
1063
|
+
this.runningVar[c] = this.momentum * variances[c] + (1 - this.momentum) * this.runningVar[c];
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
// Normalize
|
|
1067
|
+
this.xCentered = [];
|
|
1068
|
+
this.std = Array(channels).fill(0).map(() => []);
|
|
1069
|
+
|
|
1070
|
+
const output = [];
|
|
1071
|
+
for (let b = 0; b < batchSize; b++) {
|
|
1072
|
+
const batchOut = [];
|
|
1073
|
+
for (let c = 0; c < channels; c++) {
|
|
1074
|
+
const channelData = x[b][c];
|
|
1075
|
+
const channelOut = zeros(channelData.length, channelData[0].length);
|
|
1076
|
+
const channelCentered = zeros(channelData.length, channelData[0].length);
|
|
1077
|
+
const channelStd = Math.sqrt(variances[c] + this.eps);
|
|
1078
|
+
this.std[c].push(channelStd);
|
|
1079
|
+
|
|
1080
|
+
for (let i = 0; i < channelData.length; i++) {
|
|
1081
|
+
for (let j = 0; j < channelData[0].length; j++) {
|
|
1082
|
+
channelCentered[i][j] = channelData[i][j] - means[c];
|
|
1083
|
+
channelOut[i][j] = channelCentered[i][j] / channelStd;
|
|
1084
|
+
|
|
1085
|
+
// Apply affine transformation if enabled
|
|
1086
|
+
if (this.affine) {
|
|
1087
|
+
channelOut[i][j] = channelOut[i][j] * this.weight[c] + this.bias[c];
|
|
1088
|
+
}
|
|
1089
|
+
}
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
batchOut.push(channelOut);
|
|
1093
|
+
if (b === 0) this.xCentered.push(channelCentered);
|
|
1094
|
+
else this.xCentered[c] = addMatrices(this.xCentered[c], channelCentered);
|
|
1095
|
+
}
|
|
1096
|
+
output.push(batchOut);
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
return output;
|
|
1100
|
+
} else {
|
|
1101
|
+
// Inference mode - use running statistics
|
|
1102
|
+
const output = [];
|
|
1103
|
+
for (let b = 0; b < batchSize; b++) {
|
|
1104
|
+
const batchOut = [];
|
|
1105
|
+
for (let c = 0; c < channels; c++) {
|
|
1106
|
+
const channelData = x[b][c];
|
|
1107
|
+
const channelOut = zeros(channelData.length, channelData[0].length);
|
|
1108
|
+
const channelStd = Math.sqrt(this.runningVar[c] + this.eps);
|
|
1109
|
+
|
|
1110
|
+
for (let i = 0; i < channelData.length; i++) {
|
|
1111
|
+
for (let j = 0; j < channelData[0].length; j++) {
|
|
1112
|
+
channelOut[i][j] = (channelData[i][j] - this.runningMean[c]) / channelStd;
|
|
1113
|
+
|
|
1114
|
+
// Apply affine transformation if enabled
|
|
1115
|
+
if (this.affine) {
|
|
1116
|
+
channelOut[i][j] = channelOut[i][j] * this.weight[c] + this.bias[c];
|
|
1117
|
+
}
|
|
1118
|
+
}
|
|
1119
|
+
}
|
|
1120
|
+
|
|
1121
|
+
batchOut.push(channelOut);
|
|
1122
|
+
}
|
|
1123
|
+
output.push(batchOut);
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
return output;
|
|
1127
|
+
}
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
backward(gradOutput) {
|
|
1131
|
+
if (!this.training) {
|
|
1132
|
+
throw new Error("Backward should only be called in training mode");
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
const batchSize = gradOutput.length;
|
|
1136
|
+
const channels = gradOutput[0].length;
|
|
1137
|
+
|
|
1138
|
+
// Initialize gradients
|
|
1139
|
+
const gradInput = this.x.map(batch =>
|
|
1140
|
+
batch.map(channel =>
|
|
1141
|
+
zeros(channel.length, channel[0].length)
|
|
1142
|
+
)
|
|
1143
|
+
);
|
|
1144
|
+
|
|
1145
|
+
if (this.affine) {
|
|
1146
|
+
this.gradWeight.fill(0);
|
|
1147
|
+
this.gradBias.fill(0);
|
|
1148
|
+
}
|
|
1149
|
+
|
|
1150
|
+
for (let c = 0; c < channels; c++) {
|
|
1151
|
+
let sumGradWeight = 0;
|
|
1152
|
+
let sumGradBias = 0;
|
|
1153
|
+
|
|
1154
|
+
for (let b = 0; b < batchSize; b++) {
|
|
1155
|
+
const channelGrad = gradOutput[b][c];
|
|
1156
|
+
const channelData = this.x[b][c];
|
|
1157
|
+
|
|
1158
|
+
// Calculate gradients for bias and weight
|
|
1159
|
+
if (this.affine) {
|
|
1160
|
+
for (let i = 0; i < channelGrad.length; i++) {
|
|
1161
|
+
for (let j = 0; j < channelGrad[0].length; j++) {
|
|
1162
|
+
sumGradBias += channelGrad[i][j];
|
|
1163
|
+
sumGradWeight += channelGrad[i][j] * (this.xCentered[c][i][j] / this.std[c][b]);
|
|
1164
|
+
}
|
|
1165
|
+
}
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
// Calculate gradient for input
|
|
1169
|
+
const n = channelData.length * channelData[0].length;
|
|
1170
|
+
const stdInv = 1 / this.std[c][b];
|
|
1171
|
+
|
|
1172
|
+
for (let i = 0; i < channelGrad.length; i++) {
|
|
1173
|
+
for (let j = 0; j < channelGrad[0].length; j++) {
|
|
1174
|
+
let grad = channelGrad[i][j];
|
|
1175
|
+
|
|
1176
|
+
if (this.affine) {
|
|
1177
|
+
grad *= this.weight[c];
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
grad *= stdInv;
|
|
1181
|
+
gradInput[b][c][i][j] = grad;
|
|
1182
|
+
}
|
|
1183
|
+
}
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
if (this.affine) {
|
|
1187
|
+
this.gradWeight[c] = sumGradWeight / batchSize;
|
|
1188
|
+
this.gradBias[c] = sumGradBias / batchSize;
|
|
1189
|
+
}
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
return gradInput;
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1195
|
+
parameters() {
|
|
1196
|
+
if (!this.affine) return [];
|
|
1197
|
+
return [
|
|
1198
|
+
{ param: [this.weight], grad: [this.gradWeight] },
|
|
1199
|
+
{ param: [this.bias], grad: [this.gradBias] }
|
|
1200
|
+
];
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
train() { this.training = true; }
|
|
1204
|
+
eval() { this.training = false; }
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
// ---------------------- Model Save/Load ----------------------
|
|
1208
|
+
export function saveModel(model){
|
|
1209
|
+
if(!(model instanceof Sequential)) throw new Error("saveModel supports only Sequential");
|
|
1210
|
+
const weights=model.layers.map(layer=>({weights:layer.W||null,biases:layer.b||null}));
|
|
1211
|
+
return JSON.stringify(weights);
|
|
1212
|
+
/* Didn't expect this to work /: */
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
export function loadModel(model,json){
|
|
1216
|
+
if(!(model instanceof Sequential)) throw new Error("loadModel supports only Sequential");
|
|
1217
|
+
const weights=JSON.parse(json);
|
|
1218
|
+
model.layers.forEach((layer,i)=>{
|
|
1219
|
+
if(layer.W && weights[i].weights) layer.W=weights[i].weights;
|
|
1220
|
+
if(layer.b && weights[i].biases) layer.b=weights[i].biases;
|
|
1221
|
+
});
|
|
1222
|
+
/* Didn't expect this to work /: */
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
// ---------------------- Advanced Utils ----------------------
|
|
1226
|
+
export function flattenBatch(batch){ return batch.flat(2); }
|
|
1227
|
+
export function stack(tensors){ return tensors.map(t=>t.data); }
|
|
1228
|
+
export function eye(n){ return Array.from({length:n},(_,i)=>Array.from({length:n},(_,j)=>i===j?1:0)); }
|
|
1229
|
+
export function concat(a,b,axis=0){ /* concat along axis */ if(axis===0) return [...a,...b]; if(axis===1) return a.map((row,i)=>[...row,...b[i]]); }
|
|
1230
|
+
export function reshape(tensor, rows, cols) {
|
|
1231
|
+
let flat = tensor.data.flat(); // flatten first
|
|
1232
|
+
if(flat.length < rows*cols) throw new Error("reshape size mismatch");
|
|
1233
|
+
const out = Array.from({length: rows}, (_, i) =>
|
|
1234
|
+
flat.slice(i*cols, i*cols + cols)
|
|
1235
|
+
);
|
|
1236
|
+
return out;
|
|
1237
|
+
}
|