mini-jstorch 2.0.1 → 2.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Docs/API.md ADDED
@@ -0,0 +1,277 @@
1
+ # API Reference
2
+
3
+ ## Model Container
4
+
5
+ ### Sequential
6
+
7
+ ```js
8
+ new Sequential(layers: Layer[])
9
+ ```
10
+
11
+ Container that chains layers sequentially.
12
+
13
+ **Methods:**
14
+ - forward(x) — Pass input through all layers
15
+ - backward(grad) — Backpropagate gradient through all layers
16
+ - parameters() — Returns [{param, grad}, ...] for all trainable parameters
17
+ - zeroGrad() — Zero all parameter gradients
18
+ - train() — Set all layers to training mode
19
+ - eval() — Set all layers to evaluation mode
20
+ - stateDict() — Get {layer_0.weight, layer_0.bias, ...}
21
+ - loadStateDict(dict) — Load weights from state dict object
22
+ - step(lr) — Apply SGD step to all layers directly
23
+
24
+ ---
25
+
26
+ ## Layers
27
+
28
+ ### Linear
29
+
30
+ ```js
31
+ new Linear(inFeatures: number, outFeatures: number)
32
+ ```
33
+
34
+ Fully connected layer. Weight shape: [inFeatures, outFeatures]. Bias shape: [1, outFeatures].
35
+
36
+ ### Conv2D (experimental)
37
+
38
+ ```js
39
+ new Conv2D(inChannels: number, outChannels: number, kernelSize: number, stride?: number, padding?: number)
40
+ ```
41
+
42
+ 2D convolution layer. Input shape: [batch, channels, height, width].
43
+
44
+ ### Flatten
45
+
46
+ ```js
47
+ new Flatten()
48
+ ```
49
+
50
+ Flattens multi-dimensional input per sample. Preserves batch dimension.
51
+
52
+ ---
53
+
54
+ ## Activations
55
+
56
+ All activations follow the same interface:
57
+ - forward(x) — x: [batch, features], returns: [batch, features]
58
+ - backward(grad) — grad: [batch, features], returns: [batch, features]
59
+
60
+ | Class | Formula |
61
+ |-------|---------|
62
+ | ReLU() | max(0, x) |
63
+ | Sigmoid() | 1 / (1 + exp(-x)) |
64
+ | Tanh() | tanh(x) |
65
+ | LeakyReLU(alpha?) | x > 0 ? x : alpha * x |
66
+ | GELU() | 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) |
67
+ | ELU(alpha?) | x > 0 ? x : alpha * (exp(x) - 1) |
68
+ | Mish() | x * tanh(ln(1 + exp(x))) |
69
+ | SiLU() | x * sigmoid(x) |
70
+ | Softmax(dim?) | exp(x - max) / sum(exp(x - max)) |
71
+
72
+ ---
73
+
74
+ ## Loss Functions
75
+
76
+ ### MSELoss
77
+
78
+ ```js
79
+ new MSELoss()
80
+ ```
81
+
82
+ loss = mean((pred - target)^2)
83
+ gradient = 2 * (pred - target) / batchSize
84
+
85
+ ### SoftmaxCrossEntropyLoss (recommended for classification)
86
+
87
+ ```js
88
+ new SoftmaxCrossEntropyLoss()
89
+ ```
90
+
91
+ Combines softmax + cross-entropy in a numerically stable way. Input: logits (not probabilities). Do NOT combine with a Softmax layer.
92
+
93
+ ### BCEWithLogitsLoss (recommended for binary classification)
94
+
95
+ ```js
96
+ new BCEWithLogitsLoss()
97
+ ```
98
+
99
+ Combines sigmoid + binary cross-entropy. Numerically stable. Input: logits (not probabilities). Do NOT combine with a Sigmoid layer.
100
+
101
+ ### CrossEntropyLoss (deprecated)
102
+
103
+ ```js
104
+ new CrossEntropyLoss()
105
+ ```
106
+
107
+ Use SoftmaxCrossEntropyLoss instead. Exists for backward compatibility.
108
+
109
+ ---
110
+
111
+ ## Optimizers
112
+
113
+ ### Adam (recommended)
114
+
115
+ ```js
116
+ new Adam(parameters, options)
117
+ // or
118
+ new Adam(parameters, lr, beta1, beta2, eps, maxGradNorm)
119
+ ```
120
+
121
+ Options object: { lr: 0.001, b1: 0.9, b2: 0.999, eps: 1e-8, max_grad_norm: 1.0 }
122
+
123
+ ### AdamW
124
+
125
+ ```js
126
+ new AdamW(parameters, options)
127
+ ```
128
+
129
+ Adam with decoupled weight decay. Options include weight_decay (default: 0.01).
130
+
131
+ ### SGD
132
+
133
+ ```js
134
+ new SGD(parameters, lr?, maxGradNorm?)
135
+ ```
136
+
137
+ ### Lion
138
+
139
+ ```js
140
+ new LION(parameters, options)
141
+ ```
142
+
143
+ Memory-efficient optimizer.
144
+
145
+ **All optimizers have:**
146
+ - step() — Update parameters using accumulated gradients
147
+ - zeroGrad() — Zero all parameter gradients
148
+
149
+ ---
150
+
151
+ ## Learning Rate Schedulers
152
+
153
+ ### StepLR
154
+
155
+ ```js
156
+ new StepLR(optimizer, stepSize, gamma)
157
+ ```
158
+
159
+ Multiplies LR by gamma every stepSize steps.
160
+
161
+ ### LambdaLR
162
+
163
+ ```js
164
+ new LambdaLR(optimizer, fn)
165
+ ```
166
+
167
+ Sets LR = baseLr * fn(epoch) on each step.
168
+
169
+ ### ReduceLROnPlateau
170
+
171
+ ```js
172
+ new ReduceLROnPlateau(optimizer, options)
173
+ ```
174
+
175
+ Reduces LR when loss stops improving. Options: patience, factor, min_lr, threshold, cooldown, verbose.
176
+
177
+ ---
178
+
179
+ ## Regularization
180
+
181
+ ### Dropout
182
+
183
+ ```js
184
+ new Dropout(p?)
185
+ ```
186
+
187
+ Randomly zeros p fraction of neurons during training. Default p = 0.5. Call layer.train() / layer.eval() to toggle.
188
+
189
+ ### BatchNorm2d (experimental)
190
+
191
+ ```js
192
+ new BatchNorm2d(numFeatures, eps?, momentum?, affine?)
193
+ ```
194
+
195
+ 2D batch normalization. Input shape: [batch, channels, height, width].
196
+
197
+ ---
198
+
199
+ ## Tokenizer
200
+
201
+ ```js
202
+ new Tokenizer(vocabSize?)
203
+ ```
204
+
205
+ | Method | Description |
206
+ |--------|-------------|
207
+ | fit(texts) | Build vocabulary from text array |
208
+ | transform(texts, maxLength?, padToMax?) | Convert texts to token indices |
209
+ | fitTransform(texts, maxLength?, padToMax?) | Fit + transform in one call |
210
+ | inverseTransform(tokens, skipPad?) | Convert token indices back to text |
211
+ | getVocabulary() | Get vocabulary as string array |
212
+ | getVocabSize() | Get vocabulary size |
213
+ | getWordCounts() | Get word frequency map |
214
+ | getMostCommon(n?) | Get top N most frequent words |
215
+
216
+ ---
217
+
218
+ ## Utilities
219
+
220
+ zeros(rows, cols) — Matrix filled with 0
221
+ ones(rows, cols) — Matrix filled with 1
222
+ randomMatrix(rows, cols, scale?) — Random matrix (Xavier init if scale omitted)
223
+ transpose(matrix) — Matrix transpose
224
+ dot(a, b) — Matrix multiplication
225
+ addMatrices(a, b) — Element-wise addition
226
+ softmax(vector) — Softmax on 1D array
227
+ crossEntropy(pred, target) — Cross-entropy loss (scalar)
228
+ reshape(tensor, rows, cols) — Reshape to new dimensions
229
+ flattenBatch(batch) — Flatten batch to 2D
230
+ concat(a, b, axis) — Concatenate along axis 0 or 1
231
+ stack(tensors) — Stack tensors
232
+
233
+ ---
234
+
235
+ ## Model Persistence
236
+
237
+ saveModel(model) — Serialize model to JSON string
238
+ loadModel(model, json) — Load weights from JSON into model
239
+
240
+ Supports all layer types. Validates layer types and shapes. Logs warnings for mismatches.
241
+
242
+ ---
243
+
244
+ ## Tensor (Advanced)
245
+
246
+ ```js
247
+ new Tensor(data, requiresGrad?)
248
+ ```
249
+
250
+ | Method | Description |
251
+ |--------|-------------|
252
+ | add(tensor) | Element-wise addition |
253
+ | mul(tensor) | Element-wise multiplication |
254
+ | matmul(tensor) | Matrix multiplication |
255
+ | transpose() | Transpose tensor |
256
+ | flatten() | Flatten to 1D array |
257
+ | shape() | Returns [rows, cols] |
258
+
259
+ Static: Tensor.zeros(r,c), Tensor.ones(r,c), Tensor.random(r,c,scale?)
260
+
261
+ ---
262
+
263
+ ## User-Friendly Utilities (fu_*)
264
+
265
+ fu_tensor(data, requiresGrad?) — Create tensor from 2D array
266
+ fu_add(a, b) — Element-wise add
267
+ fu_mul(a, b) — Element-wise multiply
268
+ fu_matmul(a, b) — Matrix multiply
269
+ fu_sum(tensor) — Sum all elements
270
+ fu_mean(tensor) — Mean of all elements
271
+ fu_relu(tensor) — ReLU activation
272
+ fu_sigmoid(tensor) — Sigmoid activation
273
+ fu_tanh(tensor) — Tanh activation
274
+ fu_softmax(tensor) — Softmax activation
275
+ fu_flatten(tensor) — Flatten to 1D
276
+ fu_reshape(tensor, rows, cols) — Reshape tensor
277
+ fu_stack(tensors) — Stack tensors
package/README.md CHANGED
@@ -1,4 +1,4 @@
1
- ## Mini-JSTorch (MAJOR UPDATE)
1
+ ## Mini-JSTorch (v2.0.2)
2
2
 
3
3
  ---
4
4
 
@@ -7,15 +7,14 @@ It runs in Node.js and modern browsers, with a simple API inspired by PyTorch-st
7
7
 
8
8
  This project prioritizes `clarity`, `numerical correctness`, and `accessibility` over performance or large-scale production use.
9
9
 
10
- In this version `2.0.0`, we introduce:
11
- - **Fixed Linear layer cache** (critical bug fix for training)
12
- - **Fixed GELU gradient calculation**
13
- - **Fixed MSELoss gradient scaling**
14
- - **Optimized Softmax gradient** (O(n²) → O(n))
15
- - **Improved Tokenizer** with proper PAD/UNK separation
16
- - **Added Sequential.zeroGrad(), train(), eval(), stateDict() methods**
10
+ ### Changelog
17
11
 
18
- ---
12
+ **v2.0.2:**
13
+ - **Fixed critical training bug:** Optimizers (Adam, SGD, AdamW, Lion) now correctly update Linear and Conv2D layer weights
14
+ - **Fixed BatchNorm2d:** Inference mode no longer produces NaN for multi-channel inputs
15
+ - **Fixed ELU activation:** Backward pass now uses correct derivative formula
16
+ - **Fixed saveModel/loadModel:** Now correctly saves and restores all layer types including Conv2D and BatchNorm2d
17
+ - **Fixed BatchNorm2d gradient zeroing:** gradWeight/gradBias now correctly reset between batches
19
18
 
20
19
  **⚠️ BREAKING CHANGES in v2.0.0:**
21
20
  - Tokenizer API: `tokenizeBatch()` → `transform()`, `detokenizeBatch()` → `inverseTransform()`
@@ -80,7 +79,6 @@ In Browser/Website:
80
79
  async function train() {
81
80
  const statusEl = document.getElementById('status');
82
81
  const logEl = document.getElementById('log');
83
-
84
82
  try {
85
83
  const model = new Sequential([
86
84
  new Linear(2, 16), new Tanh(),
@@ -107,13 +105,13 @@ In Browser/Website:
107
105
  }
108
106
  }
109
107
 
110
- statusEl.textContent = 'Done';
108
+ statusEl.textContent = 'Done';
111
109
  const preds = model.forward(X);
112
110
  document.getElementById('res').innerHTML = `<h4>Results:</h4>` +
113
111
  X.map((input, i) => `[${input}] -> <b>${preds[i][0].toFixed(4)}</b> (Target: ${y[i][0]})`).join('<br>');
114
112
 
115
113
  } catch (e) {
116
- statusEl.textContent = 'Error: ' + e.message;
114
+ statusEl.textContent = 'Error: ' + e.message;
117
115
  }
118
116
  }
119
117
  train();
@@ -319,11 +317,16 @@ console.log(`\nAccuracy: ${(correct / X.length * 100).toFixed(2)}%`);
319
317
  # Save & Load Models
320
318
 
321
319
  ```javascript
322
- // WARN: Error/Bug may be expected for this time!
323
- import { saveModel, loadModel, Sequential } from "./src/jstorch.js";
320
+ import { saveModel, loadModel, Sequential } from "./src/jstorch";
324
321
 
322
+ // Save trained model
325
323
  const json = saveModel(model);
326
- const model2 = new Sequential([...]); // same architecture
324
+
325
+ // Create fresh model with same architecture and load weights
326
+ const model2 = new Sequential([
327
+ new Linear(2, 16), new ReLU(),
328
+ new Linear(16, 1)
329
+ ]);
327
330
  loadModel(model2, json);
328
331
  ```
329
332
 
@@ -360,7 +363,7 @@ node demo/<fileNameInDemo>.js
360
363
 
361
364
  MIT License
362
365
 
363
- Copyright (c) 2024
366
+ Copyright (c) 2024-2025
364
367
  rizal-editors
365
368
 
366
369
  ---
@@ -39,4 +39,4 @@ finalPred.forEach((p, i) => {
39
39
  });
40
40
  console.log(`\nAverage Error: ${(totalError / X.length).toFixed(2)}`);
41
41
  console.log(`Weight (slope): ${model.layers[0].W[0][0].toFixed(4)} (expected: 2.0)`);
42
- console.log(`Bias: ${model.layers[0].b[0].toFixed(4)} (expected: 0.0)`);
42
+ console.log(`Bias: ${model.layers[0].b[0][0].toFixed(4)} (expected: 0.0)`);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "mini-jstorch",
3
- "version": "2.0.1",
3
+ "version": "2.0.2",
4
4
  "type": "module",
5
5
  "description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
6
6
  "main": "index.js",
@@ -15,8 +15,7 @@
15
15
  "tiny-ml",
16
16
  "mini-neural-network",
17
17
  "mini-ml-library",
18
- "mini-js-ml",
19
- "educational-ml"
18
+ "mini-js-ml"
20
19
  ],
21
20
  "author": "Rizal",
22
21
  "license": "MIT"
package/src/jstorch.js CHANGED
@@ -81,62 +81,26 @@ export function crossEntropy(pred,target){
81
81
 
82
82
  // ---------------------- USERS FRIENDLY UTILS (USE THIS FOR YOUR UTILS!) ----------------
83
83
  export function fu_tensor(data, requiresGrad = false) {
84
- if (!Array.isArray(data) || !Array.isArray(data[0])) {
85
- throw new Error("fu_tensor: Data must be 2D array");
86
- }
87
- const tensor = new Tensor(data);
88
- tensor.requiresGrad = requiresGrad;
89
- return tensor;
84
+ if (!Array.isArray(data) || !Array.isArray(data[0])){
85
+ throw new Error("fu_tensor: Data must be 2D array");
86
+ }
87
+ return new Tensor(data, requiresGrad);
90
88
  }
91
89
 
92
- // fu_add
93
- export function fu_add(a, b) {
94
- if (!(a instanceof Tensor) && !(b instanceof Tensor)) {
95
- throw new Error("fu_add: At least one operand must be Tensor");
96
- }
97
-
98
- if (!(a instanceof Tensor)) {
99
- a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
100
- Array(b.shape()[1]).fill(a)
101
- ));
102
- }
103
-
104
- if (!(b instanceof Tensor)) {
105
- b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
106
- Array(a.shape()[1]).fill(b)
107
- ));
108
- }
109
-
110
- if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
111
- throw new Error(`fu_add: Shape mismatch ${a.shape()} vs ${b.shape()}`);
112
- }
113
-
114
- return new Tensor(a.data.map((r, i) => r.map((v, j) => v + b.data[i][j])));
90
+ // fu_add
91
+ export function fu_add(a, b){
92
+ if (!(a instanceof Tensor)) a = fu_tensor(a);
93
+ if (!(b instanceof Tensor)) b = fu_tensor(b);
94
+
95
+ return a.add(b);
115
96
  }
116
97
 
117
98
  // fu_mul
118
- export function fu_mul(a, b) {
119
- if (!(a instanceof Tensor) && !(b instanceof Tensor)) {
120
- throw new Error("fu_mul: At least one operand must be Tensor");
121
- }
122
-
123
- if (!(a instanceof Tensor)) {
124
- a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
125
- Array(b.shape()[1]).fill(a)
126
- ));
127
- }
128
-
129
- if (!(b instanceof Tensor)) {
130
- b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
131
- Array(a.shape()[1]).fill(b)
132
- ));
133
- }
134
-
135
- if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
136
- throw new Error(`fu_mul: Shape mismatch ${a.shape()} vs ${b.shape()}`);
137
- }
138
-
139
- return new Tensor(a.data.map((r, i) => r.map((v, j) => v * b.data[i][j])));
99
+ export function fu_mul(a, b){
100
+ if (!(a instanceof Tensor)) a = fu_tensor(a);
101
+ if (!(b instanceof Tensor)) b = fu_tensor(b);
102
+
103
+ return a.mul(b);
140
104
  }
141
105
 
142
106
  // fu_matmul
@@ -144,11 +108,7 @@ export function fu_matmul(a, b) {
144
108
  if (!(a instanceof Tensor)) a = fu_tensor(a);
145
109
  if (!(b instanceof Tensor)) b = fu_tensor(b);
146
110
 
147
- if (a.shape()[1] !== b.shape()[0]) {
148
- throw new Error(`fu_matmul: Inner dimension mismatch ${a.shape()[1]} vs ${b.shape()[0]}`);
149
- }
150
-
151
- return new Tensor(dot(a.data, b.data));
111
+ return a.matmul(b);
152
112
  }
153
113
 
154
114
  // fu_sum
@@ -344,276 +304,144 @@ export class Tensor {
344
304
 
345
305
  // ---------------------- Layers ----------------------
346
306
  export class Linear {
347
- constructor(inputDim, outputDim){
348
- this.W = randomMatrix(inputDim, outputDim);
349
- this.b = Array(outputDim).fill(0);
350
- this.gradW = zeros(inputDim, outputDim);
351
- this.gradb = Array(outputDim).fill(0);
352
- this.x = null;
353
- this.originalShape = null;
354
-
355
- this._WFlat = null;
356
- this._bFlat = null;
307
+ constructor(inFeatures, outFeatures) {
308
+ this.inFeatures = inFeatures;
309
+ this.outFeatures = outFeatures;
310
+
311
+ // Weights: [inFeatures, outFeatures]
312
+ this.W = randomMatrix(inFeatures, outFeatures);
313
+ this.gradW = zeros(inFeatures, outFeatures);
314
+
315
+ // Bias: [1, outFeatures]
316
+ this.b = [Array(outFeatures).fill(0)];
317
+ this.gradB = [Array(outFeatures).fill(0)];
318
+
319
+ this.x = null; // cache input
357
320
  }
358
-
359
- _updateCache() {
360
- const rows = this.W.length;
361
- const cols = this.W[0].length;
362
- this._WFlat = new Float32Array(rows * cols);
363
- for (let i = 0; i < rows; i++){
364
- const offset = i * cols;
365
- const row = this.W[i];
366
- for (let j = 0; j < cols; j++){
367
- this._WFlat[offset + j] = row[j];
368
- }
369
- }
370
- this._bFlat = new Float32Array(this.b);
371
- }
372
321
 
373
- forward(x){
374
- this.originalShape = this._getShapeType(x);
375
-
376
- if (this.originalShape === '3d') {
377
- this.x = x.map(sample => sample[0]);
378
- } else {
379
- this.x = x;
380
- }
381
-
382
- this._updateCache();
383
-
384
- const m = this.x.length;
385
- const k = this.x[0].length;
386
- const n = this.W[0].length;
387
-
388
- if (!this._WFlat) {
389
- const rows = this.W.length;
390
- const cols = this.W[0].length;
391
- this._WFlat = new Float32Array(rows * cols);
392
- for (let i = 0; i < rows; i++) {
393
- const offset = i * cols;
394
- const row = this.W[i];
395
- for (let j = 0; j < cols; j++) {
396
- this._WFlat[offset + j] = row[j];
397
- }
322
+ forward(x) {
323
+ // x: [batch, inFeatures]
324
+ this.x = x;
325
+
326
+ const out = dot(x, this.W); // [batch, outFeatures]
327
+
328
+ // add bias
329
+ for (let i = 0; i < out.length; i++) {
330
+ for (let j = 0; j < this.outFeatures; j++) {
331
+ out[i][j] += this.b[0][j];
398
332
  }
399
- this._bFlat = new Float32Array(this.b);
400
333
  }
334
+
335
+ return out;
336
+ }
337
+
338
+ backward(grad) {
339
+ // grad: [batch, outFeatures]
340
+ const batchSize = grad.length;
401
341
 
402
- // Flatten input x to Float32Array
403
- const xFlat = new Float32Array(m * k);
404
- for (let i = 0; i < m; i++) {
405
- const row = this.x[i];
406
- const offset = i * k;
407
- for (let j = 0; j < k; j++) {
408
- xFlat[offset + j] = row[j];
342
+ // mutate in place to preserve optimizer Reference
343
+ const xT = transpose(this.x);
344
+ const computedGradW = dot(xT, grad);
345
+ for (let i=0; i<this.inFeatures; i++){
346
+ for (let j=0; j<this.outFeatures; j++){
347
+ this.gradW[i][j] = computedGradW[i][j]
409
348
  }
410
349
  }
411
350
 
412
- const outFlat = new Float32Array(m * n);
413
- for (let i = 0; i < m; i++) {
414
- const xOffset = i * k;
415
- for (let j = 0; j < n; j++) {
416
- let sum = 0;
417
- for (let l = 0; l < k; l++) {
418
- sum += xFlat[xOffset + l] * this._WFlat[l * n + j];
419
- }
420
- outFlat[i * n + j] = sum + this._bFlat[j];
351
+ // gradB = sum over batch
352
+ for (let j=0; j<this.outFeatures; j++){
353
+ let sum = 0;
354
+ for (let i=0; i <batchSize; i++){
355
+ sum+=grad[i][j]
421
356
  }
357
+ this.gradB[0][j] = sum;
422
358
  }
423
359
 
424
- const out = Array(m);
425
- for (let i = 0; i < m; i++) {
426
- const row = Array(n);
427
- const offset = i * n;
428
- for (let j = 0; j < n; j++) {
429
- row[j] = outFlat[offset + j];
430
- }
431
- out[i] = row;
432
- }
360
+ const WT = transpose(this.W);
361
+ const gradInput = dot(grad, WT);
433
362
 
434
- return out;
363
+ return gradInput;
435
364
  }
436
365
 
437
- backward(grad){
438
- const m = this.x.length;
439
- const k = this.W.length; // input dim
440
- const n = this.W[0].length; // output dim
441
-
442
- // Convert grad to Float32Array
443
- const gradFlat = new Float32Array(m * n);
444
- for (let i = 0; i < m; i++) {
445
- const row = grad[i];
446
- const offset = i * n;
447
- for (let j = 0; j < n; j++) {
448
- gradFlat[offset + j] = row[j];
449
- }
450
- }
451
-
452
- // Convert x to Float32Array
453
- const xFlat = new Float32Array(m * k);
454
- for (let i = 0; i < m; i++) {
455
- const row = this.x[i];
456
- const offset = i * k;
457
- for (let j = 0; j < k; j++) {
458
- xFlat[offset + j] = row[j];
459
- }
460
- }
461
-
462
- // Reset gradW
463
- for (let i = 0; i < this.gradW.length; i++) {
464
- for (let j = 0; j < this.gradW[0].length; j++) {
465
- this.gradW[i][j] = 0;
466
- }
467
- }
468
-
469
- // Compute gradW = x^T * grad
470
- for (let i = 0; i < k; i++) {
471
- for (let j = 0; j < n; j++) {
472
- let sum = 0;
473
- for (let batch = 0; batch < m; batch++) {
474
- sum += xFlat[batch * k + i] * gradFlat[batch * n + j];
475
- }
476
- this.gradW[i][j] = sum;
477
- }
478
- }
479
-
480
- // Compute gradb
481
- for (let j = 0; j < n; j++) {
482
- let sum = 0;
483
- for (let batch = 0; batch < m; batch++) {
484
- sum += gradFlat[batch * n + j];
485
- }
486
- this.gradb[j] = sum;
487
- }
488
-
489
- const gradInputFlat = new Float32Array(m * k);
490
- for (let i = 0; i < m; i++) {
491
- for (let j = 0; j < k; j++) {
492
- let sum = 0;
493
- for (let l = 0; l < n; l++) {
494
- sum += gradFlat[i * n + l] * this.W[j][l];
495
- }
496
- gradInputFlat[i * k + j] = sum;
497
- }
498
- }
499
-
500
- // Convert back to 2D array
501
- const gradInput = Array(m);
502
- for (let i = 0; i < m; i++) {
503
- const row = Array(k);
504
- const offset = i * k;
505
- for (let j = 0; j < k; j++) {
506
- row[j] = gradInputFlat[offset + j];
507
- }
508
- gradInput[i] = row;
509
- }
510
-
511
- if (this.originalShape === '3d') {
512
- return gradInput.map(row => [row]);
513
- }
514
- return gradInput;
515
- }
366
+ step(lr) {
367
+ for (let i = 0; i < this.inFeatures; i++) {
368
+ for (let j = 0; j < this.outFeatures; j++) {
369
+ this.W[i][j] -= lr * this.gradW[i][j];
370
+ }
371
+ }
516
372
 
517
- _getShapeType(x) {
518
- if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
519
- return '3d';
520
- } else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
521
- return '2d';
522
- } else {
523
- throw new Error(`Unsupported input shape for Linear layer`);
373
+ for (let j = 0; j < this.outFeatures; j++) {
374
+ this.b[0][j] -= lr * this.gradB[0][j];
524
375
  }
525
376
  }
526
377
 
527
- parameters(){
528
- return [
529
- {param: this.W, grad: this.gradW},
530
- {param: [this.b], grad: [this.gradb]}
531
- ];
378
+ parameters() {
379
+ return [
380
+ { param: this.W, grad: this.gradW },
381
+ { param: this.b, grad: this.gradB }
382
+ ];
532
383
  }
533
384
  }
534
385
 
535
386
  export class Flatten {
536
- constructor() {
537
- this.originalShape = null;
538
- }
539
-
540
- forward(x) {
541
- // Always convert to [batch, features] format
542
- this.originalShape = x.map(sample => this._getShape(sample));
543
-
544
- return x.map(sample => {
545
- const flat = this._flatten(sample);
546
- return flat; // Return as 1D array for [batch, features] compatibility
547
- });
548
- }
549
-
550
- backward(grad) {
551
- // grad is [batch, features], reshape back to original shape
552
- return grad.map((flatGrad, batchIdx) => {
553
- const shape = this.originalShape[batchIdx];
554
- return this._unflatten(flatGrad, shape);
555
- });
556
- }
557
-
558
- _getShape(sample) {
559
- if (Array.isArray(sample[0]) && Array.isArray(sample[0][0])) {
560
- return {
561
- type: '3d',
562
- dims: [sample.length, sample[0].length, sample[0][0].length]
563
- };
564
- } else if (Array.isArray(sample[0])) {
565
- return {
566
- type: '2d',
567
- dims: [sample.length, sample[0].length]
568
- };
569
- } else {
570
- return {
571
- type: '1d',
572
- dims: [sample.length]
573
- };
574
- }
575
- }
576
-
577
- _flatten(sample) {
578
- if (Array.isArray(sample[0]) && Array.isArray(sample[0][0])) {
579
- return sample.flat(2); // [channels, height, width] -> flat
580
- } else if (Array.isArray(sample[0])) {
581
- return sample.flat(); // [height, width] -> flat
582
- } else {
583
- return sample; // already flat
584
- }
585
- }
586
-
587
- _unflatten(flat, shape) {
588
- if (shape.type === '3d') {
589
- const [channels, height, width] = shape.dims;
590
- const result = [];
591
- let index = 0;
592
- for (let c = 0; c < channels; c++) {
593
- const channel = [];
594
- for (let h = 0; h < height; h++) {
595
- const row = [];
596
- for (let w = 0; w < width; w++) {
597
- row.push(flat[index++]);
598
- }
599
- channel.push(row);
600
- }
601
- result.push(channel);
602
- }
603
- return result;
604
- } else if (shape.type === '2d') {
605
- const [height, width] = shape.dims;
606
- const result = [];
607
- for (let h = 0; h < height; h++) {
608
- result.push(flat.slice(h * width, h * width + width));
609
- }
610
- return result;
611
- } else {
612
- return flat; // 1d
613
- }
614
- }
615
-
616
- parameters() { return []; }
387
+ constructor(){
388
+ this.originalShape = null;
389
+ }
390
+
391
+ forward(x){
392
+ // Save full shape per sample
393
+ this.originalShape = x.map(sample => this._getDims(sample));
394
+
395
+ return x.map(sample => this._flattenDeep(sample));
396
+ }
397
+
398
+ backward(grad){
399
+ return grad.map((flat, i) =>
400
+ this._reshape(flat, this.originalShape[i])
401
+ );
402
+ }
403
+
404
+ // Get dimensions recursively
405
+ _getDims(arr){
406
+ const dims = [];
407
+ let current = arr;
408
+ while (Array.isArray(current)){
409
+ dims.push(current.length);
410
+ current = current[0];
411
+ }
412
+ return dims;
413
+ }
414
+
415
+ // Flatten ANY depth
416
+ _flattenDeep(arr){
417
+ return arr.flat(Infinity);
418
+ }
419
+
420
+ // Reshape back using saved dims
421
+ _reshape(flat, dims){
422
+ let index = 0;
423
+
424
+ function build(dimIdx){
425
+ const size = dims[dimIdx];
426
+ const result = [];
427
+
428
+ if (dimIdx === dims.length - 1){
429
+ for (let i=0; i<size; i++){
430
+ result.push(flat[index++]);
431
+ }
432
+ } else {
433
+ for (let i=0; i<size; i++){
434
+ result.push(build(dimIdx + 1));
435
+ }
436
+ }
437
+
438
+ return result;
439
+ }
440
+
441
+ return build(0);
442
+ }
443
+
444
+ parameters() { return []; }
617
445
  }
618
446
 
619
447
  // ---------------------- Conv2D (BETA) ----------------------
@@ -632,24 +460,25 @@ export class Conv2D {
632
460
  );
633
461
  this.x = null;
634
462
 
635
- // Cache Float32Array untuk kernels
636
463
  this._WFlat = null;
637
- this._cacheKernels();
464
+ this._updateCache();
638
465
  }
639
466
 
640
- _cacheKernels() {
641
- this._WFlat = this.W.map(oc =>
467
+ _updateCache(){
468
+ this._WFlat = this.W.map(oc =>
642
469
  oc.map(ic => {
643
470
  const rows = ic.length;
644
471
  const cols = ic[0].length;
645
472
  const flat = new Float32Array(rows * cols);
646
- for (let i = 0; i < rows; i++) {
473
+
474
+ for (let i=0; i<rows; i++){
647
475
  const offset = i * cols;
648
476
  const row = ic[i];
649
- for (let j = 0; j < cols; j++) {
650
- flat[offset + j] = row[j];
477
+ for (let j=0; j<cols; j++){
478
+ flat[offset+j] = row[j];
651
479
  }
652
480
  }
481
+
653
482
  return flat;
654
483
  })
655
484
  );
@@ -697,6 +526,8 @@ export class Conv2D {
697
526
  }
698
527
 
699
528
  forward(batch) {
529
+ this._updateCache();
530
+
700
531
  this.x = batch;
701
532
  const kH = this.kernel;
702
533
  const kW = this.kernel;
@@ -727,43 +558,52 @@ export class Conv2D {
727
558
 
728
559
  backward(grad) {
729
560
  const batchSize = this.x.length;
730
- const gradW = this.gradW.map(oc => oc.map(ic => zeros(this.kernel, this.kernel)));
731
- const gradInput = this.x.map(sample =>
561
+ const gradInput = this.x.map(sample =>
732
562
  sample.map(chan => zeros(chan.length, chan[0].length))
733
563
  );
734
-
735
- for (let b = 0; b < batchSize; b++) {
736
- for (let oc = 0; oc < this.outC; oc++) {
737
- for (let ic = 0; ic < this.inC; ic++) {
564
+
565
+ // Zero existing gradW in place
566
+ for (let oc=0; oc<this.outC; oc++){
567
+ for (let ic=0; ic<this.inC; ic++){
568
+ for (let i=0; i<this.kernel; i++){
569
+ for (let j=0; j<this.kernel; j++){
570
+ this.gradW[oc][ic][i][j] = 0;
571
+ }
572
+ }
573
+ }
574
+ }
575
+
576
+ for (let b=0; b<batchSize; b++){
577
+ for (let oc=0; oc<this.outC; oc++){
578
+ for (let ic=0; ic<this.inC; ic++){
738
579
  const outGrad = grad[b][oc];
739
580
 
740
- // Compute gradW
741
- for (let i = 0; i < this.kernel; i++) {
742
- for (let j = 0; j < this.kernel; j++) {
581
+ // Accumulate gradW in place
582
+ for (let i=0; i<this.kernel; i++){
583
+ for (let j=0; j<this.kernel; j++){
743
584
  let sum = 0;
744
- for (let y = 0; y < outGrad.length; y++) {
745
- for (let x = 0; x < outGrad[0].length; x++) {
585
+ for (let y=0; y<outGrad.length; y++){
586
+ for (let x=0; x<outGrad[0].length; x++){
746
587
  const inY = y * this.stride + i;
747
588
  const inX = x * this.stride + j;
748
- if (inY < this.x[b][ic].length && inX < this.x[b][ic][0].length) {
749
- sum += this.x[b][ic][inY][inX] * outGrad[y][x];
589
+ if (inY<this.x[b][ic].length && inX < this.x[b][ic][0].length){
590
+ sum+=this.x[b][ic][inY][inX] * outGrad[y][x];
750
591
  }
751
592
  }
752
593
  }
753
- gradW[oc][ic][i][j] += sum;
594
+ this.gradW[oc][ic][i][j] += sum;
754
595
  }
755
596
  }
756
-
757
- // Compute gradInput
758
- for (let y = 0; y < outGrad.length; y++) {
759
- for (let x = 0; x < outGrad[0].length; x++) {
760
- for (let ki = 0; ki < this.kernel; ki++) {
761
- for (let kj = 0; kj < this.kernel; kj++) {
597
+
598
+ // Compute gradInput
599
+ for (let y=0; y<outGrad.length; y++){
600
+ for (let x=0; x<outGrad[0].length; x++){
601
+ for (let ki=0; ki<this.kernel; ki++){
602
+ for (let kj=0; kj<this.kernel; kj++){
762
603
  const inY = y * this.stride + ki;
763
604
  const inX = x * this.stride + kj;
764
- if (inY < gradInput[b][ic].length && inX < gradInput[b][ic][0].length) {
765
- gradInput[b][ic][inY][inX] +=
766
- this.W[oc][ic][ki][kj] * outGrad[y][x];
605
+ if (inY<gradInput[b][ic].length && inX < gradInput[b][ic][0].length){
606
+ gradInput[b][ic][inY][inX] += this.W[oc][ic][ki][kj] * outGrad[y][x];
767
607
  }
768
608
  }
769
609
  }
@@ -772,8 +612,7 @@ export class Conv2D {
772
612
  }
773
613
  }
774
614
  }
775
-
776
- this.gradW = gradW;
615
+
777
616
  return gradInput;
778
617
  }
779
618
 
@@ -781,7 +620,7 @@ export class Conv2D {
781
620
  return this.W.flatMap((w, oc) =>
782
621
  w.map((wc, ic) => ({
783
622
  param: wc,
784
- grad: this.gradW[oc][ic]
623
+ grad: this.gradW[oc][ic] // Reference stays valid — gradW is mutated in-place now
785
624
  }))
786
625
  );
787
626
  }
@@ -872,6 +711,14 @@ export class Sequential {
872
711
  return state;
873
712
  }
874
713
 
714
+ step(lr){
715
+ this.layers.forEach(layer => {
716
+ if(typeof layer.step === "function"){
717
+ layer.step(lr);
718
+ }
719
+ })
720
+ }
721
+
875
722
  /**
876
723
  * Load state dict
877
724
  */
@@ -895,43 +742,35 @@ export class Sequential {
895
742
  }
896
743
 
897
744
  // ---------------------- Activations ----------------------
898
- export class ReLU{
899
- constructor(){ this.mask = null; this.originalShape = null; }
900
-
901
- forward(x){
902
- this.originalShape = this._getShapeType(x);
903
-
904
- if (this.originalShape === '3d') {
905
- // Handle [batch, 1, features]
906
- this.mask = x.map(sample => sample[0].map(v => v > 0));
907
- return x.map(sample => [sample[0].map(v => Math.max(0, v))]);
908
- } else {
909
- // Handle [batch, features]
910
- this.mask = x.map(row => row.map(v => v > 0));
911
- return x.map(row => row.map(v => Math.max(0, v)));
912
- }
913
- }
914
-
915
- backward(grad){
916
- if (this.originalShape === '3d') {
917
- return grad.map((sample, i) =>
918
- [sample[0].map((v, j) => this.mask[i][j] ? v : 0)]
919
- );
920
- } else {
921
- return grad.map((row, i) =>
922
- row.map((v, j) => this.mask[i][j] ? v : 0)
923
- );
745
+ export class ReLU {
746
+ constructor() {
747
+ this.mask = null;
748
+ }
749
+
750
+ forward(x) {
751
+ this.mask = this._mapRecursive(x, v => v > 0);
752
+ return this._mapRecursive(x, v => Math.max(0, v));
753
+ }
754
+
755
+ backward(grad) {
756
+ return this._mapRecursiveWithMask(grad, this.mask, (g, m) => m ? g : 0);
757
+ }
758
+
759
+ // ===== helper =====
760
+ _mapRecursive(arr, fn) {
761
+ if (Array.isArray(arr)) {
762
+ return arr.map(v => this._mapRecursive(v, fn));
924
763
  }
764
+ return fn(arr);
925
765
  }
926
-
927
- _getShapeType(x) {
928
- if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
929
- return '3d';
930
- } else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
931
- return '2d';
932
- } else {
933
- throw new Error(`Unsupported input shape for ReLU`);
766
+
767
+ _mapRecursiveWithMask(arr, mask, fn) {
768
+ if (Array.isArray(arr)) {
769
+ return arr.map((v, i) =>
770
+ this._mapRecursiveWithMask(v, mask[i], fn)
771
+ );
934
772
  }
773
+ return fn(arr, mask);
935
774
  }
936
775
  }
937
776
 
@@ -1156,7 +995,10 @@ export class Tokenizer {
1156
995
  .replace(/\s+/g, ' ')
1157
996
  .trim()
1158
997
  .split(' ')
1159
- .filter(word => word.length > 0 && !/^[.!?;:,]+$/.test(word) || word.length > 1);
998
+ .filter(word =>
999
+ word.length > 0 &&
1000
+ (!/^[.!?;:,]+$/.test(word) || word.length > 1)
1001
+ );
1160
1002
  }
1161
1003
 
1162
1004
  /**
@@ -1812,23 +1654,25 @@ export class ReduceLROnPlateau {
1812
1654
 
1813
1655
  // ---------------------- ELU Activation ----------------------
1814
1656
  export class ELU {
1815
- constructor(alpha=1.0) {
1657
+ constructor(alpha=1.0){
1816
1658
  this.alpha = alpha;
1817
- this.out = null;
1659
+ this.x = null; // Cache input for correct derivative
1818
1660
  }
1819
-
1820
- forward(x) {
1821
- this.out = x.map(row =>
1822
- row.map(v => v > 0 ? v : this.alpha * (Math.exp(v) - 1))
1661
+
1662
+ forward(x){
1663
+ this.x=x; // Store original input for backward
1664
+ return x.map(row =>
1665
+ row.map(v=>v>0 ? v : this.alpha*(Math.exp(v) - 1))
1823
1666
  );
1824
- return this.out;
1825
1667
  }
1826
-
1827
- backward(grad) {
1828
- return grad.map((row, i) =>
1829
- row.map((v, j) =>
1830
- v * (this.out[i][j] > 0 ? 1 : this.alpha * Math.exp(this.out[i][j]))
1831
- )
1668
+
1669
+ backward(grad){
1670
+ return grad.map((row, i) =>
1671
+ row.map((v, j) => {
1672
+ const xVal = this.x[i][j];
1673
+ // d/dx ELU: 1 if x > 0, else alpha * exp(x)
1674
+ return v * (xVal > 0 ? 1 : this.alpha * Math.exp(xVal));
1675
+ })
1832
1676
  );
1833
1677
  }
1834
1678
  }
@@ -1910,10 +1754,10 @@ export class BatchNorm2d {
1910
1754
 
1911
1755
  // Parameters
1912
1756
  if (affine) {
1913
- this.weight = Array(numFeatures).fill(1);
1914
- this.bias = Array(numFeatures).fill(0);
1915
- this.gradWeight = Array(numFeatures).fill(0);
1916
- this.gradBias = Array(numFeatures).fill(0);
1757
+ this.weight = [Array(numFeatures).fill(1)];
1758
+ this.bias = [Array(numFeatures).fill(0)];
1759
+ this.gradWeight = [Array(numFeatures).fill(0)];
1760
+ this.gradBias = [Array(numFeatures).fill(0)];
1917
1761
  }
1918
1762
 
1919
1763
  // Running statistics
@@ -1993,7 +1837,7 @@ export class BatchNorm2d {
1993
1837
 
1994
1838
  // Apply affine transformation if enabled
1995
1839
  if (this.affine) {
1996
- channelOut[i][j] = channelOut[i][j] * this.weight[c] + this.bias[c];
1840
+ channelOut[i][j] = channelOut[i][j] * this.weight[0][c] + this.bias[0][c];
1997
1841
  }
1998
1842
  }
1999
1843
  }
@@ -2022,7 +1866,7 @@ export class BatchNorm2d {
2022
1866
 
2023
1867
  // Apply affine transformation if enabled
2024
1868
  if (this.affine) {
2025
- channelOut[i][j] = channelOut[i][j] * this.weight[c] + this.bias[c];
1869
+ channelOut[i][j] = channelOut[i][j] * this.weight[0][c] + this.bias[0][c];
2026
1870
  }
2027
1871
  }
2028
1872
  }
@@ -2052,8 +1896,10 @@ export class BatchNorm2d {
2052
1896
  );
2053
1897
 
2054
1898
  if (this.affine) {
2055
- this.gradWeight.fill(0);
2056
- this.gradBias.fill(0);
1899
+ for (let c=0; c<channels; c++){
1900
+ this.gradWeight[0][c] = 0;
1901
+ this.gradBias[0][c] = 0;
1902
+ }
2057
1903
  }
2058
1904
 
2059
1905
  for (let c = 0; c < channels; c++) {
@@ -2083,7 +1929,7 @@ export class BatchNorm2d {
2083
1929
  let grad = channelGrad[i][j];
2084
1930
 
2085
1931
  if (this.affine) {
2086
- grad *= this.weight[c];
1932
+ grad *= this.weight[0][c];
2087
1933
  }
2088
1934
 
2089
1935
  grad *= stdInv;
@@ -2093,8 +1939,8 @@ export class BatchNorm2d {
2093
1939
  }
2094
1940
 
2095
1941
  if (this.affine) {
2096
- this.gradWeight[c] = sumGradWeight / batchSize;
2097
- this.gradBias[c] = sumGradBias / batchSize;
1942
+ this.gradWeight[0][c] = sumGradWeight / batchSize;
1943
+ this.gradBias[0][c] = sumGradBias / batchSize;
2098
1944
  }
2099
1945
  }
2100
1946
 
@@ -2104,9 +1950,9 @@ export class BatchNorm2d {
2104
1950
  parameters() {
2105
1951
  if (!this.affine) return [];
2106
1952
  return [
2107
- { param: [this.weight], grad: [this.gradWeight] },
2108
- { param: [this.bias], grad: [this.gradBias] }
2109
- ];
1953
+ { param: this.weight, grad: this.gradWeight },
1954
+ { param: this.bias, grad: this.gradBias }
1955
+ ]
2110
1956
  }
2111
1957
 
2112
1958
  train() { this.training = true; }
@@ -2115,20 +1961,143 @@ export class BatchNorm2d {
2115
1961
 
2116
1962
  // ---------------------- Model Save/Load (BETA) ----------------------
2117
1963
  export function saveModel(model){
2118
- if(!(model instanceof Sequential)) throw new Error("saveModel supports only Sequential");
2119
- const weights=model.layers.map(layer=>({weights:layer.W||null,biases:layer.b||null}));
2120
- return JSON.stringify(weights);
2121
- /* Didn't expect this to work */
1964
+ if(!(model instanceof Sequential)){
1965
+ throw new Error("saveModel supports only Sequential models");
1966
+ }
1967
+
1968
+ const state = {
1969
+ version: "2.0.0",
1970
+ layers: model.layers.map((layer, idx) => {
1971
+ const params = layer.parameters ? layer.parameters() : [];
1972
+
1973
+ if (params.length === 0){
1974
+ return { type: layer.constructor.name, params: [] };
1975
+ }
1976
+
1977
+ return {
1978
+ type: layer.constructor.name,
1979
+ params: params.map(p => ({
1980
+ // Deep clone parameter data
1981
+ data: p.param.map(row =>
1982
+ Array.isArray(row) ? [...row] : row
1983
+ ),
1984
+ // Preserve shape metadata for validation
1985
+ shape: Array.isArray(p.param[0])
1986
+ ? [p.param.length, p.param[0].length]
1987
+ : [p.param.length]
1988
+ }))
1989
+ };
1990
+ })
1991
+ };
1992
+
1993
+ return JSON.stringify(state);
2122
1994
  }
2123
1995
 
2124
- export function loadModel(model,json){
2125
- if(!(model instanceof Sequential)) throw new Error("loadModel supports only Sequential");
2126
- const weights=JSON.parse(json);
2127
- model.layers.forEach((layer,i)=>{
2128
- if(layer.W && weights[i].weights) layer.W=weights[i].weights;
2129
- if(layer.b && weights[i].biases) layer.b=weights[i].biases;
2130
- });
2131
- /* Didn't expect this to work */
1996
+ export function loadModel(model, json){
1997
+ if (!(model instanceof Sequential)){
1998
+ throw new Error("loadModel supports only Sequential models");
1999
+ }
2000
+
2001
+ const state = JSON.parse(json);
2002
+
2003
+ // Validate structure
2004
+ if (!state.layers || !Array.isArray(state.layers)){
2005
+ throw new Error("loadModel: invalid save format - missing 'layers' array");
2006
+ }
2007
+
2008
+ if (state.layers.length !== model.layers.length){
2009
+ console.warn(
2010
+ `[JST WARN]: Layer count mismatch - saved ${state.layers.length},` +
2011
+ `current model has ${model.layers.length}. Loading what matches.`
2012
+ );
2013
+ }
2014
+
2015
+ let loadedCount = 0;
2016
+ let skippedCount = 0;
2017
+
2018
+ for(let i=0; i<Math.min(state.layers.length, model.layers.length); i++){
2019
+ const savedLayer = state.layers[i];
2020
+ const currentLayer = model.layers[i];
2021
+
2022
+ if(savedLayer.params.length === 0){
2023
+ // Layer with no trainable params - skip
2024
+ continue
2025
+ }
2026
+
2027
+ // Validate layer type
2028
+ if (savedLayer.type !== currentLayer.constructor.name){
2029
+ console.warn(
2030
+ `[JST WARN]: Layer ${i} type mismatch - ` +
2031
+ `saved: ${savedLayer.type}, current: ${currentLayer.constructor.name}. Skipping.`
2032
+ );
2033
+ skippedCount++;
2034
+ continue;
2035
+ }
2036
+
2037
+ // Get current layer parameters
2038
+ const currentParams = currentLayer.parameters ? currentLayer.parameters() : [];
2039
+
2040
+ if (currentParams.length !== savedLayer.params.length){
2041
+ console.warn(
2042
+ `[JST WARN]: Layer ${i} parameter count mismatch - ` +
2043
+ `saved: ${savedLayer.params.length}, current: ${currentParams.length}. Skipping.`
2044
+ );
2045
+ skippedCount++;
2046
+ continue;
2047
+ }
2048
+
2049
+ // Load parameters wiht shape validation
2050
+ for (let j=0; j<savedLayer.params.length; j++){
2051
+ const savedParam = savedLayer.params[j];
2052
+ const currentParam = currentParams[j].param;
2053
+
2054
+ // Validate shape
2055
+ const currentRows = currentParam.length;
2056
+ const currentCols = Array.isArray(currentParam[0])
2057
+ ? currentParam[0].length
2058
+ : 1;
2059
+
2060
+ const savedRows = savedParam.shape[0];
2061
+ const savedCols = savedParam.shape[1] || 1;
2062
+
2063
+ if (currentRows !== savedRows || currentCols !== savedCols){
2064
+ console.warn(
2065
+ `[JST WARN]: Layer ${i} param ${j} shape mismatch - ` +
2066
+ `saved: [${savedRows}, ${savedCols}],` +
2067
+ `current: [${currentRows}, ${currentCols}]. Skipping this parameter.`
2068
+ );
2069
+ continue
2070
+ }
2071
+
2072
+ // Copy parameter data
2073
+ if (Array.isArray(currentParam[0])){
2074
+ // 2D Parameter
2075
+ for (let r=0; r<currentRows; r++){
2076
+ for (let c=0; c<currentCols; c++){
2077
+ currentParam[r][c] = savedParam.data[r][c];
2078
+ }
2079
+ }
2080
+ } else {
2081
+ // 1D parameter
2082
+ for (let r=0; r<currentRows; r++){
2083
+ currentParam[r] = savedParam.data[r];
2084
+ }
2085
+ }
2086
+ }
2087
+
2088
+ // Invalidate any cached flat representations
2089
+ if (typeof currentLayer._updateCache === 'function'){
2090
+ currentLayer._updateCache();
2091
+ }
2092
+
2093
+ loadedCount++;
2094
+ }
2095
+
2096
+ console.log(
2097
+ `[JST]: Model loaded: ${loadedCount} layers restored, ${skippedCount} skipped.`
2098
+ );
2099
+
2100
+ return model;
2132
2101
  }
2133
2102
 
2134
2103
  // ---------------------- Advanced Utils ----------------------