homa 0.1.1__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/activations/APLU.py +49 -0
- homa/activations/ActivationFunction.py +6 -0
- homa/activations/AdaptiveActivationFunction.py +15 -0
- homa/activations/BaseDLReLU.py +34 -0
- homa/activations/CaLU.py +13 -0
- homa/activations/DLReLU.py +6 -0
- homa/activations/ERF.py +10 -0
- homa/activations/Elliot.py +10 -0
- homa/activations/ExpExpish.py +9 -0
- homa/activations/ExponentialDLReLU.py +6 -0
- homa/activations/ExponentialSwish.py +10 -0
- homa/activations/GCU.py +9 -0
- homa/activations/GaLU.py +11 -0
- homa/activations/GaussianReLU.py +50 -0
- homa/activations/GeneralizedSwish.py +10 -0
- homa/activations/Gish.py +11 -0
- homa/activations/LaLU.py +11 -0
- homa/activations/LogLogish.py +10 -0
- homa/activations/LogSigmoid.py +10 -0
- homa/activations/Logish.py +10 -0
- homa/activations/MeLU.py +11 -0
- homa/activations/MexicanReLU.py +49 -0
- homa/activations/MinSin.py +10 -0
- homa/activations/NReLU.py +12 -0
- homa/activations/NoisyReLU.py +6 -0
- homa/activations/PLogish.py +6 -0
- homa/activations/ParametricLogish.py +13 -0
- homa/activations/Phish.py +11 -0
- homa/activations/RReLU.py +16 -0
- homa/activations/RandomizedSlopedReLU.py +7 -0
- homa/activations/SGELU.py +12 -0
- homa/activations/SReLU.py +37 -0
- homa/activations/SelfArctan.py +9 -0
- homa/activations/ShiftedReLU.py +10 -0
- homa/activations/SigmoidDerivative.py +10 -0
- homa/activations/SineReLU.py +11 -0
- homa/activations/SlopedReLU.py +13 -0
- homa/activations/SmallGaLU.py +11 -0
- homa/activations/Smish.py +9 -0
- homa/activations/SoftsignRReLU.py +17 -0
- homa/activations/Suish.py +11 -0
- homa/activations/TBSReLU.py +13 -0
- homa/activations/TSReLU.py +10 -0
- homa/activations/TangentBipolarSigmoidReLU.py +6 -0
- homa/activations/TangentSigmoidReLU.py +6 -0
- homa/activations/TeLU.py +9 -0
- homa/activations/TripleStateSwish.py +15 -0
- homa/activations/WideMeLU.py +15 -0
- homa/activations/__init__.py +49 -2
- homa/activations/learnable/AOAF.py +16 -0
- homa/activations/learnable/AReLU.py +16 -0
- homa/activations/learnable/DPReLU.py +16 -0
- homa/activations/learnable/DualLine.py +18 -0
- homa/activations/learnable/FReLU.py +14 -0
- homa/activations/learnable/LeLeLU.py +14 -0
- homa/activations/learnable/PERU.py +16 -0
- homa/activations/learnable/PiLU.py +18 -0
- homa/activations/learnable/ShiLU.py +16 -0
- homa/activations/learnable/StarReLU.py +16 -0
- homa/activations/learnable/__init__.py +10 -0
- homa/activations/learnable/concerns/ChannelBased.py +36 -0
- homa/activations/learnable/concerns/__init__.py +1 -0
- homa/cli/Commands/Command.py +2 -0
- homa/cli/Commands/InitCommand.py +34 -0
- homa/cli/Commands/__init__.py +2 -0
- homa/cli/HomaCommand.py +4 -0
- homa/ensemble/Ensemble.py +2 -4
- homa/ensemble/concerns/CalculatesMetricNecessities.py +14 -10
- homa/ensemble/concerns/PredictsProbabilities.py +4 -0
- homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -1
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +3 -2
- homa/ensemble/concerns/ReportsLogits.py +4 -0
- homa/ensemble/concerns/ReportsSize.py +2 -2
- homa/ensemble/concerns/StoresModels.py +29 -0
- homa/ensemble/concerns/__init__.py +1 -2
- homa/loss/LogitNormLoss.py +12 -0
- homa/loss/Loss.py +2 -0
- homa/loss/__init__.py +2 -0
- homa/torch/__init__.py +0 -1
- homa/vision/Classifier.py +5 -0
- homa/vision/Resnet.py +6 -5
- homa/vision/StochasticClassifier.py +29 -0
- homa/vision/StochasticSwin.py +11 -0
- homa/vision/Swin.py +13 -0
- homa/vision/__init__.py +3 -1
- homa/vision/concerns/HasLabels.py +13 -0
- homa/vision/concerns/HasLogits.py +12 -0
- homa/vision/concerns/HasProbabilities.py +9 -0
- homa/vision/concerns/ReportsAccuracy.py +27 -0
- homa/vision/concerns/ReportsMetrics.py +6 -0
- homa/vision/concerns/Trainable.py +5 -2
- homa/vision/concerns/__init__.py +5 -0
- homa/vision/modules/SwinModule.py +23 -0
- homa/vision/modules/__init__.py +1 -1
- homa/vision/utils.py +9 -18
- homa-0.2.9.dist-info/METADATA +75 -0
- homa-0.2.9.dist-info/RECORD +113 -0
- homa/activations/classes/APLU.py +0 -48
- homa/activations/classes/GALU.py +0 -51
- homa/activations/classes/MELU.py +0 -50
- homa/activations/classes/PDELU.py +0 -39
- homa/activations/classes/SReLU.py +0 -49
- homa/activations/classes/SmallGALU.py +0 -39
- homa/activations/classes/StochasticActivation.py +0 -20
- homa/activations/classes/WideMELU.py +0 -61
- homa/activations/classes/__init__.py +0 -8
- homa/activations/utils.py +0 -27
- homa/ensemble/concerns/HasNetwork.py +0 -5
- homa/ensemble/concerns/HasStateDicts.py +0 -8
- homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
- homa/torch/Module.py +0 -8
- homa/vision/StochasticResnet.py +0 -8
- homa/vision/modules/StochasticResnetModule.py +0 -9
- homa-0.1.1.dist-info/METADATA +0 -21
- homa-0.1.1.dist-info/RECORD +0 -51
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/WHEEL +0 -0
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/entry_points.txt +0 -0
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from ..activations import (
|
|
2
|
+
AOAF,
|
|
3
|
+
AReLU,
|
|
4
|
+
DPReLU,
|
|
5
|
+
DualLine,
|
|
6
|
+
FReLU,
|
|
7
|
+
LeLeLU,
|
|
8
|
+
PERU,
|
|
9
|
+
PiLU,
|
|
10
|
+
ShiLU,
|
|
11
|
+
StarReLU,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StochasticClassifier:
|
|
16
|
+
def __init__(self, *args, **kwargs):
|
|
17
|
+
super().__init__(*args, **kwargs)
|
|
18
|
+
self._activation_pool = [
|
|
19
|
+
AOAF,
|
|
20
|
+
AReLU,
|
|
21
|
+
DPReLU,
|
|
22
|
+
DualLine,
|
|
23
|
+
FReLU,
|
|
24
|
+
LeLeLU,
|
|
25
|
+
PERU,
|
|
26
|
+
PiLU,
|
|
27
|
+
ShiLU,
|
|
28
|
+
StarReLU,
|
|
29
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Swin import Swin
|
|
3
|
+
from .StochasticClassifier import StochasticClassifier
|
|
4
|
+
from .utils import replace_activations
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StochasticSwin(Swin, StochasticClassifier):
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
super().__init__(*args, **kwargs)
|
|
10
|
+
replace_activations(self.network, torch.nn.GELU, self._activation_pool)
|
|
11
|
+
replace_activations(self.network, torch.nn.ReLU, self._activation_pool)
|
homa/vision/Swin.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Classifier import Classifier
|
|
3
|
+
from .concerns import Trainable, ReportsMetrics
|
|
4
|
+
from .modules import SwinModule
|
|
5
|
+
from ..device import get_device
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Swin(Classifier, Trainable, ReportsMetrics):
|
|
9
|
+
def __init__(self, num_classes: int, lr: float = 0.0001):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.network = SwinModule(num_classes=num_classes).to(get_device())
|
|
12
|
+
self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
|
|
13
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
homa/vision/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class HasLabels:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def predict(self, x: torch.Tensor):
|
|
9
|
+
return torch.argmax(self.logits(x), dim=1)
|
|
10
|
+
|
|
11
|
+
@torch.no_grad()
|
|
12
|
+
def predict_(self, x: torch.Tensor):
|
|
13
|
+
return torch.argmax(self.logits(x), dim=1)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class HasLogits:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
|
+
return self.network(x)
|
|
10
|
+
|
|
11
|
+
def logits_(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
return self.network(x)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from torch import Tensor, no_grad
|
|
2
|
+
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from ...device import get_device
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReportsAccuracy:
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
|
|
10
|
+
def accuracy_tensors(self, x: Tensor, y: Tensor) -> float:
|
|
11
|
+
predictions = self.predict_(x)
|
|
12
|
+
return (predictions == y).float().mean().item()
|
|
13
|
+
|
|
14
|
+
def accuracy_dataloader(self, dataloader: DataLoader):
|
|
15
|
+
correct, total = 0, 0
|
|
16
|
+
for x, y in dataloader:
|
|
17
|
+
x, y = x.to(get_device()), y.to(get_device())
|
|
18
|
+
predictions = self.predict_(x)
|
|
19
|
+
correct += (predictions == y).sum().item()
|
|
20
|
+
total += y.numel()
|
|
21
|
+
return correct / total if total > 0 else 0.0
|
|
22
|
+
|
|
23
|
+
def accuracy(self, x: Tensor | DataLoader, y: Tensor | None = None) -> float:
|
|
24
|
+
self.network.eval()
|
|
25
|
+
if isinstance(x, DataLoader):
|
|
26
|
+
return self.accuracy_dataloader(x)
|
|
27
|
+
return self.accuracy_tensors(x, y)
|
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
from torch import Tensor
|
|
2
2
|
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from .HasLogits import HasLogits
|
|
4
|
+
from .HasProbabilities import HasProbabilities
|
|
5
|
+
from .HasLabels import HasLabels
|
|
3
6
|
from ...device import get_device
|
|
4
7
|
|
|
5
8
|
|
|
6
|
-
class Trainable:
|
|
9
|
+
class Trainable(HasLogits, HasProbabilities, HasLabels):
|
|
7
10
|
def __init__(self, *args, **kwargs):
|
|
8
11
|
super().__init__(*args, **kwargs)
|
|
9
12
|
|
|
@@ -16,7 +19,7 @@ class Trainable:
|
|
|
16
19
|
def train_tensors(self, x: Tensor, y: Tensor):
|
|
17
20
|
self.network.train()
|
|
18
21
|
self.optimizer.zero_grad()
|
|
19
|
-
loss = self.criterion(
|
|
22
|
+
loss = self.criterion(self.network(x).float(), y)
|
|
20
23
|
loss.backward()
|
|
21
24
|
self.optimizer.step()
|
|
22
25
|
|
homa/vision/concerns/__init__.py
CHANGED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.models import swin_v2_b
|
|
3
|
+
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SwinModule(torch.nn.Module):
|
|
7
|
+
def __init__(self, num_classes: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.num_classes = num_classes
|
|
10
|
+
self._create_encoder()
|
|
11
|
+
self._create_fc()
|
|
12
|
+
|
|
13
|
+
def _create_encoder(self):
|
|
14
|
+
self.encoder = swin_v2_b(weights="DEFAULT")
|
|
15
|
+
self.encoder.head = torch.nn.Identity()
|
|
16
|
+
|
|
17
|
+
def _create_fc(self):
|
|
18
|
+
self.fc = torch.nn.Linear(1024, self.num_classes)
|
|
19
|
+
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
|
+
|
|
21
|
+
def forward(self, images: torch.Tensor):
|
|
22
|
+
features = self.encoder(images)
|
|
23
|
+
return self.fc(features)
|
homa/vision/modules/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .ResnetModule import ResnetModule
|
|
2
|
-
from .
|
|
2
|
+
from .SwinModule import SwinModule
|
homa/vision/utils.py
CHANGED
|
@@ -1,21 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import random
|
|
2
3
|
|
|
3
4
|
|
|
4
|
-
def
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
for name, child in list(parent.named_children()):
|
|
13
|
-
for needle in find:
|
|
14
|
-
if isinstance(child, needle):
|
|
15
|
-
setattr(parent, name, replacement())
|
|
16
|
-
replaced += 1
|
|
17
|
-
return replaced
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def replace_relu(model: torch.nn.Module, replacement: torch.nn.Module):
|
|
21
|
-
return replace_modules(model, torch.nn.ReLU, replacement)
|
|
5
|
+
def replace_activations(module, needle: torch.nn.Module, candidates: list):
|
|
6
|
+
for name, child in module.named_children():
|
|
7
|
+
if isinstance(child, needle):
|
|
8
|
+
new_activation = random.choice(candidates)
|
|
9
|
+
setattr(module, name, new_activation())
|
|
10
|
+
else:
|
|
11
|
+
replace_activations(child, needle, candidates)
|
|
12
|
+
return module
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: homa
|
|
3
|
+
Version: 0.2.9
|
|
4
|
+
Summary: A curated list of machine learning and deep learning helpers.
|
|
5
|
+
Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: numpy
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: fire
|
|
11
|
+
|
|
12
|
+
# Core
|
|
13
|
+
|
|
14
|
+
### Device Management
|
|
15
|
+
|
|
16
|
+
```py
|
|
17
|
+
from homa import cpu, mps, cuda, device
|
|
18
|
+
|
|
19
|
+
torch.tensor([1, 2, 3, 4, 5]).to(cpu())
|
|
20
|
+
torch.tensor([1, 2, 3, 4, 5]).to(cuda())
|
|
21
|
+
torch.tensor([1, 2, 3, 4, 5]).to(mps())
|
|
22
|
+
torch.tensor([1, 2, 3, 4, 5]).to(device())
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
# Vision
|
|
26
|
+
|
|
27
|
+
## Resnet
|
|
28
|
+
|
|
29
|
+
This is the standard ResNet50 module.
|
|
30
|
+
|
|
31
|
+
You can train the model with a `DataLoader` object.
|
|
32
|
+
|
|
33
|
+
```py
|
|
34
|
+
from homa.vision import Resnet
|
|
35
|
+
|
|
36
|
+
model = Resnet(num_classes=10, lr=0.001)
|
|
37
|
+
for epoch in range(10):
|
|
38
|
+
model.train(train_dataloader)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Similarly you can manually take care of decomposition of data from the `DataLoader`.
|
|
42
|
+
|
|
43
|
+
```py
|
|
44
|
+
from homa.vision import Resnet
|
|
45
|
+
|
|
46
|
+
model = Resnet(num_classes=10, lr=0.001)
|
|
47
|
+
for epoch in range(10):
|
|
48
|
+
for x, y in train_dataloader:
|
|
49
|
+
model.train(x, y)
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## StochasticResnet
|
|
53
|
+
|
|
54
|
+
This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
|
|
55
|
+
|
|
56
|
+
You can train the model with a `DataLoader` object.
|
|
57
|
+
|
|
58
|
+
```py
|
|
59
|
+
from homa.vision import StochasticResnet
|
|
60
|
+
|
|
61
|
+
model = StochasticResnet(num_classes=10, lr=0.001)
|
|
62
|
+
for epoch in range(10):
|
|
63
|
+
model.train(train_dataloader)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Similarly you can manually take care of decomposition of data from the `DataLoader`.
|
|
67
|
+
|
|
68
|
+
```py
|
|
69
|
+
from homa.vision import StochasticResnet
|
|
70
|
+
|
|
71
|
+
model = StochasticResnet(num_classes=10, lr=0.001)
|
|
72
|
+
for epoch in range(10):
|
|
73
|
+
for x, y in train_dataloader:
|
|
74
|
+
model.train(x, y)
|
|
75
|
+
```
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
|
|
2
|
+
homa/device.py,sha256=9kKXfpYfnEk2cFQWPfcJrVloHgC_SSbP4I8IRY9TYk4,343
|
|
3
|
+
homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
|
|
4
|
+
homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
|
|
5
|
+
homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
|
|
6
|
+
homa/activations/ActivationFunction.py,sha256=XUw7Pa5E-CPG6rPL8Us_pDH7xCZqY0c2P9xtnJMyX44,141
|
|
7
|
+
homa/activations/AdaptiveActivationFunction.py,sha256=p_bqAq7527UOhVm47kdUtgdC1DlApxgiLOA4ZPBFdCE,386
|
|
8
|
+
homa/activations/BaseDLReLU.py,sha256=iRmDhhbFaO8N9G8u5M01s8-y-09t7poP96oA6uQkVq8,1186
|
|
9
|
+
homa/activations/CaLU.py,sha256=n0drKwp4GstHql69p4S58KeVctdaQ1B5oK_AIoI_okk,331
|
|
10
|
+
homa/activations/DLReLU.py,sha256=Q8l2zpR5q_tSgfgbz90uDXbXMbBT3b_7BWKw6JbtpQE,191
|
|
11
|
+
homa/activations/ERF.py,sha256=tDgHbo7UNFU93XPlcQCBRRxPMksr-FOE19mlsqfzmU8,252
|
|
12
|
+
homa/activations/Elliot.py,sha256=RDxERH9vFh6FYwtZXKHMDmLVG2ia1UfOoW18Gm2_8hM,298
|
|
13
|
+
homa/activations/ExpExpish.py,sha256=iq_uOmmV9EIz2eKowEy7SCeW-OMgGcEeMcivTnPc-Y0,202
|
|
14
|
+
homa/activations/ExponentialDLReLU.py,sha256=aVtah3c4sokB-aSdbVa5F_uq06IyXHwovnHtXlKGYlw,199
|
|
15
|
+
homa/activations/ExponentialSwish.py,sha256=nJtGu1TRHa2GSQ35w2MN0HEWzFogVvA9R2pGEkFvJX4,266
|
|
16
|
+
homa/activations/GCU.py,sha256=hXwty6WPovnhPGAxQDd4bIixujdoMOORN-77imVri7s,199
|
|
17
|
+
homa/activations/GaLU.py,sha256=5QHnHsUsLAy28s-LTxtwRN-t1hO1tg9xtWmkzE1T7Ck,308
|
|
18
|
+
homa/activations/GaussianReLU.py,sha256=ufNeVnod6dxkPLmdd9ye-xt0SIWap2dehX14_YxSZVM,2051
|
|
19
|
+
homa/activations/GeneralizedSwish.py,sha256=zv6CX83cOTVnN0yoCIXIvIgkjXLnmm_T_LsvyoN7lOY,236
|
|
20
|
+
homa/activations/Gish.py,sha256=Xohk5tTmeGTmQ4PXtHF5sPBDikmoNTjdEJzy2KPDmOI,249
|
|
21
|
+
homa/activations/LaLU.py,sha256=UiulXzSTmnoU_Gp8qKigFoL6efonqbldUlsBBlm9mB8,356
|
|
22
|
+
homa/activations/LogLogish.py,sha256=lfVRNhnDGbYYakTsUmePmr5azkzz_NQwEy6NvSSD-Do,205
|
|
23
|
+
homa/activations/LogSigmoid.py,sha256=PUvr84dRRd6L-VZ_9UeWAN9lhUFr2Otj8VrAIQ3eOEM,239
|
|
24
|
+
homa/activations/Logish.py,sha256=CnL-10b76C2EaDm56N4n2GYCaYJUKl_k7H82UBJI5to,257
|
|
25
|
+
homa/activations/MeLU.py,sha256=f13h2AAQCwp9soR3RWbMAA4Bl38oqRdBAsdzh6Bf4k8,321
|
|
26
|
+
homa/activations/MexicanReLU.py,sha256=vfDa1lWI-PgY4ztDY34aeBMaJ2rOyAYt5ifZBG0DS0c,1946
|
|
27
|
+
homa/activations/MinSin.py,sha256=JzQsmuffRAGGcD40nlz2ZnOGhQMEU0JYBIeFHIC1qcE,250
|
|
28
|
+
homa/activations/NReLU.py,sha256=mX4B2OXw28M8zyd6RpkaSoOCGZuB3FaksO4oFyr3YD8,314
|
|
29
|
+
homa/activations/NoisyReLU.py,sha256=2YFkOS_h8EijvCugYQTfq_gQg5uNEkcypcm0iDgEHIg,134
|
|
30
|
+
homa/activations/PLogish.py,sha256=ia_V0xewAS6mmX3G8JNQTxWl66ea8P3MuRLQkEPv-I8,165
|
|
31
|
+
homa/activations/ParametricLogish.py,sha256=grCGG61xTDytA4iOK3kS70V9m2bYoiSmuilJV3U8vIw,360
|
|
32
|
+
homa/activations/Phish.py,sha256=CLAV1fLHRAq-GxBut-_FsSYJRMlk5sOFVlcXs3G3w9c,280
|
|
33
|
+
homa/activations/RReLU.py,sha256=ILpkmoWk8WatXusrPqSLu15xMWQALwRQVZjhzwmw1PM,476
|
|
34
|
+
homa/activations/RandomizedSlopedReLU.py,sha256=O20XX3vRRmkERxwhLSNgue-fn0qSRoF7rlIN1LSWlyI,169
|
|
35
|
+
homa/activations/SGELU.py,sha256=AaNmXRoFQ68Xsgt4sSWMZxnSCTR5OD5ZEuqxxg1mvfg,358
|
|
36
|
+
homa/activations/SReLU.py,sha256=xyChK3G2HPpM7C8icQNfMzrOm142boDLY31n9yXqPtg,1472
|
|
37
|
+
homa/activations/SelfArctan.py,sha256=Sq3yWGXjxdP32J-rSZ38BQ5S_XErr5H1ZyPsMF1VKfI,193
|
|
38
|
+
homa/activations/ShiftedReLU.py,sha256=JVsf2F6C13PRICjnVOSVEsx9IoQ9rcM2TFn55DguZQs,229
|
|
39
|
+
homa/activations/SigmoidDerivative.py,sha256=4PPT-QX4MW9ySKU4Qv9K-y--lxlqFxvKVviC2S3e6Z0,274
|
|
40
|
+
homa/activations/SineReLU.py,sha256=gzYF1ZEZAFYmUuABWJf18LIer1oPAS38i_5NLcIhP-I,357
|
|
41
|
+
homa/activations/SlopedReLU.py,sha256=j6YfM4msg6It-ANbpMzEaXkiHvgEdhFNFbB6NkY6KpE,421
|
|
42
|
+
homa/activations/SmallGaLU.py,sha256=ERrK-g3QMZTNFDzUyiSLAovymEpV5h1x1696CN5K4Zg,289
|
|
43
|
+
homa/activations/Smish.py,sha256=hsr5FS4KywsCmsuFUKP-4pKoXkJK0hhRVDleq_CFGX0,198
|
|
44
|
+
homa/activations/SoftsignRReLU.py,sha256=bBSjYDLUVKxXPyaJExYXndEO3oORnP3M6NKoU-hiCCQ,564
|
|
45
|
+
homa/activations/Suish.py,sha256=I459CV24NV1JlLbko4oHUOh98fxoLaM-2SH71pVMcwA,279
|
|
46
|
+
homa/activations/TBSReLU.py,sha256=ZfYY_M6msDimJAOHr1HyrG1HHnWiJ7hnZ5hWjCFPecU,320
|
|
47
|
+
homa/activations/TSReLU.py,sha256=gbU0Q7zhf3X6oWvKUSry6sVdRhuaxIQt8keFH3WsxV8,256
|
|
48
|
+
homa/activations/TangentBipolarSigmoidReLU.py,sha256=YtrFHkFbEbx7aeIpIRc9TLCxhpveUFSgAnvTKaKLZ4E,156
|
|
49
|
+
homa/activations/TangentSigmoidReLU.py,sha256=C47UK6ADWsG2ueaZe9FUt-sPBzeuBLkiNjpkDZOCYGc,146
|
|
50
|
+
homa/activations/TeLU.py,sha256=qU5x0EskjQs6d5rCtbL91C6cMAm8vjDnjQNMX0LcEt8,180
|
|
51
|
+
homa/activations/TripleStateSwish.py,sha256=UG5BGY29wUEJaryClB2rDM90s0jt5vMJF9Kv-5M4Rgo,507
|
|
52
|
+
homa/activations/WideMeLU.py,sha256=ieJjTjnK9JJtApPFGpmTynu3G8YlyH5jw6qnhkJkStI,421
|
|
53
|
+
homa/activations/__init__.py,sha256=2GHNqrOp6WoLAtFFJcSj6j4GP-W8-YAYRZGX9vZbcmU,1659
|
|
54
|
+
homa/activations/learnable/AOAF.py,sha256=KYtQtpLiupdyoumqNmz0kMTgRK66sSYiuLnpbr2H7Mw,509
|
|
55
|
+
homa/activations/learnable/AReLU.py,sha256=-6kQ0mDGq3p9Xlg74waMa8xsTDALCtkE6pwx7DrTDeI,610
|
|
56
|
+
homa/activations/learnable/DPReLU.py,sha256=xQhYTJ0-mfRGdld950xoTh8c9O08WIY50K0FjPtVVFs,507
|
|
57
|
+
homa/activations/learnable/DualLine.py,sha256=cgqyE7dVqXflT8ulCuOyKQQa09FYSj8vJkeVUEOaeIU,600
|
|
58
|
+
homa/activations/learnable/FReLU.py,sha256=qQ8GjjWWGeoE6qW9tw49mZPs29app0QK1AFOuMc5ASU,413
|
|
59
|
+
homa/activations/learnable/LeLeLU.py,sha256=ya2m60QRcpVlTwMejJTgMTxM3RRHF0RgNe72_EdD1-U,425
|
|
60
|
+
homa/activations/learnable/PERU.py,sha256=y2OxRLIA1HTUnFyRHs0zgLhLMJhQz9Q4F6QrqBSkQ00,513
|
|
61
|
+
homa/activations/learnable/PiLU.py,sha256=p5FmWGJWlZEdLGVXmiXKg0rTxCVO-qn9bQIVcyAaa8U,616
|
|
62
|
+
homa/activations/learnable/ShiLU.py,sha256=35VC1pCAWMaxHKWYBeXd2DrXn1tepvQaT7a-KwoNdHY,475
|
|
63
|
+
homa/activations/learnable/StarReLU.py,sha256=hrscp-A0HnIvebFPLGr86K5Uf_U--EWtpNDqdNgonA0,485
|
|
64
|
+
homa/activations/learnable/__init__.py,sha256=fcfm-GHEe4AQzEz9mXrWfSLkcgWaTg91ccByx7LxfX4,264
|
|
65
|
+
homa/activations/learnable/concerns/ChannelBased.py,sha256=uK6FdC9mJRWSoXinjM8r5GJCZNWWxst7NMt8P6rnhKg,1143
|
|
66
|
+
homa/activations/learnable/concerns/__init__.py,sha256=CubRRYQEQMAK2-igsYKD8tcyesPOYoZYF_IlHzRZXi4,39
|
|
67
|
+
homa/cli/HomaCommand.py,sha256=w-Dg6dFpoXbQx2tvWSLdND2pdhqB2cPSORyi4MfY8XY,307
|
|
68
|
+
homa/cli/Commands/Command.py,sha256=DnmsEwpaxdQaLjzyYBO7qtIQTLwYzyhJS31YazA1IHg,24
|
|
69
|
+
homa/cli/Commands/InitCommand.py,sha256=3whh2mWLuevXpUyRpDEMbo_KNeAIdO2aLMFnC2nz_0c,1159
|
|
70
|
+
homa/cli/Commands/__init__.py,sha256=PYKkcG06R5LqLnp2x8otuimzRpL4oMbziL3xEMkCffc,66
|
|
71
|
+
homa/cli/namespaces/CacheNamespace.py,sha256=QXGljzj287stzTx0y_MXnqvCgPLqd7WjSPop2WDe14E,784
|
|
72
|
+
homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33IFCiqYro,590
|
|
73
|
+
homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
|
|
74
|
+
homa/ensemble/Ensemble.py,sha256=GNkXEV7Nli8lHSTQ3qTTCTeSBwST1PLZS5wxpKpeC5U,290
|
|
75
|
+
homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
|
|
76
|
+
homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
|
|
77
|
+
homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
|
|
78
|
+
homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
|
|
79
|
+
homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
|
|
80
|
+
homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
|
|
81
|
+
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
|
|
82
|
+
homa/ensemble/concerns/ReportsLogits.py,sha256=vTGuC9NR4rno3Mkbm0MhL8f7YopuCErGyjIorxamKTM,461
|
|
83
|
+
homa/ensemble/concerns/ReportsSize.py,sha256=S7lo_Wu6rDnuqyAcv6AI6jspaBhcpfsirpp9RVD8c20,238
|
|
84
|
+
homa/ensemble/concerns/StoresModels.py,sha256=PNoaoAOx4v8rercxXHmf7zqVIPGYM4APzIHHEb3RwT0,850
|
|
85
|
+
homa/ensemble/concerns/__init__.py,sha256=X0F_b2Jsv0XpiNhYwJsl-dfPsBOdEeW53LQPE4xQD0w,479
|
|
86
|
+
homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
|
|
87
|
+
homa/loss/Loss.py,sha256=COUr_idShYgAP8xKCxcaXbyUyAoJg7IOON0ARTQykmQ,21
|
|
88
|
+
homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
|
|
89
|
+
homa/torch/__init__.py,sha256=HTxCVaw1TLgpHMH8guB3hHYQ80cX6_fSEoPT_hz2Y8w,23
|
|
90
|
+
homa/torch/helpers.py,sha256=CLbTCXRrroM0n4PfM-K_xFavs4dCZJEu_L7hdgb1DCI,134
|
|
91
|
+
homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
|
|
92
|
+
homa/vision/Model.py,sha256=JIeVpHJwirHfsDfYYbLsu0kt7bGf4nhMQGIOagUDKw4,22
|
|
93
|
+
homa/vision/Resnet.py,sha256=Uitf58bEzIKkZd-F4FTvJ8nmhoFHlzZjJTvBPXEt2Iw,513
|
|
94
|
+
homa/vision/StochasticClassifier.py,sha256=6-o0TaH4iWXiPFefL7DOdLr3ZrTnjnJ9PIgQLlygN8w,497
|
|
95
|
+
homa/vision/StochasticSwin.py,sha256=FggzfaVYrP4fnjAFcdMpDozwQHc7CQhl3iRw78oQh0o,425
|
|
96
|
+
homa/vision/Swin.py,sha256=W3XbfUTrjaIhMH8fI_whPP6XO9fVA2R34LlGfQ1hoyo,508
|
|
97
|
+
homa/vision/__init__.py,sha256=w5OkcmdU6Ik5wHIJzeV1Z2UElQtvCsUZks1Q-xViSVg,153
|
|
98
|
+
homa/vision/utils.py,sha256=WB2b7eMDaf6UO3SuS7cB6IJk-9NRQesLavuzWUZRZyg,389
|
|
99
|
+
homa/vision/concerns/HasLabels.py,sha256=fM6nHLeQaEaWDlV6R8NQ5hgOSiwspPxOIwj-nvYXbP0,321
|
|
100
|
+
homa/vision/concerns/HasLogits.py,sha256=oStX4NCV7zwxI7Vj23M8wQSlY1xoSmAYJ_6cBNJpVCk,290
|
|
101
|
+
homa/vision/concerns/HasProbabilities.py,sha256=m1_ObS2BNYO-WVCNVMiHXzC3XAsyb88_0N4BWVDwCw0,221
|
|
102
|
+
homa/vision/concerns/ReportsAccuracy.py,sha256=DD0YTr5i8JMllIJTQn88Dn711yjZ2uiecaTi7WqpOEw,986
|
|
103
|
+
homa/vision/concerns/ReportsMetrics.py,sha256=93Hw_JBUbwfkrJNJA1xFSQ4cqRwzbSv4nPU524PGF6I,169
|
|
104
|
+
homa/vision/concerns/Trainable.py,sha256=SRCW3XpG9_DQgubyqhALlYDHwAWNzVVFjshUv1ecuEQ,988
|
|
105
|
+
homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioURsg,234
|
|
106
|
+
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
107
|
+
homa/vision/modules/SwinModule.py,sha256=h7wq1YdKoN6-7C3FVFA0bpkAET_30002iTRbjZxziFQ,714
|
|
108
|
+
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
109
|
+
homa-0.2.9.dist-info/METADATA,sha256=uqaBYePnoJwrTwJRFB47fx_vh073hlynKWA7JAU0hDs,1759
|
|
110
|
+
homa-0.2.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
111
|
+
homa-0.2.9.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
112
|
+
homa-0.2.9.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
113
|
+
homa-0.2.9.dist-info/RECORD,,
|
homa/activations/classes/APLU.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class APLU(torch.nn.Module):
|
|
5
|
-
def __init__(self, max_input: float = 1.0):
|
|
6
|
-
super(APLU, self).__init__()
|
|
7
|
-
self.max_input = max_input
|
|
8
|
-
self.alpha = None
|
|
9
|
-
self.beta = None
|
|
10
|
-
self.gamma = None
|
|
11
|
-
self.xi = None
|
|
12
|
-
self.psi = None
|
|
13
|
-
self.mu = None
|
|
14
|
-
self._num_channels = None
|
|
15
|
-
|
|
16
|
-
def _initialize_parameters(self, x):
|
|
17
|
-
if x.ndim < 2:
|
|
18
|
-
raise ValueError(
|
|
19
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
num_channels = x.shape[1]
|
|
23
|
-
self._num_channels = num_channels
|
|
24
|
-
|
|
25
|
-
param_shape = [1] * x.ndim
|
|
26
|
-
param_shape[1] = num_channels
|
|
27
|
-
|
|
28
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
|
|
29
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
30
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
|
|
31
|
-
|
|
32
|
-
self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
|
|
33
|
-
self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
|
|
34
|
-
self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
|
|
35
|
-
|
|
36
|
-
def forward(self, x):
|
|
37
|
-
if self.alpha is None:
|
|
38
|
-
self._initialize_parameters(x)
|
|
39
|
-
|
|
40
|
-
a = torch.relu(x)
|
|
41
|
-
|
|
42
|
-
# following are called hinges
|
|
43
|
-
b = self.alpha * torch.relu(-x + self.xi)
|
|
44
|
-
c = self.beta * torch.relu(-x + self.psi)
|
|
45
|
-
d = self.gamma * torch.relu(-x + self.mu)
|
|
46
|
-
z = a + b + c + d
|
|
47
|
-
|
|
48
|
-
return z
|
homa/activations/classes/GALU.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class GALU(torch.nn.Module):
|
|
5
|
-
def __init__(self, max_input: float = 1.0):
|
|
6
|
-
super(GALU, self).__init__()
|
|
7
|
-
if max_input <= 0:
|
|
8
|
-
raise ValueError("max_input must be positive.")
|
|
9
|
-
self.max_input = max_input
|
|
10
|
-
self.alpha = None
|
|
11
|
-
self.beta = None
|
|
12
|
-
self.gamma = None
|
|
13
|
-
self.delta = None
|
|
14
|
-
self._num_channels = None
|
|
15
|
-
|
|
16
|
-
def _initialize_parameters(self, x):
|
|
17
|
-
if x.ndim < 2:
|
|
18
|
-
raise ValueError(
|
|
19
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
num_channels = x.shape[1]
|
|
23
|
-
self._num_channels = num_channels
|
|
24
|
-
param_shape = [1] * x.ndim
|
|
25
|
-
param_shape[1] = num_channels
|
|
26
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
|
|
27
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
28
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
|
|
29
|
-
self.delta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
30
|
-
|
|
31
|
-
def forward(self, x):
|
|
32
|
-
if self.alpha is None:
|
|
33
|
-
self._initialize_parameters(x)
|
|
34
|
-
|
|
35
|
-
zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
36
|
-
x_norm = x / self.max_input
|
|
37
|
-
part_prelu = torch.relu(x_norm) + self.alpha * torch.min(x_norm, zero)
|
|
38
|
-
part_beta = self.beta * (
|
|
39
|
-
torch.relu(1.0 - torch.abs(x_norm - 1.0))
|
|
40
|
-
+ torch.min(torch.abs(x_norm - 3.0) - 1.0, zero)
|
|
41
|
-
)
|
|
42
|
-
part_gamma = self.gamma * (
|
|
43
|
-
torch.relu(0.5 - torch.abs(x_norm - 0.5))
|
|
44
|
-
+ torch.min(torch.abs(x_norm - 1.5) - 0.5, zero)
|
|
45
|
-
)
|
|
46
|
-
part_delta = self.delta * (
|
|
47
|
-
torch.relu(0.5 - torch.abs(x_norm - 2.5))
|
|
48
|
-
+ torch.min(torch.abs(x_norm - 3.5) - 0.5, zero)
|
|
49
|
-
)
|
|
50
|
-
z = part_prelu + part_beta + part_gamma + part_delta
|
|
51
|
-
return z * self.max_input
|
homa/activations/classes/MELU.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class MELU(torch.nn.Module):
|
|
5
|
-
def __init__(self, maxInput: float = 1.0):
|
|
6
|
-
super().__init__()
|
|
7
|
-
self.maxInput = float(maxInput)
|
|
8
|
-
self.alpha = None
|
|
9
|
-
self.beta = None
|
|
10
|
-
self.gamma = None
|
|
11
|
-
self.delta = None
|
|
12
|
-
self.xi = None
|
|
13
|
-
self.psi = None
|
|
14
|
-
self._initialized = False
|
|
15
|
-
|
|
16
|
-
def _initialize_parameters(self, X: torch.Tensor):
|
|
17
|
-
if X.dim() != 4:
|
|
18
|
-
raise ValueError(
|
|
19
|
-
f"Expected 4D input (B, C, H, W), but got {X.dim()}D input."
|
|
20
|
-
)
|
|
21
|
-
num_channels = X.shape[1]
|
|
22
|
-
shape = (1, num_channels, 1, 1)
|
|
23
|
-
self.alpha = torch.nn.Parameter(torch.zeros(shape))
|
|
24
|
-
self.beta = torch.nn.Parameter(torch.zeros(shape))
|
|
25
|
-
self.gamma = torch.nn.Parameter(torch.zeros(shape))
|
|
26
|
-
self.delta = torch.nn.Parameter(torch.zeros(shape))
|
|
27
|
-
self.xi = torch.nn.Parameter(torch.zeros(shape))
|
|
28
|
-
self.psi = torch.nn.Parameter(torch.zeros(shape))
|
|
29
|
-
self._initialized = True
|
|
30
|
-
|
|
31
|
-
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
32
|
-
if not self._initialized:
|
|
33
|
-
self._initialize_parameters(X)
|
|
34
|
-
X_norm = X / self.maxInput
|
|
35
|
-
Y = torch.roll(X_norm, shifts=-1, dims=1)
|
|
36
|
-
term1 = torch.relu(X_norm)
|
|
37
|
-
term2 = self.alpha * torch.clamp(X_norm, max=0)
|
|
38
|
-
dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
|
|
39
|
-
dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
|
|
40
|
-
dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
|
|
41
|
-
dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
|
|
42
|
-
dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
|
|
43
|
-
term3 = self.beta * torch.sqrt(torch.relu(2 - dist_sq_beta))
|
|
44
|
-
term4 = self.gamma * torch.sqrt(torch.relu(1 - dist_sq_gamma))
|
|
45
|
-
term5 = self.delta * torch.sqrt(torch.relu(1 - dist_sq_delta))
|
|
46
|
-
term6 = self.xi * torch.sqrt(torch.relu(1 - dist_sq_xi))
|
|
47
|
-
term7 = self.psi * torch.sqrt(torch.relu(1 - dist_sq_psi))
|
|
48
|
-
Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7
|
|
49
|
-
Z = Z_norm * self.maxInput
|
|
50
|
-
return Z
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class PDELU(torch.nn.Module):
|
|
5
|
-
def __init__(self, theta: float = 0.5):
|
|
6
|
-
super(PDELU, self).__init__()
|
|
7
|
-
if theta == 1.0:
|
|
8
|
-
raise ValueError(
|
|
9
|
-
"theta cannot be 1.0, as it would cause a division by zero."
|
|
10
|
-
)
|
|
11
|
-
self.theta = theta
|
|
12
|
-
self._power_val = 1.0 / (1.0 - self.theta)
|
|
13
|
-
self.alpha = torch.nn.UninitializedParameter()
|
|
14
|
-
self._num_channels = None
|
|
15
|
-
|
|
16
|
-
def _initialize_parameters(self, x: torch.Tensor):
|
|
17
|
-
if x.ndim < 2:
|
|
18
|
-
raise ValueError(
|
|
19
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
num_channels = x.shape[1]
|
|
23
|
-
self._num_channels = num_channels
|
|
24
|
-
param_shape = [1] * x.ndim
|
|
25
|
-
param_shape[1] = num_channels
|
|
26
|
-
init_tensor = torch.zeros(param_shape) + 0.1
|
|
27
|
-
self.alpha = torch.nn.Parameter(init_tensor)
|
|
28
|
-
|
|
29
|
-
def forward(self, x: torch.Tensor):
|
|
30
|
-
if self.alpha is None:
|
|
31
|
-
self._initialize_parameters(x)
|
|
32
|
-
|
|
33
|
-
zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
34
|
-
positive_part = torch.relu(x)
|
|
35
|
-
inner_term = torch.relu(1.0 + (1.0 - self.theta) * x)
|
|
36
|
-
powered_term = torch.pow(inner_term, self._power_val)
|
|
37
|
-
subtracted_term = powered_term - 1.0
|
|
38
|
-
negative_part = self.alpha * torch.min(subtracted_term, zero)
|
|
39
|
-
return positive_part + negative_part
|
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class SReLU(torch.nn.Module):
|
|
5
|
-
def __init__(
|
|
6
|
-
self,
|
|
7
|
-
alpha_init: float = 0.0,
|
|
8
|
-
beta_init: float = 0.0,
|
|
9
|
-
gamma_init: float = 1.0,
|
|
10
|
-
delta_init: float = 1.0,
|
|
11
|
-
):
|
|
12
|
-
super().__init__()
|
|
13
|
-
self.alpha_init_val = alpha_init
|
|
14
|
-
self.beta_init_val = beta_init
|
|
15
|
-
self.gamma_init_val = gamma_init
|
|
16
|
-
self.delta_init_val = delta_init
|
|
17
|
-
self.alpha = torch.nn.UninitializedParameter()
|
|
18
|
-
self.beta = torch.nn.UninitializedParameter()
|
|
19
|
-
self.gamma = torch.nn.UninitializedParameter()
|
|
20
|
-
self.delta = torch.nn.UninitializedParameter()
|
|
21
|
-
|
|
22
|
-
def _initialize_parameters(self, x: torch.Tensor):
|
|
23
|
-
if isinstance(self.alpha, torch.nn.UninitializedParameter):
|
|
24
|
-
if x.dim() < 2:
|
|
25
|
-
raise ValueError(
|
|
26
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got {x.dim()}"
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
num_channels = x.shape[1]
|
|
30
|
-
param_shape = [1] * x.dim()
|
|
31
|
-
param_shape[1] = num_channels
|
|
32
|
-
self.alpha = torch.nn.Parameter(
|
|
33
|
-
torch.full(param_shape, self.alpha_init_val)
|
|
34
|
-
)
|
|
35
|
-
self.beta = torch.nn.Parameter(torch.full(param_shape, self.beta_init_val))
|
|
36
|
-
self.gamma = torch.nn.Parameter(
|
|
37
|
-
torch.full(param_shape, self.gamma_init_val)
|
|
38
|
-
)
|
|
39
|
-
self.delta = torch.nn.Parameter(
|
|
40
|
-
torch.full(param_shape, self.delta_init_val)
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
44
|
-
self._initialize_parameters(x)
|
|
45
|
-
start = self.beta + self.alpha * (x - self.beta)
|
|
46
|
-
finish = self.delta + self.gamma * (x - self.delta)
|
|
47
|
-
out = torch.where(x < self.beta, start, x)
|
|
48
|
-
out = torch.where(x > self.delta, finish, out)
|
|
49
|
-
return out
|