catniff 0.5.11 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js CHANGED
@@ -6,6 +6,8 @@ class Tensor {
6
6
  value;
7
7
  shape;
8
8
  strides;
9
+ offset;
10
+ numel;
9
11
  grad;
10
12
  requiresGrad;
11
13
  gradFn;
@@ -13,14 +15,19 @@ class Tensor {
13
15
  device;
14
16
  static training = false;
15
17
  constructor(value, options = {}) {
18
+ // Storage
16
19
  this.value = Tensor.flatten(value);
20
+ // Tensor metadata
17
21
  this.shape = options.shape || Tensor.getShape(value);
18
22
  this.strides = options.strides || Tensor.getStrides(this.shape);
23
+ this.offset = options.offset || 0;
24
+ this.numel = options.numel || Tensor.shapeToSize(this.shape);
25
+ this.device = options.device || "cpu";
26
+ // Autograd data
19
27
  this.grad = options.grad;
20
28
  this.requiresGrad = options.requiresGrad ?? false;
21
29
  this.gradFn = options.gradFn || (() => { });
22
30
  this.children = options.children || [];
23
- this.device = options.device || "cpu";
24
31
  // Move to device in-place
25
32
  this.to_(this.device);
26
33
  }
@@ -96,7 +103,7 @@ class Tensor {
96
103
  newShape[index] = shapeA[index];
97
104
  }
98
105
  else {
99
- throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`);
106
+ throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`);
100
107
  }
101
108
  }
102
109
  return newShape;
@@ -164,22 +171,28 @@ class Tensor {
164
171
  // Convert the coordinates to 1D index of flattened B with respect to B's shape
165
172
  const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
166
173
  // Calculate with op
167
- outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
174
+ outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
168
175
  }
169
176
  return new Tensor(outputValue, {
170
177
  shape: outputShape,
171
- strides: outputStrides
178
+ strides: outputStrides,
179
+ numel: outputSize
172
180
  });
173
181
  }
174
182
  // Utility for self-inflicting element-wise ops
175
183
  static elementWiseSelf(tA, op) {
176
184
  if (typeof tA.value === "number")
177
185
  return new Tensor(op(tA.value));
178
- const newValue = new Array(tA.value.length);
179
- for (let index = 0; index < tA.value.length; index++) {
180
- newValue[index] = op(tA.value[index]);
186
+ const outputShape = tA.shape;
187
+ const outputStrides = Tensor.getStrides(outputShape);
188
+ const outputSize = tA.numel;
189
+ const outputValue = new Array(outputSize);
190
+ for (let index = 0; index < outputSize; index++) {
191
+ const outputCoords = Tensor.indexToCoords(index, outputStrides);
192
+ const originalIndex = tA.offset + Tensor.coordsToIndex(outputCoords, tA.strides);
193
+ outputValue[index] = op(tA.value[originalIndex]);
181
194
  }
182
- return new Tensor(newValue, { shape: tA.shape, strides: tA.strides });
195
+ return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: tA.numel });
183
196
  }
184
197
  // Utility to do element-wise operation and build a dag node with another tensor
185
198
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
@@ -260,6 +273,19 @@ class Tensor {
260
273
  tensor.grad = tensor.grad.add(squeezedGrad);
261
274
  }
262
275
  }
276
+ static normalizeDims(dims, numDims) {
277
+ for (let index = 0; index < dims.length; index++) {
278
+ // Handle negative indices
279
+ if (dims[index] < 0) {
280
+ dims[index] += numDims;
281
+ }
282
+ // If dimension out of bound, throw error
283
+ if (dims[index] >= numDims || dims[index] < 0) {
284
+ throw new Error("Dimensions do not exist");
285
+ }
286
+ }
287
+ return dims;
288
+ }
263
289
  // Contiguity-related ops
264
290
  isContiguous() {
265
291
  const expectedStrides = Tensor.getStrides(this.shape);
@@ -281,14 +307,14 @@ class Tensor {
281
307
  if (this.isContiguous())
282
308
  return this;
283
309
  const outputStrides = Tensor.getStrides(this.shape);
284
- const outputSize = Tensor.shapeToSize(this.shape);
310
+ const outputSize = this.numel;
285
311
  const outputValue = new Array(outputSize);
286
312
  for (let index = 0; index < outputSize; index++) {
287
313
  const outputCoords = Tensor.indexToCoords(index, outputStrides);
288
314
  const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
289
- outputValue[index] = this.value[originalIndex];
315
+ outputValue[index] = this.value[this.offset + originalIndex];
290
316
  }
291
- const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides });
317
+ const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
292
318
  // Gradient flow back to the original tensor
293
319
  if (this.requiresGrad) {
294
320
  out.requiresGrad = true;
@@ -299,15 +325,39 @@ class Tensor {
299
325
  }
300
326
  return out;
301
327
  }
328
+ view(newShape) {
329
+ // Verify shape size
330
+ const originalSize = this.numel;
331
+ const outputSize = Tensor.shapeToSize(newShape);
332
+ if (originalSize !== outputSize) {
333
+ throw new Error("Can not create view: incompatible sizes");
334
+ }
335
+ // Verify compatibility (only contiguity for now)
336
+ if (!this.isContiguous()) {
337
+ throw new Error("Can not create view: incompatible metadata");
338
+ }
339
+ const outputStrides = Tensor.getStrides(newShape);
340
+ const out = new Tensor(this.value, { shape: newShape, strides: outputStrides, numel: outputSize });
341
+ // Gradient reshaped and flow back to the original tensor
342
+ if (this.requiresGrad) {
343
+ out.requiresGrad = true;
344
+ out.children.push(this);
345
+ out.gradFn = () => {
346
+ Tensor.addGrad(this, out.grad.reshape(this.shape));
347
+ };
348
+ }
349
+ return out;
350
+ }
302
351
  reshape(newShape) {
303
352
  // Verify shape size
304
- const originalSize = Tensor.shapeToSize(this.shape);
353
+ const originalSize = this.numel;
305
354
  const outputSize = Tensor.shapeToSize(newShape);
306
355
  if (originalSize !== outputSize) {
307
- throw new Error("Cannot reshape: incompatible sizes");
356
+ throw new Error("Can not reshape: incompatible sizes");
308
357
  }
358
+ // Create new tensor with forced compatibility (only contiguity for now)
309
359
  const outputStrides = Tensor.getStrides(newShape);
310
- const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides });
360
+ const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides, numel: outputSize });
311
361
  // Gradient reshaped and flow back to the original tensor
312
362
  if (this.requiresGrad) {
313
363
  out.requiresGrad = true;
@@ -318,6 +368,234 @@ class Tensor {
318
368
  }
319
369
  return out;
320
370
  }
371
+ // Transpose
372
+ transpose(dim1, dim2) {
373
+ // Handle negative indices
374
+ if (dim1 < 0) {
375
+ dim1 += this.shape.length;
376
+ }
377
+ if (dim2 < 0) {
378
+ dim2 += this.shape.length;
379
+ }
380
+ // If dimension out of bound, throw error
381
+ if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
382
+ throw new Error("Dimensions do not exist to transpose");
383
+ }
384
+ // If same dimension, return view
385
+ if (dim1 === dim2)
386
+ return this;
387
+ // Create new shape and strides by swapping
388
+ const newShape = [...this.shape];
389
+ const newStrides = [...this.strides];
390
+ [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
391
+ [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
392
+ // Create new tensor with same data but swapped shape/strides
393
+ const out = new Tensor(this.value, {
394
+ shape: newShape,
395
+ strides: newStrides,
396
+ offset: this.offset,
397
+ numel: this.numel,
398
+ device: this.device
399
+ });
400
+ out.requiresGrad = this.requiresGrad;
401
+ // Handle gradient if needed
402
+ if (this.requiresGrad) {
403
+ out.children.push(this);
404
+ out.gradFn = () => {
405
+ Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
406
+ };
407
+ }
408
+ return out;
409
+ }
410
+ swapaxes = this.transpose;
411
+ swapdims = this.transpose;
412
+ // Transpose 2D
413
+ t() {
414
+ // Verify matrix shape
415
+ if (this.shape.length !== 2) {
416
+ throw new Error("Input is not a matrix");
417
+ }
418
+ return this.transpose(0, 1);
419
+ }
420
+ // Permute
421
+ permute(dims) {
422
+ dims = Tensor.normalizeDims(dims, this.shape.length);
423
+ if (dims.length !== this.shape.length) {
424
+ throw new Error("Permutation must specify all dimensions");
425
+ }
426
+ // Compute new shape and strides
427
+ const newShape = new Array(dims.length);
428
+ const newStrides = new Array(dims.length);
429
+ for (let index = 0; index < dims.length; index++) {
430
+ const dim = dims[index];
431
+ newShape[index] = this.shape[dim];
432
+ newStrides[index] = this.strides[dim];
433
+ }
434
+ const out = new Tensor(this.value, {
435
+ shape: newShape,
436
+ strides: newStrides,
437
+ offset: this.offset,
438
+ numel: this.numel,
439
+ device: this.device
440
+ });
441
+ if (this.requiresGrad) {
442
+ out.requiresGrad = true;
443
+ out.children.push(this);
444
+ out.gradFn = () => {
445
+ // Compute inverse permutation
446
+ const inverseAxes = new Array(dims.length);
447
+ for (let i = 0; i < dims.length; i++) {
448
+ inverseAxes[dims[i]] = i;
449
+ }
450
+ // Permute gradient back to original order
451
+ const permutedGrad = out.grad.permute(inverseAxes);
452
+ Tensor.addGrad(this, permutedGrad);
453
+ };
454
+ }
455
+ return out;
456
+ }
457
+ // Utility for indexing with array of indices
458
+ indexWithArray(indices) {
459
+ if (typeof this.value === "number")
460
+ return this;
461
+ indices = Tensor.normalizeDims(indices, this.shape[0]);
462
+ // Init necessary stuff for indexing
463
+ const reducedShape = this.shape.slice(1);
464
+ const reducedStrides = this.strides.slice(1);
465
+ const elementsPerIndex = Tensor.shapeToSize(reducedShape);
466
+ // Init output data
467
+ const outputShape = [indices.length, ...reducedShape];
468
+ const outputSize = Tensor.shapeToSize(outputShape);
469
+ const outputValue = new Array(outputSize);
470
+ for (let i = 0; i < indices.length; i++) {
471
+ const sourceRowIndex = indices[i];
472
+ const targetStart = i * elementsPerIndex;
473
+ for (let j = 0; j < elementsPerIndex; j++) {
474
+ const fullCoords = Tensor.indexToCoords(j, reducedStrides);
475
+ fullCoords.unshift(sourceRowIndex);
476
+ const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides);
477
+ outputValue[targetStart + j] = this.value[this.offset + sourceIndex];
478
+ }
479
+ }
480
+ const out = new Tensor(outputValue, {
481
+ shape: outputShape,
482
+ numel: outputSize
483
+ });
484
+ // Handle gradient
485
+ if (this.requiresGrad) {
486
+ out.requiresGrad = true;
487
+ out.children.push(this);
488
+ out.gradFn = () => {
489
+ const outGrad = out.grad;
490
+ // Create zero gradient tensor with original shape
491
+ const grad = Tensor.zerosLike(this);
492
+ // Scatter gradients back to original positions
493
+ for (let i = 0; i < indices.length; i++) {
494
+ const originalRowIndex = indices[i];
495
+ const sourceStart = i * elementsPerIndex;
496
+ for (let j = 0; j < elementsPerIndex; j++) {
497
+ const fullCoords = Tensor.indexToCoords(j, reducedStrides);
498
+ fullCoords.unshift(originalRowIndex);
499
+ const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides);
500
+ grad.value[targetIndex] += outGrad.value[sourceStart + j];
501
+ }
502
+ }
503
+ Tensor.addGrad(this, grad);
504
+ };
505
+ }
506
+ return out;
507
+ }
508
+ // Tensor indexing
509
+ index(indices) {
510
+ if (typeof indices === "number") {
511
+ return this.indexWithArray([indices]).squeeze(0);
512
+ }
513
+ else {
514
+ const tensorIndices = this.handleOther(indices).contiguous();
515
+ const originalShape = tensorIndices.shape;
516
+ const flatIndices = tensorIndices.value;
517
+ const result = this.indexWithArray(flatIndices);
518
+ // Reshape to preserve input shape
519
+ const outputShape = [...originalShape, ...this.shape.slice(1)];
520
+ return result.reshape(outputShape);
521
+ }
522
+ }
523
+ // Tensor slicing
524
+ slice(ranges) {
525
+ // Handle scalars
526
+ if (typeof this.value === "number")
527
+ return this;
528
+ const newShape = [];
529
+ const newStrides = [];
530
+ let newOffset = this.offset || 0;
531
+ // Pad ranges to match tensor dimensions
532
+ const paddedRanges = [...ranges];
533
+ while (paddedRanges.length < this.shape.length) {
534
+ paddedRanges.push([]);
535
+ }
536
+ for (let i = 0; i < this.shape.length; i++) {
537
+ const range = paddedRanges[i] || [];
538
+ const dimSize = this.shape[i];
539
+ const stride = this.strides[i];
540
+ // Default values
541
+ let start = range[0] ?? 0;
542
+ let end = range[1] ?? dimSize;
543
+ let step = range[2] ?? 1;
544
+ // Handle negative indices
545
+ if (start < 0)
546
+ start += dimSize;
547
+ if (end < 0)
548
+ end += dimSize;
549
+ // Clamp to valid range
550
+ start = Math.max(0, Math.min(start, dimSize));
551
+ end = Math.max(0, Math.min(end, dimSize));
552
+ // Calculate new dimension size
553
+ const newDimSize = step > 0
554
+ ? Math.max(0, Math.ceil((end - start) / step))
555
+ : Math.max(0, Math.ceil((start - end) / Math.abs(step)));
556
+ newShape.push(newDimSize);
557
+ newStrides.push(stride * step);
558
+ newOffset += start * stride;
559
+ }
560
+ const out = new Tensor(this.value, {
561
+ shape: newShape,
562
+ strides: newStrides,
563
+ offset: newOffset,
564
+ device: this.device
565
+ });
566
+ if (this.requiresGrad) {
567
+ out.requiresGrad = true;
568
+ out.children.push(this);
569
+ out.gradFn = () => {
570
+ // Create zero tensor of original shape
571
+ const grad = Tensor.zerosLike(this);
572
+ // Upstream grad
573
+ const outGrad = out.grad;
574
+ const totalElements = outGrad.numel;
575
+ for (let i = 0; i < totalElements; i++) {
576
+ // Convert flat index to coordinates in sliced tensor
577
+ const slicedCoords = Tensor.indexToCoords(i, outGrad.strides);
578
+ // Map back to original coordinates
579
+ const originalCoords = new Array(slicedCoords.length);
580
+ for (let dim = 0; dim < slicedCoords.length; dim++) {
581
+ const coord = slicedCoords[dim];
582
+ const range = paddedRanges[dim] || [];
583
+ const start = range[0] ?? 0;
584
+ const step = range[2] ?? 1;
585
+ const normalizedStart = start < 0 ? start + this.shape[dim] : start;
586
+ originalCoords[dim] = normalizedStart + coord * step;
587
+ }
588
+ // Get flat indices with offsets
589
+ const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
590
+ const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset;
591
+ // Accumulate gradient
592
+ grad.value[targetIndex] += outGrad.value[srcIndex];
593
+ }
594
+ Tensor.addGrad(this, grad);
595
+ };
596
+ }
597
+ return out;
598
+ }
321
599
  // Tensor squeeze
322
600
  squeeze(dims) {
323
601
  if (typeof this.value === "number")
@@ -334,6 +612,7 @@ class Tensor {
334
612
  }
335
613
  }
336
614
  }
615
+ dims = Tensor.normalizeDims(dims, this.shape.length);
337
616
  // Remove size-1 dims only
338
617
  const outShape = [], outStrides = [];
339
618
  for (let index = 0; index < this.shape.length; index++) {
@@ -348,10 +627,11 @@ class Tensor {
348
627
  outStrides.push(stride);
349
628
  }
350
629
  }
351
- const outValue = outShape.length === 0 ? this.value[0] : this.value;
630
+ const outValue = outShape.length === 0 ? this.value[this.offset] : this.value;
352
631
  const out = new Tensor(outValue, {
353
632
  shape: outShape,
354
633
  strides: outStrides,
634
+ offset: this.offset,
355
635
  device: this.device
356
636
  });
357
637
  // Set up gradient if needed
@@ -370,6 +650,10 @@ class Tensor {
370
650
  }
371
651
  // Tensor unsqueeze - adds dimension of size 1 at specified position
372
652
  unsqueeze(dim) {
653
+ // Handle negative indices
654
+ if (dim < 0) {
655
+ dim += this.shape.length;
656
+ }
373
657
  let thisValue = this.value;
374
658
  if (typeof thisValue === "number") {
375
659
  thisValue = [thisValue];
@@ -389,7 +673,12 @@ class Tensor {
389
673
  newDimStride = this.strides[dim] * this.shape[dim];
390
674
  }
391
675
  newStrides.splice(dim, 0, newDimStride);
392
- const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, device: this.device });
676
+ const out = new Tensor(thisValue, {
677
+ shape: newShape,
678
+ strides: newStrides,
679
+ offset: this.offset,
680
+ device: this.device
681
+ });
393
682
  // Set up gradient if needed
394
683
  if (this.requiresGrad) {
395
684
  out.requiresGrad = true;
@@ -400,325 +689,138 @@ class Tensor {
400
689
  }
401
690
  return out;
402
691
  }
403
- // Tensor sum reduction
404
- sum(dims, keepDims = false) {
405
- if (typeof this.value === "number")
406
- return this;
692
+ // Generic reduction operation handler
693
+ static reduce(tensor, dims, keepDims, config) {
694
+ if (typeof tensor.value === "number")
695
+ return tensor;
407
696
  if (typeof dims === "undefined") {
408
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
697
+ dims = Array.from({ length: tensor.shape.length }, (_, index) => index);
409
698
  }
410
699
  if (Array.isArray(dims)) {
411
- // Sort in descending order
700
+ dims = Tensor.normalizeDims(dims, tensor.shape.length);
412
701
  const sortedDims = dims.sort((a, b) => b - a);
413
- let reducedThis = this;
702
+ let reducedThis = tensor;
414
703
  for (let i = 0; i < sortedDims.length; i++) {
415
- reducedThis = reducedThis.sum(sortedDims[i], true);
704
+ reducedThis = Tensor.reduce(reducedThis, sortedDims[i], true, config);
416
705
  }
417
706
  return keepDims ? reducedThis : reducedThis.squeeze(dims);
418
707
  }
419
- // Dims that are reduced now have size-1
420
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
708
+ const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
421
709
  const outputStrides = Tensor.getStrides(outputShape);
422
710
  const outputSize = Tensor.shapeToSize(outputShape);
423
- const outputValue = new Array(outputSize).fill(0);
424
- const originalSize = Tensor.shapeToSize(this.shape);
425
- // Gradient data
426
- let gradShape, gradStrides, gradValue = [];
427
- // Allocate gradient data only when needed
428
- if (this.requiresGrad) {
429
- gradShape = this.shape;
430
- gradStrides = this.strides;
431
- gradValue = new Array(originalSize).fill(0);
432
- }
433
- // Calculate new value after sum
434
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
435
- // Force 0 on reduced axes to collapse into size-1 dims
436
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
437
- outCoords[dims] = 0;
438
- // Convert output coordinates to flat index
439
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
440
- // Add into sum
441
- outputValue[outFlatIndex] += this.value[realFlatIndex];
442
- // Mark for gradient if needed
443
- if (this.requiresGrad) {
444
- gradValue[realFlatIndex] = 1;
711
+ const outputValue = new Array(outputSize).fill(config.identity);
712
+ const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
713
+ const originalSize = tensor.numel;
714
+ const originalValue = tensor.value;
715
+ const linearStrides = Tensor.getStrides(tensor.shape);
716
+ // Forward pass
717
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
718
+ // Convert linear index to coordinates using contiguous strides
719
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
720
+ // Convert coordinates to actual strided index
721
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
722
+ // Convert coords to reduced index
723
+ coords[dims] = 0;
724
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
725
+ // Apply op
726
+ outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
727
+ // Count el if needed
728
+ if (config.needsCounters) {
729
+ outputCounters[outFlatIndex]++;
445
730
  }
446
731
  }
447
- const out = new Tensor(outputValue, {
448
- shape: outputShape,
449
- strides: outputStrides
450
- });
451
- // Set up gradient if needed
452
- if (this.requiresGrad) {
732
+ // Post-process if needed (e.g., divide by count for mean)
733
+ if (config.postProcess) {
734
+ config.postProcess({ values: outputValue, counters: outputCounters });
735
+ }
736
+ const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
737
+ // Gradient setup
738
+ if (tensor.requiresGrad) {
453
739
  out.requiresGrad = true;
454
- out.children.push(this);
740
+ out.children.push(tensor);
455
741
  out.gradFn = () => {
456
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
457
- Tensor.addGrad(this, out.grad.mul(localGrad));
742
+ let shareCounts = [];
743
+ if (config.needsShareCounts) {
744
+ shareCounts = new Array(outputSize).fill(0);
745
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
746
+ // Convert linear index to coordinates using contiguous strides
747
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
748
+ // Convert coordinates to actual strided index
749
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
750
+ // Convert coords to reduced index
751
+ coords[dims] = 0;
752
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
753
+ // We collect how many elements share the same max value first
754
+ shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
755
+ }
756
+ }
757
+ const gradValue = new Array(originalSize);
758
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
759
+ // Convert linear index to coordinates using contiguous strides
760
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
761
+ // Convert coordinates to actual strided index
762
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
763
+ // Convert coords to reduced index
764
+ coords[dims] = 0;
765
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
766
+ gradValue[flatIndex] = config.gradientFn({
767
+ outputValue,
768
+ originalValue: tensor.value,
769
+ counters: outputCounters,
770
+ shareCounts,
771
+ realIndex: realFlatIndex,
772
+ outIndex: outFlatIndex
773
+ });
774
+ }
775
+ const localGrad = new Tensor(gradValue, { shape: tensor.shape, numel: tensor.numel });
776
+ Tensor.addGrad(tensor, out.grad.mul(localGrad));
458
777
  };
459
778
  }
460
779
  return keepDims ? out : out.squeeze(dims);
461
780
  }
462
- // Tensor product reduction
781
+ // Simplified reduction operations
782
+ sum(dims, keepDims = false) {
783
+ return Tensor.reduce(this, dims, keepDims, {
784
+ identity: 0,
785
+ operation: (a, b) => a + b,
786
+ gradientFn: ({}) => 1
787
+ });
788
+ }
463
789
  prod(dims, keepDims = false) {
464
- if (typeof this.value === "number")
465
- return this;
466
- if (typeof dims === "undefined") {
467
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
468
- }
469
- if (Array.isArray(dims)) {
470
- // Sort in descending order
471
- const sortedDims = dims.sort((a, b) => b - a);
472
- let reducedThis = this;
473
- for (let i = 0; i < sortedDims.length; i++) {
474
- reducedThis = reducedThis.prod(sortedDims[i], true);
475
- }
476
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
477
- }
478
- // Dims that are reduced now have size-1
479
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
480
- const outputStrides = Tensor.getStrides(outputShape);
481
- const outputSize = Tensor.shapeToSize(outputShape);
482
- const outputValue = new Array(outputSize).fill(1);
483
- const originalSize = Tensor.shapeToSize(this.shape);
484
- // Calculate new value after multiplying
485
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
486
- // Force 0 on reduced axes to collapse into size-1 dims
487
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
488
- outCoords[dims] = 0;
489
- // Convert output coordinates to flat index
490
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
491
- // Multiply into product
492
- outputValue[outFlatIndex] *= this.value[realFlatIndex];
493
- }
494
- const out = new Tensor(outputValue, {
495
- shape: outputShape,
496
- strides: outputStrides
790
+ return Tensor.reduce(this, dims, keepDims, {
791
+ identity: 1,
792
+ operation: (a, b) => a * b,
793
+ gradientFn: ({ outputValue, originalValue, realIndex, outIndex }) => outputValue[outIndex] / originalValue[realIndex]
497
794
  });
498
- // Set up gradient if needed
499
- if (this.requiresGrad) {
500
- out.requiresGrad = true;
501
- out.children.push(this);
502
- out.gradFn = () => {
503
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
504
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
505
- // Force 0 on reduced axes to collapse into size-1 dims
506
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
507
- outCoords[dims] = 0;
508
- // Convert output coordinates to flat index
509
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
510
- // Grad is the product of other elements of the same axis, which is product of all els divided by the current value
511
- gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
512
- }
513
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
514
- Tensor.addGrad(this, out.grad.mul(localGrad));
515
- };
516
- }
517
- return keepDims ? out : out.squeeze(dims);
518
795
  }
519
- // Tensor mean reduction
520
796
  mean(dims, keepDims = false) {
521
- if (typeof this.value === "number")
522
- return this;
523
- if (typeof dims === "undefined") {
524
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
525
- }
526
- if (Array.isArray(dims)) {
527
- // Sort in descending order
528
- const sortedDims = dims.sort((a, b) => b - a);
529
- let reducedThis = this;
530
- for (let i = 0; i < sortedDims.length; i++) {
531
- reducedThis = reducedThis.mean(sortedDims[i], true);
532
- }
533
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
534
- }
535
- // Dims that are reduced now have size-1
536
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
537
- const outputStrides = Tensor.getStrides(outputShape);
538
- const outputSize = Tensor.shapeToSize(outputShape);
539
- const outputValue = new Array(outputSize).fill(0);
540
- const outputFeeders = new Array(outputSize).fill(0);
541
- const originalSize = Tensor.shapeToSize(this.shape);
542
- // Calculate sums and how many elements contribute to specific positions
543
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
544
- // Force 0 on reduced axes to collapse into size-1 dims
545
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
546
- outCoords[dims] = 0;
547
- // Convert output coordinates to flat index
548
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
549
- // Calculate sum and contributors to the sum
550
- outputValue[outFlatIndex] += this.value[realFlatIndex];
551
- outputFeeders[outFlatIndex]++;
552
- }
553
- // Calculate mean by dividing sum by the number of contributors to the position
554
- for (let index = 0; index < outputSize; index++) {
555
- outputValue[index] /= outputFeeders[index];
556
- }
557
- const out = new Tensor(outputValue, {
558
- shape: outputShape,
559
- strides: outputStrides
560
- });
561
- // Set up gradient if needed
562
- if (this.requiresGrad) {
563
- out.requiresGrad = true;
564
- out.children.push(this);
565
- out.gradFn = () => {
566
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
567
- // Calculate grad by assigning 1 divided by the number of contributors to the position
568
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
569
- // Force 0 on reduced axes to collapse into size-1 dims
570
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
571
- outCoords[dims] = 0;
572
- // Convert output coordinates to flat index
573
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
574
- // Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
575
- gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
797
+ return Tensor.reduce(this, dims, keepDims, {
798
+ identity: 0,
799
+ operation: (a, b) => a + b,
800
+ needsCounters: true,
801
+ postProcess: ({ values, counters }) => {
802
+ for (let i = 0; i < values.length; i++) {
803
+ values[i] /= counters[i];
576
804
  }
577
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
578
- Tensor.addGrad(this, out.grad.mul(localGrad));
579
- };
580
- }
581
- return keepDims ? out : out.squeeze(dims);
805
+ },
806
+ gradientFn: ({ counters, outIndex }) => 1 / counters[outIndex]
807
+ });
582
808
  }
583
- // Tensor maximum reduction
584
809
  max(dims, keepDims = false) {
585
- if (typeof this.value === "number")
586
- return this;
587
- if (typeof dims === "undefined") {
588
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
589
- }
590
- if (Array.isArray(dims)) {
591
- // Sort in descending order
592
- const sortedDims = dims.sort((a, b) => b - a);
593
- let reducedThis = this;
594
- for (let i = 0; i < sortedDims.length; i++) {
595
- reducedThis = reducedThis.max(sortedDims[i], true);
596
- }
597
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
598
- }
599
- // Dims that are reduced now have size-1
600
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
601
- const outputStrides = Tensor.getStrides(outputShape);
602
- const outputSize = Tensor.shapeToSize(outputShape);
603
- const outputValue = new Array(outputSize).fill(-Infinity);
604
- const originalSize = Tensor.shapeToSize(this.shape);
605
- // Calculate maximum values of axes
606
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
607
- // Force 0 on reduced axes to collapse into size-1 dims
608
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
609
- outCoords[dims] = 0;
610
- // Convert output coordinates to flat index
611
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
612
- // Get max over time
613
- if (this.value[realFlatIndex] > outputValue[outFlatIndex]) {
614
- outputValue[outFlatIndex] = this.value[realFlatIndex];
615
- }
616
- }
617
- const out = new Tensor(outputValue, {
618
- shape: outputShape,
619
- strides: outputStrides
810
+ return Tensor.reduce(this, dims, keepDims, {
811
+ identity: -Infinity,
812
+ operation: (a, b) => Math.max(a, b),
813
+ needsShareCounts: true,
814
+ gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
620
815
  });
621
- // Set up gradient if needed
622
- if (this.requiresGrad) {
623
- out.requiresGrad = true;
624
- out.children.push(this);
625
- out.gradFn = () => {
626
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
627
- const shareCounts = new Array(outputSize).fill(0);
628
- const originalValue = this.value;
629
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
630
- // Force 0 on reduced axes to collapse into size-1 dims
631
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
632
- outCoords[dims] = 0;
633
- // Convert output coordinates to flat index
634
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
635
- // We collect how many elements share the same max value first
636
- shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
637
- }
638
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
639
- // Force 0 on reduced axes to collapse into size-1 dims
640
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
641
- outCoords[dims] = 0;
642
- // Convert output coordinates to flat index
643
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
644
- // Here we share the grad between the elements that share the same max value
645
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
646
- }
647
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
648
- Tensor.addGrad(this, out.grad.mul(localGrad));
649
- };
650
- }
651
- return keepDims ? out : out.squeeze(dims);
652
816
  }
653
- // Tensor minimum reduction
654
817
  min(dims, keepDims = false) {
655
- if (typeof this.value === "number")
656
- return this;
657
- if (typeof dims === "undefined") {
658
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
659
- }
660
- if (Array.isArray(dims)) {
661
- // Sort in descending order
662
- const sortedDims = dims.sort((a, b) => b - a);
663
- let reducedThis = this;
664
- for (let i = 0; i < sortedDims.length; i++) {
665
- reducedThis = reducedThis.min(sortedDims[i], true);
666
- }
667
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
668
- }
669
- // Dims that are reduced now have size-1
670
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
671
- const outputStrides = Tensor.getStrides(outputShape);
672
- const outputSize = Tensor.shapeToSize(outputShape);
673
- const outputValue = new Array(outputSize).fill(Infinity);
674
- const originalSize = Tensor.shapeToSize(this.shape);
675
- // Calculate minimum values of axes
676
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
677
- // Force 0 on reduced axes to collapse into size-1 dims
678
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
679
- outCoords[dims] = 0;
680
- // Convert output coordinates to flat index
681
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
682
- // Get min over time
683
- if (this.value[realFlatIndex] < outputValue[outFlatIndex]) {
684
- outputValue[outFlatIndex] = this.value[realFlatIndex];
685
- }
686
- }
687
- const out = new Tensor(outputValue, {
688
- shape: outputShape,
689
- strides: outputStrides
818
+ return Tensor.reduce(this, dims, keepDims, {
819
+ identity: Infinity,
820
+ operation: (a, b) => Math.min(a, b),
821
+ needsShareCounts: true,
822
+ gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
690
823
  });
691
- // Set up gradient if needed
692
- if (this.requiresGrad) {
693
- out.requiresGrad = true;
694
- out.children.push(this);
695
- out.gradFn = () => {
696
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
697
- const shareCounts = new Array(outputSize).fill(0);
698
- const originalValue = this.value;
699
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
700
- // Force 0 on reduced axes to collapse into size-1 dims
701
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
702
- outCoords[dims] = 0;
703
- // Convert output coordinates to flat index
704
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
705
- // We collect how many elements share the same min value first
706
- shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
707
- }
708
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
709
- // Force 0 on reduced axes to collapse into size-1 dims
710
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
711
- outCoords[dims] = 0;
712
- // Convert output coordinates to flat index
713
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
714
- // Here we share the grad between the elements that share the same min value
715
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
716
- }
717
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
718
- Tensor.addGrad(this, out.grad.mul(localGrad));
719
- };
720
- }
721
- return keepDims ? out : out.squeeze(dims);
722
824
  }
723
825
  // Tensor all condition reduction
724
826
  all(dims, keepDims = false) {
@@ -738,75 +840,18 @@ class Tensor {
738
840
  std(dims, keepDims = false) {
739
841
  return this.var(dims, keepDims).sqrt();
740
842
  }
741
- // Tensor product reduction
742
- softmax(dims) {
843
+ // Tensor softmax
844
+ softmax(dim = -1) {
743
845
  if (typeof this.value === "number")
744
846
  return this;
745
- if (typeof dims === "undefined") {
746
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
747
- }
748
- if (Array.isArray(dims)) {
749
- // Sort in descending order
750
- const sortedDims = dims.sort((a, b) => b - a);
751
- let reducedThis = this;
752
- for (let i = 0; i < sortedDims.length; i++) {
753
- reducedThis = reducedThis.softmax(sortedDims[i]);
754
- }
755
- return reducedThis;
756
- }
757
- // Dims that are reduced now have size-1
758
- const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
759
- const expSumStrides = Tensor.getStrides(expSumShape);
760
- const expSumSize = Tensor.shapeToSize(expSumShape);
761
- const expSumValue = new Array(expSumSize).fill(0);
762
- const outputShape = this.shape;
763
- const outputStrides = this.strides;
764
- const outputSize = Tensor.shapeToSize(outputShape);
765
- const outputValue = new Array(outputSize);
766
- // Calculate sums of e^xi over axes
767
- for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
768
- const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
769
- // Force 0 on reduced axes to collapse into size-1 dims
770
- const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
771
- // Convert exp sum coordinates to flat index
772
- const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
773
- // Add e^x to the sum cache
774
- expSumValue[expSumFlatIndex] += Math.exp(this.value[realFlatIndex]);
775
- }
776
- // Calculate e^xi / sum over axes
777
- for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
778
- const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
779
- // Force 0 on reduced axes to collapse into size-1 dims
780
- const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
781
- // Convert exp sum coordinates to flat index
782
- const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
783
- // Calculate e^xi / sum
784
- outputValue[realFlatIndex] = Math.exp(this.value[realFlatIndex]) / expSumValue[expSumFlatIndex];
785
- }
786
- const out = new Tensor(outputValue, {
787
- shape: outputShape,
788
- strides: outputStrides
789
- });
790
- // Set up gradient if needed
791
- if (this.requiresGrad) {
792
- out.requiresGrad = true;
793
- out.children.push(this);
794
- out.gradFn = () => {
795
- const upstreamGrad = out.grad;
796
- const softmaxOutput = out.detach();
797
- // Compute element-wise product: ∂L/∂σᵢ × σᵢ
798
- const gradTimesOutput = upstreamGrad.mul(softmaxOutput);
799
- // Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ)
800
- const sumGradOutput = gradTimesOutput.sum(dims, true); // keepDims=true for broadcasting
801
- // Apply softmax gradient formula:
802
- // ∂L/∂zⱼ = (∂L/∂σⱼ × σⱼ) - (σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ))
803
- const term1 = upstreamGrad.mul(softmaxOutput); // ∂L/∂σⱼ × σⱼ
804
- const term2 = softmaxOutput.mul(sumGradOutput); // σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ)
805
- const localGrad = term1.sub(term2);
806
- Tensor.addGrad(this, localGrad);
807
- };
808
- }
809
- return out;
847
+ // Handle negative indexing
848
+ if (dim < 0)
849
+ dim = this.shape.length + dim;
850
+ const maxVals = this.max(dim, true);
851
+ const shifted = this.sub(maxVals);
852
+ const expVals = shifted.exp();
853
+ const sumExp = expVals.sum(dim, true);
854
+ return expVals.div(sumExp);
810
855
  }
811
856
  // Tensor element-wise addition
812
857
  add(other) {
@@ -1139,76 +1184,6 @@ class Tensor {
1139
1184
  erfinv() {
1140
1185
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfinv)(a), (self, outGrad) => outGrad.mul(self.erfinv().square().exp().mul(Math.sqrt(Math.PI) / 2)));
1141
1186
  }
1142
- // Transpose
1143
- transpose(dim1, dim2) {
1144
- // If dimension out of bound, throw error
1145
- if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
1146
- throw new Error("Dimensions do not exist to tranpose");
1147
- }
1148
- // If same dimension, return copy
1149
- if (dim1 === dim2) {
1150
- return new Tensor(this.value, { shape: this.shape, strides: this.strides });
1151
- }
1152
- // Create new shape and strides by swapping
1153
- const newShape = [...this.shape];
1154
- const newStrides = [...this.strides];
1155
- [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
1156
- [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
1157
- // Create new tensor with same data but swapped shape/strides
1158
- const out = new Tensor(this.value, { shape: newShape, strides: newStrides, device: this.device });
1159
- out.requiresGrad = this.requiresGrad;
1160
- // Handle gradient if needed
1161
- if (this.requiresGrad) {
1162
- out.children.push(this);
1163
- out.gradFn = () => {
1164
- Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
1165
- };
1166
- }
1167
- return out;
1168
- }
1169
- swapaxes = this.transpose;
1170
- swapdims = this.transpose;
1171
- // Transpose 2D
1172
- t() {
1173
- // Verify matrix shape
1174
- if (this.shape.length !== 2) {
1175
- throw new Error("Input is not a matrix");
1176
- }
1177
- return this.transpose(0, 1);
1178
- }
1179
- // Permute
1180
- permute(dims) {
1181
- if (dims.length !== this.shape.length) {
1182
- throw new Error("Permutation must specify all dimensions");
1183
- }
1184
- // Compute new shape and strides
1185
- const newShape = new Array(dims.length);
1186
- const newStrides = new Array(dims.length);
1187
- for (let index = 0; index < dims.length; index++) {
1188
- const dim = dims[index];
1189
- newShape[index] = this.shape[dim];
1190
- newStrides[index] = this.strides[dim];
1191
- }
1192
- const out = new Tensor(this.value, {
1193
- shape: newShape,
1194
- strides: newStrides
1195
- });
1196
- if (this.requiresGrad) {
1197
- out.requiresGrad = true;
1198
- out.children.push(this);
1199
- out.gradFn = () => {
1200
- // Compute inverse permutation
1201
- const inverseAxes = new Array(dims.length);
1202
- for (let i = 0; i < dims.length; i++) {
1203
- inverseAxes[dims[i]] = i;
1204
- }
1205
- // Permute gradient back to original order
1206
- const permutedGrad = out.grad.permute(inverseAxes);
1207
- Tensor.addGrad(this, permutedGrad);
1208
- };
1209
- }
1210
- return out;
1211
- }
1212
1187
  // 1D tensor dot product
1213
1188
  dot(other) {
1214
1189
  other = this.handleOther(other);
@@ -1216,36 +1191,7 @@ class Tensor {
1216
1191
  if (this.shape.length !== 1 || other.shape.length !== 1) {
1217
1192
  throw new Error("Inputs are not 1D tensors");
1218
1193
  }
1219
- // Simple vector dot product
1220
- const vectLen = this.shape[0];
1221
- const vectA = this.value;
1222
- const vectB = other.value;
1223
- let sum = 0;
1224
- for (let index = 0; index < vectLen; index++) {
1225
- sum += vectA[index] * vectB[index];
1226
- }
1227
- const out = new Tensor(sum);
1228
- if (this.requiresGrad) {
1229
- out.requiresGrad = true;
1230
- out.children.push(this);
1231
- }
1232
- if (other.requiresGrad) {
1233
- out.requiresGrad = true;
1234
- out.children.push(other);
1235
- }
1236
- if (out.requiresGrad) {
1237
- out.gradFn = () => {
1238
- // Disable gradient collecting of gradients themselves
1239
- const outGrad = out.grad;
1240
- const selfNoGrad = this.detach();
1241
- const otherNoGrad = other.detach();
1242
- if (this.requiresGrad)
1243
- Tensor.addGrad(this, outGrad.mul(otherNoGrad));
1244
- if (other.requiresGrad)
1245
- Tensor.addGrad(other, outGrad.mul(selfNoGrad));
1246
- };
1247
- }
1248
- return out;
1194
+ return this.mul(other).sum();
1249
1195
  }
1250
1196
  // Matrix multiplication
1251
1197
  mm(other) {
@@ -1274,12 +1220,12 @@ class Tensor {
1274
1220
  for (let k = 0; k < matACols; k++) {
1275
1221
  // Tensor values are 1D arrays so we have to get real index using strides
1276
1222
  matC[i * matCStrides[0] + j * matCStrides[1]] +=
1277
- matA[i * matAStrides[0] + k * matAStrides[1]] *
1278
- matB[k * matBStrides[0] + j * matBStrides[1]];
1223
+ matA[i * matAStrides[0] + k * matAStrides[1] + this.offset] *
1224
+ matB[k * matBStrides[0] + j * matBStrides[1] + other.offset];
1279
1225
  }
1280
1226
  }
1281
1227
  }
1282
- const out = new Tensor(matC, { shape: matCShape, strides: matCStrides });
1228
+ const out = new Tensor(matC, { shape: matCShape, strides: matCStrides, numel: matCSize });
1283
1229
  if (this.requiresGrad) {
1284
1230
  out.requiresGrad = true;
1285
1231
  out.children.push(this);
@@ -1331,13 +1277,13 @@ class Tensor {
1331
1277
  for (let k = 0; k < batchACols; k++) {
1332
1278
  // Tensor values are 1D arrays so we have to get real index using strides
1333
1279
  batchC[q * batchCStrides[0] + i * batchCStrides[1] + j * batchCStrides[2]] +=
1334
- batchA[q * batchAStrides[0] + i * batchAStrides[1] + k * batchAStrides[2]] *
1335
- batchB[q * batchBStrides[0] + k * batchBStrides[1] + j * batchBStrides[2]];
1280
+ batchA[q * batchAStrides[0] + i * batchAStrides[1] + k * batchAStrides[2] + this.offset] *
1281
+ batchB[q * batchBStrides[0] + k * batchBStrides[1] + j * batchBStrides[2] + other.offset];
1336
1282
  }
1337
1283
  }
1338
1284
  }
1339
1285
  }
1340
- const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides });
1286
+ const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides, numel: batchCSize });
1341
1287
  if (this.requiresGrad) {
1342
1288
  out.requiresGrad = true;
1343
1289
  out.children.push(this);
@@ -1410,7 +1356,7 @@ class Tensor {
1410
1356
  const otherOffsetShape = otherShape.slice(0, -2);
1411
1357
  const selfOffsetStrides = selfStrides.slice(0, -2);
1412
1358
  const otherOffsetStrides = otherStrides.slice(0, -2);
1413
- // The output's offset data
1359
+ // Base offset data
1414
1360
  const offsetShape = Tensor.broadcastShapes(selfOffsetShape, otherOffsetShape);
1415
1361
  const offsetSize = Tensor.shapeToSize(offsetShape);
1416
1362
  const offsetStrides = Tensor.getStrides(offsetShape);
@@ -1419,10 +1365,11 @@ class Tensor {
1419
1365
  const outputStrides = Tensor.getStrides(outputShape);
1420
1366
  const outputSize = Tensor.shapeToSize(outputShape);
1421
1367
  const outputValue = new Array(outputSize).fill(0);
1368
+ const outputOffsetStrides = outputStrides.slice(0, -2);
1422
1369
  // Loop through outer dims and do matmul on two outer-most dims
1423
1370
  for (let index = 0; index < offsetSize; index++) {
1424
1371
  const coords = Tensor.indexToCoords(index, offsetStrides);
1425
- const offset = Tensor.coordsToIndex(coords, outputStrides.slice(0, -2));
1372
+ const offset = Tensor.coordsToIndex(coords, outputOffsetStrides);
1426
1373
  const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
1427
1374
  const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
1428
1375
  for (let i = 0; i < batchARows; i++) {
@@ -1431,12 +1378,12 @@ class Tensor {
1431
1378
  const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
1432
1379
  const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
1433
1380
  const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
1434
- outputValue[outputIdx] += batchA[selfIdx] * batchB[otherIdx];
1381
+ outputValue[outputIdx] += batchA[selfIdx + this.offset] * batchB[otherIdx + other.offset];
1435
1382
  }
1436
1383
  }
1437
1384
  }
1438
1385
  }
1439
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides });
1386
+ const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
1440
1387
  if (this.requiresGrad) {
1441
1388
  out.requiresGrad = true;
1442
1389
  out.children.push(this);
@@ -1452,9 +1399,9 @@ class Tensor {
1452
1399
  const selfNoGrad = self.detach();
1453
1400
  const otherNoGrad = other.detach();
1454
1401
  if (this.requiresGrad)
1455
- Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
1402
+ Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(-2, -1)));
1456
1403
  if (other.requiresGrad)
1457
- Tensor.addGrad(other, selfNoGrad.transpose(lastDim - 1, lastDim).matmul(outGrad));
1404
+ Tensor.addGrad(other, selfNoGrad.transpose(-2, -1).matmul(outGrad));
1458
1405
  };
1459
1406
  }
1460
1407
  return out;
@@ -1476,15 +1423,15 @@ class Tensor {
1476
1423
  return new Tensor(num, options);
1477
1424
  const outputSize = Tensor.shapeToSize(shape);
1478
1425
  const outputValue = new Array(outputSize).fill(num);
1479
- return new Tensor(outputValue, { shape, ...options });
1426
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1480
1427
  }
1481
1428
  // Utility to create a new tensor with shape of another tensor, filled with a number
1482
1429
  static fullLike(tensor, num, options = {}) {
1483
1430
  if (typeof tensor.value === "number")
1484
1431
  return new Tensor(num, options);
1485
- return new Tensor(new Array(tensor.value.length).fill(num), {
1432
+ return new Tensor(new Array(tensor.numel).fill(num), {
1486
1433
  shape: tensor.shape,
1487
- strides: tensor.strides,
1434
+ numel: tensor.numel,
1488
1435
  device: tensor.device,
1489
1436
  ...options
1490
1437
  });
@@ -1495,15 +1442,15 @@ class Tensor {
1495
1442
  return new Tensor(1, options);
1496
1443
  const outputSize = Tensor.shapeToSize(shape);
1497
1444
  const outputValue = new Array(outputSize).fill(1);
1498
- return new Tensor(outputValue, { shape, ...options });
1445
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1499
1446
  }
1500
1447
  // Utility to create a new tensor with shape of another tensor, filled with 1
1501
1448
  static onesLike(tensor, options = {}) {
1502
1449
  if (typeof tensor.value === "number")
1503
1450
  return new Tensor(1, options);
1504
- return new Tensor(new Array(tensor.value.length).fill(1), {
1451
+ return new Tensor(new Array(tensor.numel).fill(1), {
1505
1452
  shape: tensor.shape,
1506
- strides: tensor.strides,
1453
+ numel: tensor.numel,
1507
1454
  device: tensor.device,
1508
1455
  ...options
1509
1456
  });
@@ -1514,15 +1461,15 @@ class Tensor {
1514
1461
  return new Tensor(0, options);
1515
1462
  const outputSize = Tensor.shapeToSize(shape);
1516
1463
  const outputValue = new Array(outputSize).fill(0);
1517
- return new Tensor(outputValue, { shape, ...options });
1464
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1518
1465
  }
1519
1466
  // Utility to create a new tensor with shape of another tensor, filled with 0
1520
1467
  static zerosLike(tensor, options = {}) {
1521
1468
  if (typeof tensor.value === "number")
1522
1469
  return new Tensor(0, options);
1523
- return new Tensor(new Array(tensor.value.length).fill(0), {
1470
+ return new Tensor(new Array(tensor.numel).fill(0), {
1524
1471
  shape: tensor.shape,
1525
- strides: tensor.strides,
1472
+ numel: tensor.numel,
1526
1473
  device: tensor.device,
1527
1474
  ...options
1528
1475
  });
@@ -1536,19 +1483,19 @@ class Tensor {
1536
1483
  for (let index = 0; index < outputValue.length; index++) {
1537
1484
  outputValue[index] = (0, utils_1.randUniform)();
1538
1485
  }
1539
- return new Tensor(outputValue, { shape, ...options });
1486
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1540
1487
  }
1541
1488
  // Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
1542
1489
  static randLike(tensor, options = {}) {
1543
1490
  if (typeof tensor.value === "number")
1544
1491
  return new Tensor((0, utils_1.randUniform)(), options);
1545
- const outputValue = new Array(tensor.value.length);
1492
+ const outputValue = new Array(tensor.numel);
1546
1493
  for (let index = 0; index < outputValue.length; index++) {
1547
1494
  outputValue[index] = (0, utils_1.randUniform)();
1548
1495
  }
1549
1496
  return new Tensor(outputValue, {
1550
1497
  shape: tensor.shape,
1551
- strides: tensor.strides,
1498
+ numel: tensor.numel,
1552
1499
  device: tensor.device,
1553
1500
  ...options
1554
1501
  });
@@ -1562,19 +1509,19 @@ class Tensor {
1562
1509
  for (let index = 0; index < outputValue.length; index++) {
1563
1510
  outputValue[index] = (0, utils_1.randNormal)();
1564
1511
  }
1565
- return new Tensor(outputValue, { shape, ...options });
1512
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1566
1513
  }
1567
1514
  // Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
1568
1515
  static randnLike(tensor, options = {}) {
1569
1516
  if (typeof tensor.value === "number")
1570
1517
  return new Tensor((0, utils_1.randNormal)(), options);
1571
- const outputValue = new Array(tensor.value.length);
1518
+ const outputValue = new Array(tensor.numel);
1572
1519
  for (let index = 0; index < outputValue.length; index++) {
1573
1520
  outputValue[index] = (0, utils_1.randNormal)();
1574
1521
  }
1575
1522
  return new Tensor(outputValue, {
1576
1523
  shape: tensor.shape,
1577
- strides: tensor.strides,
1524
+ numel: tensor.numel,
1578
1525
  device: tensor.device,
1579
1526
  ...options
1580
1527
  });
@@ -1588,19 +1535,19 @@ class Tensor {
1588
1535
  for (let index = 0; index < outputValue.length; index++) {
1589
1536
  outputValue[index] = (0, utils_1.randInt)(low, high);
1590
1537
  }
1591
- return new Tensor(outputValue, { shape, ...options });
1538
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1592
1539
  }
1593
1540
  // Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
1594
1541
  static randintLike(tensor, low, high, options = {}) {
1595
1542
  if (typeof tensor.value === "number")
1596
1543
  return new Tensor((0, utils_1.randInt)(low, high), options);
1597
- const outputValue = new Array(tensor.value.length);
1544
+ const outputValue = new Array(tensor.numel);
1598
1545
  for (let index = 0; index < outputValue.length; index++) {
1599
1546
  outputValue[index] = (0, utils_1.randInt)(low, high);
1600
1547
  }
1601
1548
  return new Tensor(outputValue, {
1602
1549
  shape: tensor.shape,
1603
- strides: tensor.strides,
1550
+ numel: tensor.numel,
1604
1551
  device: tensor.device,
1605
1552
  ...options
1606
1553
  });
@@ -1614,7 +1561,7 @@ class Tensor {
1614
1561
  for (let index = 0; index < outputValue.length; index++) {
1615
1562
  outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
1616
1563
  }
1617
- return new Tensor(outputValue, { shape, ...options });
1564
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1618
1565
  }
1619
1566
  // Utility to create a new tensor filled with a random number with uniform distribution from low to high
1620
1567
  static uniform(shape, low, high, options = {}) {
@@ -1625,7 +1572,7 @@ class Tensor {
1625
1572
  for (let index = 0; index < outputValue.length; index++) {
1626
1573
  outputValue[index] = (0, utils_1.randUniform)(low, high);
1627
1574
  }
1628
- return new Tensor(outputValue, { shape, ...options });
1575
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1629
1576
  }
1630
1577
  // Reverse-mode autodiff call
1631
1578
  backward(options = {}) {
@@ -1674,13 +1621,15 @@ class Tensor {
1674
1621
  }
1675
1622
  return result;
1676
1623
  }
1677
- return buildNested(this.value, this.shape, this.strides);
1624
+ return buildNested(this.value, this.shape, this.strides, this.offset);
1678
1625
  }
1679
1626
  // Returns a view of the tensor with gradient turned on/off and detaches from autograd
1680
1627
  withGrad(requiresGrad) {
1681
1628
  return new Tensor(this.value, {
1682
1629
  shape: this.shape,
1683
1630
  strides: this.strides,
1631
+ offset: this.offset,
1632
+ numel: this.numel,
1684
1633
  device: this.device,
1685
1634
  requiresGrad
1686
1635
  });
@@ -1690,6 +1639,8 @@ class Tensor {
1690
1639
  return new Tensor(this.value, {
1691
1640
  shape: this.shape,
1692
1641
  strides: this.strides,
1642
+ offset: this.offset,
1643
+ numel: this.numel,
1693
1644
  device: this.device,
1694
1645
  requiresGrad: false
1695
1646
  });
@@ -1699,6 +1650,8 @@ class Tensor {
1699
1650
  return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1700
1651
  shape: this.shape,
1701
1652
  strides: this.strides,
1653
+ offset: this.offset,
1654
+ numel: this.numel,
1702
1655
  requiresGrad: this.requiresGrad
1703
1656
  });
1704
1657
  }