catniff 0.8.23 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -0
- package/dist/core.js +76 -9
- package/dist/nn.d.ts +11 -0
- package/dist/nn.js +26 -1
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -202,6 +202,7 @@ export declare class Tensor {
|
|
|
202
202
|
gelu(approximate?: string): Tensor;
|
|
203
203
|
maximum(other: TensorValue | Tensor): Tensor;
|
|
204
204
|
minimum(other: TensorValue | Tensor): Tensor;
|
|
205
|
+
copysign(other: TensorValue | Tensor): Tensor;
|
|
205
206
|
round(): Tensor;
|
|
206
207
|
floor(): Tensor;
|
|
207
208
|
ceil(): Tensor;
|
|
@@ -219,6 +220,7 @@ export declare class Tensor {
|
|
|
219
220
|
mv(other: TensorValue | Tensor): Tensor;
|
|
220
221
|
matmul(other: TensorValue | Tensor): Tensor;
|
|
221
222
|
tensordot(other: TensorValue | Tensor, axes?: number | [number, number] | [number[], number[]]): Tensor;
|
|
223
|
+
conv2d(weight: Tensor | TensorValue, bias?: Tensor | TensorValue, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number): Tensor;
|
|
222
224
|
dropout(rate: number): Tensor;
|
|
223
225
|
triu(diagonal?: number): Tensor;
|
|
224
226
|
tril(diagonal?: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -328,14 +328,9 @@ class Tensor {
|
|
|
328
328
|
}
|
|
329
329
|
const reducedGrad = accumGrad.sum(axesToReduce, true);
|
|
330
330
|
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
|
|
331
|
-
// Enforce 0-offset contiguous grads and correct dtype
|
|
332
331
|
if (typeof tensor.grad === "undefined") {
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
if (grad.offset !== 0) {
|
|
336
|
-
grad = grad.clone();
|
|
337
|
-
}
|
|
338
|
-
tensor.grad = grad.contiguous().cast(tensor.dtype);
|
|
332
|
+
// Force default grad to have same shape and dtype as original tensor
|
|
333
|
+
tensor.grad = Tensor.zerosLike(tensor).add(squeezedGrad.cast(tensor.dtype));
|
|
339
334
|
}
|
|
340
335
|
else {
|
|
341
336
|
tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
|
|
@@ -801,12 +796,12 @@ class Tensor {
|
|
|
801
796
|
const outGrad = out.grad;
|
|
802
797
|
const grad = Tensor.zerosLike(this);
|
|
803
798
|
for (let i = 0; i < out.numel; i++) {
|
|
804
|
-
const coords = Tensor.indexToCoords(i,
|
|
799
|
+
const coords = Tensor.indexToCoords(i, Tensor.getStrides(outGrad.shape));
|
|
805
800
|
const windowIdx = coords[dim];
|
|
806
801
|
const withinWindow = coords[coords.length - 1];
|
|
807
802
|
coords[dim] = windowIdx * step + withinWindow;
|
|
808
803
|
coords.pop();
|
|
809
|
-
const sourceIdx = Tensor.coordsToIndex(coords,
|
|
804
|
+
const sourceIdx = Tensor.coordsToIndex(coords, Tensor.getStrides(grad.shape));
|
|
810
805
|
grad.value[sourceIdx] += outGrad.value[i];
|
|
811
806
|
}
|
|
812
807
|
Tensor.addGrad(this, grad);
|
|
@@ -1815,6 +1810,10 @@ class Tensor {
|
|
|
1815
1810
|
minimum(other) {
|
|
1816
1811
|
return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5))));
|
|
1817
1812
|
}
|
|
1813
|
+
// Tensor element-wise copysign
|
|
1814
|
+
copysign(other) {
|
|
1815
|
+
return this.elementWiseABDAG(other, (a, b) => Math.abs(a) * (Object.is(b, -0) || b < 0 ? -1 : 1), (self, other, outGrad) => outGrad.mul(self.sign().mul(other.sign())), (self, other, outGrad) => new Tensor(0));
|
|
1816
|
+
}
|
|
1818
1817
|
// Tensor element-wise round
|
|
1819
1818
|
round() {
|
|
1820
1819
|
return this.elementWiseSelfDAG((a) => Math.round(a));
|
|
@@ -2183,6 +2182,74 @@ class Tensor {
|
|
|
2183
2182
|
];
|
|
2184
2183
|
return result2D.reshape(finalShape);
|
|
2185
2184
|
}
|
|
2185
|
+
// 2D convolution
|
|
2186
|
+
conv2d(weight, bias, stride = 1, padding = 0, dilation = 1, groups = 1) {
|
|
2187
|
+
weight = this.handleOther(weight);
|
|
2188
|
+
const [sH, sW] = Array.isArray(stride) ? stride : [stride, stride];
|
|
2189
|
+
const [pH, pW] = Array.isArray(padding) ? padding : [padding, padding];
|
|
2190
|
+
const [dH, dW] = Array.isArray(dilation) ? dilation : [dilation, dilation];
|
|
2191
|
+
const [N, Cin, H, W] = this.shape;
|
|
2192
|
+
const [Cout, CinPerGroup, kH, kW] = weight.shape;
|
|
2193
|
+
// Pad input
|
|
2194
|
+
let x = (pH > 0 || pW > 0) ? this.pad([pW, pW, pH, pH]) : this;
|
|
2195
|
+
const Hp = H + 2 * pH;
|
|
2196
|
+
const Wp = W + 2 * pW;
|
|
2197
|
+
const Hout = Math.floor((Hp - dH * (kH - 1) - 1) / sH + 1);
|
|
2198
|
+
const Wout = Math.floor((Wp - dW * (kW - 1) - 1) / sW + 1);
|
|
2199
|
+
// Unfold H with a window large enough to cover the dilated kernel extent,
|
|
2200
|
+
// then slice every dH-th position to realise the dilation holes.
|
|
2201
|
+
// x: [N, Cin, Hp, Wp]
|
|
2202
|
+
// -> unfold(2, dH*(kH-1)+1, sH)
|
|
2203
|
+
// -> [N, Cin, Hout, Wp, dH*(kH-1)+1]
|
|
2204
|
+
// -> slice step dH on last dim
|
|
2205
|
+
// -> [N, Cin, Hout, Wp, kH]
|
|
2206
|
+
const dilKH = dH * (kH - 1) + 1;
|
|
2207
|
+
x = x.unfold(2, dilKH, sH);
|
|
2208
|
+
if (dH > 1)
|
|
2209
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wp], [0, dilKH, dH]]);
|
|
2210
|
+
// Unfold W
|
|
2211
|
+
// x: [N, Cin, Hout, Wp, kH]
|
|
2212
|
+
// -> unfold(3, dW*(kW-1)+1, sW)
|
|
2213
|
+
// -> [N, Cin, Hout, Wout, kH, dW*(kW-1)+1]
|
|
2214
|
+
// -> slice step dW on last dim
|
|
2215
|
+
// -> [N, Cin, Hout, Wout, kH, kW]
|
|
2216
|
+
const dilKW = dW * (kW - 1) + 1;
|
|
2217
|
+
x = x.unfold(3, dilKW, sW);
|
|
2218
|
+
if (dW > 1)
|
|
2219
|
+
x = x.slice([[0, N], [0, Cin], [0, Hout], [0, Wout], [0, kH], [0, dilKW, dW]]);
|
|
2220
|
+
// Reshape patches to [N, Hout*Wout, Cin*kH*kW]
|
|
2221
|
+
// permute [0,2,3,1,4,5] -> [N, Hout, Wout, Cin, kH, kW]
|
|
2222
|
+
// then reshape merges the spatial and channel-kernel dims.
|
|
2223
|
+
// reshape() forces contiguity internally so no explicit .contiguous() needed.
|
|
2224
|
+
x = x.permute([0, 2, 3, 1, 4, 5]).reshape([N, Hout * Wout, Cin * kH * kW]);
|
|
2225
|
+
// Matmul with weight
|
|
2226
|
+
// weight: [Cout, CinPerGroup, kH, kW] -> [Cout, CinPerGroup*kH*kW]
|
|
2227
|
+
const w = weight.reshape([Cout, CinPerGroup * kH * kW]);
|
|
2228
|
+
let out;
|
|
2229
|
+
if (groups === 1) {
|
|
2230
|
+
// x: [N, Hout*Wout, Cin*kH*kW] @ w.t(): [Cin*kH*kW, Cout]
|
|
2231
|
+
// -> [N, Hout*Wout, Cout]
|
|
2232
|
+
out = x.matmul(w.t());
|
|
2233
|
+
}
|
|
2234
|
+
else {
|
|
2235
|
+
// Each group handles Cin/groups input channels and Cout/groups output channels.
|
|
2236
|
+
// chunk(groups, 2) splits the Cin*kH*kW axis into groups equal slices,
|
|
2237
|
+
// each of size CinPerGroup*kH*kW - valid because reshape laid Cin outermost.
|
|
2238
|
+
const patchChunks = x.chunk(groups, 2); // Tensor[groups], each [N, Hout*Wout, CinPerGroup*kH*kW]
|
|
2239
|
+
const weightChunks = w.chunk(groups, 0); // Tensor[groups], each [Cout/groups, CinPerGroup*kH*kW]
|
|
2240
|
+
const groupOuts = patchChunks.map((patch, i) => patch.matmul(weightChunks[i].t()) // [N, Hout*Wout, Cout/groups]
|
|
2241
|
+
);
|
|
2242
|
+
// Cat all group outputs along the channel axis (dim 2)
|
|
2243
|
+
out = groupOuts.reduce((acc, g) => acc.cat(g, 2)); // [N, Hout*Wout, Cout]
|
|
2244
|
+
}
|
|
2245
|
+
// Restore [N, Cout, Hout, Wout]
|
|
2246
|
+
out = out.permute([0, 2, 1]).reshape([N, Cout, Hout, Wout]);
|
|
2247
|
+
// Bias
|
|
2248
|
+
if (bias) {
|
|
2249
|
+
out = out.add(this.handleOther(bias).reshape([1, Cout, 1, 1]));
|
|
2250
|
+
}
|
|
2251
|
+
return out;
|
|
2252
|
+
}
|
|
2186
2253
|
// Dropout
|
|
2187
2254
|
dropout(rate) {
|
|
2188
2255
|
if (!Tensor.training || rate === 0)
|
package/dist/nn.d.ts
CHANGED
|
@@ -55,6 +55,16 @@ export declare class LSTMCell {
|
|
|
55
55
|
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
56
56
|
forward(input: Tensor, hidden: Tensor, cell: Tensor): [Tensor, Tensor];
|
|
57
57
|
}
|
|
58
|
+
export declare class Conv2d {
|
|
59
|
+
weight: Tensor;
|
|
60
|
+
bias?: Tensor;
|
|
61
|
+
stride: number | [number, number];
|
|
62
|
+
padding: number | [number, number];
|
|
63
|
+
dilation: number | [number, number];
|
|
64
|
+
groups: number;
|
|
65
|
+
constructor(inChannels: number, outChannels: number, kernelSize: number, stride?: number | [number, number], padding?: number | [number, number], dilation?: number | [number, number], groups?: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
66
|
+
forward(input: Tensor): Tensor;
|
|
67
|
+
}
|
|
58
68
|
export declare class BatchNorm {
|
|
59
69
|
weight?: Tensor;
|
|
60
70
|
bias?: Tensor;
|
|
@@ -127,6 +137,7 @@ export declare const nn: {
|
|
|
127
137
|
RNNCell: typeof RNNCell;
|
|
128
138
|
GRUCell: typeof GRUCell;
|
|
129
139
|
LSTMCell: typeof LSTMCell;
|
|
140
|
+
Conv2d: typeof Conv2d;
|
|
130
141
|
BatchNorm: typeof BatchNorm;
|
|
131
142
|
InstanceNorm: typeof InstanceNorm;
|
|
132
143
|
GroupNorm: typeof GroupNorm;
|
package/dist/nn.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
3
|
+
exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.GroupNorm = exports.InstanceNorm = exports.BatchNorm = exports.Conv2d = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Sequential = exports.Linear = void 0;
|
|
4
4
|
const core_1 = require("./core");
|
|
5
5
|
class Linear {
|
|
6
6
|
weight;
|
|
@@ -136,6 +136,30 @@ class LSTMCell {
|
|
|
136
136
|
}
|
|
137
137
|
}
|
|
138
138
|
exports.LSTMCell = LSTMCell;
|
|
139
|
+
class Conv2d {
|
|
140
|
+
weight;
|
|
141
|
+
bias;
|
|
142
|
+
stride;
|
|
143
|
+
padding;
|
|
144
|
+
dilation;
|
|
145
|
+
groups;
|
|
146
|
+
constructor(inChannels, outChannels, kernelSize, stride = 1, padding = 0, dilation = 1, groups = 1, bias = true, device, dtype) {
|
|
147
|
+
this.stride = stride;
|
|
148
|
+
this.padding = padding;
|
|
149
|
+
this.dilation = dilation;
|
|
150
|
+
this.groups = groups;
|
|
151
|
+
const fanIn = (inChannels / groups) * kernelSize * kernelSize;
|
|
152
|
+
const bound = Math.sqrt(1 / fanIn);
|
|
153
|
+
this.weight = core_1.Tensor.uniform([outChannels, inChannels / groups, kernelSize, kernelSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
154
|
+
if (bias) {
|
|
155
|
+
this.bias = core_1.Tensor.uniform([outChannels], -bound, bound, { requiresGrad: true, device, dtype });
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
forward(input) {
|
|
159
|
+
return input.conv2d(this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
exports.Conv2d = Conv2d;
|
|
139
163
|
class BatchNorm {
|
|
140
164
|
weight;
|
|
141
165
|
bias;
|
|
@@ -436,6 +460,7 @@ exports.nn = {
|
|
|
436
460
|
RNNCell,
|
|
437
461
|
GRUCell,
|
|
438
462
|
LSTMCell,
|
|
463
|
+
Conv2d,
|
|
439
464
|
BatchNorm,
|
|
440
465
|
InstanceNorm,
|
|
441
466
|
GroupNorm,
|