catniff 0.7.4 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +14 -0
- package/dist/core.d.ts +16 -12
- package/dist/core.js +368 -126
- package/dist/dtype.d.ts +5 -0
- package/dist/dtype.js +25 -0
- package/dist/nn.d.ts +9 -8
- package/dist/nn.js +50 -50
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.Tensor = void 0;
|
|
4
|
+
const dtype_1 = require("./dtype");
|
|
4
5
|
const utils_1 = require("./utils");
|
|
5
6
|
class Tensor {
|
|
6
7
|
value;
|
|
@@ -13,12 +14,16 @@ class Tensor {
|
|
|
13
14
|
gradFn;
|
|
14
15
|
children;
|
|
15
16
|
device;
|
|
17
|
+
dtype;
|
|
16
18
|
static training = false;
|
|
17
19
|
static noGrad = false;
|
|
18
20
|
static createGraph = false;
|
|
19
21
|
constructor(value, options = {}) {
|
|
20
|
-
//
|
|
21
|
-
this.
|
|
22
|
+
// Memory buffer
|
|
23
|
+
this.dtype = options.dtype || "float32";
|
|
24
|
+
const flatValue = Tensor.flattenValue(value);
|
|
25
|
+
const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
|
|
26
|
+
this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
|
|
22
27
|
// Tensor metadata
|
|
23
28
|
this.shape = options.shape || Tensor.getShape(value);
|
|
24
29
|
this.strides = options.strides || Tensor.getStrides(this.shape);
|
|
@@ -34,31 +39,34 @@ class Tensor {
|
|
|
34
39
|
this.to_(this.device);
|
|
35
40
|
}
|
|
36
41
|
// Utility to flatten an nD array to be 1D
|
|
37
|
-
static flattenValue(
|
|
42
|
+
static flattenValue(tensorValue) {
|
|
38
43
|
// Handle scalar tensors
|
|
39
|
-
if (typeof
|
|
40
|
-
return
|
|
44
|
+
if (typeof tensorValue === "number")
|
|
45
|
+
return [tensorValue];
|
|
41
46
|
// If value is already 1D, we just need to return the value ('s reference)
|
|
42
|
-
if (typeof
|
|
43
|
-
return
|
|
47
|
+
if (typeof tensorValue[0] === "number")
|
|
48
|
+
return tensorValue;
|
|
44
49
|
// Or else recursively traverse through the nD array to flatten
|
|
45
50
|
const result = [];
|
|
46
51
|
function traverse(arr) {
|
|
47
52
|
if (typeof arr === "number") {
|
|
48
53
|
result.push(arr);
|
|
54
|
+
// Assume if we can index a value, it is an ArrayLike
|
|
49
55
|
}
|
|
50
|
-
else if (
|
|
51
|
-
arr.
|
|
56
|
+
else if (typeof arr[0] !== "undefined") {
|
|
57
|
+
for (let index = 0; index < arr.length; index++) {
|
|
58
|
+
traverse(arr[index]);
|
|
59
|
+
}
|
|
52
60
|
}
|
|
53
61
|
}
|
|
54
|
-
traverse(
|
|
62
|
+
traverse(tensorValue);
|
|
55
63
|
return result;
|
|
56
64
|
}
|
|
57
65
|
// Utility to get shape from tensor *value*
|
|
58
|
-
static getShape(
|
|
66
|
+
static getShape(tensorValue) {
|
|
59
67
|
const shape = [];
|
|
60
|
-
let subA =
|
|
61
|
-
while (
|
|
68
|
+
let subA = tensorValue;
|
|
69
|
+
while (typeof subA !== "number") {
|
|
62
70
|
shape.push(subA.length);
|
|
63
71
|
subA = subA[0];
|
|
64
72
|
}
|
|
@@ -146,17 +154,52 @@ class Tensor {
|
|
|
146
154
|
}
|
|
147
155
|
return prod;
|
|
148
156
|
}
|
|
149
|
-
|
|
157
|
+
// Utility to get best possible result type if type conflicts happen:
|
|
158
|
+
static getResultDtype(type1, type2) {
|
|
159
|
+
if (type1 === type2)
|
|
160
|
+
return type1;
|
|
161
|
+
const type1Ranking = dtype_1.dtypeHiearchy[type1];
|
|
162
|
+
const type2Ranking = dtype_1.dtypeHiearchy[type2];
|
|
163
|
+
if (type1Ranking > type2Ranking) {
|
|
164
|
+
return type1;
|
|
165
|
+
}
|
|
166
|
+
return type2;
|
|
167
|
+
}
|
|
168
|
+
// Utility to handle other tensor if an op needs a second operand
|
|
169
|
+
handleOther(other) {
|
|
170
|
+
if (other instanceof Tensor) {
|
|
171
|
+
if (this.device !== other.device) {
|
|
172
|
+
throw new Error("Can not operate on tensors that are not on the same device");
|
|
173
|
+
}
|
|
174
|
+
return other;
|
|
175
|
+
}
|
|
176
|
+
return new Tensor(other, {
|
|
177
|
+
offset: 0,
|
|
178
|
+
device: this.device,
|
|
179
|
+
dtype: this.dtype
|
|
180
|
+
});
|
|
181
|
+
}
|
|
150
182
|
// Utility for binary (two operators involved) element-wise ops
|
|
151
183
|
static elementWiseAB(tA, tB, op) {
|
|
152
|
-
|
|
153
|
-
|
|
184
|
+
const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype);
|
|
185
|
+
// Both are scalars
|
|
186
|
+
if (tA.shape.length === 0 && tB.shape.length === 0) {
|
|
187
|
+
return new Tensor(op(tA.value[0], tB.value[0]), {
|
|
188
|
+
shape: [],
|
|
189
|
+
strides: [],
|
|
190
|
+
offset: 0,
|
|
191
|
+
numel: 1,
|
|
192
|
+
device: tA.device,
|
|
193
|
+
dtype: outputDtype
|
|
194
|
+
});
|
|
154
195
|
}
|
|
155
|
-
|
|
156
|
-
|
|
196
|
+
// First tensor is scalar
|
|
197
|
+
if (tA.shape.length === 0) {
|
|
198
|
+
return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0]));
|
|
157
199
|
}
|
|
158
|
-
|
|
159
|
-
|
|
200
|
+
// Second tensor is scalar
|
|
201
|
+
if (tB.shape.length === 0) {
|
|
202
|
+
return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0]));
|
|
160
203
|
}
|
|
161
204
|
// Pad + broadcast shape
|
|
162
205
|
const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
|
|
@@ -164,7 +207,7 @@ class Tensor {
|
|
|
164
207
|
// Get other output info
|
|
165
208
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
166
209
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
167
|
-
const outputValue = new
|
|
210
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
|
|
168
211
|
for (let i = 0; i < outputSize; i++) {
|
|
169
212
|
// Get coordinates from 1D index
|
|
170
213
|
const coordsOutput = Tensor.indexToCoords(i, outputStrides);
|
|
@@ -178,18 +221,29 @@ class Tensor {
|
|
|
178
221
|
return new Tensor(outputValue, {
|
|
179
222
|
shape: outputShape,
|
|
180
223
|
strides: outputStrides,
|
|
181
|
-
|
|
224
|
+
offset: 0,
|
|
225
|
+
numel: outputSize,
|
|
226
|
+
device: tA.device,
|
|
227
|
+
dtype: outputDtype
|
|
182
228
|
});
|
|
183
229
|
}
|
|
184
230
|
// Utility for self-inflicting element-wise ops
|
|
185
231
|
static elementWiseSelf(tA, op) {
|
|
186
|
-
|
|
187
|
-
|
|
232
|
+
// Handle scalar case
|
|
233
|
+
if (tA.shape.length === 0)
|
|
234
|
+
return new Tensor(op(tA.value[0]), {
|
|
235
|
+
shape: [],
|
|
236
|
+
strides: [],
|
|
237
|
+
offset: 0,
|
|
238
|
+
numel: 1,
|
|
239
|
+
device: tA.device,
|
|
240
|
+
dtype: tA.dtype
|
|
241
|
+
});
|
|
188
242
|
const contiguous = tA.isContiguous();
|
|
189
243
|
const outputShape = tA.shape;
|
|
190
244
|
const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
|
|
191
245
|
const outputSize = tA.numel;
|
|
192
|
-
const outputValue = new
|
|
246
|
+
const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize);
|
|
193
247
|
if (contiguous) {
|
|
194
248
|
for (let index = 0; index < outputSize; index++) {
|
|
195
249
|
outputValue[index] = op(tA.value[index + tA.offset]);
|
|
@@ -202,7 +256,14 @@ class Tensor {
|
|
|
202
256
|
outputValue[index] = op(tA.value[originalIndex]);
|
|
203
257
|
}
|
|
204
258
|
}
|
|
205
|
-
return new Tensor(outputValue, {
|
|
259
|
+
return new Tensor(outputValue, {
|
|
260
|
+
shape: outputShape,
|
|
261
|
+
strides: outputStrides,
|
|
262
|
+
offset: 0,
|
|
263
|
+
numel: tA.numel,
|
|
264
|
+
device: tA.device,
|
|
265
|
+
dtype: tA.dtype
|
|
266
|
+
});
|
|
206
267
|
}
|
|
207
268
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
208
269
|
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
@@ -246,16 +307,6 @@ class Tensor {
|
|
|
246
307
|
}
|
|
247
308
|
return out;
|
|
248
309
|
}
|
|
249
|
-
// Utility to handle other tensor if an op needs a second operand
|
|
250
|
-
handleOther(other) {
|
|
251
|
-
if (other instanceof Tensor) {
|
|
252
|
-
if (this.device !== other.device) {
|
|
253
|
-
throw new Error("Can not operate on tensors that are not on the same device");
|
|
254
|
-
}
|
|
255
|
-
return other;
|
|
256
|
-
}
|
|
257
|
-
return new Tensor(other, { device: this.device });
|
|
258
|
-
}
|
|
259
310
|
// Utility to add to gradient of tensor
|
|
260
311
|
static addGrad(tensor, accumGrad) {
|
|
261
312
|
const axesToSqueeze = [];
|
|
@@ -278,7 +329,7 @@ class Tensor {
|
|
|
278
329
|
tensor.grad = squeezedGrad;
|
|
279
330
|
}
|
|
280
331
|
else {
|
|
281
|
-
tensor.grad = tensor.grad.add(squeezedGrad);
|
|
332
|
+
tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
|
|
282
333
|
}
|
|
283
334
|
}
|
|
284
335
|
static normalizeDims(dims, numDims) {
|
|
@@ -306,20 +357,27 @@ class Tensor {
|
|
|
306
357
|
}
|
|
307
358
|
contiguous() {
|
|
308
359
|
// Check if scalar
|
|
309
|
-
if (
|
|
360
|
+
if (this.shape.length === 0)
|
|
310
361
|
return this;
|
|
311
362
|
// Check if already contiguous
|
|
312
363
|
if (this.isContiguous())
|
|
313
364
|
return this;
|
|
314
365
|
const outputStrides = Tensor.getStrides(this.shape);
|
|
315
366
|
const outputSize = this.numel;
|
|
316
|
-
const outputValue = new
|
|
367
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
317
368
|
for (let index = 0; index < outputSize; index++) {
|
|
318
369
|
const outputCoords = Tensor.indexToCoords(index, outputStrides);
|
|
319
370
|
const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
|
|
320
371
|
outputValue[index] = this.value[this.offset + originalIndex];
|
|
321
372
|
}
|
|
322
|
-
const out = new Tensor(outputValue, {
|
|
373
|
+
const out = new Tensor(outputValue, {
|
|
374
|
+
shape: this.shape,
|
|
375
|
+
strides: outputStrides,
|
|
376
|
+
offset: 0,
|
|
377
|
+
numel: outputSize,
|
|
378
|
+
device: this.device,
|
|
379
|
+
dtype: this.dtype
|
|
380
|
+
});
|
|
323
381
|
// Gradient flow back to the original tensor
|
|
324
382
|
if (this.requiresGrad) {
|
|
325
383
|
out.requiresGrad = true;
|
|
@@ -334,7 +392,7 @@ class Tensor {
|
|
|
334
392
|
// Verify shape size
|
|
335
393
|
const originalSize = this.numel;
|
|
336
394
|
const outputSize = Tensor.shapeToSize(newShape);
|
|
337
|
-
if (originalSize !== outputSize
|
|
395
|
+
if (originalSize !== outputSize) {
|
|
338
396
|
throw new Error("Can not create view: incompatible sizes");
|
|
339
397
|
}
|
|
340
398
|
// Verify compatibility (only contiguity for now)
|
|
@@ -347,7 +405,8 @@ class Tensor {
|
|
|
347
405
|
strides: outputStrides,
|
|
348
406
|
offset: this.offset,
|
|
349
407
|
numel: outputSize,
|
|
350
|
-
device: this.device
|
|
408
|
+
device: this.device,
|
|
409
|
+
dtype: this.dtype
|
|
351
410
|
});
|
|
352
411
|
// Gradient reshaped and flow back to the original tensor
|
|
353
412
|
if (this.requiresGrad) {
|
|
@@ -423,7 +482,8 @@ class Tensor {
|
|
|
423
482
|
strides: newStrides,
|
|
424
483
|
offset: this.offset,
|
|
425
484
|
numel: this.numel,
|
|
426
|
-
device: this.device
|
|
485
|
+
device: this.device,
|
|
486
|
+
dtype: this.dtype
|
|
427
487
|
});
|
|
428
488
|
out.requiresGrad = this.requiresGrad;
|
|
429
489
|
// Handle gradient if needed
|
|
@@ -464,7 +524,8 @@ class Tensor {
|
|
|
464
524
|
strides: newStrides,
|
|
465
525
|
offset: this.offset,
|
|
466
526
|
numel: this.numel,
|
|
467
|
-
device: this.device
|
|
527
|
+
device: this.device,
|
|
528
|
+
dtype: this.dtype
|
|
468
529
|
});
|
|
469
530
|
if (this.requiresGrad) {
|
|
470
531
|
out.requiresGrad = true;
|
|
@@ -484,7 +545,7 @@ class Tensor {
|
|
|
484
545
|
}
|
|
485
546
|
// Utility for indexing with array of indices
|
|
486
547
|
indexWithArray(indices) {
|
|
487
|
-
if (
|
|
548
|
+
if (this.shape.length === 0)
|
|
488
549
|
return this;
|
|
489
550
|
indices = Tensor.normalizeDims(indices, this.shape[0]);
|
|
490
551
|
// Init necessary stuff for indexing
|
|
@@ -494,7 +555,7 @@ class Tensor {
|
|
|
494
555
|
// Init output data
|
|
495
556
|
const outputShape = [indices.length, ...reducedShape];
|
|
496
557
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
497
|
-
const outputValue = new
|
|
558
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
498
559
|
for (let i = 0; i < indices.length; i++) {
|
|
499
560
|
const sourceRowIndex = indices[i];
|
|
500
561
|
const targetStart = i * elementsPerIndex;
|
|
@@ -507,7 +568,10 @@ class Tensor {
|
|
|
507
568
|
}
|
|
508
569
|
const out = new Tensor(outputValue, {
|
|
509
570
|
shape: outputShape,
|
|
510
|
-
|
|
571
|
+
offset: 0,
|
|
572
|
+
numel: outputSize,
|
|
573
|
+
device: this.device,
|
|
574
|
+
dtype: this.dtype
|
|
511
575
|
});
|
|
512
576
|
// Handle gradient
|
|
513
577
|
if (this.requiresGrad) {
|
|
@@ -536,13 +600,13 @@ class Tensor {
|
|
|
536
600
|
// Tensor indexing
|
|
537
601
|
index(indices) {
|
|
538
602
|
const tensorIndices = this.handleOther(indices).clone();
|
|
539
|
-
if (
|
|
540
|
-
return this.indexWithArray([tensorIndices.value]).squeeze(0);
|
|
603
|
+
if (tensorIndices.shape.length === 0) {
|
|
604
|
+
return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
|
|
541
605
|
}
|
|
542
606
|
else {
|
|
543
607
|
const originalShape = tensorIndices.shape;
|
|
544
608
|
const flatIndices = tensorIndices.value;
|
|
545
|
-
const result = this.indexWithArray(flatIndices);
|
|
609
|
+
const result = this.indexWithArray(Array.from(flatIndices));
|
|
546
610
|
// Reshape to preserve input shape
|
|
547
611
|
const outputShape = [...originalShape, ...this.shape.slice(1)];
|
|
548
612
|
return result.reshape(outputShape);
|
|
@@ -551,7 +615,7 @@ class Tensor {
|
|
|
551
615
|
// Tensor slicing
|
|
552
616
|
slice(ranges) {
|
|
553
617
|
// Handle scalars
|
|
554
|
-
if (
|
|
618
|
+
if (this.shape.length === 0)
|
|
555
619
|
return this;
|
|
556
620
|
const newShape = [];
|
|
557
621
|
const newStrides = [];
|
|
@@ -589,7 +653,8 @@ class Tensor {
|
|
|
589
653
|
shape: newShape,
|
|
590
654
|
strides: newStrides,
|
|
591
655
|
offset: newOffset,
|
|
592
|
-
device: this.device
|
|
656
|
+
device: this.device,
|
|
657
|
+
dtype: this.dtype
|
|
593
658
|
});
|
|
594
659
|
if (this.requiresGrad) {
|
|
595
660
|
out.requiresGrad = true;
|
|
@@ -651,7 +716,7 @@ class Tensor {
|
|
|
651
716
|
expand(newShape) {
|
|
652
717
|
// Handle scalars
|
|
653
718
|
let self = this;
|
|
654
|
-
if (
|
|
719
|
+
if (this.shape.length === 0) {
|
|
655
720
|
self = self.unsqueeze(0);
|
|
656
721
|
}
|
|
657
722
|
// Pad shapes to same length
|
|
@@ -675,8 +740,8 @@ class Tensor {
|
|
|
675
740
|
shape: targetShape,
|
|
676
741
|
strides: newStrides,
|
|
677
742
|
offset: self.offset,
|
|
678
|
-
|
|
679
|
-
|
|
743
|
+
device: self.device,
|
|
744
|
+
dtype: self.dtype
|
|
680
745
|
});
|
|
681
746
|
if (self.requiresGrad) {
|
|
682
747
|
out.requiresGrad = true;
|
|
@@ -691,7 +756,7 @@ class Tensor {
|
|
|
691
756
|
cat(other, dim = 0) {
|
|
692
757
|
other = this.handleOther(other);
|
|
693
758
|
// Handle scalars
|
|
694
|
-
if (
|
|
759
|
+
if (this.shape.length === 0 || other.shape.length === 0) {
|
|
695
760
|
throw new Error("Can not concatenate scalars");
|
|
696
761
|
}
|
|
697
762
|
// Handle negative indices
|
|
@@ -720,7 +785,8 @@ class Tensor {
|
|
|
720
785
|
}
|
|
721
786
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
722
787
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
723
|
-
const
|
|
788
|
+
const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
789
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
|
|
724
790
|
for (let outIndex = 0; outIndex < outputSize; outIndex++) {
|
|
725
791
|
const coords = Tensor.indexToCoords(outIndex, outputStrides);
|
|
726
792
|
// Check which tensor this output position comes from
|
|
@@ -740,7 +806,10 @@ class Tensor {
|
|
|
740
806
|
const out = new Tensor(outputValue, {
|
|
741
807
|
shape: outputShape,
|
|
742
808
|
strides: outputStrides,
|
|
743
|
-
|
|
809
|
+
offset: 0,
|
|
810
|
+
numel: outputSize,
|
|
811
|
+
device: this.device,
|
|
812
|
+
dtype: this.dtype
|
|
744
813
|
});
|
|
745
814
|
if (this.requiresGrad) {
|
|
746
815
|
out.requiresGrad = true;
|
|
@@ -773,7 +842,7 @@ class Tensor {
|
|
|
773
842
|
}
|
|
774
843
|
// Tensor squeeze
|
|
775
844
|
squeeze(dims) {
|
|
776
|
-
if (
|
|
845
|
+
if (this.shape.length === 0)
|
|
777
846
|
return this;
|
|
778
847
|
if (typeof dims === "number") {
|
|
779
848
|
dims = [dims];
|
|
@@ -807,7 +876,9 @@ class Tensor {
|
|
|
807
876
|
shape: outShape,
|
|
808
877
|
strides: outStrides,
|
|
809
878
|
offset: this.offset,
|
|
810
|
-
|
|
879
|
+
numel: this.numel,
|
|
880
|
+
device: this.device,
|
|
881
|
+
dtype: this.dtype
|
|
811
882
|
});
|
|
812
883
|
// Set up gradient if needed
|
|
813
884
|
if (this.requiresGrad) {
|
|
@@ -830,9 +901,6 @@ class Tensor {
|
|
|
830
901
|
dim += this.shape.length;
|
|
831
902
|
}
|
|
832
903
|
let thisValue = this.value;
|
|
833
|
-
if (typeof thisValue === "number") {
|
|
834
|
-
thisValue = [thisValue];
|
|
835
|
-
}
|
|
836
904
|
// Insert size-1 dimension at specified position
|
|
837
905
|
const newShape = [...this.shape];
|
|
838
906
|
newShape.splice(dim, 0, 1);
|
|
@@ -852,7 +920,9 @@ class Tensor {
|
|
|
852
920
|
shape: newShape,
|
|
853
921
|
strides: newStrides,
|
|
854
922
|
offset: this.offset,
|
|
855
|
-
|
|
923
|
+
numel: this.numel,
|
|
924
|
+
device: this.device,
|
|
925
|
+
dtype: this.dtype
|
|
856
926
|
});
|
|
857
927
|
// Set up gradient if needed
|
|
858
928
|
if (this.requiresGrad) {
|
|
@@ -866,10 +936,13 @@ class Tensor {
|
|
|
866
936
|
}
|
|
867
937
|
// Generic reduction operation handler
|
|
868
938
|
static reduce(tensor, dims, keepDims, config) {
|
|
869
|
-
if (
|
|
939
|
+
if (tensor.shape.length === 0)
|
|
870
940
|
return tensor;
|
|
871
941
|
if (typeof dims === "undefined") {
|
|
872
|
-
dims = Array
|
|
942
|
+
dims = new Array(tensor.shape.length);
|
|
943
|
+
for (let index = 0; index < dims.length; index++) {
|
|
944
|
+
dims[index] = index;
|
|
945
|
+
}
|
|
873
946
|
}
|
|
874
947
|
if (Array.isArray(dims)) {
|
|
875
948
|
dims = Tensor.normalizeDims(dims, tensor.shape.length);
|
|
@@ -880,11 +953,11 @@ class Tensor {
|
|
|
880
953
|
}
|
|
881
954
|
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
882
955
|
}
|
|
956
|
+
const dimSize = tensor.shape[dims];
|
|
883
957
|
const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
884
958
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
885
|
-
const outputSize =
|
|
886
|
-
const outputValue = new
|
|
887
|
-
const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
|
|
959
|
+
const outputSize = tensor.numel / dimSize;
|
|
960
|
+
const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
|
|
888
961
|
const originalSize = tensor.numel;
|
|
889
962
|
const originalValue = tensor.value;
|
|
890
963
|
const linearStrides = Tensor.getStrides(tensor.shape);
|
|
@@ -899,24 +972,27 @@ class Tensor {
|
|
|
899
972
|
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
|
|
900
973
|
// Apply op
|
|
901
974
|
outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
|
|
902
|
-
// Count el if needed
|
|
903
|
-
if (config.needsCounters) {
|
|
904
|
-
outputCounters[outFlatIndex]++;
|
|
905
|
-
}
|
|
906
975
|
}
|
|
907
976
|
// Post-process if needed (e.g., divide by count for mean)
|
|
908
977
|
if (config.postProcess) {
|
|
909
|
-
config.postProcess({ values: outputValue,
|
|
978
|
+
config.postProcess({ values: outputValue, dimSize });
|
|
910
979
|
}
|
|
911
|
-
const out = new Tensor(outputValue, {
|
|
980
|
+
const out = new Tensor(outputValue, {
|
|
981
|
+
shape: outputShape,
|
|
982
|
+
strides: outputStrides,
|
|
983
|
+
offset: 0,
|
|
984
|
+
numel: outputSize,
|
|
985
|
+
device: tensor.device,
|
|
986
|
+
dtype: tensor.dtype
|
|
987
|
+
});
|
|
912
988
|
// Gradient setup
|
|
913
989
|
if (tensor.requiresGrad) {
|
|
914
990
|
out.requiresGrad = true;
|
|
915
991
|
out.children.push(tensor);
|
|
916
992
|
out.gradFn = () => {
|
|
917
|
-
let shareCounts = [];
|
|
993
|
+
let shareCounts = new dtype_1.TypedArray[tensor.dtype]();
|
|
918
994
|
if (config.needsShareCounts) {
|
|
919
|
-
shareCounts = new
|
|
995
|
+
shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0);
|
|
920
996
|
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
|
|
921
997
|
// Convert linear index to coordinates using contiguous strides
|
|
922
998
|
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
|
|
@@ -929,7 +1005,7 @@ class Tensor {
|
|
|
929
1005
|
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
930
1006
|
}
|
|
931
1007
|
}
|
|
932
|
-
const gradValue = new
|
|
1008
|
+
const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize);
|
|
933
1009
|
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
|
|
934
1010
|
// Convert linear index to coordinates using contiguous strides
|
|
935
1011
|
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
|
|
@@ -941,13 +1017,19 @@ class Tensor {
|
|
|
941
1017
|
gradValue[flatIndex] = config.gradientFn({
|
|
942
1018
|
outputValue,
|
|
943
1019
|
originalValue: tensor.value,
|
|
944
|
-
|
|
1020
|
+
dimSize,
|
|
945
1021
|
shareCounts,
|
|
946
1022
|
realIndex: realFlatIndex,
|
|
947
1023
|
outIndex: outFlatIndex
|
|
948
1024
|
});
|
|
949
1025
|
}
|
|
950
|
-
const localGrad = new Tensor(gradValue, {
|
|
1026
|
+
const localGrad = new Tensor(gradValue, {
|
|
1027
|
+
shape: tensor.shape,
|
|
1028
|
+
offset: 0,
|
|
1029
|
+
numel: tensor.numel,
|
|
1030
|
+
device: tensor.device,
|
|
1031
|
+
dtype: tensor.dtype
|
|
1032
|
+
});
|
|
951
1033
|
Tensor.addGrad(tensor, out.grad.mul(localGrad));
|
|
952
1034
|
};
|
|
953
1035
|
}
|
|
@@ -972,13 +1054,12 @@ class Tensor {
|
|
|
972
1054
|
return Tensor.reduce(this, dims, keepDims, {
|
|
973
1055
|
identity: 0,
|
|
974
1056
|
operation: (a, b) => a + b,
|
|
975
|
-
|
|
976
|
-
postProcess: ({ values, counters }) => {
|
|
1057
|
+
postProcess: ({ values, dimSize }) => {
|
|
977
1058
|
for (let i = 0; i < values.length; i++) {
|
|
978
|
-
values[i] /=
|
|
1059
|
+
values[i] /= dimSize;
|
|
979
1060
|
}
|
|
980
1061
|
},
|
|
981
|
-
gradientFn: ({
|
|
1062
|
+
gradientFn: ({ dimSize }) => 1 / dimSize
|
|
982
1063
|
});
|
|
983
1064
|
}
|
|
984
1065
|
max(dims, keepDims = false) {
|
|
@@ -1017,7 +1098,7 @@ class Tensor {
|
|
|
1017
1098
|
}
|
|
1018
1099
|
// Tensor softmax
|
|
1019
1100
|
softmax(dim = -1) {
|
|
1020
|
-
if (
|
|
1101
|
+
if (this.shape.length === 0)
|
|
1021
1102
|
return this;
|
|
1022
1103
|
// Handle negative indexing
|
|
1023
1104
|
if (dim < 0) {
|
|
@@ -1035,7 +1116,7 @@ class Tensor {
|
|
|
1035
1116
|
}
|
|
1036
1117
|
// Tensor softmin
|
|
1037
1118
|
softmin(dim = -1) {
|
|
1038
|
-
if (
|
|
1119
|
+
if (this.shape.length === 0)
|
|
1039
1120
|
return this;
|
|
1040
1121
|
// Handle negative indexing
|
|
1041
1122
|
if (dim < 0) {
|
|
@@ -1435,10 +1516,11 @@ class Tensor {
|
|
|
1435
1516
|
const matBCols = other.shape[1];
|
|
1436
1517
|
if (matACols !== matBRows)
|
|
1437
1518
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1519
|
+
const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1438
1520
|
const matCShape = [matARows, matBCols];
|
|
1439
1521
|
const matCStrides = Tensor.getStrides(matCShape);
|
|
1440
1522
|
const matCSize = Tensor.shapeToSize(matCShape);
|
|
1441
|
-
const matC = new
|
|
1523
|
+
const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
|
|
1442
1524
|
for (let i = 0; i < matARows; i++) {
|
|
1443
1525
|
for (let j = 0; j < matBCols; j++) {
|
|
1444
1526
|
for (let k = 0; k < matACols; k++) {
|
|
@@ -1449,7 +1531,14 @@ class Tensor {
|
|
|
1449
1531
|
}
|
|
1450
1532
|
}
|
|
1451
1533
|
}
|
|
1452
|
-
const out = new Tensor(matC, {
|
|
1534
|
+
const out = new Tensor(matC, {
|
|
1535
|
+
shape: matCShape,
|
|
1536
|
+
strides: matCStrides,
|
|
1537
|
+
offset: 0,
|
|
1538
|
+
numel: matCSize,
|
|
1539
|
+
device: this.device,
|
|
1540
|
+
dtype: matCDtype
|
|
1541
|
+
});
|
|
1453
1542
|
if (this.requiresGrad) {
|
|
1454
1543
|
out.requiresGrad = true;
|
|
1455
1544
|
out.children.push(this);
|
|
@@ -1490,10 +1579,11 @@ class Tensor {
|
|
|
1490
1579
|
const batchBCols = other.shape[2];
|
|
1491
1580
|
if (batchACols !== batchBRows)
|
|
1492
1581
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1582
|
+
const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1493
1583
|
const batchCShape = [batchSize, batchARows, batchBCols];
|
|
1494
1584
|
const batchCStrides = Tensor.getStrides(batchCShape);
|
|
1495
1585
|
const batchCSize = Tensor.shapeToSize(batchCShape);
|
|
1496
|
-
const batchC = new
|
|
1586
|
+
const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
|
|
1497
1587
|
for (let q = 0; q < batchSize; q++) {
|
|
1498
1588
|
for (let i = 0; i < batchARows; i++) {
|
|
1499
1589
|
for (let j = 0; j < batchBCols; j++) {
|
|
@@ -1506,7 +1596,14 @@ class Tensor {
|
|
|
1506
1596
|
}
|
|
1507
1597
|
}
|
|
1508
1598
|
}
|
|
1509
|
-
const out = new Tensor(batchC, {
|
|
1599
|
+
const out = new Tensor(batchC, {
|
|
1600
|
+
shape: batchCShape,
|
|
1601
|
+
strides: batchCStrides,
|
|
1602
|
+
offset: 0,
|
|
1603
|
+
numel: batchCSize,
|
|
1604
|
+
device: this.device,
|
|
1605
|
+
dtype: batchCDtype
|
|
1606
|
+
});
|
|
1510
1607
|
if (this.requiresGrad) {
|
|
1511
1608
|
out.requiresGrad = true;
|
|
1512
1609
|
out.children.push(this);
|
|
@@ -1583,10 +1680,11 @@ class Tensor {
|
|
|
1583
1680
|
const offsetSize = Tensor.shapeToSize(offsetShape);
|
|
1584
1681
|
const offsetStrides = Tensor.getStrides(offsetShape);
|
|
1585
1682
|
// Output shape, strides, size, value
|
|
1683
|
+
const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1586
1684
|
const outputShape = [...offsetShape, batchARows, batchBCols];
|
|
1587
1685
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
1588
1686
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
1589
|
-
const outputValue = new
|
|
1687
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
|
|
1590
1688
|
const outputOffsetStrides = outputStrides.slice(0, -2);
|
|
1591
1689
|
// Loop through outer dims and do matmul on two outer-most dims
|
|
1592
1690
|
for (let index = 0; index < offsetSize; index++) {
|
|
@@ -1605,7 +1703,14 @@ class Tensor {
|
|
|
1605
1703
|
}
|
|
1606
1704
|
}
|
|
1607
1705
|
}
|
|
1608
|
-
const out = new Tensor(outputValue, {
|
|
1706
|
+
const out = new Tensor(outputValue, {
|
|
1707
|
+
shape: outputShape,
|
|
1708
|
+
strides: outputStrides,
|
|
1709
|
+
offset: 0,
|
|
1710
|
+
numel: outputSize,
|
|
1711
|
+
device: this.device,
|
|
1712
|
+
dtype: outputDtype
|
|
1713
|
+
});
|
|
1609
1714
|
if (this.requiresGrad) {
|
|
1610
1715
|
out.requiresGrad = true;
|
|
1611
1716
|
out.children.push(this);
|
|
@@ -1647,7 +1752,7 @@ class Tensor {
|
|
|
1647
1752
|
const maskShape = this.shape.slice(-2);
|
|
1648
1753
|
const maskStrides = Tensor.getStrides(maskShape);
|
|
1649
1754
|
const maskSize = Tensor.shapeToSize(maskShape);
|
|
1650
|
-
const maskValue = new
|
|
1755
|
+
const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(1);
|
|
1651
1756
|
const [rows, cols] = maskShape;
|
|
1652
1757
|
for (let i = 0; i < rows; i++) {
|
|
1653
1758
|
const maxJ = Math.min(cols, i + diagonal);
|
|
@@ -1658,8 +1763,10 @@ class Tensor {
|
|
|
1658
1763
|
const mask = new Tensor(maskValue, {
|
|
1659
1764
|
shape: maskShape,
|
|
1660
1765
|
strides: maskStrides,
|
|
1766
|
+
offset: 0,
|
|
1661
1767
|
numel: maskSize,
|
|
1662
|
-
device: this.device
|
|
1768
|
+
device: this.device,
|
|
1769
|
+
dtype: this.dtype
|
|
1663
1770
|
});
|
|
1664
1771
|
return this.mul(mask);
|
|
1665
1772
|
}
|
|
@@ -1671,7 +1778,7 @@ class Tensor {
|
|
|
1671
1778
|
const maskShape = this.shape.slice(-2);
|
|
1672
1779
|
const maskStrides = Tensor.getStrides(maskShape);
|
|
1673
1780
|
const maskSize = Tensor.shapeToSize(maskShape);
|
|
1674
|
-
const maskValue = new
|
|
1781
|
+
const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(0);
|
|
1675
1782
|
const [rows, cols] = maskShape;
|
|
1676
1783
|
for (let i = 0; i < rows; i++) {
|
|
1677
1784
|
const maxJ = Math.min(cols, i + diagonal + 1);
|
|
@@ -1682,8 +1789,10 @@ class Tensor {
|
|
|
1682
1789
|
const mask = new Tensor(maskValue, {
|
|
1683
1790
|
shape: maskShape,
|
|
1684
1791
|
strides: maskStrides,
|
|
1792
|
+
offset: 0,
|
|
1685
1793
|
numel: maskSize,
|
|
1686
|
-
device: this.device
|
|
1794
|
+
device: this.device,
|
|
1795
|
+
dtype: this.dtype
|
|
1687
1796
|
});
|
|
1688
1797
|
return this.mul(mask);
|
|
1689
1798
|
}
|
|
@@ -1698,16 +1807,28 @@ class Tensor {
|
|
|
1698
1807
|
return new Tensor(num, options);
|
|
1699
1808
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1700
1809
|
const outputValue = new Array(outputSize).fill(num);
|
|
1701
|
-
return new Tensor(outputValue, {
|
|
1810
|
+
return new Tensor(outputValue, {
|
|
1811
|
+
shape,
|
|
1812
|
+
offset: 0,
|
|
1813
|
+
numel: outputSize,
|
|
1814
|
+
...options
|
|
1815
|
+
});
|
|
1702
1816
|
}
|
|
1703
1817
|
// Utility to create a new tensor with shape of another tensor, filled with a number
|
|
1704
1818
|
static fullLike(tensor, num, options = {}) {
|
|
1705
|
-
if (
|
|
1706
|
-
return new Tensor(num,
|
|
1819
|
+
if (tensor.shape.length === 0)
|
|
1820
|
+
return new Tensor(num, {
|
|
1821
|
+
offset: 0,
|
|
1822
|
+
device: tensor.device,
|
|
1823
|
+
dtype: tensor.dtype,
|
|
1824
|
+
...options
|
|
1825
|
+
});
|
|
1707
1826
|
return new Tensor(new Array(tensor.numel).fill(num), {
|
|
1708
1827
|
shape: tensor.shape,
|
|
1828
|
+
offset: 0,
|
|
1709
1829
|
numel: tensor.numel,
|
|
1710
1830
|
device: tensor.device,
|
|
1831
|
+
dtype: tensor.dtype,
|
|
1711
1832
|
...options
|
|
1712
1833
|
});
|
|
1713
1834
|
}
|
|
@@ -1717,16 +1838,28 @@ class Tensor {
|
|
|
1717
1838
|
return new Tensor(1, options);
|
|
1718
1839
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1719
1840
|
const outputValue = new Array(outputSize).fill(1);
|
|
1720
|
-
return new Tensor(outputValue, {
|
|
1841
|
+
return new Tensor(outputValue, {
|
|
1842
|
+
shape,
|
|
1843
|
+
offset: 0,
|
|
1844
|
+
numel: outputSize,
|
|
1845
|
+
...options
|
|
1846
|
+
});
|
|
1721
1847
|
}
|
|
1722
1848
|
// Utility to create a new tensor with shape of another tensor, filled with 1
|
|
1723
1849
|
static onesLike(tensor, options = {}) {
|
|
1724
|
-
if (
|
|
1725
|
-
return new Tensor(1,
|
|
1850
|
+
if (tensor.shape.length === 0)
|
|
1851
|
+
return new Tensor(1, {
|
|
1852
|
+
offset: 0,
|
|
1853
|
+
device: tensor.device,
|
|
1854
|
+
dtype: tensor.dtype,
|
|
1855
|
+
...options
|
|
1856
|
+
});
|
|
1726
1857
|
return new Tensor(new Array(tensor.numel).fill(1), {
|
|
1727
1858
|
shape: tensor.shape,
|
|
1859
|
+
offset: 0,
|
|
1728
1860
|
numel: tensor.numel,
|
|
1729
1861
|
device: tensor.device,
|
|
1862
|
+
dtype: tensor.dtype,
|
|
1730
1863
|
...options
|
|
1731
1864
|
});
|
|
1732
1865
|
}
|
|
@@ -1736,16 +1869,28 @@ class Tensor {
|
|
|
1736
1869
|
return new Tensor(0, options);
|
|
1737
1870
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1738
1871
|
const outputValue = new Array(outputSize).fill(0);
|
|
1739
|
-
return new Tensor(outputValue, {
|
|
1872
|
+
return new Tensor(outputValue, {
|
|
1873
|
+
shape,
|
|
1874
|
+
offset: 0,
|
|
1875
|
+
numel: outputSize,
|
|
1876
|
+
...options
|
|
1877
|
+
});
|
|
1740
1878
|
}
|
|
1741
1879
|
// Utility to create a new tensor with shape of another tensor, filled with 0
|
|
1742
1880
|
static zerosLike(tensor, options = {}) {
|
|
1743
|
-
if (
|
|
1744
|
-
return new Tensor(0,
|
|
1881
|
+
if (tensor.shape.length === 0)
|
|
1882
|
+
return new Tensor(0, {
|
|
1883
|
+
offset: 0,
|
|
1884
|
+
device: tensor.device,
|
|
1885
|
+
dtype: tensor.dtype,
|
|
1886
|
+
...options
|
|
1887
|
+
});
|
|
1745
1888
|
return new Tensor(new Array(tensor.numel).fill(0), {
|
|
1746
1889
|
shape: tensor.shape,
|
|
1890
|
+
offset: 0,
|
|
1747
1891
|
numel: tensor.numel,
|
|
1748
1892
|
device: tensor.device,
|
|
1893
|
+
dtype: tensor.dtype,
|
|
1749
1894
|
...options
|
|
1750
1895
|
});
|
|
1751
1896
|
}
|
|
@@ -1758,20 +1903,32 @@ class Tensor {
|
|
|
1758
1903
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1759
1904
|
outputValue[index] = (0, utils_1.randUniform)();
|
|
1760
1905
|
}
|
|
1761
|
-
return new Tensor(outputValue, {
|
|
1906
|
+
return new Tensor(outputValue, {
|
|
1907
|
+
shape,
|
|
1908
|
+
offset: 0,
|
|
1909
|
+
numel: outputSize,
|
|
1910
|
+
...options
|
|
1911
|
+
});
|
|
1762
1912
|
}
|
|
1763
1913
|
// Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
|
|
1764
1914
|
static randLike(tensor, options = {}) {
|
|
1765
|
-
if (
|
|
1766
|
-
return new Tensor((0, utils_1.randUniform)(),
|
|
1915
|
+
if (tensor.shape.length === 0)
|
|
1916
|
+
return new Tensor((0, utils_1.randUniform)(), {
|
|
1917
|
+
offset: 0,
|
|
1918
|
+
device: tensor.device,
|
|
1919
|
+
dtype: tensor.dtype,
|
|
1920
|
+
...options
|
|
1921
|
+
});
|
|
1767
1922
|
const outputValue = new Array(tensor.numel);
|
|
1768
1923
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1769
1924
|
outputValue[index] = (0, utils_1.randUniform)();
|
|
1770
1925
|
}
|
|
1771
1926
|
return new Tensor(outputValue, {
|
|
1772
1927
|
shape: tensor.shape,
|
|
1928
|
+
offset: 0,
|
|
1773
1929
|
numel: tensor.numel,
|
|
1774
1930
|
device: tensor.device,
|
|
1931
|
+
dtype: tensor.dtype,
|
|
1775
1932
|
...options
|
|
1776
1933
|
});
|
|
1777
1934
|
}
|
|
@@ -1784,20 +1941,32 @@ class Tensor {
|
|
|
1784
1941
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1785
1942
|
outputValue[index] = (0, utils_1.randNormal)();
|
|
1786
1943
|
}
|
|
1787
|
-
return new Tensor(outputValue, {
|
|
1944
|
+
return new Tensor(outputValue, {
|
|
1945
|
+
shape,
|
|
1946
|
+
offset: 0,
|
|
1947
|
+
numel: outputSize,
|
|
1948
|
+
...options
|
|
1949
|
+
});
|
|
1788
1950
|
}
|
|
1789
1951
|
// Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
|
|
1790
1952
|
static randnLike(tensor, options = {}) {
|
|
1791
|
-
if (
|
|
1792
|
-
return new Tensor((0, utils_1.randNormal)(),
|
|
1953
|
+
if (tensor.shape.length === 0)
|
|
1954
|
+
return new Tensor((0, utils_1.randNormal)(), {
|
|
1955
|
+
offset: 0,
|
|
1956
|
+
device: tensor.device,
|
|
1957
|
+
dtype: tensor.dtype,
|
|
1958
|
+
...options
|
|
1959
|
+
});
|
|
1793
1960
|
const outputValue = new Array(tensor.numel);
|
|
1794
1961
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1795
1962
|
outputValue[index] = (0, utils_1.randNormal)();
|
|
1796
1963
|
}
|
|
1797
1964
|
return new Tensor(outputValue, {
|
|
1798
1965
|
shape: tensor.shape,
|
|
1966
|
+
offset: 0,
|
|
1799
1967
|
numel: tensor.numel,
|
|
1800
1968
|
device: tensor.device,
|
|
1969
|
+
dtype: tensor.dtype,
|
|
1801
1970
|
...options
|
|
1802
1971
|
});
|
|
1803
1972
|
}
|
|
@@ -1810,20 +1979,32 @@ class Tensor {
|
|
|
1810
1979
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1811
1980
|
outputValue[index] = (0, utils_1.randInt)(low, high);
|
|
1812
1981
|
}
|
|
1813
|
-
return new Tensor(outputValue, {
|
|
1982
|
+
return new Tensor(outputValue, {
|
|
1983
|
+
shape,
|
|
1984
|
+
offset: 0,
|
|
1985
|
+
numel: outputSize,
|
|
1986
|
+
...options
|
|
1987
|
+
});
|
|
1814
1988
|
}
|
|
1815
1989
|
// Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
|
|
1816
1990
|
static randintLike(tensor, low, high, options = {}) {
|
|
1817
|
-
if (
|
|
1818
|
-
return new Tensor((0, utils_1.randInt)(low, high),
|
|
1991
|
+
if (tensor.shape.length === 0)
|
|
1992
|
+
return new Tensor((0, utils_1.randInt)(low, high), {
|
|
1993
|
+
offset: 0,
|
|
1994
|
+
device: tensor.device,
|
|
1995
|
+
dtype: tensor.dtype,
|
|
1996
|
+
...options
|
|
1997
|
+
});
|
|
1819
1998
|
const outputValue = new Array(tensor.numel);
|
|
1820
1999
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1821
2000
|
outputValue[index] = (0, utils_1.randInt)(low, high);
|
|
1822
2001
|
}
|
|
1823
2002
|
return new Tensor(outputValue, {
|
|
1824
2003
|
shape: tensor.shape,
|
|
2004
|
+
offset: 0,
|
|
1825
2005
|
numel: tensor.numel,
|
|
1826
2006
|
device: tensor.device,
|
|
2007
|
+
dtype: tensor.dtype,
|
|
1827
2008
|
...options
|
|
1828
2009
|
});
|
|
1829
2010
|
}
|
|
@@ -1834,7 +2015,12 @@ class Tensor {
|
|
|
1834
2015
|
outputValue[i] = i;
|
|
1835
2016
|
}
|
|
1836
2017
|
(0, utils_1.fyShuffle)(outputValue);
|
|
1837
|
-
return new Tensor(outputValue, {
|
|
2018
|
+
return new Tensor(outputValue, {
|
|
2019
|
+
shape: [n],
|
|
2020
|
+
offset: 0,
|
|
2021
|
+
numel: n,
|
|
2022
|
+
...options
|
|
2023
|
+
});
|
|
1838
2024
|
}
|
|
1839
2025
|
// Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
|
|
1840
2026
|
static normal(shape, mean, stdDev, options = {}) {
|
|
@@ -1845,7 +2031,12 @@ class Tensor {
|
|
|
1845
2031
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1846
2032
|
outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
|
|
1847
2033
|
}
|
|
1848
|
-
return new Tensor(outputValue, {
|
|
2034
|
+
return new Tensor(outputValue, {
|
|
2035
|
+
shape,
|
|
2036
|
+
offset: 0,
|
|
2037
|
+
numel: outputSize,
|
|
2038
|
+
...options
|
|
2039
|
+
});
|
|
1849
2040
|
}
|
|
1850
2041
|
// Utility to create a new tensor filled with a random number with uniform distribution from low to high
|
|
1851
2042
|
static uniform(shape, low, high, options = {}) {
|
|
@@ -1856,7 +2047,12 @@ class Tensor {
|
|
|
1856
2047
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1857
2048
|
outputValue[index] = (0, utils_1.randUniform)(low, high);
|
|
1858
2049
|
}
|
|
1859
|
-
return new Tensor(outputValue, {
|
|
2050
|
+
return new Tensor(outputValue, {
|
|
2051
|
+
shape,
|
|
2052
|
+
offset: 0,
|
|
2053
|
+
numel: outputSize,
|
|
2054
|
+
...options
|
|
2055
|
+
});
|
|
1860
2056
|
}
|
|
1861
2057
|
// Utility to create an 1D tensor from a range incrementing with "step"
|
|
1862
2058
|
static arange(start, stop, step = 1, options = {}) {
|
|
@@ -1870,7 +2066,12 @@ class Tensor {
|
|
|
1870
2066
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1871
2067
|
outputValue[index] = start + step * index;
|
|
1872
2068
|
}
|
|
1873
|
-
return new Tensor(outputValue, {
|
|
2069
|
+
return new Tensor(outputValue, {
|
|
2070
|
+
shape: outputShape,
|
|
2071
|
+
offset: 0,
|
|
2072
|
+
numel: outputSize,
|
|
2073
|
+
...options
|
|
2074
|
+
});
|
|
1874
2075
|
}
|
|
1875
2076
|
// Utility to create an 1D tensor from a range evenly spaced out with a given amount of steps
|
|
1876
2077
|
static linspace(start, stop, steps, options = {}) {
|
|
@@ -1886,7 +2087,12 @@ class Tensor {
|
|
|
1886
2087
|
}
|
|
1887
2088
|
// Ensure we hit the endpoint exactly (avoids floating point errors)
|
|
1888
2089
|
outputValue[steps - 1] = stop;
|
|
1889
|
-
return new Tensor(outputValue, {
|
|
2090
|
+
return new Tensor(outputValue, {
|
|
2091
|
+
shape: [steps],
|
|
2092
|
+
offset: 0,
|
|
2093
|
+
numel: steps,
|
|
2094
|
+
...options
|
|
2095
|
+
});
|
|
1890
2096
|
}
|
|
1891
2097
|
// Utility to create a 2D tensor with its main diagonal filled with 1s and others with 0s
|
|
1892
2098
|
static eye(n, m = n, options = {}) {
|
|
@@ -1897,7 +2103,13 @@ class Tensor {
|
|
|
1897
2103
|
for (let i = 0; i < Math.min(n, m); i++) {
|
|
1898
2104
|
outputValue[i * outputStrides[0] + i * outputStrides[1]] = 1;
|
|
1899
2105
|
}
|
|
1900
|
-
return new Tensor(outputValue, {
|
|
2106
|
+
return new Tensor(outputValue, {
|
|
2107
|
+
shape: outputShape,
|
|
2108
|
+
offset: 0,
|
|
2109
|
+
strides: outputStrides,
|
|
2110
|
+
numel: outputSize,
|
|
2111
|
+
...options
|
|
2112
|
+
});
|
|
1901
2113
|
}
|
|
1902
2114
|
// Reverse-mode autodiff call
|
|
1903
2115
|
backward(options = {}) {
|
|
@@ -1928,8 +2140,8 @@ class Tensor {
|
|
|
1928
2140
|
}
|
|
1929
2141
|
// Returns the raw number/nD array form of tensor
|
|
1930
2142
|
val() {
|
|
1931
|
-
if (
|
|
1932
|
-
return this.value;
|
|
2143
|
+
if (this.shape.length === 0)
|
|
2144
|
+
return this.value[0];
|
|
1933
2145
|
function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
|
|
1934
2146
|
if (dim === shape.length - 1) {
|
|
1935
2147
|
// Last dimension: extract elements using actual stride
|
|
@@ -1956,20 +2168,28 @@ class Tensor {
|
|
|
1956
2168
|
offset: this.offset,
|
|
1957
2169
|
numel: this.numel,
|
|
1958
2170
|
device: this.device,
|
|
2171
|
+
dtype: this.dtype,
|
|
1959
2172
|
requiresGrad: false
|
|
1960
2173
|
});
|
|
1961
2174
|
}
|
|
1962
2175
|
// Returns a copy of the tensor (with new data allocation) and keeps grad connection
|
|
1963
2176
|
clone() {
|
|
1964
2177
|
let out;
|
|
1965
|
-
if (
|
|
1966
|
-
out = new Tensor(this.value
|
|
2178
|
+
if (this.shape.length === 0) {
|
|
2179
|
+
out = new Tensor(this.value, {
|
|
2180
|
+
shape: [],
|
|
2181
|
+
strides: [],
|
|
2182
|
+
offset: 0,
|
|
2183
|
+
numel: 1,
|
|
2184
|
+
device: this.device,
|
|
2185
|
+
dtype: this.dtype
|
|
2186
|
+
});
|
|
1967
2187
|
}
|
|
1968
2188
|
else {
|
|
1969
2189
|
const contiguous = this.isContiguous();
|
|
1970
2190
|
const outputStrides = contiguous ? this.strides : Tensor.getStrides(this.shape);
|
|
1971
2191
|
const outputSize = this.numel;
|
|
1972
|
-
const outputValue = new
|
|
2192
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
1973
2193
|
if (contiguous) {
|
|
1974
2194
|
for (let index = 0; index < outputSize; index++) {
|
|
1975
2195
|
outputValue[index] = this.value[this.offset + index];
|
|
@@ -1982,7 +2202,14 @@ class Tensor {
|
|
|
1982
2202
|
outputValue[index] = this.value[this.offset + originalIndex];
|
|
1983
2203
|
}
|
|
1984
2204
|
}
|
|
1985
|
-
out = new Tensor(outputValue, {
|
|
2205
|
+
out = new Tensor(outputValue, {
|
|
2206
|
+
shape: this.shape,
|
|
2207
|
+
strides: outputStrides,
|
|
2208
|
+
offset: 0,
|
|
2209
|
+
numel: outputSize,
|
|
2210
|
+
device: this.device,
|
|
2211
|
+
dtype: this.dtype
|
|
2212
|
+
});
|
|
1986
2213
|
}
|
|
1987
2214
|
if (this.requiresGrad) {
|
|
1988
2215
|
out.requiresGrad = true;
|
|
@@ -2009,8 +2236,23 @@ class Tensor {
|
|
|
2009
2236
|
this.value = other.value;
|
|
2010
2237
|
this.strides = other.strides;
|
|
2011
2238
|
this.offset = other.offset;
|
|
2239
|
+
this.device = other.device;
|
|
2240
|
+
this.dtype = other.dtype;
|
|
2012
2241
|
return this;
|
|
2013
2242
|
}
|
|
2243
|
+
// Op to return a new tensor casted to another dtype
|
|
2244
|
+
cast(dtype) {
|
|
2245
|
+
if (this.dtype === dtype)
|
|
2246
|
+
return this;
|
|
2247
|
+
return new Tensor(this.value, {
|
|
2248
|
+
shape: this.shape,
|
|
2249
|
+
strides: this.strides,
|
|
2250
|
+
offset: this.offset,
|
|
2251
|
+
numel: this.numel,
|
|
2252
|
+
device: this.device,
|
|
2253
|
+
dtype: dtype
|
|
2254
|
+
});
|
|
2255
|
+
}
|
|
2014
2256
|
// Holds all available backends
|
|
2015
2257
|
static backends = new Map();
|
|
2016
2258
|
// Op to transfer tensor to another device
|