catniff 0.2.1 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.js +10 -5
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -241,12 +241,14 @@ class Tensor {
|
|
|
241
241
|
}
|
|
242
242
|
}
|
|
243
243
|
}
|
|
244
|
+
// Remove size-1 dims only
|
|
244
245
|
const outShape = this.shape.filter((dim, i) => {
|
|
245
246
|
const shouldSqueeze = dims.includes(i);
|
|
246
247
|
if (shouldSqueeze && dim !== 1)
|
|
247
248
|
throw new Error(`Can not squeeze dim with size ${dim}`);
|
|
248
249
|
return !shouldSqueeze;
|
|
249
250
|
});
|
|
251
|
+
// Remove strides of size-1 dims
|
|
250
252
|
const outStrides = this.strides.filter((stride, i) => !dims.includes(i));
|
|
251
253
|
const outValue = outShape.length === 0 ? this.value[0] : this.value;
|
|
252
254
|
const out = new Tensor(outValue, {
|
|
@@ -280,13 +282,13 @@ class Tensor {
|
|
|
280
282
|
// New stride
|
|
281
283
|
const newStrides = [...this.strides];
|
|
282
284
|
let newDimStride;
|
|
283
|
-
if (dim
|
|
284
|
-
// Inserting at
|
|
285
|
-
newDimStride =
|
|
285
|
+
if (dim >= this.shape.length) {
|
|
286
|
+
// Inserting at the back: use 1
|
|
287
|
+
newDimStride = 1;
|
|
286
288
|
}
|
|
287
289
|
else {
|
|
288
|
-
// Inserting
|
|
289
|
-
newDimStride = this.strides[dim
|
|
290
|
+
// Inserting before dim: use current stride * current shape
|
|
291
|
+
newDimStride = this.strides[dim] * this.shape[dim];
|
|
290
292
|
}
|
|
291
293
|
newStrides.splice(dim, 0, newDimStride);
|
|
292
294
|
const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
|
|
@@ -585,6 +587,7 @@ class Tensor {
|
|
|
585
587
|
}
|
|
586
588
|
if (out.requiresGrad) {
|
|
587
589
|
out.gradFn = () => {
|
|
590
|
+
// Disable gradient collecting of gradients themselves
|
|
588
591
|
const outGrad = out.grad.withGrad(false);
|
|
589
592
|
const selfNoGrad = this.withGrad(false);
|
|
590
593
|
const otherNoGrad = other.withGrad(false);
|
|
@@ -637,6 +640,7 @@ class Tensor {
|
|
|
637
640
|
}
|
|
638
641
|
if (out.requiresGrad) {
|
|
639
642
|
out.gradFn = () => {
|
|
643
|
+
// Disable gradient collecting of gradients themselves
|
|
640
644
|
const outGrad = out.grad.withGrad(false);
|
|
641
645
|
const selfNoGrad = this.withGrad(false);
|
|
642
646
|
const otherNoGrad = other.withGrad(false);
|
|
@@ -670,6 +674,7 @@ class Tensor {
|
|
|
670
674
|
}
|
|
671
675
|
if (out.requiresGrad) {
|
|
672
676
|
out.gradFn = () => {
|
|
677
|
+
// Disable gradient collecting of gradients themselves
|
|
673
678
|
const outGrad = out.grad.withGrad(false);
|
|
674
679
|
const selfNoGrad = this.withGrad(false);
|
|
675
680
|
const otherNoGrad = other.withGrad(false);
|