mini-jstorch 1.2.2 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,127 @@
1
+ // src/Dummy/debug_train.js
2
+ import {
3
+ Tensor,
4
+ Linear,
5
+ Sequential,
6
+ ReLU,
7
+ Sigmoid,
8
+ CrossEntropyLoss,
9
+ Adam,
10
+ MathOps
11
+ } from '../src/Dummy/exp.js';
12
+
13
+ // ---------------------- Simple Debug Data ----------------------
14
+ function generateSimpleData() {
15
+ // VERY simple data: 4 points in 2D
16
+ const X = [
17
+ [0, 0],
18
+ [0, 1],
19
+ [1, 0],
20
+ [1, 1]
21
+ ];
22
+
23
+ const y = [
24
+ [0], // AND operation
25
+ [0],
26
+ [0],
27
+ [1]
28
+ ];
29
+
30
+ return { X, y };
31
+ }
32
+
33
+ // ---------------------- Debug Model ----------------------
34
+ function createDebugModel() {
35
+ return new Sequential([
36
+ new Linear(2, 2), // Small layer
37
+ new ReLU(),
38
+ new Linear(2, 1), // Output layer
39
+ new Sigmoid()
40
+ ]);
41
+ }
42
+
43
+ // ---------------------- Debug Training ----------------------
44
+ function debugTraining() {
45
+ console.log("šŸ” DEBUG TRAINING STARTED");
46
+ console.log("=========================");
47
+
48
+ const data = generateSimpleData();
49
+ const model = createDebugModel();
50
+ const lossFunction = new CrossEntropyLoss();
51
+ const parameters = model.parameters();
52
+
53
+ console.log("Model parameters:", parameters.length);
54
+ console.log("Data samples:", data.X.length);
55
+
56
+ // Single step debug
57
+ for (let step = 0; step < 10; step++) {
58
+ console.log(`\n--- Step ${step} ---`);
59
+
60
+ // Forward pass
61
+ const predictions = model.forward(data.X);
62
+ console.log("Predictions:", predictions.map(p => p[0].toFixed(3)));
63
+
64
+ const loss = lossFunction.forward(predictions, data.y);
65
+ console.log("Loss:", loss);
66
+
67
+ if (isNaN(loss)) {
68
+ console.log("āŒ NaN LOSS DETECTED!");
69
+ console.log("Predictions:", predictions);
70
+ console.log("Targets:", data.y);
71
+ break;
72
+ }
73
+
74
+ // Backward pass
75
+ const grad = lossFunction.backward();
76
+ console.log("Gradient:", grad.map(g => g[0].toFixed(3)));
77
+
78
+ model.backward(grad);
79
+
80
+ // Check gradients
81
+ console.log("Parameter gradients:");
82
+ parameters.forEach((param, idx) => {
83
+ if (Array.isArray(param.grad[0])) {
84
+ console.log(` Param ${idx} grad:`, param.grad.map(row =>
85
+ row.map(v => v.toFixed(3))
86
+ ));
87
+ } else {
88
+ console.log(` Param ${idx} grad:`, param.grad.map(v => v.toFixed(3)));
89
+ }
90
+ });
91
+
92
+ // Update weights
93
+ const optimizer = new Adam(parameters, 0.1);
94
+ optimizer.step();
95
+
96
+ // Reset gradients manually
97
+ parameters.forEach(param => {
98
+ if (Array.isArray(param.grad[0])) {
99
+ for (let i = 0; i < param.grad.length; i++) {
100
+ for (let j = 0; j < param.grad[0].length; j++) {
101
+ param.grad[i][j] = 0;
102
+ }
103
+ }
104
+ } else {
105
+ for (let i = 0; i < param.grad.length; i++) {
106
+ param.grad[i] = 0;
107
+ }
108
+ }
109
+ });
110
+
111
+ // Calculate accuracy
112
+ const accuracy = calculateAccuracy(predictions, data.y);
113
+ console.log("Accuracy:", (accuracy * 100).toFixed(1) + "%");
114
+ }
115
+ }
116
+
117
+ function calculateAccuracy(predictions, targets) {
118
+ let correct = 0;
119
+ for (let i = 0; i < predictions.length; i++) {
120
+ const predLabel = predictions[i][0] > 0.5 ? 1 : 0;
121
+ if (predLabel === targets[i][0]) correct++;
122
+ }
123
+ return correct / predictions.length;
124
+ }
125
+
126
+ // Run debug
127
+ debugTraining();
@@ -0,0 +1,570 @@
1
+ // mini-jstorch-test.js - Fully Synchronous Test Suite
2
+ import MiniJSTorch, { Tensor } from '../testdummy.js';
3
+
4
+ // Initialize test framework
5
+ const torch = new MiniJSTorch();
6
+ let testsPassed = 0;
7
+ let testsFailed = 0;
8
+
9
+ function assert(condition, message) {
10
+ if (!condition) {
11
+ console.error(`āŒ FAIL: ${message}`);
12
+ testsFailed++;
13
+ } else {
14
+ console.log(`āœ… PASS: ${message}`);
15
+ testsPassed++;
16
+ }
17
+ }
18
+
19
+ function testSuite(name) {
20
+ console.log(`\n🧪 Testing ${name}...`);
21
+ }
22
+
23
+ function summarize() {
24
+ console.log(`\nšŸ“Š Test Summary:`);
25
+ console.log(`āœ… Passed: ${testsPassed}`);
26
+ console.log(`āŒ Failed: ${testsFailed}`);
27
+ console.log(`šŸ“ˆ Success Rate: ${((testsPassed / (testsPassed + testsFailed)) * 100).toFixed(1)}%`);
28
+ }
29
+
30
+ // ====================== TENSOR OPERATIONS TESTS ======================
31
+ testSuite("Tensor Operations");
32
+
33
+ function testTensorOperations() {
34
+ // Basic tensor creation
35
+ const t1 = new Tensor([1, 2, 3, 4], [2, 2]);
36
+ assert(t1.shape[0] === 2 && t1.shape[1] === 2, "Tensor shape creation");
37
+ assert(t1.data[0] === 1 && t1.data[3] === 4, "Tensor data initialization");
38
+
39
+ // Addition
40
+ const t2 = new Tensor([5, 6, 7, 8], [2, 2]);
41
+ const sum = t1.add(t2);
42
+ assert(sum.data[0] === 6 && sum.data[3] === 12, "Tensor addition");
43
+
44
+ // Subtraction
45
+ const sub = t1.sub(t2);
46
+ assert(sub.data[0] === -4 && sub.data[3] === -4, "Tensor subtraction");
47
+
48
+ // Multiplication
49
+ const mul = t1.mul(t2);
50
+ assert(mul.data[0] === 5 && mul.data[3] === 32, "Tensor multiplication");
51
+
52
+ // Division
53
+ const div = t2.div(t1);
54
+ assert(Math.abs(div.data[0] - 5) < 0.001, "Tensor division");
55
+
56
+ // Matrix multiplication
57
+ const matmul = t1.matmul(t2);
58
+ assert(Math.abs(matmul.data[0] - 19) < 0.001, "Matrix multiplication");
59
+
60
+ // Activation functions
61
+ const relu = t1.relu();
62
+ assert(relu.data[0] === 1 && relu.data[1] === 2, "ReLU activation");
63
+
64
+ const sigmoid = new Tensor([-1, 0, 1]).sigmoid();
65
+ assert(sigmoid.data[0] < 0.5 && sigmoid.data[1] === 0.5 && sigmoid.data[2] > 0.5, "Sigmoid activation");
66
+
67
+ const tanh = new Tensor([-1, 0, 1]).tanh();
68
+ assert(tanh.data[0] < 0 && tanh.data[1] === 0 && tanh.data[2] > 0, "Tanh activation");
69
+
70
+ // Reshape
71
+ const reshaped = t1.reshape([4]);
72
+ assert(reshaped.shape[0] === 4 && reshaped.data[2] === 3, "Tensor reshape");
73
+
74
+ // Transpose
75
+ const transposed = t1.transpose();
76
+ assert(transposed.shape[0] === 2 && transposed.shape[1] === 2, "Tensor transpose");
77
+ assert(transposed.data[1] === 3 && transposed.data[2] === 2, "Tensor transpose values");
78
+
79
+ // Sum and mean
80
+ const sumAll = t1.sum();
81
+ assert(sumAll.data[0] === 10, "Tensor sum all");
82
+
83
+ const meanAll = t1.mean();
84
+ assert(Math.abs(meanAll.data[0] - 2.5) < 0.1, "Tensor mean all");
85
+
86
+ // Memory management
87
+ const memoryStats = Tensor.memoryTracker.getStats();
88
+ assert(memoryStats.tensorCount > 0, "Memory tracking");
89
+ }
90
+
91
+ // ====================== LAYERS TESTS ======================
92
+ testSuite("Layers");
93
+
94
+ function testLayers() {
95
+ // Linear layer
96
+ const linear = torch.layers.linear(3, 2);
97
+ const input = new Tensor([1, 2, 3], [1, 3]);
98
+ const output = linear.forward(input);
99
+ assert(output.shape[0] === 1 && output.shape[1] === 2, "Linear layer forward pass");
100
+
101
+ // Conv2D layer
102
+ const conv = torch.layers.conv2d(3, 16, 3);
103
+ const convInput = new Tensor(new Float32Array(1 * 3 * 32 * 32), [1, 3, 32, 32]);
104
+ const convOutput = conv.forward(convInput);
105
+ assert(convOutput.shape[0] === 1 && convOutput.shape[1] === 16, "Conv2D layer forward pass");
106
+
107
+ // Activation layers
108
+ const reluLayer = torch.layers.relu();
109
+ const reluOutput = reluLayer.forward(new Tensor([-1, 0, 1]));
110
+ assert(reluOutput.data[0] === 0 && reluOutput.data[2] === 1, "ReLU layer");
111
+
112
+ const sigmoidLayer = torch.layers.sigmoid();
113
+ const sigmoidOutput = sigmoidLayer.forward(new Tensor([-1, 0, 1]));
114
+ assert(sigmoidOutput.data[0] < 0.5 && sigmoidOutput.data[1] === 0.5, "Sigmoid layer");
115
+
116
+ const tanhLayer = torch.layers.tanh();
117
+ const tanhOutput = tanhLayer.forward(new Tensor([-1, 0, 1]));
118
+ assert(tanhOutput.data[0] < 0 && tanhOutput.data[1] === 0, "Tanh layer");
119
+
120
+ // Dropout layer
121
+ const dropout = torch.layers.dropout(0.5);
122
+ dropout.train();
123
+ const dropoutInput = new Tensor([1, 2, 3, 4]);
124
+ const dropoutOutput = dropout.forward(dropoutInput);
125
+ assert(dropoutOutput.data.some(val => val === 0), "Dropout layer in train mode");
126
+
127
+ dropout.eval();
128
+ const dropoutOutputEval = dropout.forward(dropoutInput);
129
+ assert(dropoutOutputEval.data[0] === 1 && dropoutOutputEval.data[3] === 4, "Dropout layer in eval mode");
130
+
131
+ // BatchNorm2D
132
+ const bn = torch.layers.batchNorm2d(16);
133
+ bn.train();
134
+ const bnInput = new Tensor(new Float32Array(2 * 16 * 8 * 8), [2, 16, 8, 8]);
135
+ const bnOutput = bn.forward(bnInput);
136
+ assert(bnOutput.shape[0] === 2 && bnOutput.shape[1] === 16, "BatchNorm2D forward pass");
137
+
138
+ // LSTM
139
+ const lstm = torch.layers.lstm(10, 20, 2);
140
+ const lstmInput = new Tensor(new Float32Array(5 * 3 * 10), [5, 3, 10]);
141
+ const lstmOutput = lstm.forward(lstmInput);
142
+ assert(lstmOutput.output.shape[0] === 5 && lstmOutput.output.shape[2] === 20, "LSTM forward pass");
143
+
144
+ // MultiHeadAttention
145
+ const attn = torch.layers.attention(64, 8);
146
+ const attnInput = new Tensor(new Float32Array(2 * 10 * 64), [2, 10, 64]);
147
+ const attnOutput = attn.forward(attnInput, attnInput, attnInput);
148
+ assert(attnOutput.shape[0] === 2 && attnOutput.shape[2] === 64, "MultiHeadAttention forward pass");
149
+
150
+ // Transformer
151
+ const transformer = torch.layers.transformer(64, 8, 2);
152
+ const transformerOutput = transformer.forward(attnInput);
153
+ assert(transformerOutput.shape[0] === 2 && transformerOutput.shape[2] === 64, "Transformer forward pass");
154
+
155
+ // LayerNorm
156
+ const layerNorm = torch.layers.layerNorm(64);
157
+ const normInput = new Tensor(new Float32Array(2 * 10 * 64), [2, 10, 64]);
158
+ const normOutput = layerNorm.forward(normInput);
159
+ assert(normOutput.shape[0] === 2 && normOutput.shape[2] === 64, "LayerNorm forward pass");
160
+ }
161
+
162
+ // ====================== LOSS FUNCTIONS TESTS ======================
163
+ testSuite("Loss Functions");
164
+
165
+ function testLossFunctions() {
166
+ // MSE Loss
167
+ const mse = torch.loss.mse();
168
+ const pred = new Tensor([1, 2, 3, 4]);
169
+ const target = new Tensor([0, 2, 3, 5]);
170
+ const mseLoss = mse.forward(pred, target);
171
+ assert(Math.abs(mseLoss.data[0] - 0.5) < 0.001, "MSE Loss calculation");
172
+
173
+ // CrossEntropy Loss
174
+ const ce = torch.loss.crossEntropy();
175
+ const cePred = new Tensor([0.1, 0.9, 0.1, 0.9], [2, 2]);
176
+ const ceTarget = new Tensor([0, 1, 1, 0], [2, 2]);
177
+ const ceLoss = ce.forward(cePred, ceTarget);
178
+ assert(ceLoss.data[0] > 0, "CrossEntropy Loss calculation");
179
+
180
+ // Huber Loss
181
+ const huber = torch.loss.huber(1.0);
182
+ const huberLoss = huber.forward(pred, target);
183
+ assert(huberLoss.data[0] > 0, "Huber Loss calculation");
184
+
185
+ // Triplet Loss
186
+ const triplet = torch.loss.tripletLoss(1.0);
187
+ const anchor = new Tensor([1, 2]);
188
+ const positive = new Tensor([1.1, 2.1]);
189
+ const negative = new Tensor([3, 4]);
190
+ const tripletLoss = triplet.forward(anchor, positive, negative);
191
+ assert(tripletLoss.data[0] > 0, "Triplet Loss calculation");
192
+ }
193
+
194
+ // ====================== OPTIMIZERS TESTS ======================
195
+ testSuite("Optimizers");
196
+
197
+ function testOptimizers() {
198
+ // Create a simple model
199
+ const model = torch.nn.sequential([
200
+ torch.layers.linear(2, 3),
201
+ torch.layers.relu(),
202
+ torch.layers.linear(3, 1)
203
+ ]);
204
+
205
+ // SGD
206
+ const sgd = torch.optim.sgd(0.01);
207
+ const params = model.getParameters();
208
+ const initialWeights = params[0].data[0];
209
+
210
+ sgd.step(params);
211
+ assert(params[0].data[0] !== initialWeights, "SGD parameter update");
212
+
213
+ // Adam
214
+ const adam = torch.optim.adam(0.001);
215
+ const adamInitialWeights = params[0].data[0];
216
+
217
+ adam.step(params);
218
+ assert(params[0].data[0] !== adamInitialWeights, "Adam parameter update");
219
+
220
+ // AdamW
221
+ const adamw = torch.optim.adamw(0.001, 0.9, 0.999, 0.01);
222
+ const adamwInitialWeights = params[0].data[0];
223
+
224
+ adamw.step(params);
225
+ assert(params[0].data[0] !== adamwInitialWeights, "AdamW parameter update");
226
+
227
+ // LAMB
228
+ const lamb = torch.optim.lamb(0.001);
229
+ const lambInitialWeights = params[0].data[0];
230
+
231
+ lamb.step(params);
232
+ assert(params[0].data[0] !== lambInitialWeights, "LAMB parameter update");
233
+
234
+ // RMSprop
235
+ const rmsprop = torch.optim.rmsprop(0.01);
236
+ const rmspropInitialWeights = params[0].data[0];
237
+
238
+ rmsprop.step(params);
239
+ assert(params[0].data[0] !== rmspropInitialWeights, "RMSprop parameter update");
240
+ }
241
+
242
+ // ====================== NEURAL NETWORK TESTS ======================
243
+ testSuite("Neural Network");
244
+
245
+ function testNeuralNetwork() {
246
+ // Sequential model
247
+ const model = torch.nn.sequential();
248
+ model.add(torch.layers.linear(4, 8));
249
+ model.add(torch.layers.relu());
250
+ model.add(torch.layers.dropout(0.2));
251
+ model.add(torch.layers.linear(8, 2));
252
+
253
+ const input = new Tensor([1, 2, 3, 4], [1, 4]);
254
+ const output = model.forward(input);
255
+ assert(output.shape[0] === 1 && output.shape[1] === 2, "Sequential model forward pass");
256
+
257
+ // Backward pass
258
+ const gradOutput = new Tensor([0.1, 0.2], [1, 2]);
259
+ const gradInput = model.backward(gradOutput);
260
+ assert(gradInput.shape[0] === 1 && gradInput.shape[1] === 4, "Sequential model backward pass");
261
+
262
+ // Get parameters
263
+ const params = model.getParameters();
264
+ assert(params.length > 0, "Sequential model parameter extraction");
265
+
266
+ // Test individual layer parameters
267
+ const linearLayer = torch.layers.linear(3, 2);
268
+ const linearParams = linearLayer.getParameters();
269
+ assert(linearParams.length === 2, "Linear layer parameter count");
270
+ }
271
+
272
+ // ====================== UTILS TESTS ======================
273
+ testSuite("Utils");
274
+
275
+ function testUtils() {
276
+ // DataLoader
277
+ const data = {
278
+ inputs: new Float32Array([1, 2, 3, 4, 5, 6, 7, 8]),
279
+ targets: new Float32Array([0, 1, 0, 1]),
280
+ inputShape: [2],
281
+ targetShape: [1]
282
+ };
283
+
284
+ const loader = torch.utils.dataLoader(data, 2, false);
285
+ const batches = [...loader];
286
+ assert(batches.length === 2, "DataLoader batch creation");
287
+ assert(batches[0].inputs.shape[0] === 2, "DataLoader batch size");
288
+
289
+ // One-hot encoding
290
+ const labels = [0, 2, 1];
291
+ const oneHot = torch.utils.oneHot(labels, 3);
292
+ assert(oneHot.shape[0] === 3 && oneHot.shape[1] === 3, "One-hot encoding shape");
293
+ assert(oneHot.data[0] === 1 && oneHot.data[4] === 1, "One-hot encoding values");
294
+
295
+ // Accuracy calculation
296
+ const pred = new Tensor([0.9, 0.1, 0.7, 0.3], [2, 2]);
297
+ const target = new Tensor([1, 0, 1, 0], [2, 2]);
298
+ const accuracy = torch.utils.accuracy(pred, target);
299
+ assert(accuracy === 1.0, "Accuracy calculation");
300
+
301
+ // Benchmark
302
+ const simpleModel = torch.nn.sequential([
303
+ torch.layers.linear(10, 5),
304
+ torch.layers.relu()
305
+ ]);
306
+ const benchmarkInput = new Tensor(new Float32Array(10), [1, 10]);
307
+ const benchmark = torch.utils.benchmark(simpleModel, benchmarkInput, 10);
308
+ assert(benchmark.avgTime > 0, "Benchmark execution");
309
+ assert(benchmark.minTime > 0, "Benchmark min time");
310
+ assert(benchmark.maxTime > 0, "Benchmark max time");
311
+
312
+ // Profile function
313
+ const profileResult = torch.utils.profile(() => {
314
+ const x = new Tensor([1, 2, 3]);
315
+ return x.relu();
316
+ });
317
+ assert(profileResult !== undefined, "Profile function execution");
318
+ }
319
+
320
+ // ====================== ADVANCED FEATURES TESTS ======================
321
+ testSuite("Advanced Features");
322
+
323
+ function testAdvancedFeatures() {
324
+ // Quantization
325
+ const quant = torch.quant;
326
+ const tensor = new Tensor([0.1, 0.5, 0.9]);
327
+ const quantized = quant.quantize(tensor, 'int8');
328
+ assert(quantized.dtype === 'int8', "Tensor quantization");
329
+
330
+ const dequantized = quant.dequantize(quantized);
331
+ assert(dequantized.dtype === 'float32', "Tensor dequantization");
332
+
333
+ // Automatic Mixed Precision
334
+ const amp = torch.amp;
335
+ const ampModel = torch.nn.sequential([
336
+ torch.layers.linear(4, 2)
337
+ ]);
338
+ const ampInput = new Tensor([1, 2, 3, 4], [1, 4]);
339
+ const ampOutput = amp.forward(ampModel, ampInput);
340
+ assert(ampOutput.dtype === 'float32', "AMP forward pass");
341
+
342
+ // Visualization
343
+ const viz = torch.viz;
344
+ viz.init();
345
+ viz.logLoss(new Tensor([0.5]));
346
+ viz.logAccuracy(0.8);
347
+ viz.logLearningRate(0.001);
348
+ const metrics = viz.getMetrics();
349
+ assert(metrics.loss.length > 0, "Visualization metrics logging");
350
+
351
+ // Learning Rate Schedulers
352
+ const optimizer = torch.optim.sgd(0.1);
353
+ const scheduler = new StepLR(optimizer, stepSize=2, gamma=0.1);
354
+
355
+ let lr = scheduler.step();
356
+ assert(lr === 0.1, "StepLR initial learning rate");
357
+
358
+ lr = scheduler.step();
359
+ assert(lr === 0.1, "StepLR before step");
360
+
361
+ lr = scheduler.step();
362
+ assert(lr === 0.01, "StepLR after step");
363
+
364
+ // ExponentialLR
365
+ const expScheduler = new ExponentialLR(optimizer, gamma=0.5);
366
+ lr = expScheduler.step();
367
+ assert(lr === 0.005, "ExponentialLR learning rate");
368
+
369
+ // CosineAnnealingLR
370
+ const cosScheduler = new CosineAnnealingLR(optimizer, T_max=10);
371
+ lr = cosScheduler.step();
372
+ assert(lr > 0 && lr < 0.1, "CosineAnnealingLR learning rate");
373
+
374
+ // Gradient Clipping
375
+ const params = [new Tensor([1, 2, 3], [3], true)];
376
+ params[0].grad = new Tensor([0.5, 1.5, 0.8]);
377
+
378
+ const norm = clipGradNorm_(params, 1.0);
379
+ assert(norm > 0, "Gradient norm calculation");
380
+
381
+ clipGradValue_(params, 1.0);
382
+ assert(params[0].grad.data[1] <= 1.0, "Gradient value clipping");
383
+
384
+ // Tensor map function (synchronous)
385
+ const mapTensor = new Tensor([1, 2, 3, 4]);
386
+ const mapped = mapTensor.map(x => x * 2);
387
+ assert(mapped.data[0] === 2 && mapped.data[3] === 8, "Tensor map function");
388
+ }
389
+
390
+ // ====================== INTEGRATION TESTS ======================
391
+ testSuite("Integration");
392
+
393
+ function testIntegration() {
394
+ // Complete training simulation
395
+ const model = torch.nn.sequential([
396
+ torch.layers.linear(4, 8),
397
+ torch.layers.relu(),
398
+ torch.layers.dropout(0.2),
399
+ torch.layers.linear(8, 2)
400
+ ]);
401
+
402
+ const optimizer = torch.optim.adam(0.01);
403
+ const lossFn = torch.loss.mse();
404
+ const scheduler = new StepLR(optimizer, stepSize=10, gamma=0.1);
405
+
406
+ // Create dummy data
407
+ const data = {
408
+ inputs: new Float32Array(100 * 4),
409
+ targets: new Float32Array(100 * 2),
410
+ inputShape: [4],
411
+ targetShape: [2]
412
+ };
413
+
414
+ // Fill with random data
415
+ for (let i = 0; i < data.inputs.length; i++) {
416
+ data.inputs[i] = Math.random();
417
+ data.targets[i] = Math.random();
418
+ }
419
+
420
+ const loader = torch.utils.dataLoader(data, 10);
421
+
422
+ // Training loop
423
+ let initialLoss = Infinity;
424
+ let finalLoss = 0;
425
+
426
+ for (let epoch = 0; epoch < 3; epoch++) {
427
+ for (const batch of loader) {
428
+ // Forward
429
+ const output = model.forward(batch.inputs);
430
+ const loss = lossFn.forward(output, batch.targets);
431
+
432
+ if (epoch === 0) {
433
+ initialLoss = loss.data[0];
434
+ }
435
+
436
+ // Backward
437
+ lossFn.backward();
438
+ model.backward(loss.grad);
439
+
440
+ // Update
441
+ optimizer.step(model.getParameters());
442
+ model.getParameters().forEach(p => p.zeroGrad());
443
+ }
444
+
445
+ scheduler.step();
446
+ }
447
+
448
+ // Check final loss
449
+ const finalBatch = [...loader][0];
450
+ const finalOutput = model.forward(finalBatch.inputs);
451
+ const finalLossTensor = lossFn.forward(finalOutput, finalBatch.targets);
452
+ finalLoss = finalLossTensor.data[0];
453
+
454
+ assert(finalLoss < initialLoss, "Training reduces loss");
455
+ console.log(`Initial loss: ${initialLoss.toFixed(4)}, Final loss: ${finalLoss.toFixed(4)}`);
456
+
457
+ // Model saving/loading simulation
458
+ const modelData = {
459
+ layers: model.layers.map(layer => ({
460
+ type: layer.constructor.name,
461
+ params: layer.getParameters().map(param => ({
462
+ data: Array.from(param.data),
463
+ shape: param.shape,
464
+ requiresGrad: param.requiresGrad
465
+ }))
466
+ }))
467
+ };
468
+
469
+ assert(modelData.layers.length > 0, "Model serialization");
470
+
471
+ // Test distributed initialization (sync version)
472
+ const distributed = torch.distributed;
473
+ distributed.init(1, 0); // Single process
474
+ assert(distributed.worldSize === 1, "Distributed initialization");
475
+ }
476
+
477
+ // ====================== PERFORMANCE TESTS ======================
478
+ testSuite("Performance");
479
+
480
+ function testPerformance() {
481
+ // Large matrix multiplication performance
482
+ console.log("Testing large matrix multiplication...");
483
+ const size = 256;
484
+ const a = new Tensor(new Float32Array(size * size), [size, size]);
485
+ const b = new Tensor(new Float32Array(size * size), [size, size]);
486
+
487
+ const start = performance.now();
488
+ const c = a.matmul(b);
489
+ const end = performance.now();
490
+
491
+ const timeMs = end - start;
492
+ console.log(`Matrix multiplication (${size}x${size}): ${timeMs.toFixed(2)}ms`);
493
+ assert(timeMs < 1000, "Large matrix multiplication performance");
494
+
495
+ // Memory usage
496
+ const memoryStats = Tensor.memoryTracker.getStats();
497
+ console.log(`Memory usage: ${(memoryStats.totalMemory / 1024 / 1024).toFixed(2)}MB`);
498
+ assert(memoryStats.totalMemory < 100 * 1024 * 1024, "Memory usage within limits");
499
+
500
+ // Synchronous map test
501
+ console.log("Testing synchronous map...");
502
+ const largeTensor = new Tensor(new Float32Array(10000), [10000]);
503
+
504
+ const mapStart = performance.now();
505
+ const result = largeTensor.map(x => x * 2);
506
+ const mapEnd = performance.now();
507
+
508
+ const mapTime = mapEnd - mapStart;
509
+ console.log(`Synchronous map (10000 elements): ${mapTime.toFixed(2)}ms`);
510
+ assert(mapTime < 100, "Synchronous map performance");
511
+
512
+ // Convolution performance
513
+ console.log("Testing convolution performance...");
514
+ const conv = torch.layers.conv2d(3, 16, 3);
515
+ const convInput = new Tensor(new Float32Array(1 * 3 * 64 * 64), [1, 3, 64, 64]);
516
+
517
+ const convStart = performance.now();
518
+ const convOutput = conv.forward(convInput);
519
+ const convEnd = performance.now();
520
+
521
+ const convTime = convEnd - convStart;
522
+ console.log(`Convolution (1x3x64x64): ${convTime.toFixed(2)}ms`);
523
+ assert(convTime < 500, "Convolution performance");
524
+
525
+ // LSTM performance
526
+ console.log("Testing LSTM performance...");
527
+ const lstm = torch.layers.lstm(64, 128, 2);
528
+ const lstmInput = new Tensor(new Float32Array(10 * 4 * 64), [10, 4, 64]);
529
+
530
+ const lstmStart = performance.now();
531
+ const lstmOutput = lstm.forward(lstmInput);
532
+ const lstmEnd = performance.now();
533
+
534
+ const lstmTime = lstmEnd - lstmStart;
535
+ console.log(`LSTM (10x4x64): ${lstmTime.toFixed(2)}ms`);
536
+ assert(lstmTime < 1000, "LSTM performance");
537
+ }
538
+
539
+ // ====================== RUN ALL TESTS ======================
540
+ function runAllTests() {
541
+ console.log("šŸš€ Starting mini-jstorch test suite...\n");
542
+
543
+ try {
544
+ testTensorOperations();
545
+ testLayers();
546
+ testLossFunctions();
547
+ testOptimizers();
548
+ testNeuralNetwork();
549
+ testUtils();
550
+ testAdvancedFeatures();
551
+ testIntegration();
552
+ testPerformance();
553
+
554
+ summarize();
555
+
556
+ if (testsFailed === 0) {
557
+ console.log("\nšŸŽ‰ All tests passed! mini-jstorch is working correctly!");
558
+ } else {
559
+ console.log(`\nāš ļø ${testsFailed} test(s) failed. Please check the implementation.`);
560
+ }
561
+
562
+ } catch (error) {
563
+ console.error("\nšŸ’„ Test suite crashed:", error);
564
+ testsFailed++;
565
+ summarize();
566
+ }
567
+ }
568
+
569
+ // Run the tests
570
+ runAllTests();