catniff 0.7.3 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Tensor = void 0;
4
+ const dtype_1 = require("./dtype");
4
5
  const utils_1 = require("./utils");
5
6
  class Tensor {
6
7
  value;
@@ -13,12 +14,16 @@ class Tensor {
13
14
  gradFn;
14
15
  children;
15
16
  device;
17
+ dtype;
16
18
  static training = false;
17
19
  static noGrad = false;
18
20
  static createGraph = false;
19
21
  constructor(value, options = {}) {
20
- // Storage
21
- this.value = Tensor.flattenValue(value);
22
+ // Memory buffer
23
+ this.dtype = options.dtype || "float32";
24
+ const flatValue = Tensor.flattenValue(value);
25
+ const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
26
+ this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
22
27
  // Tensor metadata
23
28
  this.shape = options.shape || Tensor.getShape(value);
24
29
  this.strides = options.strides || Tensor.getStrides(this.shape);
@@ -34,31 +39,34 @@ class Tensor {
34
39
  this.to_(this.device);
35
40
  }
36
41
  // Utility to flatten an nD array to be 1D
37
- static flattenValue(tensor) {
42
+ static flattenValue(tensorValue) {
38
43
  // Handle scalar tensors
39
- if (typeof tensor === "number")
40
- return tensor;
44
+ if (typeof tensorValue === "number")
45
+ return [tensorValue];
41
46
  // If value is already 1D, we just need to return the value ('s reference)
42
- if (typeof tensor[0] === "number")
43
- return tensor;
47
+ if (typeof tensorValue[0] === "number")
48
+ return tensorValue;
44
49
  // Or else recursively traverse through the nD array to flatten
45
50
  const result = [];
46
51
  function traverse(arr) {
47
52
  if (typeof arr === "number") {
48
53
  result.push(arr);
54
+ // Assume if we can index a value, it is an ArrayLike
49
55
  }
50
- else if (Array.isArray(arr)) {
51
- arr.forEach(traverse);
56
+ else if (typeof arr[0] !== "undefined") {
57
+ for (let index = 0; index < arr.length; index++) {
58
+ traverse(arr[index]);
59
+ }
52
60
  }
53
61
  }
54
- traverse(tensor);
62
+ traverse(tensorValue);
55
63
  return result;
56
64
  }
57
65
  // Utility to get shape from tensor *value*
58
- static getShape(tensor) {
66
+ static getShape(tensorValue) {
59
67
  const shape = [];
60
- let subA = tensor;
61
- while (Array.isArray(subA)) {
68
+ let subA = tensorValue;
69
+ while (typeof subA !== "number") {
62
70
  shape.push(subA.length);
63
71
  subA = subA[0];
64
72
  }
@@ -146,17 +154,52 @@ class Tensor {
146
154
  }
147
155
  return prod;
148
156
  }
149
- ;
157
+ // Utility to get best possible result type if type conflicts happen:
158
+ static getResultDtype(type1, type2) {
159
+ if (type1 === type2)
160
+ return type1;
161
+ const type1Ranking = dtype_1.dtypeHiearchy[type1];
162
+ const type2Ranking = dtype_1.dtypeHiearchy[type2];
163
+ if (type1Ranking > type2Ranking) {
164
+ return type1;
165
+ }
166
+ return type2;
167
+ }
168
+ // Utility to handle other tensor if an op needs a second operand
169
+ handleOther(other) {
170
+ if (other instanceof Tensor) {
171
+ if (this.device !== other.device) {
172
+ throw new Error("Can not operate on tensors that are not on the same device");
173
+ }
174
+ return other;
175
+ }
176
+ return new Tensor(other, {
177
+ offset: 0,
178
+ device: this.device,
179
+ dtype: this.dtype
180
+ });
181
+ }
150
182
  // Utility for binary (two operators involved) element-wise ops
151
183
  static elementWiseAB(tA, tB, op) {
152
- if (typeof tA.value === "number" && typeof tB.value === "number") {
153
- return new Tensor(op(tA.value, tB.value));
184
+ const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype);
185
+ // Both are scalars
186
+ if (tA.shape.length === 0 && tB.shape.length === 0) {
187
+ return new Tensor(op(tA.value[0], tB.value[0]), {
188
+ shape: [],
189
+ strides: [],
190
+ offset: 0,
191
+ numel: 1,
192
+ device: tA.device,
193
+ dtype: outputDtype
194
+ });
154
195
  }
155
- if (typeof tA.value === "number") {
156
- return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
196
+ // First tensor is scalar
197
+ if (tA.shape.length === 0) {
198
+ return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0]));
157
199
  }
158
- if (typeof tB.value === "number") {
159
- return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
200
+ // Second tensor is scalar
201
+ if (tB.shape.length === 0) {
202
+ return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0]));
160
203
  }
161
204
  // Pad + broadcast shape
162
205
  const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
@@ -164,7 +207,7 @@ class Tensor {
164
207
  // Get other output info
165
208
  const outputStrides = Tensor.getStrides(outputShape);
166
209
  const outputSize = Tensor.shapeToSize(outputShape);
167
- const outputValue = new Array(outputSize);
210
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
168
211
  for (let i = 0; i < outputSize; i++) {
169
212
  // Get coordinates from 1D index
170
213
  const coordsOutput = Tensor.indexToCoords(i, outputStrides);
@@ -178,18 +221,29 @@ class Tensor {
178
221
  return new Tensor(outputValue, {
179
222
  shape: outputShape,
180
223
  strides: outputStrides,
181
- numel: outputSize
224
+ offset: 0,
225
+ numel: outputSize,
226
+ device: tA.device,
227
+ dtype: outputDtype
182
228
  });
183
229
  }
184
230
  // Utility for self-inflicting element-wise ops
185
231
  static elementWiseSelf(tA, op) {
186
- if (typeof tA.value === "number")
187
- return new Tensor(op(tA.value));
232
+ // Handle scalar case
233
+ if (tA.shape.length === 0)
234
+ return new Tensor(op(tA.value[0]), {
235
+ shape: [],
236
+ strides: [],
237
+ offset: 0,
238
+ numel: 1,
239
+ device: tA.device,
240
+ dtype: tA.dtype
241
+ });
188
242
  const contiguous = tA.isContiguous();
189
243
  const outputShape = tA.shape;
190
244
  const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
191
245
  const outputSize = tA.numel;
192
- const outputValue = new Array(outputSize);
246
+ const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize);
193
247
  if (contiguous) {
194
248
  for (let index = 0; index < outputSize; index++) {
195
249
  outputValue[index] = op(tA.value[index + tA.offset]);
@@ -202,7 +256,14 @@ class Tensor {
202
256
  outputValue[index] = op(tA.value[originalIndex]);
203
257
  }
204
258
  }
205
- return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: tA.numel });
259
+ return new Tensor(outputValue, {
260
+ shape: outputShape,
261
+ strides: outputStrides,
262
+ offset: 0,
263
+ numel: tA.numel,
264
+ device: tA.device,
265
+ dtype: tA.dtype
266
+ });
206
267
  }
207
268
  // Utility to do element-wise operation and build a dag node with another tensor
208
269
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
@@ -246,16 +307,6 @@ class Tensor {
246
307
  }
247
308
  return out;
248
309
  }
249
- // Utility to handle other tensor if an op needs a second operand
250
- handleOther(other) {
251
- if (other instanceof Tensor) {
252
- if (this.device !== other.device) {
253
- throw new Error("Can not operate on tensors that are not on the same device");
254
- }
255
- return other;
256
- }
257
- return new Tensor(other, { device: this.device });
258
- }
259
310
  // Utility to add to gradient of tensor
260
311
  static addGrad(tensor, accumGrad) {
261
312
  const axesToSqueeze = [];
@@ -278,7 +329,7 @@ class Tensor {
278
329
  tensor.grad = squeezedGrad;
279
330
  }
280
331
  else {
281
- tensor.grad = tensor.grad.add(squeezedGrad);
332
+ tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
282
333
  }
283
334
  }
284
335
  static normalizeDims(dims, numDims) {
@@ -306,20 +357,27 @@ class Tensor {
306
357
  }
307
358
  contiguous() {
308
359
  // Check if scalar
309
- if (typeof this.value === "number")
360
+ if (this.shape.length === 0)
310
361
  return this;
311
362
  // Check if already contiguous
312
363
  if (this.isContiguous())
313
364
  return this;
314
365
  const outputStrides = Tensor.getStrides(this.shape);
315
366
  const outputSize = this.numel;
316
- const outputValue = new Array(outputSize);
367
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
317
368
  for (let index = 0; index < outputSize; index++) {
318
369
  const outputCoords = Tensor.indexToCoords(index, outputStrides);
319
370
  const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
320
371
  outputValue[index] = this.value[this.offset + originalIndex];
321
372
  }
322
- const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
373
+ const out = new Tensor(outputValue, {
374
+ shape: this.shape,
375
+ strides: outputStrides,
376
+ offset: 0,
377
+ numel: outputSize,
378
+ device: this.device,
379
+ dtype: this.dtype
380
+ });
323
381
  // Gradient flow back to the original tensor
324
382
  if (this.requiresGrad) {
325
383
  out.requiresGrad = true;
@@ -334,7 +392,7 @@ class Tensor {
334
392
  // Verify shape size
335
393
  const originalSize = this.numel;
336
394
  const outputSize = Tensor.shapeToSize(newShape);
337
- if (originalSize !== outputSize || typeof this.value === "number") {
395
+ if (originalSize !== outputSize) {
338
396
  throw new Error("Can not create view: incompatible sizes");
339
397
  }
340
398
  // Verify compatibility (only contiguity for now)
@@ -347,7 +405,8 @@ class Tensor {
347
405
  strides: outputStrides,
348
406
  offset: this.offset,
349
407
  numel: outputSize,
350
- device: this.device
408
+ device: this.device,
409
+ dtype: this.dtype
351
410
  });
352
411
  // Gradient reshaped and flow back to the original tensor
353
412
  if (this.requiresGrad) {
@@ -423,7 +482,8 @@ class Tensor {
423
482
  strides: newStrides,
424
483
  offset: this.offset,
425
484
  numel: this.numel,
426
- device: this.device
485
+ device: this.device,
486
+ dtype: this.dtype
427
487
  });
428
488
  out.requiresGrad = this.requiresGrad;
429
489
  // Handle gradient if needed
@@ -464,7 +524,8 @@ class Tensor {
464
524
  strides: newStrides,
465
525
  offset: this.offset,
466
526
  numel: this.numel,
467
- device: this.device
527
+ device: this.device,
528
+ dtype: this.dtype
468
529
  });
469
530
  if (this.requiresGrad) {
470
531
  out.requiresGrad = true;
@@ -484,7 +545,7 @@ class Tensor {
484
545
  }
485
546
  // Utility for indexing with array of indices
486
547
  indexWithArray(indices) {
487
- if (typeof this.value === "number")
548
+ if (this.shape.length === 0)
488
549
  return this;
489
550
  indices = Tensor.normalizeDims(indices, this.shape[0]);
490
551
  // Init necessary stuff for indexing
@@ -494,7 +555,7 @@ class Tensor {
494
555
  // Init output data
495
556
  const outputShape = [indices.length, ...reducedShape];
496
557
  const outputSize = Tensor.shapeToSize(outputShape);
497
- const outputValue = new Array(outputSize);
558
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
498
559
  for (let i = 0; i < indices.length; i++) {
499
560
  const sourceRowIndex = indices[i];
500
561
  const targetStart = i * elementsPerIndex;
@@ -507,7 +568,10 @@ class Tensor {
507
568
  }
508
569
  const out = new Tensor(outputValue, {
509
570
  shape: outputShape,
510
- numel: outputSize
571
+ offset: 0,
572
+ numel: outputSize,
573
+ device: this.device,
574
+ dtype: this.dtype
511
575
  });
512
576
  // Handle gradient
513
577
  if (this.requiresGrad) {
@@ -536,13 +600,13 @@ class Tensor {
536
600
  // Tensor indexing
537
601
  index(indices) {
538
602
  const tensorIndices = this.handleOther(indices).clone();
539
- if (typeof tensorIndices.value === "number") {
540
- return this.indexWithArray([tensorIndices.value]).squeeze(0);
603
+ if (tensorIndices.shape.length === 0) {
604
+ return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
541
605
  }
542
606
  else {
543
607
  const originalShape = tensorIndices.shape;
544
608
  const flatIndices = tensorIndices.value;
545
- const result = this.indexWithArray(flatIndices);
609
+ const result = this.indexWithArray(Array.from(flatIndices));
546
610
  // Reshape to preserve input shape
547
611
  const outputShape = [...originalShape, ...this.shape.slice(1)];
548
612
  return result.reshape(outputShape);
@@ -551,7 +615,7 @@ class Tensor {
551
615
  // Tensor slicing
552
616
  slice(ranges) {
553
617
  // Handle scalars
554
- if (typeof this.value === "number")
618
+ if (this.shape.length === 0)
555
619
  return this;
556
620
  const newShape = [];
557
621
  const newStrides = [];
@@ -589,7 +653,8 @@ class Tensor {
589
653
  shape: newShape,
590
654
  strides: newStrides,
591
655
  offset: newOffset,
592
- device: this.device
656
+ device: this.device,
657
+ dtype: this.dtype
593
658
  });
594
659
  if (this.requiresGrad) {
595
660
  out.requiresGrad = true;
@@ -651,7 +716,7 @@ class Tensor {
651
716
  expand(newShape) {
652
717
  // Handle scalars
653
718
  let self = this;
654
- if (typeof this.value === "number") {
719
+ if (this.shape.length === 0) {
655
720
  self = self.unsqueeze(0);
656
721
  }
657
722
  // Pad shapes to same length
@@ -675,8 +740,8 @@ class Tensor {
675
740
  shape: targetShape,
676
741
  strides: newStrides,
677
742
  offset: self.offset,
678
- numel: Tensor.shapeToSize(targetShape),
679
- device: self.device
743
+ device: self.device,
744
+ dtype: self.dtype
680
745
  });
681
746
  if (self.requiresGrad) {
682
747
  out.requiresGrad = true;
@@ -691,7 +756,7 @@ class Tensor {
691
756
  cat(other, dim = 0) {
692
757
  other = this.handleOther(other);
693
758
  // Handle scalars
694
- if (typeof this.value === "number" || typeof other.value === "number") {
759
+ if (this.shape.length === 0 || other.shape.length === 0) {
695
760
  throw new Error("Can not concatenate scalars");
696
761
  }
697
762
  // Handle negative indices
@@ -720,7 +785,8 @@ class Tensor {
720
785
  }
721
786
  const outputSize = Tensor.shapeToSize(outputShape);
722
787
  const outputStrides = Tensor.getStrides(outputShape);
723
- const outputValue = new Array(outputSize);
788
+ const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
789
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
724
790
  for (let outIndex = 0; outIndex < outputSize; outIndex++) {
725
791
  const coords = Tensor.indexToCoords(outIndex, outputStrides);
726
792
  // Check which tensor this output position comes from
@@ -740,7 +806,10 @@ class Tensor {
740
806
  const out = new Tensor(outputValue, {
741
807
  shape: outputShape,
742
808
  strides: outputStrides,
743
- numel: outputSize
809
+ offset: 0,
810
+ numel: outputSize,
811
+ device: this.device,
812
+ dtype: this.dtype
744
813
  });
745
814
  if (this.requiresGrad) {
746
815
  out.requiresGrad = true;
@@ -773,7 +842,7 @@ class Tensor {
773
842
  }
774
843
  // Tensor squeeze
775
844
  squeeze(dims) {
776
- if (typeof this.value === "number")
845
+ if (this.shape.length === 0)
777
846
  return this;
778
847
  if (typeof dims === "number") {
779
848
  dims = [dims];
@@ -807,7 +876,9 @@ class Tensor {
807
876
  shape: outShape,
808
877
  strides: outStrides,
809
878
  offset: this.offset,
810
- device: this.device
879
+ numel: this.numel,
880
+ device: this.device,
881
+ dtype: this.dtype
811
882
  });
812
883
  // Set up gradient if needed
813
884
  if (this.requiresGrad) {
@@ -830,9 +901,6 @@ class Tensor {
830
901
  dim += this.shape.length;
831
902
  }
832
903
  let thisValue = this.value;
833
- if (typeof thisValue === "number") {
834
- thisValue = [thisValue];
835
- }
836
904
  // Insert size-1 dimension at specified position
837
905
  const newShape = [...this.shape];
838
906
  newShape.splice(dim, 0, 1);
@@ -852,7 +920,9 @@ class Tensor {
852
920
  shape: newShape,
853
921
  strides: newStrides,
854
922
  offset: this.offset,
855
- device: this.device
923
+ numel: this.numel,
924
+ device: this.device,
925
+ dtype: this.dtype
856
926
  });
857
927
  // Set up gradient if needed
858
928
  if (this.requiresGrad) {
@@ -866,10 +936,13 @@ class Tensor {
866
936
  }
867
937
  // Generic reduction operation handler
868
938
  static reduce(tensor, dims, keepDims, config) {
869
- if (typeof tensor.value === "number")
939
+ if (tensor.shape.length === 0)
870
940
  return tensor;
871
941
  if (typeof dims === "undefined") {
872
- dims = Array.from({ length: tensor.shape.length }, (_, index) => index);
942
+ dims = new Array(tensor.shape.length);
943
+ for (let index = 0; index < dims.length; index++) {
944
+ dims[index] = index;
945
+ }
873
946
  }
874
947
  if (Array.isArray(dims)) {
875
948
  dims = Tensor.normalizeDims(dims, tensor.shape.length);
@@ -883,8 +956,8 @@ class Tensor {
883
956
  const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
884
957
  const outputStrides = Tensor.getStrides(outputShape);
885
958
  const outputSize = Tensor.shapeToSize(outputShape);
886
- const outputValue = new Array(outputSize).fill(config.identity);
887
- const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
959
+ const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
960
+ const outputCounters = config.needsCounters ? new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0) : new dtype_1.TypedArray[tensor.dtype]();
888
961
  const originalSize = tensor.numel;
889
962
  const originalValue = tensor.value;
890
963
  const linearStrides = Tensor.getStrides(tensor.shape);
@@ -908,15 +981,22 @@ class Tensor {
908
981
  if (config.postProcess) {
909
982
  config.postProcess({ values: outputValue, counters: outputCounters });
910
983
  }
911
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
984
+ const out = new Tensor(outputValue, {
985
+ shape: outputShape,
986
+ strides: outputStrides,
987
+ offset: 0,
988
+ numel: outputSize,
989
+ device: tensor.device,
990
+ dtype: tensor.dtype
991
+ });
912
992
  // Gradient setup
913
993
  if (tensor.requiresGrad) {
914
994
  out.requiresGrad = true;
915
995
  out.children.push(tensor);
916
996
  out.gradFn = () => {
917
- let shareCounts = [];
997
+ let shareCounts = new dtype_1.TypedArray[tensor.dtype]();
918
998
  if (config.needsShareCounts) {
919
- shareCounts = new Array(outputSize).fill(0);
999
+ shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0);
920
1000
  for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
921
1001
  // Convert linear index to coordinates using contiguous strides
922
1002
  const coords = Tensor.indexToCoords(flatIndex, linearStrides);
@@ -929,7 +1009,7 @@ class Tensor {
929
1009
  shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
930
1010
  }
931
1011
  }
932
- const gradValue = new Array(originalSize);
1012
+ const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize);
933
1013
  for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
934
1014
  // Convert linear index to coordinates using contiguous strides
935
1015
  const coords = Tensor.indexToCoords(flatIndex, linearStrides);
@@ -947,7 +1027,13 @@ class Tensor {
947
1027
  outIndex: outFlatIndex
948
1028
  });
949
1029
  }
950
- const localGrad = new Tensor(gradValue, { shape: tensor.shape, numel: tensor.numel });
1030
+ const localGrad = new Tensor(gradValue, {
1031
+ shape: tensor.shape,
1032
+ offset: 0,
1033
+ numel: tensor.numel,
1034
+ device: tensor.device,
1035
+ dtype: tensor.dtype
1036
+ });
951
1037
  Tensor.addGrad(tensor, out.grad.mul(localGrad));
952
1038
  };
953
1039
  }
@@ -1017,7 +1103,7 @@ class Tensor {
1017
1103
  }
1018
1104
  // Tensor softmax
1019
1105
  softmax(dim = -1) {
1020
- if (typeof this.value === "number")
1106
+ if (this.shape.length === 0)
1021
1107
  return this;
1022
1108
  // Handle negative indexing
1023
1109
  if (dim < 0) {
@@ -1035,7 +1121,7 @@ class Tensor {
1035
1121
  }
1036
1122
  // Tensor softmin
1037
1123
  softmin(dim = -1) {
1038
- if (typeof this.value === "number")
1124
+ if (this.shape.length === 0)
1039
1125
  return this;
1040
1126
  // Handle negative indexing
1041
1127
  if (dim < 0) {
@@ -1435,10 +1521,11 @@ class Tensor {
1435
1521
  const matBCols = other.shape[1];
1436
1522
  if (matACols !== matBRows)
1437
1523
  throw new Error("Invalid matrices shape for multiplication");
1524
+ const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1438
1525
  const matCShape = [matARows, matBCols];
1439
1526
  const matCStrides = Tensor.getStrides(matCShape);
1440
1527
  const matCSize = Tensor.shapeToSize(matCShape);
1441
- const matC = new Array(matCSize).fill(0);
1528
+ const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
1442
1529
  for (let i = 0; i < matARows; i++) {
1443
1530
  for (let j = 0; j < matBCols; j++) {
1444
1531
  for (let k = 0; k < matACols; k++) {
@@ -1449,7 +1536,14 @@ class Tensor {
1449
1536
  }
1450
1537
  }
1451
1538
  }
1452
- const out = new Tensor(matC, { shape: matCShape, strides: matCStrides, numel: matCSize });
1539
+ const out = new Tensor(matC, {
1540
+ shape: matCShape,
1541
+ strides: matCStrides,
1542
+ offset: 0,
1543
+ numel: matCSize,
1544
+ device: this.device,
1545
+ dtype: matCDtype
1546
+ });
1453
1547
  if (this.requiresGrad) {
1454
1548
  out.requiresGrad = true;
1455
1549
  out.children.push(this);
@@ -1490,10 +1584,11 @@ class Tensor {
1490
1584
  const batchBCols = other.shape[2];
1491
1585
  if (batchACols !== batchBRows)
1492
1586
  throw new Error("Invalid matrices shape for multiplication");
1587
+ const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1493
1588
  const batchCShape = [batchSize, batchARows, batchBCols];
1494
1589
  const batchCStrides = Tensor.getStrides(batchCShape);
1495
1590
  const batchCSize = Tensor.shapeToSize(batchCShape);
1496
- const batchC = new Array(batchCSize).fill(0);
1591
+ const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
1497
1592
  for (let q = 0; q < batchSize; q++) {
1498
1593
  for (let i = 0; i < batchARows; i++) {
1499
1594
  for (let j = 0; j < batchBCols; j++) {
@@ -1506,7 +1601,14 @@ class Tensor {
1506
1601
  }
1507
1602
  }
1508
1603
  }
1509
- const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides, numel: batchCSize });
1604
+ const out = new Tensor(batchC, {
1605
+ shape: batchCShape,
1606
+ strides: batchCStrides,
1607
+ offset: 0,
1608
+ numel: batchCSize,
1609
+ device: this.device,
1610
+ dtype: batchCDtype
1611
+ });
1510
1612
  if (this.requiresGrad) {
1511
1613
  out.requiresGrad = true;
1512
1614
  out.children.push(this);
@@ -1583,10 +1685,11 @@ class Tensor {
1583
1685
  const offsetSize = Tensor.shapeToSize(offsetShape);
1584
1686
  const offsetStrides = Tensor.getStrides(offsetShape);
1585
1687
  // Output shape, strides, size, value
1688
+ const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
1586
1689
  const outputShape = [...offsetShape, batchARows, batchBCols];
1587
1690
  const outputStrides = Tensor.getStrides(outputShape);
1588
1691
  const outputSize = Tensor.shapeToSize(outputShape);
1589
- const outputValue = new Array(outputSize).fill(0);
1692
+ const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
1590
1693
  const outputOffsetStrides = outputStrides.slice(0, -2);
1591
1694
  // Loop through outer dims and do matmul on two outer-most dims
1592
1695
  for (let index = 0; index < offsetSize; index++) {
@@ -1605,7 +1708,14 @@ class Tensor {
1605
1708
  }
1606
1709
  }
1607
1710
  }
1608
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
1711
+ const out = new Tensor(outputValue, {
1712
+ shape: outputShape,
1713
+ strides: outputStrides,
1714
+ offset: 0,
1715
+ numel: outputSize,
1716
+ device: this.device,
1717
+ dtype: outputDtype
1718
+ });
1609
1719
  if (this.requiresGrad) {
1610
1720
  out.requiresGrad = true;
1611
1721
  out.children.push(this);
@@ -1647,7 +1757,7 @@ class Tensor {
1647
1757
  const maskShape = this.shape.slice(-2);
1648
1758
  const maskStrides = Tensor.getStrides(maskShape);
1649
1759
  const maskSize = Tensor.shapeToSize(maskShape);
1650
- const maskValue = new Array(maskSize).fill(1);
1760
+ const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(1);
1651
1761
  const [rows, cols] = maskShape;
1652
1762
  for (let i = 0; i < rows; i++) {
1653
1763
  const maxJ = Math.min(cols, i + diagonal);
@@ -1658,8 +1768,10 @@ class Tensor {
1658
1768
  const mask = new Tensor(maskValue, {
1659
1769
  shape: maskShape,
1660
1770
  strides: maskStrides,
1771
+ offset: 0,
1661
1772
  numel: maskSize,
1662
- device: this.device
1773
+ device: this.device,
1774
+ dtype: this.dtype
1663
1775
  });
1664
1776
  return this.mul(mask);
1665
1777
  }
@@ -1671,7 +1783,7 @@ class Tensor {
1671
1783
  const maskShape = this.shape.slice(-2);
1672
1784
  const maskStrides = Tensor.getStrides(maskShape);
1673
1785
  const maskSize = Tensor.shapeToSize(maskShape);
1674
- const maskValue = new Array(maskSize).fill(0);
1786
+ const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(0);
1675
1787
  const [rows, cols] = maskShape;
1676
1788
  for (let i = 0; i < rows; i++) {
1677
1789
  const maxJ = Math.min(cols, i + diagonal + 1);
@@ -1682,8 +1794,10 @@ class Tensor {
1682
1794
  const mask = new Tensor(maskValue, {
1683
1795
  shape: maskShape,
1684
1796
  strides: maskStrides,
1797
+ offset: 0,
1685
1798
  numel: maskSize,
1686
- device: this.device
1799
+ device: this.device,
1800
+ dtype: this.dtype
1687
1801
  });
1688
1802
  return this.mul(mask);
1689
1803
  }
@@ -1698,16 +1812,28 @@ class Tensor {
1698
1812
  return new Tensor(num, options);
1699
1813
  const outputSize = Tensor.shapeToSize(shape);
1700
1814
  const outputValue = new Array(outputSize).fill(num);
1701
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1815
+ return new Tensor(outputValue, {
1816
+ shape,
1817
+ offset: 0,
1818
+ numel: outputSize,
1819
+ ...options
1820
+ });
1702
1821
  }
1703
1822
  // Utility to create a new tensor with shape of another tensor, filled with a number
1704
1823
  static fullLike(tensor, num, options = {}) {
1705
- if (typeof tensor.value === "number")
1706
- return new Tensor(num, options);
1824
+ if (tensor.shape.length === 0)
1825
+ return new Tensor(num, {
1826
+ offset: 0,
1827
+ device: tensor.device,
1828
+ dtype: tensor.dtype,
1829
+ ...options
1830
+ });
1707
1831
  return new Tensor(new Array(tensor.numel).fill(num), {
1708
1832
  shape: tensor.shape,
1833
+ offset: 0,
1709
1834
  numel: tensor.numel,
1710
1835
  device: tensor.device,
1836
+ dtype: tensor.dtype,
1711
1837
  ...options
1712
1838
  });
1713
1839
  }
@@ -1717,16 +1843,28 @@ class Tensor {
1717
1843
  return new Tensor(1, options);
1718
1844
  const outputSize = Tensor.shapeToSize(shape);
1719
1845
  const outputValue = new Array(outputSize).fill(1);
1720
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1846
+ return new Tensor(outputValue, {
1847
+ shape,
1848
+ offset: 0,
1849
+ numel: outputSize,
1850
+ ...options
1851
+ });
1721
1852
  }
1722
1853
  // Utility to create a new tensor with shape of another tensor, filled with 1
1723
1854
  static onesLike(tensor, options = {}) {
1724
- if (typeof tensor.value === "number")
1725
- return new Tensor(1, options);
1855
+ if (tensor.shape.length === 0)
1856
+ return new Tensor(1, {
1857
+ offset: 0,
1858
+ device: tensor.device,
1859
+ dtype: tensor.dtype,
1860
+ ...options
1861
+ });
1726
1862
  return new Tensor(new Array(tensor.numel).fill(1), {
1727
1863
  shape: tensor.shape,
1864
+ offset: 0,
1728
1865
  numel: tensor.numel,
1729
1866
  device: tensor.device,
1867
+ dtype: tensor.dtype,
1730
1868
  ...options
1731
1869
  });
1732
1870
  }
@@ -1736,16 +1874,28 @@ class Tensor {
1736
1874
  return new Tensor(0, options);
1737
1875
  const outputSize = Tensor.shapeToSize(shape);
1738
1876
  const outputValue = new Array(outputSize).fill(0);
1739
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1877
+ return new Tensor(outputValue, {
1878
+ shape,
1879
+ offset: 0,
1880
+ numel: outputSize,
1881
+ ...options
1882
+ });
1740
1883
  }
1741
1884
  // Utility to create a new tensor with shape of another tensor, filled with 0
1742
1885
  static zerosLike(tensor, options = {}) {
1743
- if (typeof tensor.value === "number")
1744
- return new Tensor(0, options);
1886
+ if (tensor.shape.length === 0)
1887
+ return new Tensor(0, {
1888
+ offset: 0,
1889
+ device: tensor.device,
1890
+ dtype: tensor.dtype,
1891
+ ...options
1892
+ });
1745
1893
  return new Tensor(new Array(tensor.numel).fill(0), {
1746
1894
  shape: tensor.shape,
1895
+ offset: 0,
1747
1896
  numel: tensor.numel,
1748
1897
  device: tensor.device,
1898
+ dtype: tensor.dtype,
1749
1899
  ...options
1750
1900
  });
1751
1901
  }
@@ -1758,20 +1908,32 @@ class Tensor {
1758
1908
  for (let index = 0; index < outputValue.length; index++) {
1759
1909
  outputValue[index] = (0, utils_1.randUniform)();
1760
1910
  }
1761
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1911
+ return new Tensor(outputValue, {
1912
+ shape,
1913
+ offset: 0,
1914
+ numel: outputSize,
1915
+ ...options
1916
+ });
1762
1917
  }
1763
1918
  // Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
1764
1919
  static randLike(tensor, options = {}) {
1765
- if (typeof tensor.value === "number")
1766
- return new Tensor((0, utils_1.randUniform)(), options);
1920
+ if (tensor.shape.length === 0)
1921
+ return new Tensor((0, utils_1.randUniform)(), {
1922
+ offset: 0,
1923
+ device: tensor.device,
1924
+ dtype: tensor.dtype,
1925
+ ...options
1926
+ });
1767
1927
  const outputValue = new Array(tensor.numel);
1768
1928
  for (let index = 0; index < outputValue.length; index++) {
1769
1929
  outputValue[index] = (0, utils_1.randUniform)();
1770
1930
  }
1771
1931
  return new Tensor(outputValue, {
1772
1932
  shape: tensor.shape,
1933
+ offset: 0,
1773
1934
  numel: tensor.numel,
1774
1935
  device: tensor.device,
1936
+ dtype: tensor.dtype,
1775
1937
  ...options
1776
1938
  });
1777
1939
  }
@@ -1784,20 +1946,32 @@ class Tensor {
1784
1946
  for (let index = 0; index < outputValue.length; index++) {
1785
1947
  outputValue[index] = (0, utils_1.randNormal)();
1786
1948
  }
1787
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1949
+ return new Tensor(outputValue, {
1950
+ shape,
1951
+ offset: 0,
1952
+ numel: outputSize,
1953
+ ...options
1954
+ });
1788
1955
  }
1789
1956
  // Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
1790
1957
  static randnLike(tensor, options = {}) {
1791
- if (typeof tensor.value === "number")
1792
- return new Tensor((0, utils_1.randNormal)(), options);
1958
+ if (tensor.shape.length === 0)
1959
+ return new Tensor((0, utils_1.randNormal)(), {
1960
+ offset: 0,
1961
+ device: tensor.device,
1962
+ dtype: tensor.dtype,
1963
+ ...options
1964
+ });
1793
1965
  const outputValue = new Array(tensor.numel);
1794
1966
  for (let index = 0; index < outputValue.length; index++) {
1795
1967
  outputValue[index] = (0, utils_1.randNormal)();
1796
1968
  }
1797
1969
  return new Tensor(outputValue, {
1798
1970
  shape: tensor.shape,
1971
+ offset: 0,
1799
1972
  numel: tensor.numel,
1800
1973
  device: tensor.device,
1974
+ dtype: tensor.dtype,
1801
1975
  ...options
1802
1976
  });
1803
1977
  }
@@ -1810,20 +1984,32 @@ class Tensor {
1810
1984
  for (let index = 0; index < outputValue.length; index++) {
1811
1985
  outputValue[index] = (0, utils_1.randInt)(low, high);
1812
1986
  }
1813
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1987
+ return new Tensor(outputValue, {
1988
+ shape,
1989
+ offset: 0,
1990
+ numel: outputSize,
1991
+ ...options
1992
+ });
1814
1993
  }
1815
1994
  // Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
1816
1995
  static randintLike(tensor, low, high, options = {}) {
1817
- if (typeof tensor.value === "number")
1818
- return new Tensor((0, utils_1.randInt)(low, high), options);
1996
+ if (tensor.shape.length === 0)
1997
+ return new Tensor((0, utils_1.randInt)(low, high), {
1998
+ offset: 0,
1999
+ device: tensor.device,
2000
+ dtype: tensor.dtype,
2001
+ ...options
2002
+ });
1819
2003
  const outputValue = new Array(tensor.numel);
1820
2004
  for (let index = 0; index < outputValue.length; index++) {
1821
2005
  outputValue[index] = (0, utils_1.randInt)(low, high);
1822
2006
  }
1823
2007
  return new Tensor(outputValue, {
1824
2008
  shape: tensor.shape,
2009
+ offset: 0,
1825
2010
  numel: tensor.numel,
1826
2011
  device: tensor.device,
2012
+ dtype: tensor.dtype,
1827
2013
  ...options
1828
2014
  });
1829
2015
  }
@@ -1834,7 +2020,12 @@ class Tensor {
1834
2020
  outputValue[i] = i;
1835
2021
  }
1836
2022
  (0, utils_1.fyShuffle)(outputValue);
1837
- return new Tensor(outputValue, { shape: [n], numel: n, ...options });
2023
+ return new Tensor(outputValue, {
2024
+ shape: [n],
2025
+ offset: 0,
2026
+ numel: n,
2027
+ ...options
2028
+ });
1838
2029
  }
1839
2030
  // Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
1840
2031
  static normal(shape, mean, stdDev, options = {}) {
@@ -1845,7 +2036,12 @@ class Tensor {
1845
2036
  for (let index = 0; index < outputValue.length; index++) {
1846
2037
  outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
1847
2038
  }
1848
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
2039
+ return new Tensor(outputValue, {
2040
+ shape,
2041
+ offset: 0,
2042
+ numel: outputSize,
2043
+ ...options
2044
+ });
1849
2045
  }
1850
2046
  // Utility to create a new tensor filled with a random number with uniform distribution from low to high
1851
2047
  static uniform(shape, low, high, options = {}) {
@@ -1856,7 +2052,12 @@ class Tensor {
1856
2052
  for (let index = 0; index < outputValue.length; index++) {
1857
2053
  outputValue[index] = (0, utils_1.randUniform)(low, high);
1858
2054
  }
1859
- return new Tensor(outputValue, { shape, numel: outputSize, ...options });
2055
+ return new Tensor(outputValue, {
2056
+ shape,
2057
+ offset: 0,
2058
+ numel: outputSize,
2059
+ ...options
2060
+ });
1860
2061
  }
1861
2062
  // Utility to create an 1D tensor from a range incrementing with "step"
1862
2063
  static arange(start, stop, step = 1, options = {}) {
@@ -1870,7 +2071,12 @@ class Tensor {
1870
2071
  for (let index = 0; index < outputValue.length; index++) {
1871
2072
  outputValue[index] = start + step * index;
1872
2073
  }
1873
- return new Tensor(outputValue, { shape: outputShape, numel: outputSize, ...options });
2074
+ return new Tensor(outputValue, {
2075
+ shape: outputShape,
2076
+ offset: 0,
2077
+ numel: outputSize,
2078
+ ...options
2079
+ });
1874
2080
  }
1875
2081
  // Utility to create an 1D tensor from a range evenly spaced out with a given amount of steps
1876
2082
  static linspace(start, stop, steps, options = {}) {
@@ -1886,7 +2092,12 @@ class Tensor {
1886
2092
  }
1887
2093
  // Ensure we hit the endpoint exactly (avoids floating point errors)
1888
2094
  outputValue[steps - 1] = stop;
1889
- return new Tensor(outputValue, { shape: [steps], numel: steps, ...options });
2095
+ return new Tensor(outputValue, {
2096
+ shape: [steps],
2097
+ offset: 0,
2098
+ numel: steps,
2099
+ ...options
2100
+ });
1890
2101
  }
1891
2102
  // Utility to create a 2D tensor with its main diagonal filled with 1s and others with 0s
1892
2103
  static eye(n, m = n, options = {}) {
@@ -1897,7 +2108,13 @@ class Tensor {
1897
2108
  for (let i = 0; i < Math.min(n, m); i++) {
1898
2109
  outputValue[i * outputStrides[0] + i * outputStrides[1]] = 1;
1899
2110
  }
1900
- return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize, ...options });
2111
+ return new Tensor(outputValue, {
2112
+ shape: outputShape,
2113
+ offset: 0,
2114
+ strides: outputStrides,
2115
+ numel: outputSize,
2116
+ ...options
2117
+ });
1901
2118
  }
1902
2119
  // Reverse-mode autodiff call
1903
2120
  backward(options = {}) {
@@ -1928,8 +2145,8 @@ class Tensor {
1928
2145
  }
1929
2146
  // Returns the raw number/nD array form of tensor
1930
2147
  val() {
1931
- if (typeof this.value === "number")
1932
- return this.value;
2148
+ if (this.shape.length === 0)
2149
+ return this.value[0];
1933
2150
  function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
1934
2151
  if (dim === shape.length - 1) {
1935
2152
  // Last dimension: extract elements using actual stride
@@ -1956,20 +2173,28 @@ class Tensor {
1956
2173
  offset: this.offset,
1957
2174
  numel: this.numel,
1958
2175
  device: this.device,
2176
+ dtype: this.dtype,
1959
2177
  requiresGrad: false
1960
2178
  });
1961
2179
  }
1962
2180
  // Returns a copy of the tensor (with new data allocation) and keeps grad connection
1963
2181
  clone() {
1964
2182
  let out;
1965
- if (typeof this.value === "number") {
1966
- out = new Tensor(this.value);
2183
+ if (this.shape.length === 0) {
2184
+ out = new Tensor(this.value, {
2185
+ shape: [],
2186
+ strides: [],
2187
+ offset: 0,
2188
+ numel: 1,
2189
+ device: this.device,
2190
+ dtype: this.dtype
2191
+ });
1967
2192
  }
1968
2193
  else {
1969
2194
  const contiguous = this.isContiguous();
1970
2195
  const outputStrides = contiguous ? this.strides : Tensor.getStrides(this.shape);
1971
2196
  const outputSize = this.numel;
1972
- const outputValue = new Array(outputSize);
2197
+ const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
1973
2198
  if (contiguous) {
1974
2199
  for (let index = 0; index < outputSize; index++) {
1975
2200
  outputValue[index] = this.value[this.offset + index];
@@ -1982,7 +2207,14 @@ class Tensor {
1982
2207
  outputValue[index] = this.value[this.offset + originalIndex];
1983
2208
  }
1984
2209
  }
1985
- out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
2210
+ out = new Tensor(outputValue, {
2211
+ shape: this.shape,
2212
+ strides: outputStrides,
2213
+ offset: 0,
2214
+ numel: outputSize,
2215
+ device: this.device,
2216
+ dtype: this.dtype
2217
+ });
1986
2218
  }
1987
2219
  if (this.requiresGrad) {
1988
2220
  out.requiresGrad = true;
@@ -1994,19 +2226,38 @@ class Tensor {
1994
2226
  return out;
1995
2227
  }
1996
2228
  // Returns this tensor with value replaced with the value of another tensor
1997
- replace(other, allowShapeMismatch = false) {
2229
+ replace(other) {
1998
2230
  other = this.handleOther(other);
1999
2231
  // Verify shape
2000
- if (!allowShapeMismatch) {
2001
- for (let index = 0; index < this.shape.length; index++) {
2002
- if (this.shape[index] !== other.shape[index]) {
2003
- throw new Error("Shape mismatch when trying to do tensor value replacement");
2004
- }
2232
+ if (this.shape.length !== other.shape.length) {
2233
+ throw new Error("Shape mismatch when trying to do tensor value replacement");
2234
+ }
2235
+ for (let index = 0; index < this.shape.length; index++) {
2236
+ if (this.shape[index] !== other.shape[index]) {
2237
+ throw new Error("Shape mismatch when trying to do tensor value replacement");
2005
2238
  }
2006
2239
  }
2240
+ // Reassign values
2007
2241
  this.value = other.value;
2242
+ this.strides = other.strides;
2243
+ this.offset = other.offset;
2244
+ this.device = other.device;
2245
+ this.dtype = other.dtype;
2008
2246
  return this;
2009
2247
  }
2248
+ // Op to return a new tensor casted to another dtype
2249
+ cast(dtype) {
2250
+ if (this.dtype === dtype)
2251
+ return this;
2252
+ return new Tensor(this.value, {
2253
+ shape: this.shape,
2254
+ strides: this.strides,
2255
+ offset: this.offset,
2256
+ numel: this.numel,
2257
+ device: this.device,
2258
+ dtype: dtype
2259
+ });
2260
+ }
2010
2261
  // Holds all available backends
2011
2262
  static backends = new Map();
2012
2263
  // Op to transfer tensor to another device