catniff 0.8.22 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -0
- package/dist/core.js +141 -3
- package/dist/nn.d.ts +11 -0
- package/dist/nn.js +26 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -70,6 +70,7 @@ export declare class Tensor {
|
|
|
70
70
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
71
71
|
expand(newShape: number[]): Tensor;
|
|
72
72
|
unfold(dim: number, size: number, step: number): Tensor;
|
|
73
|
+
pad(pad: number[], mode?: string, value?: number): Tensor;
|
|
73
74
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
74
75
|
stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
|
|
75
76
|
squeeze(dims?: number[] | number): Tensor;
|
|
@@ -218,6 +219,7 @@ export declare class Tensor {
|
|
|
218
219
|
mv(other: TensorValue | Tensor): Tensor;
|
|
219
220
|
matmul(other: TensorValue | Tensor): Tensor;
|
|
220
221
|
tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
|
|
222
|
+
conv2d(weight: Tensor | TensorValue, bias?: Tensor | TensorValue, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number): Tensor;
|
|
221
223
|
dropout(rate: number): Tensor;
|
|
222
224
|
triu(diagonal?: number): Tensor;
|
|
223
225
|
tril(diagonal?: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -328,8 +328,14 @@ class Tensor {
|
|
|
328
328
|
}
|
|
329
329
|
const reducedGrad = accumGrad.sum(axesToReduce, true);
|
|
330
330
|
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
|
|
331
|
+
// Enforce 0-offset contiguous grads and correct dtype
|
|
331
332
|
if (typeof tensor.grad === "undefined") {
|
|
332
|
-
|
|
333
|
+
let grad = squeezedGrad;
|
|
334
|
+
// Handle potentially contiguous tensors with non zero offset
|
|
335
|
+
if (grad.offset !== 0) {
|
|
336
|
+
grad = grad.clone();
|
|
337
|
+
}
|
|
338
|
+
tensor.grad = grad.contiguous().cast(tensor.dtype);
|
|
333
339
|
}
|
|
334
340
|
else {
|
|
335
341
|
tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
|
|
@@ -795,12 +801,12 @@ class Tensor {
|
|
|
795
801
|
const outGrad = out.grad;
|
|
796
802
|
const grad = Tensor.zerosLike(this);
|
|
797
803
|
for (let i = 0; i < out.numel; i++) {
|
|
798
|
-
const coords = Tensor.indexToCoords(i,
|
|
804
|
+
const coords = Tensor.indexToCoords(i, Tensor.getStrides(outGrad.shape));
|
|
799
805
|
const windowIdx = coords[dim];
|
|
800
806
|
const withinWindow = coords[coords.length - 1];
|
|
801
807
|
coords[dim] = windowIdx * step + withinWindow;
|
|
802
808
|
coords.pop();
|
|
803
|
-
const sourceIdx = Tensor.coordsToIndex(coords,
|
|
809
|
+
const sourceIdx = Tensor.coordsToIndex(coords, Tensor.getStrides(grad.shape));
|
|
804
810
|
grad.value[sourceIdx] += outGrad.value[i];
|
|
805
811
|
}
|
|
806
812
|
Tensor.addGrad(this, grad);
|
|
@@ -808,6 +814,70 @@ class Tensor {
|
|
|
808
814
|
}
|
|
809
815
|
return out;
|
|
810
816
|
}
|
|
817
|
+
// Tensor padding
|
|
818
|
+
pad(pad, mode = "constant", value = 0) {
|
|
819
|
+
const original = this.clone().contiguous(); // This is needed for index padding to work
|
|
820
|
+
const outputShape = [...original.shape];
|
|
821
|
+
const paddingPerDim = [];
|
|
822
|
+
for (let i = 0; i < original.shape.length; i++) {
|
|
823
|
+
const left = pad[(original.shape.length - 1 - i) * 2] || 0;
|
|
824
|
+
const right = pad[(original.shape.length - 1 - i) * 2 + 1] || 0;
|
|
825
|
+
paddingPerDim[i] = { left, right };
|
|
826
|
+
outputShape[i] += left + right;
|
|
827
|
+
}
|
|
828
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
829
|
+
if (mode === "constant") {
|
|
830
|
+
const outputValue = new dtype_1.TypedArray[original.dtype](outputSize).fill(value);
|
|
831
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
832
|
+
for (let index = 0; index < original.numel; index++) {
|
|
833
|
+
const coords = Tensor.indexToCoords(index, original.strides);
|
|
834
|
+
let paddedIndex = 0;
|
|
835
|
+
// Pad each coord
|
|
836
|
+
for (let j = 0; j < original.shape.length; j++) {
|
|
837
|
+
const shiftedCoord = coords[j] + paddingPerDim[j].left;
|
|
838
|
+
paddedIndex += shiftedCoord * outputStrides[j];
|
|
839
|
+
}
|
|
840
|
+
outputValue[paddedIndex] = original.value[index];
|
|
841
|
+
}
|
|
842
|
+
const out = new Tensor(outputValue, {
|
|
843
|
+
shape: outputShape,
|
|
844
|
+
strides: outputStrides,
|
|
845
|
+
offset: 0,
|
|
846
|
+
dtype: original.dtype,
|
|
847
|
+
device: original.device
|
|
848
|
+
});
|
|
849
|
+
if (original.requiresGrad) {
|
|
850
|
+
out.requiresGrad = true;
|
|
851
|
+
out.children.push(original);
|
|
852
|
+
out.gradFn = () => {
|
|
853
|
+
const outGrad = out.grad;
|
|
854
|
+
const gradValue = new dtype_1.TypedArray[original.dtype](original.numel);
|
|
855
|
+
const gradStrides = Tensor.getStrides(original.shape);
|
|
856
|
+
for (let index = 0; index < gradValue.length; index++) {
|
|
857
|
+
const coords = Tensor.indexToCoords(index, gradStrides);
|
|
858
|
+
let paddedIndex = 0;
|
|
859
|
+
// Pad each coord
|
|
860
|
+
for (let j = 0; j < original.shape.length; j++) {
|
|
861
|
+
const shiftedCoord = coords[j] + paddingPerDim[j].left;
|
|
862
|
+
paddedIndex += shiftedCoord * outputStrides[j];
|
|
863
|
+
}
|
|
864
|
+
gradValue[index] = outGrad.value[paddedIndex];
|
|
865
|
+
}
|
|
866
|
+
Tensor.addGrad(original, new Tensor(gradValue, {
|
|
867
|
+
shape: original.shape,
|
|
868
|
+
strides: gradStrides,
|
|
869
|
+
offset: 0,
|
|
870
|
+
dtype: original.dtype,
|
|
871
|
+
device: original.device
|
|
872
|
+
}));
|
|
873
|
+
};
|
|
874
|
+
}
|
|
875
|
+
return out;
|
|
876
|
+
}
|
|
877
|
+
else {
|
|
878
|
+
throw new Error(`Padding mode not supported: "${mode}"`);
|
|
879
|
+
}
|
|
880
|
+
}
|
|
811
881
|
// Tensor concatentation
|
|
812
882
|
cat(other, dim = 0) {
|
|
813
883
|
other = this.handleOther(other);
|
|
@@ -2113,6 +2183,74 @@ class Tensor {
|
|
|
2113
2183
|
];
|
|
2114
2184
|
return result2D.reshape(finalShape);
|
|
2115
2185
|
}
|
|
2186
|
+
// 2D convolution
|
|
2187
|
+
conv2d(weight, bias, stride = 1, padding = 0, dilation = 1, groups = 1) {
|
|
2188
|
+
weight = this.handleOther(weight);
|
|
2189
|
+
const [sH, sW] = Array.isArray(stride) ? stride : [stride, stride];
|
|
2190
|
+
const [pH, pW] = Array.isArray(padding) ? padding : [padding, padding];
|
|
2191
|
+
const [dH, dW] = Array.isArray(dilation) ? dilation : [dilation, dilation];
|
|
2192
|
+
const [N, Cin, H, W] = this.shape;
|
|
2193
|
+
const [Cout, CinPerGroup, kH, kW] = weight.shape;
|
|
2194
|
+
// Pad input
|
|
2195
|
+
let x = (pH > 0 || pW > 0) ? this.pad([pW, pW, pH, pH]) : this;
|
|
2196
|
+
const Hp = H + 2 * pH;
|
|
2197
|
+
const Wp = W + 2 * pW;
|
|
2198
|
+
const Hout = Math.floor((Hp - dH * (kH - 1) - 1) / sH + 1);
|
|
2199
|
+
const Wout = Math.floor((Wp - dW * (kW - 1) - 1) / sW + 1);
|
|
2200
|
+
// Unfold H with a window large enough to cover the dilated kernel extent,
|
|
2201
|
+
// then slice every dH-th position to realise the dilation holes.
|
|
2202
|
+
// x: [N, Cin, Hp, Wp]
|
|
2203
|
+
// -> unfold(2, dH*(kH-1)+1, sH)
|
|
2204
|
+
// -> [N, Cin, Hout, Wp, dH*(kH-1)+1]
|
|
2205
|
+
// -> slice step dH on last dim
|
|
2206
|
+
// -> [N, Cin, Hout, Wp, kH]
|
|
2207
|
+
const dilKH = dH * (kH - 1) + 1;
|
|
2208
|
+
x = x.unfold(2, dilKH, sH);
|
|
2209
|
+
if (dH > 1)
|
|
2210
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wp], [0, dilKH, dH]]);
|
|
2211
|
+
// Unfold W
|
|
2212
|
+
// x: [N, Cin, Hout, Wp, kH]
|
|
2213
|
+
// -> unfold(3, dW*(kW-1)+1, sW)
|
|
2214
|
+
// -> [N, Cin, Hout, Wout, kH, dW*(kW-1)+1]
|
|
2215
|
+
// -> slice step dW on last dim
|
|
2216
|
+
// -> [N, Cin, Hout, Wout, kH, kW]
|
|
2217
|
+
const dilKW = dW * (kW - 1) + 1;
|
|
2218
|
+
x = x.unfold(3, dilKW, sW);
|
|
2219
|
+
if (dW > 1)
|
|
2220
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wout], [0, kH], [0, dilKW, dW]]);
|
|
2221
|
+
// Reshape patches to [N, Hout*Wout, Cin*kH*kW]
|
|
2222
|
+
// permute [0,2,3,1,4,5] -> [N, Hout, Wout, Cin, kH, kW]
|
|
2223
|
+
// then reshape merges the spatial and channel-kernel dims.
|
|
2224
|
+
// reshape() forces contiguity internally so no explicit .contiguous() needed.
|
|
2225
|
+
x = x.permute([0, 2, 3, 1, 4, 5]).reshape([N, Hout * Wout, Cin * kH * kW]);
|
|
2226
|
+
// Matmul with weight
|
|
2227
|
+
// weight: [Cout, CinPerGroup, kH, kW] -> [Cout, CinPerGroup*kH*kW]
|
|
2228
|
+
const w = weight.reshape([Cout, CinPerGroup * kH * kW]);
|
|
2229
|
+
let out;
|
|
2230
|
+
if (groups === 1) {
|
|
2231
|
+
// x: [N, Hout*Wout, Cin*kH*kW] @ w.t(): [Cin*kH*kW, Cout]
|
|
2232
|
+
// -> [N, Hout*Wout, Cout]
|
|
2233
|
+
out = x.matmul(w.t());
|
|
2234
|
+
}
|
|
2235
|
+
else {
|
|
2236
|
+
// Each group handles Cin/groups input channels and Cout/groups output channels.
|
|
2237
|
+
// chunk(groups, 2) splits the Cin*kH*kW axis into groups equal slices,
|
|
2238
|
+
// each of size CinPerGroup*kH*kW — valid because reshape laid Cin outermost.
|
|
2239
|
+
const patchChunks = x.chunk(groups, 2); // Tensor[groups], each [N, Hout*Wout, CinPerGroup*kH*kW]
|
|
2240
|
+
const weightChunks = w.chunk(groups, 0); // Tensor[groups], each [Cout/groups, CinPerGroup*kH*kW]
|
|
2241
|
+
const groupOuts = patchChunks.map((patch, i) => patch.matmul(weightChunks[i].t()) // [N, Hout*Wout, Cout/groups]
|
|
2242
|
+
);
|
|
2243
|
+
// Cat all group outputs along the channel axis (dim 2)
|
|
2244
|
+
out = groupOuts.reduce((acc, g) => acc.cat(g, 2)); // [N, Hout*Wout, Cout]
|
|
2245
|
+
}
|
|
2246
|
+
// Restore [N, Cout, Hout, Wout]
|
|
2247
|
+
out = out.permute([0, 2, 1]).reshape([N, Cout, Hout, Wout]);
|
|
2248
|
+
// Bias
|
|
2249
|
+
if (bias) {
|
|
2250
|
+
out = out.add(this.handleOther(bias).reshape([1, Cout, 1, 1]));
|
|
2251
|
+
}
|
|
2252
|
+
return out;
|
|
2253
|
+
}
|
|
2116
2254
|
// Dropout
|
|
2117
2255
|
dropout(rate) {
|
|
2118
2256
|
if (!Tensor.training || rate === 0)
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,6 +55,16 @@ export declare class LSTMCell {
|
|
|
55
55
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
56
56
|
forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
|
|
57
57
|
}
|
|
58
|
+
export declare class Conv2d {
|
|
59
|
+
weight: Tensor;
|
|
60
|
+
bias?: Tensor;
|
|
61
|
+
stride: number | [number, number];
|
|
62
|
+
padding: number | [number, number];
|
|
63
|
+
dilation: number | [number, number];
|
|
64
|
+
groups: number;
|
|
65
|
+
constructor(inChannels: number, outChannels: number, kernelSize: number, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
66
|
+
forward(input: Tensor): Tensor;
|
|
67
|
+
}
|
|
58
68
|
export declare class BatchNorm {
|
|
59
69
|
weight?: Tensor;
|
|
60
70
|
bias?: Tensor;
|
|
@@ -127,6 +137,7 @@ export declare const nn: {
|
|
|
127
137
|
RNNCell: typeof RNNCell;
|
|
128
138
|
GRUCell: typeof GRUCell;
|
|
129
139
|
LSTMCell: typeof LSTMCell;
|
|
140
|
+
Conv2d: typeof Conv2d;
|
|
130
141
|
BatchNorm: typeof BatchNorm;
|
|
131
142
|
InstanceNorm: typeof InstanceNorm;
|
|
132
143
|
GroupNorm: typeof GroupNorm;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.Conv2d = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class Linear {
|
|
6
6
|
weight;
|
|
@@ -136,6 +136,30 @@ class LSTMCell {
|
|
|
136
136
|
}
|
|
137
137
|
}
|
|
138
138
|
exports.LSTMCell = LSTMCell;
|
|
139
|
+
class Conv2d {
|
|
140
|
+
weight;
|
|
141
|
+
bias;
|
|
142
|
+
stride;
|
|
143
|
+
padding;
|
|
144
|
+
dilation;
|
|
145
|
+
groups;
|
|
146
|
+
constructor(inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1, groups = 1, bias = true, device, dtype) {
|
|
147
|
+
this.stride = stride;
|
|
148
|
+
this.padding = padding;
|
|
149
|
+
this.dilation = dilation;
|
|
150
|
+
this.groups = groups;
|
|
151
|
+
const fanIn = (inChannels / groups) * kernelSize * kernelSize;
|
|
152
|
+
const bound = Math.sqrt(1 / fanIn);
|
|
153
|
+
this.weight = core_1.Tensor.uniform([outChannels, inChannels / groups, kernelSize, kernelSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
154
|
+
if (bias) {
|
|
155
|
+
this.bias = core_1.Tensor.uniform([outChannels], -bound, bound, { requiresGrad: true, device, dtype });
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
forward(input) {
|
|
159
|
+
return input.conv2d(this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
exports.Conv2d = Conv2d;
|
|
139
163
|
class BatchNorm {
|
|
140
164
|
weight;
|
|
141
165
|
bias;
|
|
@@ -436,6 +460,7 @@ exports.nn = {
|
|
|
436
460
|
RNNCell,
|
|
437
461
|
GRUCell,
|
|
438
462
|
LSTMCell,
|
|
463
|
+
Conv2d,
|
|
439
464
|
BatchNorm,
|
|
440
465
|
InstanceNorm,
|
|
441
466
|
GroupNorm,
|