catniff 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -19,14 +19,14 @@ export declare class Tensor {
19
19
  static flatten(tensor: TensorValue): number[] | number;
20
20
  static getShape(tensor: TensorValue): number[];
21
21
  static getStrides(shape: number[]): number[];
22
- static padDims(shapeA: number[], shapeB: number[]): number[][];
22
+ static padShape(stridesA: number[], stridesB: number[], shapeA: number[], shapeB: number[]): number[][];
23
23
  static broadcastShapes(shapeA: number[], shapeB: number[]): number[];
24
24
  static indexToCoords(index: number, shape: number[], strides: number[]): number[];
25
25
  static coordsToIndex(coords: number[], shape: number[], strides: number[]): number;
26
26
  static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
27
27
  static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
28
- elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad: (self: Tensor, other: Tensor, outGrad: Tensor) => void, otherGrad: (self: Tensor, other: Tensor, outGrad: Tensor) => void): Tensor;
29
- elementWiseSelfDAG(op: (a: number) => number, thisGrad: (self: Tensor, outGrad: Tensor) => void): Tensor;
28
+ elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
29
+ elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
30
30
  static forceTensor(value: TensorValue | Tensor): Tensor;
31
31
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
32
32
  squeeze(dims?: number[] | number): Tensor;
package/dist/core.js CHANGED
@@ -53,16 +53,21 @@ class Tensor {
53
53
  }
54
54
  return strides;
55
55
  }
56
- // Left-pad dimensions for two shape to be of same length
57
- static padDims(shapeA, shapeB) {
58
- const newA = [...shapeA], newB = [...shapeB];
59
- while (newA.length < newB.length) {
60
- newA.unshift(1);
61
- }
62
- while (newA.length > newB.length) {
63
- newB.unshift(1);
64
- }
65
- return [newA, newB];
56
+ // Left-pad shape and strides for two shape to be of same length
57
+ static padShape(stridesA, stridesB, shapeA, shapeB) {
58
+ const newStrideA = [...stridesA], newStrideB = [...stridesB];
59
+ const newShapeA = [...shapeA], newShapeB = [...shapeB];
60
+ while (newStrideA.length < newStrideB.length) {
61
+ const newStride = newShapeA[0] * newStrideA[0];
62
+ newStrideA.unshift(newStride);
63
+ newShapeA.unshift(1);
64
+ }
65
+ while (newStrideA.length > newStrideB.length) {
66
+ const newStride = newShapeB[0] * newStrideB[0];
67
+ newStrideB.unshift(newStride);
68
+ newShapeB.unshift(1);
69
+ }
70
+ return [newStrideA, newStrideB, newShapeA, newShapeB];
66
71
  }
67
72
  // Broadcast shapes
68
73
  static broadcastShapes(shapeA, shapeB) {
@@ -118,7 +123,7 @@ class Tensor {
118
123
  return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
119
124
  }
120
125
  // Pad + broadcast shape
121
- const [paddedAShape, paddedBShape] = Tensor.padDims(tA.shape, tB.shape);
126
+ const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
122
127
  const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
123
128
  // Get other output info
124
129
  const outputStrides = Tensor.getStrides(outputShape);
@@ -128,9 +133,9 @@ class Tensor {
128
133
  // Get coordinates from 1D index
129
134
  const coordsOutput = Tensor.indexToCoords(i, outputShape, outputStrides);
130
135
  // Convert the coordinates to 1D index of flattened A with respect to A's shape
131
- const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape, Tensor.getStrides(paddedAShape));
136
+ const indexA = Tensor.coordsToIndex(coordsOutput, paddedAShape, paddedAStrides);
132
137
  // Convert the coordinates to 1D index of flattened B with respect to B's shape
133
- const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape, Tensor.getStrides(paddedBShape));
138
+ const indexB = Tensor.coordsToIndex(coordsOutput, paddedBShape, paddedBStrides);
134
139
  // Calculate with op
135
140
  outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
136
141
  }
@@ -143,10 +148,10 @@ class Tensor {
143
148
  static elementWiseSelf(tA, op) {
144
149
  if (typeof tA.value === "number")
145
150
  return new Tensor(op(tA.value));
146
- return new Tensor(tA.value.map(el => op(el)), { shape: tA.shape, strides: tA.strides });
151
+ return new Tensor(tA.value.map(el => op(el)), { shape: [...tA.shape], strides: [...tA.strides] });
147
152
  }
148
153
  // Utility to do element-wise operation and build a dag node with another tensor
149
- elementWiseABDAG(other, op, thisGrad, otherGrad) {
154
+ elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
150
155
  other = Tensor.forceTensor(other);
151
156
  const out = Tensor.elementWiseAB(this, other, op);
152
157
  if (this.requiresGrad) {
@@ -159,17 +164,20 @@ class Tensor {
159
164
  }
160
165
  if (out.requiresGrad) {
161
166
  out.gradFn = () => {
162
- const outGrad = out.grad;
167
+ // Disable gradient collecting of gradients themselves
168
+ const outGrad = out.grad.withGrad(false);
169
+ const selfNoGrad = this.withGrad(false);
170
+ const otherNoGrad = other.withGrad(false);
163
171
  if (this.requiresGrad)
164
- thisGrad(this, other, outGrad);
172
+ Tensor.addGrad(this, thisGrad(selfNoGrad, otherNoGrad, outGrad));
165
173
  if (other.requiresGrad)
166
- otherGrad(this, other, outGrad);
174
+ Tensor.addGrad(other, otherGrad(selfNoGrad, otherNoGrad, outGrad));
167
175
  };
168
176
  }
169
177
  return out;
170
178
  }
171
179
  // Utility to do self-inflicting element-wise operation and build a dag node
172
- elementWiseSelfDAG(op, thisGrad) {
180
+ elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) {
173
181
  const out = Tensor.elementWiseSelf(this, op);
174
182
  if (this.requiresGrad) {
175
183
  out.requiresGrad = true;
@@ -177,9 +185,11 @@ class Tensor {
177
185
  }
178
186
  if (out.requiresGrad) {
179
187
  out.gradFn = () => {
180
- const outGrad = out.grad;
188
+ // Disable gradient collecting of gradients themselves
189
+ const outGrad = out.grad.withGrad(false);
190
+ const selfNoGrad = this.withGrad(false);
181
191
  if (this.requiresGrad)
182
- thisGrad(this, outGrad);
192
+ Tensor.addGrad(this, thisGrad(selfNoGrad, outGrad));
183
193
  };
184
194
  }
185
195
  return out;
@@ -207,7 +217,6 @@ class Tensor {
207
217
  }
208
218
  }
209
219
  const reducedGrad = accumGrad.sum(axesToReduce, true);
210
- // console.log(accumGrad, new Tensor([[1,1,1]]));
211
220
  const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
212
221
  if (typeof tensor.grad === "undefined") {
213
222
  tensor.grad = squeezedGrad;
@@ -238,7 +247,7 @@ class Tensor {
238
247
  throw new Error(`Can not squeeze dim with size ${dim}`);
239
248
  return !shouldSqueeze;
240
249
  });
241
- const outStrides = Tensor.getStrides(outShape);
250
+ const outStrides = this.strides.filter((stride, i) => !dims.includes(i));
242
251
  const outValue = outShape.length === 0 ? this.value[0] : this.value;
243
252
  const out = new Tensor(outValue, {
244
253
  shape: outShape,
@@ -249,7 +258,7 @@ class Tensor {
249
258
  out.requiresGrad = true;
250
259
  out.children.push(this);
251
260
  out.gradFn = () => {
252
- let restoredGrad = out.grad;
261
+ let restoredGrad = out.grad.withGrad(false);
253
262
  for (let i = dims.length - 1; i >= 0; i--) {
254
263
  restoredGrad = restoredGrad.unsqueeze(dims[i]);
255
264
  }
@@ -268,19 +277,25 @@ class Tensor {
268
277
  // Insert size-1 dimension at specified position
269
278
  const newShape = [...this.shape];
270
279
  newShape.splice(dim, 0, 1);
271
- // Insert appropriate stride for new dimension
280
+ // New stride
272
281
  const newStrides = [...this.strides];
273
- // For dimension of size 1, stride can be any value since we never step through it
274
- // Use the stride of the next dimension, or 1 if it's the last dimension
275
- const newStride = dim < this.strides.length ? this.strides[dim] : 1;
276
- newStrides.splice(dim, 0, newStride);
282
+ let newDimStride;
283
+ if (dim === 0) {
284
+ // Inserting at front: use product of all original dimensions
285
+ newDimStride = this.shape.reduce((a, b) => a * b, 1) || 1;
286
+ }
287
+ else {
288
+ // Inserting elsewhere: use stride of previous dimension
289
+ newDimStride = this.strides[dim - 1];
290
+ }
291
+ newStrides.splice(dim, 0, newDimStride);
277
292
  const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
278
293
  // Set up gradient if needed
279
294
  if (this.requiresGrad) {
280
295
  out.requiresGrad = true;
281
296
  out.children.push(this);
282
297
  out.gradFn = () => {
283
- Tensor.addGrad(this, out.grad.squeeze(dim));
298
+ Tensor.addGrad(this, out.grad.withGrad(false).squeeze(dim));
284
299
  };
285
300
  }
286
301
  return out;
@@ -300,6 +315,12 @@ class Tensor {
300
315
  const outputSize = outputShape.reduce((a, b) => a * b, 1);
301
316
  const outputValue = new Array(outputSize).fill(0);
302
317
  const originalSize = this.shape.reduce((a, b) => a * b, 1);
318
+ let gradShape, gradStrides, gradValue = [];
319
+ if (this.requiresGrad) {
320
+ gradShape = [...this.shape];
321
+ gradStrides = [...this.strides];
322
+ gradValue = new Array(originalSize).fill(0);
323
+ }
303
324
  for (let index = 0; index < originalSize; index++) {
304
325
  const coords = Tensor.indexToCoords(index, this.shape, this.strides);
305
326
  // Force 0 on reduced axes to collapse into size-1 dims
@@ -309,249 +330,200 @@ class Tensor {
309
330
  // Accumulate
310
331
  const realFlatIndex = coords.reduce((acc, val, i) => acc + val * this.strides[i], 0);
311
332
  outputValue[outFlatIndex] += this.value[realFlatIndex];
333
+ // Mark for gradient
334
+ if (this.requiresGrad) {
335
+ (gradValue)[realFlatIndex] = 1;
336
+ }
312
337
  }
313
338
  const out = new Tensor(outputValue, {
314
339
  shape: outputShape,
315
340
  strides: outputStrides
316
341
  });
342
+ // Set up gradient if needed
343
+ if (this.requiresGrad) {
344
+ out.requiresGrad = true;
345
+ out.children.push(this);
346
+ out.gradFn = () => {
347
+ const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
348
+ Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
349
+ };
350
+ }
317
351
  return keepDims ? out : out.squeeze(dims);
318
352
  }
319
353
  // Tensor element-wise addition
320
354
  add(other) {
321
- return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => {
322
- Tensor.addGrad(self, outGrad);
323
- }, (self, other, outGrad) => {
324
- Tensor.addGrad(other, outGrad);
325
- });
355
+ return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
326
356
  }
327
357
  // Tensor element-wise subtraction
328
358
  sub(other) {
329
- return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => {
330
- Tensor.addGrad(self, outGrad);
331
- }, (self, other, outGrad) => {
332
- Tensor.addGrad(other, outGrad.neg());
333
- });
359
+ return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
334
360
  }
335
361
  // Tensor element-wise multiplication
336
362
  mul(other) {
337
- return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => {
338
- Tensor.addGrad(self, outGrad.mul(other));
339
- }, (self, other, outGrad) => {
340
- Tensor.addGrad(other, outGrad.mul(self));
341
- });
363
+ return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
342
364
  }
343
365
  // Tensor element-wise power
344
366
  pow(other) {
345
- return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => {
346
- Tensor.addGrad(self, outGrad.mul(other.mul(self.pow(other.sub(1)))));
347
- }, (self, other, outGrad) => {
348
- Tensor.addGrad(other, outGrad.mul(self.pow(other).mul(self.log())));
349
- });
367
+ return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
350
368
  }
351
369
  // Tensor element-wise division
352
370
  div(other) {
353
- return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => {
354
- Tensor.addGrad(self, outGrad.div(other));
355
- }, (self, other, outGrad) => {
356
- Tensor.addGrad(other, outGrad.mul(self.neg().div(other.pow(2))));
357
- });
371
+ return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.pow(2))));
358
372
  }
359
373
  // Tensor element-wise greater or equal comparison
360
374
  ge(other) {
361
- return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
375
+ return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
362
376
  }
363
377
  // Tensor element-wise less or equal comparison
364
378
  le(other) {
365
- return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
379
+ return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
366
380
  }
367
381
  // Tensor element-wise greater-than comparison
368
382
  gt(other) {
369
- return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
383
+ return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
370
384
  }
371
385
  // Tensor element-wise less-than comparison
372
386
  lt(other) {
373
- return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
387
+ return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
374
388
  }
375
389
  // Tensor element-wise equality comparison
376
390
  eq(other) {
377
- return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
391
+ return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
378
392
  }
379
393
  // Tensor element-wise logical and
380
394
  logicalAnd(other) {
381
- return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
395
+ return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
382
396
  }
383
397
  // Tensor element-wise logical or
384
398
  logicalOr(other) {
385
- return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
399
+ return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
386
400
  }
387
401
  // Tensor element-wise logical xor
388
402
  logicalXor(other) {
389
- return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0, (self, other, outGrad) => { }, (self, other, outGrad) => { });
403
+ return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
390
404
  }
391
405
  // Tensor element-wise logical not
392
406
  logicalNot() {
393
- return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1, (self, outGrad) => { });
407
+ return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
394
408
  }
395
409
  // Tensor element-wise bitwise and
396
410
  bitwiseAnd(other) {
397
- return this.elementWiseABDAG(other, (a, b) => a & b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
411
+ return this.elementWiseABDAG(other, (a, b) => a & b);
398
412
  }
399
413
  // Tensor element-wise bitwise or
400
414
  bitwiseOr(other) {
401
- return this.elementWiseABDAG(other, (a, b) => a | b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
415
+ return this.elementWiseABDAG(other, (a, b) => a | b);
402
416
  }
403
417
  // Tensor element-wise bitwise xor
404
418
  bitwiseXor(other) {
405
- return this.elementWiseABDAG(other, (a, b) => a ^ b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
419
+ return this.elementWiseABDAG(other, (a, b) => a ^ b);
406
420
  }
407
421
  // Tensor element-wise bitwise not
408
422
  bitwiseNot() {
409
- return this.elementWiseSelfDAG((a) => ~a, (self, outGrad) => { });
423
+ return this.elementWiseSelfDAG((a) => ~a);
410
424
  }
411
425
  // Tensor element-wise left shift
412
426
  bitwiseLeftShift(other) {
413
- return this.elementWiseABDAG(other, (a, b) => a << b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
427
+ return this.elementWiseABDAG(other, (a, b) => a << b);
414
428
  }
415
429
  // Tensor element-wise right shift
416
430
  bitwiseRightShift(other) {
417
- return this.elementWiseABDAG(other, (a, b) => a >> b, (self, other, outGrad) => { }, (self, other, outGrad) => { });
431
+ return this.elementWiseABDAG(other, (a, b) => a >> b);
418
432
  }
419
433
  // Tensor element-wise negation
420
434
  neg() {
421
- return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => {
422
- Tensor.addGrad(self, outGrad.mul(-1));
423
- });
435
+ return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
424
436
  }
425
437
  // Tensor element-wise absolute
426
438
  abs() {
427
- return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => {
428
- Tensor.addGrad(self, outGrad.mul(self.sign()));
429
- });
439
+ return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
430
440
  }
431
441
  // Tensor element-wise sign function
432
442
  sign() {
433
- return this.elementWiseSelfDAG((a) => Math.sign(a), (self, outGrad) => { });
443
+ return this.elementWiseSelfDAG((a) => Math.sign(a));
434
444
  }
435
445
  // Tensor element-wise sin
436
446
  sin() {
437
- return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => {
438
- Tensor.addGrad(self, outGrad.mul(self.cos()));
439
- });
447
+ return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
440
448
  }
441
449
  // Tensor element-wise cos
442
450
  cos() {
443
- return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => {
444
- Tensor.addGrad(self, outGrad.mul(self.sin().neg()));
445
- });
451
+ return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
446
452
  }
447
453
  // Tensor element-wise tan
448
454
  tan() {
449
- return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => {
450
- Tensor.addGrad(self, outGrad.mul(self.tan().pow(2).add(1)));
451
- });
455
+ return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().pow(2).add(1)));
452
456
  }
453
457
  // Tensor element-wise asin
454
458
  asin() {
455
- return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => {
456
- Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()));
457
- });
459
+ return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1).sqrt()));
458
460
  }
459
461
  // Tensor element-wise acos
460
462
  acos() {
461
- return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => {
462
- Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
463
- });
463
+ return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1).sqrt()).neg());
464
464
  }
465
465
  // Tensor element-wise atan
466
466
  atan() {
467
- return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => {
468
- Tensor.addGrad(self, outGrad.div(self.pow(2).add(1)));
469
- });
467
+ return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.pow(2).add(1)));
470
468
  }
471
469
  // Tensor element-wise sinh
472
470
  sinh() {
473
- return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => {
474
- Tensor.addGrad(self, outGrad.mul(self.cosh()));
475
- });
471
+ return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
476
472
  }
477
473
  // Tensor element-wise cosh
478
474
  cosh() {
479
- return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => {
480
- Tensor.addGrad(self, outGrad.mul(self.sinh()));
481
- });
475
+ return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
482
476
  }
483
477
  // Tensor element-wise asinh
484
478
  asinh() {
485
- return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => {
486
- Tensor.addGrad(self, outGrad.div(self.pow(2).add(1).sqrt()));
487
- });
479
+ return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.pow(2).add(1).sqrt()));
488
480
  }
489
481
  // Tensor element-wise acosh
490
482
  acosh() {
491
- return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => {
492
- Tensor.addGrad(self, outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
493
- });
483
+ return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
494
484
  }
495
485
  // Tensor element-wise atanh
496
486
  atanh() {
497
- return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => {
498
- Tensor.addGrad(self, outGrad.div(self.pow(2).neg().add(1)));
499
- });
487
+ return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.pow(2).neg().add(1)));
500
488
  }
501
489
  // Tensor element-wise square root
502
490
  sqrt() {
503
- return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => {
504
- Tensor.addGrad(self, outGrad.div(self.sqrt().mul(2)));
505
- });
491
+ return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
506
492
  }
507
493
  // Tensor element-wise e^x
508
494
  exp() {
509
- return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => {
510
- Tensor.addGrad(self, outGrad.mul(self.exp()));
511
- });
495
+ return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
512
496
  }
513
497
  // Tensor element-wise natural log
514
498
  log() {
515
- return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => {
516
- Tensor.addGrad(self, outGrad.div(self));
517
- });
499
+ return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
518
500
  }
519
501
  // Tensor element-wise log2
520
502
  log2() {
521
- return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => {
522
- Tensor.addGrad(self, outGrad.div(self.mul(Math.log(2))));
523
- });
503
+ return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
524
504
  }
525
505
  // Tensor element-wise log10
526
506
  log10() {
527
- return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => {
528
- Tensor.addGrad(self, outGrad.div(self.mul(Math.log(10))));
529
- });
507
+ return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
530
508
  }
531
509
  // Tensor element-wise log(1+x)
532
510
  log1p() {
533
- return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => {
534
- Tensor.addGrad(self, outGrad.div(self.add(1)));
535
- });
511
+ return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
536
512
  }
537
513
  // Tensor element-wise relu
538
514
  relu() {
539
- return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => {
540
- Tensor.addGrad(self, outGrad.mul(self.ge(0)));
541
- });
515
+ return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.ge(0)));
542
516
  }
543
517
  // Tensor element-wise sigmoid
544
518
  sigmoid() {
545
519
  return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
546
520
  const sig = self.sigmoid();
547
- Tensor.addGrad(self, outGrad.mul(sig).mul(sig.neg().add(1)));
521
+ return outGrad.mul(sig).mul(sig.neg().add(1));
548
522
  });
549
523
  }
550
524
  // Tensor element-wise tanh
551
525
  tanh() {
552
- return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => {
553
- Tensor.addGrad(self, outGrad.mul(self.tanh().pow(2).neg().add(1)));
554
- });
526
+ return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().pow(2).neg().add(1)));
555
527
  }
556
528
  // Transpose
557
529
  transpose(dim1, dim2) {
@@ -561,7 +533,7 @@ class Tensor {
561
533
  }
562
534
  // If same dimension, return copy
563
535
  if (dim1 === dim2) {
564
- return new Tensor(this.value, { shape: this.shape, strides: this.strides });
536
+ return new Tensor(this.value, { shape: [...this.shape], strides: [...this.strides] });
565
537
  }
566
538
  // Create new shape and strides by swapping
567
539
  const newShape = [...this.shape];
@@ -575,7 +547,7 @@ class Tensor {
575
547
  if (this.requiresGrad) {
576
548
  out.children.push(this);
577
549
  out.gradFn = () => {
578
- Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
550
+ Tensor.addGrad(this, out.grad.withGrad(false).transpose(dim1, dim2));
579
551
  };
580
552
  }
581
553
  return out;
@@ -613,11 +585,13 @@ class Tensor {
613
585
  }
614
586
  if (out.requiresGrad) {
615
587
  out.gradFn = () => {
616
- const outGrad = out.grad;
588
+ const outGrad = out.grad.withGrad(false);
589
+ const selfNoGrad = this.withGrad(false);
590
+ const otherNoGrad = other.withGrad(false);
617
591
  if (this.requiresGrad)
618
- Tensor.addGrad(this, outGrad.mul(other));
592
+ Tensor.addGrad(this, outGrad.mul(otherNoGrad));
619
593
  if (other.requiresGrad)
620
- Tensor.addGrad(other, outGrad.mul(this));
594
+ Tensor.addGrad(other, outGrad.mul(selfNoGrad));
621
595
  };
622
596
  }
623
597
  return out;
@@ -663,15 +637,18 @@ class Tensor {
663
637
  }
664
638
  if (out.requiresGrad) {
665
639
  out.gradFn = () => {
666
- const outGrad = out.grad;
640
+ const outGrad = out.grad.withGrad(false);
641
+ const selfNoGrad = this.withGrad(false);
642
+ const otherNoGrad = other.withGrad(false);
667
643
  if (this.requiresGrad)
668
- Tensor.addGrad(this, outGrad.mm(other.t()));
644
+ Tensor.addGrad(this, outGrad.mm(otherNoGrad.t()));
669
645
  if (other.requiresGrad)
670
- Tensor.addGrad(other, this.t().mm(outGrad));
646
+ Tensor.addGrad(other, selfNoGrad.t().mm(outGrad));
671
647
  };
672
648
  }
673
649
  return out;
674
650
  }
651
+ // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
675
652
  mv(other) {
676
653
  other = Tensor.forceTensor(other);
677
654
  // Verify 2D shape
@@ -679,8 +656,8 @@ class Tensor {
679
656
  throw new Error("Input is not a 2D and 1D tensor pair");
680
657
  }
681
658
  // MM with no grad
682
- const thisMat = new Tensor(this.value, { shape: this.shape, strides: this.strides });
683
- const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [1, 1] });
659
+ const thisMat = new Tensor(this.value, { shape: [...this.shape], strides: [...this.strides] });
660
+ const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [other.strides[0], 1] });
684
661
  const out = thisMat.mm(otherMat).squeeze(1);
685
662
  // Handle grad with original tensors
686
663
  if (this.requiresGrad) {
@@ -693,15 +670,18 @@ class Tensor {
693
670
  }
694
671
  if (out.requiresGrad) {
695
672
  out.gradFn = () => {
696
- const outGrad = out.grad;
673
+ const outGrad = out.grad.withGrad(false);
674
+ const selfNoGrad = this.withGrad(false);
675
+ const otherNoGrad = other.withGrad(false);
697
676
  if (this.requiresGrad)
698
- Tensor.addGrad(this, outGrad.unsqueeze(1).mm(other.unsqueeze(0)));
677
+ Tensor.addGrad(this, outGrad.unsqueeze(1).mm(otherNoGrad.unsqueeze(0)));
699
678
  if (other.requiresGrad)
700
- Tensor.addGrad(other, this.t().mv(outGrad));
679
+ Tensor.addGrad(other, selfNoGrad.t().mv(outGrad));
701
680
  };
702
681
  }
703
682
  return out;
704
683
  }
684
+ // General matrix multiplication with different shapes
705
685
  matmul(other) {
706
686
  other = Tensor.forceTensor(other);
707
687
  if (this.shape.length === 1 && other.shape.length === 1) {
@@ -723,7 +703,7 @@ class Tensor {
723
703
  static fullLike(tensor, num, options = {}) {
724
704
  if (typeof tensor.value === "number")
725
705
  return new Tensor(num, options);
726
- return new Tensor(tensor.value.map(el => num), { shape: tensor.shape, strides: tensor.strides, ...options });
706
+ return new Tensor(tensor.value.map(el => num), { shape: [...tensor.shape], strides: [...tensor.strides], ...options });
727
707
  }
728
708
  // Reverse-mode autodiff call
729
709
  backward() {
@@ -771,8 +751,8 @@ class Tensor {
771
751
  // Returns a copy of the tensor with gradient turned on/off
772
752
  withGrad(requiresGrad) {
773
753
  return new Tensor(this.value, {
774
- shape: this.shape,
775
- strides: this.strides,
754
+ shape: [...this.shape],
755
+ strides: [...this.strides],
776
756
  requiresGrad
777
757
  });
778
758
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.2.0",
3
+ "version": "0.2.1",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {