catniff 0.1.9 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js ADDED
@@ -0,0 +1,780 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Tensor = void 0;
4
+ class Tensor {
5
+ value;
6
+ shape;
7
+ strides;
8
+ grad;
9
+ requiresGrad;
10
+ gradFn;
11
+ children;
12
+ constructor(value, options = {}) {
13
+ this.value = Tensor.flatten(value);
14
+ this.shape = options.shape || Tensor.getShape(value);
15
+ this.strides = options.strides || Tensor.getStrides(this.shape);
16
+ this.grad = options.grad;
17
+ this.requiresGrad = options.requiresGrad ?? false;
18
+ this.gradFn = options.gradFn || (() => { });
19
+ this.children = options.children || [];
20
+ }
21
+ // Utility to flatten an nD array to be 1D
22
+ static flatten(tensor) {
23
+ if (typeof tensor === "number")
24
+ return tensor;
25
+ const result = [];
26
+ function traverse(arr) {
27
+ if (typeof arr === "number") {
28
+ result.push(arr);
29
+ }
30
+ else if (Array.isArray(arr)) {
31
+ arr.forEach(traverse);
32
+ }
33
+ }
34
+ traverse(tensor);
35
+ return result;
36
+ }
37
+ // Utility to get shape from tensor *value*
38
+ static getShape(tensor) {
39
+ const shape = [];
40
+ let subA = tensor;
41
+ while (Array.isArray(subA)) {
42
+ shape.push(subA.length);
43
+ subA = subA[0];
44
+ }
45
+ return shape;
46
+ }
47
+ // Utility to get strides from shape
48
+ static getStrides(shape) {
49
+ const strides = new Array(shape.length);
50
+ strides[strides.length - 1] = 1;
51
+ for (let i = strides.length - 2; i >= 0; i--) {
52
+ strides[i] = strides[i + 1] * shape[i + 1];
53
+ }
54
+ return strides;
55
+ }
56
+ // Left-pad dimensions for two shape to be of same length
57
+ static padDims(shapeA, shapeB) {
58
+ const newA = [...shapeA], newB = [...shapeB];
59
+ while (newA.length < newB.length) {
60
+ newA.unshift(1);
61
+ }
62
+ while (newA.length > newB.length) {
63
+ newB.unshift(1);
64
+ }
65
+ return [newA, newB];
66
+ }
67
+ // Broadcast shapes
68
+ static broadcastShapes(shapeA, shapeB) {
69
+ const newShape = new Array(shapeA.length);
70
+ for (let index = 0; index < shapeA.length; index++) {
71
+ if (shapeA[index] === 1) {
72
+ newShape[index] = shapeB[index];
73
+ }
74
+ else if (shapeB[index] === 1) {
75
+ newShape[index] = shapeA[index];
76
+ }
77
+ else if (shapeA[index] === shapeB[index]) {
78
+ newShape[index] = shapeA[index];
79
+ }
80
+ else {
81
+ throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`);
82
+ }
83
+ }
84
+ return newShape;
85
+ }
86
+ // Convert flat index to array of coordinates
87
+ static indexToCoords(index, shape, strides) {
88
+ const coords = new Array(shape.length);
89
+ let remaining = index;
90
+ // Sort dimensions by stride (largest first) for correct decomposition
91
+ const sortedDims = shape.map((_, i) => i).sort((a, b) => strides[b] - strides[a]);
92
+ for (const dim of sortedDims) {
93
+ coords[dim] = Math.floor(remaining / strides[dim]);
94
+ remaining %= strides[dim];
95
+ }
96
+ return coords;
97
+ }
98
+ // Convert array of coordinates to *unbroadcasted* flat index
99
+ static coordsToIndex(coords, shape, strides) {
100
+ let index = 0;
101
+ for (let i = 0; i < coords.length; i++) {
102
+ const coord = coords[i];
103
+ // Handle broadcasting
104
+ const actualCoord = shape[i] === 1 ? 0 : coord;
105
+ index += actualCoord * strides[i];
106
+ }
107
+ return index;
108
+ }
109
+ // Utility for binary (two operators involved) element-wise ops
110
+ static elementWiseAB(tA, tB, op) {
111
+ if (typeof tA.value === "number" && typeof tB.value === "number") {
112
+ return new Tensor(op(tA.value, tB.value));
113
+ }
114
+ if (typeof tA.value === "number") {
115
+ return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
116
+ }
117
+ if (typeof tB.value === "number") {
118
+ return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
119
+ }
120
+ // Pad + broadcast shape
121
+ const [paddedAShape, paddedBShape] = Tensor.padDims(tA.shape, tB.shape);
122
+ const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
123
+ // Get other output info
124
+ const outputStrides = Tensor.getStrides(outputShape);
125
+ const outputSize = outputShape.reduce((a, b) => a * b, 1);
126
+ const outputValue = new Array(outputSize);
127
+ for (let i = 0; i < outputSize; i++) {
128
+ // Get coordinates from 1D index
129
+ const coordsOutput = Tensor.indexToCoords(i, outputShape, outputStrides);
130
+ // Convert the coordinates to 1D index of flattened A with respect to A's shape
131
+ const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape, Tensor.getStrides(paddedAShape));
132
+ // Convert the coordinates to 1D index of flattened B with respect to B's shape
133
+ const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape, Tensor.getStrides(paddedBShape));
134
+ // Calculate with op
135
+ outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
136
+ }
137
+ return new Tensor(outputValue, {
138
+ shape: outputShape,
139
+ strides: outputStrides
140
+ });
141
+ }
142
+ // Utility for self-inflicting element-wise ops
143
+ static elementWiseSelf(tA, op) {
144
+ if (typeof tA.value === "number")
145
+ return new Tensor(op(tA.value));
146
+ return new Tensor(tA.value.map(el => op(el)), { shape: tA.shape, strides: tA.strides });
147
+ }
148
+ // Utility to do element-wise operation and build a dag node with another tensor
149
+ elementWiseABDAG(other, op, thisGrad, otherGrad) {
150
+ other = Tensor.forceTensor(other);
151
+ const out = Tensor.elementWiseAB(this, other, op);
152
+ if (this.requiresGrad) {
153
+ out.requiresGrad = true;
154
+ out.children.push(this);
155
+ }
156
+ if (other.requiresGrad) {
157
+ out.requiresGrad = true;
158
+ out.children.push(other);
159
+ }
160
+ if (out.requiresGrad) {
161
+ out.gradFn = () => {
162
+ const outGrad = out.grad;
163
+ if (this.requiresGrad)
164
+ thisGrad(this, other, outGrad);
165
+ if (other.requiresGrad)
166
+ otherGrad(this, other, outGrad);
167
+ };
168
+ }
169
+ return out;
170
+ }
171
+ // Utility to do self-inflicting element-wise operation and build a dag node
172
+ elementWiseSelfDAG(op, thisGrad) {
173
+ const out = Tensor.elementWiseSelf(this, op);
174
+ if (this.requiresGrad) {
175
+ out.requiresGrad = true;
176
+ out.children.push(this);
177
+ }
178
+ if (out.requiresGrad) {
179
+ out.gradFn = () => {
180
+ const outGrad = out.grad;
181
+ if (this.requiresGrad)
182
+ thisGrad(this, outGrad);
183
+ };
184
+ }
185
+ return out;
186
+ }
187
+ // Utility to force an input value to be a tensor
188
+ static forceTensor(value) {
189
+ if (value instanceof Tensor)
190
+ return value;
191
+ return new Tensor(value);
192
+ }
193
+ // Utility to add to gradient of tensor
194
+ static addGrad(tensor, accumGrad) {
195
+ const axesToSqueeze = [];
196
+ const axesToReduce = [];
197
+ const shape = tensor.shape;
198
+ const gradShape = accumGrad.shape;
199
+ const paddedDims = gradShape.length - shape.length;
200
+ for (let i = 0; i < paddedDims; i++) {
201
+ axesToReduce.push(i);
202
+ axesToSqueeze.push(i);
203
+ }
204
+ for (let i = 0; i < shape.length; i++) {
205
+ if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
206
+ axesToReduce.push(i + paddedDims);
207
+ }
208
+ }
209
+ const reducedGrad = accumGrad.sum(axesToReduce, true);
210
+ // console.log(accumGrad, new Tensor([[1,1,1]]));
211
+ const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
212
+ if (typeof tensor.grad === "undefined") {
213
+ tensor.grad = squeezedGrad;
214
+ }
215
+ else {
216
+ tensor.grad = tensor.grad.add(squeezedGrad);
217
+ }
218
+ }
219
+ // Tensor squeeze
220
+ squeeze(dims) {
221
+ if (typeof this.value === "number")
222
+ return new Tensor(this.value);
223
+ if (typeof dims === "number") {
224
+ dims = [dims];
225
+ }
226
+ if (typeof dims === "undefined") {
227
+ const shape = this.shape;
228
+ dims = [];
229
+ for (let index = 0; index < shape.length; index++) {
230
+ if (shape[index] === 1) {
231
+ dims.push(index);
232
+ }
233
+ }
234
+ }
235
+ const outShape = this.shape.filter((dim, i) => {
236
+ const shouldSqueeze = dims.includes(i);
237
+ if (shouldSqueeze && dim !== 1)
238
+ throw new Error(`Can not squeeze dim with size ${dim}`);
239
+ return !shouldSqueeze;
240
+ });
241
+ const outStrides = Tensor.getStrides(outShape);
242
+ const outValue = outShape.length === 0 ? this.value[0] : this.value;
243
+ const out = new Tensor(outValue, {
244
+ shape: outShape,
245
+ strides: outStrides
246
+ });
247
+ // Set up gradient if needed
248
+ if (this.requiresGrad) {
249
+ out.requiresGrad = true;
250
+ out.children.push(this);
251
+ out.gradFn = () => {
252
+ let restoredGrad = out.grad;
253
+ for (let i = dims.length - 1; i >= 0; i--) {
254
+ restoredGrad = restoredGrad.unsqueeze(dims[i]);
255
+ }
256
+ Tensor.addGrad(this, restoredGrad);
257
+ };
258
+ }
259
+ return out;
260
+ }
261
+ // Tensor unsqueeze - adds dimension of size 1 at specified position
262
+ unsqueeze(dim) {
263
+ if (typeof this.value === "number")
264
+ return new Tensor([this.value]);
265
+ if (dim < 0 || dim > this.shape.length) {
266
+ throw new Error(`Invalid dimension ${dim} for unsqueeze`);
267
+ }
268
+ // Insert size-1 dimension at specified position
269
+ const newShape = [...this.shape];
270
+ newShape.splice(dim, 0, 1);
271
+ // Insert appropriate stride for new dimension
272
+ const newStrides = [...this.strides];
273
+ // For dimension of size 1, stride can be any value since we never step through it
274
+ // Use the stride of the next dimension, or 1 if it's the last dimension
275
+ const newStride = dim < this.strides.length ? this.strides[dim] : 1;
276
+ newStrides.splice(dim, 0, newStride);
277
+ const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
278
+ // Set up gradient if needed
279
+ if (this.requiresGrad) {
280
+ out.requiresGrad = true;
281
+ out.children.push(this);
282
+ out.gradFn = () => {
283
+ Tensor.addGrad(this, out.grad.squeeze(dim));
284
+ };
285
+ }
286
+ return out;
287
+ }
288
+ // Tensor sum reduction
289
+ sum(dims, keepDims = false) {
290
+ if (typeof this.value === "number")
291
+ return new Tensor(this.value);
292
+ if (typeof dims === "number") {
293
+ dims = [dims];
294
+ }
295
+ if (typeof dims === "undefined") {
296
+ dims = Array.from({ length: this.shape.length }, (_, index) => index);
297
+ }
298
+ const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
299
+ const outputStrides = Tensor.getStrides(outputShape);
300
+ const outputSize = outputShape.reduce((a, b) => a * b, 1);
301
+ const outputValue = new Array(outputSize).fill(0);
302
+ const originalSize = this.shape.reduce((a, b) => a * b, 1);
303
+ for (let index = 0; index < originalSize; index++) {
304
+ const coords = Tensor.indexToCoords(index, this.shape, this.strides);
305
+ // Force 0 on reduced axes to collapse into size-1 dims
306
+ const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
307
+ // Convert output coordinates to flat index
308
+ const outFlatIndex = outCoords.reduce((acc, val, i) => acc + val * outputStrides[i], 0);
309
+ // Accumulate
310
+ const realFlatIndex = coords.reduce((acc, val, i) => acc + val * this.strides[i], 0);
311
+ outputValue[outFlatIndex] += this.value[realFlatIndex];
312
+ }
313
+ const out = new Tensor(outputValue, {
314
+ shape: outputShape,
315
+ strides: outputStrides
316
+ });
317
+ return keepDims ? out : out.squeeze(dims);
318
+ }
319
+ // Tensor element-wise addition
320
+ add(other) {
321
+ return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => {
322
+ Tensor.addGrad(self, outGrad);
323
+ }, (self, other, outGrad) => {
324
+ Tensor.addGrad(other, outGrad);
325
+ });
326
+ }
327
+ // Tensor element-wise subtraction
328
+ sub(other) {
329
+ return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => {
330
+ Tensor.addGrad(self, outGrad);
331
+ }, (self, other, outGrad) => {
332
+ Tensor.addGrad(other, outGrad.neg());
333
+ });
334
+ }
335
+ // Tensor element-wise multiplication
336
+ mul(other) {
337
+ return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => {
338
+ Tensor.addGrad(self, outGrad.mul(other));
339
+ }, (self, other, outGrad) => {
340
+ Tensor.addGrad(other, outGrad.mul(self));
341
+ });
342
+ }
343
+ // Tensor element-wise power
344
+ pow(other) {
345
+ return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => {
346
+ Tensor.addGrad(self, outGrad.mul(other.mul(self.pow(other.sub(1)))));
347
+ }, (self, other, outGrad) => {
348
+ Tensor.addGrad(other, outGrad.mul(self.pow(other).mul(self.log())));
349
+ });
350
+ }
351
+ // Tensor element-wise division
352
+ div(other) {
353
+ return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => {
354
+ Tensor.addGrad(self, outGrad.div(other));
355
+ }, (self, other, outGrad) => {
356
+ Tensor.addGrad(other, outGrad.mul(self.neg().div(other.pow(2))));
357
+ });
358
+ }
359
+ // Tensor element-wise greater or equal comparison
360
+ ge(other) {
361
+ return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
362
+ }
363
+ // Tensor element-wise less or equal comparison
364
+ le(other) {
365
+ return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
366
+ }
367
+ // Tensor element-wise greater-than comparison
368
+ gt(other) {
369
+ return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
370
+ }
371
+ // Tensor element-wise less-than comparison
372
+ lt(other) {
373
+ return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
374
+ }
375
+ // Tensor element-wise equality comparison
376
+ eq(other) {
377
+ return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
378
+ }
379
+ // Tensor element-wise logical and
380
+ logicalAnd(other) {
381
+ return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
382
+ }
383
+ // Tensor element-wise logical or
384
+ logicalOr(other) {
385
+ return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
386
+ }
387
+ // Tensor element-wise logical xor
388
+ logicalXor(other) {
389
+ return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
390
+ }
391
+ // Tensor element-wise logical not
392
+ logicalNot() {
393
+ return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1, (self, outGrad) => { });
394
+ }
395
+ // Tensor element-wise bitwise and
396
+ bitwiseAnd(other) {
397
+ return this.elementWiseABDAG(other, (a, b) => a & b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
398
+ }
399
+ // Tensor element-wise bitwise or
400
+ bitwiseOr(other) {
401
+ return this.elementWiseABDAG(other, (a, b) => a | b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
402
+ }
403
+ // Tensor element-wise bitwise xor
404
+ bitwiseXor(other) {
405
+ return this.elementWiseABDAG(other, (a, b) => a ^ b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
406
+ }
407
+ // Tensor element-wise bitwise not
408
+ bitwiseNot() {
409
+ return this.elementWiseSelfDAG((a) => ~a, (self, outGrad) => { });
410
+ }
411
+ // Tensor element-wise left shift
412
+ bitwiseLeftShift(other) {
413
+ return this.elementWiseABDAG(other, (a, b) => a << b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
414
+ }
415
+ // Tensor element-wise right shift
416
+ bitwiseRightShift(other) {
417
+ return this.elementWiseABDAG(other, (a, b) => a >> b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
418
+ }
419
+ // Tensor element-wise negation
420
+ neg() {
421
+ return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => {
422
+ Tensor.addGrad(self, outGrad.mul(-1));
423
+ });
424
+ }
425
+ // Tensor element-wise absolute
426
+ abs() {
427
+ return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => {
428
+ Tensor.addGrad(self, outGrad.mul(self.sign()));
429
+ });
430
+ }
431
+ // Tensor element-wise sign function
432
+ sign() {
433
+ return this.elementWiseSelfDAG((a) => Math.sign(a), (self, outGrad) => { });
434
+ }
435
+ // Tensor element-wise sin
436
+ sin() {
437
+ return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => {
438
+ Tensor.addGrad(self, outGrad.mul(self.cos()));
439
+ });
440
+ }
441
+ // Tensor element-wise cos
442
+ cos() {
443
+ return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => {
444
+ Tensor.addGrad(self, outGrad.mul(self.sin().neg()));
445
+ });
446
+ }
447
+ // Tensor element-wise tan
448
+ tan() {
449
+ return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => {
450
+ Tensor.addGrad(self, outGrad.mul(self.tan().pow(2).add(1)));
451
+ });
452
+ }
453
+ // Tensor element-wise asin
454
+ asin() {
455
+ return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => {
456
+ Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()));
457
+ });
458
+ }
459
+ // Tensor element-wise acos
460
+ acos() {
461
+ return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => {
462
+ Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
463
+ });
464
+ }
465
+ // Tensor element-wise atan
466
+ atan() {
467
+ return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => {
468
+ Tensor.addGrad(self, outGrad.div(self.pow(2).add(1)));
469
+ });
470
+ }
471
+ // Tensor element-wise sinh
472
+ sinh() {
473
+ return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => {
474
+ Tensor.addGrad(self, outGrad.mul(self.cosh()));
475
+ });
476
+ }
477
+ // Tensor element-wise cosh
478
+ cosh() {
479
+ return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => {
480
+ Tensor.addGrad(self, outGrad.mul(self.sinh()));
481
+ });
482
+ }
483
+ // Tensor element-wise asinh
484
+ asinh() {
485
+ return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => {
486
+ Tensor.addGrad(self, outGrad.div(self.pow(2).add(1).sqrt()));
487
+ });
488
+ }
489
+ // Tensor element-wise acosh
490
+ acosh() {
491
+ return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => {
492
+ Tensor.addGrad(self, outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
493
+ });
494
+ }
495
+ // Tensor element-wise atanh
496
+ atanh() {
497
+ return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => {
498
+ Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1)));
499
+ });
500
+ }
501
+ // Tensor element-wise square root
502
+ sqrt() {
503
+ return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => {
504
+ Tensor.addGrad(self, outGrad.div(self.sqrt().mul(2)));
505
+ });
506
+ }
507
+ // Tensor element-wise e^x
508
+ exp() {
509
+ return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => {
510
+ Tensor.addGrad(self, outGrad.mul(self.exp()));
511
+ });
512
+ }
513
+ // Tensor element-wise natural log
514
+ log() {
515
+ return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => {
516
+ Tensor.addGrad(self, outGrad.div(self));
517
+ });
518
+ }
519
+ // Tensor element-wise log2
520
+ log2() {
521
+ return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => {
522
+ Tensor.addGrad(self, outGrad.div(self.mul(Math.log(2))));
523
+ });
524
+ }
525
+ // Tensor element-wise log10
526
+ log10() {
527
+ return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => {
528
+ Tensor.addGrad(self, outGrad.div(self.mul(Math.log(10))));
529
+ });
530
+ }
531
+ // Tensor element-wise log(1+x)
532
+ log1p() {
533
+ return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => {
534
+ Tensor.addGrad(self, outGrad.div(self.add(1)));
535
+ });
536
+ }
537
+ // Tensor element-wise relu
538
+ relu() {
539
+ return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => {
540
+ Tensor.addGrad(self, outGrad.mul(self.ge(0)));
541
+ });
542
+ }
543
+ // Tensor element-wise sigmoid
544
+ sigmoid() {
545
+ return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
546
+ const sig = self.sigmoid();
547
+ Tensor.addGrad(self, outGrad.mul(sig).mul(sig.neg().add(1)));
548
+ });
549
+ }
550
+ // Tensor element-wise tanh
551
+ tanh() {
552
+ return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => {
553
+ Tensor.addGrad(self, outGrad.mul(self.tanh().pow(2).neg().add(1)));
554
+ });
555
+ }
556
+ // Transpose
557
+ transpose(dim1, dim2) {
558
+ // If dimension out of bound, throw error
559
+ if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
560
+ throw new Error("Dimensions do not exist to tranpose");
561
+ }
562
+ // If same dimension, return copy
563
+ if (dim1 === dim2) {
564
+ return new Tensor(this.value, { shape: this.shape, strides: this.strides });
565
+ }
566
+ // Create new shape and strides by swapping
567
+ const newShape = [...this.shape];
568
+ const newStrides = [...this.strides];
569
+ [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
570
+ [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
571
+ // Create new tensor with same data but swapped shape/strides
572
+ const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
573
+ out.requiresGrad = this.requiresGrad;
574
+ // Handle gradient if needed
575
+ if (this.requiresGrad) {
576
+ out.children.push(this);
577
+ out.gradFn = () => {
578
+ Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
579
+ };
580
+ }
581
+ return out;
582
+ }
583
+ // Transpose 2D
584
+ t() {
585
+ // Verify matrix shape
586
+ if (this.shape.length !== 2) {
587
+ throw new Error("Input is not a matrix");
588
+ }
589
+ return this.transpose(0, 1);
590
+ }
591
+ // 1D tensor dot product
592
+ dot(other) {
593
+ other = Tensor.forceTensor(other);
594
+ // Verify 1D shape
595
+ if (this.shape.length !== 1 || other.shape.length !== 1) {
596
+ throw new Error("Inputs are not 1D tensors");
597
+ }
598
+ const vectLen = this.shape[0];
599
+ const vectA = this.value;
600
+ const vectB = other.value;
601
+ let sum = 0;
602
+ for (let index = 0; index < vectLen; index++) {
603
+ sum += vectA[index] * vectB[index];
604
+ }
605
+ const out = new Tensor(sum);
606
+ if (this.requiresGrad) {
607
+ out.requiresGrad = true;
608
+ out.children.push(this);
609
+ }
610
+ if (other.requiresGrad) {
611
+ out.requiresGrad = true;
612
+ out.children.push(other);
613
+ }
614
+ if (out.requiresGrad) {
615
+ out.gradFn = () => {
616
+ const outGrad = out.grad;
617
+ if (this.requiresGrad)
618
+ Tensor.addGrad(this, outGrad.mul(other));
619
+ if (other.requiresGrad)
620
+ Tensor.addGrad(other, outGrad.mul(this));
621
+ };
622
+ }
623
+ return out;
624
+ }
625
+ // Matrix multiplication
626
+ mm(other) {
627
+ other = Tensor.forceTensor(other);
628
+ // Verify 2D shape
629
+ if (this.shape.length !== 2 || other.shape.length !== 2) {
630
+ throw new Error("Inputs are not matrices");
631
+ }
632
+ const matA = this.value;
633
+ const matB = other.value;
634
+ const matAStrides = this.strides;
635
+ const matBStrides = other.strides;
636
+ const matARows = this.shape[0];
637
+ const matACols = this.shape[1];
638
+ const matBRows = other.shape[0];
639
+ const matBCols = other.shape[1];
640
+ if (matACols !== matBRows)
641
+ throw new Error("Invalid matrices shape for multiplication");
642
+ const matCShape = [matARows, matBCols];
643
+ const matCStrides = Tensor.getStrides(matCShape);
644
+ const matCSize = matCShape.reduce((a, b) => a * b, 1);
645
+ const matC = new Array(matCSize).fill(0);
646
+ for (let i = 0; i < matARows; i++) {
647
+ for (let j = 0; j < matBCols; j++) {
648
+ for (let k = 0; k < matACols; k++) {
649
+ matC[i * matCStrides[0] + j * matCStrides[1]] +=
650
+ matA[i * matAStrides[0] + k * matAStrides[1]] *
651
+ matB[k * matBStrides[0] + j * matBStrides[1]];
652
+ }
653
+ }
654
+ }
655
+ const out = new Tensor(matC, { shape: matCShape, strides: matCStrides });
656
+ if (this.requiresGrad) {
657
+ out.requiresGrad = true;
658
+ out.children.push(this);
659
+ }
660
+ if (other.requiresGrad) {
661
+ out.requiresGrad = true;
662
+ out.children.push(other);
663
+ }
664
+ if (out.requiresGrad) {
665
+ out.gradFn = () => {
666
+ const outGrad = out.grad;
667
+ if (this.requiresGrad)
668
+ Tensor.addGrad(this, outGrad.mm(other.t()));
669
+ if (other.requiresGrad)
670
+ Tensor.addGrad(other, this.t().mm(outGrad));
671
+ };
672
+ }
673
+ return out;
674
+ }
675
+ mv(other) {
676
+ other = Tensor.forceTensor(other);
677
+ // Verify 2D shape
678
+ if (this.shape.length !== 2 || other.shape.length !== 1) {
679
+ throw new Error("Input is not a 2D and 1D tensor pair");
680
+ }
681
+ // MM with no grad
682
+ const thisMat = new Tensor(this.value, { shape: this.shape, strides: this.strides });
683
+ const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [1, 1] });
684
+ const out = thisMat.mm(otherMat).squeeze(1);
685
+ // Handle grad with original tensors
686
+ if (this.requiresGrad) {
687
+ out.requiresGrad = true;
688
+ out.children.push(this);
689
+ }
690
+ if (other.requiresGrad) {
691
+ out.requiresGrad = true;
692
+ out.children.push(other);
693
+ }
694
+ if (out.requiresGrad) {
695
+ out.gradFn = () => {
696
+ const outGrad = out.grad;
697
+ if (this.requiresGrad)
698
+ Tensor.addGrad(this, outGrad.unsqueeze(1).mm(other.unsqueeze(0)));
699
+ if (other.requiresGrad)
700
+ Tensor.addGrad(other, this.t().mv(outGrad));
701
+ };
702
+ }
703
+ return out;
704
+ }
705
+ matmul(other) {
706
+ other = Tensor.forceTensor(other);
707
+ if (this.shape.length === 1 && other.shape.length === 1) {
708
+ return this.dot(other);
709
+ }
710
+ else if (this.shape.length === 1 && other.shape.length === 2) {
711
+ return this.unsqueeze(0).mm(other).squeeze(0);
712
+ }
713
+ else if (this.shape.length === 2 && other.shape.length === 1) {
714
+ return this.mv(other);
715
+ }
716
+ else if (this.shape.length === 2 && other.shape.length === 2) {
717
+ return this.mm(other);
718
+ }
719
+ // Too lazy for batched matmul
720
+ throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
721
+ }
722
+ // Utility to create a new tensor with shape of another tensor, filled with a number
723
+ static fullLike(tensor, num, options = {}) {
724
+ if (typeof tensor.value === "number")
725
+ return new Tensor(num, options);
726
+ return new Tensor(tensor.value.map(el => num), { shape: tensor.shape, strides: tensor.strides, ...options });
727
+ }
728
+ // Reverse-mode autodiff call
729
+ backward() {
730
+ // Build topological order
731
+ const topo = [];
732
+ const visited = new Set();
733
+ function build(node) {
734
+ if (!visited.has(node) && node.requiresGrad) {
735
+ visited.add(node);
736
+ node.grad = Tensor.fullLike(node, 0);
737
+ for (let child of node.children)
738
+ build(child);
739
+ topo.push(node);
740
+ }
741
+ }
742
+ build(this);
743
+ // Feed backward to calculate gradient
744
+ this.grad = Tensor.fullLike(this, 1);
745
+ for (let index = topo.length - 1; index > -1; index--) {
746
+ topo[index].gradFn();
747
+ }
748
+ }
749
+ // Returns the number/nD array form of tensor
750
+ val() {
751
+ if (typeof this.value === "number")
752
+ return this.value;
753
+ function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
754
+ if (dim === shape.length - 1) {
755
+ // Last dimension: extract elements using actual stride
756
+ const result = [];
757
+ for (let i = 0; i < shape[dim]; i++) {
758
+ result.push(data[baseIndex + i * strides[dim]]);
759
+ }
760
+ return result;
761
+ }
762
+ // Recursive case: build nested structure
763
+ const result = [];
764
+ for (let i = 0; i < shape[dim]; i++) {
765
+ result.push(buildNested(data, shape, strides, baseIndex + i * strides[dim], dim + 1));
766
+ }
767
+ return result;
768
+ }
769
+ return buildNested(this.value, this.shape, this.strides);
770
+ }
771
+ // Returns a copy of the tensor with gradient turned on/off
772
+ withGrad(requiresGrad) {
773
+ return new Tensor(this.value, {
774
+ shape: this.shape,
775
+ strides: this.strides,
776
+ requiresGrad
777
+ });
778
+ }
779
+ }
780
+ exports.Tensor = Tensor;