mini-jstorch 1.7.0 → 1.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Docs/About.md CHANGED
@@ -4,19 +4,19 @@
4
4
 
5
5
  ## General Information
6
6
 
7
- - **Project Name:** Mini-JSTorch
8
- - **Internal Name:** JST (JSTorch)
7
+ - **Project Name:** mini-jstorch
8
+ - **Internal Name:** JST (JST-orch)
9
9
 
10
10
  > Note:
11
- > Early versions of Mini-JSTorch do not strictly follow semantic versioning conventions
11
+ > Early versions of JST do not strictly follow semantic versioning conventions
12
12
  > (e.g. `0.0.1` for patches, `0.1.0` for minor releases, `1.0.0` for major releases).
13
13
  > This inconsistency reflects the early learning and experimental phase of the project.
14
14
 
15
15
  ---
16
16
 
17
- ## 1. Engine Architecture Limitations (JSTorch Core)
17
+ ## 1. Engine Architecture Limitations (JST Core)
18
18
 
19
- This section outlines the known structural weaknesses of the JSTorch engine.
19
+ This section outlines the known structural weaknesses of the JST engine.
20
20
  Although the architecture may appear complex, it is currently sensitive and tightly coupled.
21
21
 
22
22
  ### Identified Limitations
@@ -68,7 +68,7 @@ In short, `fu_` exists to ensure safety, clarity, and consistency for end users
68
68
 
69
69
  ## 3. SJK (Shortcut JST Keywords) Reference
70
70
 
71
- This section lists commonly used abbreviations and keywords within the Mini-JSTorch ecosystem.
71
+ This section lists commonly used abbreviations and keywords within the mini-jstorch ecosystem.
72
72
 
73
73
  **Format:** `"KEYWORD" : "Full Name / Meaning"`
74
74
 
package/README.md CHANGED
@@ -1,101 +1,367 @@
1
- # Mini-JSTorch
1
+ ## Mini-JSTorch
2
2
 
3
- A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices, inspired by PyTorch.
4
3
 
5
- ## Overview
4
+ Mini-JSTorch is a lightweight, `dependency-free` JavaScript neural network library designed for `education`, `experimentation`, and `small-scale models`.
5
+ It runs in Node.js and modern browsers, with a simple API inspired by PyTorch-style workflows.
6
6
 
7
- Mini-JSTorch is a high-performance, minimalist JavaScript library for building neural networks. It runs efficiently in Frontend environments, including low-end devices. The library enables quick experimentation and learning in AI without compromising stability, accuracy, or training reliability.
7
+ This project prioritizes `clarity`, `numerical correctness`, and `accessibility` over performance or large-scale production use.
8
8
 
9
- This release, **version 1.7.0:** We introduce **Softmax Layer**, **Tokenizer**, **AdamW Optimizer**, and enhanced NLP capabilities.
9
+ In this version `1.8.0`, we Introduce the **SoftmaxCrossEntropyLoss**, and **BCEWithLogitsLoss**
10
10
 
11
11
  ---
12
12
 
13
- ## New Features Highlights
13
+ # Overview
14
14
 
15
- - **Softmax Layer:** Professional classification output with proper gradient computation
16
- - **Tokenizer:** Lightweight text preprocessing for NLP tasks
17
- - **AdamW Optimizer:** Modern optimizer with decoupled weight decay
15
+ **Mini-JSTorch provides a minimal neural network engine implemented entirely in plain JavaScript.**
16
+
17
+ *It is intended for:*
18
+
19
+ - learning how neural networks work internally
20
+ - experimenting with small models
21
+ - running simple training loops in the browser
22
+ - environments where large frameworks are unnecessary or unavailable
23
+
24
+ `Mini-JSTorch is NOT a replacement for PyTorch, TensorFlow, or TensorFlow.js.`
25
+
26
+ `It is intentionally scoped to remain small, readable, and easy to debug.`
18
27
 
19
28
  ---
20
29
 
21
- ## Core Features
30
+ # Key Characteristics
22
31
 
23
- - **Layers:** Linear, Flatten, Conv2D
24
- - **Activations:** ReLU, Sigmoid, Tanh, LeakyReLU, GELU, Mish, SiLU, ELU
25
- - **Loss Functions:** MSELoss, CrossEntropyLoss
26
- - **Optimizers:** Adam, SGD, LION, **AdamW**
27
- - **Schedulers:** StepLR, LambdaLR, ReduceLROnPlateau
28
- - **Regularization:** Dropout, BatchNorm2D
29
- - **Utilities:** zeros, randomMatrix, softmax, crossEntropy, dot, addMatrices, reshape, stack, flatten, eye, concat
30
- - **Model Container:** Sequential (for stacking layers with forward/backward passes)
32
+ - Zero dependencies
33
+ - ESM-first (`type: module`)
34
+ - Works in Node.js and browser environments
35
+ - Explicit, manual forward and backward passes
36
+ - Focused on 2D training logic (`[batch][features]`)
37
+ - Designed for educational and experimental use
31
38
 
32
- # Others
39
+ ---
33
40
 
34
- - **Tokenizer**
35
- - **Softmax Layer**
41
+ # Browser Support
42
+
43
+ Now, Mini-JSTorch can be used directly in browsers:
44
+
45
+ - via ESM imports
46
+ - via CDN / `<script>` with a global `JST` object
47
+
48
+ This makes it suitable for:
49
+
50
+ - demos
51
+ - learning environments
52
+ - lightweight frontend experiments
53
+
54
+ Here example code to make a simple Model with JSTorch.
55
+ In Browser/Website:
56
+
57
+ ```html
58
+ <!DOCTYPE html>
59
+ <html>
60
+ <body>
61
+ <div id="output">
62
+ <p>Status: <span id="status">Initializing...</span></p>
63
+ <div id="training-log"></div>
64
+ <div id="results" style="margin-top: 20px;"></div>
65
+ </div>
66
+
67
+ <script type="module">
68
+ import { Sequential, Linear, ReLU, MSELoss, Adam, StepLR, Tanh } from 'https://unpkg.com/jstorch'; // DO NOT CHANGE
69
+
70
+ const statusEl = document.getElementById('status');
71
+ const trainingLogEl = document.getElementById('training-log');
72
+ const resultsEl = document.getElementById('results');
73
+
74
+ async function trainModel() {
75
+ try {
76
+ statusEl.textContent = 'Creating model...';
77
+
78
+ const model = new Sequential([
79
+ new Linear(2, 16),
80
+ new Tanh(),
81
+ new Linear(16, 8),
82
+ new ReLU(),
83
+ new Linear(8, 1)
84
+ ]);
85
+
86
+ const X = [[0,0], [0,1], [1,0], [1,1]];
87
+ const y = [[0], [1], [1], [0]];
88
+
89
+ const criterion = new MSELoss();
90
+ const optimizer = new Adam(model.parameters(), 0.1);
91
+ const scheduler = new StepLR(optimizer, 25, 0.5);
92
+
93
+ trainingLogEl.innerHTML = '<h4>Training Progress:</h4>';
94
+ const logList = document.createElement('ul');
95
+ trainingLogEl.appendChild(logList);
96
+
97
+ statusEl.textContent = 'Training...';
98
+
99
+ for (let epoch = 0; epoch < 1000; epoch++) {
100
+ const pred = model.forward(X);
101
+ const loss = criterion.forward(pred, y);
102
+ const grad = criterion.backward();
103
+ model.backward(grad);
104
+ optimizer.step();
105
+ scheduler.step();
106
+
107
+ if (epoch % 100 === 0) {
108
+ const logItem = document.createElement('li');
109
+ logItem.textContent = `Epoch ${epoch}: Loss = ${loss.toFixed(6)}`;
110
+ logList.appendChild(logItem);
111
+
112
+ // Update status every 100 epochs
113
+ statusEl.textContent = `Training... Epoch ${epoch}/1000 (Loss: ${loss.toFixed(6)})`;
114
+
115
+ await new Promise(resolve => setTimeout(resolve, 10));
116
+ }
117
+ }
118
+
119
+ statusEl.textContent = 'Training completed!';
120
+ statusEl.style.color = 'green';
121
+
122
+ resultsEl.innerHTML = '<h4>XOR Predictions:</h4>';
123
+ const resultsTable = document.createElement('table');
124
+ resultsTable.style.border = '1px solid #ccc';
125
+ resultsTable.style.borderCollapse = 'collapse';
126
+ resultsTable.style.width = '300px';
127
+
128
+ // Table header
129
+ const headerRow = document.createElement('tr');
130
+ ['Input A', 'Input B', 'Prediction', 'Target'].forEach(text => {
131
+ const th = document.createElement('th');
132
+ th.textContent = text;
133
+ th.style.border = '1px solid #ccc';
134
+ th.style.padding = '8px';
135
+ headerRow.appendChild(th);
136
+ });
137
+ resultsTable.appendChild(headerRow);
138
+
139
+ const predictions = model.forward(X);
140
+ predictions.forEach((pred, i) => {
141
+ const row = document.createElement('tr');
142
+
143
+ const cell1 = document.createElement('td');
144
+ cell1.textContent = X[i][0];
145
+ cell1.style.border = '1px solid #ccc';
146
+ cell1.style.padding = '8px';
147
+ cell1.style.textAlign = 'center';
148
+
149
+ const cell2 = document.createElement('td');
150
+ cell2.textContent = X[i][1];
151
+ cell2.style.border = '1px solid #ccc';
152
+ cell2.style.padding = '8px';
153
+ cell2.style.textAlign = 'center';
154
+
155
+ const cell3 = document.createElement('td');
156
+ cell3.textContent = pred[0].toFixed(4);
157
+ cell3.style.border = '1px solid #ccc';
158
+ cell3.style.padding = '8px';
159
+ cell3.style.textAlign = 'center';
160
+ cell3.style.fontWeight = 'bold';
161
+ cell3.style.color = Math.abs(pred[0] - y[i][0]) < 0.1 ? 'green' : 'red';
162
+
163
+ const cell4 = document.createElement('td');
164
+ cell4.textContent = y[i][0];
165
+ cell4.style.border = '1px solid #ccc';
166
+ cell4.style.padding = '8px';
167
+ cell4.style.textAlign = 'center';
168
+
169
+ row.appendChild(cell1);
170
+ row.appendChild(cell2);
171
+ row.appendChild(cell3);
172
+ row.appendChild(cell4);
173
+ resultsTable.appendChild(row);
174
+ });
175
+
176
+ resultsEl.appendChild(resultsTable);
177
+
178
+ const summary = document.createElement('div');
179
+ summary.style.marginTop = '20px';
180
+ summary.style.padding = '10px';
181
+ summary.style.backgroundColor = '#f0f0f0';
182
+ summary.style.borderRadius = '5px';
183
+ summary.innerHTML = `
184
+ <p><strong>Model Architecture:</strong> 2 → 16 → 8 → 1</p>
185
+ <p><strong>Activation:</strong> Tanh → ReLU</p>
186
+ <p><strong>Loss Function:</strong> MSE</p>
187
+ <p><strong>Optimizer:</strong> Adam (LR: 0.1)</p>
188
+ <p><strong>Epochs:</strong> 1000</p>
189
+ `;
190
+ resultsEl.appendChild(summary);
191
+
192
+ } catch (error) {
193
+ statusEl.textContent = `Error: ${error.message}`;
194
+ statusEl.style.color = 'red';
195
+ console.error(error);
196
+ }
197
+ }
198
+
199
+ trainModel();
200
+ </script>
201
+ </body>
202
+ </html>
203
+ ```
36
204
 
37
205
  ---
38
206
 
39
- ## Installation
207
+ # Core Features
208
+
209
+ # Layers
210
+
211
+ - Linear
212
+ - Flatten
213
+ - Conv2D (*experimental*)
214
+
215
+ # Activations
216
+
217
+ - ReLU
218
+ - Sigmoid
219
+ - Tanh
220
+ - LeakyReLU
221
+ - GELU
222
+ - Mish
223
+ - SiLU
224
+ - ELU
225
+
226
+ # Loss Functions
227
+
228
+ - MSELoss
229
+ - CrossEntropyLoss (*legacy*)
230
+ - SoftmaxCrossEntropyLoss (**recommended**)
231
+ - BCEWithLogitsLoss (**recommended**)
232
+
233
+ # Optimizers
234
+
235
+ - SGD
236
+ - Adam
237
+ - AdamW
238
+ - Lion
239
+
240
+ # Learning Rate Schedulers
241
+
242
+ - StepLR
243
+ - LambdaLR
244
+ - ReduceLROnPlateau
245
+ - Regularization
246
+ - Dropout (*basic*, *educational*)
247
+ - BatchNorm2D (*experimental*)
248
+
249
+ # Utilities
250
+
251
+ - zeros
252
+ - randomMatrix
253
+ - dot
254
+ - addMatrices
255
+ - reshape
256
+ - stack
257
+ - flatten
258
+ - concat
259
+ - softmax
260
+ - crossEntropy
261
+
262
+ # Model Container
263
+
264
+ - Sequential
265
+
266
+ ---
267
+
268
+ # Installation
40
269
 
41
270
  ```bash
42
271
  npm install mini-jstorch
43
- # Node.js v20+ recommended for best performance
44
272
  ```
273
+ Node.js v18+ or any modern browser with ES module support is recommended.
45
274
 
46
275
  ---
47
276
 
48
- ## Quick Start Example
277
+ # Quick Start (Recommended Loss)
278
+
279
+ # Multi-class Classification (SoftmaxCrossEntropy)
49
280
 
50
281
  ```javascript
51
- import { Sequential, Linear, ReLU, Sigmoid, CrossEntropyLoss, Adam, StepLR } from './src/jstorch.js';
282
+ import {
283
+ Sequential,
284
+ Linear,
285
+ ReLU,
286
+ SoftmaxCrossEntropyLoss,
287
+ Adam
288
+ } from "./src/jstorch.js";
52
289
 
53
- // Build model
54
290
  const model = new Sequential([
55
- new Linear(2,4),
291
+ new Linear(2, 4),
56
292
  new ReLU(),
57
- new Linear(4,2),
58
- new Sigmoid()
293
+ new Linear(4, 2) // logits output
59
294
  ]);
60
295
 
61
- // Sample XOR dataset
62
296
  const X = [
63
297
  [0,0], [0,1], [1,0], [1,1]
64
298
  ];
299
+
65
300
  const Y = [
66
301
  [1,0], [0,1], [0,1], [1,0]
67
302
  ];
68
303
 
69
- // Loss & optimizer
70
- const lossFn = new CrossEntropyLoss();
304
+ const lossFn = new SoftmaxCrossEntropyLoss();
71
305
  const optimizer = new Adam(model.parameters(), 0.1);
72
- const scheduler = new StepLR(optimizer, 20, 0.5); // Halve LR every 20 epochs
73
-
74
- // Training loop
75
- for (let epoch = 1; epoch <= 100; epoch++) {
76
- const pred = model.forward(X);
77
- const loss = lossFn.forward(pred, Y);
78
- const gradLoss = lossFn.backward();
79
- model.backward(gradLoss);
306
+
307
+ for (let epoch = 1; epoch <= 300; epoch++) {
308
+ const logits = model.forward(X);
309
+ const loss = lossFn.forward(logits, Y);
310
+ const grad = lossFn.backward();
311
+ model.backward(grad);
80
312
  optimizer.step();
81
- scheduler.step();
82
- if (epoch % 20 === 0) console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(4)}, LR: ${optimizer.lr.toFixed(4)}`);
313
+
314
+ if (epoch % 50 === 0) {
315
+ console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(4)}`);
316
+ }
83
317
  }
318
+ ```
319
+ Do not combine `SoftmaxCrossEntropyLoss` with a `Softmax` layer.
320
+
321
+ # Binary Classifiaction (BCEWithLogitsLoss)
84
322
 
85
- // Prediction
86
- const predTest = model.forward(X);
87
- predTest.forEach((p,i) => {
88
- const predictedClass = p.indexOf(Math.max(...p));
89
- console.log(`Input: ${X[i]}, Predicted class: ${predictedClass}, Raw output: ${p.map(v => v.toFixed(3))}`);
90
- });
323
+ ```javascript
324
+ import {
325
+ Sequential,
326
+ Linear,
327
+ ReLU,
328
+ BCEWithLogitsLoss,
329
+ Adam
330
+ } from "./src/jstorch.js";
331
+
332
+ const model = new Sequential([
333
+ new Linear(2, 4),
334
+ new ReLU(),
335
+ new Linear(4, 1) // logit
336
+ ]);
337
+
338
+ const X = [
339
+ [0,0], [0,1], [1,0], [1,1]
340
+ ];
341
+
342
+ const Y = [
343
+ [0], [1], [1], [0]
344
+ ];
345
+
346
+ const lossFn = new BCEWithLogitsLoss();
347
+ const optimizer = new Adam(model.parameters(), 0.1);
348
+
349
+ for (let epoch = 1; epoch <= 300; epoch++) {
350
+ const logits = model.forward(X);
351
+ const loss = lossFn.forward(logits, Y);
352
+ const grad = lossFn.backward();
353
+ model.backward(grad);
354
+ optimizer.step();
355
+ }
91
356
  ```
357
+ Do not combine `BCEWithLogitsLoss` with a `Sigmoid` layer.
92
358
 
93
359
  ---
94
360
 
95
- ## Save & Load Models
361
+ # Save & Load Models
96
362
 
97
363
  ```javascript
98
- import { saveModel, loadModel, Sequential } from '.jstorch.js';
364
+ import { saveModel, loadModel, Sequential } from "mini-jstorch";
99
365
 
100
366
  const json = saveModel(model);
101
367
  const model2 = new Sequential([...]); // same architecture
@@ -104,13 +370,12 @@ loadModel(model2, json);
104
370
 
105
371
  ---
106
372
 
107
- ## Demos & Testing
373
+ # Demos
108
374
 
109
- Check the `demo/` directory for ready-to-run demos:
110
- - **demo/MakeModel.js:** Build and run a simple neural network.
111
- - **demo/scheduler.js:** Experiment with learning rate schedulers.
112
- - **demo/fu_fun.js:** Test all user-friendly (fu or For Users/Friendly Users) functions
113
- - Add your own scripts for quick prototyping!
375
+ See the `demo/` directory for runnable examples:
376
+ - `demo/MakeModel.js` simple training loop
377
+ - `demo/scheduler.js` learning rate schedulers
378
+ - `demo/fu_fun.js` utility functions
114
379
 
115
380
  ```bash
116
381
  node demo/MakeModel.js
@@ -120,17 +385,30 @@ node demo/fu_fun.js
120
385
 
121
386
  ---
122
387
 
123
- ## Intended Use Cases
388
+ # Design Notes & Limitations
124
389
 
125
- - Rapid prototyping of neural networks in frontend.
126
- - Learning and teaching foundational neural network concepts.
127
- - Experimentation on low-end devices or mobile browsers.
128
- - Lightweight AI projects without GPU dependency.
390
+ - Training logic is 2D-first: `[batch][features]`
391
+ - Higher-dimensional data is reshaped internally by specific layers (e.g. Conv2D, Flatten)
392
+ - No automatic broadcasting or autograd graph
393
+ - Some components (Conv2D, BatchNorm2D, Dropout) are educational / experimental
394
+ - Not intended for large-scale or production ML workloads
129
395
 
130
396
  ---
131
397
 
398
+ # Intended Use Cases
399
+
400
+ - Learning how neural networks work internally
401
+ - Teaching ML fundamentals
402
+ - Small experiments in Node.js or the browser
403
+ - Lightweight AI demos without GPU or large frameworks
404
+
405
+ ---
406
+
132
407
  # License
133
408
 
134
- `MIT License`
409
+ MIT License
410
+
411
+ Copyright (c) 2024
412
+ rizal-editors
135
413
 
136
- **Copyright (c) 2024 rizal-editors**
414
+ ---
package/index.js CHANGED
@@ -1,3 +1,10 @@
1
1
  // package root
2
- // * = all exports from src/jstorch.js
3
- export * from "./src/jstorch.js";
2
+
3
+ // provide JST in browser global scope
4
+ import * as JST from './src/jstorch.js';
5
+
6
+ if (typeof window !== 'undefined') {
7
+ window.JST = JST; // Global JST (JSTorch) object
8
+ }
9
+
10
+ export * from './src/jstorch.js';
package/package.json CHANGED
@@ -1,29 +1,19 @@
1
1
  {
2
2
  "name": "mini-jstorch",
3
- "version": "1.7.0",
3
+ "version": "1.8.0",
4
4
  "type": "module",
5
5
  "description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
6
6
  "main": "index.js",
7
7
  "keywords": [
8
- "neural-network",
9
- "javascript",
10
- "lightweight-torch",
11
- "lightweight",
12
- "small-torch",
8
+ "lightweight-ml",
13
9
  "javascript-torch",
14
- "ai-model",
15
- "jstorch",
16
- "pytorch",
17
- "front-end",
18
- "machine-learning",
10
+ "front-end-torch",
19
11
  "tiny-ml",
20
- "frontend-ai",
21
- "mini-neural-network"
12
+ "mini-neural-network",
13
+ "mini-ml-library",
14
+ "mini-js-ml",
15
+ "educational-ml"
22
16
  ],
23
17
  "author": "Rizal",
24
- "license": "MIT",
25
- "repository": {
26
- "type": "git",
27
- "url": "https://github.com/rizal-editors/mini-jstorch.git"
28
- }
18
+ "license": "MIT"
29
19
  }
package/src/jstorch.js CHANGED
@@ -1,9 +1,8 @@
1
1
  /*!
2
- * Project: mini-jstorch
3
2
  * File: jstorch.js
4
- * Author: Rizal-editors
3
+ * Author: rizal-editors
5
4
  * License: MIT
6
- * Copyright (C) 2025 Rizal-editors
5
+ * Copyright (C) 2025 rizal-editors
7
6
  *
8
7
  * Permission is hereby granted, free of charge, to any person obtaining a copy
9
8
  * of this software and associated documentation files (the "Software"), to deal
@@ -24,7 +23,13 @@
24
23
  * SOFTWARE.
25
24
  */
26
25
 
27
- // ---------------------- DONOT USE THESE (ENGINE INTERNALS) ----------------------
26
+ // --------------------------------------------------------------
27
+ // PLEASE READ THE README.md FILE BEFORE USING THIS PACKAGE!
28
+ // This package is designed to be used in a Node.js/browser environment.
29
+ // See the Documentation for more details.
30
+ // --------------------------------------------------------------
31
+
32
+ // ---------------------- DONOT USE THESE (ENGINE INTERNALS) ERROR/BUG ARE EXPECTED ----------------------
28
33
  export function zeros(rows, cols) {
29
34
  return Array.from({length:rows},()=>Array(cols).fill(0));
30
35
  }
@@ -74,7 +79,7 @@ export function crossEntropy(pred,target){
74
79
  return -target.reduce((sum,t,i)=>sum+t*Math.log(pred[i]+eps),0);
75
80
  }
76
81
 
77
- // ---------------------- USERS FRIENDLY UTILS (USE THIS!) ----------------
82
+ // ---------------------- USERS FRIENDLY UTILS (USE THIS FOR YOUR UTILS!) ----------------
78
83
  export function fu_tensor(data, requiresGrad = false) {
79
84
  if (!Array.isArray(data) || !Array.isArray(data[0])) {
80
85
  throw new Error("fu_tensor: Data must be 2D array");
@@ -745,6 +750,69 @@ export class Dropout{ constructor(p=0.5){ this.p=p; } forward(x){ return x.map(r
745
750
  export class MSELoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((row,i)=>row.reduce((sum,v,j)=>sum+(v-target[i][j])**2,0)/row.length); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((row,i)=>row.map((v,j)=>2*(v-this.target[i][j])/row.length)); } }
746
751
  export class CrossEntropyLoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((p,i)=>crossEntropy(softmax(p),target[i])); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((p,i)=>{ const s=softmax(p); return s.map((v,j)=>(v-this.target[i][j])/this.pred.length); }); } }
747
752
 
753
+ export class SoftmaxCrossEntropyLoss {
754
+ forward(logits, targets) {
755
+ this.targets = targets;
756
+ const batch = logits.length;
757
+
758
+ // stable softmax
759
+ this.probs = logits.map(row => {
760
+ const max = Math.max(...row);
761
+ const exps = row.map(v => Math.exp(v - max));
762
+ const sum = exps.reduce((a,b)=>a+b, 0);
763
+ return exps.map(v => v / sum);
764
+ });
765
+
766
+ let loss = 0;
767
+ for (let i = 0; i < batch; i++) {
768
+ for (let j = 0; j < this.probs[i].length; j++) {
769
+ if (targets[i][j] === 1) {
770
+ loss -= Math.log(this.probs[i][j] + 1e-12);
771
+ }
772
+ }
773
+ }
774
+
775
+ return loss / batch;
776
+ }
777
+
778
+ backward() {
779
+ const batch = this.targets.length;
780
+ return this.probs.map((row,i) =>
781
+ row.map((p,j) => (p - this.targets[i][j]) / batch)
782
+ );
783
+ }
784
+ }
785
+
786
+ export class BCEWithLogitsLoss {
787
+ forward(logits, targets) {
788
+ this.logits = logits;
789
+ this.targets = targets;
790
+ const batch = logits.length;
791
+ let loss = 0;
792
+
793
+ for (let i = 0; i < batch; i++) {
794
+ for (let j = 0; j < logits[i].length; j++) {
795
+ const x = logits[i][j];
796
+ const y = targets[i][j];
797
+ // stable BCE
798
+ loss += Math.max(x, 0) - x*y + Math.log(1 + Math.exp(-Math.abs(x)));
799
+ }
800
+ }
801
+
802
+ return loss / batch;
803
+ }
804
+
805
+ backward() {
806
+ const batch = this.logits.length;
807
+ return this.logits.map((row,i) =>
808
+ row.map((x,j) => {
809
+ const sigmoid = 1 / (1 + Math.exp(-x));
810
+ return (sigmoid - this.targets[i][j]) / batch;
811
+ })
812
+ );
813
+ }
814
+ }
815
+
748
816
  // ---------------------- Optimizers ----------------------
749
817
  export class Adam{
750
818
  constructor(params, lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-8, max_grad_norm = 1.0){