mini-jstorch 2.0.3 → 2.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,4 +1,4 @@
1
- ## Mini-JSTorch (v2.0.2)
1
+ ## Mini-JSTorch (v2.0.4)
2
2
 
3
3
  ---
4
4
 
@@ -9,12 +9,10 @@ This project prioritizes `clarity`, `numerical correctness`, and `accessibility`
9
9
 
10
10
  ## Changelog
11
11
 
12
- **v2.0.2:**
13
- - **Fixed critical training bug:** Optimizers (Adam, SGD, AdamW, Lion) now correctly update Linear and Conv2D layer weights
14
- - **Fixed BatchNorm2d:** Inference mode no longer produces NaN for multi-channel inputs
15
- - **Fixed ELU activation:** Backward pass now uses correct derivative formula
16
- - **Fixed saveModel/loadModel:** Now correctly saves and restores all layer types including Conv2D and BatchNorm2d
17
- - **Fixed BatchNorm2d gradient zeroing:** gradWeight/gradBias now correctly reset between batches
12
+ **v2.0.4:**
13
+ - **Optimized saveModel/loadModel:** Complete rewrite with flat 1D serialization using shape metadata, single-cursor deserialization, and pre-allocated arrays. Reduces JSON size by ~40-50% for large models and improves load performance.
14
+ - **Fixed stateDict/loadStateDict:** Now uses `parameters()` for universal layer compatibility. Previously only saved `W` and `b` properties, silently dropping Conv2D and BatchNorm2d weights.
15
+
18
16
 
19
17
  **⚠️ BREAKING CHANGES in v2.0.0:**
20
18
  - Tokenizer API: `tokenizeBatch()` → `transform()`, `detokenizeBatch()` → `inverseTransform()`
@@ -74,7 +72,7 @@ In Browser/Website:
74
72
  <div id="res"></div>
75
73
 
76
74
  <script type="module">
77
- import { Sequential, Linear, ReLU, MSELoss, Adam, StepLR, Tanh } from 'https://unpkg.com/mini-jstorch@1.8.0/index.js';
75
+ import { Sequential, Linear, ReLU, MSELoss, Adam, StepLR, Tanh } from 'https://unpkg.com/mini-jstorch@2.0.4/index.js';
78
76
 
79
77
  async function train() {
80
78
  const statusEl = document.getElementById('status');
@@ -187,7 +185,7 @@ In Browser/Website:
187
185
 
188
186
  ## Node.js
189
187
  ```bash
190
- npm install mini-jstorch
188
+ npm install mini-jstorch@latest
191
189
  ```
192
190
  Node.js v18+ or any modern browser with ES module support is recommended.
193
191
 
@@ -214,13 +212,12 @@ import {
214
212
  const model = new Sequential([
215
213
  new Linear(2, 8),
216
214
  new ReLU(),
217
- new Linear(8, 2) // logits output
215
+ new Linear(8, 2)
218
216
  ]);
219
217
 
220
218
  const X = [
221
219
  [0,0], [0,1], [1,0], [1,1]
222
220
  ];
223
-
224
221
  const Y = [
225
222
  [1,0], [0,1], [0,1], [1,0]
226
223
  ];
@@ -234,14 +231,20 @@ for (let epoch = 1; epoch <= 300; epoch++) {
234
231
  const grad = lossFn.backward();
235
232
  model.backward(grad);
236
233
  optimizer.step();
237
-
238
- // Zero gradients for next iteration
239
234
  model.zeroGrad();
240
235
 
241
236
  if (epoch % 50 === 0) {
242
237
  console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(6)}`);
243
238
  }
244
239
  }
240
+
241
+ console.log('\nResults:');
242
+ const logits = model.forward(X);
243
+ X.forEach((input, i) => {
244
+ const pred = logits[i][0] > logits[i][1] ? 0 : 1;
245
+ const target = Y[i][0] === 1 ? 0 : 1;
246
+ console.log(` [${input}] → class ${pred} (target: ${target}) ${pred === target ? 'TRUE' : 'FALSE'}`);
247
+ });
245
248
  ```
246
249
  `Important:` Do not combine `SoftmaxCrossEntropyLoss` with a `Softmax` layer.
247
250
 
@@ -280,7 +283,6 @@ for (let epoch = 1; epoch <= 300; epoch++) {
280
283
  optimizer.step();
281
284
  model.zeroGrad();
282
285
 
283
- // Print progress every 50 epochs
284
286
  if (epoch % 50 === 0) {
285
287
  const probs = logits.map(p => 1 / (1 + Math.exp(-p[0])));
286
288
  console.log(`Epoch ${epoch} | Loss: ${loss.toFixed(6)}`);
@@ -292,7 +294,6 @@ for (let epoch = 1; epoch <= 300; epoch++) {
292
294
  }
293
295
  }
294
296
 
295
- // Final evaluation
296
297
  console.log("\nTraining Complete\n");
297
298
  model.eval();
298
299
 
@@ -340,7 +341,7 @@ See the `demo/` directory for runnable examples!
340
341
  - `demo/scheduler.js`
341
342
  - `demo/xor_classification.js`
342
343
  - `demo/linear_regression.js`
343
-
344
+ - `demo/saveAndLoadModel.js`
344
345
 
345
346
  ```bash
346
347
  node demo/<fileNameInDemo>.js
@@ -0,0 +1,45 @@
1
+ // demo/saveAndLoadModel.js
2
+ // Trains a linear regression model (y = 3x + 5), saves it to JSON,
3
+ // loads it back into a fresh model, and verifies predictions match.
4
+
5
+ import {
6
+ Sequential, Linear,
7
+ MSELoss, Adam,
8
+ saveModel, loadModel
9
+ } from '../src/jstorch.js';
10
+
11
+ const X = [[1], [2], [3], [4], [5]];
12
+ const y = [[8], [11], [14], [17], [20]];
13
+
14
+ const model = new Sequential([new Linear(1, 1)]);
15
+ const criterion = new MSELoss();
16
+ const optimizer = new Adam(model.parameters(), 0.1);
17
+
18
+ // Train
19
+ for (let e = 1; e <= 500; e++) {
20
+ const pred = model.forward(X);
21
+ const loss = criterion.forward(pred, y);
22
+ model.backward(criterion.backward());
23
+ optimizer.step();
24
+ model.zeroGrad();
25
+ if (e % 100 === 0) console.log(`Epoch ${e} | Loss: ${loss.toFixed(6)}`);
26
+ }
27
+
28
+ // Original predictions
29
+ console.log('\nTrained model predictions:');
30
+ const original = model.forward(X);
31
+ X.forEach((x, i) => console.log(` x=${x[0]} → ${original[i][0].toFixed(4)} (target: ${y[i][0]})`));
32
+
33
+ // Save → Load
34
+ const json = saveModel(model);
35
+ const restored = new Sequential([new Linear(1, 1)]);
36
+ loadModel(restored, json);
37
+
38
+ // Loaded predictions
39
+ console.log('\nLoaded model predictions:');
40
+ const loaded = restored.forward(X);
41
+ X.forEach((x, i) => console.log(` x=${x[0]} → ${loaded[i][0].toFixed(4)} (target: ${y[i][0]})`));
42
+
43
+ // Verify
44
+ const ok = original.every((r, i) => Math.abs(r[0] - loaded[i][0]) < 1e-10);
45
+ console.log(`\n${ok ? 'Model saved and restored successfully.' : 'Mismatch after load.'}`);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "mini-jstorch",
3
- "version": "2.0.3",
3
+ "version": "2.0.4",
4
4
  "type": "module",
5
5
  "description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
6
6
  "main": "index.js",
package/src/jstorch.js CHANGED
@@ -703,13 +703,25 @@ export class Sequential {
703
703
  * Get model state dict (weights and biases)
704
704
  */
705
705
  stateDict(){
706
- const state = {};
707
- this.layers.forEach((layer, idx) => {
708
- if (layer.W) state[`layer_${idx}.weight`] = layer.W;
709
- if (layer.b) state[`layer_${idx}.bias`] = layer.b;
710
- });
711
- return state;
712
- }
706
+ const state = {};
707
+ this.layers.forEach((layer, idx) => {
708
+ const params = layer.parameters ? layer.parameters() : [];
709
+ if (params.length === 0) return;
710
+
711
+ params.forEach((p, pIdx) => {
712
+ // Deep clone parameter data to prevent reference sharing
713
+ const paramData = p.param;
714
+ const is2D = Array.isArray(paramData[0]);
715
+
716
+ if (is2D){
717
+ state[`layer_${idx}.param_${pIdx}`] = paramData.map(row => [...row]);
718
+ } else {
719
+ state[`layer_${idx}.param_${pIdx}`] = [...paramData];
720
+ }
721
+ });
722
+ });
723
+ return state;
724
+ }
713
725
 
714
726
  step(lr){
715
727
  this.layers.forEach(layer => {
@@ -720,24 +732,70 @@ export class Sequential {
720
732
  }
721
733
 
722
734
  /**
723
- * Load state dict
735
+ * Load state dict into model.
736
+ * Mutates parameter in-place to preserve optimizer references.
724
737
  */
725
738
  loadStateDict(stateDict){
726
739
  this.layers.forEach((layer, idx) => {
727
- const weightKey = `layer_${idx}.weight`;
728
- const biasKey = `layer_${idx}.bias`;
729
-
730
- if (layer.W && stateDict[weightKey]){
731
- layer.W = stateDict[weightKey];
732
- // Invalidate cache
733
- if (layer._InvalidateCache) layer._InvalidateCache();
734
- }
735
- if (layer.b && stateDict[biasKey]){
736
- layer.b = stateDict[biasKey];
737
- if (layer._InvalidateCache) layer._InvalidateCache();
738
- }
739
- });
740
- return this;
740
+ const params = layer.parameters ? layer.parameters() : [];
741
+ if (params.length === 0) return;
742
+
743
+ params.forEach((p, pIdx) => {
744
+ const key = `layer_${idx}.param_${pIdx}`;
745
+ const savedData = stateDict[key];
746
+
747
+ if (!savedData){
748
+ console.warn(`[JST WARN]: stateDict missing key: ${key}`);
749
+ return;
750
+ }
751
+
752
+ const currentParam = p.param;
753
+ const is2D = Array.isArray(currentParam[0]);
754
+
755
+ // Validate shape
756
+ const savedRows = savedData.length;
757
+ const currentRows = currentParam.length;
758
+
759
+ if (savedRows !== currentRows){
760
+ console.warn(
761
+ `[JST WARN]: stateDict shape mismatch for ${key} - ` +
762
+ `saved rows: ${savedRows}, current rows: ${currentRows}. Skipping.`
763
+ );
764
+ return;
765
+ }
766
+
767
+ if (is2D){
768
+ const savedCols = savedData[0].length;
769
+ const currentCols = currentParam[0].length;
770
+
771
+ if (savedCols !== currentCols){
772
+ console.warn(
773
+ `[JST WARN]: stateDict shape mismatch for ${key} - ` +
774
+ `saved cols: ${savedCols}, current cols: ${currentCols}. Skipping.`
775
+ );
776
+ return;
777
+ }
778
+
779
+ // Mutate in-place to preserve optimizer References
780
+ for (let r=0; r<currentRows; r++){
781
+ for (let c=0; c<currentCols; c++){
782
+ currentParam[r][c] = savedData[r][c];
783
+ }
784
+ }
785
+ } else {
786
+ // 1D parameter
787
+ for (let r=0; r<currentRows; r++){
788
+ currentParam[r] = savedData[r];
789
+ }
790
+ }
791
+ });
792
+
793
+ // Invalidate cached flat representations
794
+ if(typeof layer._updateCache === 'function'){
795
+ layer._updateCache();
796
+ }
797
+ });
798
+ return this;
741
799
  }
742
800
  }
743
801
 
@@ -1225,7 +1283,10 @@ export class CrossEntropyLoss{
1225
1283
  return this._impl.forward(logits, targets);
1226
1284
  }
1227
1285
  backward(){
1228
- return this._impl.backward();
1286
+ if (!this._impl.probs){
1287
+ throw new Error("CrossEntropyLoss: backward() called before forward()");
1288
+ }
1289
+ return this._impl.backward();
1229
1290
  }
1230
1291
  }
1231
1292
 
@@ -1959,40 +2020,76 @@ export class BatchNorm2d {
1959
2020
  eval() { this.training = false; }
1960
2021
  }
1961
2022
 
1962
- // ---------------------- Model Save/Load (BETA) ----------------------
2023
+ // ---------------------- Model Save/Load (OPTIMIZED) ----------------------
2024
+ /**
2025
+ * Serialize model parameters to a compact flat JSON string.
2026
+ * Flattens all 2D parameter matrices into 1D arrays with shape metadata.
2027
+ */
1963
2028
  export function saveModel(model){
1964
- if(!(model instanceof Sequential)){
2029
+ if (!(model instanceof Sequential)){
1965
2030
  throw new Error("saveModel supports only Sequential models");
1966
2031
  }
1967
2032
 
1968
- const state = {
1969
- version: "2.0.0",
1970
- layers: model.layers.map((layer, idx) => {
1971
- const params = layer.parameters ? layer.parameters() : [];
2033
+ const layers = [];
2034
+
2035
+ for (let i=0; i<model.layers.length; i++){
2036
+ const layer = model.layers[i];
2037
+ const params = layer.parameters ? layer.parameters() : [];
2038
+ const serializedParams = [];
2039
+
2040
+ for (let j=0; j<params.length; j++){
2041
+ const param = params[j].param;
2042
+ const rows = param.length;
2043
+ const is2D = Array.isArray(param[0]);
2044
+ const cols = is2D ? param[0].length : 1;
1972
2045
 
1973
- if (params.length === 0){
1974
- return { type: layer.constructor.name, params: [] };
2046
+ let flat;
2047
+
2048
+ if (!is2D){
2049
+ // 1D parameter: native slice for bulk memory copy
2050
+ flat = param.slice();
2051
+ } else {
2052
+ // 2D parameter: flatten row by row
2053
+ const total = rows * cols;
2054
+ flat = new Array(total);
2055
+ let cursor = 0;
2056
+ for (let r=0; r<rows; r++){
2057
+ const row = param[r];
2058
+ for (let c=0; c<cols; c++){
2059
+ flat[cursor++] = row[c];
2060
+ }
2061
+ }
1975
2062
  }
1976
2063
 
1977
- return {
1978
- type: layer.constructor.name,
1979
- params: params.map(p => ({
1980
- // Deep clone parameter data
1981
- data: p.param.map(row =>
1982
- Array.isArray(row) ? [...row] : row
1983
- ),
1984
- // Preserve shape metadata for validation
1985
- shape: Array.isArray(p.param[0])
1986
- ? [p.param.length, p.param[0].length]
1987
- : [p.param.length]
1988
- }))
2064
+ serializedParams.push({
2065
+ s: [rows, cols],
2066
+ d: flat
2067
+ });
2068
+ }
2069
+
2070
+ // Extract running stats for BatchNorm2d layers
2071
+ let runningStats = null;
2072
+ if (typeof layer.runningMean !== 'undefined' && typeof layer.runningVar !== 'undefined'){
2073
+ runningStats = {
2074
+ mean: layer.runningMean.slice(),
2075
+ var: layer.runningVar.slice()
1989
2076
  };
1990
- })
1991
- };
2077
+ }
2078
+
2079
+ layers.push({
2080
+ t: layer.constructor.name,
2081
+ p: serializedParams,
2082
+ rs: runningStats
2083
+ });
2084
+ }
1992
2085
 
1993
- return JSON.stringify(state);
2086
+ return JSON.stringify({ ver: "2.0.4", layers: layers });
1994
2087
  }
1995
2088
 
2089
+ /**
2090
+ * Load serialize flat parameters into a model.
2091
+ * Uses pre-allocated arrays and single-cursor deserialization.
2092
+ */
1996
2093
  export function loadModel(model, json){
1997
2094
  if (!(model instanceof Sequential)){
1998
2095
  throw new Error("loadModel supports only Sequential models");
@@ -2000,92 +2097,108 @@ export function loadModel(model, json){
2000
2097
 
2001
2098
  const state = JSON.parse(json);
2002
2099
 
2003
- // Validate structure
2004
2100
  if (!state.layers || !Array.isArray(state.layers)){
2005
2101
  throw new Error("loadModel: invalid save format - missing 'layers' array");
2006
2102
  }
2007
2103
 
2104
+ const layerCount = Math.min(state.layers.length, model.layers.length);
2105
+
2008
2106
  if (state.layers.length !== model.layers.length){
2009
2107
  console.warn(
2010
2108
  `[JST WARN]: Layer count mismatch - saved ${state.layers.length},` +
2011
- `current model has ${model.layers.length}. Loading what matches.`
2109
+ `current model has ${model.layers.length}. Loading ${layerCount} layers.`
2012
2110
  );
2013
2111
  }
2014
2112
 
2015
2113
  let loadedCount = 0;
2016
2114
  let skippedCount = 0;
2017
2115
 
2018
- for(let i=0; i<Math.min(state.layers.length, model.layers.length); i++){
2116
+ for (let i=0; i<layerCount; i++){
2019
2117
  const savedLayer = state.layers[i];
2020
2118
  const currentLayer = model.layers[i];
2021
2119
 
2022
- if(savedLayer.params.length === 0){
2023
- // Layer with no trainable params - skip
2024
- continue
2120
+ if (!savedLayer.p || savedLayer.p.length === 0){
2121
+ // Still restore running stats even if no trainable params
2122
+ if (savedLayer.rs && typeof currentLayer.runningMean !== 'undefined'){
2123
+ for (let k=0; k<savedLayer.rs.mean.length; k++){
2124
+ currentLayer.runningMean[k] = savedLayer.rs.mean[k];
2125
+ currentLayer.runningVar[k] = savedLayer.rs.var[k];
2126
+ }
2127
+ }
2128
+ continue;
2025
2129
  }
2026
2130
 
2027
2131
  // Validate layer type
2028
- if (savedLayer.type !== currentLayer.constructor.name){
2132
+ if (savedLayer.t !== currentLayer.constructor.name){
2029
2133
  console.warn(
2030
2134
  `[JST WARN]: Layer ${i} type mismatch - ` +
2031
- `saved: ${savedLayer.type}, current: ${currentLayer.constructor.name}. Skipping.`
2135
+ `saved: ${savedLayer.t}, current: ${currentLayer.constructor.name}. Skipping.`
2032
2136
  );
2033
2137
  skippedCount++;
2034
2138
  continue;
2035
2139
  }
2036
2140
 
2037
- // Get current layer parameters
2038
2141
  const currentParams = currentLayer.parameters ? currentLayer.parameters() : [];
2039
2142
 
2040
- if (currentParams.length !== savedLayer.params.length){
2143
+ if (currentParams.length !== savedLayer.p.length){
2041
2144
  console.warn(
2042
2145
  `[JST WARN]: Layer ${i} parameter count mismatch - ` +
2043
- `saved: ${savedLayer.params.length}, current: ${currentParams.length}. Skipping.`
2146
+ `saved: ${savedLayer.p.length}, current: ${currentParams.length}. Skipping.`
2044
2147
  );
2045
2148
  skippedCount++;
2046
2149
  continue;
2047
2150
  }
2048
2151
 
2049
- // Load parameters wiht shape validation
2050
- for (let j=0; j<savedLayer.params.length; j++){
2051
- const savedParam = savedLayer.params[j];
2052
- const currentParam = currentParams[j].param;
2152
+ for (let j=0; j<savedLayer.p.length; j++){
2153
+ const savedParam = savedLayer.p[j];
2154
+ const savedRows = savedParam.s[0];
2155
+ const savedCols = savedParam.s[1];
2156
+ const flatData = savedParam.d;
2053
2157
 
2054
- // Validate shape
2158
+ const currentParam = currentParams[j].param;
2055
2159
  const currentRows = currentParam.length;
2056
2160
  const currentCols = Array.isArray(currentParam[0])
2057
2161
  ? currentParam[0].length
2058
2162
  : 1;
2059
2163
 
2060
- const savedRows = savedParam.shape[0];
2061
- const savedCols = savedParam.shape[1] || 1;
2062
-
2164
+ // Validate shape
2063
2165
  if (currentRows !== savedRows || currentCols !== savedCols){
2064
2166
  console.warn(
2065
2167
  `[JST WARN]: Layer ${i} param ${j} shape mismatch - ` +
2066
- `saved: [${savedRows}, ${savedCols}],` +
2067
- `current: [${currentRows}, ${currentCols}]. Skipping this parameter.`
2168
+ `saved: [${savedRows}, ${savedCols}], ` +
2169
+ `current: [${currentRows}, ${currentCols}]. Skipping.`
2068
2170
  );
2069
- continue
2171
+ continue;
2070
2172
  }
2071
2173
 
2072
- // Copy parameter data
2073
- if (Array.isArray(currentParam[0])){
2074
- // 2D Parameter
2075
- for (let r=0; r<currentRows; r++){
2076
- for (let c=0; c<currentCols; c++){
2077
- currentParam[r][c] = savedParam.data[r][c];
2078
- }
2174
+ // Single-cursor deserialization into pre-existing arrays
2175
+ let cursor =0;
2176
+
2177
+ if (savedCols === 1 && !Array.isArray(currentParam[0])){
2178
+ // 1D Parameter
2179
+ for (let r=0; r<savedRows; r++){
2180
+ currentParam[r] = flatData[cursor++];
2079
2181
  }
2080
2182
  } else {
2081
- // 1D parameter
2082
- for (let r=0; r<currentRows; r++){
2083
- currentParam[r] = savedParam.data[r];
2183
+ // 2D parameter: populate rows in-place
2184
+ for (let r=0; r<savedRows; r++){
2185
+ const row = currentParam[r];
2186
+ for (let c=0; c<savedCols; c++){
2187
+ row[c] = flatData[cursor++];
2188
+ }
2084
2189
  }
2085
2190
  }
2086
2191
  }
2087
2192
 
2088
- // Invalidate any cached flat representations
2193
+ // Restore BatchNorm2d running statistics
2194
+ if (savedLayer.rs && typeof currentLayer.runningMean !== 'undefined'){
2195
+ for (let k=0; k<savedLayer.rs.mean.length; k++){
2196
+ currentLayer.runningMean[k] = savedLayer.rs.mean[k];
2197
+ currentLayer.runningVar[k] = savedLayer.rs.var[k];
2198
+ }
2199
+ }
2200
+
2201
+ // Invalidate cached flat representations
2089
2202
  if (typeof currentLayer._updateCache === 'function'){
2090
2203
  currentLayer._updateCache();
2091
2204
  }
@@ -2094,7 +2207,7 @@ export function loadModel(model, json){
2094
2207
  }
2095
2208
 
2096
2209
  console.log(
2097
- `[JST]: Model loaded: ${loadedCount} layers restored, ${skippedCount} skipped.`
2210
+ `[JST] Model loaded: ${loadedCount} layers restored, ${skippedCount} skipped.`
2098
2211
  );
2099
2212
 
2100
2213
  return model;
File without changes