catniff 0.7.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -61,6 +61,7 @@ export declare class Tensor {
61
61
  index(indices: Tensor | TensorValue): Tensor;
62
62
  slice(ranges: number[][]): Tensor;
63
63
  chunk(chunks: number, dim?: number): Tensor[];
64
+ cat(other: Tensor | TensorValue, dim?: number): Tensor;
64
65
  squeeze(dims?: number[] | number): Tensor;
65
66
  unsqueeze(dim: number): Tensor;
66
67
  static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
package/dist/core.js CHANGED
@@ -218,7 +218,6 @@ class Tensor {
218
218
  }
219
219
  if (out.requiresGrad) {
220
220
  out.gradFn = () => {
221
- // Disable gradient collecting of gradients themselves
222
221
  const outGrad = out.grad;
223
222
  const selfWithGrad = Tensor.createGraph ? this : this.detach();
224
223
  const otherWithGrad = Tensor.createGraph ? other : other.detach();
@@ -239,7 +238,6 @@ class Tensor {
239
238
  }
240
239
  if (out.requiresGrad) {
241
240
  out.gradFn = () => {
242
- // Disable gradient collecting of gradients themselves
243
241
  const outGrad = out.grad;
244
242
  const selfWithGrad = Tensor.createGraph ? this : this.detach();
245
243
  if (this.requiresGrad)
@@ -649,6 +647,90 @@ class Tensor {
649
647
  }
650
648
  return results;
651
649
  }
650
+ // Tensor concatentation
651
+ cat(other, dim = 0) {
652
+ other = this.handleOther(other);
653
+ // Handle scalars
654
+ if (typeof this.value === "number" || typeof other.value === "number") {
655
+ throw new Error("Can not concatenate scalars");
656
+ }
657
+ // Handle negative indices
658
+ if (dim < 0) {
659
+ dim += this.shape.length;
660
+ }
661
+ // If dimension out of bound, throw error
662
+ if (dim >= this.shape.length || dim < 0) {
663
+ throw new Error("Dimension does not exist to concatenate");
664
+ }
665
+ // If shape does not match, throw error
666
+ if (this.shape.length !== other.shape.length) {
667
+ throw new Error("Shape does not match to concatenate");
668
+ }
669
+ const outputShape = new Array(this.shape.length);
670
+ for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
671
+ if (currentDim === dim) {
672
+ outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim];
673
+ }
674
+ else if (this.shape[currentDim] !== other.shape[currentDim]) {
675
+ throw new Error("Shape does not match to concatenate");
676
+ }
677
+ else {
678
+ outputShape[currentDim] = this.shape[currentDim];
679
+ }
680
+ }
681
+ const outputSize = Tensor.shapeToSize(outputShape);
682
+ const outputStrides = Tensor.getStrides(outputShape);
683
+ const outputValue = new Array(outputSize);
684
+ for (let outIndex = 0; outIndex < outputSize; outIndex++) {
685
+ const coords = Tensor.indexToCoords(outIndex, outputStrides);
686
+ // Check which tensor this output position comes from
687
+ if (coords[dim] < this.shape[dim]) {
688
+ // Comes from this tensor
689
+ const srcIndex = Tensor.coordsToIndex(coords, this.strides);
690
+ outputValue[outIndex] = this.value[srcIndex + this.offset];
691
+ }
692
+ else {
693
+ // Comes from other tensor - adjust coordinate in concat dimension
694
+ const otherCoords = [...coords];
695
+ otherCoords[dim] -= this.shape[dim];
696
+ const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides);
697
+ outputValue[outIndex] = other.value[srcIndex + other.offset];
698
+ }
699
+ }
700
+ const out = new Tensor(outputValue, {
701
+ shape: outputShape,
702
+ strides: outputStrides,
703
+ numel: outputSize
704
+ });
705
+ if (this.requiresGrad) {
706
+ out.requiresGrad = true;
707
+ out.children.push(this);
708
+ }
709
+ if (other.requiresGrad) {
710
+ out.requiresGrad = true;
711
+ out.children.push(other);
712
+ }
713
+ if (out.requiresGrad) {
714
+ out.gradFn = () => {
715
+ const outGrad = out.grad;
716
+ const thisRanges = new Array(this.shape.length);
717
+ const otherRanges = new Array(other.shape.length);
718
+ for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
719
+ if (currentDim === dim) {
720
+ thisRanges[currentDim] = [0, this.shape[currentDim], 1];
721
+ otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1];
722
+ }
723
+ else {
724
+ thisRanges[currentDim] = [];
725
+ otherRanges[currentDim] = [];
726
+ }
727
+ }
728
+ Tensor.addGrad(this, outGrad.slice(thisRanges));
729
+ Tensor.addGrad(other, outGrad.slice(otherRanges));
730
+ };
731
+ }
732
+ return out;
733
+ }
652
734
  // Tensor squeeze
653
735
  squeeze(dims) {
654
736
  if (typeof this.value === "number")
@@ -1338,7 +1420,6 @@ class Tensor {
1338
1420
  }
1339
1421
  if (out.requiresGrad) {
1340
1422
  out.gradFn = () => {
1341
- // Disable gradient collecting of gradients themselves
1342
1423
  const outGrad = out.grad;
1343
1424
  const selfWithGrad = Tensor.createGraph ? this : this.detach();
1344
1425
  const otherWithGrad = Tensor.createGraph ? other : other.detach();
@@ -1396,7 +1477,6 @@ class Tensor {
1396
1477
  }
1397
1478
  if (out.requiresGrad) {
1398
1479
  out.gradFn = () => {
1399
- // Disable gradient collecting of gradients themselves
1400
1480
  const outGrad = out.grad;
1401
1481
  const selfWithGrad = Tensor.createGraph ? this : this.detach();
1402
1482
  const otherWithGrad = Tensor.createGraph ? other : other.detach();
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.7.0",
3
+ "version": "0.7.1",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {