catniff 0.8.22 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -70,6 +70,7 @@ export declare class Tensor {
70
70
  chunk(chunks: number, dim?: number): Tensor[];
71
71
  expand(newShape: number[]): Tensor;
72
72
  unfold(dim: number, size: number, step: number): Tensor;
73
+ pad(pad: number[], mode?: string, value?: number): Tensor;
73
74
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
74
75
  stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
75
76
  squeeze(dims?: number[] | number): Tensor;
@@ -218,6 +219,7 @@ export declare class Tensor {
218
219
  mv(other: TensorValue | Tensor): Tensor;
219
220
  matmul(other: TensorValue | Tensor): Tensor;
220
221
  tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
222
+ conv2d(weight: Tensor | TensorValue, bias?: Tensor | TensorValue, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number): Tensor;
221
223
  dropout(rate: number): Tensor;
222
224
  triu(diagonal?: number): Tensor;
223
225
  tril(diagonal?: number): Tensor;
package/dist/core.js CHANGED
@@ -328,8 +328,14 @@ class Tensor {
328
328
  }
329
329
  const reducedGrad = accumGrad.sum(axesToReduce, true);
330
330
  const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
331
+ // Enforce 0-offset contiguous grads and correct dtype
331
332
  if (typeof tensor.grad === "undefined") {
332
- tensor.grad = squeezedGrad;
333
+ let grad = squeezedGrad;
334
+ // Handle potentially contiguous tensors with non zero offset
335
+ if (grad.offset !== 0) {
336
+ grad = grad.clone();
337
+ }
338
+ tensor.grad = grad.contiguous().cast(tensor.dtype);
333
339
  }
334
340
  else {
335
341
  tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
@@ -795,12 +801,12 @@ class Tensor {
795
801
  const outGrad = out.grad;
796
802
  const grad = Tensor.zerosLike(this);
797
803
  for (let i = 0; i < out.numel; i++) {
798
- const coords = Tensor.indexToCoords(i, newStrides);
804
+ const coords = Tensor.indexToCoords(i, Tensor.getStrides(outGrad.shape));
799
805
  const windowIdx = coords[dim];
800
806
  const withinWindow = coords[coords.length - 1];
801
807
  coords[dim] = windowIdx * step + withinWindow;
802
808
  coords.pop();
803
- const sourceIdx = Tensor.coordsToIndex(coords, this.strides);
809
+ const sourceIdx = Tensor.coordsToIndex(coords, Tensor.getStrides(grad.shape));
804
810
  grad.value[sourceIdx] += outGrad.value[i];
805
811
  }
806
812
  Tensor.addGrad(this, grad);
@@ -808,6 +814,70 @@ class Tensor {
808
814
  }
809
815
  return out;
810
816
  }
817
+ // Tensor padding
818
+ pad(pad, mode = "constant", value = 0) {
819
+ const original = this.clone().contiguous(); // This is needed for index padding to work
820
+ const outputShape = [...original.shape];
821
+ const paddingPerDim = [];
822
+ for (let i = 0; i < original.shape.length; i++) {
823
+ const left = pad[(original.shape.length - 1 - i) * 2] || 0;
824
+ const right = pad[(original.shape.length - 1 - i) * 2 + 1] || 0;
825
+ paddingPerDim[i] = { left, right };
826
+ outputShape[i] += left + right;
827
+ }
828
+ const outputSize = Tensor.shapeToSize(outputShape);
829
+ if (mode === "constant") {
830
+ const outputValue = new dtype_1.TypedArray[original.dtype](outputSize).fill(value);
831
+ const outputStrides = Tensor.getStrides(outputShape);
832
+ for (let index = 0; index < original.numel; index++) {
833
+ const coords = Tensor.indexToCoords(index, original.strides);
834
+ let paddedIndex = 0;
835
+ // Pad each coord
836
+ for (let j = 0; j < original.shape.length; j++) {
837
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
838
+ paddedIndex += shiftedCoord * outputStrides[j];
839
+ }
840
+ outputValue[paddedIndex] = original.value[index];
841
+ }
842
+ const out = new Tensor(outputValue, {
843
+ shape: outputShape,
844
+ strides: outputStrides,
845
+ offset: 0,
846
+ dtype: original.dtype,
847
+ device: original.device
848
+ });
849
+ if (original.requiresGrad) {
850
+ out.requiresGrad = true;
851
+ out.children.push(original);
852
+ out.gradFn = () => {
853
+ const outGrad = out.grad;
854
+ const gradValue = new dtype_1.TypedArray[original.dtype](original.numel);
855
+ const gradStrides = Tensor.getStrides(original.shape);
856
+ for (let index = 0; index < gradValue.length; index++) {
857
+ const coords = Tensor.indexToCoords(index, gradStrides);
858
+ let paddedIndex = 0;
859
+ // Pad each coord
860
+ for (let j = 0; j < original.shape.length; j++) {
861
+ const shiftedCoord = coords[j] + paddingPerDim[j].left;
862
+ paddedIndex += shiftedCoord * outputStrides[j];
863
+ }
864
+ gradValue[index] = outGrad.value[paddedIndex];
865
+ }
866
+ Tensor.addGrad(original, new Tensor(gradValue, {
867
+ shape: original.shape,
868
+ strides: gradStrides,
869
+ offset: 0,
870
+ dtype: original.dtype,
871
+ device: original.device
872
+ }));
873
+ };
874
+ }
875
+ return out;
876
+ }
877
+ else {
878
+ throw new Error(`Padding mode not supported: "${mode}"`);
879
+ }
880
+ }
811
881
  // Tensor concatentation
812
882
  cat(other, dim = 0) {
813
883
  other = this.handleOther(other);
@@ -2113,6 +2183,74 @@ class Tensor {
2113
2183
  ];
2114
2184
  return result2D.reshape(finalShape);
2115
2185
  }
2186
+ // 2D convolution
2187
+ conv2d(weight, bias, stride = 1, padding = 0, dilation = 1, groups = 1) {
2188
+ weight = this.handleOther(weight);
2189
+ const [sH, sW] = Array.isArray(stride) ? stride : [stride, stride];
2190
+ const [pH, pW] = Array.isArray(padding) ? padding : [padding, padding];
2191
+ const [dH, dW] = Array.isArray(dilation) ? dilation : [dilation, dilation];
2192
+ const [N, Cin, H, W] = this.shape;
2193
+ const [Cout, CinPerGroup, kH, kW] = weight.shape;
2194
+ // Pad input
2195
+ let x = (pH > 0 || pW > 0) ? this.pad([pW, pW, pH, pH]) : this;
2196
+ const Hp = H + 2 * pH;
2197
+ const Wp = W + 2 * pW;
2198
+ const Hout = Math.floor((Hp - dH * (kH - 1) - 1) / sH + 1);
2199
+ const Wout = Math.floor((Wp - dW * (kW - 1) - 1) / sW + 1);
2200
+ // Unfold H with a window large enough to cover the dilated kernel extent,
2201
+ // then slice every dH-th position to realise the dilation holes.
2202
+ // x: [N, Cin, Hp, Wp]
2203
+ // -> unfold(2, dH*(kH-1)+1, sH)
2204
+ // -> [N, Cin, Hout, Wp, dH*(kH-1)+1]
2205
+ // -> slice step dH on last dim
2206
+ // -> [N, Cin, Hout, Wp, kH]
2207
+ const dilKH = dH * (kH - 1) + 1;
2208
+ x = x.unfold(2, dilKH, sH);
2209
+ if (dH > 1)
2210
+ x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wp], [0, dilKH, dH]]);
2211
+ // Unfold W
2212
+ // x: [N, Cin, Hout, Wp, kH]
2213
+ // -> unfold(3, dW*(kW-1)+1, sW)
2214
+ // -> [N, Cin, Hout, Wout, kH, dW*(kW-1)+1]
2215
+ // -> slice step dW on last dim
2216
+ // -> [N, Cin, Hout, Wout, kH, kW]
2217
+ const dilKW = dW * (kW - 1) + 1;
2218
+ x = x.unfold(3, dilKW, sW);
2219
+ if (dW > 1)
2220
+ x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wout], [0, kH], [0, dilKW, dW]]);
2221
+ // Reshape patches to [N, Hout*Wout, Cin*kH*kW]
2222
+ // permute [0,2,3,1,4,5] -> [N, Hout, Wout, Cin, kH, kW]
2223
+ // then reshape merges the spatial and channel-kernel dims.
2224
+ // reshape() forces contiguity internally so no explicit .contiguous() needed.
2225
+ x = x.permute([0, 2, 3, 1, 4, 5]).reshape([N, Hout * Wout, Cin * kH * kW]);
2226
+ // Matmul with weight
2227
+ // weight: [Cout, CinPerGroup, kH, kW] -> [Cout, CinPerGroup*kH*kW]
2228
+ const w = weight.reshape([Cout, CinPerGroup * kH * kW]);
2229
+ let out;
2230
+ if (groups === 1) {
2231
+ // x: [N, Hout*Wout, Cin*kH*kW] @ w.t(): [Cin*kH*kW, Cout]
2232
+ // -> [N, Hout*Wout, Cout]
2233
+ out = x.matmul(w.t());
2234
+ }
2235
+ else {
2236
+ // Each group handles Cin/groups input channels and Cout/groups output channels.
2237
+ // chunk(groups, 2) splits the Cin*kH*kW axis into groups equal slices,
2238
+ // each of size CinPerGroup*kH*kW — valid because reshape laid Cin outermost.
2239
+ const patchChunks = x.chunk(groups, 2); // Tensor[groups], each [N, Hout*Wout, CinPerGroup*kH*kW]
2240
+ const weightChunks = w.chunk(groups, 0); // Tensor[groups], each [Cout/groups, CinPerGroup*kH*kW]
2241
+ const groupOuts = patchChunks.map((patch, i) => patch.matmul(weightChunks[i].t()) // [N, Hout*Wout, Cout/groups]
2242
+ );
2243
+ // Cat all group outputs along the channel axis (dim 2)
2244
+ out = groupOuts.reduce((acc, g) => acc.cat(g, 2)); // [N, Hout*Wout, Cout]
2245
+ }
2246
+ // Restore [N, Cout, Hout, Wout]
2247
+ out = out.permute([0, 2, 1]).reshape([N, Cout, Hout, Wout]);
2248
+ // Bias
2249
+ if (bias) {
2250
+ out = out.add(this.handleOther(bias).reshape([1, Cout, 1, 1]));
2251
+ }
2252
+ return out;
2253
+ }
2116
2254
  // Dropout
2117
2255
  dropout(rate) {
2118
2256
  if (!Tensor.training || rate === 0)
package/dist/nn.d.ts CHANGED
@@ -55,6 +55,16 @@ export declare class LSTMCell {
55
55
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
56
56
  forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
57
57
  }
58
+ export declare class Conv2d {
59
+ weight: Tensor;
60
+ bias?: Tensor;
61
+ stride: number | [number, number];
62
+ padding: number | [number, number];
63
+ dilation: number | [number, number];
64
+ groups: number;
65
+ constructor(inChannels: number, outChannels: number, kernelSize: number, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number, bias?: boolean, device?: string, dtype?: dtype);
66
+ forward(input: Tensor): Tensor;
67
+ }
58
68
  export declare class BatchNorm {
59
69
  weight?: Tensor;
60
70
  bias?: Tensor;
@@ -127,6 +137,7 @@ export declare const nn: {
127
137
  RNNCell: typeof RNNCell;
128
138
  GRUCell: typeof GRUCell;
129
139
  LSTMCell: typeof LSTMCell;
140
+ Conv2d: typeof Conv2d;
130
141
  BatchNorm: typeof BatchNorm;
131
142
  InstanceNorm: typeof InstanceNorm;
132
143
  GroupNorm: typeof GroupNorm;
package/dist/nn.js CHANGED
@@ -1,6 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
3
+ exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.Conv2d = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
4
4
  const core_1 = require("./core");
5
5
  class Linear {
6
6
  weight;
@@ -136,6 +136,30 @@ class LSTMCell {
136
136
  }
137
137
  }
138
138
  exports.LSTMCell = LSTMCell;
139
+ class Conv2d {
140
+ weight;
141
+ bias;
142
+ stride;
143
+ padding;
144
+ dilation;
145
+ groups;
146
+ constructor(inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1, groups = 1, bias = true, device, dtype) {
147
+ this.stride = stride;
148
+ this.padding = padding;
149
+ this.dilation = dilation;
150
+ this.groups = groups;
151
+ const fanIn = (inChannels / groups) * kernelSize * kernelSize;
152
+ const bound = Math.sqrt(1 / fanIn);
153
+ this.weight = core_1.Tensor.uniform([outChannels, inChannels / groups, kernelSize, kernelSize], -bound, bound, { requiresGrad: true, device, dtype });
154
+ if (bias) {
155
+ this.bias = core_1.Tensor.uniform([outChannels], -bound, bound, { requiresGrad: true, device, dtype });
156
+ }
157
+ }
158
+ forward(input) {
159
+ return input.conv2d(this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups);
160
+ }
161
+ }
162
+ exports.Conv2d = Conv2d;
139
163
  class BatchNorm {
140
164
  weight;
141
165
  bias;
@@ -436,6 +460,7 @@ exports.nn = {
436
460
  RNNCell,
437
461
  GRUCell,
438
462
  LSTMCell,
463
+ Conv2d,
439
464
  BatchNorm,
440
465
  InstanceNorm,
441
466
  GroupNorm,
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.22",
3
+ "version": "0.9.0",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {