catniff 0.8.23 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -219,6 +219,7 @@ export declare class Tensor {
219
219
  mv(other: TensorValue | Tensor): Tensor;
220
220
  matmul(other: TensorValue | Tensor): Tensor;
221
221
  tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
222
+ conv2d(weight: Tensor | TensorValue, bias?: Tensor | TensorValue, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number): Tensor;
222
223
  dropout(rate: number): Tensor;
223
224
  triu(diagonal?: number): Tensor;
224
225
  tril(diagonal?: number): Tensor;
package/dist/core.js CHANGED
@@ -801,12 +801,12 @@ class Tensor {
801
801
  const outGrad = out.grad;
802
802
  const grad = Tensor.zerosLike(this);
803
803
  for (let i = 0; i < out.numel; i++) {
804
- const coords = Tensor.indexToCoords(i, newStrides);
804
+ const coords = Tensor.indexToCoords(i, Tensor.getStrides(outGrad.shape));
805
805
  const windowIdx = coords[dim];
806
806
  const withinWindow = coords[coords.length - 1];
807
807
  coords[dim] = windowIdx * step + withinWindow;
808
808
  coords.pop();
809
- const sourceIdx = Tensor.coordsToIndex(coords, this.strides);
809
+ const sourceIdx = Tensor.coordsToIndex(coords, Tensor.getStrides(grad.shape));
810
810
  grad.value[sourceIdx] += outGrad.value[i];
811
811
  }
812
812
  Tensor.addGrad(this, grad);
@@ -2183,6 +2183,74 @@ class Tensor {
2183
2183
  ];
2184
2184
  return result2D.reshape(finalShape);
2185
2185
  }
2186
+ // 2D convolution
2187
+ conv2d(weight, bias, stride = 1, padding = 0, dilation = 1, groups = 1) {
2188
+ weight = this.handleOther(weight);
2189
+ const [sH, sW] = Array.isArray(stride) ? stride : [stride, stride];
2190
+ const [pH, pW] = Array.isArray(padding) ? padding : [padding, padding];
2191
+ const [dH, dW] = Array.isArray(dilation) ? dilation : [dilation, dilation];
2192
+ const [N, Cin, H, W] = this.shape;
2193
+ const [Cout, CinPerGroup, kH, kW] = weight.shape;
2194
+ // Pad input
2195
+ let x = (pH > 0 || pW > 0) ? this.pad([pW, pW, pH, pH]) : this;
2196
+ const Hp = H + 2 * pH;
2197
+ const Wp = W + 2 * pW;
2198
+ const Hout = Math.floor((Hp - dH * (kH - 1) - 1) / sH + 1);
2199
+ const Wout = Math.floor((Wp - dW * (kW - 1) - 1) / sW + 1);
2200
+ // Unfold H with a window large enough to cover the dilated kernel extent,
2201
+ // then slice every dH-th position to realise the dilation holes.
2202
+ // x: [N, Cin, Hp, Wp]
2203
+ // -> unfold(2, dH*(kH-1)+1, sH)
2204
+ // -> [N, Cin, Hout, Wp, dH*(kH-1)+1]
2205
+ // -> slice step dH on last dim
2206
+ // -> [N, Cin, Hout, Wp, kH]
2207
+ const dilKH = dH * (kH - 1) + 1;
2208
+ x = x.unfold(2, dilKH, sH);
2209
+ if (dH > 1)
2210
+ x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wp], [0, dilKH, dH]]);
2211
+ // Unfold W
2212
+ // x: [N, Cin, Hout, Wp, kH]
2213
+ // -> unfold(3, dW*(kW-1)+1, sW)
2214
+ // -> [N, Cin, Hout, Wout, kH, dW*(kW-1)+1]
2215
+ // -> slice step dW on last dim
2216
+ // -> [N, Cin, Hout, Wout, kH, kW]
2217
+ const dilKW = dW * (kW - 1) + 1;
2218
+ x = x.unfold(3, dilKW, sW);
2219
+ if (dW > 1)
2220
+ x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wout], [0, kH], [0, dilKW, dW]]);
2221
+ // Reshape patches to [N, Hout*Wout, Cin*kH*kW]
2222
+ // permute [0,2,3,1,4,5] -> [N, Hout, Wout, Cin, kH, kW]
2223
+ // then reshape merges the spatial and channel-kernel dims.
2224
+ // reshape() forces contiguity internally so no explicit .contiguous() needed.
2225
+ x = x.permute([0, 2, 3, 1, 4, 5]).reshape([N, Hout * Wout, Cin * kH * kW]);
2226
+ // Matmul with weight
2227
+ // weight: [Cout, CinPerGroup, kH, kW] -> [Cout, CinPerGroup*kH*kW]
2228
+ const w = weight.reshape([Cout, CinPerGroup * kH * kW]);
2229
+ let out;
2230
+ if (groups === 1) {
2231
+ // x: [N, Hout*Wout, Cin*kH*kW] @ w.t(): [Cin*kH*kW, Cout]
2232
+ // -> [N, Hout*Wout, Cout]
2233
+ out = x.matmul(w.t());
2234
+ }
2235
+ else {
2236
+ // Each group handles Cin/groups input channels and Cout/groups output channels.
2237
+ // chunk(groups, 2) splits the Cin*kH*kW axis into groups equal slices,
2238
+ // each of size CinPerGroup*kH*kW — valid because reshape laid Cin outermost.
2239
+ const patchChunks = x.chunk(groups, 2); // Tensor[groups], each [N, Hout*Wout, CinPerGroup*kH*kW]
2240
+ const weightChunks = w.chunk(groups, 0); // Tensor[groups], each [Cout/groups, CinPerGroup*kH*kW]
2241
+ const groupOuts = patchChunks.map((patch, i) => patch.matmul(weightChunks[i].t()) // [N, Hout*Wout, Cout/groups]
2242
+ );
2243
+ // Cat all group outputs along the channel axis (dim 2)
2244
+ out = groupOuts.reduce((acc, g) => acc.cat(g, 2)); // [N, Hout*Wout, Cout]
2245
+ }
2246
+ // Restore [N, Cout, Hout, Wout]
2247
+ out = out.permute([0, 2, 1]).reshape([N, Cout, Hout, Wout]);
2248
+ // Bias
2249
+ if (bias) {
2250
+ out = out.add(this.handleOther(bias).reshape([1, Cout, 1, 1]));
2251
+ }
2252
+ return out;
2253
+ }
2186
2254
  // Dropout
2187
2255
  dropout(rate) {
2188
2256
  if (!Tensor.training || rate === 0)
package/dist/nn.d.ts CHANGED
@@ -55,6 +55,16 @@ export declare class LSTMCell {
55
55
  constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
56
56
  forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
57
57
  }
58
+ export declare class Conv2d {
59
+ weight: Tensor;
60
+ bias?: Tensor;
61
+ stride: number | [number, number];
62
+ padding: number | [number, number];
63
+ dilation: number | [number, number];
64
+ groups: number;
65
+ constructor(inChannels: number, outChannels: number, kernelSize: number, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number, bias?: boolean, device?: string, dtype?: dtype);
66
+ forward(input: Tensor): Tensor;
67
+ }
58
68
  export declare class BatchNorm {
59
69
  weight?: Tensor;
60
70
  bias?: Tensor;
@@ -127,6 +137,7 @@ export declare const nn: {
127
137
  RNNCell: typeof RNNCell;
128
138
  GRUCell: typeof GRUCell;
129
139
  LSTMCell: typeof LSTMCell;
140
+ Conv2d: typeof Conv2d;
130
141
  BatchNorm: typeof BatchNorm;
131
142
  InstanceNorm: typeof InstanceNorm;
132
143
  GroupNorm: typeof GroupNorm;
package/dist/nn.js CHANGED
@@ -1,6 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
3
+ exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.Conv2d = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
4
4
  const core_1 = require("./core");
5
5
  class Linear {
6
6
  weight;
@@ -136,6 +136,30 @@ class LSTMCell {
136
136
  }
137
137
  }
138
138
  exports.LSTMCell = LSTMCell;
139
+ class Conv2d {
140
+ weight;
141
+ bias;
142
+ stride;
143
+ padding;
144
+ dilation;
145
+ groups;
146
+ constructor(inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1, groups = 1, bias = true, device, dtype) {
147
+ this.stride = stride;
148
+ this.padding = padding;
149
+ this.dilation = dilation;
150
+ this.groups = groups;
151
+ const fanIn = (inChannels / groups) * kernelSize * kernelSize;
152
+ const bound = Math.sqrt(1 / fanIn);
153
+ this.weight = core_1.Tensor.uniform([outChannels, inChannels / groups, kernelSize, kernelSize], -bound, bound, { requiresGrad: true, device, dtype });
154
+ if (bias) {
155
+ this.bias = core_1.Tensor.uniform([outChannels], -bound, bound, { requiresGrad: true, device, dtype });
156
+ }
157
+ }
158
+ forward(input) {
159
+ return input.conv2d(this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups);
160
+ }
161
+ }
162
+ exports.Conv2d = Conv2d;
139
163
  class BatchNorm {
140
164
  weight;
141
165
  bias;
@@ -436,6 +460,7 @@ exports.nn = {
436
460
  RNNCell,
437
461
  GRUCell,
438
462
  LSTMCell,
463
+ Conv2d,
439
464
  BatchNorm,
440
465
  InstanceNorm,
441
466
  GroupNorm,
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.23",
3
+ "version": "0.9.0",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {