mini-jstorch 2.0.3 → 2.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -16
- package/demo/saveAndLoadModel.js +45 -0
- package/package.json +1 -1
- package/src/jstorch.js +193 -80
- /package/src/Dummy/{msg → msg.txt} +0 -0
package/README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
## Mini-JSTorch (v2.0.
|
|
1
|
+
## Mini-JSTorch (v2.0.4)
|
|
2
2
|
|
|
3
3
|
---
|
|
4
4
|
|
|
@@ -9,12 +9,10 @@ This project prioritizes `clarity`, `numerical correctness`, and `accessibility`
|
|
|
9
9
|
|
|
10
10
|
## Changelog
|
|
11
11
|
|
|
12
|
-
**v2.0.
|
|
13
|
-
- **
|
|
14
|
-
- **Fixed
|
|
15
|
-
|
|
16
|
-
- **Fixed saveModel/loadModel:** Now correctly saves and restores all layer types including Conv2D and BatchNorm2d
|
|
17
|
-
- **Fixed BatchNorm2d gradient zeroing:** gradWeight/gradBias now correctly reset between batches
|
|
12
|
+
**v2.0.4:**
|
|
13
|
+
- **Optimized saveModel/loadModel:** Complete rewrite with flat 1D serialization using shape metadata, single-cursor deserialization, and pre-allocated arrays. Reduces JSON size by ~40-50% for large models and improves load performance.
|
|
14
|
+
- **Fixed stateDict/loadStateDict:** Now uses `parameters()` for universal layer compatibility. Previously only saved `W` and `b` properties, silently dropping Conv2D and BatchNorm2d weights.
|
|
15
|
+
|
|
18
16
|
|
|
19
17
|
**⚠️ BREAKING CHANGES in v2.0.0:**
|
|
20
18
|
- Tokenizer API: `tokenizeBatch()` → `transform()`, `detokenizeBatch()` → `inverseTransform()`
|
|
@@ -74,7 +72,7 @@ In Browser/Website:
|
|
|
74
72
|
<div id="res"></div>
|
|
75
73
|
|
|
76
74
|
<script type="module">
|
|
77
|
-
import { Sequential, Linear, ReLU, MSELoss, Adam, StepLR, Tanh } from 'https://unpkg.com/mini-jstorch@
|
|
75
|
+
import { Sequential, Linear, ReLU, MSELoss, Adam, StepLR, Tanh } from 'https://unpkg.com/mini-jstorch@2.0.4/index.js';
|
|
78
76
|
|
|
79
77
|
async function train() {
|
|
80
78
|
const statusEl = document.getElementById('status');
|
|
@@ -187,7 +185,7 @@ In Browser/Website:
|
|
|
187
185
|
|
|
188
186
|
## Node.js
|
|
189
187
|
```bash
|
|
190
|
-
npm install mini-jstorch
|
|
188
|
+
npm install mini-jstorch@latest
|
|
191
189
|
```
|
|
192
190
|
Node.js v18+ or any modern browser with ES module support is recommended.
|
|
193
191
|
|
|
@@ -214,13 +212,12 @@ import {
|
|
|
214
212
|
const model = new Sequential([
|
|
215
213
|
new Linear(2, 8),
|
|
216
214
|
new ReLU(),
|
|
217
|
-
new Linear(8, 2)
|
|
215
|
+
new Linear(8, 2)
|
|
218
216
|
]);
|
|
219
217
|
|
|
220
218
|
const X = [
|
|
221
219
|
[0,0], [0,1], [1,0], [1,1]
|
|
222
220
|
];
|
|
223
|
-
|
|
224
221
|
const Y = [
|
|
225
222
|
[1,0], [0,1], [0,1], [1,0]
|
|
226
223
|
];
|
|
@@ -234,14 +231,20 @@ for (let epoch = 1; epoch <= 300; epoch++) {
|
|
|
234
231
|
const grad = lossFn.backward();
|
|
235
232
|
model.backward(grad);
|
|
236
233
|
optimizer.step();
|
|
237
|
-
|
|
238
|
-
// Zero gradients for next iteration
|
|
239
234
|
model.zeroGrad();
|
|
240
235
|
|
|
241
236
|
if (epoch % 50 === 0) {
|
|
242
237
|
console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(6)}`);
|
|
243
238
|
}
|
|
244
239
|
}
|
|
240
|
+
|
|
241
|
+
console.log('\nResults:');
|
|
242
|
+
const logits = model.forward(X);
|
|
243
|
+
X.forEach((input, i) => {
|
|
244
|
+
const pred = logits[i][0] > logits[i][1] ? 0 : 1;
|
|
245
|
+
const target = Y[i][0] === 1 ? 0 : 1;
|
|
246
|
+
console.log(` [${input}] → class ${pred} (target: ${target}) ${pred === target ? 'TRUE' : 'FALSE'}`);
|
|
247
|
+
});
|
|
245
248
|
```
|
|
246
249
|
`Important:` Do not combine `SoftmaxCrossEntropyLoss` with a `Softmax` layer.
|
|
247
250
|
|
|
@@ -280,7 +283,6 @@ for (let epoch = 1; epoch <= 300; epoch++) {
|
|
|
280
283
|
optimizer.step();
|
|
281
284
|
model.zeroGrad();
|
|
282
285
|
|
|
283
|
-
// Print progress every 50 epochs
|
|
284
286
|
if (epoch % 50 === 0) {
|
|
285
287
|
const probs = logits.map(p => 1 / (1 + Math.exp(-p[0])));
|
|
286
288
|
console.log(`Epoch ${epoch} | Loss: ${loss.toFixed(6)}`);
|
|
@@ -292,7 +294,6 @@ for (let epoch = 1; epoch <= 300; epoch++) {
|
|
|
292
294
|
}
|
|
293
295
|
}
|
|
294
296
|
|
|
295
|
-
// Final evaluation
|
|
296
297
|
console.log("\nTraining Complete\n");
|
|
297
298
|
model.eval();
|
|
298
299
|
|
|
@@ -340,7 +341,7 @@ See the `demo/` directory for runnable examples!
|
|
|
340
341
|
- `demo/scheduler.js`
|
|
341
342
|
- `demo/xor_classification.js`
|
|
342
343
|
- `demo/linear_regression.js`
|
|
343
|
-
|
|
344
|
+
- `demo/saveAndLoadModel.js`
|
|
344
345
|
|
|
345
346
|
```bash
|
|
346
347
|
node demo/<fileNameInDemo>.js
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
// demo/saveAndLoadModel.js
|
|
2
|
+
// Trains a linear regression model (y = 3x + 5), saves it to JSON,
|
|
3
|
+
// loads it back into a fresh model, and verifies predictions match.
|
|
4
|
+
|
|
5
|
+
import {
|
|
6
|
+
Sequential, Linear,
|
|
7
|
+
MSELoss, Adam,
|
|
8
|
+
saveModel, loadModel
|
|
9
|
+
} from '../src/jstorch.js';
|
|
10
|
+
|
|
11
|
+
const X = [[1], [2], [3], [4], [5]];
|
|
12
|
+
const y = [[8], [11], [14], [17], [20]];
|
|
13
|
+
|
|
14
|
+
const model = new Sequential([new Linear(1, 1)]);
|
|
15
|
+
const criterion = new MSELoss();
|
|
16
|
+
const optimizer = new Adam(model.parameters(), 0.1);
|
|
17
|
+
|
|
18
|
+
// Train
|
|
19
|
+
for (let e = 1; e <= 500; e++) {
|
|
20
|
+
const pred = model.forward(X);
|
|
21
|
+
const loss = criterion.forward(pred, y);
|
|
22
|
+
model.backward(criterion.backward());
|
|
23
|
+
optimizer.step();
|
|
24
|
+
model.zeroGrad();
|
|
25
|
+
if (e % 100 === 0) console.log(`Epoch ${e} | Loss: ${loss.toFixed(6)}`);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
// Original predictions
|
|
29
|
+
console.log('\nTrained model predictions:');
|
|
30
|
+
const original = model.forward(X);
|
|
31
|
+
X.forEach((x, i) => console.log(` x=${x[0]} → ${original[i][0].toFixed(4)} (target: ${y[i][0]})`));
|
|
32
|
+
|
|
33
|
+
// Save → Load
|
|
34
|
+
const json = saveModel(model);
|
|
35
|
+
const restored = new Sequential([new Linear(1, 1)]);
|
|
36
|
+
loadModel(restored, json);
|
|
37
|
+
|
|
38
|
+
// Loaded predictions
|
|
39
|
+
console.log('\nLoaded model predictions:');
|
|
40
|
+
const loaded = restored.forward(X);
|
|
41
|
+
X.forEach((x, i) => console.log(` x=${x[0]} → ${loaded[i][0].toFixed(4)} (target: ${y[i][0]})`));
|
|
42
|
+
|
|
43
|
+
// Verify
|
|
44
|
+
const ok = original.every((r, i) => Math.abs(r[0] - loaded[i][0]) < 1e-10);
|
|
45
|
+
console.log(`\n${ok ? 'Model saved and restored successfully.' : 'Mismatch after load.'}`);
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "mini-jstorch",
|
|
3
|
-
"version": "2.0.
|
|
3
|
+
"version": "2.0.4",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
|
|
6
6
|
"main": "index.js",
|
package/src/jstorch.js
CHANGED
|
@@ -703,13 +703,25 @@ export class Sequential {
|
|
|
703
703
|
* Get model state dict (weights and biases)
|
|
704
704
|
*/
|
|
705
705
|
stateDict(){
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
706
|
+
const state = {};
|
|
707
|
+
this.layers.forEach((layer, idx) => {
|
|
708
|
+
const params = layer.parameters ? layer.parameters() : [];
|
|
709
|
+
if (params.length === 0) return;
|
|
710
|
+
|
|
711
|
+
params.forEach((p, pIdx) => {
|
|
712
|
+
// Deep clone parameter data to prevent reference sharing
|
|
713
|
+
const paramData = p.param;
|
|
714
|
+
const is2D = Array.isArray(paramData[0]);
|
|
715
|
+
|
|
716
|
+
if (is2D){
|
|
717
|
+
state[`layer_${idx}.param_${pIdx}`] = paramData.map(row => [...row]);
|
|
718
|
+
} else {
|
|
719
|
+
state[`layer_${idx}.param_${pIdx}`] = [...paramData];
|
|
720
|
+
}
|
|
721
|
+
});
|
|
722
|
+
});
|
|
723
|
+
return state;
|
|
724
|
+
}
|
|
713
725
|
|
|
714
726
|
step(lr){
|
|
715
727
|
this.layers.forEach(layer => {
|
|
@@ -720,24 +732,70 @@ export class Sequential {
|
|
|
720
732
|
}
|
|
721
733
|
|
|
722
734
|
/**
|
|
723
|
-
* Load state dict
|
|
735
|
+
* Load state dict into model.
|
|
736
|
+
* Mutates parameter in-place to preserve optimizer references.
|
|
724
737
|
*/
|
|
725
738
|
loadStateDict(stateDict){
|
|
726
739
|
this.layers.forEach((layer, idx) => {
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
740
|
+
const params = layer.parameters ? layer.parameters() : [];
|
|
741
|
+
if (params.length === 0) return;
|
|
742
|
+
|
|
743
|
+
params.forEach((p, pIdx) => {
|
|
744
|
+
const key = `layer_${idx}.param_${pIdx}`;
|
|
745
|
+
const savedData = stateDict[key];
|
|
746
|
+
|
|
747
|
+
if (!savedData){
|
|
748
|
+
console.warn(`[JST WARN]: stateDict missing key: ${key}`);
|
|
749
|
+
return;
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
const currentParam = p.param;
|
|
753
|
+
const is2D = Array.isArray(currentParam[0]);
|
|
754
|
+
|
|
755
|
+
// Validate shape
|
|
756
|
+
const savedRows = savedData.length;
|
|
757
|
+
const currentRows = currentParam.length;
|
|
758
|
+
|
|
759
|
+
if (savedRows !== currentRows){
|
|
760
|
+
console.warn(
|
|
761
|
+
`[JST WARN]: stateDict shape mismatch for ${key} - ` +
|
|
762
|
+
`saved rows: ${savedRows}, current rows: ${currentRows}. Skipping.`
|
|
763
|
+
);
|
|
764
|
+
return;
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
if (is2D){
|
|
768
|
+
const savedCols = savedData[0].length;
|
|
769
|
+
const currentCols = currentParam[0].length;
|
|
770
|
+
|
|
771
|
+
if (savedCols !== currentCols){
|
|
772
|
+
console.warn(
|
|
773
|
+
`[JST WARN]: stateDict shape mismatch for ${key} - ` +
|
|
774
|
+
`saved cols: ${savedCols}, current cols: ${currentCols}. Skipping.`
|
|
775
|
+
);
|
|
776
|
+
return;
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
// Mutate in-place to preserve optimizer References
|
|
780
|
+
for (let r=0; r<currentRows; r++){
|
|
781
|
+
for (let c=0; c<currentCols; c++){
|
|
782
|
+
currentParam[r][c] = savedData[r][c];
|
|
783
|
+
}
|
|
784
|
+
}
|
|
785
|
+
} else {
|
|
786
|
+
// 1D parameter
|
|
787
|
+
for (let r=0; r<currentRows; r++){
|
|
788
|
+
currentParam[r] = savedData[r];
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
});
|
|
792
|
+
|
|
793
|
+
// Invalidate cached flat representations
|
|
794
|
+
if(typeof layer._updateCache === 'function'){
|
|
795
|
+
layer._updateCache();
|
|
796
|
+
}
|
|
797
|
+
});
|
|
798
|
+
return this;
|
|
741
799
|
}
|
|
742
800
|
}
|
|
743
801
|
|
|
@@ -1225,7 +1283,10 @@ export class CrossEntropyLoss{
|
|
|
1225
1283
|
return this._impl.forward(logits, targets);
|
|
1226
1284
|
}
|
|
1227
1285
|
backward(){
|
|
1228
|
-
|
|
1286
|
+
if (!this._impl.probs){
|
|
1287
|
+
throw new Error("CrossEntropyLoss: backward() called before forward()");
|
|
1288
|
+
}
|
|
1289
|
+
return this._impl.backward();
|
|
1229
1290
|
}
|
|
1230
1291
|
}
|
|
1231
1292
|
|
|
@@ -1959,40 +2020,76 @@ export class BatchNorm2d {
|
|
|
1959
2020
|
eval() { this.training = false; }
|
|
1960
2021
|
}
|
|
1961
2022
|
|
|
1962
|
-
// ---------------------- Model Save/Load (
|
|
2023
|
+
// ---------------------- Model Save/Load (OPTIMIZED) ----------------------
|
|
2024
|
+
/**
|
|
2025
|
+
* Serialize model parameters to a compact flat JSON string.
|
|
2026
|
+
* Flattens all 2D parameter matrices into 1D arrays with shape metadata.
|
|
2027
|
+
*/
|
|
1963
2028
|
export function saveModel(model){
|
|
1964
|
-
if(!(model instanceof Sequential)){
|
|
2029
|
+
if (!(model instanceof Sequential)){
|
|
1965
2030
|
throw new Error("saveModel supports only Sequential models");
|
|
1966
2031
|
}
|
|
1967
2032
|
|
|
1968
|
-
const
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
2033
|
+
const layers = [];
|
|
2034
|
+
|
|
2035
|
+
for (let i=0; i<model.layers.length; i++){
|
|
2036
|
+
const layer = model.layers[i];
|
|
2037
|
+
const params = layer.parameters ? layer.parameters() : [];
|
|
2038
|
+
const serializedParams = [];
|
|
2039
|
+
|
|
2040
|
+
for (let j=0; j<params.length; j++){
|
|
2041
|
+
const param = params[j].param;
|
|
2042
|
+
const rows = param.length;
|
|
2043
|
+
const is2D = Array.isArray(param[0]);
|
|
2044
|
+
const cols = is2D ? param[0].length : 1;
|
|
1972
2045
|
|
|
1973
|
-
|
|
1974
|
-
|
|
2046
|
+
let flat;
|
|
2047
|
+
|
|
2048
|
+
if (!is2D){
|
|
2049
|
+
// 1D parameter: native slice for bulk memory copy
|
|
2050
|
+
flat = param.slice();
|
|
2051
|
+
} else {
|
|
2052
|
+
// 2D parameter: flatten row by row
|
|
2053
|
+
const total = rows * cols;
|
|
2054
|
+
flat = new Array(total);
|
|
2055
|
+
let cursor = 0;
|
|
2056
|
+
for (let r=0; r<rows; r++){
|
|
2057
|
+
const row = param[r];
|
|
2058
|
+
for (let c=0; c<cols; c++){
|
|
2059
|
+
flat[cursor++] = row[c];
|
|
2060
|
+
}
|
|
2061
|
+
}
|
|
1975
2062
|
}
|
|
1976
2063
|
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
2064
|
+
serializedParams.push({
|
|
2065
|
+
s: [rows, cols],
|
|
2066
|
+
d: flat
|
|
2067
|
+
});
|
|
2068
|
+
}
|
|
2069
|
+
|
|
2070
|
+
// Extract running stats for BatchNorm2d layers
|
|
2071
|
+
let runningStats = null;
|
|
2072
|
+
if (typeof layer.runningMean !== 'undefined' && typeof layer.runningVar !== 'undefined'){
|
|
2073
|
+
runningStats = {
|
|
2074
|
+
mean: layer.runningMean.slice(),
|
|
2075
|
+
var: layer.runningVar.slice()
|
|
1989
2076
|
};
|
|
1990
|
-
}
|
|
1991
|
-
|
|
2077
|
+
}
|
|
2078
|
+
|
|
2079
|
+
layers.push({
|
|
2080
|
+
t: layer.constructor.name,
|
|
2081
|
+
p: serializedParams,
|
|
2082
|
+
rs: runningStats
|
|
2083
|
+
});
|
|
2084
|
+
}
|
|
1992
2085
|
|
|
1993
|
-
return JSON.stringify(
|
|
2086
|
+
return JSON.stringify({ ver: "2.0.4", layers: layers });
|
|
1994
2087
|
}
|
|
1995
2088
|
|
|
2089
|
+
/**
|
|
2090
|
+
* Load serialize flat parameters into a model.
|
|
2091
|
+
* Uses pre-allocated arrays and single-cursor deserialization.
|
|
2092
|
+
*/
|
|
1996
2093
|
export function loadModel(model, json){
|
|
1997
2094
|
if (!(model instanceof Sequential)){
|
|
1998
2095
|
throw new Error("loadModel supports only Sequential models");
|
|
@@ -2000,92 +2097,108 @@ export function loadModel(model, json){
|
|
|
2000
2097
|
|
|
2001
2098
|
const state = JSON.parse(json);
|
|
2002
2099
|
|
|
2003
|
-
// Validate structure
|
|
2004
2100
|
if (!state.layers || !Array.isArray(state.layers)){
|
|
2005
2101
|
throw new Error("loadModel: invalid save format - missing 'layers' array");
|
|
2006
2102
|
}
|
|
2007
2103
|
|
|
2104
|
+
const layerCount = Math.min(state.layers.length, model.layers.length);
|
|
2105
|
+
|
|
2008
2106
|
if (state.layers.length !== model.layers.length){
|
|
2009
2107
|
console.warn(
|
|
2010
2108
|
`[JST WARN]: Layer count mismatch - saved ${state.layers.length},` +
|
|
2011
|
-
`current model has ${model.layers.length}. Loading
|
|
2109
|
+
`current model has ${model.layers.length}. Loading ${layerCount} layers.`
|
|
2012
2110
|
);
|
|
2013
2111
|
}
|
|
2014
2112
|
|
|
2015
2113
|
let loadedCount = 0;
|
|
2016
2114
|
let skippedCount = 0;
|
|
2017
2115
|
|
|
2018
|
-
for(let i=0; i<
|
|
2116
|
+
for (let i=0; i<layerCount; i++){
|
|
2019
2117
|
const savedLayer = state.layers[i];
|
|
2020
2118
|
const currentLayer = model.layers[i];
|
|
2021
2119
|
|
|
2022
|
-
if(savedLayer.
|
|
2023
|
-
//
|
|
2024
|
-
|
|
2120
|
+
if (!savedLayer.p || savedLayer.p.length === 0){
|
|
2121
|
+
// Still restore running stats even if no trainable params
|
|
2122
|
+
if (savedLayer.rs && typeof currentLayer.runningMean !== 'undefined'){
|
|
2123
|
+
for (let k=0; k<savedLayer.rs.mean.length; k++){
|
|
2124
|
+
currentLayer.runningMean[k] = savedLayer.rs.mean[k];
|
|
2125
|
+
currentLayer.runningVar[k] = savedLayer.rs.var[k];
|
|
2126
|
+
}
|
|
2127
|
+
}
|
|
2128
|
+
continue;
|
|
2025
2129
|
}
|
|
2026
2130
|
|
|
2027
2131
|
// Validate layer type
|
|
2028
|
-
if (savedLayer.
|
|
2132
|
+
if (savedLayer.t !== currentLayer.constructor.name){
|
|
2029
2133
|
console.warn(
|
|
2030
2134
|
`[JST WARN]: Layer ${i} type mismatch - ` +
|
|
2031
|
-
`saved: ${savedLayer.
|
|
2135
|
+
`saved: ${savedLayer.t}, current: ${currentLayer.constructor.name}. Skipping.`
|
|
2032
2136
|
);
|
|
2033
2137
|
skippedCount++;
|
|
2034
2138
|
continue;
|
|
2035
2139
|
}
|
|
2036
2140
|
|
|
2037
|
-
// Get current layer parameters
|
|
2038
2141
|
const currentParams = currentLayer.parameters ? currentLayer.parameters() : [];
|
|
2039
2142
|
|
|
2040
|
-
if (currentParams.length !== savedLayer.
|
|
2143
|
+
if (currentParams.length !== savedLayer.p.length){
|
|
2041
2144
|
console.warn(
|
|
2042
2145
|
`[JST WARN]: Layer ${i} parameter count mismatch - ` +
|
|
2043
|
-
`saved: ${savedLayer.
|
|
2146
|
+
`saved: ${savedLayer.p.length}, current: ${currentParams.length}. Skipping.`
|
|
2044
2147
|
);
|
|
2045
2148
|
skippedCount++;
|
|
2046
2149
|
continue;
|
|
2047
2150
|
}
|
|
2048
2151
|
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2152
|
+
for (let j=0; j<savedLayer.p.length; j++){
|
|
2153
|
+
const savedParam = savedLayer.p[j];
|
|
2154
|
+
const savedRows = savedParam.s[0];
|
|
2155
|
+
const savedCols = savedParam.s[1];
|
|
2156
|
+
const flatData = savedParam.d;
|
|
2053
2157
|
|
|
2054
|
-
|
|
2158
|
+
const currentParam = currentParams[j].param;
|
|
2055
2159
|
const currentRows = currentParam.length;
|
|
2056
2160
|
const currentCols = Array.isArray(currentParam[0])
|
|
2057
2161
|
? currentParam[0].length
|
|
2058
2162
|
: 1;
|
|
2059
2163
|
|
|
2060
|
-
|
|
2061
|
-
const savedCols = savedParam.shape[1] || 1;
|
|
2062
|
-
|
|
2164
|
+
// Validate shape
|
|
2063
2165
|
if (currentRows !== savedRows || currentCols !== savedCols){
|
|
2064
2166
|
console.warn(
|
|
2065
2167
|
`[JST WARN]: Layer ${i} param ${j} shape mismatch - ` +
|
|
2066
|
-
`saved: [${savedRows}, ${savedCols}]
|
|
2067
|
-
`current: [${currentRows}, ${currentCols}]. Skipping
|
|
2168
|
+
`saved: [${savedRows}, ${savedCols}], ` +
|
|
2169
|
+
`current: [${currentRows}, ${currentCols}]. Skipping.`
|
|
2068
2170
|
);
|
|
2069
|
-
continue
|
|
2171
|
+
continue;
|
|
2070
2172
|
}
|
|
2071
2173
|
|
|
2072
|
-
//
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2174
|
+
// Single-cursor deserialization into pre-existing arrays
|
|
2175
|
+
let cursor =0;
|
|
2176
|
+
|
|
2177
|
+
if (savedCols === 1 && !Array.isArray(currentParam[0])){
|
|
2178
|
+
// 1D Parameter
|
|
2179
|
+
for (let r=0; r<savedRows; r++){
|
|
2180
|
+
currentParam[r] = flatData[cursor++];
|
|
2079
2181
|
}
|
|
2080
2182
|
} else {
|
|
2081
|
-
//
|
|
2082
|
-
for (let r=0; r<
|
|
2083
|
-
|
|
2183
|
+
// 2D parameter: populate rows in-place
|
|
2184
|
+
for (let r=0; r<savedRows; r++){
|
|
2185
|
+
const row = currentParam[r];
|
|
2186
|
+
for (let c=0; c<savedCols; c++){
|
|
2187
|
+
row[c] = flatData[cursor++];
|
|
2188
|
+
}
|
|
2084
2189
|
}
|
|
2085
2190
|
}
|
|
2086
2191
|
}
|
|
2087
2192
|
|
|
2088
|
-
//
|
|
2193
|
+
// Restore BatchNorm2d running statistics
|
|
2194
|
+
if (savedLayer.rs && typeof currentLayer.runningMean !== 'undefined'){
|
|
2195
|
+
for (let k=0; k<savedLayer.rs.mean.length; k++){
|
|
2196
|
+
currentLayer.runningMean[k] = savedLayer.rs.mean[k];
|
|
2197
|
+
currentLayer.runningVar[k] = savedLayer.rs.var[k];
|
|
2198
|
+
}
|
|
2199
|
+
}
|
|
2200
|
+
|
|
2201
|
+
// Invalidate cached flat representations
|
|
2089
2202
|
if (typeof currentLayer._updateCache === 'function'){
|
|
2090
2203
|
currentLayer._updateCache();
|
|
2091
2204
|
}
|
|
@@ -2094,7 +2207,7 @@ export function loadModel(model, json){
|
|
|
2094
2207
|
}
|
|
2095
2208
|
|
|
2096
2209
|
console.log(
|
|
2097
|
-
`[JST]
|
|
2210
|
+
`[JST] Model loaded: ${loadedCount} layers restored, ${skippedCount} skipped.`
|
|
2098
2211
|
);
|
|
2099
2212
|
|
|
2100
2213
|
return model;
|
|
File without changes
|