mini-jstorch 2.0.1 → 2.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Docs/API.md +277 -0
- package/README.md +19 -16
- package/demo/linear_regression.js +1 -1
- package/package.json +2 -3
- package/src/jstorch.js +380 -411
package/Docs/API.md
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# API Reference
|
|
2
|
+
|
|
3
|
+
## Model Container
|
|
4
|
+
|
|
5
|
+
### Sequential
|
|
6
|
+
|
|
7
|
+
```js
|
|
8
|
+
new Sequential(layers: Layer[])
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Container that chains layers sequentially.
|
|
12
|
+
|
|
13
|
+
**Methods:**
|
|
14
|
+
- forward(x) — Pass input through all layers
|
|
15
|
+
- backward(grad) — Backpropagate gradient through all layers
|
|
16
|
+
- parameters() — Returns [{param, grad}, ...] for all trainable parameters
|
|
17
|
+
- zeroGrad() — Zero all parameter gradients
|
|
18
|
+
- train() — Set all layers to training mode
|
|
19
|
+
- eval() — Set all layers to evaluation mode
|
|
20
|
+
- stateDict() — Get {layer_0.weight, layer_0.bias, ...}
|
|
21
|
+
- loadStateDict(dict) — Load weights from state dict object
|
|
22
|
+
- step(lr) — Apply SGD step to all layers directly
|
|
23
|
+
|
|
24
|
+
---
|
|
25
|
+
|
|
26
|
+
## Layers
|
|
27
|
+
|
|
28
|
+
### Linear
|
|
29
|
+
|
|
30
|
+
```js
|
|
31
|
+
new Linear(inFeatures: number, outFeatures: number)
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Fully connected layer. Weight shape: [inFeatures, outFeatures]. Bias shape: [1, outFeatures].
|
|
35
|
+
|
|
36
|
+
### Conv2D (experimental)
|
|
37
|
+
|
|
38
|
+
```js
|
|
39
|
+
new Conv2D(inChannels: number, outChannels: number, kernelSize: number, stride?: number, padding?: number)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
2D convolution layer. Input shape: [batch, channels, height, width].
|
|
43
|
+
|
|
44
|
+
### Flatten
|
|
45
|
+
|
|
46
|
+
```js
|
|
47
|
+
new Flatten()
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
Flattens multi-dimensional input per sample. Preserves batch dimension.
|
|
51
|
+
|
|
52
|
+
---
|
|
53
|
+
|
|
54
|
+
## Activations
|
|
55
|
+
|
|
56
|
+
All activations follow the same interface:
|
|
57
|
+
- forward(x) — x: [batch, features], returns: [batch, features]
|
|
58
|
+
- backward(grad) — grad: [batch, features], returns: [batch, features]
|
|
59
|
+
|
|
60
|
+
| Class | Formula |
|
|
61
|
+
|-------|---------|
|
|
62
|
+
| ReLU() | max(0, x) |
|
|
63
|
+
| Sigmoid() | 1 / (1 + exp(-x)) |
|
|
64
|
+
| Tanh() | tanh(x) |
|
|
65
|
+
| LeakyReLU(alpha?) | x > 0 ? x : alpha * x |
|
|
66
|
+
| GELU() | 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) |
|
|
67
|
+
| ELU(alpha?) | x > 0 ? x : alpha * (exp(x) - 1) |
|
|
68
|
+
| Mish() | x * tanh(ln(1 + exp(x))) |
|
|
69
|
+
| SiLU() | x * sigmoid(x) |
|
|
70
|
+
| Softmax(dim?) | exp(x - max) / sum(exp(x - max)) |
|
|
71
|
+
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## Loss Functions
|
|
75
|
+
|
|
76
|
+
### MSELoss
|
|
77
|
+
|
|
78
|
+
```js
|
|
79
|
+
new MSELoss()
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
loss = mean((pred - target)^2)
|
|
83
|
+
gradient = 2 * (pred - target) / batchSize
|
|
84
|
+
|
|
85
|
+
### SoftmaxCrossEntropyLoss (recommended for classification)
|
|
86
|
+
|
|
87
|
+
```js
|
|
88
|
+
new SoftmaxCrossEntropyLoss()
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
Combines softmax + cross-entropy in a numerically stable way. Input: logits (not probabilities). Do NOT combine with a Softmax layer.
|
|
92
|
+
|
|
93
|
+
### BCEWithLogitsLoss (recommended for binary classification)
|
|
94
|
+
|
|
95
|
+
```js
|
|
96
|
+
new BCEWithLogitsLoss()
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Combines sigmoid + binary cross-entropy. Numerically stable. Input: logits (not probabilities). Do NOT combine with a Sigmoid layer.
|
|
100
|
+
|
|
101
|
+
### CrossEntropyLoss (deprecated)
|
|
102
|
+
|
|
103
|
+
```js
|
|
104
|
+
new CrossEntropyLoss()
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
Use SoftmaxCrossEntropyLoss instead. Exists for backward compatibility.
|
|
108
|
+
|
|
109
|
+
---
|
|
110
|
+
|
|
111
|
+
## Optimizers
|
|
112
|
+
|
|
113
|
+
### Adam (recommended)
|
|
114
|
+
|
|
115
|
+
```js
|
|
116
|
+
new Adam(parameters, options)
|
|
117
|
+
// or
|
|
118
|
+
new Adam(parameters, lr, beta1, beta2, eps, maxGradNorm)
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
Options object: { lr: 0.001, b1: 0.9, b2: 0.999, eps: 1e-8, max_grad_norm: 1.0 }
|
|
122
|
+
|
|
123
|
+
### AdamW
|
|
124
|
+
|
|
125
|
+
```js
|
|
126
|
+
new AdamW(parameters, options)
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
Adam with decoupled weight decay. Options include weight_decay (default: 0.01).
|
|
130
|
+
|
|
131
|
+
### SGD
|
|
132
|
+
|
|
133
|
+
```js
|
|
134
|
+
new SGD(parameters, lr?, maxGradNorm?)
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
### Lion
|
|
138
|
+
|
|
139
|
+
```js
|
|
140
|
+
new LION(parameters, options)
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
Memory-efficient optimizer.
|
|
144
|
+
|
|
145
|
+
**All optimizers have:**
|
|
146
|
+
- step() — Update parameters using accumulated gradients
|
|
147
|
+
- zeroGrad() — Zero all parameter gradients
|
|
148
|
+
|
|
149
|
+
---
|
|
150
|
+
|
|
151
|
+
## Learning Rate Schedulers
|
|
152
|
+
|
|
153
|
+
### StepLR
|
|
154
|
+
|
|
155
|
+
```js
|
|
156
|
+
new StepLR(optimizer, stepSize, gamma)
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
Multiplies LR by gamma every stepSize steps.
|
|
160
|
+
|
|
161
|
+
### LambdaLR
|
|
162
|
+
|
|
163
|
+
```js
|
|
164
|
+
new LambdaLR(optimizer, fn)
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
Sets LR = baseLr * fn(epoch) on each step.
|
|
168
|
+
|
|
169
|
+
### ReduceLROnPlateau
|
|
170
|
+
|
|
171
|
+
```js
|
|
172
|
+
new ReduceLROnPlateau(optimizer, options)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
Reduces LR when loss stops improving. Options: patience, factor, min_lr, threshold, cooldown, verbose.
|
|
176
|
+
|
|
177
|
+
---
|
|
178
|
+
|
|
179
|
+
## Regularization
|
|
180
|
+
|
|
181
|
+
### Dropout
|
|
182
|
+
|
|
183
|
+
```js
|
|
184
|
+
new Dropout(p?)
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
Randomly zeros p fraction of neurons during training. Default p = 0.5. Call layer.train() / layer.eval() to toggle.
|
|
188
|
+
|
|
189
|
+
### BatchNorm2d (experimental)
|
|
190
|
+
|
|
191
|
+
```js
|
|
192
|
+
new BatchNorm2d(numFeatures, eps?, momentum?, affine?)
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
2D batch normalization. Input shape: [batch, channels, height, width].
|
|
196
|
+
|
|
197
|
+
---
|
|
198
|
+
|
|
199
|
+
## Tokenizer
|
|
200
|
+
|
|
201
|
+
```js
|
|
202
|
+
new Tokenizer(vocabSize?)
|
|
203
|
+
```
|
|
204
|
+
|
|
205
|
+
| Method | Description |
|
|
206
|
+
|--------|-------------|
|
|
207
|
+
| fit(texts) | Build vocabulary from text array |
|
|
208
|
+
| transform(texts, maxLength?, padToMax?) | Convert texts to token indices |
|
|
209
|
+
| fitTransform(texts, maxLength?, padToMax?) | Fit + transform in one call |
|
|
210
|
+
| inverseTransform(tokens, skipPad?) | Convert token indices back to text |
|
|
211
|
+
| getVocabulary() | Get vocabulary as string array |
|
|
212
|
+
| getVocabSize() | Get vocabulary size |
|
|
213
|
+
| getWordCounts() | Get word frequency map |
|
|
214
|
+
| getMostCommon(n?) | Get top N most frequent words |
|
|
215
|
+
|
|
216
|
+
---
|
|
217
|
+
|
|
218
|
+
## Utilities
|
|
219
|
+
|
|
220
|
+
zeros(rows, cols) — Matrix filled with 0
|
|
221
|
+
ones(rows, cols) — Matrix filled with 1
|
|
222
|
+
randomMatrix(rows, cols, scale?) — Random matrix (Xavier init if scale omitted)
|
|
223
|
+
transpose(matrix) — Matrix transpose
|
|
224
|
+
dot(a, b) — Matrix multiplication
|
|
225
|
+
addMatrices(a, b) — Element-wise addition
|
|
226
|
+
softmax(vector) — Softmax on 1D array
|
|
227
|
+
crossEntropy(pred, target) — Cross-entropy loss (scalar)
|
|
228
|
+
reshape(tensor, rows, cols) — Reshape to new dimensions
|
|
229
|
+
flattenBatch(batch) — Flatten batch to 2D
|
|
230
|
+
concat(a, b, axis) — Concatenate along axis 0 or 1
|
|
231
|
+
stack(tensors) — Stack tensors
|
|
232
|
+
|
|
233
|
+
---
|
|
234
|
+
|
|
235
|
+
## Model Persistence
|
|
236
|
+
|
|
237
|
+
saveModel(model) — Serialize model to JSON string
|
|
238
|
+
loadModel(model, json) — Load weights from JSON into model
|
|
239
|
+
|
|
240
|
+
Supports all layer types. Validates layer types and shapes. Logs warnings for mismatches.
|
|
241
|
+
|
|
242
|
+
---
|
|
243
|
+
|
|
244
|
+
## Tensor (Advanced)
|
|
245
|
+
|
|
246
|
+
```js
|
|
247
|
+
new Tensor(data, requiresGrad?)
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
| Method | Description |
|
|
251
|
+
|--------|-------------|
|
|
252
|
+
| add(tensor) | Element-wise addition |
|
|
253
|
+
| mul(tensor) | Element-wise multiplication |
|
|
254
|
+
| matmul(tensor) | Matrix multiplication |
|
|
255
|
+
| transpose() | Transpose tensor |
|
|
256
|
+
| flatten() | Flatten to 1D array |
|
|
257
|
+
| shape() | Returns [rows, cols] |
|
|
258
|
+
|
|
259
|
+
Static: Tensor.zeros(r,c), Tensor.ones(r,c), Tensor.random(r,c,scale?)
|
|
260
|
+
|
|
261
|
+
---
|
|
262
|
+
|
|
263
|
+
## User-Friendly Utilities (fu_*)
|
|
264
|
+
|
|
265
|
+
fu_tensor(data, requiresGrad?) — Create tensor from 2D array
|
|
266
|
+
fu_add(a, b) — Element-wise add
|
|
267
|
+
fu_mul(a, b) — Element-wise multiply
|
|
268
|
+
fu_matmul(a, b) — Matrix multiply
|
|
269
|
+
fu_sum(tensor) — Sum all elements
|
|
270
|
+
fu_mean(tensor) — Mean of all elements
|
|
271
|
+
fu_relu(tensor) — ReLU activation
|
|
272
|
+
fu_sigmoid(tensor) — Sigmoid activation
|
|
273
|
+
fu_tanh(tensor) — Tanh activation
|
|
274
|
+
fu_softmax(tensor) — Softmax activation
|
|
275
|
+
fu_flatten(tensor) — Flatten to 1D
|
|
276
|
+
fu_reshape(tensor, rows, cols) — Reshape tensor
|
|
277
|
+
fu_stack(tensors) — Stack tensors
|
package/README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
## Mini-JSTorch (
|
|
1
|
+
## Mini-JSTorch (v2.0.2)
|
|
2
2
|
|
|
3
3
|
---
|
|
4
4
|
|
|
@@ -7,15 +7,14 @@ It runs in Node.js and modern browsers, with a simple API inspired by PyTorch-st
|
|
|
7
7
|
|
|
8
8
|
This project prioritizes `clarity`, `numerical correctness`, and `accessibility` over performance or large-scale production use.
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
- **Fixed Linear layer cache** (critical bug fix for training)
|
|
12
|
-
- **Fixed GELU gradient calculation**
|
|
13
|
-
- **Fixed MSELoss gradient scaling**
|
|
14
|
-
- **Optimized Softmax gradient** (O(n²) → O(n))
|
|
15
|
-
- **Improved Tokenizer** with proper PAD/UNK separation
|
|
16
|
-
- **Added Sequential.zeroGrad(), train(), eval(), stateDict() methods**
|
|
10
|
+
## Changelog
|
|
17
11
|
|
|
18
|
-
|
|
12
|
+
**v2.0.2:**
|
|
13
|
+
- **Fixed critical training bug:** Optimizers (Adam, SGD, AdamW, Lion) now correctly update Linear and Conv2D layer weights
|
|
14
|
+
- **Fixed BatchNorm2d:** Inference mode no longer produces NaN for multi-channel inputs
|
|
15
|
+
- **Fixed ELU activation:** Backward pass now uses correct derivative formula
|
|
16
|
+
- **Fixed saveModel/loadModel:** Now correctly saves and restores all layer types including Conv2D and BatchNorm2d
|
|
17
|
+
- **Fixed BatchNorm2d gradient zeroing:** gradWeight/gradBias now correctly reset between batches
|
|
19
18
|
|
|
20
19
|
**⚠️ BREAKING CHANGES in v2.0.0:**
|
|
21
20
|
- Tokenizer API: `tokenizeBatch()` → `transform()`, `detokenizeBatch()` → `inverseTransform()`
|
|
@@ -80,7 +79,6 @@ In Browser/Website:
|
|
|
80
79
|
async function train() {
|
|
81
80
|
const statusEl = document.getElementById('status');
|
|
82
81
|
const logEl = document.getElementById('log');
|
|
83
|
-
|
|
84
82
|
try {
|
|
85
83
|
const model = new Sequential([
|
|
86
84
|
new Linear(2, 16), new Tanh(),
|
|
@@ -107,13 +105,13 @@ In Browser/Website:
|
|
|
107
105
|
}
|
|
108
106
|
}
|
|
109
107
|
|
|
110
|
-
statusEl.textContent = '
|
|
108
|
+
statusEl.textContent = 'Done';
|
|
111
109
|
const preds = model.forward(X);
|
|
112
110
|
document.getElementById('res').innerHTML = `<h4>Results:</h4>` +
|
|
113
111
|
X.map((input, i) => `[${input}] -> <b>${preds[i][0].toFixed(4)}</b> (Target: ${y[i][0]})`).join('<br>');
|
|
114
112
|
|
|
115
113
|
} catch (e) {
|
|
116
|
-
statusEl.textContent = '
|
|
114
|
+
statusEl.textContent = 'Error: ' + e.message;
|
|
117
115
|
}
|
|
118
116
|
}
|
|
119
117
|
train();
|
|
@@ -319,11 +317,16 @@ console.log(`\nAccuracy: ${(correct / X.length * 100).toFixed(2)}%`);
|
|
|
319
317
|
# Save & Load Models
|
|
320
318
|
|
|
321
319
|
```javascript
|
|
322
|
-
|
|
323
|
-
import { saveModel, loadModel, Sequential } from "./src/jstorch.js";
|
|
320
|
+
import { saveModel, loadModel, Sequential } from "./src/jstorch";
|
|
324
321
|
|
|
322
|
+
// Save trained model
|
|
325
323
|
const json = saveModel(model);
|
|
326
|
-
|
|
324
|
+
|
|
325
|
+
// Create fresh model with same architecture and load weights
|
|
326
|
+
const model2 = new Sequential([
|
|
327
|
+
new Linear(2, 16), new ReLU(),
|
|
328
|
+
new Linear(16, 1)
|
|
329
|
+
]);
|
|
327
330
|
loadModel(model2, json);
|
|
328
331
|
```
|
|
329
332
|
|
|
@@ -360,7 +363,7 @@ node demo/<fileNameInDemo>.js
|
|
|
360
363
|
|
|
361
364
|
MIT License
|
|
362
365
|
|
|
363
|
-
Copyright (c) 2024
|
|
366
|
+
Copyright (c) 2024-2025
|
|
364
367
|
rizal-editors
|
|
365
368
|
|
|
366
369
|
---
|
|
@@ -39,4 +39,4 @@ finalPred.forEach((p, i) => {
|
|
|
39
39
|
});
|
|
40
40
|
console.log(`\nAverage Error: ${(totalError / X.length).toFixed(2)}`);
|
|
41
41
|
console.log(`Weight (slope): ${model.layers[0].W[0][0].toFixed(4)} (expected: 2.0)`);
|
|
42
|
-
console.log(`Bias: ${model.layers[0].b[0].toFixed(4)} (expected: 0.0)`);
|
|
42
|
+
console.log(`Bias: ${model.layers[0].b[0][0].toFixed(4)} (expected: 0.0)`);
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "mini-jstorch",
|
|
3
|
-
"version": "2.0.
|
|
3
|
+
"version": "2.0.3",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "A lightweight JavaScript neural network library for learning AI concepts and rapid Frontend experimentation. PyTorch-inspired, zero dependencies, perfect for educational use.",
|
|
6
6
|
"main": "index.js",
|
|
@@ -15,8 +15,7 @@
|
|
|
15
15
|
"tiny-ml",
|
|
16
16
|
"mini-neural-network",
|
|
17
17
|
"mini-ml-library",
|
|
18
|
-
"mini-js-ml"
|
|
19
|
-
"educational-ml"
|
|
18
|
+
"mini-js-ml"
|
|
20
19
|
],
|
|
21
20
|
"author": "Rizal",
|
|
22
21
|
"license": "MIT"
|
package/src/jstorch.js
CHANGED
|
@@ -81,62 +81,26 @@ export function crossEntropy(pred,target){
|
|
|
81
81
|
|
|
82
82
|
// ---------------------- USERS FRIENDLY UTILS (USE THIS FOR YOUR UTILS!) ----------------
|
|
83
83
|
export function fu_tensor(data, requiresGrad = false) {
|
|
84
|
-
if (!Array.isArray(data) || !Array.isArray(data[0]))
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
tensor.requiresGrad = requiresGrad;
|
|
89
|
-
return tensor;
|
|
84
|
+
if (!Array.isArray(data) || !Array.isArray(data[0])){
|
|
85
|
+
throw new Error("fu_tensor: Data must be 2D array");
|
|
86
|
+
}
|
|
87
|
+
return new Tensor(data, requiresGrad);
|
|
90
88
|
}
|
|
91
89
|
|
|
92
|
-
// fu_add
|
|
93
|
-
export function fu_add(a, b)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if (!(a instanceof Tensor)) {
|
|
99
|
-
a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
|
|
100
|
-
Array(b.shape()[1]).fill(a)
|
|
101
|
-
));
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
if (!(b instanceof Tensor)) {
|
|
105
|
-
b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
|
|
106
|
-
Array(a.shape()[1]).fill(b)
|
|
107
|
-
));
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
|
|
111
|
-
throw new Error(`fu_add: Shape mismatch ${a.shape()} vs ${b.shape()}`);
|
|
112
|
-
}
|
|
113
|
-
|
|
114
|
-
return new Tensor(a.data.map((r, i) => r.map((v, j) => v + b.data[i][j])));
|
|
90
|
+
// fu_add
|
|
91
|
+
export function fu_add(a, b){
|
|
92
|
+
if (!(a instanceof Tensor)) a = fu_tensor(a);
|
|
93
|
+
if (!(b instanceof Tensor)) b = fu_tensor(b);
|
|
94
|
+
|
|
95
|
+
return a.add(b);
|
|
115
96
|
}
|
|
116
97
|
|
|
117
98
|
// fu_mul
|
|
118
|
-
export function fu_mul(a, b)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
if (!(a instanceof Tensor)) {
|
|
124
|
-
a = fu_tensor(Array(b.shape()[0]).fill().map(() =>
|
|
125
|
-
Array(b.shape()[1]).fill(a)
|
|
126
|
-
));
|
|
127
|
-
}
|
|
128
|
-
|
|
129
|
-
if (!(b instanceof Tensor)) {
|
|
130
|
-
b = fu_tensor(Array(a.shape()[0]).fill().map(() =>
|
|
131
|
-
Array(a.shape()[1]).fill(b)
|
|
132
|
-
));
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
if (a.shape()[0] !== b.shape()[0] || a.shape()[1] !== b.shape()[1]) {
|
|
136
|
-
throw new Error(`fu_mul: Shape mismatch ${a.shape()} vs ${b.shape()}`);
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
return new Tensor(a.data.map((r, i) => r.map((v, j) => v * b.data[i][j])));
|
|
99
|
+
export function fu_mul(a, b){
|
|
100
|
+
if (!(a instanceof Tensor)) a = fu_tensor(a);
|
|
101
|
+
if (!(b instanceof Tensor)) b = fu_tensor(b);
|
|
102
|
+
|
|
103
|
+
return a.mul(b);
|
|
140
104
|
}
|
|
141
105
|
|
|
142
106
|
// fu_matmul
|
|
@@ -144,11 +108,7 @@ export function fu_matmul(a, b) {
|
|
|
144
108
|
if (!(a instanceof Tensor)) a = fu_tensor(a);
|
|
145
109
|
if (!(b instanceof Tensor)) b = fu_tensor(b);
|
|
146
110
|
|
|
147
|
-
|
|
148
|
-
throw new Error(`fu_matmul: Inner dimension mismatch ${a.shape()[1]} vs ${b.shape()[0]}`);
|
|
149
|
-
}
|
|
150
|
-
|
|
151
|
-
return new Tensor(dot(a.data, b.data));
|
|
111
|
+
return a.matmul(b);
|
|
152
112
|
}
|
|
153
113
|
|
|
154
114
|
// fu_sum
|
|
@@ -344,276 +304,144 @@ export class Tensor {
|
|
|
344
304
|
|
|
345
305
|
// ---------------------- Layers ----------------------
|
|
346
306
|
export class Linear {
|
|
347
|
-
constructor(
|
|
348
|
-
this.
|
|
349
|
-
this.
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
this.
|
|
353
|
-
this.
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
this.
|
|
307
|
+
constructor(inFeatures, outFeatures) {
|
|
308
|
+
this.inFeatures = inFeatures;
|
|
309
|
+
this.outFeatures = outFeatures;
|
|
310
|
+
|
|
311
|
+
// Weights: [inFeatures, outFeatures]
|
|
312
|
+
this.W = randomMatrix(inFeatures, outFeatures);
|
|
313
|
+
this.gradW = zeros(inFeatures, outFeatures);
|
|
314
|
+
|
|
315
|
+
// Bias: [1, outFeatures]
|
|
316
|
+
this.b = [Array(outFeatures).fill(0)];
|
|
317
|
+
this.gradB = [Array(outFeatures).fill(0)];
|
|
318
|
+
|
|
319
|
+
this.x = null; // cache input
|
|
357
320
|
}
|
|
358
|
-
|
|
359
|
-
_updateCache() {
|
|
360
|
-
const rows = this.W.length;
|
|
361
|
-
const cols = this.W[0].length;
|
|
362
|
-
this._WFlat = new Float32Array(rows * cols);
|
|
363
|
-
for (let i = 0; i < rows; i++){
|
|
364
|
-
const offset = i * cols;
|
|
365
|
-
const row = this.W[i];
|
|
366
|
-
for (let j = 0; j < cols; j++){
|
|
367
|
-
this._WFlat[offset + j] = row[j];
|
|
368
|
-
}
|
|
369
|
-
}
|
|
370
|
-
this._bFlat = new Float32Array(this.b);
|
|
371
|
-
}
|
|
372
321
|
|
|
373
|
-
forward(x){
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
const m = this.x.length;
|
|
385
|
-
const k = this.x[0].length;
|
|
386
|
-
const n = this.W[0].length;
|
|
387
|
-
|
|
388
|
-
if (!this._WFlat) {
|
|
389
|
-
const rows = this.W.length;
|
|
390
|
-
const cols = this.W[0].length;
|
|
391
|
-
this._WFlat = new Float32Array(rows * cols);
|
|
392
|
-
for (let i = 0; i < rows; i++) {
|
|
393
|
-
const offset = i * cols;
|
|
394
|
-
const row = this.W[i];
|
|
395
|
-
for (let j = 0; j < cols; j++) {
|
|
396
|
-
this._WFlat[offset + j] = row[j];
|
|
397
|
-
}
|
|
322
|
+
forward(x) {
|
|
323
|
+
// x: [batch, inFeatures]
|
|
324
|
+
this.x = x;
|
|
325
|
+
|
|
326
|
+
const out = dot(x, this.W); // [batch, outFeatures]
|
|
327
|
+
|
|
328
|
+
// add bias
|
|
329
|
+
for (let i = 0; i < out.length; i++) {
|
|
330
|
+
for (let j = 0; j < this.outFeatures; j++) {
|
|
331
|
+
out[i][j] += this.b[0][j];
|
|
398
332
|
}
|
|
399
|
-
this._bFlat = new Float32Array(this.b);
|
|
400
333
|
}
|
|
334
|
+
|
|
335
|
+
return out;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
backward(grad) {
|
|
339
|
+
// grad: [batch, outFeatures]
|
|
340
|
+
const batchSize = grad.length;
|
|
401
341
|
|
|
402
|
-
//
|
|
403
|
-
const
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
xFlat[offset + j] = row[j];
|
|
342
|
+
// mutate in place to preserve optimizer Reference
|
|
343
|
+
const xT = transpose(this.x);
|
|
344
|
+
const computedGradW = dot(xT, grad);
|
|
345
|
+
for (let i=0; i<this.inFeatures; i++){
|
|
346
|
+
for (let j=0; j<this.outFeatures; j++){
|
|
347
|
+
this.gradW[i][j] = computedGradW[i][j]
|
|
409
348
|
}
|
|
410
349
|
}
|
|
411
350
|
|
|
412
|
-
|
|
413
|
-
for (let
|
|
414
|
-
|
|
415
|
-
for (let
|
|
416
|
-
|
|
417
|
-
for (let l = 0; l < k; l++) {
|
|
418
|
-
sum += xFlat[xOffset + l] * this._WFlat[l * n + j];
|
|
419
|
-
}
|
|
420
|
-
outFlat[i * n + j] = sum + this._bFlat[j];
|
|
351
|
+
// gradB = sum over batch
|
|
352
|
+
for (let j=0; j<this.outFeatures; j++){
|
|
353
|
+
let sum = 0;
|
|
354
|
+
for (let i=0; i <batchSize; i++){
|
|
355
|
+
sum+=grad[i][j]
|
|
421
356
|
}
|
|
357
|
+
this.gradB[0][j] = sum;
|
|
422
358
|
}
|
|
423
359
|
|
|
424
|
-
const
|
|
425
|
-
|
|
426
|
-
const row = Array(n);
|
|
427
|
-
const offset = i * n;
|
|
428
|
-
for (let j = 0; j < n; j++) {
|
|
429
|
-
row[j] = outFlat[offset + j];
|
|
430
|
-
}
|
|
431
|
-
out[i] = row;
|
|
432
|
-
}
|
|
360
|
+
const WT = transpose(this.W);
|
|
361
|
+
const gradInput = dot(grad, WT);
|
|
433
362
|
|
|
434
|
-
return
|
|
363
|
+
return gradInput;
|
|
435
364
|
}
|
|
436
365
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
const gradFlat = new Float32Array(m * n);
|
|
444
|
-
for (let i = 0; i < m; i++) {
|
|
445
|
-
const row = grad[i];
|
|
446
|
-
const offset = i * n;
|
|
447
|
-
for (let j = 0; j < n; j++) {
|
|
448
|
-
gradFlat[offset + j] = row[j];
|
|
449
|
-
}
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
// Convert x to Float32Array
|
|
453
|
-
const xFlat = new Float32Array(m * k);
|
|
454
|
-
for (let i = 0; i < m; i++) {
|
|
455
|
-
const row = this.x[i];
|
|
456
|
-
const offset = i * k;
|
|
457
|
-
for (let j = 0; j < k; j++) {
|
|
458
|
-
xFlat[offset + j] = row[j];
|
|
459
|
-
}
|
|
460
|
-
}
|
|
461
|
-
|
|
462
|
-
// Reset gradW
|
|
463
|
-
for (let i = 0; i < this.gradW.length; i++) {
|
|
464
|
-
for (let j = 0; j < this.gradW[0].length; j++) {
|
|
465
|
-
this.gradW[i][j] = 0;
|
|
466
|
-
}
|
|
467
|
-
}
|
|
468
|
-
|
|
469
|
-
// Compute gradW = x^T * grad
|
|
470
|
-
for (let i = 0; i < k; i++) {
|
|
471
|
-
for (let j = 0; j < n; j++) {
|
|
472
|
-
let sum = 0;
|
|
473
|
-
for (let batch = 0; batch < m; batch++) {
|
|
474
|
-
sum += xFlat[batch * k + i] * gradFlat[batch * n + j];
|
|
475
|
-
}
|
|
476
|
-
this.gradW[i][j] = sum;
|
|
477
|
-
}
|
|
478
|
-
}
|
|
479
|
-
|
|
480
|
-
// Compute gradb
|
|
481
|
-
for (let j = 0; j < n; j++) {
|
|
482
|
-
let sum = 0;
|
|
483
|
-
for (let batch = 0; batch < m; batch++) {
|
|
484
|
-
sum += gradFlat[batch * n + j];
|
|
485
|
-
}
|
|
486
|
-
this.gradb[j] = sum;
|
|
487
|
-
}
|
|
488
|
-
|
|
489
|
-
const gradInputFlat = new Float32Array(m * k);
|
|
490
|
-
for (let i = 0; i < m; i++) {
|
|
491
|
-
for (let j = 0; j < k; j++) {
|
|
492
|
-
let sum = 0;
|
|
493
|
-
for (let l = 0; l < n; l++) {
|
|
494
|
-
sum += gradFlat[i * n + l] * this.W[j][l];
|
|
495
|
-
}
|
|
496
|
-
gradInputFlat[i * k + j] = sum;
|
|
497
|
-
}
|
|
498
|
-
}
|
|
499
|
-
|
|
500
|
-
// Convert back to 2D array
|
|
501
|
-
const gradInput = Array(m);
|
|
502
|
-
for (let i = 0; i < m; i++) {
|
|
503
|
-
const row = Array(k);
|
|
504
|
-
const offset = i * k;
|
|
505
|
-
for (let j = 0; j < k; j++) {
|
|
506
|
-
row[j] = gradInputFlat[offset + j];
|
|
507
|
-
}
|
|
508
|
-
gradInput[i] = row;
|
|
509
|
-
}
|
|
510
|
-
|
|
511
|
-
if (this.originalShape === '3d') {
|
|
512
|
-
return gradInput.map(row => [row]);
|
|
513
|
-
}
|
|
514
|
-
return gradInput;
|
|
515
|
-
}
|
|
366
|
+
step(lr) {
|
|
367
|
+
for (let i = 0; i < this.inFeatures; i++) {
|
|
368
|
+
for (let j = 0; j < this.outFeatures; j++) {
|
|
369
|
+
this.W[i][j] -= lr * this.gradW[i][j];
|
|
370
|
+
}
|
|
371
|
+
}
|
|
516
372
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
return '3d';
|
|
520
|
-
} else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
|
|
521
|
-
return '2d';
|
|
522
|
-
} else {
|
|
523
|
-
throw new Error(`Unsupported input shape for Linear layer`);
|
|
373
|
+
for (let j = 0; j < this.outFeatures; j++) {
|
|
374
|
+
this.b[0][j] -= lr * this.gradB[0][j];
|
|
524
375
|
}
|
|
525
376
|
}
|
|
526
377
|
|
|
527
|
-
parameters(){
|
|
528
|
-
return [
|
|
529
|
-
{param: this.W, grad: this.gradW},
|
|
530
|
-
{param:
|
|
531
|
-
];
|
|
378
|
+
parameters() {
|
|
379
|
+
return [
|
|
380
|
+
{ param: this.W, grad: this.gradW },
|
|
381
|
+
{ param: this.b, grad: this.gradB }
|
|
382
|
+
];
|
|
532
383
|
}
|
|
533
384
|
}
|
|
534
385
|
|
|
535
386
|
export class Flatten {
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
for (let h = 0; h < height; h++) {
|
|
595
|
-
const row = [];
|
|
596
|
-
for (let w = 0; w < width; w++) {
|
|
597
|
-
row.push(flat[index++]);
|
|
598
|
-
}
|
|
599
|
-
channel.push(row);
|
|
600
|
-
}
|
|
601
|
-
result.push(channel);
|
|
602
|
-
}
|
|
603
|
-
return result;
|
|
604
|
-
} else if (shape.type === '2d') {
|
|
605
|
-
const [height, width] = shape.dims;
|
|
606
|
-
const result = [];
|
|
607
|
-
for (let h = 0; h < height; h++) {
|
|
608
|
-
result.push(flat.slice(h * width, h * width + width));
|
|
609
|
-
}
|
|
610
|
-
return result;
|
|
611
|
-
} else {
|
|
612
|
-
return flat; // 1d
|
|
613
|
-
}
|
|
614
|
-
}
|
|
615
|
-
|
|
616
|
-
parameters() { return []; }
|
|
387
|
+
constructor(){
|
|
388
|
+
this.originalShape = null;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
forward(x){
|
|
392
|
+
// Save full shape per sample
|
|
393
|
+
this.originalShape = x.map(sample => this._getDims(sample));
|
|
394
|
+
|
|
395
|
+
return x.map(sample => this._flattenDeep(sample));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
backward(grad){
|
|
399
|
+
return grad.map((flat, i) =>
|
|
400
|
+
this._reshape(flat, this.originalShape[i])
|
|
401
|
+
);
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
// Get dimensions recursively
|
|
405
|
+
_getDims(arr){
|
|
406
|
+
const dims = [];
|
|
407
|
+
let current = arr;
|
|
408
|
+
while (Array.isArray(current)){
|
|
409
|
+
dims.push(current.length);
|
|
410
|
+
current = current[0];
|
|
411
|
+
}
|
|
412
|
+
return dims;
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
// Flatten ANY depth
|
|
416
|
+
_flattenDeep(arr){
|
|
417
|
+
return arr.flat(Infinity);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
// Reshape back using saved dims
|
|
421
|
+
_reshape(flat, dims){
|
|
422
|
+
let index = 0;
|
|
423
|
+
|
|
424
|
+
function build(dimIdx){
|
|
425
|
+
const size = dims[dimIdx];
|
|
426
|
+
const result = [];
|
|
427
|
+
|
|
428
|
+
if (dimIdx === dims.length - 1){
|
|
429
|
+
for (let i=0; i<size; i++){
|
|
430
|
+
result.push(flat[index++]);
|
|
431
|
+
}
|
|
432
|
+
} else {
|
|
433
|
+
for (let i=0; i<size; i++){
|
|
434
|
+
result.push(build(dimIdx + 1));
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
return result;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
return build(0);
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
parameters() { return []; }
|
|
617
445
|
}
|
|
618
446
|
|
|
619
447
|
// ---------------------- Conv2D (BETA) ----------------------
|
|
@@ -632,24 +460,25 @@ export class Conv2D {
|
|
|
632
460
|
);
|
|
633
461
|
this.x = null;
|
|
634
462
|
|
|
635
|
-
// Cache Float32Array untuk kernels
|
|
636
463
|
this._WFlat = null;
|
|
637
|
-
this.
|
|
464
|
+
this._updateCache();
|
|
638
465
|
}
|
|
639
466
|
|
|
640
|
-
|
|
641
|
-
this._WFlat = this.W.map(oc =>
|
|
467
|
+
_updateCache(){
|
|
468
|
+
this._WFlat = this.W.map(oc =>
|
|
642
469
|
oc.map(ic => {
|
|
643
470
|
const rows = ic.length;
|
|
644
471
|
const cols = ic[0].length;
|
|
645
472
|
const flat = new Float32Array(rows * cols);
|
|
646
|
-
|
|
473
|
+
|
|
474
|
+
for (let i=0; i<rows; i++){
|
|
647
475
|
const offset = i * cols;
|
|
648
476
|
const row = ic[i];
|
|
649
|
-
for (let j
|
|
650
|
-
flat[offset
|
|
477
|
+
for (let j=0; j<cols; j++){
|
|
478
|
+
flat[offset+j] = row[j];
|
|
651
479
|
}
|
|
652
480
|
}
|
|
481
|
+
|
|
653
482
|
return flat;
|
|
654
483
|
})
|
|
655
484
|
);
|
|
@@ -697,6 +526,8 @@ export class Conv2D {
|
|
|
697
526
|
}
|
|
698
527
|
|
|
699
528
|
forward(batch) {
|
|
529
|
+
this._updateCache();
|
|
530
|
+
|
|
700
531
|
this.x = batch;
|
|
701
532
|
const kH = this.kernel;
|
|
702
533
|
const kW = this.kernel;
|
|
@@ -727,43 +558,52 @@ export class Conv2D {
|
|
|
727
558
|
|
|
728
559
|
backward(grad) {
|
|
729
560
|
const batchSize = this.x.length;
|
|
730
|
-
const
|
|
731
|
-
const gradInput = this.x.map(sample =>
|
|
561
|
+
const gradInput = this.x.map(sample =>
|
|
732
562
|
sample.map(chan => zeros(chan.length, chan[0].length))
|
|
733
563
|
);
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
564
|
+
|
|
565
|
+
// Zero existing gradW in place
|
|
566
|
+
for (let oc=0; oc<this.outC; oc++){
|
|
567
|
+
for (let ic=0; ic<this.inC; ic++){
|
|
568
|
+
for (let i=0; i<this.kernel; i++){
|
|
569
|
+
for (let j=0; j<this.kernel; j++){
|
|
570
|
+
this.gradW[oc][ic][i][j] = 0;
|
|
571
|
+
}
|
|
572
|
+
}
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
for (let b=0; b<batchSize; b++){
|
|
577
|
+
for (let oc=0; oc<this.outC; oc++){
|
|
578
|
+
for (let ic=0; ic<this.inC; ic++){
|
|
738
579
|
const outGrad = grad[b][oc];
|
|
739
580
|
|
|
740
|
-
//
|
|
741
|
-
for (let i
|
|
742
|
-
for (let j
|
|
581
|
+
// Accumulate gradW in place
|
|
582
|
+
for (let i=0; i<this.kernel; i++){
|
|
583
|
+
for (let j=0; j<this.kernel; j++){
|
|
743
584
|
let sum = 0;
|
|
744
|
-
for (let y
|
|
745
|
-
for (let x
|
|
585
|
+
for (let y=0; y<outGrad.length; y++){
|
|
586
|
+
for (let x=0; x<outGrad[0].length; x++){
|
|
746
587
|
const inY = y * this.stride + i;
|
|
747
588
|
const inX = x * this.stride + j;
|
|
748
|
-
if (inY
|
|
749
|
-
sum
|
|
589
|
+
if (inY<this.x[b][ic].length && inX < this.x[b][ic][0].length){
|
|
590
|
+
sum+=this.x[b][ic][inY][inX] * outGrad[y][x];
|
|
750
591
|
}
|
|
751
592
|
}
|
|
752
593
|
}
|
|
753
|
-
gradW[oc][ic][i][j] += sum;
|
|
594
|
+
this.gradW[oc][ic][i][j] += sum;
|
|
754
595
|
}
|
|
755
596
|
}
|
|
756
|
-
|
|
757
|
-
// Compute gradInput
|
|
758
|
-
for (let y
|
|
759
|
-
for (let x
|
|
760
|
-
for (let ki
|
|
761
|
-
for (let kj
|
|
597
|
+
|
|
598
|
+
// Compute gradInput
|
|
599
|
+
for (let y=0; y<outGrad.length; y++){
|
|
600
|
+
for (let x=0; x<outGrad[0].length; x++){
|
|
601
|
+
for (let ki=0; ki<this.kernel; ki++){
|
|
602
|
+
for (let kj=0; kj<this.kernel; kj++){
|
|
762
603
|
const inY = y * this.stride + ki;
|
|
763
604
|
const inX = x * this.stride + kj;
|
|
764
|
-
if (inY
|
|
765
|
-
gradInput[b][ic][inY][inX] +=
|
|
766
|
-
this.W[oc][ic][ki][kj] * outGrad[y][x];
|
|
605
|
+
if (inY<gradInput[b][ic].length && inX < gradInput[b][ic][0].length){
|
|
606
|
+
gradInput[b][ic][inY][inX] += this.W[oc][ic][ki][kj] * outGrad[y][x];
|
|
767
607
|
}
|
|
768
608
|
}
|
|
769
609
|
}
|
|
@@ -772,8 +612,7 @@ export class Conv2D {
|
|
|
772
612
|
}
|
|
773
613
|
}
|
|
774
614
|
}
|
|
775
|
-
|
|
776
|
-
this.gradW = gradW;
|
|
615
|
+
|
|
777
616
|
return gradInput;
|
|
778
617
|
}
|
|
779
618
|
|
|
@@ -781,7 +620,7 @@ export class Conv2D {
|
|
|
781
620
|
return this.W.flatMap((w, oc) =>
|
|
782
621
|
w.map((wc, ic) => ({
|
|
783
622
|
param: wc,
|
|
784
|
-
grad: this.gradW[oc][ic]
|
|
623
|
+
grad: this.gradW[oc][ic] // Reference stays valid — gradW is mutated in-place now
|
|
785
624
|
}))
|
|
786
625
|
);
|
|
787
626
|
}
|
|
@@ -872,6 +711,14 @@ export class Sequential {
|
|
|
872
711
|
return state;
|
|
873
712
|
}
|
|
874
713
|
|
|
714
|
+
step(lr){
|
|
715
|
+
this.layers.forEach(layer => {
|
|
716
|
+
if(typeof layer.step === "function"){
|
|
717
|
+
layer.step(lr);
|
|
718
|
+
}
|
|
719
|
+
})
|
|
720
|
+
}
|
|
721
|
+
|
|
875
722
|
/**
|
|
876
723
|
* Load state dict
|
|
877
724
|
*/
|
|
@@ -895,43 +742,35 @@ export class Sequential {
|
|
|
895
742
|
}
|
|
896
743
|
|
|
897
744
|
// ---------------------- Activations ----------------------
|
|
898
|
-
export class ReLU{
|
|
899
|
-
constructor(){
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
if (this.originalShape === '3d') {
|
|
917
|
-
return grad.map((sample, i) =>
|
|
918
|
-
[sample[0].map((v, j) => this.mask[i][j] ? v : 0)]
|
|
919
|
-
);
|
|
920
|
-
} else {
|
|
921
|
-
return grad.map((row, i) =>
|
|
922
|
-
row.map((v, j) => this.mask[i][j] ? v : 0)
|
|
923
|
-
);
|
|
745
|
+
export class ReLU {
|
|
746
|
+
constructor() {
|
|
747
|
+
this.mask = null;
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
forward(x) {
|
|
751
|
+
this.mask = this._mapRecursive(x, v => v > 0);
|
|
752
|
+
return this._mapRecursive(x, v => Math.max(0, v));
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
backward(grad) {
|
|
756
|
+
return this._mapRecursiveWithMask(grad, this.mask, (g, m) => m ? g : 0);
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
// ===== helper =====
|
|
760
|
+
_mapRecursive(arr, fn) {
|
|
761
|
+
if (Array.isArray(arr)) {
|
|
762
|
+
return arr.map(v => this._mapRecursive(v, fn));
|
|
924
763
|
}
|
|
764
|
+
return fn(arr);
|
|
925
765
|
}
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
if (Array.isArray(
|
|
929
|
-
return
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
} else {
|
|
933
|
-
throw new Error(`Unsupported input shape for ReLU`);
|
|
766
|
+
|
|
767
|
+
_mapRecursiveWithMask(arr, mask, fn) {
|
|
768
|
+
if (Array.isArray(arr)) {
|
|
769
|
+
return arr.map((v, i) =>
|
|
770
|
+
this._mapRecursiveWithMask(v, mask[i], fn)
|
|
771
|
+
);
|
|
934
772
|
}
|
|
773
|
+
return fn(arr, mask);
|
|
935
774
|
}
|
|
936
775
|
}
|
|
937
776
|
|
|
@@ -1156,7 +995,10 @@ export class Tokenizer {
|
|
|
1156
995
|
.replace(/\s+/g, ' ')
|
|
1157
996
|
.trim()
|
|
1158
997
|
.split(' ')
|
|
1159
|
-
.filter(word =>
|
|
998
|
+
.filter(word =>
|
|
999
|
+
word.length > 0 &&
|
|
1000
|
+
(!/^[.!?;:,]+$/.test(word) || word.length > 1)
|
|
1001
|
+
);
|
|
1160
1002
|
}
|
|
1161
1003
|
|
|
1162
1004
|
/**
|
|
@@ -1812,23 +1654,25 @@ export class ReduceLROnPlateau {
|
|
|
1812
1654
|
|
|
1813
1655
|
// ---------------------- ELU Activation ----------------------
|
|
1814
1656
|
export class ELU {
|
|
1815
|
-
constructor(alpha=1.0)
|
|
1657
|
+
constructor(alpha=1.0){
|
|
1816
1658
|
this.alpha = alpha;
|
|
1817
|
-
this.
|
|
1659
|
+
this.x = null; // Cache input for correct derivative
|
|
1818
1660
|
}
|
|
1819
|
-
|
|
1820
|
-
forward(x)
|
|
1821
|
-
this.
|
|
1822
|
-
|
|
1661
|
+
|
|
1662
|
+
forward(x){
|
|
1663
|
+
this.x=x; // Store original input for backward
|
|
1664
|
+
return x.map(row =>
|
|
1665
|
+
row.map(v=>v>0 ? v : this.alpha*(Math.exp(v) - 1))
|
|
1823
1666
|
);
|
|
1824
|
-
return this.out;
|
|
1825
1667
|
}
|
|
1826
|
-
|
|
1827
|
-
backward(grad)
|
|
1828
|
-
return grad.map((row, i) =>
|
|
1829
|
-
row.map((v, j) =>
|
|
1830
|
-
|
|
1831
|
-
|
|
1668
|
+
|
|
1669
|
+
backward(grad){
|
|
1670
|
+
return grad.map((row, i) =>
|
|
1671
|
+
row.map((v, j) => {
|
|
1672
|
+
const xVal = this.x[i][j];
|
|
1673
|
+
// d/dx ELU: 1 if x > 0, else alpha * exp(x)
|
|
1674
|
+
return v * (xVal > 0 ? 1 : this.alpha * Math.exp(xVal));
|
|
1675
|
+
})
|
|
1832
1676
|
);
|
|
1833
1677
|
}
|
|
1834
1678
|
}
|
|
@@ -1910,10 +1754,10 @@ export class BatchNorm2d {
|
|
|
1910
1754
|
|
|
1911
1755
|
// Parameters
|
|
1912
1756
|
if (affine) {
|
|
1913
|
-
this.weight = Array(numFeatures).fill(1);
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1757
|
+
this.weight = [Array(numFeatures).fill(1)];
|
|
1758
|
+
this.bias = [Array(numFeatures).fill(0)];
|
|
1759
|
+
this.gradWeight = [Array(numFeatures).fill(0)];
|
|
1760
|
+
this.gradBias = [Array(numFeatures).fill(0)];
|
|
1917
1761
|
}
|
|
1918
1762
|
|
|
1919
1763
|
// Running statistics
|
|
@@ -1993,7 +1837,7 @@ export class BatchNorm2d {
|
|
|
1993
1837
|
|
|
1994
1838
|
// Apply affine transformation if enabled
|
|
1995
1839
|
if (this.affine) {
|
|
1996
|
-
|
|
1840
|
+
channelOut[i][j] = channelOut[i][j] * this.weight[0][c] + this.bias[0][c];
|
|
1997
1841
|
}
|
|
1998
1842
|
}
|
|
1999
1843
|
}
|
|
@@ -2022,7 +1866,7 @@ export class BatchNorm2d {
|
|
|
2022
1866
|
|
|
2023
1867
|
// Apply affine transformation if enabled
|
|
2024
1868
|
if (this.affine) {
|
|
2025
|
-
channelOut[i][j] = channelOut[i][j] * this.weight[c] + this.bias[c];
|
|
1869
|
+
channelOut[i][j] = channelOut[i][j] * this.weight[0][c] + this.bias[0][c];
|
|
2026
1870
|
}
|
|
2027
1871
|
}
|
|
2028
1872
|
}
|
|
@@ -2052,8 +1896,10 @@ export class BatchNorm2d {
|
|
|
2052
1896
|
);
|
|
2053
1897
|
|
|
2054
1898
|
if (this.affine) {
|
|
2055
|
-
|
|
2056
|
-
|
|
1899
|
+
for (let c=0; c<channels; c++){
|
|
1900
|
+
this.gradWeight[0][c] = 0;
|
|
1901
|
+
this.gradBias[0][c] = 0;
|
|
1902
|
+
}
|
|
2057
1903
|
}
|
|
2058
1904
|
|
|
2059
1905
|
for (let c = 0; c < channels; c++) {
|
|
@@ -2083,7 +1929,7 @@ export class BatchNorm2d {
|
|
|
2083
1929
|
let grad = channelGrad[i][j];
|
|
2084
1930
|
|
|
2085
1931
|
if (this.affine) {
|
|
2086
|
-
grad *= this.weight[c];
|
|
1932
|
+
grad *= this.weight[0][c];
|
|
2087
1933
|
}
|
|
2088
1934
|
|
|
2089
1935
|
grad *= stdInv;
|
|
@@ -2093,8 +1939,8 @@ export class BatchNorm2d {
|
|
|
2093
1939
|
}
|
|
2094
1940
|
|
|
2095
1941
|
if (this.affine) {
|
|
2096
|
-
this.gradWeight[c] = sumGradWeight / batchSize;
|
|
2097
|
-
this.gradBias[c] = sumGradBias / batchSize;
|
|
1942
|
+
this.gradWeight[0][c] = sumGradWeight / batchSize;
|
|
1943
|
+
this.gradBias[0][c] = sumGradBias / batchSize;
|
|
2098
1944
|
}
|
|
2099
1945
|
}
|
|
2100
1946
|
|
|
@@ -2104,9 +1950,9 @@ export class BatchNorm2d {
|
|
|
2104
1950
|
parameters() {
|
|
2105
1951
|
if (!this.affine) return [];
|
|
2106
1952
|
return [
|
|
2107
|
-
|
|
2108
|
-
|
|
2109
|
-
|
|
1953
|
+
{ param: this.weight, grad: this.gradWeight },
|
|
1954
|
+
{ param: this.bias, grad: this.gradBias }
|
|
1955
|
+
]
|
|
2110
1956
|
}
|
|
2111
1957
|
|
|
2112
1958
|
train() { this.training = true; }
|
|
@@ -2115,20 +1961,143 @@ export class BatchNorm2d {
|
|
|
2115
1961
|
|
|
2116
1962
|
// ---------------------- Model Save/Load (BETA) ----------------------
|
|
2117
1963
|
export function saveModel(model){
|
|
2118
|
-
if(!(model instanceof Sequential))
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
1964
|
+
if(!(model instanceof Sequential)){
|
|
1965
|
+
throw new Error("saveModel supports only Sequential models");
|
|
1966
|
+
}
|
|
1967
|
+
|
|
1968
|
+
const state = {
|
|
1969
|
+
version: "2.0.0",
|
|
1970
|
+
layers: model.layers.map((layer, idx) => {
|
|
1971
|
+
const params = layer.parameters ? layer.parameters() : [];
|
|
1972
|
+
|
|
1973
|
+
if (params.length === 0){
|
|
1974
|
+
return { type: layer.constructor.name, params: [] };
|
|
1975
|
+
}
|
|
1976
|
+
|
|
1977
|
+
return {
|
|
1978
|
+
type: layer.constructor.name,
|
|
1979
|
+
params: params.map(p => ({
|
|
1980
|
+
// Deep clone parameter data
|
|
1981
|
+
data: p.param.map(row =>
|
|
1982
|
+
Array.isArray(row) ? [...row] : row
|
|
1983
|
+
),
|
|
1984
|
+
// Preserve shape metadata for validation
|
|
1985
|
+
shape: Array.isArray(p.param[0])
|
|
1986
|
+
? [p.param.length, p.param[0].length]
|
|
1987
|
+
: [p.param.length]
|
|
1988
|
+
}))
|
|
1989
|
+
};
|
|
1990
|
+
})
|
|
1991
|
+
};
|
|
1992
|
+
|
|
1993
|
+
return JSON.stringify(state);
|
|
2122
1994
|
}
|
|
2123
1995
|
|
|
2124
|
-
export function loadModel(model,json){
|
|
2125
|
-
if(!(model instanceof Sequential))
|
|
2126
|
-
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
|
|
1996
|
+
export function loadModel(model, json){
|
|
1997
|
+
if (!(model instanceof Sequential)){
|
|
1998
|
+
throw new Error("loadModel supports only Sequential models");
|
|
1999
|
+
}
|
|
2000
|
+
|
|
2001
|
+
const state = JSON.parse(json);
|
|
2002
|
+
|
|
2003
|
+
// Validate structure
|
|
2004
|
+
if (!state.layers || !Array.isArray(state.layers)){
|
|
2005
|
+
throw new Error("loadModel: invalid save format - missing 'layers' array");
|
|
2006
|
+
}
|
|
2007
|
+
|
|
2008
|
+
if (state.layers.length !== model.layers.length){
|
|
2009
|
+
console.warn(
|
|
2010
|
+
`[JST WARN]: Layer count mismatch - saved ${state.layers.length},` +
|
|
2011
|
+
`current model has ${model.layers.length}. Loading what matches.`
|
|
2012
|
+
);
|
|
2013
|
+
}
|
|
2014
|
+
|
|
2015
|
+
let loadedCount = 0;
|
|
2016
|
+
let skippedCount = 0;
|
|
2017
|
+
|
|
2018
|
+
for(let i=0; i<Math.min(state.layers.length, model.layers.length); i++){
|
|
2019
|
+
const savedLayer = state.layers[i];
|
|
2020
|
+
const currentLayer = model.layers[i];
|
|
2021
|
+
|
|
2022
|
+
if(savedLayer.params.length === 0){
|
|
2023
|
+
// Layer with no trainable params - skip
|
|
2024
|
+
continue
|
|
2025
|
+
}
|
|
2026
|
+
|
|
2027
|
+
// Validate layer type
|
|
2028
|
+
if (savedLayer.type !== currentLayer.constructor.name){
|
|
2029
|
+
console.warn(
|
|
2030
|
+
`[JST WARN]: Layer ${i} type mismatch - ` +
|
|
2031
|
+
`saved: ${savedLayer.type}, current: ${currentLayer.constructor.name}. Skipping.`
|
|
2032
|
+
);
|
|
2033
|
+
skippedCount++;
|
|
2034
|
+
continue;
|
|
2035
|
+
}
|
|
2036
|
+
|
|
2037
|
+
// Get current layer parameters
|
|
2038
|
+
const currentParams = currentLayer.parameters ? currentLayer.parameters() : [];
|
|
2039
|
+
|
|
2040
|
+
if (currentParams.length !== savedLayer.params.length){
|
|
2041
|
+
console.warn(
|
|
2042
|
+
`[JST WARN]: Layer ${i} parameter count mismatch - ` +
|
|
2043
|
+
`saved: ${savedLayer.params.length}, current: ${currentParams.length}. Skipping.`
|
|
2044
|
+
);
|
|
2045
|
+
skippedCount++;
|
|
2046
|
+
continue;
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
// Load parameters wiht shape validation
|
|
2050
|
+
for (let j=0; j<savedLayer.params.length; j++){
|
|
2051
|
+
const savedParam = savedLayer.params[j];
|
|
2052
|
+
const currentParam = currentParams[j].param;
|
|
2053
|
+
|
|
2054
|
+
// Validate shape
|
|
2055
|
+
const currentRows = currentParam.length;
|
|
2056
|
+
const currentCols = Array.isArray(currentParam[0])
|
|
2057
|
+
? currentParam[0].length
|
|
2058
|
+
: 1;
|
|
2059
|
+
|
|
2060
|
+
const savedRows = savedParam.shape[0];
|
|
2061
|
+
const savedCols = savedParam.shape[1] || 1;
|
|
2062
|
+
|
|
2063
|
+
if (currentRows !== savedRows || currentCols !== savedCols){
|
|
2064
|
+
console.warn(
|
|
2065
|
+
`[JST WARN]: Layer ${i} param ${j} shape mismatch - ` +
|
|
2066
|
+
`saved: [${savedRows}, ${savedCols}],` +
|
|
2067
|
+
`current: [${currentRows}, ${currentCols}]. Skipping this parameter.`
|
|
2068
|
+
);
|
|
2069
|
+
continue
|
|
2070
|
+
}
|
|
2071
|
+
|
|
2072
|
+
// Copy parameter data
|
|
2073
|
+
if (Array.isArray(currentParam[0])){
|
|
2074
|
+
// 2D Parameter
|
|
2075
|
+
for (let r=0; r<currentRows; r++){
|
|
2076
|
+
for (let c=0; c<currentCols; c++){
|
|
2077
|
+
currentParam[r][c] = savedParam.data[r][c];
|
|
2078
|
+
}
|
|
2079
|
+
}
|
|
2080
|
+
} else {
|
|
2081
|
+
// 1D parameter
|
|
2082
|
+
for (let r=0; r<currentRows; r++){
|
|
2083
|
+
currentParam[r] = savedParam.data[r];
|
|
2084
|
+
}
|
|
2085
|
+
}
|
|
2086
|
+
}
|
|
2087
|
+
|
|
2088
|
+
// Invalidate any cached flat representations
|
|
2089
|
+
if (typeof currentLayer._updateCache === 'function'){
|
|
2090
|
+
currentLayer._updateCache();
|
|
2091
|
+
}
|
|
2092
|
+
|
|
2093
|
+
loadedCount++;
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
console.log(
|
|
2097
|
+
`[JST]: Model loaded: ${loadedCount} layers restored, ${skippedCount} skipped.`
|
|
2098
|
+
);
|
|
2099
|
+
|
|
2100
|
+
return model;
|
|
2132
2101
|
}
|
|
2133
2102
|
|
|
2134
2103
|
// ---------------------- Advanced Utils ----------------------
|