catniff 0.8.23 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +70 -2
- package/dist/nn.d.ts +11 -0
- package/dist/nn.js +26 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -219,6 +219,7 @@ export declare class Tensor {
|
|
|
219
219
|
mv(other: TensorValue | Tensor): Tensor;
|
|
220
220
|
matmul(other: TensorValue | Tensor): Tensor;
|
|
221
221
|
tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
|
|
222
|
+
conv2d(weight: Tensor | TensorValue, bias?: Tensor | TensorValue, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number): Tensor;
|
|
222
223
|
dropout(rate: number): Tensor;
|
|
223
224
|
triu(diagonal?: number): Tensor;
|
|
224
225
|
tril(diagonal?: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -801,12 +801,12 @@ class Tensor {
|
|
|
801
801
|
const outGrad = out.grad;
|
|
802
802
|
const grad = Tensor.zerosLike(this);
|
|
803
803
|
for (let i = 0; i < out.numel; i++) {
|
|
804
|
-
const coords = Tensor.indexToCoords(i,
|
|
804
|
+
const coords = Tensor.indexToCoords(i, Tensor.getStrides(outGrad.shape));
|
|
805
805
|
const windowIdx = coords[dim];
|
|
806
806
|
const withinWindow = coords[coords.length - 1];
|
|
807
807
|
coords[dim] = windowIdx * step + withinWindow;
|
|
808
808
|
coords.pop();
|
|
809
|
-
const sourceIdx = Tensor.coordsToIndex(coords,
|
|
809
|
+
const sourceIdx = Tensor.coordsToIndex(coords, Tensor.getStrides(grad.shape));
|
|
810
810
|
grad.value[sourceIdx] += outGrad.value[i];
|
|
811
811
|
}
|
|
812
812
|
Tensor.addGrad(this, grad);
|
|
@@ -2183,6 +2183,74 @@ class Tensor {
|
|
|
2183
2183
|
];
|
|
2184
2184
|
return result2D.reshape(finalShape);
|
|
2185
2185
|
}
|
|
2186
|
+
// 2D convolution
|
|
2187
|
+
conv2d(weight, bias, stride = 1, padding = 0, dilation = 1, groups = 1) {
|
|
2188
|
+
weight = this.handleOther(weight);
|
|
2189
|
+
const [sH, sW] = Array.isArray(stride) ? stride : [stride, stride];
|
|
2190
|
+
const [pH, pW] = Array.isArray(padding) ? padding : [padding, padding];
|
|
2191
|
+
const [dH, dW] = Array.isArray(dilation) ? dilation : [dilation, dilation];
|
|
2192
|
+
const [N, Cin, H, W] = this.shape;
|
|
2193
|
+
const [Cout, CinPerGroup, kH, kW] = weight.shape;
|
|
2194
|
+
// Pad input
|
|
2195
|
+
let x = (pH > 0 || pW > 0) ? this.pad([pW, pW, pH, pH]) : this;
|
|
2196
|
+
const Hp = H + 2 * pH;
|
|
2197
|
+
const Wp = W + 2 * pW;
|
|
2198
|
+
const Hout = Math.floor((Hp - dH * (kH - 1) - 1) / sH + 1);
|
|
2199
|
+
const Wout = Math.floor((Wp - dW * (kW - 1) - 1) / sW + 1);
|
|
2200
|
+
// Unfold H with a window large enough to cover the dilated kernel extent,
|
|
2201
|
+
// then slice every dH-th position to realise the dilation holes.
|
|
2202
|
+
// x: [N, Cin, Hp, Wp]
|
|
2203
|
+
// -> unfold(2, dH*(kH-1)+1, sH)
|
|
2204
|
+
// -> [N, Cin, Hout, Wp, dH*(kH-1)+1]
|
|
2205
|
+
// -> slice step dH on last dim
|
|
2206
|
+
// -> [N, Cin, Hout, Wp, kH]
|
|
2207
|
+
const dilKH = dH * (kH - 1) + 1;
|
|
2208
|
+
x = x.unfold(2, dilKH, sH);
|
|
2209
|
+
if (dH > 1)
|
|
2210
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wp], [0, dilKH, dH]]);
|
|
2211
|
+
// Unfold W
|
|
2212
|
+
// x: [N, Cin, Hout, Wp, kH]
|
|
2213
|
+
// -> unfold(3, dW*(kW-1)+1, sW)
|
|
2214
|
+
// -> [N, Cin, Hout, Wout, kH, dW*(kW-1)+1]
|
|
2215
|
+
// -> slice step dW on last dim
|
|
2216
|
+
// -> [N, Cin, Hout, Wout, kH, kW]
|
|
2217
|
+
const dilKW = dW * (kW - 1) + 1;
|
|
2218
|
+
x = x.unfold(3, dilKW, sW);
|
|
2219
|
+
if (dW > 1)
|
|
2220
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wout], [0, kH], [0, dilKW, dW]]);
|
|
2221
|
+
// Reshape patches to [N, Hout*Wout, Cin*kH*kW]
|
|
2222
|
+
// permute [0,2,3,1,4,5] -> [N, Hout, Wout, Cin, kH, kW]
|
|
2223
|
+
// then reshape merges the spatial and channel-kernel dims.
|
|
2224
|
+
// reshape() forces contiguity internally so no explicit .contiguous() needed.
|
|
2225
|
+
x = x.permute([0, 2, 3, 1, 4, 5]).reshape([N, Hout * Wout, Cin * kH * kW]);
|
|
2226
|
+
// Matmul with weight
|
|
2227
|
+
// weight: [Cout, CinPerGroup, kH, kW] -> [Cout, CinPerGroup*kH*kW]
|
|
2228
|
+
const w = weight.reshape([Cout, CinPerGroup * kH * kW]);
|
|
2229
|
+
let out;
|
|
2230
|
+
if (groups === 1) {
|
|
2231
|
+
// x: [N, Hout*Wout, Cin*kH*kW] @ w.t(): [Cin*kH*kW, Cout]
|
|
2232
|
+
// -> [N, Hout*Wout, Cout]
|
|
2233
|
+
out = x.matmul(w.t());
|
|
2234
|
+
}
|
|
2235
|
+
else {
|
|
2236
|
+
// Each group handles Cin/groups input channels and Cout/groups output channels.
|
|
2237
|
+
// chunk(groups, 2) splits the Cin*kH*kW axis into groups equal slices,
|
|
2238
|
+
// each of size CinPerGroup*kH*kW — valid because reshape laid Cin outermost.
|
|
2239
|
+
const patchChunks = x.chunk(groups, 2); // Tensor[groups], each [N, Hout*Wout, CinPerGroup*kH*kW]
|
|
2240
|
+
const weightChunks = w.chunk(groups, 0); // Tensor[groups], each [Cout/groups, CinPerGroup*kH*kW]
|
|
2241
|
+
const groupOuts = patchChunks.map((patch, i) => patch.matmul(weightChunks[i].t()) // [N, Hout*Wout, Cout/groups]
|
|
2242
|
+
);
|
|
2243
|
+
// Cat all group outputs along the channel axis (dim 2)
|
|
2244
|
+
out = groupOuts.reduce((acc, g) => acc.cat(g, 2)); // [N, Hout*Wout, Cout]
|
|
2245
|
+
}
|
|
2246
|
+
// Restore [N, Cout, Hout, Wout]
|
|
2247
|
+
out = out.permute([0, 2, 1]).reshape([N, Cout, Hout, Wout]);
|
|
2248
|
+
// Bias
|
|
2249
|
+
if (bias) {
|
|
2250
|
+
out = out.add(this.handleOther(bias).reshape([1, Cout, 1, 1]));
|
|
2251
|
+
}
|
|
2252
|
+
return out;
|
|
2253
|
+
}
|
|
2186
2254
|
// Dropout
|
|
2187
2255
|
dropout(rate) {
|
|
2188
2256
|
if (!Tensor.training || rate === 0)
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,6 +55,16 @@ export declare class LSTMCell {
|
|
|
55
55
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
56
56
|
forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
|
|
57
57
|
}
|
|
58
|
+
export declare class Conv2d {
|
|
59
|
+
weight: Tensor;
|
|
60
|
+
bias?: Tensor;
|
|
61
|
+
stride: number | [number, number];
|
|
62
|
+
padding: number | [number, number];
|
|
63
|
+
dilation: number | [number, number];
|
|
64
|
+
groups: number;
|
|
65
|
+
constructor(inChannels: number, outChannels: number, kernelSize: number, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
66
|
+
forward(input: Tensor): Tensor;
|
|
67
|
+
}
|
|
58
68
|
export declare class BatchNorm {
|
|
59
69
|
weight?: Tensor;
|
|
60
70
|
bias?: Tensor;
|
|
@@ -127,6 +137,7 @@ export declare const nn: {
|
|
|
127
137
|
RNNCell: typeof RNNCell;
|
|
128
138
|
GRUCell: typeof GRUCell;
|
|
129
139
|
LSTMCell: typeof LSTMCell;
|
|
140
|
+
Conv2d: typeof Conv2d;
|
|
130
141
|
BatchNorm: typeof BatchNorm;
|
|
131
142
|
InstanceNorm: typeof InstanceNorm;
|
|
132
143
|
GroupNorm: typeof GroupNorm;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.Conv2d = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class Linear {
|
|
6
6
|
weight;
|
|
@@ -136,6 +136,30 @@ class LSTMCell {
|
|
|
136
136
|
}
|
|
137
137
|
}
|
|
138
138
|
exports.LSTMCell = LSTMCell;
|
|
139
|
+
class Conv2d {
|
|
140
|
+
weight;
|
|
141
|
+
bias;
|
|
142
|
+
stride;
|
|
143
|
+
padding;
|
|
144
|
+
dilation;
|
|
145
|
+
groups;
|
|
146
|
+
constructor(inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1, groups = 1, bias = true, device, dtype) {
|
|
147
|
+
this.stride = stride;
|
|
148
|
+
this.padding = padding;
|
|
149
|
+
this.dilation = dilation;
|
|
150
|
+
this.groups = groups;
|
|
151
|
+
const fanIn = (inChannels / groups) * kernelSize * kernelSize;
|
|
152
|
+
const bound = Math.sqrt(1 / fanIn);
|
|
153
|
+
this.weight = core_1.Tensor.uniform([outChannels, inChannels / groups, kernelSize, kernelSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
154
|
+
if (bias) {
|
|
155
|
+
this.bias = core_1.Tensor.uniform([outChannels], -bound, bound, { requiresGrad: true, device, dtype });
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
forward(input) {
|
|
159
|
+
return input.conv2d(this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
exports.Conv2d = Conv2d;
|
|
139
163
|
class BatchNorm {
|
|
140
164
|
weight;
|
|
141
165
|
bias;
|
|
@@ -436,6 +460,7 @@ exports.nn = {
|
|
|
436
460
|
RNNCell,
|
|
437
461
|
GRUCell,
|
|
438
462
|
LSTMCell,
|
|
463
|
+
Conv2d,
|
|
439
464
|
BatchNorm,
|
|
440
465
|
InstanceNorm,
|
|
441
466
|
GroupNorm,
|