catniff 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/core.js +10 -5
  2. package/package.json +1 -1
package/dist/core.js CHANGED
@@ -241,12 +241,14 @@ class Tensor {
241
241
  }
242
242
  }
243
243
  }
244
+ // Remove size-1 dims only
244
245
  const outShape = this.shape.filter((dim, i) => {
245
246
  const shouldSqueeze = dims.includes(i);
246
247
  if (shouldSqueeze && dim !== 1)
247
248
  throw new Error(`Can not squeeze dim with size ${dim}`);
248
249
  return !shouldSqueeze;
249
250
  });
251
+ // Remove strides of size-1 dims
250
252
  const outStrides = this.strides.filter((stride, i) => !dims.includes(i));
251
253
  const outValue = outShape.length === 0 ? this.value[0] : this.value;
252
254
  const out = new Tensor(outValue, {
@@ -280,13 +282,13 @@ class Tensor {
280
282
  // New stride
281
283
  const newStrides = [...this.strides];
282
284
  let newDimStride;
283
- if (dim === 0) {
284
- // Inserting at front: use product of all original dimensions
285
- newDimStride = this.shape.reduce((a, b) => a * b, 1) || 1;
285
+ if (dim >= this.shape.length) {
286
+ // Inserting at the back: use 1
287
+ newDimStride = 1;
286
288
  }
287
289
  else {
288
- // Inserting elsewhere: use stride of previous dimension
289
- newDimStride = this.strides[dim - 1];
290
+ // Inserting before dim: use current stride * current shape
291
+ newDimStride = this.strides[dim] * this.shape[dim];
290
292
  }
291
293
  newStrides.splice(dim, 0, newDimStride);
292
294
  const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
@@ -585,6 +587,7 @@ class Tensor {
585
587
  }
586
588
  if (out.requiresGrad) {
587
589
  out.gradFn = () => {
590
+ // Disable gradient collecting of gradients themselves
588
591
  const outGrad = out.grad.withGrad(false);
589
592
  const selfNoGrad = this.withGrad(false);
590
593
  const otherNoGrad = other.withGrad(false);
@@ -637,6 +640,7 @@ class Tensor {
637
640
  }
638
641
  if (out.requiresGrad) {
639
642
  out.gradFn = () => {
643
+ // Disable gradient collecting of gradients themselves
640
644
  const outGrad = out.grad.withGrad(false);
641
645
  const selfNoGrad = this.withGrad(false);
642
646
  const otherNoGrad = other.withGrad(false);
@@ -670,6 +674,7 @@ class Tensor {
670
674
  }
671
675
  if (out.requiresGrad) {
672
676
  out.gradFn = () => {
677
+ // Disable gradient collecting of gradients themselves
673
678
  const outGrad = out.grad.withGrad(false);
674
679
  const selfNoGrad = this.withGrad(false);
675
680
  const otherNoGrad = other.withGrad(false);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.2.1",
3
+ "version": "0.2.3",
4
4
  "description": "A cute autograd engine for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {