catniff 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.js +10 -1
- package/dist/nn.d.ts +10 -1
- package/dist/nn.js +30 -10
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -20,6 +20,13 @@ class Tensor {
|
|
|
20
20
|
this.gradFn = options.gradFn || (() => { });
|
|
21
21
|
this.children = options.children || [];
|
|
22
22
|
this.device = options.device || "cpu";
|
|
23
|
+
// Move tensor to device
|
|
24
|
+
if (this.device !== "cpu") {
|
|
25
|
+
const backend = Tensor.backends.get(this.device);
|
|
26
|
+
if (backend && backend.transfer) {
|
|
27
|
+
backend.transfer(this);
|
|
28
|
+
}
|
|
29
|
+
}
|
|
23
30
|
}
|
|
24
31
|
// Utility to flatten an nD array to be 1D
|
|
25
32
|
static flatten(tensor) {
|
|
@@ -1527,6 +1534,7 @@ class Tensor {
|
|
|
1527
1534
|
return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
|
|
1528
1535
|
shape: this.shape,
|
|
1529
1536
|
strides: this.strides,
|
|
1537
|
+
device: this.device,
|
|
1530
1538
|
requiresGrad: this.requiresGrad
|
|
1531
1539
|
});
|
|
1532
1540
|
}
|
|
@@ -1549,7 +1557,8 @@ class Tensor {
|
|
|
1549
1557
|
to(device) {
|
|
1550
1558
|
const backend = Tensor.backends.get(device);
|
|
1551
1559
|
if (backend && backend.transfer) {
|
|
1552
|
-
|
|
1560
|
+
backend.transfer(this);
|
|
1561
|
+
return this;
|
|
1553
1562
|
}
|
|
1554
1563
|
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|
|
1555
1564
|
}
|
package/dist/nn.d.ts
CHANGED
|
@@ -2,11 +2,20 @@ import { Tensor, TensorValue } from "./core";
|
|
|
2
2
|
declare class Linear {
|
|
3
3
|
weight: Tensor;
|
|
4
4
|
bias?: Tensor;
|
|
5
|
-
constructor(inFeatures: number, outFeatures: number, bias?: boolean,
|
|
5
|
+
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
|
|
6
6
|
forward(input: Tensor | TensorValue): Tensor;
|
|
7
7
|
}
|
|
8
|
+
declare class RNNCell {
|
|
9
|
+
weightIH: Tensor;
|
|
10
|
+
weightHH: Tensor;
|
|
11
|
+
biasIH?: Tensor;
|
|
12
|
+
biasHH?: Tensor;
|
|
13
|
+
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
14
|
+
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
15
|
+
}
|
|
8
16
|
export declare const nn: {
|
|
9
17
|
Linear: typeof Linear;
|
|
18
|
+
RNNCell: typeof RNNCell;
|
|
10
19
|
state: {
|
|
11
20
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
12
21
|
};
|
package/dist/nn.js
CHANGED
|
@@ -5,17 +5,11 @@ const core_1 = require("./core");
|
|
|
5
5
|
class Linear {
|
|
6
6
|
weight;
|
|
7
7
|
bias;
|
|
8
|
-
constructor(inFeatures, outFeatures, bias = true,
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
return core_1.Tensor.uniform(shape, -bound, bound, { requiresGrad: true });
|
|
12
|
-
};
|
|
13
|
-
if (customInit) {
|
|
14
|
-
initFunc = customInit;
|
|
15
|
-
}
|
|
16
|
-
this.weight = initFunc([outFeatures, inFeatures]);
|
|
8
|
+
constructor(inFeatures, outFeatures, bias = true, device) {
|
|
9
|
+
const bound = 1 / Math.sqrt(inFeatures);
|
|
10
|
+
this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
|
|
17
11
|
if (bias) {
|
|
18
|
-
this.bias =
|
|
12
|
+
this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
|
|
19
13
|
}
|
|
20
14
|
}
|
|
21
15
|
forward(input) {
|
|
@@ -27,6 +21,31 @@ class Linear {
|
|
|
27
21
|
return output;
|
|
28
22
|
}
|
|
29
23
|
}
|
|
24
|
+
class RNNCell {
|
|
25
|
+
weightIH;
|
|
26
|
+
weightHH;
|
|
27
|
+
biasIH;
|
|
28
|
+
biasHH;
|
|
29
|
+
constructor(inputSize, hiddenSize, bias = true, device) {
|
|
30
|
+
const bound = 1 / Math.sqrt(hiddenSize);
|
|
31
|
+
this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
32
|
+
this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
33
|
+
if (bias) {
|
|
34
|
+
this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
35
|
+
this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
forward(input, hidden) {
|
|
39
|
+
input = core_1.Tensor.forceTensor(input);
|
|
40
|
+
hidden = core_1.Tensor.forceTensor(hidden);
|
|
41
|
+
let output = input.matmul(this.weightIH.t())
|
|
42
|
+
.add(hidden.matmul(this.weightHH.t()));
|
|
43
|
+
if (this.biasIH && this.biasHH) {
|
|
44
|
+
output = output.add(this.biasIH).add(this.biasHH);
|
|
45
|
+
}
|
|
46
|
+
return output.tanh();
|
|
47
|
+
}
|
|
48
|
+
}
|
|
30
49
|
const state = {
|
|
31
50
|
getParameters(model, visited = new WeakSet()) {
|
|
32
51
|
if (visited.has(model)) {
|
|
@@ -50,5 +69,6 @@ const state = {
|
|
|
50
69
|
};
|
|
51
70
|
exports.nn = {
|
|
52
71
|
Linear,
|
|
72
|
+
RNNCell,
|
|
53
73
|
state
|
|
54
74
|
};
|