mini-jstorch 1.5.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Docs/About.md ADDED
@@ -0,0 +1,83 @@
1
+ # Mini-JSTorch — Technical Information
2
+
3
+ ---
4
+
5
+ ## General Information
6
+
7
+ - **Project Name:** Mini-JSTorch
8
+ - **Internal Name:** JST (JSTorch)
9
+
10
+ > Note:
11
+ > Early versions of Mini-JSTorch do not strictly follow semantic versioning conventions
12
+ > (e.g. `0.0.1` for patches, `0.1.0` for minor releases, `1.0.0` for major releases).
13
+ > This inconsistency reflects the early learning and experimental phase of the project.
14
+
15
+ ---
16
+
17
+ ## 1. Engine Architecture Limitations (JSTorch Core)
18
+
19
+ This section outlines the known structural weaknesses of the JSTorch engine.
20
+ Although the architecture may appear complex, it is currently sensitive and tightly coupled.
21
+
22
+ ### Identified Limitations
23
+
24
+ - **High dependency on Utilities**
25
+ Every core class depends directly on the Utilities module, which is defined at the top of the `jstorch.js` file. This creates strong coupling across the engine.
26
+
27
+ - **Limited Tensor dimensionality**
28
+ Tensor implementations currently support only two dimensions.
29
+ Extending support to higher-dimensional tensors would require significant architectural changes due to the existing complexity.
30
+
31
+ - **Uneven class complexity**
32
+ New or recently modified classes often become significantly more complex than others, leading to inconsistency in maintainability and internal design balance.
33
+
34
+ ---
35
+
36
+ ## 2. Rationale Behind the `fu_` Utilities
37
+
38
+ This section explains why the `fu_` utilities were introduced despite the existence of internal Utilities.
39
+
40
+ ### Issues with Internal Utilities
41
+
42
+ - The Utilities defined at the beginning of `jstorch.js` are **internal engine helpers**, not intended for direct user interaction.
43
+
44
+ - These Utilities are heavily reused across multiple core classes.
45
+ Any modification to a utility function may trigger **cascading (domino) errors** throughout the engine due to tight dependencies.
46
+
47
+ - Some utility functions intentionally diverge from standard or expected formulas.
48
+ For example:
49
+ - Expected formula:
50
+ `Param1 - Param4 * Param3`
51
+ - Internal Utilities implementation:
52
+ `Param1 - Param2 * Param3 + Param4`
53
+
54
+ This behavior exists because internal Utilities are optimized for class-level computations, not for user-facing correctness or predictability.
55
+
56
+ ### Purpose of `fu_` Utilities
57
+
58
+ The `fu_` utilities were designed to improve the **user experience** by providing:
59
+
60
+ - Predictable and correct computational behavior
61
+ - User-friendly and stable helper functions
62
+ - Isolation from internal engine changes
63
+ - Reduced risk of incorrect outputs and dependency-based cascading errors
64
+
65
+ In short, `fu_` exists to ensure safety, clarity, and consistency for end users of Mini-JSTorch.
66
+
67
+ ---
68
+
69
+ ## 3. SJK (Shortcut JST Keywords) Reference
70
+
71
+ This section lists commonly used abbreviations and keywords within the Mini-JSTorch ecosystem.
72
+
73
+ **Format:** `"KEYWORD" : "Full Name / Meaning"`
74
+
75
+ - `"JST"` : JSTorch
76
+ - `"fu"` : For User / User-Friendly
77
+ - `"fun"` : Function
78
+ - `"Dummy"` : Experimental
79
+ - `"Exp"` : Restricted experimental entity
80
+ - `"msg"` : Message, comment, warning, announcement
81
+ - `"donot"` : Do not / Don't
82
+
83
+ ---
@@ -0,0 +1,129 @@
1
+ # Project File Structure #
2
+
3
+ This document describes the directory and file structure of the **mini-JSTorch** package.
4
+ It provides an overview of how the project is organized and the purpose of each major component.
5
+
6
+ ---
7
+
8
+ ## Repository Overview
9
+
10
+ ```text
11
+ mini-jstorch/
12
+ ├── demo/
13
+ │ ├── fu_fun.js
14
+ │ ├── MakeModel.js
15
+ │ └── scheduler.js
16
+ ├── Docs/
17
+ │ ├── About.md
18
+ │ └── Structure.md
19
+ ├── src/
20
+ │ ├── jstorch.js
21
+ │ └── Dummy/
22
+ │ └── msg/
23
+ ├── index.js
24
+ ├── package.json
25
+ └── README.md
26
+ ```
27
+
28
+ ---
29
+
30
+ ## Directory Descriptions
31
+
32
+ `/demo`
33
+
34
+ - Contains demonstration and testing files.
35
+
36
+ - Used for unit testing, quick system checks, and example usage
37
+ - Intended for users who prefer practical examples over reading full API documentation
38
+ - Allows testing features without writing extensive manual code
39
+
40
+ `/Docs`
41
+
42
+ - Contains detailed documentation related to the mini-JSTorch package.
43
+
44
+ - Provides deeper explanations of internal design and usage
45
+ - Intended for contributors and advanced users
46
+
47
+ `/src`
48
+
49
+ - Contains the source code of the JSTorch engine.
50
+
51
+ - Houses all core logic and internal implementations
52
+ - Modifications in this directory directly affect engine behavior
53
+
54
+ `/src/Dummy`
55
+
56
+ - Experimental and restricted directory.
57
+
58
+ - Used for experimental purposes and future development
59
+ - Files inside this directory may be unstable or incomplete
60
+ - Not intended for public or production use
61
+
62
+ `/src/Dummy/msg`
63
+
64
+ - Contains warning or message files.
65
+
66
+ - Indicates that files within the `Dummy` directory are restricted
67
+ - Serves as a notification mechanism for experimental or future-update-related content
68
+
69
+ ---
70
+
71
+ ## File Descriptions
72
+
73
+ `/demo/fu_fun.js`
74
+
75
+ - Purpose: Tests all user-facing (`fu_`) functions
76
+ - Notes: Focuses on friendly and predictable helper utilities
77
+
78
+ `/demo/MakeModel.js`
79
+
80
+ - Purpose: Demonstrates creation of a simple model
81
+ - Notes: Uses the `StepLR` scheduler as part of the example workflow
82
+
83
+ `/demo/scheduler.js`
84
+
85
+ - Purpose: Tests scheduler-related functionality
86
+ - Notes: Intended to validate learning rate scheduling behavior
87
+
88
+ `/Docs/About.md`
89
+
90
+ - Purpose: Contains additional information about the mini-JSTorch package
91
+ - Notes: May include background, design decisions, or non-API-related explanations
92
+
93
+ `/Docs/Structure.md`
94
+
95
+ - Purpose: Documents the repository file and folder structure
96
+ - Notes: This file
97
+
98
+ `/src/jstorch.js`
99
+
100
+ - Purpose: Core engine implementation
101
+
102
+ - Notes:
103
+
104
+ - Contains all JSTorch engine logic and functions
105
+ - Central file of the entire package
106
+ - Changes here have wide-ranging effects
107
+
108
+ `index.js`
109
+
110
+ - Purpose: Package entry point
111
+ - Notes: Exposes public APIs and connects internal modules
112
+
113
+ `package.json`
114
+
115
+ - Purpose: Project configuration and metadata
116
+ - Notes: Defines dependencies, scripts, and package information
117
+
118
+ `README.md`
119
+
120
+ - Purpose: Main documentation entry
121
+ - Notes: Provides overview, installation instructions, and basic usage
122
+
123
+ **Notes**
124
+
125
+ - Experimental files may change or be restricted without notice
126
+ - Users are encouraged to rely on public APIs and documented utilities
127
+ - Internal structures are subject to refactoring as the project evolves
128
+
129
+ ---
package/README.md CHANGED
@@ -4,17 +4,17 @@ A lightweight JavaScript neural network library for rapid frontend AI experiment
4
4
 
5
5
  ## Overview
6
6
 
7
- Mini-JSTorch is a high-performance, minimalist JavaScript library for building neural networks. It runs efficiently in both frontend and backend environments, including low-end devices. The library enables quick experimentation and learning in AI without compromising stability, accuracy, or training reliability.
7
+ Mini-JSTorch is a high-performance, minimalist JavaScript library for building neural networks. It runs efficiently in Frontend environments, including low-end devices. The library enables quick experimentation and learning in AI without compromising stability, accuracy, or training reliability.
8
8
 
9
- This release, **version 1.5.0:** Adds user-friendly tensor functions, Flatten layer, and improved Conv2D operations and Modify Some Class
10
- For Architecture Compability.
9
+ This release, **version 1.7.0:** We introduce **Softmax Layer**, **Tokenizer**, **AdamW Optimizer**, and enhanced NLP capabilities.
11
10
 
12
11
  ---
13
12
 
14
13
  ## New Features Highlights
15
14
 
16
- - **User-Friendly Tensor API:** New `fu_` functions (`fu_tensor`, `fu_add`, `fu_matmul`, etc.) with automatic validation and shape checking
17
- - **Flatten Layer:** Essential for connecting CNN architectures to dense layers.
15
+ - **Softmax Layer:** Professional classification output with proper gradient computation
16
+ - **Tokenizer:** Lightweight text preprocessing for NLP tasks
17
+ - **AdamW Optimizer:** Modern optimizer with decoupled weight decay
18
18
 
19
19
  ---
20
20
 
@@ -23,12 +23,17 @@ For Architecture Compability.
23
23
  - **Layers:** Linear, Flatten, Conv2D
24
24
  - **Activations:** ReLU, Sigmoid, Tanh, LeakyReLU, GELU, Mish, SiLU, ELU
25
25
  - **Loss Functions:** MSELoss, CrossEntropyLoss
26
- - **Optimizers:** Adam, SGD
27
- - **Schedulers:** StepLR, LambdaLR
26
+ - **Optimizers:** Adam, SGD, LION, **AdamW**
27
+ - **Schedulers:** StepLR, LambdaLR, ReduceLROnPlateau
28
28
  - **Regularization:** Dropout, BatchNorm2D
29
29
  - **Utilities:** zeros, randomMatrix, softmax, crossEntropy, dot, addMatrices, reshape, stack, flatten, eye, concat
30
30
  - **Model Container:** Sequential (for stacking layers with forward/backward passes)
31
31
 
32
+ # Others
33
+
34
+ - **Tokenizer**
35
+ - **Softmax Layer**
36
+
32
37
  ---
33
38
 
34
39
  ## Installation
@@ -43,7 +48,7 @@ npm install mini-jstorch
43
48
  ## Quick Start Example
44
49
 
45
50
  ```javascript
46
- import { Sequential, Linear, ReLU, Sigmoid, CrossEntropyLoss, Adam, StepLR } from './jstorch.js';
51
+ import { Sequential, Linear, ReLU, Sigmoid, CrossEntropyLoss, Adam, StepLR } from './src/jstorch.js';
47
52
 
48
53
  // Build model
49
54
  const model = new Sequential([
@@ -101,21 +106,23 @@ loadModel(model2, json);
101
106
 
102
107
  ## Demos & Testing
103
108
 
104
- Check the `tests/` directory for ready-to-run demos:
105
- - **tests/MakeModel.js:** Build and run a simple neural network.
106
- - **tests/scheduler.js:** Experiment with learning rate schedulers.
109
+ Check the `demo/` directory for ready-to-run demos:
110
+ - **demo/MakeModel.js:** Build and run a simple neural network.
111
+ - **demo/scheduler.js:** Experiment with learning rate schedulers.
112
+ - **demo/fu_fun.js:** Test all user-friendly (fu or For Users/Friendly Users) functions
107
113
  - Add your own scripts for quick prototyping!
108
114
 
109
115
  ```bash
110
- node tests/MakeModel.js
111
- node tests/scheduler.js
116
+ node demo/MakeModel.js
117
+ node demo/scheduler.js
118
+ node demo/fu_fun.js
112
119
  ```
113
120
 
114
121
  ---
115
122
 
116
123
  ## Intended Use Cases
117
124
 
118
- - Rapid prototyping of neural networks in frontend and backend.
125
+ - Rapid prototyping of neural networks in frontend.
119
126
  - Learning and teaching foundational neural network concepts.
120
127
  - Experimentation on low-end devices or mobile browsers.
121
128
  - Lightweight AI projects without GPU dependency.
@@ -126,22 +133,4 @@ node tests/scheduler.js
126
133
 
127
134
  `MIT License`
128
135
 
129
- **Copyright (c) 2025 rizal-editors**
130
-
131
- Permission is hereby granted, free of charge, to any person obtaining a copy
132
- of this software and associated documentation files (the "Software"), to deal
133
- in the Software without restriction, including without limitation the rights
134
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
135
- copies of the Software, and to permit persons to whom the Software is
136
- furnished to do so, subject to the following conditions:
137
-
138
- The above copyright notice and this permission notice shall be included in all
139
- copies or substantial portions of the Software.
140
-
141
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
142
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
143
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
144
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
145
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
146
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
147
- SOFTWARE.
136
+ **Copyright (c) 2024 rizal-editors**
@@ -5,68 +5,68 @@ import {
5
5
  } from '../src/jstorch.js';
6
6
 
7
7
  function testAllFuFunctions() {
8
- console.log("🧪 TESTING ALL FU_FUNCTIONS\n");
8
+ console.log("TESTING ALL FU_FUNCTIONS\n");
9
9
 
10
10
  // Test 1: fu_tensor
11
11
  console.log("1. fu_tensor");
12
12
  const t1 = fu_tensor([[1, 2], [3, 4]]);
13
- console.log("", t1.data);
13
+ console.log("", t1.data);
14
14
 
15
15
  // Test 2: fu_add
16
16
  console.log("\n2. fu_add");
17
17
  const a = fu_tensor([[1, 2]]);
18
18
  const b = fu_tensor([[3, 4]]);
19
19
  const c = fu_add(a, b);
20
- console.log("", a.data, "+", b.data, "=", c.data);
20
+ console.log("", a.data, "+", b.data, "=", c.data);
21
21
 
22
22
  // Test 3: fu_mul
23
23
  console.log("\n3. fu_mul");
24
24
  const d = fu_mul(a, b);
25
- console.log("", a.data, "*", b.data, "=", d.data);
25
+ console.log("", a.data, "*", b.data, "=", d.data);
26
26
 
27
27
  // Test 4: fu_matmul
28
28
  console.log("\n4. fu_matmul");
29
29
  const e = fu_tensor([[1, 2]]);
30
30
  const f = fu_tensor([[3], [4]]);
31
31
  const g = fu_matmul(e, f);
32
- console.log("matmul =", g.data);
32
+ console.log("matmul =", g.data);
33
33
 
34
34
  // Test 5: fu_sum & fu_mean
35
35
  console.log("\n5. fu_sum & fu_mean");
36
36
  const h = fu_tensor([[1, 2], [3, 4]]);
37
37
  const sum = fu_sum(h);
38
38
  const mean = fu_mean(h);
39
- console.log("sum =", sum.data, "mean =", mean.data);
39
+ console.log("sum =", sum.data, "mean =", mean.data);
40
40
 
41
41
  // Test 6: fu_relu
42
42
  console.log("\n6. fu_relu");
43
43
  const i = fu_tensor([[-1, 0], [1, 2]]);
44
44
  const relu = fu_relu(i);
45
- console.log("relu =", relu.data);
45
+ console.log("relu =", relu.data);
46
46
 
47
47
  // Test 7: fu_sigmoid
48
48
  console.log("\n7. fu_sigmoid");
49
49
  const sigmoid = fu_sigmoid(i);
50
- console.log("sigmoid =", sigmoid.data);
50
+ console.log("sigmoid =", sigmoid.data);
51
51
 
52
52
  // Test 8: fu_tanh
53
53
  console.log("\n8. fu_tanh");
54
54
  const tanh = fu_tanh(i);
55
- console.log("tanh =", tanh.data);
55
+ console.log("tanh =", tanh.data);
56
56
 
57
57
  // Test 9: fu_softmax
58
58
  console.log("\n9. fu_softmax");
59
59
  const j = fu_tensor([[1, 2, 3]]);
60
60
  const softmax = fu_softmax(j);
61
- console.log("softmax =", softmax.data);
61
+ console.log("softmax =", softmax.data);
62
62
 
63
63
  // Test 10: fu_flatten & fu_reshape
64
64
  console.log("\n10. fu_flatten & fu_reshape");
65
65
  const k = fu_tensor([[1, 2], [3, 4]]);
66
66
  const flat = fu_flatten(k);
67
67
  const reshaped = fu_reshape(flat, 1, 4);
68
- console.log("flatten =", flat.data);
69
- console.log("reshape =", reshaped.data);
68
+ console.log("flatten =", flat.data);
69
+ console.log("reshape =", reshaped.data);
70
70
  }
71
71
 
72
72
  testAllFuFunctions();
@@ -0,0 +1,69 @@
1
+ // Example: Test ALL learning rate schedulers in mini-jstorch with mini-jstorch optimizers
2
+
3
+ import { SGD, StepLR, LambdaLR, ReduceLROnPlateau, Tensor } from "../src/jstorch.js";
4
+
5
+ const param = { param: [[1, 2], [3, 4]], grad: [[0, 0], [0, 0]] };
6
+ const optimizer = new SGD([param], 0.1);
7
+
8
+ // --- Test StepLR ---
9
+ console.log("Testing StepLR...");
10
+ const stepScheduler = new StepLR(optimizer, 3, 0.5);
11
+ for (let epoch = 1; epoch <= 10; epoch++) {
12
+ stepScheduler.step();
13
+ console.log(`Epoch ${epoch}: LR = ${optimizer.lr.toFixed(4)}`);
14
+ }
15
+
16
+ // --- Test LambdaLR ---
17
+ console.log("\nTesting LambdaLR...");
18
+ optimizer.lr = 0.1; // Reset LR
19
+ const lambdaScheduler = new LambdaLR(optimizer, epoch => 1.0 / (1 + epoch));
20
+ for (let epoch = 1; epoch <= 5; epoch++) {
21
+ lambdaScheduler.step();
22
+ console.log(`Epoch ${epoch}: LR = ${optimizer.lr.toFixed(4)}`);
23
+ }
24
+
25
+ // --- Test ReduceLROnPlateau ---
26
+ console.log("\nTesting ReduceLROnPlateau...");
27
+ optimizer.lr = 0.1; // Reset LR
28
+ const plateauScheduler = new ReduceLROnPlateau(optimizer, {
29
+ patience: 2,
30
+ factor: 0.5,
31
+ min_lr: 0.01,
32
+ verbose: true
33
+ });
34
+
35
+ // Simulate training with plateauing loss
36
+ const losses = [0.9, 0.8, 0.7, 0.69, 0.68, 0.68, 0.68, 0.67, 0.67, 0.67];
37
+ console.log("Simulated training with plateauing loss:");
38
+ for (let epoch = 0; epoch < losses.length; epoch++) {
39
+ plateauScheduler.step(losses[epoch]);
40
+ console.log(`Epoch ${epoch + 1}: Loss = ${losses[epoch].toFixed(3)}, LR = ${optimizer.lr.toFixed(4)}, Wait = ${plateauScheduler.wait}`);
41
+ }
42
+
43
+ // --- Test ReduceLROnPlateau with Cooldown ---
44
+ console.log("\nTesting ReduceLROnPlateau with Cooldown...");
45
+ optimizer.lr = 0.1; // Reset LR
46
+ const plateauWithCooldown = new ReduceLROnPlateau(optimizer, {
47
+ patience: 2,
48
+ factor: 0.5,
49
+ min_lr: 0.01,
50
+ cooldown: 2,
51
+ verbose: true
52
+ });
53
+
54
+ // Simulate training with multiple plateaus
55
+ const losses2 = [0.9, 0.9, 0.9, 0.9, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7];
56
+ console.log("Simulated training with cooldown:");
57
+ for (let epoch = 0; epoch < losses2.length; epoch++) {
58
+ plateauWithCooldown.step(losses2[epoch]);
59
+ console.log(`Epoch ${epoch + 1}: Loss = ${losses2[epoch].toFixed(3)}, LR = ${optimizer.lr.toFixed(4)}, Wait = ${plateauWithCooldown.wait}, Cooldown = ${plateauWithCooldown.cooldown_counter}`);
60
+ }
61
+
62
+ // --- Summary ---
63
+ console.log("\nSCHEDULER SUMMARY:");
64
+ console.log(`StepLR: ${stepScheduler.last_epoch} epochs processed`);
65
+ console.log(`LambdaLR: ${lambdaScheduler.last_epoch} epochs processed`);
66
+ console.log(`ReduceLROnPlateau: ${plateauScheduler.num_reductions} LR reductions`);
67
+ console.log(`ReduceLROnPlateau with Cooldown: ${plateauWithCooldown.num_reductions} LR reductions`);
68
+
69
+ console.log("\nAll schedulers tested successfully!");
package/index.js CHANGED
@@ -1,2 +1,3 @@
1
1
  // package root
2
+ // * = all exports from src/jstorch.js
2
3
  export * from "./src/jstorch.js";
package/package.json CHANGED
@@ -1,22 +1,24 @@
1
1
  {
2
2
  "name": "mini-jstorch",
3
- "version": "1.5.0",
3
+ "version": "1.7.0",
4
4
  "type": "module",
5
- "description": "A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices Inspired by PyTorch.",
5
+ "description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
6
6
  "main": "index.js",
7
7
  "keywords": [
8
8
  "neural-network",
9
9
  "javascript",
10
10
  "lightweight-torch",
11
11
  "lightweight",
12
- "small",
12
+ "small-torch",
13
13
  "javascript-torch",
14
- "ai",
14
+ "ai-model",
15
15
  "jstorch",
16
16
  "pytorch",
17
17
  "front-end",
18
18
  "machine-learning",
19
- "mini"
19
+ "tiny-ml",
20
+ "frontend-ai",
21
+ "mini-neural-network"
20
22
  ],
21
23
  "author": "Rizal",
22
24
  "license": "MIT",
package/src/jstorch.js CHANGED
@@ -1,9 +1,9 @@
1
1
  /*!
2
2
  * Project: mini-jstorch
3
- * File: MainEngine.js
4
- * Author: M. Rizal H. (Actual Author Name)
3
+ * File: jstorch.js
4
+ * Author: Rizal-editors
5
5
  * License: MIT
6
- * Copyright (C) 2025 M. Rizal H.
6
+ * Copyright (C) 2025 Rizal-editors
7
7
  *
8
8
  * Permission is hereby granted, free of charge, to any person obtaining a copy
9
9
  * of this software and associated documentation files (the "Software"), to deal
@@ -24,8 +24,7 @@
24
24
  * SOFTWARE.
25
25
  */
26
26
 
27
-
28
- // ---------------------- Utilities ----------------------
27
+ // ---------------------- DONOT USE THESE (ENGINE INTERNALS) ----------------------
29
28
  export function zeros(rows, cols) {
30
29
  return Array.from({length:rows},()=>Array(cols).fill(0));
31
30
  }
@@ -34,10 +33,14 @@ export function ones(rows, cols) {
34
33
  return Array.from({length:rows},()=>Array(cols).fill(1));
35
34
  }
36
35
 
37
- export function randomMatrix(rows, cols, scale=0.1){
38
- return Array.from({length:rows},()=>
39
- Array.from({length:cols},()=> (Math.random()*2-1)*scale)
40
- );
36
+ export function randomMatrix(rows, cols, scale=null){
37
+ // Auto-scale based on layer size (Xavier init)
38
+ if (scale === null){
39
+ scale = Math.sqrt(2.0 / (rows + cols));
40
+ }
41
+
42
+ return Array.from({length: rows}, () =>
43
+ Array.from({length: cols}, () => (Math.random() * 2 - 1) * scale));
41
44
  }
42
45
 
43
46
  export function transpose(matrix){
@@ -71,7 +74,7 @@ export function crossEntropy(pred,target){
71
74
  return -target.reduce((sum,t,i)=>sum+t*Math.log(pred[i]+eps),0);
72
75
  }
73
76
 
74
- // ---------------------- USERS FRIENDLY UTILS ----------------
77
+ // ---------------------- USERS FRIENDLY UTILS (USE THIS!) ----------------
75
78
  export function fu_tensor(data, requiresGrad = false) {
76
79
  if (!Array.isArray(data) || !Array.isArray(data[0])) {
77
80
  throw new Error("fu_tensor: Data must be 2D array");
@@ -242,35 +245,75 @@ export class Tensor {
242
245
 
243
246
  // ---------------------- Layers ----------------------
244
247
  export class Linear {
245
- constructor(inputDim,outputDim){
246
- this.W=randomMatrix(inputDim,outputDim);
247
- this.b=Array(outputDim).fill(0);
248
- this.gradW=zeros(inputDim,outputDim);
249
- this.gradb=Array(outputDim).fill(0);
250
- this.x=null;
248
+ constructor(inputDim, outputDim){
249
+ this.W = randomMatrix(inputDim, outputDim);
250
+ this.b = Array(outputDim).fill(0);
251
+ this.gradW = zeros(inputDim, outputDim);
252
+ this.gradb = Array(outputDim).fill(0);
253
+ this.x = null;
254
+ this.originalShape = null; // Track input shape
251
255
  }
252
256
 
253
257
  forward(x){
254
- this.x=x;
255
- const out=dot(x,this.W);
256
- return out.map((row,i)=>row.map((v,j)=>v+this.b[j]));
258
+ // Handle both [batch, features] and [batch, 1, features]
259
+ this.originalShape = this._getShapeType(x);
260
+
261
+ if (this.originalShape === '3d') {
262
+ // Convert from [batch, 1, features] to [batch, features]
263
+ this.x = x.map(sample => sample[0]);
264
+ } else {
265
+ // Already in [batch, features] format
266
+ this.x = x;
267
+ }
268
+
269
+ const out = dot(this.x, this.W);
270
+ return out.map((row, i) => row.map((v, j) => v + this.b[j]));
257
271
  }
258
272
 
259
273
  backward(grad){
260
- for(let i=0;i<this.W.length;i++) for(let j=0;j<this.W[0].length;j++)
261
- this.gradW[i][j]=this.x.reduce((sum,row,k)=>sum+row[i]*grad[k][j],0);
262
- for(let j=0;j<this.b.length;j++)
263
- this.gradb[j]=grad.reduce((sum,row)=>sum+row[j],0);
264
-
265
- const gradInput=zeros(this.x.length,this.W.length);
266
- for(let i=0;i<this.x.length;i++)
267
- for(let j=0;j<this.W.length;j++)
268
- for(let k=0;k<this.W[0].length;k++)
269
- gradInput[i][j]+=grad[i][k]*this.W[j][k];
274
+ // Compute gradients
275
+ for(let i = 0; i < this.W.length; i++) {
276
+ for(let j = 0; j < this.W[0].length; j++) {
277
+ this.gradW[i][j] = this.x.reduce((sum, row, k) => sum + row[i] * grad[k][j], 0);
278
+ }
279
+ }
280
+
281
+ for(let j = 0; j < this.b.length; j++) {
282
+ this.gradb[j] = grad.reduce((sum, row) => sum + row[j], 0);
283
+ }
284
+
285
+ const gradInput = zeros(this.x.length, this.W.length);
286
+ for(let i = 0; i < this.x.length; i++) {
287
+ for(let j = 0; j < this.W.length; j++) {
288
+ for(let k = 0; k < this.W[0].length; k++) {
289
+ gradInput[i][j] += grad[i][k] * this.W[j][k];
290
+ }
291
+ }
292
+ }
293
+
294
+ //Convert back to original shape if needed
295
+ if (this.originalShape === '3d') {
296
+ return gradInput.map(row => [row]); // Back to [batch, 1, features]
297
+ }
270
298
  return gradInput;
271
299
  }
272
300
 
273
- parameters(){ return [ {param:this.W,grad:this.gradW}, {param:[this.b],grad:[this.gradb]} ]; }
301
+ _getShapeType(x) {
302
+ if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
303
+ return '3d'; // [batch, 1, features]
304
+ } else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
305
+ return '2d'; // [batch, features]
306
+ } else {
307
+ throw new Error(`Unsupported input shape for Linear layer`);
308
+ }
309
+ }
310
+
311
+ parameters(){
312
+ return [
313
+ {param: this.W, grad: this.gradW},
314
+ {param: [this.b], grad: [this.gradb]}
315
+ ];
316
+ }
274
317
  }
275
318
 
276
319
  export class Flatten {
@@ -509,38 +552,187 @@ export class Sequential {
509
552
 
510
553
  // ---------------------- Activations ----------------------
511
554
  export class ReLU{
512
- constructor(){ this.out=null; }
555
+ constructor(){ this.mask = null; this.originalShape = null; }
513
556
 
514
557
  forward(x){
515
- // Handle both [batch, features] and [batch, channels, height, width]
516
- if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && Array.isArray(x[0][0][0])) {
517
- // [batch, channels, height, width]
518
- this.out = x.map(batch =>
519
- batch.flatMap(channel =>
520
- channel.flatMap(row =>
521
- row.map(v => Math.max(0, v))
522
- )
523
- )
524
- );
558
+ this.originalShape = this._getShapeType(x);
559
+
560
+ if (this.originalShape === '3d') {
561
+ // Handle [batch, 1, features]
562
+ this.mask = x.map(sample => sample[0].map(v => v > 0));
563
+ return x.map(sample => [sample[0].map(v => Math.max(0, v))]);
525
564
  } else {
526
- // [batch, features] - existing behavior
527
- this.out = x.map(r => r.map(v => Math.max(0, v)));
565
+ // Handle [batch, features]
566
+ this.mask = x.map(row => row.map(v => v > 0));
567
+ return x.map(row => row.map(v => Math.max(0, v)));
528
568
  }
529
- return this.out;
530
569
  }
531
570
 
532
571
  backward(grad){
533
- // Gradient shape must match forward output shape
534
- if (Array.isArray(grad[0]) && Array.isArray(grad[0][0])) {
535
- // Standard [batch, features]
536
- return grad.map((r, i) => r.map((v, j) => v * (this.out[i][j] > 0 ? 1 : 0)));
572
+ if (this.originalShape === '3d') {
573
+ return grad.map((sample, i) =>
574
+ [sample[0].map((v, j) => this.mask[i][j] ? v : 0)]
575
+ );
537
576
  } else {
538
- // return as is
539
- return grad;
577
+ return grad.map((row, i) =>
578
+ row.map((v, j) => this.mask[i][j] ? v : 0)
579
+ );
540
580
  }
541
- }
581
+ }
582
+
583
+ _getShapeType(x) {
584
+ if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
585
+ return '3d';
586
+ } else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
587
+ return '2d';
588
+ } else {
589
+ throw new Error(`Unsupported input shape for ReLU`);
590
+ }
591
+ }
592
+ }
593
+
594
+ // ---------------------- Softmax ----------------------
595
+ export class Softmax {
596
+ constructor(dim = -1) {
597
+ this.dim = dim;
598
+ this.output = null;
599
+ this.input = null;
600
+ }
601
+
602
+ forward(x) {
603
+ this.input = x;
604
+
605
+ // x: [batch_size, num_classes]
606
+ this.output = x.map(row => {
607
+ const maxVal = Math.max(...row);
608
+ const exps = row.map(v => Math.exp(v - maxVal));
609
+ const sumExps = exps.reduce((a, b) => a + b, 0);
610
+ return exps.map(v => v / sumExps);
611
+ });
612
+ return this.output;
613
+ }
614
+
615
+ backward(grad) {
616
+ // grad: [batch_size, num_classes] - gradient from next layer
617
+ const batchSize = grad.length;
618
+ const numClasses = grad[0].length;
619
+
620
+ const gradInput = zeros(batchSize, numClasses);
621
+
622
+ for (let i = 0; i < batchSize; i++) {
623
+ const s = this.output[i]; // Softmax output for this sample
624
+ const gradOut = grad[i]; // Gradient from loss
625
+
626
+ // Compute Jacobian matrix: J_ij = s_i * (δ_ij - s_j)
627
+ for (let j = 0; j < numClasses; j++) {
628
+ let sum = 0;
629
+ for (let k = 0; k < numClasses; k++) {
630
+ // J[j][k] = s[j] * ((j === k ? 1 : 0) - s[k])
631
+ const jacobian = s[j] * ((j === k ? 1 : 0) - s[k]);
632
+ sum += jacobian * gradOut[k];
633
+ }
634
+ gradInput[i][j] = sum;
635
+ }
636
+ }
637
+
638
+ return gradInput;
639
+ }
640
+
641
+ parameters() {
642
+ return []; // Softmax has no trainable parameters
643
+ }
542
644
  }
543
645
 
646
+ // ---------------------- Tokenizer ----------------------
647
+ export class Tokenizer {
648
+ constructor(vocabSize = 2000){
649
+ this.vocabSize = vocabSize;
650
+ this.wordToIndex = new Map();
651
+ this.indexToWord = new Map();
652
+ this.fitted = false;
653
+ }
654
+
655
+ fit(texts){
656
+ const wordCounts = new Map();
657
+
658
+ // Count word frequencies from all texts
659
+ texts.forEach(text => {
660
+ const words = this._preprocess(text);
661
+ words.forEach(word => {
662
+ wordCounts.set(word, (wordCounts.get(word) || 0) + 1);
663
+ });
664
+ });
665
+
666
+ // Sort by frequency and take top words
667
+ const sortedWords = [...wordCounts.entries()]
668
+ .sort((a, b) => a[1] - a[1])
669
+ .slice(0, this.vocabSize - 1); // Reverse 1 for unknown
670
+
671
+ // Build vocabulary
672
+ this.wordToIndex.clear();
673
+ this.indexToWord.clear();
674
+
675
+ // Add unk token
676
+ this.wordToIndex.set('<UNK>', 0);
677
+ this.indexToWord.set(0, '<UNK>');
678
+
679
+ // Add most frequent words
680
+ sortedWords.forEach(([word], index) =>{
681
+ this.wordToIndex.set(word, index + 1);
682
+ this.indexToWord.set(index + 1, word);
683
+ })
684
+
685
+ this.fitted = true;
686
+ return this;
687
+ }
688
+
689
+ tokenize(text){
690
+ if (!this.fitted) throw new Error("Tokenizer not fitted. Call fit() first.");
691
+
692
+ const words = this._preprocess(text);
693
+ return words.map(word => this.wordToIndex.get(word) || 0);
694
+ }
695
+
696
+ tokenizeBatch(texts, maxLength=null){
697
+ if (!this.fitted) throw new Error("Tokenizer not fitted. Call fit() first.");
698
+
699
+ return texts.map(text => {
700
+ const tokens = this.tokenize(text);
701
+
702
+ if (maxLength !== null){
703
+ // Pad or truncate to maxLength
704
+ if (tokens.length > maxLength){
705
+ return tokens.slice(0, maxLength);
706
+ } else {
707
+ return [...tokens, ...Array(maxLength - tokens.length).fill(0)];
708
+ }
709
+ }
710
+
711
+ return tokens;
712
+ })
713
+ }
714
+
715
+ detokenize(tokens){
716
+ return tokens.map(token => this.indexToWord.get(token) || '<UNK>').join(' ');
717
+ }
718
+
719
+ detokenizeBatch(tokenBatches){
720
+ return tokenBatches.map(tokens => this.detokenize(tokens));
721
+ }
722
+
723
+ getVocabSize(){
724
+ return this.wordToIndex.size;
725
+ }
726
+
727
+ _preprocess(text) {
728
+ return text.toLowerCase()
729
+ .replace(/[^\w\s]/g, ' ') // Remove punctuation
730
+ .split(/\s+/) // Split by whitespace
731
+ .filter(word => word.length > 0); // Remove empty strings
732
+ }
733
+ }
734
+
735
+ // I'm too lazy to break lines here, so everything stays in one line
544
736
  export class Sigmoid{ constructor(){ this.out=null; } forward(x){ const fn=v=>1/(1+Math.exp(-v)); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*this.out[i][j]*(1-this.out[i][j]))); } }
545
737
  export class Tanh{ constructor(){ this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>Math.tanh(v))); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(1-this.out[i][j]**2))); } }
546
738
  export class LeakyReLU{ constructor(alpha=0.01){ this.alpha=alpha; this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>v>0?v:v*this.alpha)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(this.out[i][j]>0?1:this.alpha))); } }
@@ -555,24 +747,236 @@ export class CrossEntropyLoss{ forward(pred,target){ this.pred=pred; this.target
555
747
 
556
748
  // ---------------------- Optimizers ----------------------
557
749
  export class Adam{
558
- constructor(params,lr=0.001,b1=0.9,b2=0.999,eps=1e-8){
559
- this.params=params; this.lr=lr; this.beta1=b1; this.beta2=b2; this.eps=eps;
560
- this.m=params.map(p=>zeros(p.param.length,p.param[0].length||1));
561
- this.v=params.map(p=>zeros(p.param.length,p.param[0].length||1));
562
- this.t=0;
750
+ constructor(params, lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-8, max_grad_norm = 1.0){
751
+ // Handle both parameter styles: (params, lr) OR (params, {lr, ...})
752
+ if (typeof lr === 'object') {
753
+ // Options object provided
754
+ const options = lr;
755
+ this.lr = options.lr || 0.001;
756
+ this.beta1 = options.b1 || options.beta1 || 0.9;
757
+ this.beta2 = options.b2 || options.beta2 || 0.999;
758
+ this.eps = options.eps || 1e-8;
759
+ this.max_grad_norm = options.max_grad_norm || 1.0;
760
+ } else {
761
+ // Individual parameters provided
762
+ this.lr = lr;
763
+ this.beta1 = b1;
764
+ this.beta2 = b2;
765
+ this.eps = eps;
766
+ this.max_grad_norm = max_grad_norm;
767
+ }
768
+
769
+ this.params = params;
770
+ this.m = params.map(p => zeros(p.param.length, p.param[0].length || 1));
771
+ this.v = params.map(p => zeros(p.param.length, p.param[0].length || 1));
772
+ this.t = 0;
773
+
563
774
  }
775
+
564
776
  step(){
565
777
  this.t++;
566
- this.params.forEach((p,idx)=>{
567
- for(let i=0;i<p.param.length;i++)
568
- for(let j=0;j<(p.param[0].length||1);j++){
569
- const g=p.grad[i][j];
570
- this.m[idx][i][j]=this.beta1*this.m[idx][i][j]+(1-this.beta1)*g;
571
- this.v[idx][i][j]=this.beta2*this.v[idx][i][j]+(1-this.beta2)*g*g;
572
- const mHat=this.m[idx][i][j]/(1-Math.pow(this.beta1,this.t));
573
- const vHat=this.v[idx][i][j]/(1-Math.pow(this.beta2,this.t));
574
- p.param[i][j]-=this.lr*mHat/(Math.sqrt(vHat)+this.eps);
778
+ this.params.forEach((p, idx) => {
779
+ // Calculate gradient norm for clipping
780
+ let grad_norm_sq = 0;
781
+ for (let i = 0; i < p.param.length; i++){
782
+ for (let j = 0; j < (p.param[0].length || 1); j++){
783
+ const grad_val = p.grad[i] && p.grad[i][j] !== undefined ? p.grad[i][j] : 0;
784
+ grad_norm_sq += grad_val * grad_val;
785
+ }
786
+ }
787
+
788
+ const grad_norm = Math.sqrt(grad_norm_sq);
789
+ const clip_scale = grad_norm > this.max_grad_norm ? this.max_grad_norm / grad_norm : 1.0;
790
+
791
+ // Update with clipped gradients
792
+ for (let i = 0; i < p.param.length; i++){
793
+ for(let j = 0; j < (p.param[0].length || 1); j++){
794
+ if (p.grad[i] && p.grad[i][j] !== undefined){
795
+ const g = p.grad[i][j] * clip_scale;
796
+ this.m[idx][i][j] = this.beta1 * this.m[idx][i][j] + (1 - this.beta1) * g;
797
+ this.v[idx][i][j] = this.beta2 * this.v[idx][i][j] + (1 - this.beta2) * g * g;
798
+ const mHat = this.m[idx][i][j] / (1 - Math.pow(this.beta1, this.t));
799
+ const vHat = this.v[idx][i][j] / (1 - Math.pow(this.beta2, this.t));
800
+ p.param[i][j] -= this.lr * mHat / (Math.sqrt(vHat) + this.eps);
801
+ }
802
+ }
803
+ }
804
+ });
805
+ }
806
+ }
807
+
808
+ // ---------------------- AdamW Optimizer ----------------------
809
+ export class AdamW {
810
+ constructor(params, options = {}) {
811
+ const {
812
+ lr = 0.001,
813
+ beta1 = 0.9,
814
+ beta2 = 0.999,
815
+ eps = 1e-8,
816
+ weight_decay = 0.01,
817
+ max_grad_norm = 1.0
818
+ } = options;
819
+
820
+ this.params = params;
821
+ this.lr = lr;
822
+ this.beta1 = beta1;
823
+ this.beta2 = beta2;
824
+ this.eps = eps;
825
+ this.weight_decay = weight_decay;
826
+ this.max_grad_norm = max_grad_norm;
827
+
828
+ this.m = params.map(p => zeros(p.param.length, p.param[0].length || 1));
829
+ this.v = params.map(p => zeros(p.param.length, p.param[0].length || 1));
830
+ this.t = 0;
831
+ }
832
+
833
+ step() {
834
+ this.t++;
835
+ this.params.forEach((p, idx) => {
836
+ // Gradient clipping (same as Adam)
837
+ let grad_norm_sq = 0;
838
+ for (let i = 0; i < p.param.length; i++) {
839
+ for (let j = 0; j < (p.param[0].length || 1); j++) {
840
+ const grad_val = p.grad[i] && p.grad[i][j] !== undefined ? p.grad[i][j] : 0;
841
+ grad_norm_sq += grad_val * grad_val;
842
+ }
843
+ }
844
+ const grad_norm = Math.sqrt(grad_norm_sq);
845
+ const clip_scale = grad_norm > this.max_grad_norm ? this.max_grad_norm / grad_norm : 1.0;
846
+
847
+ // AdamW update: weight decay applied separately
848
+ for (let i = 0; i < p.param.length; i++) {
849
+ for (let j = 0; j < (p.param[0].length || 1); j++) {
850
+ if (p.grad[i] && p.grad[i][j] !== undefined) {
851
+ const g = p.grad[i][j] * clip_scale;
852
+
853
+ // Adam moments
854
+ this.m[idx][i][j] = this.beta1 * this.m[idx][i][j] + (1 - this.beta1) * g;
855
+ this.v[idx][i][j] = this.beta2 * this.v[idx][i][j] + (1 - this.beta2) * g * g;
856
+
857
+ const mHat = this.m[idx][i][j] / (1 - Math.pow(this.beta1, this.t));
858
+ const vHat = this.v[idx][i][j] / (1 - Math.pow(this.beta2, this.t));
859
+
860
+ // AdamW key difference: weight decay applied to weights, not gradients
861
+ p.param[i][j] -= this.lr * (
862
+ mHat / (Math.sqrt(vHat) + this.eps) +
863
+ this.weight_decay * p.param[i][j] // Decoupled weight decay
864
+ );
865
+ }
866
+ }
867
+ }
868
+ });
869
+ }
870
+ }
871
+
872
+ export class SGD{
873
+ constructor(params, lr = 0.01, max_grad_norm = 1.0) {
874
+ this.params = params;
875
+ this.lr = lr;
876
+ this.max_grad_norm = max_grad_norm; // Gradient Clipping
877
+ }
878
+
879
+ step() {
880
+ this.params.forEach(p => {
881
+ // Calculate gradient norm
882
+ let grad_norm_sq = 0;
883
+ let total_params = 0;
884
+
885
+ for (let i = 0; i < p.param.length; i++){
886
+ const row = p.param[i];
887
+ for (let j = 0; j < (row.length || 1); j++) {
888
+ const grad_val = p.grad[i] && p.grad[i][j] !== undefined ? p.grad[i][j] : 0;
889
+ grad_norm_sq += grad_val * grad_val;
890
+ total_params++;
891
+ }
892
+ }
893
+
894
+ const grad_norm = Math.sqrt(grad_norm_sq);
895
+
896
+ // Apply gradient clipping if needed
897
+ const clip_scale = grad_norm > this.max_grad_norm ? this.max_grad_norm / grad_norm : 1.0;
898
+
899
+ // Update parameters with clipped gradients
900
+ for (let i = 0; i < p.param.length; i++){
901
+ const row = p.param[i];
902
+ for (let j = 0; j < (row.length || 1); j++) {
903
+ if (p.grad[i] && p.grad[i][j] !== undefined){
904
+ p.param[i][j] -= this.lr * (p.grad[i][j] * clip_scale);
905
+ }
906
+ }
907
+ }
908
+ });
909
+ }
910
+ }
911
+
912
+
913
+ export class LION {
914
+ constructor(params, options = {}) {
915
+ this.params = params;
916
+
917
+ const {
918
+ lr = 0.0001, // Lions typically uses smaller LR
919
+ beta1 = 0.9, // First moment decay
920
+ beta2 = 0.99, // Second moment decay
921
+ weight_decay = 0, // L2 regularization
922
+ eps = 1e-8 // Numerical stability
923
+ } = options;
924
+
925
+ this.lr = lr;
926
+ this.beta1 = beta1;
927
+ this.beta2 = beta2;
928
+ this.weight_decay = weight_decay;
929
+ this.eps = eps;
930
+
931
+ // Initialize momentums
932
+ this.m = params.map(p => zeros(p.param.length, p.param[0].length || 1));
933
+ this.t = 0;
934
+ }
935
+
936
+ step() {
937
+ this.t++;
938
+
939
+ this.params.forEach((p, idx) => {
940
+ for (let i = 0; i < p.param.length; i++) {
941
+ for (let j = 0; j < (p.param[0].length || 1); j++) {
942
+ if (p.grad[i] && p.grad[i][j] !== undefined) {
943
+ const grad = p.grad[i][j];
944
+
945
+ // Update momentum: m_t = β1 * m_{t-1} + (1 - β1) * g_t
946
+ this.m[idx][i][j] = this.beta1 * this.m[idx][i][j] + (1 - this.beta1) * grad;
947
+
948
+ // LIONS update: param = param - η * sign(m_t + β2 * g_t)
949
+ const update_term = this.m[idx][i][j] + this.beta2 * grad;
950
+
951
+ // Get sign with epsilon for stability
952
+ let sign_val;
953
+ if (update_term > this.eps) sign_val = 1;
954
+ else if (update_term < -this.eps) sign_val = -1;
955
+ else sign_val = 0;
956
+
957
+ let update = sign_val * this.lr;
958
+
959
+ // Add weight decay if specified
960
+ if (this.weight_decay > 0) {
961
+ update += this.weight_decay * this.lr * p.param[i][j];
962
+ }
963
+
964
+ p.param[i][j] -= update;
965
+ }
966
+ }
967
+ }
968
+ });
969
+ }
970
+
971
+ zeroGrad() {
972
+ this.params.forEach(p => {
973
+ if (p.grad) {
974
+ for (let i = 0; i < p.grad.length; i++) {
975
+ for (let j = 0; j < p.grad[i].length; j++) {
976
+ p.grad[i][j] = 0;
977
+ }
575
978
  }
979
+ }
576
980
  });
577
981
  }
578
982
  }
@@ -619,6 +1023,89 @@ export class LambdaLR {
619
1023
  }
620
1024
  }
621
1025
 
1026
+ // ---------------------- ReduceLROnPlateau Scheduler ----------------------
1027
+ export class ReduceLROnPlateau {
1028
+ constructor(optimizer, options = {}) {
1029
+ this.optimizer = optimizer;
1030
+
1031
+ // Destructure with defaults
1032
+ const {
1033
+ patience = 10,
1034
+ factor = 0.5,
1035
+ min_lr = 1e-6,
1036
+ threshold = 1e-4,
1037
+ cooldown = 0,
1038
+ verbose = false
1039
+ } = options;
1040
+
1041
+ this.patience = patience;
1042
+ this.factor = factor;
1043
+ this.min_lr = min_lr;
1044
+ this.threshold = threshold;
1045
+ this.cooldown = cooldown;
1046
+ this.verbose = verbose;
1047
+
1048
+ // State tracking
1049
+ this.bestLoss = Infinity;
1050
+ this.wait = 0;
1051
+ this.cooldown_counter = 0;
1052
+ this.num_reductions = 0;
1053
+ }
1054
+
1055
+ step(loss) {
1056
+ // Handle cooldown
1057
+ if (this.cooldown_counter > 0) {
1058
+ this.cooldown_counter--;
1059
+ return;
1060
+ }
1061
+
1062
+ // Check if this is significant improvement (relative threshold)
1063
+ const improvement_needed = this.bestLoss * (1 - this.threshold);
1064
+ const is_better = loss < improvement_needed;
1065
+
1066
+ if (is_better) {
1067
+ // Significant improvement - reset
1068
+ this.bestLoss = loss;
1069
+ this.wait = 0;
1070
+ } else {
1071
+ // No significant improvement
1072
+ this.wait += 1;
1073
+ }
1074
+
1075
+ // Check if we've waited long enough
1076
+ if (this.wait >= this.patience) {
1077
+ this._reduce_lr();
1078
+ this.cooldown_counter = this.cooldown;
1079
+ this.wait = 0;
1080
+ }
1081
+ }
1082
+
1083
+ _reduce_lr() {
1084
+ const old_lr = this.optimizer.lr;
1085
+ const new_lr = Math.max(old_lr * this.factor, this.min_lr);
1086
+
1087
+ if (new_lr < old_lr) {
1088
+ this.optimizer.lr = new_lr;
1089
+ this.num_reductions++;
1090
+
1091
+ if (this.verbose) {
1092
+ console.log(`ReduceLROnPlateau: reducing LR from ${old_lr} to ${new_lr}`);
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ get_last_lr() {
1098
+ return this.optimizer.lr;
1099
+ }
1100
+
1101
+ reset() {
1102
+ this.bestLoss = Infinity;
1103
+ this.wait = 0;
1104
+ this.cooldown_counter = 0;
1105
+ this.num_reductions = 0;
1106
+ }
1107
+ }
1108
+
622
1109
  // ---------------------- ELU Activation ----------------------
623
1110
  export class ELU {
624
1111
  constructor(alpha=1.0) {
@@ -708,7 +1195,6 @@ export class SiLU {
708
1195
  }
709
1196
  }
710
1197
 
711
- export class SGD{ constructor(params,lr=0.01){ this.params=params; this.lr=lr; } step(){ this.params.forEach(p=>{ for(let i=0;i<p.param.length;i++) for(let j=0;j<(p.param[0].length||1);j++) p.param[i][j]-=this.lr*p.grad[i][j]; }); } }
712
1198
 
713
1199
  // ---------------------- BatchNorm2D ----------------------
714
1200
  export class BatchNorm2d {
@@ -1,23 +0,0 @@
1
- // Example: Test learning rate schedulers (StepLR and LambdaLR) with mini-jstorch optimizers
2
-
3
- import { SGD, StepLR, LambdaLR, Tensor } from "../src/jstorch.js";
4
-
5
- const param = { param: [[1, 2], [3, 4]], grad: [[0, 0], [0, 0]] };
6
- const optimizer = new SGD([param], 0.1);
7
-
8
- // --- Test StepLR ---
9
- console.log("Testing StepLR...");
10
- const stepScheduler = new StepLR(optimizer, 3, 0.5);
11
- for (let epoch = 1; epoch <= 10; epoch++) {
12
- stepScheduler.step();
13
- console.log(`Epoch ${epoch}: LR = ${optimizer.lr.toFixed(4)}`);
14
- }
15
-
16
- // --- Test LambdaLR ---
17
- console.log("\nTesting LambdaLR...");
18
- optimizer.lr = 0.1; // Reset LR
19
- const lambdaScheduler = new LambdaLR(optimizer, epoch => 1.0 / (1 + epoch));
20
- for (let epoch = 1; epoch <= 5; epoch++) {
21
- lambdaScheduler.step();
22
- console.log(`Epoch ${epoch}: LR = ${optimizer.lr.toFixed(4)}`);
23
- }
File without changes