catniff 0.5.4 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -125,6 +125,7 @@ export declare class Tensor {
125
125
  softsign(): Tensor;
126
126
  silu(): Tensor;
127
127
  mish(): Tensor;
128
+ gelu(approximate?: string): Tensor;
128
129
  maximum(other: TensorValue | Tensor): Tensor;
129
130
  minimum(other: TensorValue | Tensor): Tensor;
130
131
  round(): Tensor;
package/dist/core.js CHANGED
@@ -1000,6 +1000,34 @@ class Tensor {
1000
1000
  return outGrad.mul(derivative);
1001
1001
  });
1002
1002
  }
1003
+ // Tensor element-wise gelu
1004
+ gelu(approximate = "none") {
1005
+ if (approximate === "none") {
1006
+ return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => {
1007
+ const sqrt2 = Math.sqrt(2);
1008
+ const sqrt2OverPi = Math.sqrt(2 / Math.PI);
1009
+ const xOverSqrt2 = self.div(sqrt2);
1010
+ const erfVal = xOverSqrt2.erf();
1011
+ const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi);
1012
+ const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi));
1013
+ return outGrad.mul(derivative);
1014
+ });
1015
+ }
1016
+ else if (approximate === "tanh") {
1017
+ return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => {
1018
+ const sqrt2OverPi = Math.sqrt(2 / Math.PI);
1019
+ const c = 0.044715;
1020
+ const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi);
1021
+ const tanhVal = tanhArg.tanh();
1022
+ const sechSquared = tanhVal.square().neg().add(1);
1023
+ const term1 = tanhVal.add(1).mul(0.5);
1024
+ const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5);
1025
+ const derivative = term1.add(term2);
1026
+ return outGrad.mul(derivative);
1027
+ });
1028
+ }
1029
+ throw new Error("Specified approximation does not exist");
1030
+ }
1003
1031
  // Tensor element-wise maximum
1004
1032
  maximum(other) {
1005
1033
  return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
@@ -1260,9 +1288,8 @@ class Tensor {
1260
1288
  else if (this.shape.length === 2 && other.shape.length === 2) {
1261
1289
  return this.mm(other);
1262
1290
  }
1263
- else if ((isThis1D && other.shape.length > 2) ||
1264
- (isOther1D && this.shape.length > 2) ||
1265
- (other.shape.length > 2 && this.shape.length > 2)) {
1291
+ else if ((this.shape.length > 0 && other.shape.length >= 2) ||
1292
+ (this.shape.length >= 2 && other.shape.length > 0)) {
1266
1293
  // Append/prepend dims if needed
1267
1294
  const self = isThis1D ? this.unsqueeze(0) : this;
1268
1295
  other = isOther1D ? other.unsqueeze(1) : other;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.4",
3
+ "version": "0.5.5",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {