catniff 0.1.10 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +39 -36
- package/dist/core.d.ts +88 -0
- package/dist/core.js +780 -0
- package/index.js +1 -2
- package/package.json +1 -1
- package/dist/autograd.d.ts +0 -112
- package/dist/autograd.js +0 -547
- package/dist/tensor.d.ts +0 -62
- package/dist/tensor.js +0 -336
package/dist/core.js
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Tensor = void 0;
|
|
4
|
+
class Tensor {
|
|
5
|
+
value;
|
|
6
|
+
shape;
|
|
7
|
+
strides;
|
|
8
|
+
grad;
|
|
9
|
+
requiresGrad;
|
|
10
|
+
gradFn;
|
|
11
|
+
children;
|
|
12
|
+
constructor(value, options = {}) {
|
|
13
|
+
this.value = Tensor.flatten(value);
|
|
14
|
+
this.shape = options.shape || Tensor.getShape(value);
|
|
15
|
+
this.strides = options.strides || Tensor.getStrides(this.shape);
|
|
16
|
+
this.grad = options.grad;
|
|
17
|
+
this.requiresGrad = options.requiresGrad ?? false;
|
|
18
|
+
this.gradFn = options.gradFn || (() => { });
|
|
19
|
+
this.children = options.children || [];
|
|
20
|
+
}
|
|
21
|
+
// Utility to flatten an nD array to be 1D
|
|
22
|
+
static flatten(tensor) {
|
|
23
|
+
if (typeof tensor === "number")
|
|
24
|
+
return tensor;
|
|
25
|
+
const result = [];
|
|
26
|
+
function traverse(arr) {
|
|
27
|
+
if (typeof arr === "number") {
|
|
28
|
+
result.push(arr);
|
|
29
|
+
}
|
|
30
|
+
else if (Array.isArray(arr)) {
|
|
31
|
+
arr.forEach(traverse);
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
traverse(tensor);
|
|
35
|
+
return result;
|
|
36
|
+
}
|
|
37
|
+
// Utility to get shape from tensor *value*
|
|
38
|
+
static getShape(tensor) {
|
|
39
|
+
const shape = [];
|
|
40
|
+
let subA = tensor;
|
|
41
|
+
while (Array.isArray(subA)) {
|
|
42
|
+
shape.push(subA.length);
|
|
43
|
+
subA = subA[0];
|
|
44
|
+
}
|
|
45
|
+
return shape;
|
|
46
|
+
}
|
|
47
|
+
// Utility to get strides from shape
|
|
48
|
+
static getStrides(shape) {
|
|
49
|
+
const strides = new Array(shape.length);
|
|
50
|
+
strides[strides.length - 1] = 1;
|
|
51
|
+
for (let i = strides.length - 2; i >= 0; i--) {
|
|
52
|
+
strides[i] = strides[i + 1] * shape[i + 1];
|
|
53
|
+
}
|
|
54
|
+
return strides;
|
|
55
|
+
}
|
|
56
|
+
// Left-pad dimensions for two shape to be of same length
|
|
57
|
+
static padDims(shapeA, shapeB) {
|
|
58
|
+
const newA = [...shapeA], newB = [...shapeB];
|
|
59
|
+
while (newA.length < newB.length) {
|
|
60
|
+
newA.unshift(1);
|
|
61
|
+
}
|
|
62
|
+
while (newA.length > newB.length) {
|
|
63
|
+
newB.unshift(1);
|
|
64
|
+
}
|
|
65
|
+
return [newA, newB];
|
|
66
|
+
}
|
|
67
|
+
// Broadcast shapes
|
|
68
|
+
static broadcastShapes(shapeA, shapeB) {
|
|
69
|
+
const newShape = new Array(shapeA.length);
|
|
70
|
+
for (let index = 0; index < shapeA.length; index++) {
|
|
71
|
+
if (shapeA[index] === 1) {
|
|
72
|
+
newShape[index] = shapeB[index];
|
|
73
|
+
}
|
|
74
|
+
else if (shapeB[index] === 1) {
|
|
75
|
+
newShape[index] = shapeA[index];
|
|
76
|
+
}
|
|
77
|
+
else if (shapeA[index] === shapeB[index]) {
|
|
78
|
+
newShape[index] = shapeA[index];
|
|
79
|
+
}
|
|
80
|
+
else {
|
|
81
|
+
throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
return newShape;
|
|
85
|
+
}
|
|
86
|
+
// Convert flat index to array of coordinates
|
|
87
|
+
static indexToCoords(index, shape, strides) {
|
|
88
|
+
const coords = new Array(shape.length);
|
|
89
|
+
let remaining = index;
|
|
90
|
+
// Sort dimensions by stride (largest first) for correct decomposition
|
|
91
|
+
const sortedDims = shape.map((_, i) => i).sort((a, b) => strides[b] - strides[a]);
|
|
92
|
+
for (const dim of sortedDims) {
|
|
93
|
+
coords[dim] = Math.floor(remaining / strides[dim]);
|
|
94
|
+
remaining %= strides[dim];
|
|
95
|
+
}
|
|
96
|
+
return coords;
|
|
97
|
+
}
|
|
98
|
+
// Convert array of coordinates to *unbroadcasted* flat index
|
|
99
|
+
static coordsToIndex(coords, shape, strides) {
|
|
100
|
+
let index = 0;
|
|
101
|
+
for (let i = 0; i < coords.length; i++) {
|
|
102
|
+
const coord = coords[i];
|
|
103
|
+
// Handle broadcasting
|
|
104
|
+
const actualCoord = shape[i] === 1 ? 0 : coord;
|
|
105
|
+
index += actualCoord * strides[i];
|
|
106
|
+
}
|
|
107
|
+
return index;
|
|
108
|
+
}
|
|
109
|
+
// Utility for binary (two operators involved) element-wise ops
|
|
110
|
+
static elementWiseAB(tA, tB, op) {
|
|
111
|
+
if (typeof tA.value === "number" && typeof tB.value === "number") {
|
|
112
|
+
return new Tensor(op(tA.value, tB.value));
|
|
113
|
+
}
|
|
114
|
+
if (typeof tA.value === "number") {
|
|
115
|
+
return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
|
|
116
|
+
}
|
|
117
|
+
if (typeof tB.value === "number") {
|
|
118
|
+
return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
|
|
119
|
+
}
|
|
120
|
+
// Pad + broadcast shape
|
|
121
|
+
const [paddedAShape, paddedBShape] = Tensor.padDims(tA.shape, tB.shape);
|
|
122
|
+
const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
|
|
123
|
+
// Get other output info
|
|
124
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
125
|
+
const outputSize = outputShape.reduce((a, b) => a * b, 1);
|
|
126
|
+
const outputValue = new Array(outputSize);
|
|
127
|
+
for (let i = 0; i < outputSize; i++) {
|
|
128
|
+
// Get coordinates from 1D index
|
|
129
|
+
const coordsOutput = Tensor.indexToCoords(i, outputShape, outputStrides);
|
|
130
|
+
// Convert the coordinates to 1D index of flattened A with respect to A's shape
|
|
131
|
+
const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape, Tensor.getStrides(paddedAShape));
|
|
132
|
+
// Convert the coordinates to 1D index of flattened B with respect to B's shape
|
|
133
|
+
const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape, Tensor.getStrides(paddedBShape));
|
|
134
|
+
// Calculate with op
|
|
135
|
+
outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
|
|
136
|
+
}
|
|
137
|
+
return new Tensor(outputValue, {
|
|
138
|
+
shape: outputShape,
|
|
139
|
+
strides: outputStrides
|
|
140
|
+
});
|
|
141
|
+
}
|
|
142
|
+
// Utility for self-inflicting element-wise ops
|
|
143
|
+
static elementWiseSelf(tA, op) {
|
|
144
|
+
if (typeof tA.value === "number")
|
|
145
|
+
return new Tensor(op(tA.value));
|
|
146
|
+
return new Tensor(tA.value.map(el => op(el)), { shape: tA.shape, strides: tA.strides });
|
|
147
|
+
}
|
|
148
|
+
// Utility to do element-wise operation and build a dag node with another tensor
|
|
149
|
+
elementWiseABDAG(other, op, thisGrad, otherGrad) {
|
|
150
|
+
other = Tensor.forceTensor(other);
|
|
151
|
+
const out = Tensor.elementWiseAB(this, other, op);
|
|
152
|
+
if (this.requiresGrad) {
|
|
153
|
+
out.requiresGrad = true;
|
|
154
|
+
out.children.push(this);
|
|
155
|
+
}
|
|
156
|
+
if (other.requiresGrad) {
|
|
157
|
+
out.requiresGrad = true;
|
|
158
|
+
out.children.push(other);
|
|
159
|
+
}
|
|
160
|
+
if (out.requiresGrad) {
|
|
161
|
+
out.gradFn = () => {
|
|
162
|
+
const outGrad = out.grad;
|
|
163
|
+
if (this.requiresGrad)
|
|
164
|
+
thisGrad(this, other, outGrad);
|
|
165
|
+
if (other.requiresGrad)
|
|
166
|
+
otherGrad(this, other, outGrad);
|
|
167
|
+
};
|
|
168
|
+
}
|
|
169
|
+
return out;
|
|
170
|
+
}
|
|
171
|
+
// Utility to do self-inflicting element-wise operation and build a dag node
|
|
172
|
+
elementWiseSelfDAG(op, thisGrad) {
|
|
173
|
+
const out = Tensor.elementWiseSelf(this, op);
|
|
174
|
+
if (this.requiresGrad) {
|
|
175
|
+
out.requiresGrad = true;
|
|
176
|
+
out.children.push(this);
|
|
177
|
+
}
|
|
178
|
+
if (out.requiresGrad) {
|
|
179
|
+
out.gradFn = () => {
|
|
180
|
+
const outGrad = out.grad;
|
|
181
|
+
if (this.requiresGrad)
|
|
182
|
+
thisGrad(this, outGrad);
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
return out;
|
|
186
|
+
}
|
|
187
|
+
// Utility to force an input value to be a tensor
|
|
188
|
+
static forceTensor(value) {
|
|
189
|
+
if (value instanceof Tensor)
|
|
190
|
+
return value;
|
|
191
|
+
return new Tensor(value);
|
|
192
|
+
}
|
|
193
|
+
// Utility to add to gradient of tensor
|
|
194
|
+
static addGrad(tensor, accumGrad) {
|
|
195
|
+
const axesToSqueeze = [];
|
|
196
|
+
const axesToReduce = [];
|
|
197
|
+
const shape = tensor.shape;
|
|
198
|
+
const gradShape = accumGrad.shape;
|
|
199
|
+
const paddedDims = gradShape.length - shape.length;
|
|
200
|
+
for (let i = 0; i < paddedDims; i++) {
|
|
201
|
+
axesToReduce.push(i);
|
|
202
|
+
axesToSqueeze.push(i);
|
|
203
|
+
}
|
|
204
|
+
for (let i = 0; i < shape.length; i++) {
|
|
205
|
+
if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
|
|
206
|
+
axesToReduce.push(i + paddedDims);
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
const reducedGrad = accumGrad.sum(axesToReduce, true);
|
|
210
|
+
// console.log(accumGrad, new Tensor([[1,1,1]]));
|
|
211
|
+
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
|
|
212
|
+
if (typeof tensor.grad === "undefined") {
|
|
213
|
+
tensor.grad = squeezedGrad;
|
|
214
|
+
}
|
|
215
|
+
else {
|
|
216
|
+
tensor.grad = tensor.grad.add(squeezedGrad);
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
// Tensor squeeze
|
|
220
|
+
squeeze(dims) {
|
|
221
|
+
if (typeof this.value === "number")
|
|
222
|
+
return new Tensor(this.value);
|
|
223
|
+
if (typeof dims === "number") {
|
|
224
|
+
dims = [dims];
|
|
225
|
+
}
|
|
226
|
+
if (typeof dims === "undefined") {
|
|
227
|
+
const shape = this.shape;
|
|
228
|
+
dims = [];
|
|
229
|
+
for (let index = 0; index < shape.length; index++) {
|
|
230
|
+
if (shape[index] === 1) {
|
|
231
|
+
dims.push(index);
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
const outShape = this.shape.filter((dim, i) => {
|
|
236
|
+
const shouldSqueeze = dims.includes(i);
|
|
237
|
+
if (shouldSqueeze && dim !== 1)
|
|
238
|
+
throw new Error(`Can not squeeze dim with size ${dim}`);
|
|
239
|
+
return !shouldSqueeze;
|
|
240
|
+
});
|
|
241
|
+
const outStrides = Tensor.getStrides(outShape);
|
|
242
|
+
const outValue = outShape.length === 0 ? this.value[0] : this.value;
|
|
243
|
+
const out = new Tensor(outValue, {
|
|
244
|
+
shape: outShape,
|
|
245
|
+
strides: outStrides
|
|
246
|
+
});
|
|
247
|
+
// Set up gradient if needed
|
|
248
|
+
if (this.requiresGrad) {
|
|
249
|
+
out.requiresGrad = true;
|
|
250
|
+
out.children.push(this);
|
|
251
|
+
out.gradFn = () => {
|
|
252
|
+
let restoredGrad = out.grad;
|
|
253
|
+
for (let i = dims.length - 1; i >= 0; i--) {
|
|
254
|
+
restoredGrad = restoredGrad.unsqueeze(dims[i]);
|
|
255
|
+
}
|
|
256
|
+
Tensor.addGrad(this, restoredGrad);
|
|
257
|
+
};
|
|
258
|
+
}
|
|
259
|
+
return out;
|
|
260
|
+
}
|
|
261
|
+
// Tensor unsqueeze - adds dimension of size 1 at specified position
|
|
262
|
+
unsqueeze(dim) {
|
|
263
|
+
if (typeof this.value === "number")
|
|
264
|
+
return new Tensor([this.value]);
|
|
265
|
+
if (dim < 0 || dim > this.shape.length) {
|
|
266
|
+
throw new Error(`Invalid dimension ${dim} for unsqueeze`);
|
|
267
|
+
}
|
|
268
|
+
// Insert size-1 dimension at specified position
|
|
269
|
+
const newShape = [...this.shape];
|
|
270
|
+
newShape.splice(dim, 0, 1);
|
|
271
|
+
// Insert appropriate stride for new dimension
|
|
272
|
+
const newStrides = [...this.strides];
|
|
273
|
+
// For dimension of size 1, stride can be any value since we never step through it
|
|
274
|
+
// Use the stride of the next dimension, or 1 if it's the last dimension
|
|
275
|
+
const newStride = dim < this.strides.length ? this.strides[dim] : 1;
|
|
276
|
+
newStrides.splice(dim, 0, newStride);
|
|
277
|
+
const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
|
|
278
|
+
// Set up gradient if needed
|
|
279
|
+
if (this.requiresGrad) {
|
|
280
|
+
out.requiresGrad = true;
|
|
281
|
+
out.children.push(this);
|
|
282
|
+
out.gradFn = () => {
|
|
283
|
+
Tensor.addGrad(this, out.grad.squeeze(dim));
|
|
284
|
+
};
|
|
285
|
+
}
|
|
286
|
+
return out;
|
|
287
|
+
}
|
|
288
|
+
// Tensor sum reduction
|
|
289
|
+
sum(dims, keepDims = false) {
|
|
290
|
+
if (typeof this.value === "number")
|
|
291
|
+
return new Tensor(this.value);
|
|
292
|
+
if (typeof dims === "number") {
|
|
293
|
+
dims = [dims];
|
|
294
|
+
}
|
|
295
|
+
if (typeof dims === "undefined") {
|
|
296
|
+
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
297
|
+
}
|
|
298
|
+
const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
|
|
299
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
300
|
+
const outputSize = outputShape.reduce((a, b) => a * b, 1);
|
|
301
|
+
const outputValue = new Array(outputSize).fill(0);
|
|
302
|
+
const originalSize = this.shape.reduce((a, b) => a * b, 1);
|
|
303
|
+
for (let index = 0; index < originalSize; index++) {
|
|
304
|
+
const coords = Tensor.indexToCoords(index, this.shape, this.strides);
|
|
305
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
306
|
+
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
307
|
+
// Convert output coordinates to flat index
|
|
308
|
+
const outFlatIndex = outCoords.reduce((acc, val, i) => acc + val * outputStrides[i], 0);
|
|
309
|
+
// Accumulate
|
|
310
|
+
const realFlatIndex = coords.reduce((acc, val, i) => acc + val * this.strides[i], 0);
|
|
311
|
+
outputValue[outFlatIndex] += this.value[realFlatIndex];
|
|
312
|
+
}
|
|
313
|
+
const out = new Tensor(outputValue, {
|
|
314
|
+
shape: outputShape,
|
|
315
|
+
strides: outputStrides
|
|
316
|
+
});
|
|
317
|
+
return keepDims ? out : out.squeeze(dims);
|
|
318
|
+
}
|
|
319
|
+
// Tensor element-wise addition
|
|
320
|
+
add(other) {
|
|
321
|
+
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => {
|
|
322
|
+
Tensor.addGrad(self, outGrad);
|
|
323
|
+
}, (self, other, outGrad) => {
|
|
324
|
+
Tensor.addGrad(other, outGrad);
|
|
325
|
+
});
|
|
326
|
+
}
|
|
327
|
+
// Tensor element-wise subtraction
|
|
328
|
+
sub(other) {
|
|
329
|
+
return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => {
|
|
330
|
+
Tensor.addGrad(self, outGrad);
|
|
331
|
+
}, (self, other, outGrad) => {
|
|
332
|
+
Tensor.addGrad(other, outGrad.neg());
|
|
333
|
+
});
|
|
334
|
+
}
|
|
335
|
+
// Tensor element-wise multiplication
|
|
336
|
+
mul(other) {
|
|
337
|
+
return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => {
|
|
338
|
+
Tensor.addGrad(self, outGrad.mul(other));
|
|
339
|
+
}, (self, other, outGrad) => {
|
|
340
|
+
Tensor.addGrad(other, outGrad.mul(self));
|
|
341
|
+
});
|
|
342
|
+
}
|
|
343
|
+
// Tensor element-wise power
|
|
344
|
+
pow(other) {
|
|
345
|
+
return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => {
|
|
346
|
+
Tensor.addGrad(self, outGrad.mul(other.mul(self.pow(other.sub(1)))));
|
|
347
|
+
}, (self, other, outGrad) => {
|
|
348
|
+
Tensor.addGrad(other, outGrad.mul(self.pow(other).mul(self.log())));
|
|
349
|
+
});
|
|
350
|
+
}
|
|
351
|
+
// Tensor element-wise division
|
|
352
|
+
div(other) {
|
|
353
|
+
return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => {
|
|
354
|
+
Tensor.addGrad(self, outGrad.div(other));
|
|
355
|
+
}, (self, other, outGrad) => {
|
|
356
|
+
Tensor.addGrad(other, outGrad.mul(self.neg().div(other.pow(2))));
|
|
357
|
+
});
|
|
358
|
+
}
|
|
359
|
+
// Tensor element-wise greater or equal comparison
|
|
360
|
+
ge(other) {
|
|
361
|
+
return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
362
|
+
}
|
|
363
|
+
// Tensor element-wise less or equal comparison
|
|
364
|
+
le(other) {
|
|
365
|
+
return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
366
|
+
}
|
|
367
|
+
// Tensor element-wise greater-than comparison
|
|
368
|
+
gt(other) {
|
|
369
|
+
return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
370
|
+
}
|
|
371
|
+
// Tensor element-wise less-than comparison
|
|
372
|
+
lt(other) {
|
|
373
|
+
return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
374
|
+
}
|
|
375
|
+
// Tensor element-wise equality comparison
|
|
376
|
+
eq(other) {
|
|
377
|
+
return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
378
|
+
}
|
|
379
|
+
// Tensor element-wise logical and
|
|
380
|
+
logicalAnd(other) {
|
|
381
|
+
return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
382
|
+
}
|
|
383
|
+
// Tensor element-wise logical or
|
|
384
|
+
logicalOr(other) {
|
|
385
|
+
return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
386
|
+
}
|
|
387
|
+
// Tensor element-wise logical xor
|
|
388
|
+
logicalXor(other) {
|
|
389
|
+
return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
390
|
+
}
|
|
391
|
+
// Tensor element-wise logical not
|
|
392
|
+
logicalNot() {
|
|
393
|
+
return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1, (self, outGrad) => { });
|
|
394
|
+
}
|
|
395
|
+
// Tensor element-wise bitwise and
|
|
396
|
+
bitwiseAnd(other) {
|
|
397
|
+
return this.elementWiseABDAG(other, (a, b) => a & b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
398
|
+
}
|
|
399
|
+
// Tensor element-wise bitwise or
|
|
400
|
+
bitwiseOr(other) {
|
|
401
|
+
return this.elementWiseABDAG(other, (a, b) => a | b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
402
|
+
}
|
|
403
|
+
// Tensor element-wise bitwise xor
|
|
404
|
+
bitwiseXor(other) {
|
|
405
|
+
return this.elementWiseABDAG(other, (a, b) => a ^ b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
406
|
+
}
|
|
407
|
+
// Tensor element-wise bitwise not
|
|
408
|
+
bitwiseNot() {
|
|
409
|
+
return this.elementWiseSelfDAG((a) => ~a, (self, outGrad) => { });
|
|
410
|
+
}
|
|
411
|
+
// Tensor element-wise left shift
|
|
412
|
+
bitwiseLeftShift(other) {
|
|
413
|
+
return this.elementWiseABDAG(other, (a, b) => a << b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
414
|
+
}
|
|
415
|
+
// Tensor element-wise right shift
|
|
416
|
+
bitwiseRightShift(other) {
|
|
417
|
+
return this.elementWiseABDAG(other, (a, b) => a >> b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
|
|
418
|
+
}
|
|
419
|
+
// Tensor element-wise negation
|
|
420
|
+
neg() {
|
|
421
|
+
return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => {
|
|
422
|
+
Tensor.addGrad(self, outGrad.mul(-1));
|
|
423
|
+
});
|
|
424
|
+
}
|
|
425
|
+
// Tensor element-wise absolute
|
|
426
|
+
abs() {
|
|
427
|
+
return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => {
|
|
428
|
+
Tensor.addGrad(self, outGrad.mul(self.sign()));
|
|
429
|
+
});
|
|
430
|
+
}
|
|
431
|
+
// Tensor element-wise sign function
|
|
432
|
+
sign() {
|
|
433
|
+
return this.elementWiseSelfDAG((a) => Math.sign(a), (self, outGrad) => { });
|
|
434
|
+
}
|
|
435
|
+
// Tensor element-wise sin
|
|
436
|
+
sin() {
|
|
437
|
+
return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => {
|
|
438
|
+
Tensor.addGrad(self, outGrad.mul(self.cos()));
|
|
439
|
+
});
|
|
440
|
+
}
|
|
441
|
+
// Tensor element-wise cos
|
|
442
|
+
cos() {
|
|
443
|
+
return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => {
|
|
444
|
+
Tensor.addGrad(self, outGrad.mul(self.sin().neg()));
|
|
445
|
+
});
|
|
446
|
+
}
|
|
447
|
+
// Tensor element-wise tan
|
|
448
|
+
tan() {
|
|
449
|
+
return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => {
|
|
450
|
+
Tensor.addGrad(self, outGrad.mul(self.tan().pow(2).add(1)));
|
|
451
|
+
});
|
|
452
|
+
}
|
|
453
|
+
// Tensor element-wise asin
|
|
454
|
+
asin() {
|
|
455
|
+
return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => {
|
|
456
|
+
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()));
|
|
457
|
+
});
|
|
458
|
+
}
|
|
459
|
+
// Tensor element-wise acos
|
|
460
|
+
acos() {
|
|
461
|
+
return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => {
|
|
462
|
+
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
|
|
463
|
+
});
|
|
464
|
+
}
|
|
465
|
+
// Tensor element-wise atan
|
|
466
|
+
atan() {
|
|
467
|
+
return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => {
|
|
468
|
+
Tensor.addGrad(self, outGrad.div(self.pow(2).add(1)));
|
|
469
|
+
});
|
|
470
|
+
}
|
|
471
|
+
// Tensor element-wise sinh
|
|
472
|
+
sinh() {
|
|
473
|
+
return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => {
|
|
474
|
+
Tensor.addGrad(self, outGrad.mul(self.cosh()));
|
|
475
|
+
});
|
|
476
|
+
}
|
|
477
|
+
// Tensor element-wise cosh
|
|
478
|
+
cosh() {
|
|
479
|
+
return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => {
|
|
480
|
+
Tensor.addGrad(self, outGrad.mul(self.sinh()));
|
|
481
|
+
});
|
|
482
|
+
}
|
|
483
|
+
// Tensor element-wise asinh
|
|
484
|
+
asinh() {
|
|
485
|
+
return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => {
|
|
486
|
+
Tensor.addGrad(self, outGrad.div(self.pow(2).add(1).sqrt()));
|
|
487
|
+
});
|
|
488
|
+
}
|
|
489
|
+
// Tensor element-wise acosh
|
|
490
|
+
acosh() {
|
|
491
|
+
return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => {
|
|
492
|
+
Tensor.addGrad(self, outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
|
|
493
|
+
});
|
|
494
|
+
}
|
|
495
|
+
// Tensor element-wise atanh
|
|
496
|
+
atanh() {
|
|
497
|
+
return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => {
|
|
498
|
+
Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1)));
|
|
499
|
+
});
|
|
500
|
+
}
|
|
501
|
+
// Tensor element-wise square root
|
|
502
|
+
sqrt() {
|
|
503
|
+
return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => {
|
|
504
|
+
Tensor.addGrad(self, outGrad.div(self.sqrt().mul(2)));
|
|
505
|
+
});
|
|
506
|
+
}
|
|
507
|
+
// Tensor element-wise e^x
|
|
508
|
+
exp() {
|
|
509
|
+
return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => {
|
|
510
|
+
Tensor.addGrad(self, outGrad.mul(self.exp()));
|
|
511
|
+
});
|
|
512
|
+
}
|
|
513
|
+
// Tensor element-wise natural log
|
|
514
|
+
log() {
|
|
515
|
+
return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => {
|
|
516
|
+
Tensor.addGrad(self, outGrad.div(self));
|
|
517
|
+
});
|
|
518
|
+
}
|
|
519
|
+
// Tensor element-wise log2
|
|
520
|
+
log2() {
|
|
521
|
+
return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => {
|
|
522
|
+
Tensor.addGrad(self, outGrad.div(self.mul(Math.log(2))));
|
|
523
|
+
});
|
|
524
|
+
}
|
|
525
|
+
// Tensor element-wise log10
|
|
526
|
+
log10() {
|
|
527
|
+
return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => {
|
|
528
|
+
Tensor.addGrad(self, outGrad.div(self.mul(Math.log(10))));
|
|
529
|
+
});
|
|
530
|
+
}
|
|
531
|
+
// Tensor element-wise log(1+x)
|
|
532
|
+
log1p() {
|
|
533
|
+
return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => {
|
|
534
|
+
Tensor.addGrad(self, outGrad.div(self.add(1)));
|
|
535
|
+
});
|
|
536
|
+
}
|
|
537
|
+
// Tensor element-wise relu
|
|
538
|
+
relu() {
|
|
539
|
+
return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => {
|
|
540
|
+
Tensor.addGrad(self, outGrad.mul(self.ge(0)));
|
|
541
|
+
});
|
|
542
|
+
}
|
|
543
|
+
// Tensor element-wise sigmoid
|
|
544
|
+
sigmoid() {
|
|
545
|
+
return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
|
|
546
|
+
const sig = self.sigmoid();
|
|
547
|
+
Tensor.addGrad(self, outGrad.mul(sig).mul(sig.neg().add(1)));
|
|
548
|
+
});
|
|
549
|
+
}
|
|
550
|
+
// Tensor element-wise tanh
|
|
551
|
+
tanh() {
|
|
552
|
+
return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => {
|
|
553
|
+
Tensor.addGrad(self, outGrad.mul(self.tanh().pow(2).neg().add(1)));
|
|
554
|
+
});
|
|
555
|
+
}
|
|
556
|
+
// Transpose
|
|
557
|
+
transpose(dim1, dim2) {
|
|
558
|
+
// If dimension out of bound, throw error
|
|
559
|
+
if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
|
|
560
|
+
throw new Error("Dimensions do not exist to tranpose");
|
|
561
|
+
}
|
|
562
|
+
// If same dimension, return copy
|
|
563
|
+
if (dim1 === dim2) {
|
|
564
|
+
return new Tensor(this.value, { shape: this.shape, strides: this.strides });
|
|
565
|
+
}
|
|
566
|
+
// Create new shape and strides by swapping
|
|
567
|
+
const newShape = [...this.shape];
|
|
568
|
+
const newStrides = [...this.strides];
|
|
569
|
+
[newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
|
|
570
|
+
[newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
|
|
571
|
+
// Create new tensor with same data but swapped shape/strides
|
|
572
|
+
const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
|
|
573
|
+
out.requiresGrad = this.requiresGrad;
|
|
574
|
+
// Handle gradient if needed
|
|
575
|
+
if (this.requiresGrad) {
|
|
576
|
+
out.children.push(this);
|
|
577
|
+
out.gradFn = () => {
|
|
578
|
+
Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
|
|
579
|
+
};
|
|
580
|
+
}
|
|
581
|
+
return out;
|
|
582
|
+
}
|
|
583
|
+
// Transpose 2D
|
|
584
|
+
t() {
|
|
585
|
+
// Verify matrix shape
|
|
586
|
+
if (this.shape.length !== 2) {
|
|
587
|
+
throw new Error("Input is not a matrix");
|
|
588
|
+
}
|
|
589
|
+
return this.transpose(0, 1);
|
|
590
|
+
}
|
|
591
|
+
// 1D tensor dot product
|
|
592
|
+
dot(other) {
|
|
593
|
+
other = Tensor.forceTensor(other);
|
|
594
|
+
// Verify 1D shape
|
|
595
|
+
if (this.shape.length !== 1 || other.shape.length !== 1) {
|
|
596
|
+
throw new Error("Inputs are not 1D tensors");
|
|
597
|
+
}
|
|
598
|
+
const vectLen = this.shape[0];
|
|
599
|
+
const vectA = this.value;
|
|
600
|
+
const vectB = other.value;
|
|
601
|
+
let sum = 0;
|
|
602
|
+
for (let index = 0; index < vectLen; index++) {
|
|
603
|
+
sum += vectA[index] * vectB[index];
|
|
604
|
+
}
|
|
605
|
+
const out = new Tensor(sum);
|
|
606
|
+
if (this.requiresGrad) {
|
|
607
|
+
out.requiresGrad = true;
|
|
608
|
+
out.children.push(this);
|
|
609
|
+
}
|
|
610
|
+
if (other.requiresGrad) {
|
|
611
|
+
out.requiresGrad = true;
|
|
612
|
+
out.children.push(other);
|
|
613
|
+
}
|
|
614
|
+
if (out.requiresGrad) {
|
|
615
|
+
out.gradFn = () => {
|
|
616
|
+
const outGrad = out.grad;
|
|
617
|
+
if (this.requiresGrad)
|
|
618
|
+
Tensor.addGrad(this, outGrad.mul(other));
|
|
619
|
+
if (other.requiresGrad)
|
|
620
|
+
Tensor.addGrad(other, outGrad.mul(this));
|
|
621
|
+
};
|
|
622
|
+
}
|
|
623
|
+
return out;
|
|
624
|
+
}
|
|
625
|
+
// Matrix multiplication
|
|
626
|
+
mm(other) {
|
|
627
|
+
other = Tensor.forceTensor(other);
|
|
628
|
+
// Verify 2D shape
|
|
629
|
+
if (this.shape.length !== 2 || other.shape.length !== 2) {
|
|
630
|
+
throw new Error("Inputs are not matrices");
|
|
631
|
+
}
|
|
632
|
+
const matA = this.value;
|
|
633
|
+
const matB = other.value;
|
|
634
|
+
const matAStrides = this.strides;
|
|
635
|
+
const matBStrides = other.strides;
|
|
636
|
+
const matARows = this.shape[0];
|
|
637
|
+
const matACols = this.shape[1];
|
|
638
|
+
const matBRows = other.shape[0];
|
|
639
|
+
const matBCols = other.shape[1];
|
|
640
|
+
if (matACols !== matBRows)
|
|
641
|
+
throw new Error("Invalid matrices shape for multiplication");
|
|
642
|
+
const matCShape = [matARows, matBCols];
|
|
643
|
+
const matCStrides = Tensor.getStrides(matCShape);
|
|
644
|
+
const matCSize = matCShape.reduce((a, b) => a * b, 1);
|
|
645
|
+
const matC = new Array(matCSize).fill(0);
|
|
646
|
+
for (let i = 0; i < matARows; i++) {
|
|
647
|
+
for (let j = 0; j < matBCols; j++) {
|
|
648
|
+
for (let k = 0; k < matACols; k++) {
|
|
649
|
+
matC[i * matCStrides[0] + j * matCStrides[1]] +=
|
|
650
|
+
matA[i * matAStrides[0] + k * matAStrides[1]] *
|
|
651
|
+
matB[k * matBStrides[0] + j * matBStrides[1]];
|
|
652
|
+
}
|
|
653
|
+
}
|
|
654
|
+
}
|
|
655
|
+
const out = new Tensor(matC, { shape: matCShape, strides: matCStrides });
|
|
656
|
+
if (this.requiresGrad) {
|
|
657
|
+
out.requiresGrad = true;
|
|
658
|
+
out.children.push(this);
|
|
659
|
+
}
|
|
660
|
+
if (other.requiresGrad) {
|
|
661
|
+
out.requiresGrad = true;
|
|
662
|
+
out.children.push(other);
|
|
663
|
+
}
|
|
664
|
+
if (out.requiresGrad) {
|
|
665
|
+
out.gradFn = () => {
|
|
666
|
+
const outGrad = out.grad;
|
|
667
|
+
if (this.requiresGrad)
|
|
668
|
+
Tensor.addGrad(this, outGrad.mm(other.t()));
|
|
669
|
+
if (other.requiresGrad)
|
|
670
|
+
Tensor.addGrad(other, this.t().mm(outGrad));
|
|
671
|
+
};
|
|
672
|
+
}
|
|
673
|
+
return out;
|
|
674
|
+
}
|
|
675
|
+
mv(other) {
|
|
676
|
+
other = Tensor.forceTensor(other);
|
|
677
|
+
// Verify 2D shape
|
|
678
|
+
if (this.shape.length !== 2 || other.shape.length !== 1) {
|
|
679
|
+
throw new Error("Input is not a 2D and 1D tensor pair");
|
|
680
|
+
}
|
|
681
|
+
// MM with no grad
|
|
682
|
+
const thisMat = new Tensor(this.value, { shape: this.shape, strides: this.strides });
|
|
683
|
+
const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [1, 1] });
|
|
684
|
+
const out = thisMat.mm(otherMat).squeeze(1);
|
|
685
|
+
// Handle grad with original tensors
|
|
686
|
+
if (this.requiresGrad) {
|
|
687
|
+
out.requiresGrad = true;
|
|
688
|
+
out.children.push(this);
|
|
689
|
+
}
|
|
690
|
+
if (other.requiresGrad) {
|
|
691
|
+
out.requiresGrad = true;
|
|
692
|
+
out.children.push(other);
|
|
693
|
+
}
|
|
694
|
+
if (out.requiresGrad) {
|
|
695
|
+
out.gradFn = () => {
|
|
696
|
+
const outGrad = out.grad;
|
|
697
|
+
if (this.requiresGrad)
|
|
698
|
+
Tensor.addGrad(this, outGrad.unsqueeze(1).mm(other.unsqueeze(0)));
|
|
699
|
+
if (other.requiresGrad)
|
|
700
|
+
Tensor.addGrad(other, this.t().mv(outGrad));
|
|
701
|
+
};
|
|
702
|
+
}
|
|
703
|
+
return out;
|
|
704
|
+
}
|
|
705
|
+
matmul(other) {
|
|
706
|
+
other = Tensor.forceTensor(other);
|
|
707
|
+
if (this.shape.length === 1 && other.shape.length === 1) {
|
|
708
|
+
return this.dot(other);
|
|
709
|
+
}
|
|
710
|
+
else if (this.shape.length === 1 && other.shape.length === 2) {
|
|
711
|
+
return this.unsqueeze(0).mm(other).squeeze(0);
|
|
712
|
+
}
|
|
713
|
+
else if (this.shape.length === 2 && other.shape.length === 1) {
|
|
714
|
+
return this.mv(other);
|
|
715
|
+
}
|
|
716
|
+
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
717
|
+
return this.mm(other);
|
|
718
|
+
}
|
|
719
|
+
// Too lazy for batched matmul
|
|
720
|
+
throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
|
|
721
|
+
}
|
|
722
|
+
// Utility to create a new tensor with shape of another tensor, filled with a number
|
|
723
|
+
static fullLike(tensor, num, options = {}) {
|
|
724
|
+
if (typeof tensor.value === "number")
|
|
725
|
+
return new Tensor(num, options);
|
|
726
|
+
return new Tensor(tensor.value.map(el => num), { shape: tensor.shape, strides: tensor.strides, ...options });
|
|
727
|
+
}
|
|
728
|
+
// Reverse-mode autodiff call
|
|
729
|
+
backward() {
|
|
730
|
+
// Build topological order
|
|
731
|
+
const topo = [];
|
|
732
|
+
const visited = new Set();
|
|
733
|
+
function build(node) {
|
|
734
|
+
if (!visited.has(node) && node.requiresGrad) {
|
|
735
|
+
visited.add(node);
|
|
736
|
+
node.grad = Tensor.fullLike(node, 0);
|
|
737
|
+
for (let child of node.children)
|
|
738
|
+
build(child);
|
|
739
|
+
topo.push(node);
|
|
740
|
+
}
|
|
741
|
+
}
|
|
742
|
+
build(this);
|
|
743
|
+
// Feed backward to calculate gradient
|
|
744
|
+
this.grad = Tensor.fullLike(this, 1);
|
|
745
|
+
for (let index = topo.length - 1; index > -1; index--) {
|
|
746
|
+
topo[index].gradFn();
|
|
747
|
+
}
|
|
748
|
+
}
|
|
749
|
+
// Returns the number/nD array form of tensor
|
|
750
|
+
val() {
|
|
751
|
+
if (typeof this.value === "number")
|
|
752
|
+
return this.value;
|
|
753
|
+
function buildNested(data, shape, strides, baseIndex = 0, dim = 0) {
|
|
754
|
+
if (dim === shape.length - 1) {
|
|
755
|
+
// Last dimension: extract elements using actual stride
|
|
756
|
+
const result = [];
|
|
757
|
+
for (let i = 0; i < shape[dim]; i++) {
|
|
758
|
+
result.push(data[baseIndex + i * strides[dim]]);
|
|
759
|
+
}
|
|
760
|
+
return result;
|
|
761
|
+
}
|
|
762
|
+
// Recursive case: build nested structure
|
|
763
|
+
const result = [];
|
|
764
|
+
for (let i = 0; i < shape[dim]; i++) {
|
|
765
|
+
result.push(buildNested(data, shape, strides, baseIndex + i * strides[dim], dim + 1));
|
|
766
|
+
}
|
|
767
|
+
return result;
|
|
768
|
+
}
|
|
769
|
+
return buildNested(this.value, this.shape, this.strides);
|
|
770
|
+
}
|
|
771
|
+
// Returns a copy of the tensor with gradient turned on/off
|
|
772
|
+
withGrad(requiresGrad) {
|
|
773
|
+
return new Tensor(this.value, {
|
|
774
|
+
shape: this.shape,
|
|
775
|
+
strides: this.strides,
|
|
776
|
+
requiresGrad
|
|
777
|
+
});
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
exports.Tensor = Tensor;
|