catniff 0.5.3 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +6 -1
- package/dist/core.js +78 -11
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -18,6 +18,7 @@ export declare class Tensor {
|
|
|
18
18
|
gradFn: Function;
|
|
19
19
|
children: Tensor[];
|
|
20
20
|
device: string;
|
|
21
|
+
static training: boolean;
|
|
21
22
|
constructor(value: TensorValue, options?: TensorOptions);
|
|
22
23
|
static flatten(tensor: TensorValue): number[] | number;
|
|
23
24
|
static getShape(tensor: TensorValue): readonly number[];
|
|
@@ -124,6 +125,7 @@ export declare class Tensor {
|
|
|
124
125
|
softsign(): Tensor;
|
|
125
126
|
silu(): Tensor;
|
|
126
127
|
mish(): Tensor;
|
|
128
|
+
gelu(approximate?: string): Tensor;
|
|
127
129
|
maximum(other: TensorValue | Tensor): Tensor;
|
|
128
130
|
minimum(other: TensorValue | Tensor): Tensor;
|
|
129
131
|
round(): Tensor;
|
|
@@ -146,6 +148,7 @@ export declare class Tensor {
|
|
|
146
148
|
bmm(other: TensorValue | Tensor): Tensor;
|
|
147
149
|
mv(other: TensorValue | Tensor): Tensor;
|
|
148
150
|
matmul(other: TensorValue | Tensor): Tensor;
|
|
151
|
+
dropout(rate: number): Tensor;
|
|
149
152
|
static full(shape: number[], num: number, options?: TensorOptions): Tensor;
|
|
150
153
|
static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
|
|
151
154
|
static ones(shape?: number[], options?: TensorOptions): Tensor;
|
|
@@ -160,7 +163,9 @@ export declare class Tensor {
|
|
|
160
163
|
static randintLike(tensor: Tensor, low: number, high: number, options?: TensorOptions): Tensor;
|
|
161
164
|
static normal(shape: number[], mean: number, stdDev: number, options?: TensorOptions): Tensor;
|
|
162
165
|
static uniform(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;
|
|
163
|
-
backward(
|
|
166
|
+
backward(options?: {
|
|
167
|
+
zeroGrad?: boolean;
|
|
168
|
+
}): void;
|
|
164
169
|
val(): TensorValue;
|
|
165
170
|
withGrad(requiresGrad: boolean): Tensor;
|
|
166
171
|
detach(): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -11,6 +11,7 @@ class Tensor {
|
|
|
11
11
|
gradFn;
|
|
12
12
|
children;
|
|
13
13
|
device;
|
|
14
|
+
static training = false;
|
|
14
15
|
constructor(value, options = {}) {
|
|
15
16
|
this.value = Tensor.flatten(value);
|
|
16
17
|
this.shape = options.shape || Tensor.getShape(value);
|
|
@@ -999,6 +1000,34 @@ class Tensor {
|
|
|
999
1000
|
return outGrad.mul(derivative);
|
|
1000
1001
|
});
|
|
1001
1002
|
}
|
|
1003
|
+
// Tensor element-wise gelu
|
|
1004
|
+
gelu(approximate = "none") {
|
|
1005
|
+
if (approximate === "none") {
|
|
1006
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => {
|
|
1007
|
+
const sqrt2 = Math.sqrt(2);
|
|
1008
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1009
|
+
const xOverSqrt2 = self.div(sqrt2);
|
|
1010
|
+
const erfVal = xOverSqrt2.erf();
|
|
1011
|
+
const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi);
|
|
1012
|
+
const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi));
|
|
1013
|
+
return outGrad.mul(derivative);
|
|
1014
|
+
});
|
|
1015
|
+
}
|
|
1016
|
+
else if (approximate === "tanh") {
|
|
1017
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => {
|
|
1018
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1019
|
+
const c = 0.044715;
|
|
1020
|
+
const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi);
|
|
1021
|
+
const tanhVal = tanhArg.tanh();
|
|
1022
|
+
const sechSquared = tanhVal.square().neg().add(1);
|
|
1023
|
+
const term1 = tanhVal.add(1).mul(0.5);
|
|
1024
|
+
const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5);
|
|
1025
|
+
const derivative = term1.add(term2);
|
|
1026
|
+
return outGrad.mul(derivative);
|
|
1027
|
+
});
|
|
1028
|
+
}
|
|
1029
|
+
throw new Error("Specified approximation does not exist");
|
|
1030
|
+
}
|
|
1002
1031
|
// Tensor element-wise maximum
|
|
1003
1032
|
maximum(other) {
|
|
1004
1033
|
return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
|
|
@@ -1259,9 +1288,8 @@ class Tensor {
|
|
|
1259
1288
|
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
1260
1289
|
return this.mm(other);
|
|
1261
1290
|
}
|
|
1262
|
-
else if ((
|
|
1263
|
-
(
|
|
1264
|
-
(other.shape.length > 2 && this.shape.length > 2)) {
|
|
1291
|
+
else if ((this.shape.length > 0 && other.shape.length >= 2) ||
|
|
1292
|
+
(this.shape.length >= 2 && other.shape.length > 0)) {
|
|
1265
1293
|
// Append/prepend dims if needed
|
|
1266
1294
|
const self = isThis1D ? this.unsqueeze(0) : this;
|
|
1267
1295
|
other = isOther1D ? other.unsqueeze(1) : other;
|
|
@@ -1335,6 +1363,15 @@ class Tensor {
|
|
|
1335
1363
|
}
|
|
1336
1364
|
throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
|
|
1337
1365
|
}
|
|
1366
|
+
// Dropout
|
|
1367
|
+
dropout(rate) {
|
|
1368
|
+
if (!Tensor.training || rate === 0)
|
|
1369
|
+
return this;
|
|
1370
|
+
const keepRate = 1 - rate;
|
|
1371
|
+
const uniform = Tensor.randLike(this);
|
|
1372
|
+
const mask = uniform.lt(keepRate);
|
|
1373
|
+
return this.mul(mask).div(keepRate);
|
|
1374
|
+
}
|
|
1338
1375
|
// Utility to create a new tensor filled with a number
|
|
1339
1376
|
static full(shape, num, options = {}) {
|
|
1340
1377
|
if (shape.length === 0)
|
|
@@ -1347,7 +1384,12 @@ class Tensor {
|
|
|
1347
1384
|
static fullLike(tensor, num, options = {}) {
|
|
1348
1385
|
if (typeof tensor.value === "number")
|
|
1349
1386
|
return new Tensor(num, options);
|
|
1350
|
-
return new Tensor(new Array(tensor.value.length).fill(num), {
|
|
1387
|
+
return new Tensor(new Array(tensor.value.length).fill(num), {
|
|
1388
|
+
shape: tensor.shape,
|
|
1389
|
+
strides: tensor.strides,
|
|
1390
|
+
device: tensor.device,
|
|
1391
|
+
...options
|
|
1392
|
+
});
|
|
1351
1393
|
}
|
|
1352
1394
|
// Utility to create a new tensor filled with 1
|
|
1353
1395
|
static ones(shape, options = {}) {
|
|
@@ -1361,7 +1403,12 @@ class Tensor {
|
|
|
1361
1403
|
static onesLike(tensor, options = {}) {
|
|
1362
1404
|
if (typeof tensor.value === "number")
|
|
1363
1405
|
return new Tensor(1, options);
|
|
1364
|
-
return new Tensor(new Array(tensor.value.length).fill(1), {
|
|
1406
|
+
return new Tensor(new Array(tensor.value.length).fill(1), {
|
|
1407
|
+
shape: tensor.shape,
|
|
1408
|
+
strides: tensor.strides,
|
|
1409
|
+
device: tensor.device,
|
|
1410
|
+
...options
|
|
1411
|
+
});
|
|
1365
1412
|
}
|
|
1366
1413
|
// Utility to create a new tensor filled with 0
|
|
1367
1414
|
static zeros(shape, options = {}) {
|
|
@@ -1375,7 +1422,12 @@ class Tensor {
|
|
|
1375
1422
|
static zerosLike(tensor, options = {}) {
|
|
1376
1423
|
if (typeof tensor.value === "number")
|
|
1377
1424
|
return new Tensor(0, options);
|
|
1378
|
-
return new Tensor(new Array(tensor.value.length).fill(0), {
|
|
1425
|
+
return new Tensor(new Array(tensor.value.length).fill(0), {
|
|
1426
|
+
shape: tensor.shape,
|
|
1427
|
+
strides: tensor.strides,
|
|
1428
|
+
device: tensor.device,
|
|
1429
|
+
...options
|
|
1430
|
+
});
|
|
1379
1431
|
}
|
|
1380
1432
|
// Utility to create a new tensor filled with a random number with uniform distribution from 0 to 1
|
|
1381
1433
|
static rand(shape, options = {}) {
|
|
@@ -1397,7 +1449,10 @@ class Tensor {
|
|
|
1397
1449
|
outputValue[index] = (0, utils_1.randUniform)();
|
|
1398
1450
|
}
|
|
1399
1451
|
return new Tensor(outputValue, {
|
|
1400
|
-
shape: tensor.shape,
|
|
1452
|
+
shape: tensor.shape,
|
|
1453
|
+
strides: tensor.strides,
|
|
1454
|
+
device: tensor.device,
|
|
1455
|
+
...options
|
|
1401
1456
|
});
|
|
1402
1457
|
}
|
|
1403
1458
|
// Utility to create a new tensor filled with a random number with normal distribution of mean=0 and stddev=1
|
|
@@ -1420,7 +1475,10 @@ class Tensor {
|
|
|
1420
1475
|
outputValue[index] = (0, utils_1.randNormal)();
|
|
1421
1476
|
}
|
|
1422
1477
|
return new Tensor(outputValue, {
|
|
1423
|
-
shape: tensor.shape,
|
|
1478
|
+
shape: tensor.shape,
|
|
1479
|
+
strides: tensor.strides,
|
|
1480
|
+
device: tensor.device,
|
|
1481
|
+
...options
|
|
1424
1482
|
});
|
|
1425
1483
|
}
|
|
1426
1484
|
// Utility to create a new tensor filled with a random integer between low and high
|
|
@@ -1443,7 +1501,10 @@ class Tensor {
|
|
|
1443
1501
|
outputValue[index] = (0, utils_1.randInt)(low, high);
|
|
1444
1502
|
}
|
|
1445
1503
|
return new Tensor(outputValue, {
|
|
1446
|
-
shape: tensor.shape,
|
|
1504
|
+
shape: tensor.shape,
|
|
1505
|
+
strides: tensor.strides,
|
|
1506
|
+
device: tensor.device,
|
|
1507
|
+
...options
|
|
1447
1508
|
});
|
|
1448
1509
|
}
|
|
1449
1510
|
// Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
|
|
@@ -1469,14 +1530,20 @@ class Tensor {
|
|
|
1469
1530
|
return new Tensor(outputValue, { shape, ...options });
|
|
1470
1531
|
}
|
|
1471
1532
|
// Reverse-mode autodiff call
|
|
1472
|
-
backward() {
|
|
1533
|
+
backward(options = {}) {
|
|
1534
|
+
// Init
|
|
1535
|
+
const zeroGrad = options.zeroGrad ?? true;
|
|
1473
1536
|
// Build topological order
|
|
1474
1537
|
const topo = [];
|
|
1475
1538
|
const visited = new Set();
|
|
1476
1539
|
function build(node) {
|
|
1540
|
+
// Only collects unvisited node and node that requires gradient
|
|
1477
1541
|
if (!visited.has(node) && node.requiresGrad) {
|
|
1478
1542
|
visited.add(node);
|
|
1479
|
-
|
|
1543
|
+
// Reset grad to zeros if specified
|
|
1544
|
+
if (zeroGrad) {
|
|
1545
|
+
node.grad = Tensor.zerosLike(node);
|
|
1546
|
+
}
|
|
1480
1547
|
for (let child of node.children)
|
|
1481
1548
|
build(child);
|
|
1482
1549
|
topo.push(node);
|