catniff 0.7.4 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +14 -0
- package/dist/core.d.ts +16 -11
- package/dist/core.js +362 -115
- package/dist/dtype.d.ts +5 -0
- package/dist/dtype.js +25 -0
- package/dist/nn.d.ts +9 -8
- package/dist/nn.js +50 -50
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.Tensor = void 0;
|
|
4
|
+
const dtype_1 = require("./dtype");
|
|
4
5
|
const utils_1 = require("./utils");
|
|
5
6
|
class Tensor {
|
|
6
7
|
value;
|
|
@@ -13,12 +14,16 @@ class Tensor {
|
|
|
13
14
|
gradFn;
|
|
14
15
|
children;
|
|
15
16
|
device;
|
|
17
|
+
dtype;
|
|
16
18
|
static training = false;
|
|
17
19
|
static noGrad = false;
|
|
18
20
|
static createGraph = false;
|
|
19
21
|
constructor(value, options = {}) {
|
|
20
|
-
//
|
|
21
|
-
this.
|
|
22
|
+
// Memory buffer
|
|
23
|
+
this.dtype = options.dtype || "float32";
|
|
24
|
+
const flatValue = Tensor.flattenValue(value);
|
|
25
|
+
const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
|
|
26
|
+
this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
|
|
22
27
|
// Tensor metadata
|
|
23
28
|
this.shape = options.shape || Tensor.getShape(value);
|
|
24
29
|
this.strides = options.strides || Tensor.getStrides(this.shape);
|
|
@@ -34,31 +39,34 @@ class Tensor {
|
|
|
34
39
|
this.to_(this.device);
|
|
35
40
|
}
|
|
36
41
|
// Utility to flatten an nD array to be 1D
|
|
37
|
-
static flattenValue(
|
|
42
|
+
static flattenValue(tensorValue) {
|
|
38
43
|
// Handle scalar tensors
|
|
39
|
-
if (typeof
|
|
40
|
-
return
|
|
44
|
+
if (typeof tensorValue === "number")
|
|
45
|
+
return [tensorValue];
|
|
41
46
|
// If value is already 1D, we just need to return the value ('s reference)
|
|
42
|
-
if (typeof
|
|
43
|
-
return
|
|
47
|
+
if (typeof tensorValue[0] === "number")
|
|
48
|
+
return tensorValue;
|
|
44
49
|
// Or else recursively traverse through the nD array to flatten
|
|
45
50
|
const result = [];
|
|
46
51
|
function traverse(arr) {
|
|
47
52
|
if (typeof arr === "number") {
|
|
48
53
|
result.push(arr);
|
|
54
|
+
// Assume if we can index a value, it is an ArrayLike
|
|
49
55
|
}
|
|
50
|
-
else if (
|
|
51
|
-
arr.
|
|
56
|
+
else if (typeof arr[0] !== "undefined") {
|
|
57
|
+
for (let index = 0; index < arr.length; index++) {
|
|
58
|
+
traverse(arr[index]);
|
|
59
|
+
}
|
|
52
60
|
}
|
|
53
61
|
}
|
|
54
|
-
traverse(
|
|
62
|
+
traverse(tensorValue);
|
|
55
63
|
return result;
|
|
56
64
|
}
|
|
57
65
|
// Utility to get shape from tensor *value*
|
|
58
|
-
static getShape(
|
|
66
|
+
static getShape(tensorValue) {
|
|
59
67
|
const shape = [];
|
|
60
|
-
let subA =
|
|
61
|
-
while (
|
|
68
|
+
let subA = tensorValue;
|
|
69
|
+
while (typeof subA !== "number") {
|
|
62
70
|
shape.push(subA.length);
|
|
63
71
|
subA = subA[0];
|
|
64
72
|
}
|
|
@@ -146,17 +154,52 @@ class Tensor {
|
|
|
146
154
|
}
|
|
147
155
|
return prod;
|
|
148
156
|
}
|
|
149
|
-
|
|
157
|
+
// Utility to get best possible result type if type conflicts happen:
|
|
158
|
+
static getResultDtype(type1, type2) {
|
|
159
|
+
if (type1 === type2)
|
|
160
|
+
return type1;
|
|
161
|
+
const type1Ranking = dtype_1.dtypeHiearchy[type1];
|
|
162
|
+
const type2Ranking = dtype_1.dtypeHiearchy[type2];
|
|
163
|
+
if (type1Ranking > type2Ranking) {
|
|
164
|
+
return type1;
|
|
165
|
+
}
|
|
166
|
+
return type2;
|
|
167
|
+
}
|
|
168
|
+
// Utility to handle other tensor if an op needs a second operand
|
|
169
|
+
handleOther(other) {
|
|
170
|
+
if (other instanceof Tensor) {
|
|
171
|
+
if (this.device !== other.device) {
|
|
172
|
+
throw new Error("Can not operate on tensors that are not on the same device");
|
|
173
|
+
}
|
|
174
|
+
return other;
|
|
175
|
+
}
|
|
176
|
+
return new Tensor(other, {
|
|
177
|
+
offset: 0,
|
|
178
|
+
device: this.device,
|
|
179
|
+
dtype: this.dtype
|
|
180
|
+
});
|
|
181
|
+
}
|
|
150
182
|
// Utility for binary (two operators involved) element-wise ops
|
|
151
183
|
static elementWiseAB(tA, tB, op) {
|
|
152
|
-
|
|
153
|
-
|
|
184
|
+
const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype);
|
|
185
|
+
// Both are scalars
|
|
186
|
+
if (tA.shape.length === 0 && tB.shape.length === 0) {
|
|
187
|
+
return new Tensor(op(tA.value[0], tB.value[0]), {
|
|
188
|
+
shape: [],
|
|
189
|
+
strides: [],
|
|
190
|
+
offset: 0,
|
|
191
|
+
numel: 1,
|
|
192
|
+
device: tA.device,
|
|
193
|
+
dtype: outputDtype
|
|
194
|
+
});
|
|
154
195
|
}
|
|
155
|
-
|
|
156
|
-
|
|
196
|
+
// First tensor is scalar
|
|
197
|
+
if (tA.shape.length === 0) {
|
|
198
|
+
return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0]));
|
|
157
199
|
}
|
|
158
|
-
|
|
159
|
-
|
|
200
|
+
// Second tensor is scalar
|
|
201
|
+
if (tB.shape.length === 0) {
|
|
202
|
+
return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0]));
|
|
160
203
|
}
|
|
161
204
|
// Pad + broadcast shape
|
|
162
205
|
const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
|
|
@@ -164,7 +207,7 @@ class Tensor {
|
|
|
164
207
|
// Get other output info
|
|
165
208
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
166
209
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
167
|
-
const outputValue = new
|
|
210
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
|
|
168
211
|
for (let i = 0; i < outputSize; i++) {
|
|
169
212
|
// Get coordinates from 1D index
|
|
170
213
|
const coordsOutput = Tensor.indexToCoords(i, outputStrides);
|
|
@@ -178,18 +221,29 @@ class Tensor {
|
|
|
178
221
|
return new Tensor(outputValue, {
|
|
179
222
|
shape: outputShape,
|
|
180
223
|
strides: outputStrides,
|
|
181
|
-
|
|
224
|
+
offset: 0,
|
|
225
|
+
numel: outputSize,
|
|
226
|
+
device: tA.device,
|
|
227
|
+
dtype: outputDtype
|
|
182
228
|
});
|
|
183
229
|
}
|
|
184
230
|
// Utility for self-inflicting element-wise ops
|
|
185
231
|
static elementWiseSelf(tA, op) {
|
|
186
|
-
|
|
187
|
-
|
|
232
|
+
// Handle scalar case
|
|
233
|
+
if (tA.shape.length === 0)
|
|
234
|
+
return new Tensor(op(tA.value[0]), {
|
|
235
|
+
shape: [],
|
|
236
|
+
strides: [],
|
|
237
|
+
offset: 0,
|
|
238
|
+
numel: 1,
|
|
239
|
+
device: tA.device,
|
|
240
|
+
dtype: tA.dtype
|
|
241
|
+
});
|
|
188
242
|
const contiguous = tA.isContiguous();
|
|
189
243
|
const outputShape = tA.shape;
|
|
190
244
|
const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
|
|
191
245
|
const outputSize = tA.numel;
|
|
192
|
-
const outputValue = new
|
|
246
|
+
const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize);
|
|
193
247
|
if (contiguous) {
|
|
194
248
|
for (let index = 0; index < outputSize; index++) {
|
|
195
249
|
outputValue[index] = op(tA.value[index + tA.offset]);
|
|
@@ -202,7 +256,14 @@ class Tensor {
|
|
|
202
256
|
outputValue[index] = op(tA.value[originalIndex]);
|
|
203
257
|
}
|
|
204
258
|
}
|
|
205
|
-
return new Tensor(outputValue, {
|
|
259
|
+
return new Tensor(outputValue, {
|
|
260
|
+
shape: outputShape,
|
|
261
|
+
strides: outputStrides,
|
|
262
|
+
offset: 0,
|
|
263
|
+
numel: tA.numel,
|
|
264
|
+
device: tA.device,
|
|
265
|
+
dtype: tA.dtype
|
|
266
|
+
});
|
|
206
267
|
}
|
|
207
268
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
208
269
|
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
@@ -246,16 +307,6 @@ class Tensor {
|
|
|
246
307
|
}
|
|
247
308
|
return out;
|
|
248
309
|
}
|
|
249
|
-
// Utility to handle other tensor if an op needs a second operand
|
|
250
|
-
handleOther(other) {
|
|
251
|
-
if (other instanceof Tensor) {
|
|
252
|
-
if (this.device !== other.device) {
|
|
253
|
-
throw new Error("Can not operate on tensors that are not on the same device");
|
|
254
|
-
}
|
|
255
|
-
return other;
|
|
256
|
-
}
|
|
257
|
-
return new Tensor(other, { device: this.device });
|
|
258
|
-
}
|
|
259
310
|
// Utility to add to gradient of tensor
|
|
260
311
|
static addGrad(tensor, accumGrad) {
|
|
261
312
|
const axesToSqueeze = [];
|
|
@@ -278,7 +329,7 @@ class Tensor {
|
|
|
278
329
|
tensor.grad = squeezedGrad;
|
|
279
330
|
}
|
|
280
331
|
else {
|
|
281
|
-
tensor.grad = tensor.grad.add(squeezedGrad);
|
|
332
|
+
tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
|
|
282
333
|
}
|
|
283
334
|
}
|
|
284
335
|
static normalizeDims(dims, numDims) {
|
|
@@ -306,20 +357,27 @@ class Tensor {
|
|
|
306
357
|
}
|
|
307
358
|
contiguous() {
|
|
308
359
|
// Check if scalar
|
|
309
|
-
if (
|
|
360
|
+
if (this.shape.length === 0)
|
|
310
361
|
return this;
|
|
311
362
|
// Check if already contiguous
|
|
312
363
|
if (this.isContiguous())
|
|
313
364
|
return this;
|
|
314
365
|
const outputStrides = Tensor.getStrides(this.shape);
|
|
315
366
|
const outputSize = this.numel;
|
|
316
|
-
const outputValue = new
|
|
367
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
317
368
|
for (let index = 0; index < outputSize; index++) {
|
|
318
369
|
const outputCoords = Tensor.indexToCoords(index, outputStrides);
|
|
319
370
|
const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
|
|
320
371
|
outputValue[index] = this.value[this.offset + originalIndex];
|
|
321
372
|
}
|
|
322
|
-
const out = new Tensor(outputValue, {
|
|
373
|
+
const out = new Tensor(outputValue, {
|
|
374
|
+
shape: this.shape,
|
|
375
|
+
strides: outputStrides,
|
|
376
|
+
offset: 0,
|
|
377
|
+
numel: outputSize,
|
|
378
|
+
device: this.device,
|
|
379
|
+
dtype: this.dtype
|
|
380
|
+
});
|
|
323
381
|
// Gradient flow back to the original tensor
|
|
324
382
|
if (this.requiresGrad) {
|
|
325
383
|
out.requiresGrad = true;
|
|
@@ -334,7 +392,7 @@ class Tensor {
|
|
|
334
392
|
// Verify shape size
|
|
335
393
|
const originalSize = this.numel;
|
|
336
394
|
const outputSize = Tensor.shapeToSize(newShape);
|
|
337
|
-
if (originalSize !== outputSize
|
|
395
|
+
if (originalSize !== outputSize) {
|
|
338
396
|
throw new Error("Can not create view: incompatible sizes");
|
|
339
397
|
}
|
|
340
398
|
// Verify compatibility (only contiguity for now)
|
|
@@ -347,7 +405,8 @@ class Tensor {
|
|
|
347
405
|
strides: outputStrides,
|
|
348
406
|
offset: this.offset,
|
|
349
407
|
numel: outputSize,
|
|
350
|
-
device: this.device
|
|
408
|
+
device: this.device,
|
|
409
|
+
dtype: this.dtype
|
|
351
410
|
});
|
|
352
411
|
// Gradient reshaped and flow back to the original tensor
|
|
353
412
|
if (this.requiresGrad) {
|
|
@@ -423,7 +482,8 @@ class Tensor {
|
|
|
423
482
|
strides: newStrides,
|
|
424
483
|
offset: this.offset,
|
|
425
484
|
numel: this.numel,
|
|
426
|
-
device: this.device
|
|
485
|
+
device: this.device,
|
|
486
|
+
dtype: this.dtype
|
|
427
487
|
});
|
|
428
488
|
out.requiresGrad = this.requiresGrad;
|
|
429
489
|
// Handle gradient if needed
|
|
@@ -464,7 +524,8 @@ class Tensor {
|
|
|
464
524
|
strides: newStrides,
|
|
465
525
|
offset: this.offset,
|
|
466
526
|
numel: this.numel,
|
|
467
|
-
device: this.device
|
|
527
|
+
device: this.device,
|
|
528
|
+
dtype: this.dtype
|
|
468
529
|
});
|
|
469
530
|
if (this.requiresGrad) {
|
|
470
531
|
out.requiresGrad = true;
|
|
@@ -484,7 +545,7 @@ class Tensor {
|
|
|
484
545
|
}
|
|
485
546
|
// Utility for indexing with array of indices
|
|
486
547
|
indexWithArray(indices) {
|
|
487
|
-
if (
|
|
548
|
+
if (this.shape.length === 0)
|
|
488
549
|
return this;
|
|
489
550
|
indices = Tensor.normalizeDims(indices, this.shape[0]);
|
|
490
551
|
// Init necessary stuff for indexing
|
|
@@ -494,7 +555,7 @@ class Tensor {
|
|
|
494
555
|
// Init output data
|
|
495
556
|
const outputShape = [indices.length, ...reducedShape];
|
|
496
557
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
497
|
-
const outputValue = new
|
|
558
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
498
559
|
for (let i = 0; i < indices.length; i++) {
|
|
499
560
|
const sourceRowIndex = indices[i];
|
|
500
561
|
const targetStart = i * elementsPerIndex;
|
|
@@ -507,7 +568,10 @@ class Tensor {
|
|
|
507
568
|
}
|
|
508
569
|
const out = new Tensor(outputValue, {
|
|
509
570
|
shape: outputShape,
|
|
510
|
-
|
|
571
|
+
offset: 0,
|
|
572
|
+
numel: outputSize,
|
|
573
|
+
device: this.device,
|
|
574
|
+
dtype: this.dtype
|
|
511
575
|
});
|
|
512
576
|
// Handle gradient
|
|
513
577
|
if (this.requiresGrad) {
|
|
@@ -536,13 +600,13 @@ class Tensor {
|
|
|
536
600
|
// Tensor indexing
|
|
537
601
|
index(indices) {
|
|
538
602
|
const tensorIndices = this.handleOther(indices).clone();
|
|
539
|
-
if (
|
|
540
|
-
return this.indexWithArray([tensorIndices.value]).squeeze(0);
|
|
603
|
+
if (tensorIndices.shape.length === 0) {
|
|
604
|
+
return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
|
|
541
605
|
}
|
|
542
606
|
else {
|
|
543
607
|
const originalShape = tensorIndices.shape;
|
|
544
608
|
const flatIndices = tensorIndices.value;
|
|
545
|
-
const result = this.indexWithArray(flatIndices);
|
|
609
|
+
const result = this.indexWithArray(Array.from(flatIndices));
|
|
546
610
|
// Reshape to preserve input shape
|
|
547
611
|
const outputShape = [...originalShape, ...this.shape.slice(1)];
|
|
548
612
|
return result.reshape(outputShape);
|
|
@@ -551,7 +615,7 @@ class Tensor {
|
|
|
551
615
|
// Tensor slicing
|
|
552
616
|
slice(ranges) {
|
|
553
617
|
// Handle scalars
|
|
554
|
-
if (
|
|
618
|
+
if (this.shape.length === 0)
|
|
555
619
|
return this;
|
|
556
620
|
const newShape = [];
|
|
557
621
|
const newStrides = [];
|
|
@@ -589,7 +653,8 @@ class Tensor {
|
|
|
589
653
|
shape: newShape,
|
|
590
654
|
strides: newStrides,
|
|
591
655
|
offset: newOffset,
|
|
592
|
-
device: this.device
|
|
656
|
+
device: this.device,
|
|
657
|
+
dtype: this.dtype
|
|
593
658
|
});
|
|
594
659
|
if (this.requiresGrad) {
|
|
595
660
|
out.requiresGrad = true;
|
|
@@ -651,7 +716,7 @@ class Tensor {
|
|
|
651
716
|
expand(newShape) {
|
|
652
717
|
// Handle scalars
|
|
653
718
|
let self = this;
|
|
654
|
-
if (
|
|
719
|
+
if (this.shape.length === 0) {
|
|
655
720
|
self = self.unsqueeze(0);
|
|
656
721
|
}
|
|
657
722
|
// Pad shapes to same length
|
|
@@ -675,8 +740,8 @@ class Tensor {
|
|
|
675
740
|
shape: targetShape,
|
|
676
741
|
strides: newStrides,
|
|
677
742
|
offset: self.offset,
|
|
678
|
-
|
|
679
|
-
|
|
743
|
+
device: self.device,
|
|
744
|
+
dtype: self.dtype
|
|
680
745
|
});
|
|
681
746
|
if (self.requiresGrad) {
|
|
682
747
|
out.requiresGrad = true;
|
|
@@ -691,7 +756,7 @@ class Tensor {
|
|
|
691
756
|
cat(other, dim = 0) {
|
|
692
757
|
other = this.handleOther(other);
|
|
693
758
|
// Handle scalars
|
|
694
|
-
if (
|
|
759
|
+
if (this.shape.length === 0 || other.shape.length === 0) {
|
|
695
760
|
throw new Error("Can not concatenate scalars");
|
|
696
761
|
}
|
|
697
762
|
// Handle negative indices
|
|
@@ -720,7 +785,8 @@ class Tensor {
|
|
|
720
785
|
}
|
|
721
786
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
722
787
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
723
|
-
const
|
|
788
|
+
const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
789
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
|
|
724
790
|
for (let outIndex = 0; outIndex < outputSize; outIndex++) {
|
|
725
791
|
const coords = Tensor.indexToCoords(outIndex, outputStrides);
|
|
726
792
|
// Check which tensor this output position comes from
|
|
@@ -740,7 +806,10 @@ class Tensor {
|
|
|
740
806
|
const out = new Tensor(outputValue, {
|
|
741
807
|
shape: outputShape,
|
|
742
808
|
strides: outputStrides,
|
|
743
|
-
|
|
809
|
+
offset: 0,
|
|
810
|
+
numel: outputSize,
|
|
811
|
+
device: this.device,
|
|
812
|
+
dtype: this.dtype
|
|
744
813
|
});
|
|
745
814
|
if (this.requiresGrad) {
|
|
746
815
|
out.requiresGrad = true;
|
|
@@ -773,7 +842,7 @@ class Tensor {
|
|
|
773
842
|
}
|
|
774
843
|
// Tensor squeeze
|
|
775
844
|
squeeze(dims) {
|
|
776
|
-
if (
|
|
845
|
+
if (this.shape.length === 0)
|
|
777
846
|
return this;
|
|
778
847
|
if (typeof dims === "number") {
|
|
779
848
|
dims = [dims];
|
|
@@ -807,7 +876,9 @@ class Tensor {
|
|
|
807
876
|
shape: outShape,
|
|
808
877
|
strides: outStrides,
|
|
809
878
|
offset: this.offset,
|
|
810
|
-
|
|
879
|
+
numel: this.numel,
|
|
880
|
+
device: this.device,
|
|
881
|
+
dtype: this.dtype
|
|
811
882
|
});
|
|
812
883
|
// Set up gradient if needed
|
|
813
884
|
if (this.requiresGrad) {
|
|
@@ -830,9 +901,6 @@ class Tensor {
|
|
|
830
901
|
dim += this.shape.length;
|
|
831
902
|
}
|
|
832
903
|
let thisValue = this.value;
|
|
833
|
-
if (typeof thisValue === "number") {
|
|
834
|
-
thisValue = [thisValue];
|
|
835
|
-
}
|
|
836
904
|
// Insert size-1 dimension at specified position
|
|
837
905
|
const newShape = [...this.shape];
|
|
838
906
|
newShape.splice(dim, 0, 1);
|
|
@@ -852,7 +920,9 @@ class Tensor {
|
|
|
852
920
|
shape: newShape,
|
|
853
921
|
strides: newStrides,
|
|
854
922
|
offset: this.offset,
|
|
855
|
-
|
|
923
|
+
numel: this.numel,
|
|
924
|
+
device: this.device,
|
|
925
|
+
dtype: this.dtype
|
|
856
926
|
});
|
|
857
927
|
// Set up gradient if needed
|
|
858
928
|
if (this.requiresGrad) {
|
|
@@ -866,10 +936,13 @@ class Tensor {
|
|
|
866
936
|
}
|
|
867
937
|
// Generic reduction operation handler
|
|
868
938
|
static reduce(tensor, dims, keepDims, config) {
|
|
869
|
-
if (
|
|
939
|
+
if (tensor.shape.length === 0)
|
|
870
940
|
return tensor;
|
|
871
941
|
if (typeof dims === "undefined") {
|
|
872
|
-
dims = Array
|
|
942
|
+
dims = new Array(tensor.shape.length);
|
|
943
|
+
for (let index = 0; index < dims.length; index++) {
|
|
944
|
+
dims[index] = index;
|
|
945
|
+
}
|
|
873
946
|
}
|
|
874
947
|
if (Array.isArray(dims)) {
|
|
875
948
|
dims = Tensor.normalizeDims(dims, tensor.shape.length);
|
|
@@ -883,8 +956,8 @@ class Tensor {
|
|
|
883
956
|
const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
884
957
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
885
958
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
886
|
-
const outputValue = new
|
|
887
|
-
const outputCounters = config.needsCounters ? new
|
|
959
|
+
const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
|
|
960
|
+
const outputCounters = config.needsCounters ? new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0) : new dtype_1.TypedArray[tensor.dtype]();
|
|
888
961
|
const originalSize = tensor.numel;
|
|
889
962
|
const originalValue = tensor.value;
|
|
890
963
|
const linearStrides = Tensor.getStrides(tensor.shape);
|
|
@@ -908,15 +981,22 @@ class Tensor {
|
|
|
908
981
|
if (config.postProcess) {
|
|
909
982
|
config.postProcess({ values: outputValue, counters: outputCounters });
|
|
910
983
|
}
|
|
911
|
-
const out = new Tensor(outputValue, {
|
|
984
|
+
const out = new Tensor(outputValue, {
|
|
985
|
+
shape: outputShape,
|
|
986
|
+
strides: outputStrides,
|
|
987
|
+
offset: 0,
|
|
988
|
+
numel: outputSize,
|
|
989
|
+
device: tensor.device,
|
|
990
|
+
dtype: tensor.dtype
|
|
991
|
+
});
|
|
912
992
|
// Gradient setup
|
|
913
993
|
if (tensor.requiresGrad) {
|
|
914
994
|
out.requiresGrad = true;
|
|
915
995
|
out.children.push(tensor);
|
|
916
996
|
out.gradFn = () => {
|
|
917
|
-
let shareCounts = [];
|
|
997
|
+
let shareCounts = new dtype_1.TypedArray[tensor.dtype]();
|
|
918
998
|
if (config.needsShareCounts) {
|
|
919
|
-
shareCounts = new
|
|
999
|
+
shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0);
|
|
920
1000
|
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
|
|
921
1001
|
// Convert linear index to coordinates using contiguous strides
|
|
922
1002
|
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
|
|
@@ -929,7 +1009,7 @@ class Tensor {
|
|
|
929
1009
|
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
930
1010
|
}
|
|
931
1011
|
}
|
|
932
|
-
const gradValue = new
|
|
1012
|
+
const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize);
|
|
933
1013
|
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
|
|
934
1014
|
// Convert linear index to coordinates using contiguous strides
|
|
935
1015
|
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
|
|
@@ -947,7 +1027,13 @@ class Tensor {
|
|
|
947
1027
|
outIndex: outFlatIndex
|
|
948
1028
|
});
|
|
949
1029
|
}
|
|
950
|
-
const localGrad = new Tensor(gradValue, {
|
|
1030
|
+
const localGrad = new Tensor(gradValue, {
|
|
1031
|
+
shape: tensor.shape,
|
|
1032
|
+
offset: 0,
|
|
1033
|
+
numel: tensor.numel,
|
|
1034
|
+
device: tensor.device,
|
|
1035
|
+
dtype: tensor.dtype
|
|
1036
|
+
});
|
|
951
1037
|
Tensor.addGrad(tensor, out.grad.mul(localGrad));
|
|
952
1038
|
};
|
|
953
1039
|
}
|
|
@@ -1017,7 +1103,7 @@ class Tensor {
|
|
|
1017
1103
|
}
|
|
1018
1104
|
// Tensor softmax
|
|
1019
1105
|
softmax(dim = -1) {
|
|
1020
|
-
if (
|
|
1106
|
+
if (this.shape.length === 0)
|
|
1021
1107
|
return this;
|
|
1022
1108
|
// Handle negative indexing
|
|
1023
1109
|
if (dim < 0) {
|
|
@@ -1035,7 +1121,7 @@ class Tensor {
|
|
|
1035
1121
|
}
|
|
1036
1122
|
// Tensor softmin
|
|
1037
1123
|
softmin(dim = -1) {
|
|
1038
|
-
if (
|
|
1124
|
+
if (this.shape.length === 0)
|
|
1039
1125
|
return this;
|
|
1040
1126
|
// Handle negative indexing
|
|
1041
1127
|
if (dim < 0) {
|
|
@@ -1435,10 +1521,11 @@ class Tensor {
|
|
|
1435
1521
|
const matBCols = other.shape[1];
|
|
1436
1522
|
if (matACols !== matBRows)
|
|
1437
1523
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1524
|
+
const matCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1438
1525
|
const matCShape = [matARows, matBCols];
|
|
1439
1526
|
const matCStrides = Tensor.getStrides(matCShape);
|
|
1440
1527
|
const matCSize = Tensor.shapeToSize(matCShape);
|
|
1441
|
-
const matC = new
|
|
1528
|
+
const matC = new dtype_1.TypedArray[matCDtype](matCSize).fill(0);
|
|
1442
1529
|
for (let i = 0; i < matARows; i++) {
|
|
1443
1530
|
for (let j = 0; j < matBCols; j++) {
|
|
1444
1531
|
for (let k = 0; k < matACols; k++) {
|
|
@@ -1449,7 +1536,14 @@ class Tensor {
|
|
|
1449
1536
|
}
|
|
1450
1537
|
}
|
|
1451
1538
|
}
|
|
1452
|
-
const out = new Tensor(matC, {
|
|
1539
|
+
const out = new Tensor(matC, {
|
|
1540
|
+
shape: matCShape,
|
|
1541
|
+
strides: matCStrides,
|
|
1542
|
+
offset: 0,
|
|
1543
|
+
numel: matCSize,
|
|
1544
|
+
device: this.device,
|
|
1545
|
+
dtype: matCDtype
|
|
1546
|
+
});
|
|
1453
1547
|
if (this.requiresGrad) {
|
|
1454
1548
|
out.requiresGrad = true;
|
|
1455
1549
|
out.children.push(this);
|
|
@@ -1490,10 +1584,11 @@ class Tensor {
|
|
|
1490
1584
|
const batchBCols = other.shape[2];
|
|
1491
1585
|
if (batchACols !== batchBRows)
|
|
1492
1586
|
throw new Error("Invalid matrices shape for multiplication");
|
|
1587
|
+
const batchCDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1493
1588
|
const batchCShape = [batchSize, batchARows, batchBCols];
|
|
1494
1589
|
const batchCStrides = Tensor.getStrides(batchCShape);
|
|
1495
1590
|
const batchCSize = Tensor.shapeToSize(batchCShape);
|
|
1496
|
-
const batchC = new
|
|
1591
|
+
const batchC = new dtype_1.TypedArray[batchCDtype](batchCSize).fill(0);
|
|
1497
1592
|
for (let q = 0; q < batchSize; q++) {
|
|
1498
1593
|
for (let i = 0; i < batchARows; i++) {
|
|
1499
1594
|
for (let j = 0; j < batchBCols; j++) {
|
|
@@ -1506,7 +1601,14 @@ class Tensor {
|
|
|
1506
1601
|
}
|
|
1507
1602
|
}
|
|
1508
1603
|
}
|
|
1509
|
-
const out = new Tensor(batchC, {
|
|
1604
|
+
const out = new Tensor(batchC, {
|
|
1605
|
+
shape: batchCShape,
|
|
1606
|
+
strides: batchCStrides,
|
|
1607
|
+
offset: 0,
|
|
1608
|
+
numel: batchCSize,
|
|
1609
|
+
device: this.device,
|
|
1610
|
+
dtype: batchCDtype
|
|
1611
|
+
});
|
|
1510
1612
|
if (this.requiresGrad) {
|
|
1511
1613
|
out.requiresGrad = true;
|
|
1512
1614
|
out.children.push(this);
|
|
@@ -1583,10 +1685,11 @@ class Tensor {
|
|
|
1583
1685
|
const offsetSize = Tensor.shapeToSize(offsetShape);
|
|
1584
1686
|
const offsetStrides = Tensor.getStrides(offsetShape);
|
|
1585
1687
|
// Output shape, strides, size, value
|
|
1688
|
+
const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
|
|
1586
1689
|
const outputShape = [...offsetShape, batchARows, batchBCols];
|
|
1587
1690
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
1588
1691
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
1589
|
-
const outputValue = new
|
|
1692
|
+
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize).fill(0);
|
|
1590
1693
|
const outputOffsetStrides = outputStrides.slice(0, -2);
|
|
1591
1694
|
// Loop through outer dims and do matmul on two outer-most dims
|
|
1592
1695
|
for (let index = 0; index < offsetSize; index++) {
|
|
@@ -1605,7 +1708,14 @@ class Tensor {
|
|
|
1605
1708
|
}
|
|
1606
1709
|
}
|
|
1607
1710
|
}
|
|
1608
|
-
const out = new Tensor(outputValue, {
|
|
1711
|
+
const out = new Tensor(outputValue, {
|
|
1712
|
+
shape: outputShape,
|
|
1713
|
+
strides: outputStrides,
|
|
1714
|
+
offset: 0,
|
|
1715
|
+
numel: outputSize,
|
|
1716
|
+
device: this.device,
|
|
1717
|
+
dtype: outputDtype
|
|
1718
|
+
});
|
|
1609
1719
|
if (this.requiresGrad) {
|
|
1610
1720
|
out.requiresGrad = true;
|
|
1611
1721
|
out.children.push(this);
|
|
@@ -1647,7 +1757,7 @@ class Tensor {
|
|
|
1647
1757
|
const maskShape = this.shape.slice(-2);
|
|
1648
1758
|
const maskStrides = Tensor.getStrides(maskShape);
|
|
1649
1759
|
const maskSize = Tensor.shapeToSize(maskShape);
|
|
1650
|
-
const maskValue = new
|
|
1760
|
+
const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(1);
|
|
1651
1761
|
const [rows, cols] = maskShape;
|
|
1652
1762
|
for (let i = 0; i < rows; i++) {
|
|
1653
1763
|
const maxJ = Math.min(cols, i + diagonal);
|
|
@@ -1658,8 +1768,10 @@ class Tensor {
|
|
|
1658
1768
|
const mask = new Tensor(maskValue, {
|
|
1659
1769
|
shape: maskShape,
|
|
1660
1770
|
strides: maskStrides,
|
|
1771
|
+
offset: 0,
|
|
1661
1772
|
numel: maskSize,
|
|
1662
|
-
device: this.device
|
|
1773
|
+
device: this.device,
|
|
1774
|
+
dtype: this.dtype
|
|
1663
1775
|
});
|
|
1664
1776
|
return this.mul(mask);
|
|
1665
1777
|
}
|
|
@@ -1671,7 +1783,7 @@ class Tensor {
|
|
|
1671
1783
|
const maskShape = this.shape.slice(-2);
|
|
1672
1784
|
const maskStrides = Tensor.getStrides(maskShape);
|
|
1673
1785
|
const maskSize = Tensor.shapeToSize(maskShape);
|
|
1674
|
-
const maskValue = new
|
|
1786
|
+
const maskValue = new dtype_1.TypedArray[this.dtype](maskSize).fill(0);
|
|
1675
1787
|
const [rows, cols] = maskShape;
|
|
1676
1788
|
for (let i = 0; i < rows; i++) {
|
|
1677
1789
|
const maxJ = Math.min(cols, i + diagonal + 1);
|
|
@@ -1682,8 +1794,10 @@ class Tensor {
|
|
|
1682
1794
|
const mask = new Tensor(maskValue, {
|
|
1683
1795
|
shape: maskShape,
|
|
1684
1796
|
strides: maskStrides,
|
|
1797
|
+
offset: 0,
|
|
1685
1798
|
numel: maskSize,
|
|
1686
|
-
device: this.device
|
|
1799
|
+
device: this.device,
|
|
1800
|
+
dtype: this.dtype
|
|
1687
1801
|
});
|
|
1688
1802
|
return this.mul(mask);
|
|
1689
1803
|
}
|
|
@@ -1698,16 +1812,28 @@ class Tensor {
|
|
|
1698
1812
|
return new Tensor(num, options);
|
|
1699
1813
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1700
1814
|
const outputValue = new Array(outputSize).fill(num);
|
|
1701
|
-
return new Tensor(outputValue, {
|
|
1815
|
+
return new Tensor(outputValue, {
|
|
1816
|
+
shape,
|
|
1817
|
+
offset: 0,
|
|
1818
|
+
numel: outputSize,
|
|
1819
|
+
...options
|
|
1820
|
+
});
|
|
1702
1821
|
}
|
|
1703
1822
|
// Utility to create a new tensor with shape of another tensor, filled with a number
|
|
1704
1823
|
static fullLike(tensor, num, options = {}) {
|
|
1705
|
-
if (
|
|
1706
|
-
return new Tensor(num,
|
|
1824
|
+
if (tensor.shape.length === 0)
|
|
1825
|
+
return new Tensor(num, {
|
|
1826
|
+
offset: 0,
|
|
1827
|
+
device: tensor.device,
|
|
1828
|
+
dtype: tensor.dtype,
|
|
1829
|
+
...options
|
|
1830
|
+
});
|
|
1707
1831
|
return new Tensor(new Array(tensor.numel).fill(num), {
|
|
1708
1832
|
shape: tensor.shape,
|
|
1833
|
+
offset: 0,
|
|
1709
1834
|
numel: tensor.numel,
|
|
1710
1835
|
device: tensor.device,
|
|
1836
|
+
dtype: tensor.dtype,
|
|
1711
1837
|
...options
|
|
1712
1838
|
});
|
|
1713
1839
|
}
|
|
@@ -1717,16 +1843,28 @@ class Tensor {
|
|
|
1717
1843
|
return new Tensor(1, options);
|
|
1718
1844
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1719
1845
|
const outputValue = new Array(outputSize).fill(1);
|
|
1720
|
-
return new Tensor(outputValue, {
|
|
1846
|
+
return new Tensor(outputValue, {
|
|
1847
|
+
shape,
|
|
1848
|
+
offset: 0,
|
|
1849
|
+
numel: outputSize,
|
|
1850
|
+
...options
|
|
1851
|
+
});
|
|
1721
1852
|
}
|
|
1722
1853
|
// Utility to create a new tensor with shape of another tensor, filled with 1
|
|
1723
1854
|
static onesLike(tensor, options = {}) {
|
|
1724
|
-
if (
|
|
1725
|
-
return new Tensor(1,
|
|
1855
|
+
if (tensor.shape.length === 0)
|
|
1856
|
+
return new Tensor(1, {
|
|
1857
|
+
offset: 0,
|
|
1858
|
+
device: tensor.device,
|
|
1859
|
+
dtype: tensor.dtype,
|
|
1860
|
+
...options
|
|
1861
|
+
});
|
|
1726
1862
|
return new Tensor(new Array(tensor.numel).fill(1), {
|
|
1727
1863
|
shape: tensor.shape,
|
|
1864
|
+
offset: 0,
|
|
1728
1865
|
numel: tensor.numel,
|
|
1729
1866
|
device: tensor.device,
|
|
1867
|
+
dtype: tensor.dtype,
|
|
1730
1868
|
...options
|
|
1731
1869
|
});
|
|
1732
1870
|
}
|
|
@@ -1736,16 +1874,28 @@ class Tensor {
|
|
|
1736
1874
|
return new Tensor(0, options);
|
|
1737
1875
|
const outputSize = Tensor.shapeToSize(shape);
|
|
1738
1876
|
const outputValue = new Array(outputSize).fill(0);
|
|
1739
|
-
return new Tensor(outputValue, {
|
|
1877
|
+
return new Tensor(outputValue, {
|
|
1878
|
+
shape,
|
|
1879
|
+
offset: 0,
|
|
1880
|
+
numel: outputSize,
|
|
1881
|
+
...options
|
|
1882
|
+
});
|
|
1740
1883
|
}
|
|
1741
1884
|
// Utility to create a new tensor with shape of another tensor, filled with 0
|
|
1742
1885
|
static zerosLike(tensor, options = {}) {
|
|
1743
|
-
if (
|
|
1744
|
-
return new Tensor(0,
|
|
1886
|
+
if (tensor.shape.length === 0)
|
|
1887
|
+
return new Tensor(0, {
|
|
1888
|
+
offset: 0,
|
|
1889
|
+
device: tensor.device,
|
|
1890
|
+
dtype: tensor.dtype,
|
|
1891
|
+
...options
|
|
1892
|
+
});
|
|
1745
1893
|
return new Tensor(new Array(tensor.numel).fill(0), {
|
|
1746
1894
|
shape: tensor.shape,
|
|
1895
|
+
offset: 0,
|
|
1747
1896
|
numel: tensor.numel,
|
|
1748
1897
|
device: tensor.device,
|
|
1898
|
+
dtype: tensor.dtype,
|
|
1749
1899
|
...options
|
|
1750
1900
|
});
|
|
1751
1901
|
}
|
|
@@ -1758,20 +1908,32 @@ class Tensor {
|
|
|
1758
1908
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1759
1909
|
outputValue[index] = (0, utils_1.randUniform)();
|
|
1760
1910
|
}
|
|
1761
|
-
return new Tensor(outputValue, {
|
|
1911
|
+
return new Tensor(outputValue, {
|
|
1912
|
+
shape,
|
|
1913
|
+
offset: 0,
|
|
1914
|
+
numel: outputSize,
|
|
1915
|
+
...options
|
|
1916
|
+
});
|
|
1762
1917
|
}
|
|
1763
1918
|
// Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
|
|
1764
1919
|
static randLike(tensor, options = {}) {
|
|
1765
|
-
if (
|
|
1766
|
-
return new Tensor((0, utils_1.randUniform)(),
|
|
1920
|
+
if (tensor.shape.length === 0)
|
|
1921
|
+
return new Tensor((0, utils_1.randUniform)(), {
|
|
1922
|
+
offset: 0,
|
|
1923
|
+
device: tensor.device,
|
|
1924
|
+
dtype: tensor.dtype,
|
|
1925
|
+
...options
|
|
1926
|
+
});
|
|
1767
1927
|
const outputValue = new Array(tensor.numel);
|
|
1768
1928
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1769
1929
|
outputValue[index] = (0, utils_1.randUniform)();
|
|
1770
1930
|
}
|
|
1771
1931
|
return new Tensor(outputValue, {
|
|
1772
1932
|
shape: tensor.shape,
|
|
1933
|
+
offset: 0,
|
|
1773
1934
|
numel: tensor.numel,
|
|
1774
1935
|
device: tensor.device,
|
|
1936
|
+
dtype: tensor.dtype,
|
|
1775
1937
|
...options
|
|
1776
1938
|
});
|
|
1777
1939
|
}
|
|
@@ -1784,20 +1946,32 @@ class Tensor {
|
|
|
1784
1946
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1785
1947
|
outputValue[index] = (0, utils_1.randNormal)();
|
|
1786
1948
|
}
|
|
1787
|
-
return new Tensor(outputValue, {
|
|
1949
|
+
return new Tensor(outputValue, {
|
|
1950
|
+
shape,
|
|
1951
|
+
offset: 0,
|
|
1952
|
+
numel: outputSize,
|
|
1953
|
+
...options
|
|
1954
|
+
});
|
|
1788
1955
|
}
|
|
1789
1956
|
// Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
|
|
1790
1957
|
static randnLike(tensor, options = {}) {
|
|
1791
|
-
if (
|
|
1792
|
-
return new Tensor((0, utils_1.randNormal)(),
|
|
1958
|
+
if (tensor.shape.length === 0)
|
|
1959
|
+
return new Tensor((0, utils_1.randNormal)(), {
|
|
1960
|
+
offset: 0,
|
|
1961
|
+
device: tensor.device,
|
|
1962
|
+
dtype: tensor.dtype,
|
|
1963
|
+
...options
|
|
1964
|
+
});
|
|
1793
1965
|
const outputValue = new Array(tensor.numel);
|
|
1794
1966
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1795
1967
|
outputValue[index] = (0, utils_1.randNormal)();
|
|
1796
1968
|
}
|
|
1797
1969
|
return new Tensor(outputValue, {
|
|
1798
1970
|
shape: tensor.shape,
|
|
1971
|
+
offset: 0,
|
|
1799
1972
|
numel: tensor.numel,
|
|
1800
1973
|
device: tensor.device,
|
|
1974
|
+
dtype: tensor.dtype,
|
|
1801
1975
|
...options
|
|
1802
1976
|
});
|
|
1803
1977
|
}
|
|
@@ -1810,20 +1984,32 @@ class Tensor {
|
|
|
1810
1984
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1811
1985
|
outputValue[index] = (0, utils_1.randInt)(low, high);
|
|
1812
1986
|
}
|
|
1813
|
-
return new Tensor(outputValue, {
|
|
1987
|
+
return new Tensor(outputValue, {
|
|
1988
|
+
shape,
|
|
1989
|
+
offset: 0,
|
|
1990
|
+
numel: outputSize,
|
|
1991
|
+
...options
|
|
1992
|
+
});
|
|
1814
1993
|
}
|
|
1815
1994
|
// Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
|
|
1816
1995
|
static randintLike(tensor, low, high, options = {}) {
|
|
1817
|
-
if (
|
|
1818
|
-
return new Tensor((0, utils_1.randInt)(low, high),
|
|
1996
|
+
if (tensor.shape.length === 0)
|
|
1997
|
+
return new Tensor((0, utils_1.randInt)(low, high), {
|
|
1998
|
+
offset: 0,
|
|
1999
|
+
device: tensor.device,
|
|
2000
|
+
dtype: tensor.dtype,
|
|
2001
|
+
...options
|
|
2002
|
+
});
|
|
1819
2003
|
const outputValue = new Array(tensor.numel);
|
|
1820
2004
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1821
2005
|
outputValue[index] = (0, utils_1.randInt)(low, high);
|
|
1822
2006
|
}
|
|
1823
2007
|
return new Tensor(outputValue, {
|
|
1824
2008
|
shape: tensor.shape,
|
|
2009
|
+
offset: 0,
|
|
1825
2010
|
numel: tensor.numel,
|
|
1826
2011
|
device: tensor.device,
|
|
2012
|
+
dtype: tensor.dtype,
|
|
1827
2013
|
...options
|
|
1828
2014
|
});
|
|
1829
2015
|
}
|
|
@@ -1834,7 +2020,12 @@ class Tensor {
|
|
|
1834
2020
|
outputValue[i] = i;
|
|
1835
2021
|
}
|
|
1836
2022
|
(0, utils_1.fyShuffle)(outputValue);
|
|
1837
|
-
return new Tensor(outputValue, {
|
|
2023
|
+
return new Tensor(outputValue, {
|
|
2024
|
+
shape: [n],
|
|
2025
|
+
offset: 0,
|
|
2026
|
+
numel: n,
|
|
2027
|
+
...options
|
|
2028
|
+
});
|
|
1838
2029
|
}
|
|
1839
2030
|
// Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
|
|
1840
2031
|
static normal(shape, mean, stdDev, options = {}) {
|
|
@@ -1845,7 +2036,12 @@ class Tensor {
|
|
|
1845
2036
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1846
2037
|
outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
|
|
1847
2038
|
}
|
|
1848
|
-
return new Tensor(outputValue, {
|
|
2039
|
+
return new Tensor(outputValue, {
|
|
2040
|
+
shape,
|
|
2041
|
+
offset: 0,
|
|
2042
|
+
numel: outputSize,
|
|
2043
|
+
...options
|
|
2044
|
+
});
|
|
1849
2045
|
}
|
|
1850
2046
|
// Utility to create a new tensor filled with a random number with uniform distribution from low to high
|
|
1851
2047
|
static uniform(shape, low, high, options = {}) {
|
|
@@ -1856,7 +2052,12 @@ class Tensor {
|
|
|
1856
2052
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1857
2053
|
outputValue[index] = (0, utils_1.randUniform)(low, high);
|
|
1858
2054
|
}
|
|
1859
|
-
return new Tensor(outputValue, {
|
|
2055
|
+
return new Tensor(outputValue, {
|
|
2056
|
+
shape,
|
|
2057
|
+
offset: 0,
|
|
2058
|
+
numel: outputSize,
|
|
2059
|
+
...options
|
|
2060
|
+
});
|
|
1860
2061
|
}
|
|
1861
2062
|
// Utility to create an 1D tensor from a range incrementing with "step"
|
|
1862
2063
|
static arange(start, stop, step = 1, options = {}) {
|
|
@@ -1870,7 +2071,12 @@ class Tensor {
|
|
|
1870
2071
|
for (let index = 0; index < outputValue.length; index++) {
|
|
1871
2072
|
outputValue[index] = start + step * index;
|
|
1872
2073
|
}
|
|
1873
|
-
return new Tensor(outputValue, {
|
|
2074
|
+
return new Tensor(outputValue, {
|
|
2075
|
+
shape: outputShape,
|
|
2076
|
+
offset: 0,
|
|
2077
|
+
numel: outputSize,
|
|
2078
|
+
...options
|
|
2079
|
+
});
|
|
1874
2080
|
}
|
|
1875
2081
|
// Utility to create an 1D tensor from a range evenly spaced out with a given amount of steps
|
|
1876
2082
|
static linspace(start, stop, steps, options = {}) {
|
|
@@ -1886,7 +2092,12 @@ class Tensor {
|
|
|
1886
2092
|
}
|
|
1887
2093
|
// Ensure we hit the endpoint exactly (avoids floating point errors)
|
|
1888
2094
|
outputValue[steps - 1] = stop;
|
|
1889
|
-
return new Tensor(outputValue, {
|
|
2095
|
+
return new Tensor(outputValue, {
|
|
2096
|
+
shape: [steps],
|
|
2097
|
+
offset: 0,
|
|
2098
|
+
numel: steps,
|
|
2099
|
+
...options
|
|
2100
|
+
});
|
|
1890
2101
|
}
|
|
1891
2102
|
// Utility to create a 2D tensor with its main diagonal filled with 1s and others with 0s
|
|
1892
2103
|
static eye(n, m = n, options = {}) {
|
|
@@ -1897,7 +2108,13 @@ class Tensor {
|
|
|
1897
2108
|
for (let i = 0; i < Math.min(n, m); i++) {
|
|
1898
2109
|
outputValue[i * outputStrides[0] + i * outputStrides[1]] = 1;
|
|
1899
2110
|
}
|
|
1900
|
-
return new Tensor(outputValue, {
|
|
2111
|
+
return new Tensor(outputValue, {
|
|
2112
|
+
shape: outputShape,
|
|
2113
|
+
offset: 0,
|
|
2114
|
+
strides: outputStrides,
|
|
2115
|
+
numel: outputSize,
|
|
2116
|
+
...options
|
|
2117
|
+
});
|
|
1901
2118
|
}
|
|
1902
2119
|
// Reverse-mode autodiff call
|
|
1903
2120
|
backward(options = {}) {
|
|
@@ -1928,8 +2145,8 @@ class Tensor {
|
|
|
1928
2145
|
}
|
|
1929
2146
|
// Returns the raw number/nD array form of tensor
|
|
1930
2147
|
val() {
|
|
1931
|
-
if (
|
|
1932
|
-
return this.value;
|
|
2148
|
+
if (this.shape.length === 0)
|
|
2149
|
+
return this.value[0];
|
|
1933
2150
|
function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
|
|
1934
2151
|
if (dim === shape.length - 1) {
|
|
1935
2152
|
// Last dimension: extract elements using actual stride
|
|
@@ -1956,20 +2173,28 @@ class Tensor {
|
|
|
1956
2173
|
offset: this.offset,
|
|
1957
2174
|
numel: this.numel,
|
|
1958
2175
|
device: this.device,
|
|
2176
|
+
dtype: this.dtype,
|
|
1959
2177
|
requiresGrad: false
|
|
1960
2178
|
});
|
|
1961
2179
|
}
|
|
1962
2180
|
// Returns a copy of the tensor (with new data allocation) and keeps grad connection
|
|
1963
2181
|
clone() {
|
|
1964
2182
|
let out;
|
|
1965
|
-
if (
|
|
1966
|
-
out = new Tensor(this.value
|
|
2183
|
+
if (this.shape.length === 0) {
|
|
2184
|
+
out = new Tensor(this.value, {
|
|
2185
|
+
shape: [],
|
|
2186
|
+
strides: [],
|
|
2187
|
+
offset: 0,
|
|
2188
|
+
numel: 1,
|
|
2189
|
+
device: this.device,
|
|
2190
|
+
dtype: this.dtype
|
|
2191
|
+
});
|
|
1967
2192
|
}
|
|
1968
2193
|
else {
|
|
1969
2194
|
const contiguous = this.isContiguous();
|
|
1970
2195
|
const outputStrides = contiguous ? this.strides : Tensor.getStrides(this.shape);
|
|
1971
2196
|
const outputSize = this.numel;
|
|
1972
|
-
const outputValue = new
|
|
2197
|
+
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
1973
2198
|
if (contiguous) {
|
|
1974
2199
|
for (let index = 0; index < outputSize; index++) {
|
|
1975
2200
|
outputValue[index] = this.value[this.offset + index];
|
|
@@ -1982,7 +2207,14 @@ class Tensor {
|
|
|
1982
2207
|
outputValue[index] = this.value[this.offset + originalIndex];
|
|
1983
2208
|
}
|
|
1984
2209
|
}
|
|
1985
|
-
out = new Tensor(outputValue, {
|
|
2210
|
+
out = new Tensor(outputValue, {
|
|
2211
|
+
shape: this.shape,
|
|
2212
|
+
strides: outputStrides,
|
|
2213
|
+
offset: 0,
|
|
2214
|
+
numel: outputSize,
|
|
2215
|
+
device: this.device,
|
|
2216
|
+
dtype: this.dtype
|
|
2217
|
+
});
|
|
1986
2218
|
}
|
|
1987
2219
|
if (this.requiresGrad) {
|
|
1988
2220
|
out.requiresGrad = true;
|
|
@@ -2009,8 +2241,23 @@ class Tensor {
|
|
|
2009
2241
|
this.value = other.value;
|
|
2010
2242
|
this.strides = other.strides;
|
|
2011
2243
|
this.offset = other.offset;
|
|
2244
|
+
this.device = other.device;
|
|
2245
|
+
this.dtype = other.dtype;
|
|
2012
2246
|
return this;
|
|
2013
2247
|
}
|
|
2248
|
+
// Op to return a new tensor casted to another dtype
|
|
2249
|
+
cast(dtype) {
|
|
2250
|
+
if (this.dtype === dtype)
|
|
2251
|
+
return this;
|
|
2252
|
+
return new Tensor(this.value, {
|
|
2253
|
+
shape: this.shape,
|
|
2254
|
+
strides: this.strides,
|
|
2255
|
+
offset: this.offset,
|
|
2256
|
+
numel: this.numel,
|
|
2257
|
+
device: this.device,
|
|
2258
|
+
dtype: dtype
|
|
2259
|
+
});
|
|
2260
|
+
}
|
|
2014
2261
|
// Holds all available backends
|
|
2015
2262
|
static backends = new Map();
|
|
2016
2263
|
// Op to transfer tensor to another device
|