catniff 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +3 -3
- package/dist/core.js +131 -151
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -19,14 +19,14 @@ export declare class Tensor {
|
|
|
19
19
|
static flatten(tensor: TensorValue): number[] | number;
|
|
20
20
|
static getShape(tensor: TensorValue): number[];
|
|
21
21
|
static getStrides(shape: number[]): number[];
|
|
22
|
-
static
|
|
22
|
+
static padShape(stridesA: number[], stridesB: number[], shapeA: number[], shapeB: number[]): number[][];
|
|
23
23
|
static broadcastShapes(shapeA: number[], shapeB: number[]): number[];
|
|
24
24
|
static indexToCoords(index: number, shape: number[], strides: number[]): number[];
|
|
25
25
|
static coordsToIndex(coords: number[], shape: number[], strides: number[]): number;
|
|
26
26
|
static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
|
|
27
27
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
28
|
-
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad
|
|
29
|
-
elementWiseSelfDAG(op: (a: number) => number, thisGrad
|
|
28
|
+
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
29
|
+
elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
30
30
|
static forceTensor(value: TensorValue | Tensor): Tensor;
|
|
31
31
|
static addGrad(tensor: Tensor, accumGrad: Tensor): void;
|
|
32
32
|
squeeze(dims?: number[] | number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -53,16 +53,21 @@ class Tensor {
|
|
|
53
53
|
}
|
|
54
54
|
return strides;
|
|
55
55
|
}
|
|
56
|
-
// Left-pad
|
|
57
|
-
static
|
|
58
|
-
const
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
}
|
|
65
|
-
|
|
56
|
+
// Left-pad shape and strides for two shape to be of same length
|
|
57
|
+
static padShape(stridesA, stridesB, shapeA, shapeB) {
|
|
58
|
+
const newStrideA = [...stridesA], newStrideB = [...stridesB];
|
|
59
|
+
const newShapeA = [...shapeA], newShapeB = [...shapeB];
|
|
60
|
+
while (newStrideA.length < newStrideB.length) {
|
|
61
|
+
const newStride = newShapeA[0] * newStrideA[0];
|
|
62
|
+
newStrideA.unshift(newStride);
|
|
63
|
+
newShapeA.unshift(1);
|
|
64
|
+
}
|
|
65
|
+
while (newStrideA.length > newStrideB.length) {
|
|
66
|
+
const newStride = newShapeB[0] * newStrideB[0];
|
|
67
|
+
newStrideB.unshift(newStride);
|
|
68
|
+
newShapeB.unshift(1);
|
|
69
|
+
}
|
|
70
|
+
return [newStrideA, newStrideB, newShapeA, newShapeB];
|
|
66
71
|
}
|
|
67
72
|
// Broadcast shapes
|
|
68
73
|
static broadcastShapes(shapeA, shapeB) {
|
|
@@ -118,7 +123,7 @@ class Tensor {
|
|
|
118
123
|
return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
|
|
119
124
|
}
|
|
120
125
|
// Pad + broadcast shape
|
|
121
|
-
const [paddedAShape, paddedBShape] = Tensor.
|
|
126
|
+
const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
|
|
122
127
|
const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
|
|
123
128
|
// Get other output info
|
|
124
129
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
@@ -128,9 +133,9 @@ class Tensor {
|
|
|
128
133
|
// Get coordinates from 1D index
|
|
129
134
|
const coordsOutput = Tensor.indexToCoords(i, outputShape, outputStrides);
|
|
130
135
|
// Convert the coordinates to 1D index of flattened A with respect to A's shape
|
|
131
|
-
const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape,
|
|
136
|
+
const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape, paddedAStrides);
|
|
132
137
|
// Convert the coordinates to 1D index of flattened B with respect to B's shape
|
|
133
|
-
const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape,
|
|
138
|
+
const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape, paddedBStrides);
|
|
134
139
|
// Calculate with op
|
|
135
140
|
outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
|
|
136
141
|
}
|
|
@@ -143,10 +148,10 @@ class Tensor {
|
|
|
143
148
|
static elementWiseSelf(tA, op) {
|
|
144
149
|
if (typeof tA.value === "number")
|
|
145
150
|
return new Tensor(op(tA.value));
|
|
146
|
-
return new Tensor(tA.value.map(el => op(el)), { shape: tA.shape, strides: tA.strides });
|
|
151
|
+
return new Tensor(tA.value.map(el => op(el)), { shape: [...tA.shape], strides: [...tA.strides] });
|
|
147
152
|
}
|
|
148
153
|
// Utility to do element-wise operation and build a dag node with another tensor
|
|
149
|
-
elementWiseABDAG(other, op, thisGrad, otherGrad) {
|
|
154
|
+
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
|
|
150
155
|
other = Tensor.forceTensor(other);
|
|
151
156
|
const out = Tensor.elementWiseAB(this, other, op);
|
|
152
157
|
if (this.requiresGrad) {
|
|
@@ -159,17 +164,20 @@ class Tensor {
|
|
|
159
164
|
}
|
|
160
165
|
if (out.requiresGrad) {
|
|
161
166
|
out.gradFn = () => {
|
|
162
|
-
|
|
167
|
+
// Disable gradient collecting of gradients themselves
|
|
168
|
+
const outGrad = out.grad.withGrad(false);
|
|
169
|
+
const selfNoGrad = this.withGrad(false);
|
|
170
|
+
const otherNoGrad = other.withGrad(false);
|
|
163
171
|
if (this.requiresGrad)
|
|
164
|
-
|
|
172
|
+
Tensor.addGrad(this, thisGrad(selfNoGrad, otherNoGrad, outGrad));
|
|
165
173
|
if (other.requiresGrad)
|
|
166
|
-
otherGrad(
|
|
174
|
+
Tensor.addGrad(other, otherGrad(selfNoGrad, otherNoGrad, outGrad));
|
|
167
175
|
};
|
|
168
176
|
}
|
|
169
177
|
return out;
|
|
170
178
|
}
|
|
171
179
|
// Utility to do self-inflicting element-wise operation and build a dag node
|
|
172
|
-
elementWiseSelfDAG(op, thisGrad) {
|
|
180
|
+
elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) {
|
|
173
181
|
const out = Tensor.elementWiseSelf(this, op);
|
|
174
182
|
if (this.requiresGrad) {
|
|
175
183
|
out.requiresGrad = true;
|
|
@@ -177,9 +185,11 @@ class Tensor {
|
|
|
177
185
|
}
|
|
178
186
|
if (out.requiresGrad) {
|
|
179
187
|
out.gradFn = () => {
|
|
180
|
-
|
|
188
|
+
// Disable gradient collecting of gradients themselves
|
|
189
|
+
const outGrad = out.grad.withGrad(false);
|
|
190
|
+
const selfNoGrad = this.withGrad(false);
|
|
181
191
|
if (this.requiresGrad)
|
|
182
|
-
|
|
192
|
+
Tensor.addGrad(this, thisGrad(selfNoGrad, outGrad));
|
|
183
193
|
};
|
|
184
194
|
}
|
|
185
195
|
return out;
|
|
@@ -207,7 +217,6 @@ class Tensor {
|
|
|
207
217
|
}
|
|
208
218
|
}
|
|
209
219
|
const reducedGrad = accumGrad.sum(axesToReduce, true);
|
|
210
|
-
// console.log(accumGrad, new Tensor([[1,1,1]]));
|
|
211
220
|
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
|
|
212
221
|
if (typeof tensor.grad === "undefined") {
|
|
213
222
|
tensor.grad = squeezedGrad;
|
|
@@ -238,7 +247,7 @@ class Tensor {
|
|
|
238
247
|
throw new Error(`Can not squeeze dim with size ${dim}`);
|
|
239
248
|
return !shouldSqueeze;
|
|
240
249
|
});
|
|
241
|
-
const outStrides =
|
|
250
|
+
const outStrides = this.strides.filter((stride, i) => !dims.includes(i));
|
|
242
251
|
const outValue = outShape.length === 0 ? this.value[0] : this.value;
|
|
243
252
|
const out = new Tensor(outValue, {
|
|
244
253
|
shape: outShape,
|
|
@@ -249,7 +258,7 @@ class Tensor {
|
|
|
249
258
|
out.requiresGrad = true;
|
|
250
259
|
out.children.push(this);
|
|
251
260
|
out.gradFn = () => {
|
|
252
|
-
let restoredGrad = out.grad;
|
|
261
|
+
let restoredGrad = out.grad.withGrad(false);
|
|
253
262
|
for (let i = dims.length - 1; i >= 0; i--) {
|
|
254
263
|
restoredGrad = restoredGrad.unsqueeze(dims[i]);
|
|
255
264
|
}
|
|
@@ -268,19 +277,25 @@ class Tensor {
|
|
|
268
277
|
// Insert size-1 dimension at specified position
|
|
269
278
|
const newShape = [...this.shape];
|
|
270
279
|
newShape.splice(dim, 0, 1);
|
|
271
|
-
//
|
|
280
|
+
// New stride
|
|
272
281
|
const newStrides = [...this.strides];
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
282
|
+
let newDimStride;
|
|
283
|
+
if (dim === 0) {
|
|
284
|
+
// Inserting at front: use product of all original dimensions
|
|
285
|
+
newDimStride = this.shape.reduce((a, b) => a * b, 1) || 1;
|
|
286
|
+
}
|
|
287
|
+
else {
|
|
288
|
+
// Inserting elsewhere: use stride of previous dimension
|
|
289
|
+
newDimStride = this.strides[dim - 1];
|
|
290
|
+
}
|
|
291
|
+
newStrides.splice(dim, 0, newDimStride);
|
|
277
292
|
const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
|
|
278
293
|
// Set up gradient if needed
|
|
279
294
|
if (this.requiresGrad) {
|
|
280
295
|
out.requiresGrad = true;
|
|
281
296
|
out.children.push(this);
|
|
282
297
|
out.gradFn = () => {
|
|
283
|
-
Tensor.addGrad(this, out.grad.squeeze(dim));
|
|
298
|
+
Tensor.addGrad(this, out.grad.withGrad(false).squeeze(dim));
|
|
284
299
|
};
|
|
285
300
|
}
|
|
286
301
|
return out;
|
|
@@ -300,6 +315,12 @@ class Tensor {
|
|
|
300
315
|
const outputSize = outputShape.reduce((a, b) => a * b, 1);
|
|
301
316
|
const outputValue = new Array(outputSize).fill(0);
|
|
302
317
|
const originalSize = this.shape.reduce((a, b) => a * b, 1);
|
|
318
|
+
let gradShape, gradStrides, gradValue = [];
|
|
319
|
+
if (this.requiresGrad) {
|
|
320
|
+
gradShape = [...this.shape];
|
|
321
|
+
gradStrides = [...this.strides];
|
|
322
|
+
gradValue = new Array(originalSize).fill(0);
|
|
323
|
+
}
|
|
303
324
|
for (let index = 0; index < originalSize; index++) {
|
|
304
325
|
const coords = Tensor.indexToCoords(index, this.shape, this.strides);
|
|
305
326
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
@@ -309,249 +330,200 @@ class Tensor {
|
|
|
309
330
|
// Accumulate
|
|
310
331
|
const realFlatIndex = coords.reduce((acc, val, i) => acc + val * this.strides[i], 0);
|
|
311
332
|
outputValue[outFlatIndex] += this.value[realFlatIndex];
|
|
333
|
+
// Mark for gradient
|
|
334
|
+
if (this.requiresGrad) {
|
|
335
|
+
(gradValue)[realFlatIndex] = 1;
|
|
336
|
+
}
|
|
312
337
|
}
|
|
313
338
|
const out = new Tensor(outputValue, {
|
|
314
339
|
shape: outputShape,
|
|
315
340
|
strides: outputStrides
|
|
316
341
|
});
|
|
342
|
+
// Set up gradient if needed
|
|
343
|
+
if (this.requiresGrad) {
|
|
344
|
+
out.requiresGrad = true;
|
|
345
|
+
out.children.push(this);
|
|
346
|
+
out.gradFn = () => {
|
|
347
|
+
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
348
|
+
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
349
|
+
};
|
|
350
|
+
}
|
|
317
351
|
return keepDims ? out : out.squeeze(dims);
|
|
318
352
|
}
|
|
319
353
|
// Tensor element-wise addition
|
|
320
354
|
add(other) {
|
|
321
|
-
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) =>
|
|
322
|
-
Tensor.addGrad(self, outGrad);
|
|
323
|
-
}, (self, other, outGrad) => {
|
|
324
|
-
Tensor.addGrad(other, outGrad);
|
|
325
|
-
});
|
|
355
|
+
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
|
|
326
356
|
}
|
|
327
357
|
// Tensor element-wise subtraction
|
|
328
358
|
sub(other) {
|
|
329
|
-
return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) =>
|
|
330
|
-
Tensor.addGrad(self, outGrad);
|
|
331
|
-
}, (self, other, outGrad) => {
|
|
332
|
-
Tensor.addGrad(other, outGrad.neg());
|
|
333
|
-
});
|
|
359
|
+
return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
|
|
334
360
|
}
|
|
335
361
|
// Tensor element-wise multiplication
|
|
336
362
|
mul(other) {
|
|
337
|
-
return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) =>
|
|
338
|
-
Tensor.addGrad(self, outGrad.mul(other));
|
|
339
|
-
}, (self, other, outGrad) => {
|
|
340
|
-
Tensor.addGrad(other, outGrad.mul(self));
|
|
341
|
-
});
|
|
363
|
+
return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
|
|
342
364
|
}
|
|
343
365
|
// Tensor element-wise power
|
|
344
366
|
pow(other) {
|
|
345
|
-
return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) =>
|
|
346
|
-
Tensor.addGrad(self, outGrad.mul(other.mul(self.pow(other.sub(1)))));
|
|
347
|
-
}, (self, other, outGrad) => {
|
|
348
|
-
Tensor.addGrad(other, outGrad.mul(self.pow(other).mul(self.log())));
|
|
349
|
-
});
|
|
367
|
+
return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
|
|
350
368
|
}
|
|
351
369
|
// Tensor element-wise division
|
|
352
370
|
div(other) {
|
|
353
|
-
return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) =>
|
|
354
|
-
Tensor.addGrad(self, outGrad.div(other));
|
|
355
|
-
}, (self, other, outGrad) => {
|
|
356
|
-
Tensor.addGrad(other, outGrad.mul(self.neg().div(other.pow(2))));
|
|
357
|
-
});
|
|
371
|
+
return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.pow(2))));
|
|
358
372
|
}
|
|
359
373
|
// Tensor element-wise greater or equal comparison
|
|
360
374
|
ge(other) {
|
|
361
|
-
return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0
|
|
375
|
+
return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
|
|
362
376
|
}
|
|
363
377
|
// Tensor element-wise less or equal comparison
|
|
364
378
|
le(other) {
|
|
365
|
-
return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0
|
|
379
|
+
return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
|
|
366
380
|
}
|
|
367
381
|
// Tensor element-wise greater-than comparison
|
|
368
382
|
gt(other) {
|
|
369
|
-
return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0
|
|
383
|
+
return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
|
|
370
384
|
}
|
|
371
385
|
// Tensor element-wise less-than comparison
|
|
372
386
|
lt(other) {
|
|
373
|
-
return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0
|
|
387
|
+
return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
|
|
374
388
|
}
|
|
375
389
|
// Tensor element-wise equality comparison
|
|
376
390
|
eq(other) {
|
|
377
|
-
return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0
|
|
391
|
+
return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
|
|
378
392
|
}
|
|
379
393
|
// Tensor element-wise logical and
|
|
380
394
|
logicalAnd(other) {
|
|
381
|
-
return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0
|
|
395
|
+
return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
|
|
382
396
|
}
|
|
383
397
|
// Tensor element-wise logical or
|
|
384
398
|
logicalOr(other) {
|
|
385
|
-
return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0
|
|
399
|
+
return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
|
|
386
400
|
}
|
|
387
401
|
// Tensor element-wise logical xor
|
|
388
402
|
logicalXor(other) {
|
|
389
|
-
return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0
|
|
403
|
+
return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
|
|
390
404
|
}
|
|
391
405
|
// Tensor element-wise logical not
|
|
392
406
|
logicalNot() {
|
|
393
|
-
return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1
|
|
407
|
+
return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
|
|
394
408
|
}
|
|
395
409
|
// Tensor element-wise bitwise and
|
|
396
410
|
bitwiseAnd(other) {
|
|
397
|
-
return this.elementWiseABDAG(other, (a, b) => a & b
|
|
411
|
+
return this.elementWiseABDAG(other, (a, b) => a & b);
|
|
398
412
|
}
|
|
399
413
|
// Tensor element-wise bitwise or
|
|
400
414
|
bitwiseOr(other) {
|
|
401
|
-
return this.elementWiseABDAG(other, (a, b) => a | b
|
|
415
|
+
return this.elementWiseABDAG(other, (a, b) => a | b);
|
|
402
416
|
}
|
|
403
417
|
// Tensor element-wise bitwise xor
|
|
404
418
|
bitwiseXor(other) {
|
|
405
|
-
return this.elementWiseABDAG(other, (a, b) => a ^ b
|
|
419
|
+
return this.elementWiseABDAG(other, (a, b) => a ^ b);
|
|
406
420
|
}
|
|
407
421
|
// Tensor element-wise bitwise not
|
|
408
422
|
bitwiseNot() {
|
|
409
|
-
return this.elementWiseSelfDAG((a) => ~a
|
|
423
|
+
return this.elementWiseSelfDAG((a) => ~a);
|
|
410
424
|
}
|
|
411
425
|
// Tensor element-wise left shift
|
|
412
426
|
bitwiseLeftShift(other) {
|
|
413
|
-
return this.elementWiseABDAG(other, (a, b) => a << b
|
|
427
|
+
return this.elementWiseABDAG(other, (a, b) => a << b);
|
|
414
428
|
}
|
|
415
429
|
// Tensor element-wise right shift
|
|
416
430
|
bitwiseRightShift(other) {
|
|
417
|
-
return this.elementWiseABDAG(other, (a, b) => a >> b
|
|
431
|
+
return this.elementWiseABDAG(other, (a, b) => a >> b);
|
|
418
432
|
}
|
|
419
433
|
// Tensor element-wise negation
|
|
420
434
|
neg() {
|
|
421
|
-
return this.elementWiseSelfDAG((a) => -a, (self, outGrad) =>
|
|
422
|
-
Tensor.addGrad(self, outGrad.mul(-1));
|
|
423
|
-
});
|
|
435
|
+
return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
|
|
424
436
|
}
|
|
425
437
|
// Tensor element-wise absolute
|
|
426
438
|
abs() {
|
|
427
|
-
return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) =>
|
|
428
|
-
Tensor.addGrad(self, outGrad.mul(self.sign()));
|
|
429
|
-
});
|
|
439
|
+
return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
|
|
430
440
|
}
|
|
431
441
|
// Tensor element-wise sign function
|
|
432
442
|
sign() {
|
|
433
|
-
return this.elementWiseSelfDAG((a) => Math.sign(a)
|
|
443
|
+
return this.elementWiseSelfDAG((a) => Math.sign(a));
|
|
434
444
|
}
|
|
435
445
|
// Tensor element-wise sin
|
|
436
446
|
sin() {
|
|
437
|
-
return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) =>
|
|
438
|
-
Tensor.addGrad(self, outGrad.mul(self.cos()));
|
|
439
|
-
});
|
|
447
|
+
return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
|
|
440
448
|
}
|
|
441
449
|
// Tensor element-wise cos
|
|
442
450
|
cos() {
|
|
443
|
-
return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) =>
|
|
444
|
-
Tensor.addGrad(self, outGrad.mul(self.sin().neg()));
|
|
445
|
-
});
|
|
451
|
+
return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
|
|
446
452
|
}
|
|
447
453
|
// Tensor element-wise tan
|
|
448
454
|
tan() {
|
|
449
|
-
return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) =>
|
|
450
|
-
Tensor.addGrad(self, outGrad.mul(self.tan().pow(2).add(1)));
|
|
451
|
-
});
|
|
455
|
+
return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().pow(2).add(1)));
|
|
452
456
|
}
|
|
453
457
|
// Tensor element-wise asin
|
|
454
458
|
asin() {
|
|
455
|
-
return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) =>
|
|
456
|
-
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()));
|
|
457
|
-
});
|
|
459
|
+
return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1).sqrt()));
|
|
458
460
|
}
|
|
459
461
|
// Tensor element-wise acos
|
|
460
462
|
acos() {
|
|
461
|
-
return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) =>
|
|
462
|
-
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
|
|
463
|
-
});
|
|
463
|
+
return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
|
|
464
464
|
}
|
|
465
465
|
// Tensor element-wise atan
|
|
466
466
|
atan() {
|
|
467
|
-
return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) =>
|
|
468
|
-
Tensor.addGrad(self, outGrad.div(self.pow(2).add(1)));
|
|
469
|
-
});
|
|
467
|
+
return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.pow(2).add(1)));
|
|
470
468
|
}
|
|
471
469
|
// Tensor element-wise sinh
|
|
472
470
|
sinh() {
|
|
473
|
-
return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) =>
|
|
474
|
-
Tensor.addGrad(self, outGrad.mul(self.cosh()));
|
|
475
|
-
});
|
|
471
|
+
return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
|
|
476
472
|
}
|
|
477
473
|
// Tensor element-wise cosh
|
|
478
474
|
cosh() {
|
|
479
|
-
return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) =>
|
|
480
|
-
Tensor.addGrad(self, outGrad.mul(self.sinh()));
|
|
481
|
-
});
|
|
475
|
+
return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
|
|
482
476
|
}
|
|
483
477
|
// Tensor element-wise asinh
|
|
484
478
|
asinh() {
|
|
485
|
-
return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) =>
|
|
486
|
-
Tensor.addGrad(self, outGrad.div(self.pow(2).add(1).sqrt()));
|
|
487
|
-
});
|
|
479
|
+
return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.pow(2).add(1).sqrt()));
|
|
488
480
|
}
|
|
489
481
|
// Tensor element-wise acosh
|
|
490
482
|
acosh() {
|
|
491
|
-
return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) =>
|
|
492
|
-
Tensor.addGrad(self, outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
|
|
493
|
-
});
|
|
483
|
+
return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
|
|
494
484
|
}
|
|
495
485
|
// Tensor element-wise atanh
|
|
496
486
|
atanh() {
|
|
497
|
-
return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) =>
|
|
498
|
-
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1)));
|
|
499
|
-
});
|
|
487
|
+
return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1)));
|
|
500
488
|
}
|
|
501
489
|
// Tensor element-wise square root
|
|
502
490
|
sqrt() {
|
|
503
|
-
return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) =>
|
|
504
|
-
Tensor.addGrad(self, outGrad.div(self.sqrt().mul(2)));
|
|
505
|
-
});
|
|
491
|
+
return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
|
|
506
492
|
}
|
|
507
493
|
// Tensor element-wise e^x
|
|
508
494
|
exp() {
|
|
509
|
-
return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) =>
|
|
510
|
-
Tensor.addGrad(self, outGrad.mul(self.exp()));
|
|
511
|
-
});
|
|
495
|
+
return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
|
|
512
496
|
}
|
|
513
497
|
// Tensor element-wise natural log
|
|
514
498
|
log() {
|
|
515
|
-
return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) =>
|
|
516
|
-
Tensor.addGrad(self, outGrad.div(self));
|
|
517
|
-
});
|
|
499
|
+
return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
|
|
518
500
|
}
|
|
519
501
|
// Tensor element-wise log2
|
|
520
502
|
log2() {
|
|
521
|
-
return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) =>
|
|
522
|
-
Tensor.addGrad(self, outGrad.div(self.mul(Math.log(2))));
|
|
523
|
-
});
|
|
503
|
+
return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
|
|
524
504
|
}
|
|
525
505
|
// Tensor element-wise log10
|
|
526
506
|
log10() {
|
|
527
|
-
return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) =>
|
|
528
|
-
Tensor.addGrad(self, outGrad.div(self.mul(Math.log(10))));
|
|
529
|
-
});
|
|
507
|
+
return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
|
|
530
508
|
}
|
|
531
509
|
// Tensor element-wise log(1+x)
|
|
532
510
|
log1p() {
|
|
533
|
-
return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) =>
|
|
534
|
-
Tensor.addGrad(self, outGrad.div(self.add(1)));
|
|
535
|
-
});
|
|
511
|
+
return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
|
|
536
512
|
}
|
|
537
513
|
// Tensor element-wise relu
|
|
538
514
|
relu() {
|
|
539
|
-
return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) =>
|
|
540
|
-
Tensor.addGrad(self, outGrad.mul(self.ge(0)));
|
|
541
|
-
});
|
|
515
|
+
return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.ge(0)));
|
|
542
516
|
}
|
|
543
517
|
// Tensor element-wise sigmoid
|
|
544
518
|
sigmoid() {
|
|
545
519
|
return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
|
|
546
520
|
const sig = self.sigmoid();
|
|
547
|
-
|
|
521
|
+
return outGrad.mul(sig).mul(sig.neg().add(1));
|
|
548
522
|
});
|
|
549
523
|
}
|
|
550
524
|
// Tensor element-wise tanh
|
|
551
525
|
tanh() {
|
|
552
|
-
return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) =>
|
|
553
|
-
Tensor.addGrad(self, outGrad.mul(self.tanh().pow(2).neg().add(1)));
|
|
554
|
-
});
|
|
526
|
+
return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().pow(2).neg().add(1)));
|
|
555
527
|
}
|
|
556
528
|
// Transpose
|
|
557
529
|
transpose(dim1, dim2) {
|
|
@@ -561,7 +533,7 @@ class Tensor {
|
|
|
561
533
|
}
|
|
562
534
|
// If same dimension, return copy
|
|
563
535
|
if (dim1 === dim2) {
|
|
564
|
-
return new Tensor(this.value, { shape: this.shape, strides: this.strides });
|
|
536
|
+
return new Tensor(this.value, { shape: [...this.shape], strides: [...this.strides] });
|
|
565
537
|
}
|
|
566
538
|
// Create new shape and strides by swapping
|
|
567
539
|
const newShape = [...this.shape];
|
|
@@ -575,7 +547,7 @@ class Tensor {
|
|
|
575
547
|
if (this.requiresGrad) {
|
|
576
548
|
out.children.push(this);
|
|
577
549
|
out.gradFn = () => {
|
|
578
|
-
Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
|
|
550
|
+
Tensor.addGrad(this, out.grad.withGrad(false).transpose(dim1, dim2));
|
|
579
551
|
};
|
|
580
552
|
}
|
|
581
553
|
return out;
|
|
@@ -613,11 +585,13 @@ class Tensor {
|
|
|
613
585
|
}
|
|
614
586
|
if (out.requiresGrad) {
|
|
615
587
|
out.gradFn = () => {
|
|
616
|
-
const outGrad = out.grad;
|
|
588
|
+
const outGrad = out.grad.withGrad(false);
|
|
589
|
+
const selfNoGrad = this.withGrad(false);
|
|
590
|
+
const otherNoGrad = other.withGrad(false);
|
|
617
591
|
if (this.requiresGrad)
|
|
618
|
-
Tensor.addGrad(this, outGrad.mul(
|
|
592
|
+
Tensor.addGrad(this, outGrad.mul(otherNoGrad));
|
|
619
593
|
if (other.requiresGrad)
|
|
620
|
-
Tensor.addGrad(other, outGrad.mul(
|
|
594
|
+
Tensor.addGrad(other, outGrad.mul(selfNoGrad));
|
|
621
595
|
};
|
|
622
596
|
}
|
|
623
597
|
return out;
|
|
@@ -663,15 +637,18 @@ class Tensor {
|
|
|
663
637
|
}
|
|
664
638
|
if (out.requiresGrad) {
|
|
665
639
|
out.gradFn = () => {
|
|
666
|
-
const outGrad = out.grad;
|
|
640
|
+
const outGrad = out.grad.withGrad(false);
|
|
641
|
+
const selfNoGrad = this.withGrad(false);
|
|
642
|
+
const otherNoGrad = other.withGrad(false);
|
|
667
643
|
if (this.requiresGrad)
|
|
668
|
-
Tensor.addGrad(this, outGrad.mm(
|
|
644
|
+
Tensor.addGrad(this, outGrad.mm(otherNoGrad.t()));
|
|
669
645
|
if (other.requiresGrad)
|
|
670
|
-
Tensor.addGrad(other,
|
|
646
|
+
Tensor.addGrad(other, selfNoGrad.t().mm(outGrad));
|
|
671
647
|
};
|
|
672
648
|
}
|
|
673
649
|
return out;
|
|
674
650
|
}
|
|
651
|
+
// Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
|
|
675
652
|
mv(other) {
|
|
676
653
|
other = Tensor.forceTensor(other);
|
|
677
654
|
// Verify 2D shape
|
|
@@ -679,8 +656,8 @@ class Tensor {
|
|
|
679
656
|
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
680
657
|
}
|
|
681
658
|
// MM with no grad
|
|
682
|
-
const thisMat = new Tensor(this.value, { shape: this.shape, strides: this.strides });
|
|
683
|
-
const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [
|
|
659
|
+
const thisMat = new Tensor(this.value, { shape: [...this.shape], strides: [...this.strides] });
|
|
660
|
+
const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [other.strides[0], 1] });
|
|
684
661
|
const out = thisMat.mm(otherMat).squeeze(1);
|
|
685
662
|
// Handle grad with original tensors
|
|
686
663
|
if (this.requiresGrad) {
|
|
@@ -693,15 +670,18 @@ class Tensor {
|
|
|
693
670
|
}
|
|
694
671
|
if (out.requiresGrad) {
|
|
695
672
|
out.gradFn = () => {
|
|
696
|
-
const outGrad = out.grad;
|
|
673
|
+
const outGrad = out.grad.withGrad(false);
|
|
674
|
+
const selfNoGrad = this.withGrad(false);
|
|
675
|
+
const otherNoGrad = other.withGrad(false);
|
|
697
676
|
if (this.requiresGrad)
|
|
698
|
-
Tensor.addGrad(this, outGrad.unsqueeze(1).mm(
|
|
677
|
+
Tensor.addGrad(this, outGrad.unsqueeze(1).mm(otherNoGrad.unsqueeze(0)));
|
|
699
678
|
if (other.requiresGrad)
|
|
700
|
-
Tensor.addGrad(other,
|
|
679
|
+
Tensor.addGrad(other, selfNoGrad.t().mv(outGrad));
|
|
701
680
|
};
|
|
702
681
|
}
|
|
703
682
|
return out;
|
|
704
683
|
}
|
|
684
|
+
// General matrix multiplication with different shapes
|
|
705
685
|
matmul(other) {
|
|
706
686
|
other = Tensor.forceTensor(other);
|
|
707
687
|
if (this.shape.length === 1 && other.shape.length === 1) {
|
|
@@ -723,7 +703,7 @@ class Tensor {
|
|
|
723
703
|
static fullLike(tensor, num, options = {}) {
|
|
724
704
|
if (typeof tensor.value === "number")
|
|
725
705
|
return new Tensor(num, options);
|
|
726
|
-
return new Tensor(tensor.value.map(el => num), { shape: tensor.shape, strides: tensor.strides, ...options });
|
|
706
|
+
return new Tensor(tensor.value.map(el => num), { shape: [...tensor.shape], strides: [...tensor.strides], ...options });
|
|
727
707
|
}
|
|
728
708
|
// Reverse-mode autodiff call
|
|
729
709
|
backward() {
|
|
@@ -771,8 +751,8 @@ class Tensor {
|
|
|
771
751
|
// Returns a copy of the tensor with gradient turned on/off
|
|
772
752
|
withGrad(requiresGrad) {
|
|
773
753
|
return new Tensor(this.value, {
|
|
774
|
-
shape: this.shape,
|
|
775
|
-
strides: this.strides,
|
|
754
|
+
shape: [...this.shape],
|
|
755
|
+
strides: [...this.strides],
|
|
776
756
|
requiresGrad
|
|
777
757
|
});
|
|
778
758
|
}
|