catniff 0.7.4 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Tensor = void 0;
4
+ const dtype_1 = require("./dtype");
4
5
  const utils_1 = require("./utils");
5
6
  class Tensor {
6
7
  value;
@@ -13,12 +14,16 @@ class Tensor {
13
14
  gradFn;
14
15
  children;
15
16
  device;
17
+ dtype;
16
18
  static training = false;
17
19
  static noGrad = false;
18
20
  static createGraph = false;
19
21
  constructor(value, options = {}) {
20
- // Storage
21
- this.value = Tensor.flattenValue(value);
22
+ // Memory buffer
23
+ this.dtype = options.dtype || "float32";
24
+ const flatValue = Tensor.flattenValue(value);
25
+ const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
26
+ this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
22
27
  // Tensor metadata
23
28
  this.shape = options.shape || Tensor.getShape(value);
24
29
  this.strides = options.strides || Tensor.getStrides(this.shape);
@@ -34,31 +39,34 @@ class Tensor {
34
39
  this.to_(this.device);
35
40
  }
36
41
  // Utility to flatten an nD array to be 1D
37
- static flattenValue(tensor) {
42
+ static flattenValue(tensorValue) {
38
43
  // Handle scalar tensors
39
- if (typeof tensor === "number")
40
- return tensor;
44
+ if (typeof tensorValue === "number")
45
+ return [tensorValue];
41
46
  // If value is already 1D, we just need to return the value ('s reference)
42
- if (typeof tensor[0] === "number")
43
- return tensor;
47
+ if (typeof tensorValue[0] === "number")
48
+ return tensorValue;
44
49
  // Or else recursively traverse through the nD array to flatten
45
50
  const result = [];
46
51
  function traverse(arr) {
47
52
  if (typeof arr === "number") {
48
53
  result.push(arr);
54
+ // Assume if we can index a value, it is an ArrayLike
49
55
  }
50
- else if (Array.isArray(arr)) {
51
- arr.forEach(traverse);
56
+ else if (typeof arr[0] !== "undefined") {
57
+ for (let index = 0; index < arr.length; index++) {
58
+ traverse(arr[index]);
59
+ }
52
60
  }
53
61
  }
54
- traverse(tensor);
62
+ traverse(tensorValue);
55
63
  return result;
56
64
  }
57
65
  // Utility to get shape from tensor *value*
58
- static getShape(tensor) {
66
+ static getShape(tensorValue) {
59
67
  const shape = [];
60
- let subA = tensor;
61
- while (Array.isArray(subA)) {
68
+ let subA = tensorValue;
69
+ while (typeof subA !== "number") {
62
70
  shape.push(subA.length);
63
71
  subA = subA[0];
64
72
  }
@@ -146,17 +154,52 @@ class Tensor {
146
154
  }
147
155
  return prod;
148
156
  }
149
- ;
157
+ // Utility to get best possible result type if type conflicts happen:
158
+ static getResultDtype(type1, type2) {
159
+ if (type1 === type2)
160
+ return type1;
161
+ const type1Ranking = dtype_1.dtypeHiearchy[type1];
162
+ const type2Ranking = dtype_1.dtypeHiearchy[type2];
163
+ if (type1Ranking > type2Ranking) {
164
+ return type1;
165
+ }
166
+ return type2;
167
+ }
168
+ // Utility to handle other tensor if an op needs a second operand
169
+ handleOther(other) {
170
+ if (other instanceof Tensor) {
171
+ if (this.device !== other.device) {
172
+ throw new Error("Can not operate on tensors that are not on the same device");
173
+ }
174
+ return other;
175
+ }
176
+ return new Tensor(other, {
177
+ offset: 0,
178
+ device: this.device,
179
+ dtype: this.dtype
180
+ });
181
+ }
150
182
  // Utility for binary (two operators involved) element-wise ops
151
183
  static elementWiseAB(tA, tB, op) {
152
- if (typeof tA.value === "number" && typeof tB.value === "number") {
153
- return new Tensor(op(tA.value, tB.value));
184
+ const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype);
185
+ // Both are scalars
186
+ if (tA.shape.length === 0 && tB.shape.length === 0) {
187
+ return new Tensor(op(tA.value[0], tB.value[0]), {
188
+ shape: [],
189
+ strides: [],
190
+ offset: 0,
191
+ numel: 1,
192
+ device: tA.device,
193
+ dtype: outputDtype
194
+ });
154
195
  }
155
- if (typeof tA.value === "number") {
156
- return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
196
+ // First tensor is scalar
197
+ if (tA.shape.length === 0) {
198
+ return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0]));
157
199
  }
158
- if (typeof tB.value === "number") {
159
- return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
200
+ // Second tensor is scalar
201
+ if (tB.shape.length === 0) {
202
+ return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0]));
160
203
  }
161
204
  // Pad + broadcast shape
162
205
  const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
@@ -164,7 +207,7 @@ class Tensor {
164
207
  // Get other output info
165
208
  const outputStrides = Tensor.getStrides(outputShape);
166
209
  const outputSize = Tensor.shapeToSize(outputShape);
167
- const outputValue = new Array(outputSize);
210
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
168
211
  for (let i = 0; i < outputSize; i++) {
169
212
  // Get coordinates from 1D index
170
213
  const coordsOutput = Tensor.indexToCoords(i, outputStrides);
@@ -178,18 +221,29 @@ class Tensor {
178
221
  return new Tensor(outputValue, {
179
222
  shape: outputShape,
180
223
  strides: outputStrides,
181
- numel: outputSize
224
+ offset: 0,
225
+ numel: outputSize,
226
+ device: tA.device,
227
+ dtype: outputDtype
182
228
  });
183
229
  }
184
230
  // Utility for self-inflicting element-wise ops
185
231
  static elementWiseSelf(tA, op) {
186
- if (typeof tA.value === "number")
187
- return new Tensor(op(tA.value));
232
+ // Handle scalar case
233
+ if (tA.shape.length === 0)
234
+ return new Tensor(op(tA.value[0]), {
235
+ shape: [],
236
+ strides: [],
237
+ offset: 0,
238
+ numel: 1,
239
+ device: tA.device,
240
+ dtype: tA.dtype
241
+ });
188
242
  const contiguous = tA.isContiguous();
189
243
  const outputShape = tA.shape;
190
244
  const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
191
245
  const outputSize = tA.numel;
192
- const outputValue = new Array(outputSize);
246
+ const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize);
193
247
  if (contiguous) {
194
248
  for (let index = 0; index < outputSize; index++) {
195
249
  outputValue[index] = op(tA.value[index + tA.offset]);
@@ -202,7 +256,14 @@ class Tensor {
202
256
  outputValue[index] = op(tA.value[originalIndex]);
203
257
  }
204
258
  }
205
- return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: tA.numel });
259
+ return new Tensor(outputValue, {
260
+ shape: outputShape,
261
+ strides: outputStrides,
262
+ offset: 0,
263
+ numel: tA.numel,
264
+ device: tA.device,
265
+ dtype: tA.dtype
266
+ });
206
267
  }
207
268
  // Utility to do element-wise operation and build a dag node with another tensor
208
269
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
@@ -246,16 +307,6 @@ class Tensor {
246
307
  }
247
308
  return out;
248
309
  }
249
- // Utility to handle other tensor if an op needs a second operand
250
- handleOther(other) {
251
- if (other instanceof Tensor) {
252
- if (this.device !== other.device) {
253
- throw new Error("Can not operate on tensors that are not on the same device");
254
- }
255
- return other;
256
- }
257
- return new Tensor(other, { device: this.device });
258
- }
259
310
  // Utility to add to gradient of tensor
260
311
  static addGrad(tensor, accumGrad) {
261
312
  const axesToSqueeze = [];
@@ -278,7 +329,7 @@ class Tensor {
278
329
  tensor.grad = squeezedGrad;
279
330
  }
280
331
  else {
281
- tensor.grad = tensor.grad.add(squeezedGrad);
332
+ tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
282
333
  }
283
334
  }
284
335
  static normalizeDims(dims, numDims) {
@@ -306,20 +357,27 @@ class Tensor {
306
357
  }
307
358
  contiguous() {
308
359
  // Check if scalar
309
- if (typeof this.value === "number")
360
+ if (this.shape.length === 0)
310
361
  return this;
311
362
  // Check if already contiguous
312
363
  if (this.isContiguous())
313
364
  return this;
314
365
  const outputStrides = Tensor.getStrides(this.shape);
315
366
  const outputSize = this.numel;
316
- const outputValue = new Array(outputSize);
367
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
317
368
  for (let index = 0; index < outputSize; index++) {
318
369
  const outputCoords = Tensor.indexToCoords(index, outputStrides);
319
370
  const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
320
371
  outputValue[index] = this.value[this.offset + originalIndex];
321
372
  }
322
- const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
373
+ const out = new Tensor(outputValue, {
374
+ shape: this.shape,
375
+ strides: outputStrides,
376
+ offset: 0,
377
+ numel: outputSize,
378
+ device: this.device,
379
+ dtype: this.dtype
380
+ });
323
381
  // Gradient flow back to the original tensor
324
382
  if (this.requiresGrad) {
325
383
  out.requiresGrad = true;
@@ -334,7 +392,7 @@ class Tensor {
334
392
  // Verify shape size
335
393
  const originalSize = this.numel;
336
394
  const outputSize = Tensor.shapeToSize(newShape);
337
- if (originalSize !== outputSize || typeof this.value === "number") {
395
+ if (originalSize !== outputSize) {
338
396
  throw new Error("Can not create view: incompatible sizes");
339
397
  }
340
398
  // Verify compatibility (only contiguity for now)
@@ -347,7 +405,8 @@ class Tensor {
347
405
  strides: outputStrides,
348
406
  offset: this.offset,
349
407
  numel: outputSize,
350
- device: this.device
408
+ device: this.device,
409
+ dtype: this.dtype
351
410
  });
352
411
  // Gradient reshaped and flow back to the original tensor
353
412
  if (this.requiresGrad) {
@@ -423,7 +482,8 @@ class Tensor {
423
482
  strides: newStrides,
424
483
  offset: this.offset,
425
484
  numel: this.numel,
426
- device: this.device
485
+ device: this.device,
486
+ dtype: this.dtype
427
487
  });
428
488
  out.requiresGrad = this.requiresGrad;
429
489
  // Handle gradient if needed
@@ -464,7 +524,8 @@ class Tensor {
464
524
  strides: newStrides,
465
525
  offset: this.offset,
466
526
  numel: this.numel,
467
- device: this.device
527
+ device: this.device,
528
+ dtype: this.dtype
468
529
  });
469
530
  if (this.requiresGrad) {
470
531
  out.requiresGrad = true;
@@ -484,7 +545,7 @@ class Tensor {
484
545
  }
485
546
  // Utility for indexing with array of indices
486
547
  indexWithArray(indices) {
487
- if (typeof this.value === "number")
548
+ if (this.shape.length === 0)
488
549
  return this;
489
550
  indices = Tensor.normalizeDims(indices, this.shape[0]);
490
551
  // Init necessary stuff for indexing
@@ -494,7 +555,7 @@ class Tensor {
494
555
  // Init output data
495
556
  const outputShape = [indices.length, ...reducedShape];
496
557
  const outputSize = Tensor.shapeToSize(outputShape);
497
- const outputValue = new Array(outputSize);
558
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
498
559
  for (let i = 0; i < indices.length; i++) {
499
560
  const sourceRowIndex = indices[i];
500
561
  const targetStart = i * elementsPerIndex;
@@ -507,7 +568,10 @@ class Tensor {
507
568
  }
508
569
  const out = new Tensor(outputValue, {
509
570
  shape: outputShape,
510
- numel: outputSize
571
+ offset: 0,
572
+ numel: outputSize,
573
+ device: this.device,
574
+ dtype: this.dtype
511
575
  });
512
576
  // Handle gradient
513
577
  if (this.requiresGrad) {
@@ -536,13 +600,13 @@ class Tensor {
536
600
  // Tensor indexing
537
601
  index(indices) {
538
602
  const tensorIndices = this.handleOther(indices).clone();
539
- if (typeof tensorIndices.value === "number") {
540
- return this.indexWithArray([tensorIndices.value]).squeeze(0);
603
+ if (tensorIndices.shape.length === 0) {
604
+ return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
541
605
  }
542
606
  else {
543
607
  const originalShape = tensorIndices.shape;
544
608
  const flatIndices = tensorIndices.value;
545
- const result = this.indexWithArray(flatIndices);
609
+ const result = this.indexWithArray(Array.from(flatIndices));
546
610
  // Reshape to preserve input shape
547
611
  const outputShape = [...originalShape, ...this.shape.slice(1)];
548
612
  return result.reshape(outputShape);
@@ -551,7 +615,7 @@ class Tensor {
551
615
  // Tensor slicing
552
616
  slice(ranges) {
553
617
  // Handle scalars
554
- if (typeof this.value === "number")
618
+ if (this.shape.length === 0)
555
619
  return this;
556
620
  const newShape = [];
557
621
  const newStrides = [];
@@ -589,7 +653,8 @@ class Tensor {
589
653
  shape: newShape,
590
654
  strides: newStrides,
591
655
  offset: newOffset,
592
- device: this.device
656
+ device: this.device,
657
+ dtype: this.dtype
593
658
  });
594
659
  if (this.requiresGrad) {
595
660
  out.requiresGrad = true;
@@ -651,7 +716,7 @@ class Tensor {
651
716
  expand(newShape) {
652
717
  // Handle scalars
653
718
  let self = this;
654
- if (typeof this.value === "number") {
719
+ if (this.shape.length === 0) {
655
720
  self = self.unsqueeze(0);
656
721
  }
657
722
  // Pad shapes to same length
@@ -675,8 +740,8 @@ class Tensor {
675
740
  shape: targetShape,
676
741
  strides: newStrides,
677
742
  offset: self.offset,
678
- numel: Tensor.shapeToSize(targetShape),
679
- device: self.device
743
+ device: self.device,
744
+ dtype: self.dtype
680
745
  });
681
746
  if (self.requiresGrad) {
682
747
  out.requiresGrad = true;
@@ -691,7 +756,7 @@ class Tensor {
691
756
  cat(other, dim = 0) {
692
757
  other = this.handleOther(other);
693
758
  // Handle scalars
694
- if (typeof this.value === "number" || typeof other.value === "number") {
759
+ if (this.shape.length === 0 || other.shape.length === 0) {
695
760
  throw new Error("Can not concatenate scalars");
696
761
  }
697
762
  // Handle negative indices
@@ -720,7 +785,8 @@ class Tensor {
720
785
  }
721
786
  const outputSize = Tensor.shapeToSize(outputShape);
722
787
  const outputStrides = Tensor.getStrides(outputShape);
723
- const outputValue = new Array(outputSize);
788
+ const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
789
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
724
790
  for (let outIndex = 0; outIndex < outputSize; outIndex++) {
725
791
  const coords = Tensor.indexToCoords(outIndex, outputStrides);
726
792
  // Check which tensor this output position comes from
@@ -740,7 +806,10 @@ class Tensor {
740
806
  const out = new Tensor(outputValue, {
741
807
  shape: outputShape,
742
808
  strides: outputStrides,
743
- numel: outputSize
809
+ offset: 0,
810
+ numel: outputSize,
811
+ device: this.device,
812
+ dtype: this.dtype
744
813
  });
745
814
  if (this.requiresGrad) {
746
815
  out.requiresGrad = true;
@@ -773,7 +842,7 @@ class Tensor {
773
842
  }
774
843
  // Tensor squeeze
775
844
  squeeze(dims) {
776
- if (typeof this.value === "number")
845
+ if (this.shape.length === 0)
777
846
  return this;
778
847
  if (typeof dims === "number") {
779
848
  dims = [dims];
@@ -807,7 +876,9 @@ class Tensor {
807
876
  shape: outShape,
808
877
  strides: outStrides,
809
878
  offset: this.offset,
810
- device: this.device
879
+ numel: this.numel,
880
+ device: this.device,
881
+ dtype: this.dtype
811
882
  });
812
883
  // Set up gradient if needed
813
884
  if (this.requiresGrad) {
@@ -830,9 +901,6 @@ class Tensor {
830
901
  dim += this.shape.length;
831
902
  }
832
903
  let thisValue = this.value;
833
- if (typeof thisValue === "number") {
834
- thisValue = [thisValue];
835
- }
836
904
  // Insert size-1 dimension at specified position
837
905
  const newShape = [...this.shape];
838
906
  newShape.splice(dim, 0, 1);
@@ -852,7 +920,9 @@ class Tensor {
852
920
  shape: newShape,
853
921
  strides: newStrides,
854
922
  offset: this.offset,
855
- device: this.device
923
+ numel: this.numel,
924
+ device: this.device,
925
+ dtype: this.dtype
856
926
  });
857
927
  // Set up gradient if needed
858
928
  if (this.requiresGrad) {
@@ -866,10 +936,13 @@ class Tensor {
866
936
  }
867
937
  // Generic reduction operation handler
868
938
  static reduce(tensor, dims, keepDims, config) {
869
- if (typeof tensor.value === "number")
939
+ if (tensor.shape.length === 0)
870
940
  return tensor;
871
941
  if (typeof dims === "undefined") {
872
- dims = Array.from({ length: tensor.shape.length }, (_, index) => index);
942
+ dims = new Array(tensor.shape.length);
943
+ for (let index = 0; index < dims.length; index++) {
944
+ dims[index] = index;
945
+ }
873
946
  }
874
947
  if (Array.isArray(dims)) {
875
948
  dims = Tensor.normalizeDims(dims, tensor.shape.length);
@@ -880,11 +953,11 @@ class Tensor {
880
953
  }
881
954
  return keepDims ? reducedThis : reducedThis.squeeze(dims);
882
955
  }
956
+ const dimSize = tensor.shape[dims];
883
957
  const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
884
958
  const outputStrides = Tensor.getStrides(outputShape);
885
- const outputSize = Tensor.shapeToSize(outputShape);
886
- const outputValue = new Array(outputSize).fill(config.identity);
887
- const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
959
+ const outputSize = tensor.numel / dimSize;
960
+ const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
888
961
  const originalSize = tensor.numel;
889
962
  const originalValue = tensor.value;
890
963
  const linearStrides = Tensor.getStrides(tensor.shape);
@@ -899,24 +972,27 @@ class Tensor {
899
972
  const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
900
973
  // Apply op
901
974
  outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
902
- // Count el if needed
903
- if (config.needsCounters) {
904
- outputCounters[outFlatIndex]++;
905
- }
906
975
  }
907
976
  // Post-process if needed (e.g., divide by count for mean)
908
977
  if (config.postProcess) {
909
- config.postProcess({ values: outputValue, counters: outputCounters });
978
+ config.postProcess({ values: outputValue, dimSize });
910
979
  }
911
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
980
+ const out = new Tensor(outputValue, {
981
+ shape: outputShape,
982
+ strides: outputStrides,
983
+ offset: 0,
984
+ numel: outputSize,
985
+ device: tensor.device,
986
+ dtype: tensor.dtype
987
+ });
912
988
  // Gradient setup
913
989
  if (tensor.requiresGrad) {
914
990
  out.requiresGrad = true;
915
991
  out.children.push(tensor);
916
992
  out.gradFn = () => {
917
- let shareCounts = [];
993
+ let shareCounts = new dtype_1.TypedArray[tensor.dtype]();
918
994
  if (config.needsShareCounts) {
919
- shareCounts = new Array(outputSize).fill(0);
995
+ shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0);
920
996
  for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
921
997
  // Convert linear index to coordinates using contiguous strides
922
998
  const coords = Tensor.indexToCoords(flatIndex, linearStrides);
@@ -929,7 +1005,7 @@ class Tensor {
929
1005
  shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
930
1006
  }
931
1007
  }
932
- const gradValue = new Array(originalSize);
1008
+ const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize);
933
1009
  for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
934
1010
  // Convert linear index to coordinates using contiguous strides
935
1011
  const coords = Tensor.indexToCoords(flatIndex, linearStrides);
@@ -941,13 +1017,19 @@ class Tensor {
941
1017
  gradValue[flatIndex] = config.gradientFn({
942
1018
  outputValue,
943
1019
  originalValue: tensor.value,
944
- counters: outputCounters,
1020
+ dimSize,
945
1021
  shareCounts,
946
1022
  realIndex: realFlatIndex,
947
1023
  outIndex: outFlatIndex
948
1024
  });
949
1025
  }
950
- const localGrad = new Tensor(gradValue, { shape: tensor.shape, numel: tensor.numel });
1026
+ const localGrad = new Tensor(gradValue, {
1027
+ shape: tensor.shape,
1028
+ offset: 0,
1029
+ numel: tensor.numel,
1030
+ device: tensor.device,
1031
+ dtype: tensor.dtype
1032
+ });
951
1033
  Tensor.addGrad(tensor, out.grad.mul(localGrad));
952
1034
  };
953
1035
  }
@@ -972,13 +1054,12 @@ class Tensor {
972
1054
  return Tensor.reduce(this, dims, keepDims, {
973
1055
  identity: 0,
974
1056
  operation: (a, b) => a + b,
975
- needsCounters: true,
976
- postProcess: ({ values, counters }) => {
1057
+ postProcess: ({ values, dimSize }) => {
977
1058
  for (let i = 0; i < values.length; i++) {
978
- values[i] /= counters[i];
1059
+ values[i] /= dimSize;
979
1060
  }
980
1061
  },
981
- gradientFn: ({ counters, outIndex }) => 1 / counters[outIndex]
1062
+ gradientFn: ({ dimSize }) => 1 / dimSize
982
1063
  });
983
1064
  }
984
1065
  max(dims, keepDims = false) {
@@ -1017,7 +1098,7 @@ class Tensor {
1017
1098
  }
1018
1099
  // Tensor softmax
1019
1100
  softmax(dim = -1) {
1020
- if (typeof this.value === "number")
1101
+ if (this.shape.length === 0)
1021
1102
  return this;
1022
1103
  // Handle negative indexing
1023
1104
  if (dim < 0) {
@@ -1035,7 +1116,7 @@ class Tensor {
1035
1116
  }
1036
1117
  // Tensor softmin
1037
1118
  softmin(dim = -1) {
1038
- if (typeof this.value === "number")
1119
+ if (this.shape.length === 0)
1039
1120
  return this;
1040
1121
  // Handle negative indexing
1041
1122
  if (dim < 0) {
@@ -1435,10 +1516,11 @@ class Tensor {
1435
1516
  const matBCols = other.shape[1];
1436
1517
  if (matACols !== matBRows)
1437
1518
  throw new Error("Invalid matrices shape for multiplication");
1519
+ const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1438
1520
  const matCShape = [matARows, matBCols];
1439
1521
  const matCStrides = Tensor.getStrides(matCShape);
1440
1522
  const matCSize = Tensor.shapeToSize(matCShape);
1441
- const matC = new Array(matCSize).fill(0);
1523
+ const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
1442
1524
  for (let i = 0; i < matARows; i++) {
1443
1525
  for (let j = 0; j < matBCols; j++) {
1444
1526
  for (let k = 0; k < matACols; k++) {
@@ -1449,7 +1531,14 @@ class Tensor {
1449
1531
  }
1450
1532
  }
1451
1533
  }
1452
- const out = new Tensor(matC, { shape: matCShape, strides: matCStrides, numel: matCSize });
1534
+ const out = new Tensor(matC, {
1535
+ shape: matCShape,
1536
+ strides: matCStrides,
1537
+ offset: 0,
1538
+ numel: matCSize,
1539
+ device: this.device,
1540
+ dtype: matCDtype
1541
+ });
1453
1542
  if (this.requiresGrad) {
1454
1543
  out.requiresGrad = true;
1455
1544
  out.children.push(this);
@@ -1490,10 +1579,11 @@ class Tensor {
1490
1579
  const batchBCols = other.shape[2];
1491
1580
  if (batchACols !== batchBRows)
1492
1581
  throw new Error("Invalid matrices shape for multiplication");
1582
+ const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1493
1583
  const batchCShape = [batchSize, batchARows, batchBCols];
1494
1584
  const batchCStrides = Tensor.getStrides(batchCShape);
1495
1585
  const batchCSize = Tensor.shapeToSize(batchCShape);
1496
- const batchC = new Array(batchCSize).fill(0);
1586
+ const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
1497
1587
  for (let q = 0; q < batchSize; q++) {
1498
1588
  for (let i = 0; i < batchARows; i++) {
1499
1589
  for (let j = 0; j < batchBCols; j++) {
@@ -1506,7 +1596,14 @@ class Tensor {
1506
1596
  }
1507
1597
  }
1508
1598
  }
1509
- const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides, numel: batchCSize });
1599
+ const out = new Tensor(batchC, {
1600
+ shape: batchCShape,
1601
+ strides: batchCStrides,
1602
+ offset: 0,
1603
+ numel: batchCSize,
1604
+ device: this.device,
1605
+ dtype: batchCDtype
1606
+ });
1510
1607
  if (this.requiresGrad) {
1511
1608
  out.requiresGrad = true;
1512
1609
  out.children.push(this);
@@ -1583,10 +1680,11 @@ class Tensor {
1583
1680
  const offsetSize = Tensor.shapeToSize(offsetShape);
1584
1681
  const offsetStrides = Tensor.getStrides(offsetShape);
1585
1682
  // Output shape, strides, size, value
1683
+ const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1586
1684
  const outputShape = [...offsetShape, batchARows, batchBCols];
1587
1685
  const outputStrides = Tensor.getStrides(outputShape);
1588
1686
  const outputSize = Tensor.shapeToSize(outputShape);
1589
- const outputValue = new Array(outputSize).fill(0);
1687
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
1590
1688
  const outputOffsetStrides = outputStrides.slice(0, -2);
1591
1689
  // Loop through outer dims and do matmul on two outer-most dims
1592
1690
  for (let index = 0; index < offsetSize; index++) {
@@ -1605,7 +1703,14 @@ class Tensor {
1605
1703
  }
1606
1704
  }
1607
1705
  }
1608
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
1706
+ const out = new Tensor(outputValue, {
1707
+ shape: outputShape,
1708
+ strides: outputStrides,
1709
+ offset: 0,
1710
+ numel: outputSize,
1711
+ device: this.device,
1712
+ dtype: outputDtype
1713
+ });
1609
1714
  if (this.requiresGrad) {
1610
1715
  out.requiresGrad = true;
1611
1716
  out.children.push(this);
@@ -1647,7 +1752,7 @@ class Tensor {
1647
1752
  const maskShape = this.shape.slice(-2);
1648
1753
  const maskStrides = Tensor.getStrides(maskShape);
1649
1754
  const maskSize = Tensor.shapeToSize(maskShape);
1650
- const maskValue = new Array(maskSize).fill(1);
1755
+ const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(1);
1651
1756
  const [rows, cols] = maskShape;
1652
1757
  for (let i = 0; i < rows; i++) {
1653
1758
  const maxJ = Math.min(cols, i + diagonal);
@@ -1658,8 +1763,10 @@ class Tensor {
1658
1763
  const mask = new Tensor(maskValue, {
1659
1764
  shape: maskShape,
1660
1765
  strides: maskStrides,
1766
+ offset: 0,
1661
1767
  numel: maskSize,
1662
- device: this.device
1768
+ device: this.device,
1769
+ dtype: this.dtype
1663
1770
  });
1664
1771
  return this.mul(mask);
1665
1772
  }
@@ -1671,7 +1778,7 @@ class Tensor {
1671
1778
  const maskShape = this.shape.slice(-2);
1672
1779
  const maskStrides = Tensor.getStrides(maskShape);
1673
1780
  const maskSize = Tensor.shapeToSize(maskShape);
1674
- const maskValue = new Array(maskSize).fill(0);
1781
+ const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(0);
1675
1782
  const [rows, cols] = maskShape;
1676
1783
  for (let i = 0; i < rows; i++) {
1677
1784
  const maxJ = Math.min(cols, i + diagonal + 1);
@@ -1682,8 +1789,10 @@ class Tensor {
1682
1789
  const mask = new Tensor(maskValue, {
1683
1790
  shape: maskShape,
1684
1791
  strides: maskStrides,
1792
+ offset: 0,
1685
1793
  numel: maskSize,
1686
- device: this.device
1794
+ device: this.device,
1795
+ dtype: this.dtype
1687
1796
  });
1688
1797
  return this.mul(mask);
1689
1798
  }
@@ -1698,16 +1807,28 @@ class Tensor {
1698
1807
  return new Tensor(num, options);
1699
1808
  const outputSize = Tensor.shapeToSize(shape);
1700
1809
  const outputValue = new Array(outputSize).fill(num);
1701
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1810
+ return new Tensor(outputValue, {
1811
+ shape,
1812
+ offset: 0,
1813
+ numel: outputSize,
1814
+ ...options
1815
+ });
1702
1816
  }
1703
1817
  // Utility to create a new tensor with shape of another tensor, filled with a number
1704
1818
  static fullLike(tensor, num, options = {}) {
1705
- if (typeof tensor.value === "number")
1706
- return new Tensor(num, options);
1819
+ if (tensor.shape.length === 0)
1820
+ return new Tensor(num, {
1821
+ offset: 0,
1822
+ device: tensor.device,
1823
+ dtype: tensor.dtype,
1824
+ ...options
1825
+ });
1707
1826
  return new Tensor(new Array(tensor.numel).fill(num), {
1708
1827
  shape: tensor.shape,
1828
+ offset: 0,
1709
1829
  numel: tensor.numel,
1710
1830
  device: tensor.device,
1831
+ dtype: tensor.dtype,
1711
1832
  ...options
1712
1833
  });
1713
1834
  }
@@ -1717,16 +1838,28 @@ class Tensor {
1717
1838
  return new Tensor(1, options);
1718
1839
  const outputSize = Tensor.shapeToSize(shape);
1719
1840
  const outputValue = new Array(outputSize).fill(1);
1720
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1841
+ return new Tensor(outputValue, {
1842
+ shape,
1843
+ offset: 0,
1844
+ numel: outputSize,
1845
+ ...options
1846
+ });
1721
1847
  }
1722
1848
  // Utility to create a new tensor with shape of another tensor, filled with 1
1723
1849
  static onesLike(tensor, options = {}) {
1724
- if (typeof tensor.value === "number")
1725
- return new Tensor(1, options);
1850
+ if (tensor.shape.length === 0)
1851
+ return new Tensor(1, {
1852
+ offset: 0,
1853
+ device: tensor.device,
1854
+ dtype: tensor.dtype,
1855
+ ...options
1856
+ });
1726
1857
  return new Tensor(new Array(tensor.numel).fill(1), {
1727
1858
  shape: tensor.shape,
1859
+ offset: 0,
1728
1860
  numel: tensor.numel,
1729
1861
  device: tensor.device,
1862
+ dtype: tensor.dtype,
1730
1863
  ...options
1731
1864
  });
1732
1865
  }
@@ -1736,16 +1869,28 @@ class Tensor {
1736
1869
  return new Tensor(0, options);
1737
1870
  const outputSize = Tensor.shapeToSize(shape);
1738
1871
  const outputValue = new Array(outputSize).fill(0);
1739
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1872
+ return new Tensor(outputValue, {
1873
+ shape,
1874
+ offset: 0,
1875
+ numel: outputSize,
1876
+ ...options
1877
+ });
1740
1878
  }
1741
1879
  // Utility to create a new tensor with shape of another tensor, filled with 0
1742
1880
  static zerosLike(tensor, options = {}) {
1743
- if (typeof tensor.value === "number")
1744
- return new Tensor(0, options);
1881
+ if (tensor.shape.length === 0)
1882
+ return new Tensor(0, {
1883
+ offset: 0,
1884
+ device: tensor.device,
1885
+ dtype: tensor.dtype,
1886
+ ...options
1887
+ });
1745
1888
  return new Tensor(new Array(tensor.numel).fill(0), {
1746
1889
  shape: tensor.shape,
1890
+ offset: 0,
1747
1891
  numel: tensor.numel,
1748
1892
  device: tensor.device,
1893
+ dtype: tensor.dtype,
1749
1894
  ...options
1750
1895
  });
1751
1896
  }
@@ -1758,20 +1903,32 @@ class Tensor {
1758
1903
  for (let index = 0; index < outputValue.length; index++) {
1759
1904
  outputValue[index] = (0, utils_1.randUniform)();
1760
1905
  }
1761
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1906
+ return new Tensor(outputValue, {
1907
+ shape,
1908
+ offset: 0,
1909
+ numel: outputSize,
1910
+ ...options
1911
+ });
1762
1912
  }
1763
1913
  // Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
1764
1914
  static randLike(tensor, options = {}) {
1765
- if (typeof tensor.value === "number")
1766
- return new Tensor((0, utils_1.randUniform)(), options);
1915
+ if (tensor.shape.length === 0)
1916
+ return new Tensor((0, utils_1.randUniform)(), {
1917
+ offset: 0,
1918
+ device: tensor.device,
1919
+ dtype: tensor.dtype,
1920
+ ...options
1921
+ });
1767
1922
  const outputValue = new Array(tensor.numel);
1768
1923
  for (let index = 0; index < outputValue.length; index++) {
1769
1924
  outputValue[index] = (0, utils_1.randUniform)();
1770
1925
  }
1771
1926
  return new Tensor(outputValue, {
1772
1927
  shape: tensor.shape,
1928
+ offset: 0,
1773
1929
  numel: tensor.numel,
1774
1930
  device: tensor.device,
1931
+ dtype: tensor.dtype,
1775
1932
  ...options
1776
1933
  });
1777
1934
  }
@@ -1784,20 +1941,32 @@ class Tensor {
1784
1941
  for (let index = 0; index < outputValue.length; index++) {
1785
1942
  outputValue[index] = (0, utils_1.randNormal)();
1786
1943
  }
1787
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1944
+ return new Tensor(outputValue, {
1945
+ shape,
1946
+ offset: 0,
1947
+ numel: outputSize,
1948
+ ...options
1949
+ });
1788
1950
  }
1789
1951
  // Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
1790
1952
  static randnLike(tensor, options = {}) {
1791
- if (typeof tensor.value === "number")
1792
- return new Tensor((0, utils_1.randNormal)(), options);
1953
+ if (tensor.shape.length === 0)
1954
+ return new Tensor((0, utils_1.randNormal)(), {
1955
+ offset: 0,
1956
+ device: tensor.device,
1957
+ dtype: tensor.dtype,
1958
+ ...options
1959
+ });
1793
1960
  const outputValue = new Array(tensor.numel);
1794
1961
  for (let index = 0; index < outputValue.length; index++) {
1795
1962
  outputValue[index] = (0, utils_1.randNormal)();
1796
1963
  }
1797
1964
  return new Tensor(outputValue, {
1798
1965
  shape: tensor.shape,
1966
+ offset: 0,
1799
1967
  numel: tensor.numel,
1800
1968
  device: tensor.device,
1969
+ dtype: tensor.dtype,
1801
1970
  ...options
1802
1971
  });
1803
1972
  }
@@ -1810,20 +1979,32 @@ class Tensor {
1810
1979
  for (let index = 0; index < outputValue.length; index++) {
1811
1980
  outputValue[index] = (0, utils_1.randInt)(low, high);
1812
1981
  }
1813
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1982
+ return new Tensor(outputValue, {
1983
+ shape,
1984
+ offset: 0,
1985
+ numel: outputSize,
1986
+ ...options
1987
+ });
1814
1988
  }
1815
1989
  // Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
1816
1990
  static randintLike(tensor, low, high, options = {}) {
1817
- if (typeof tensor.value === "number")
1818
- return new Tensor((0, utils_1.randInt)(low, high), options);
1991
+ if (tensor.shape.length === 0)
1992
+ return new Tensor((0, utils_1.randInt)(low, high), {
1993
+ offset: 0,
1994
+ device: tensor.device,
1995
+ dtype: tensor.dtype,
1996
+ ...options
1997
+ });
1819
1998
  const outputValue = new Array(tensor.numel);
1820
1999
  for (let index = 0; index < outputValue.length; index++) {
1821
2000
  outputValue[index] = (0, utils_1.randInt)(low, high);
1822
2001
  }
1823
2002
  return new Tensor(outputValue, {
1824
2003
  shape: tensor.shape,
2004
+ offset: 0,
1825
2005
  numel: tensor.numel,
1826
2006
  device: tensor.device,
2007
+ dtype: tensor.dtype,
1827
2008
  ...options
1828
2009
  });
1829
2010
  }
@@ -1834,7 +2015,12 @@ class Tensor {
1834
2015
  outputValue[i] = i;
1835
2016
  }
1836
2017
  (0, utils_1.fyShuffle)(outputValue);
1837
- return new Tensor(outputValue, { shape: [n], numel: n, ...options });
2018
+ return new Tensor(outputValue, {
2019
+ shape: [n],
2020
+ offset: 0,
2021
+ numel: n,
2022
+ ...options
2023
+ });
1838
2024
  }
1839
2025
  // Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
1840
2026
  static normal(shape, mean, stdDev, options = {}) {
@@ -1845,7 +2031,12 @@ class Tensor {
1845
2031
  for (let index = 0; index < outputValue.length; index++) {
1846
2032
  outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
1847
2033
  }
1848
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
2034
+ return new Tensor(outputValue, {
2035
+ shape,
2036
+ offset: 0,
2037
+ numel: outputSize,
2038
+ ...options
2039
+ });
1849
2040
  }
1850
2041
  // Utility to create a new tensor filled with a random number with uniform distribution from low to high
1851
2042
  static uniform(shape, low, high, options = {}) {
@@ -1856,7 +2047,12 @@ class Tensor {
1856
2047
  for (let index = 0; index < outputValue.length; index++) {
1857
2048
  outputValue[index] = (0, utils_1.randUniform)(low, high);
1858
2049
  }
1859
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
2050
+ return new Tensor(outputValue, {
2051
+ shape,
2052
+ offset: 0,
2053
+ numel: outputSize,
2054
+ ...options
2055
+ });
1860
2056
  }
1861
2057
  // Utility to create an 1D tensor from a range incrementing with "step"
1862
2058
  static arange(start, stop, step = 1, options = {}) {
@@ -1870,7 +2066,12 @@ class Tensor {
1870
2066
  for (let index = 0; index < outputValue.length; index++) {
1871
2067
  outputValue[index] = start + step * index;
1872
2068
  }
1873
- return new Tensor(outputValue, { shape: outputShape, numel: outputSize, ...options });
2069
+ return new Tensor(outputValue, {
2070
+ shape: outputShape,
2071
+ offset: 0,
2072
+ numel: outputSize,
2073
+ ...options
2074
+ });
1874
2075
  }
1875
2076
  // Utility to create an 1D tensor from a range evenly spaced out with a given amount of steps
1876
2077
  static linspace(start, stop, steps, options = {}) {
@@ -1886,7 +2087,12 @@ class Tensor {
1886
2087
  }
1887
2088
  // Ensure we hit the endpoint exactly (avoids floating point errors)
1888
2089
  outputValue[steps - 1] = stop;
1889
- return new Tensor(outputValue, { shape: [steps], numel: steps, ...options });
2090
+ return new Tensor(outputValue, {
2091
+ shape: [steps],
2092
+ offset: 0,
2093
+ numel: steps,
2094
+ ...options
2095
+ });
1890
2096
  }
1891
2097
  // Utility to create a 2D tensor with its main diagonal filled with 1s and others with 0s
1892
2098
  static eye(n, m = n, options = {}) {
@@ -1897,7 +2103,13 @@ class Tensor {
1897
2103
  for (let i = 0; i < Math.min(n, m); i++) {
1898
2104
  outputValue[i * outputStrides[0] + i * outputStrides[1]] = 1;
1899
2105
  }
1900
- return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize, ...options });
2106
+ return new Tensor(outputValue, {
2107
+ shape: outputShape,
2108
+ offset: 0,
2109
+ strides: outputStrides,
2110
+ numel: outputSize,
2111
+ ...options
2112
+ });
1901
2113
  }
1902
2114
  // Reverse-mode autodiff call
1903
2115
  backward(options = {}) {
@@ -1928,8 +2140,8 @@ class Tensor {
1928
2140
  }
1929
2141
  // Returns the raw number/nD array form of tensor
1930
2142
  val() {
1931
- if (typeof this.value === "number")
1932
- return this.value;
2143
+ if (this.shape.length === 0)
2144
+ return this.value[0];
1933
2145
  function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
1934
2146
  if (dim === shape.length - 1) {
1935
2147
  // Last dimension: extract elements using actual stride
@@ -1956,20 +2168,28 @@ class Tensor {
1956
2168
  offset: this.offset,
1957
2169
  numel: this.numel,
1958
2170
  device: this.device,
2171
+ dtype: this.dtype,
1959
2172
  requiresGrad: false
1960
2173
  });
1961
2174
  }
1962
2175
  // Returns a copy of the tensor (with new data allocation) and keeps grad connection
1963
2176
  clone() {
1964
2177
  let out;
1965
- if (typeof this.value === "number") {
1966
- out = new Tensor(this.value);
2178
+ if (this.shape.length === 0) {
2179
+ out = new Tensor(this.value, {
2180
+ shape: [],
2181
+ strides: [],
2182
+ offset: 0,
2183
+ numel: 1,
2184
+ device: this.device,
2185
+ dtype: this.dtype
2186
+ });
1967
2187
  }
1968
2188
  else {
1969
2189
  const contiguous = this.isContiguous();
1970
2190
  const outputStrides = contiguous ? this.strides : Tensor.getStrides(this.shape);
1971
2191
  const outputSize = this.numel;
1972
- const outputValue = new Array(outputSize);
2192
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
1973
2193
  if (contiguous) {
1974
2194
  for (let index = 0; index < outputSize; index++) {
1975
2195
  outputValue[index] = this.value[this.offset + index];
@@ -1982,7 +2202,14 @@ class Tensor {
1982
2202
  outputValue[index] = this.value[this.offset + originalIndex];
1983
2203
  }
1984
2204
  }
1985
- out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
2205
+ out = new Tensor(outputValue, {
2206
+ shape: this.shape,
2207
+ strides: outputStrides,
2208
+ offset: 0,
2209
+ numel: outputSize,
2210
+ device: this.device,
2211
+ dtype: this.dtype
2212
+ });
1986
2213
  }
1987
2214
  if (this.requiresGrad) {
1988
2215
  out.requiresGrad = true;
@@ -2009,8 +2236,23 @@ class Tensor {
2009
2236
  this.value = other.value;
2010
2237
  this.strides = other.strides;
2011
2238
  this.offset = other.offset;
2239
+ this.device = other.device;
2240
+ this.dtype = other.dtype;
2012
2241
  return this;
2013
2242
  }
2243
+ // Op to return a new tensor casted to another dtype
2244
+ cast(dtype) {
2245
+ if (this.dtype === dtype)
2246
+ return this;
2247
+ return new Tensor(this.value, {
2248
+ shape: this.shape,
2249
+ strides: this.strides,
2250
+ offset: this.offset,
2251
+ numel: this.numel,
2252
+ device: this.device,
2253
+ dtype: dtype
2254
+ });
2255
+ }
2014
2256
  // Holds all available backends
2015
2257
  static backends = new Map();
2016
2258
  // Op to transfer tensor to another device