catniff 0.8.22 → 0.8.23

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -70,6 +70,7 @@ export declare class Tensor {
70
70
  chunk(chunks: number, dim?: number): Tensor[];
71
71
  expand(newShape: number[]): Tensor;
72
72
  unfold(dim: number, size: number, step: number): Tensor;
73
+ pad(pad: number[], mode?: string, value?: number): Tensor;
73
74
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
74
75
  stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
75
76
  squeeze(dims?: number[] | number): Tensor;
package/dist/core.js CHANGED
@@ -328,8 +328,14 @@ class Tensor {
328
328
  }
329
329
  const reducedGrad = accumGrad.sum(axesToReduce, true);
330
330
  const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
331
+ // Enforce 0-offset contiguous grads and correct dtype
331
332
  if (typeof tensor.grad === "undefined") {
332
- tensor.grad = squeezedGrad;
333
+ let grad = squeezedGrad;
334
+ // Handle potentially contiguous tensors with non zero offset
335
+ if (grad.offset !== 0) {
336
+ grad = grad.clone();
337
+ }
338
+ tensor.grad = grad.contiguous().cast(tensor.dtype);
333
339
  }
334
340
  else {
335
341
  tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
@@ -808,6 +814,70 @@ class Tensor {
808
814
  }
809
815
  return out;
810
816
  }
817
+ // Tensor padding
818
+ pad(pad, mode = "constant", value = 0) {
819
+ const original = this.clone().contiguous(); // This is needed for index padding to work
820
+ const outputShape = [...original.shape];
821
+ const paddingPerDim = [];
822
+ for (let i = 0; i < original.shape.length; i++) {
823
+ const left = pad[(original.shape.length - 1 - i) * 2] || 0;
824
+ const right = pad[(original.shape.length - 1 - i) * 2 + 1] || 0;
825
+ paddingPerDim[i] = { left, right };
826
+ outputShape[i] += left + right;
827
+ }
828
+ const outputSize = Tensor.shapeToSize(outputShape);
829
+ if (mode === "constant") {
830
+ const outputValue = new dtype_1.TypedArray[original.dtype](outputSize).fill(value);
831
+ const outputStrides = Tensor.getStrides(outputShape);
832
+ for (let index = 0; index < original.numel; index++) {
833
+ const coords = Tensor.indexToCoords(index, original.strides);
834
+ let paddedIndex = 0;
835
+ // Pad each coord
836
+ for (let j = 0; j < original.shape.length; j++) {
837
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
838
+ paddedIndex += shiftedCoord * outputStrides[j];
839
+ }
840
+ outputValue[paddedIndex] = original.value[index];
841
+ }
842
+ const out = new Tensor(outputValue, {
843
+ shape: outputShape,
844
+ strides: outputStrides,
845
+ offset: 0,
846
+ dtype: original.dtype,
847
+ device: original.device
848
+ });
849
+ if (original.requiresGrad) {
850
+ out.requiresGrad = true;
851
+ out.children.push(original);
852
+ out.gradFn = () => {
853
+ const outGrad = out.grad;
854
+ const gradValue = new dtype_1.TypedArray[original.dtype](original.numel);
855
+ const gradStrides = Tensor.getStrides(original.shape);
856
+ for (let index = 0; index < gradValue.length; index++) {
857
+ const coords = Tensor.indexToCoords(index, gradStrides);
858
+ let paddedIndex = 0;
859
+ // Pad each coord
860
+ for (let j = 0; j < original.shape.length; j++) {
861
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
862
+ paddedIndex += shiftedCoord * outputStrides[j];
863
+ }
864
+ gradValue[index] = outGrad.value[paddedIndex];
865
+ }
866
+ Tensor.addGrad(original, new Tensor(gradValue, {
867
+ shape: original.shape,
868
+ strides: gradStrides,
869
+ offset: 0,
870
+ dtype: original.dtype,
871
+ device: original.device
872
+ }));
873
+ };
874
+ }
875
+ return out;
876
+ }
877
+ else {
878
+ throw new Error(`Padding mode not supported: "${mode}"`);
879
+ }
880
+ }
811
881
  // Tensor concatentation
812
882
  cat(other, dim = 0) {
813
883
  other = this.handleOther(other);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.22",
3
+ "version": "0.8.23",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {