catniff 0.5.4 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +30 -3
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -125,6 +125,7 @@ export declare class Tensor {
|
|
|
125
125
|
softsign(): Tensor;
|
|
126
126
|
silu(): Tensor;
|
|
127
127
|
mish(): Tensor;
|
|
128
|
+
gelu(approximate?: string): Tensor;
|
|
128
129
|
maximum(other: TensorValue | Tensor): Tensor;
|
|
129
130
|
minimum(other: TensorValue | Tensor): Tensor;
|
|
130
131
|
round(): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -1000,6 +1000,34 @@ class Tensor {
|
|
|
1000
1000
|
return outGrad.mul(derivative);
|
|
1001
1001
|
});
|
|
1002
1002
|
}
|
|
1003
|
+
// Tensor element-wise gelu
|
|
1004
|
+
gelu(approximate = "none") {
|
|
1005
|
+
if (approximate === "none") {
|
|
1006
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => {
|
|
1007
|
+
const sqrt2 = Math.sqrt(2);
|
|
1008
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1009
|
+
const xOverSqrt2 = self.div(sqrt2);
|
|
1010
|
+
const erfVal = xOverSqrt2.erf();
|
|
1011
|
+
const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi);
|
|
1012
|
+
const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi));
|
|
1013
|
+
return outGrad.mul(derivative);
|
|
1014
|
+
});
|
|
1015
|
+
}
|
|
1016
|
+
else if (approximate === "tanh") {
|
|
1017
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => {
|
|
1018
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1019
|
+
const c = 0.044715;
|
|
1020
|
+
const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi);
|
|
1021
|
+
const tanhVal = tanhArg.tanh();
|
|
1022
|
+
const sechSquared = tanhVal.square().neg().add(1);
|
|
1023
|
+
const term1 = tanhVal.add(1).mul(0.5);
|
|
1024
|
+
const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5);
|
|
1025
|
+
const derivative = term1.add(term2);
|
|
1026
|
+
return outGrad.mul(derivative);
|
|
1027
|
+
});
|
|
1028
|
+
}
|
|
1029
|
+
throw new Error("Specified approximation does not exist");
|
|
1030
|
+
}
|
|
1003
1031
|
// Tensor element-wise maximum
|
|
1004
1032
|
maximum(other) {
|
|
1005
1033
|
return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
|
|
@@ -1260,9 +1288,8 @@ class Tensor {
|
|
|
1260
1288
|
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
1261
1289
|
return this.mm(other);
|
|
1262
1290
|
}
|
|
1263
|
-
else if ((
|
|
1264
|
-
(
|
|
1265
|
-
(other.shape.length > 2 && this.shape.length > 2)) {
|
|
1291
|
+
else if ((this.shape.length > 0 && other.shape.length >= 2) ||
|
|
1292
|
+
(this.shape.length >= 2 && other.shape.length > 0)) {
|
|
1266
1293
|
// Append/prepend dims if needed
|
|
1267
1294
|
const self = isThis1D ? this.unsqueeze(0) : this;
|
|
1268
1295
|
other = isOther1D ? other.unsqueeze(1) : other;
|