homa 0.1.1__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/activations/APLU.py +49 -0
- homa/activations/ActivationFunction.py +6 -0
- homa/activations/AdaptiveActivationFunction.py +15 -0
- homa/activations/BaseDLReLU.py +34 -0
- homa/activations/CaLU.py +13 -0
- homa/activations/DLReLU.py +6 -0
- homa/activations/ERF.py +10 -0
- homa/activations/Elliot.py +10 -0
- homa/activations/ExpExpish.py +9 -0
- homa/activations/ExponentialDLReLU.py +6 -0
- homa/activations/ExponentialSwish.py +10 -0
- homa/activations/GCU.py +9 -0
- homa/activations/GaLU.py +11 -0
- homa/activations/GaussianReLU.py +50 -0
- homa/activations/GeneralizedSwish.py +10 -0
- homa/activations/Gish.py +11 -0
- homa/activations/LaLU.py +11 -0
- homa/activations/LogLogish.py +10 -0
- homa/activations/LogSigmoid.py +10 -0
- homa/activations/Logish.py +10 -0
- homa/activations/MeLU.py +11 -0
- homa/activations/MexicanReLU.py +49 -0
- homa/activations/MinSin.py +10 -0
- homa/activations/NReLU.py +12 -0
- homa/activations/NoisyReLU.py +6 -0
- homa/activations/PLogish.py +6 -0
- homa/activations/ParametricLogish.py +13 -0
- homa/activations/Phish.py +11 -0
- homa/activations/RReLU.py +16 -0
- homa/activations/RandomizedSlopedReLU.py +7 -0
- homa/activations/SGELU.py +12 -0
- homa/activations/SReLU.py +37 -0
- homa/activations/SelfArctan.py +9 -0
- homa/activations/ShiftedReLU.py +10 -0
- homa/activations/SigmoidDerivative.py +10 -0
- homa/activations/SineReLU.py +11 -0
- homa/activations/SlopedReLU.py +13 -0
- homa/activations/SmallGaLU.py +11 -0
- homa/activations/Smish.py +9 -0
- homa/activations/SoftsignRReLU.py +17 -0
- homa/activations/Suish.py +11 -0
- homa/activations/TBSReLU.py +13 -0
- homa/activations/TSReLU.py +10 -0
- homa/activations/TangentBipolarSigmoidReLU.py +6 -0
- homa/activations/TangentSigmoidReLU.py +6 -0
- homa/activations/TeLU.py +9 -0
- homa/activations/TripleStateSwish.py +15 -0
- homa/activations/WideMeLU.py +15 -0
- homa/activations/__init__.py +49 -2
- homa/activations/learnable/AOAF.py +16 -0
- homa/activations/learnable/AReLU.py +16 -0
- homa/activations/learnable/DPReLU.py +16 -0
- homa/activations/learnable/DualLine.py +18 -0
- homa/activations/learnable/FReLU.py +14 -0
- homa/activations/learnable/LeLeLU.py +14 -0
- homa/activations/learnable/PERU.py +16 -0
- homa/activations/learnable/PiLU.py +18 -0
- homa/activations/learnable/ShiLU.py +16 -0
- homa/activations/learnable/StarReLU.py +16 -0
- homa/activations/learnable/__init__.py +10 -0
- homa/activations/learnable/concerns/ChannelBased.py +36 -0
- homa/activations/learnable/concerns/__init__.py +1 -0
- homa/cli/Commands/Command.py +2 -0
- homa/cli/Commands/InitCommand.py +34 -0
- homa/cli/Commands/__init__.py +2 -0
- homa/cli/HomaCommand.py +4 -0
- homa/ensemble/Ensemble.py +2 -4
- homa/ensemble/concerns/CalculatesMetricNecessities.py +14 -10
- homa/ensemble/concerns/PredictsProbabilities.py +4 -0
- homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -1
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +3 -2
- homa/ensemble/concerns/ReportsLogits.py +4 -0
- homa/ensemble/concerns/ReportsSize.py +2 -2
- homa/ensemble/concerns/StoresModels.py +29 -0
- homa/ensemble/concerns/__init__.py +1 -2
- homa/loss/LogitNormLoss.py +12 -0
- homa/loss/Loss.py +2 -0
- homa/loss/__init__.py +2 -0
- homa/torch/__init__.py +0 -1
- homa/vision/Classifier.py +5 -0
- homa/vision/Resnet.py +6 -5
- homa/vision/StochasticClassifier.py +29 -0
- homa/vision/StochasticSwin.py +11 -0
- homa/vision/Swin.py +13 -0
- homa/vision/__init__.py +3 -1
- homa/vision/concerns/HasLabels.py +13 -0
- homa/vision/concerns/HasLogits.py +12 -0
- homa/vision/concerns/HasProbabilities.py +9 -0
- homa/vision/concerns/ReportsAccuracy.py +27 -0
- homa/vision/concerns/ReportsMetrics.py +6 -0
- homa/vision/concerns/Trainable.py +5 -2
- homa/vision/concerns/__init__.py +5 -0
- homa/vision/modules/SwinModule.py +23 -0
- homa/vision/modules/__init__.py +1 -1
- homa/vision/utils.py +9 -18
- homa-0.2.9.dist-info/METADATA +75 -0
- homa-0.2.9.dist-info/RECORD +113 -0
- homa/activations/classes/APLU.py +0 -48
- homa/activations/classes/GALU.py +0 -51
- homa/activations/classes/MELU.py +0 -50
- homa/activations/classes/PDELU.py +0 -39
- homa/activations/classes/SReLU.py +0 -49
- homa/activations/classes/SmallGALU.py +0 -39
- homa/activations/classes/StochasticActivation.py +0 -20
- homa/activations/classes/WideMELU.py +0 -61
- homa/activations/classes/__init__.py +0 -8
- homa/activations/utils.py +0 -27
- homa/ensemble/concerns/HasNetwork.py +0 -5
- homa/ensemble/concerns/HasStateDicts.py +0 -8
- homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
- homa/torch/Module.py +0 -8
- homa/vision/StochasticResnet.py +0 -8
- homa/vision/modules/StochasticResnetModule.py +0 -9
- homa-0.1.1.dist-info/METADATA +0 -21
- homa-0.1.1.dist-info/RECORD +0 -51
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/WHEEL +0 -0
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/entry_points.txt +0 -0
- {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .ActivationFunction import ActivationFunction
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TBSReLU(ActivationFunction):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__()
|
|
8
|
+
|
|
9
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
10
|
+
a = 1 - torch.exp(-x)
|
|
11
|
+
b = 1 + torch.exp(-x)
|
|
12
|
+
c = a / b
|
|
13
|
+
return x * torch.tanh(c)
|
homa/activations/TeLU.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .ActivationFunction import ActivationFunction
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TripleStateSwish(ActivationFunction):
|
|
6
|
+
def __init__(self, alpha: float = 20, beta: float = 40, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
|
+
self.alpha = alpha
|
|
9
|
+
self.beta = beta
|
|
10
|
+
|
|
11
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
a = 1 / (1 + torch.exp(-x))
|
|
13
|
+
b = 1 / (1 + torch.exp(-x + self.alpha))
|
|
14
|
+
c = 1 / (1 + torch.exp(-x + self.beta))
|
|
15
|
+
return x * a * (a + b + c)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .MexicanReLU import MexicanReLU
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class WideMeLU(MexicanReLU):
|
|
5
|
+
def __init__(self, channels: int | None = None, max_input: float = 1.0):
|
|
6
|
+
self.hats = [
|
|
7
|
+
(2.0, 2.0),
|
|
8
|
+
(1.0, 1.0),
|
|
9
|
+
(3.0, 1.0),
|
|
10
|
+
(0.5, 0.5),
|
|
11
|
+
(1.5, 0.5),
|
|
12
|
+
(2.5, 0.5),
|
|
13
|
+
(3.5, 0.5),
|
|
14
|
+
]
|
|
15
|
+
super().__init__(self.hats, channels=channels, max_input=max_input)
|
homa/activations/__init__.py
CHANGED
|
@@ -1,2 +1,49 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
1
|
+
from .ShiftedReLU import ShiftedReLU
|
|
2
|
+
from .PLogish import PLogish
|
|
3
|
+
from .ParametricLogish import ParametricLogish
|
|
4
|
+
from .ExpExpish import ExpExpish
|
|
5
|
+
from .GeneralizedSwish import GeneralizedSwish
|
|
6
|
+
from .TBSReLU import TBSReLU
|
|
7
|
+
from .NoisyReLU import NoisyReLU
|
|
8
|
+
from .ExponentialDLReLU import ExponentialDLReLU
|
|
9
|
+
from .SReLU import SReLU
|
|
10
|
+
from .TangentSigmoidReLU import TangentSigmoidReLU
|
|
11
|
+
from .Phish import Phish
|
|
12
|
+
from .WideMeLU import WideMeLU
|
|
13
|
+
from .SelfArctan import SelfArctan
|
|
14
|
+
from .LogSigmoid import LogSigmoid
|
|
15
|
+
from .SlopedReLU import SlopedReLU
|
|
16
|
+
from .SmallGaLU import SmallGaLU
|
|
17
|
+
from .MinSin import MinSin
|
|
18
|
+
from .LaLU import LaLU
|
|
19
|
+
from .MexicanReLU import MexicanReLU
|
|
20
|
+
from .APLU import APLU
|
|
21
|
+
from .ERF import ERF
|
|
22
|
+
from .TangentBipolarSigmoidReLU import TangentBipolarSigmoidReLU
|
|
23
|
+
from .BaseDLReLU import BaseDLReLU
|
|
24
|
+
from .Logish import Logish
|
|
25
|
+
from .TripleStateSwish import TripleStateSwish
|
|
26
|
+
from .ExponentialSwish import ExponentialSwish
|
|
27
|
+
from .TeLU import TeLU
|
|
28
|
+
from .Elliot import Elliot
|
|
29
|
+
from .MeLU import MeLU
|
|
30
|
+
from .GaussianReLU import GaussianReLU
|
|
31
|
+
from .ActivationFunction import ActivationFunction
|
|
32
|
+
from .RReLU import RReLU
|
|
33
|
+
from .Suish import Suish
|
|
34
|
+
from .SoftsignRReLU import SoftsignRReLU
|
|
35
|
+
from .Gish import Gish
|
|
36
|
+
from .NReLU import NReLU
|
|
37
|
+
from .LogLogish import LogLogish
|
|
38
|
+
from .SGELU import SGELU
|
|
39
|
+
from .GaLU import GaLU
|
|
40
|
+
from .TSReLU import TSReLU
|
|
41
|
+
from .SineReLU import SineReLU
|
|
42
|
+
from .DLReLU import DLReLU
|
|
43
|
+
from .CaLU import CaLU
|
|
44
|
+
from .RandomizedSlopedReLU import RandomizedSlopedReLU
|
|
45
|
+
from .GCU import GCU
|
|
46
|
+
from .SigmoidDerivative import SigmoidDerivative
|
|
47
|
+
from .Smish import Smish
|
|
48
|
+
from .AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
49
|
+
from .learnable import *
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .concerns import ChannelBased
|
|
3
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AOAF(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self, b: float = 0.17, c: float = 0.17):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = b
|
|
11
|
+
self.c = c
|
|
12
|
+
|
|
13
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
self.initialize(x, "a")
|
|
15
|
+
a = self.a.view(self.parameter_view(x))
|
|
16
|
+
return torch.relu(x - self.b * a) + self.c * a
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AReLU(AdaptiveActivationFunction):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super(AReLU, self).__init__()
|
|
8
|
+
self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
|
|
9
|
+
self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
|
|
10
|
+
|
|
11
|
+
def forward(self, z):
|
|
12
|
+
negative_slope = torch.clamp(self.a, 0.01, 0.99)
|
|
13
|
+
positive_slope = 1 + torch.sigmoid(self.b)
|
|
14
|
+
positive = positive_slope * torch.relu(z)
|
|
15
|
+
negative = negative_slope * (-torch.relu(-z))
|
|
16
|
+
return positive + negative
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DPReLU(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor):
|
|
13
|
+
self.initialize(x, ["a", "b"], [1, 0.01])
|
|
14
|
+
a = self.a.view(self.parameter_shape(x))
|
|
15
|
+
b = self.b.view(self.parameter_shape(x))
|
|
16
|
+
return torch.where(x >= 0, a * x, b * x)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DualLine(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
self.m = None
|
|
12
|
+
|
|
13
|
+
def forward(self, x: torch.Tensor):
|
|
14
|
+
self.initialize(x, ["a", "b", "m"], [1, 0.01, -0.22])
|
|
15
|
+
a = self.a.view(self.parameter_shape(x))
|
|
16
|
+
b = self.b.view(self.parameter_shape(x))
|
|
17
|
+
m = self.m.view(self.parameter_shape(x))
|
|
18
|
+
return torch.where(x >= 0, a * x + m, b * x + m)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FReLU(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.b = None
|
|
10
|
+
|
|
11
|
+
def forward(self, x: torch.Tensor):
|
|
12
|
+
self.initialize(x, "b")
|
|
13
|
+
b = self.b.view(self.parameter_shape(x))
|
|
14
|
+
return torch.where(x >= 0, x + b, b)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LeLeLU(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
|
|
11
|
+
def forward(self, x: torch.Tensor):
|
|
12
|
+
self.initialize(x, "a")
|
|
13
|
+
a = self.a.view(self.parameter_shape(x))
|
|
14
|
+
return torch.where(x >= 0, a * x, 0.01 * a * x)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .concerns import ChannelBased
|
|
3
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PERU(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor):
|
|
13
|
+
self.initialize(x, ["a", "b"])
|
|
14
|
+
a = self.a.view(self.parameter_shape(x))
|
|
15
|
+
b = self.b.view(self.parameter_shape(x))
|
|
16
|
+
return torch.where(x >= 0, a * x, a * x * torch.exp(b * x))
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DualLine(AdaptiveActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
self.c = None
|
|
12
|
+
|
|
13
|
+
def forward(self, x: torch.Tensor):
|
|
14
|
+
self.initialize(x, ["a", "b", "c"], [1, 0.01, 1])
|
|
15
|
+
a = self.a.view(self.parameter_shape(x))
|
|
16
|
+
b = self.b.view(self.parameter_shape(x))
|
|
17
|
+
c = self.c.view(self.parameter_shape(x))
|
|
18
|
+
return torch.where(x >= c, a * x + c * (1 - a), b * x + c * (1 - b))
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..ActivationFunction import ActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ShiLU(ActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
self.initialize(x, ["a", "b"])
|
|
14
|
+
a = self.a.view(self.parameter_shape(x))
|
|
15
|
+
b = self.b.view(self.parameter_shape(x))
|
|
16
|
+
return torch.relu(x) * a + b
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..ActivationFunction import ActivationFunction
|
|
3
|
+
from .concerns import ChannelBased
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StarReLU(ActivationFunction, ChannelBased):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.a = None
|
|
10
|
+
self.b = None
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
self.initialize(x, ["a", "b"])
|
|
14
|
+
a = self.a.view(self.parameter_shape(x))
|
|
15
|
+
b = self.b.view(self.parameter_shape(x))
|
|
16
|
+
return a * torch.relu(x).pow(2) + b
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .StarReLU import StarReLU
|
|
2
|
+
from .DualLine import DualLine
|
|
3
|
+
from .LeLeLU import LeLeLU
|
|
4
|
+
from .AReLU import AReLU
|
|
5
|
+
from .PERU import PERU
|
|
6
|
+
from .ShiLU import ShiLU
|
|
7
|
+
from .DPReLU import DPReLU
|
|
8
|
+
from .PiLU import DualLine
|
|
9
|
+
from .FReLU import FReLU
|
|
10
|
+
from .AOAF import AOAF
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ChannelBased:
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
|
+
self._initialized = False
|
|
9
|
+
self.num_channels = None
|
|
10
|
+
|
|
11
|
+
def initialize(
|
|
12
|
+
self, x: torch.Tensor, attrs: List[str] | str, values: List[float] | float = []
|
|
13
|
+
):
|
|
14
|
+
if getattr(self, "_initialized", False):
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
if not isinstance(values, list):
|
|
18
|
+
values = [values]
|
|
19
|
+
|
|
20
|
+
if not isinstance(attrs, list):
|
|
21
|
+
attrs = [attrs]
|
|
22
|
+
|
|
23
|
+
self.num_channels = x.shape[1]
|
|
24
|
+
for index, attr in enumerate(attrs):
|
|
25
|
+
if index < len(values) and values[index] is not None:
|
|
26
|
+
default_value = float(values[index])
|
|
27
|
+
else:
|
|
28
|
+
default_value = 1.0
|
|
29
|
+
param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
|
|
30
|
+
setattr(self, attr, param)
|
|
31
|
+
self._initialized = True
|
|
32
|
+
|
|
33
|
+
def parameter_shape(self, x: torch.Tensor) -> tuple | None:
|
|
34
|
+
if hasattr(self, "num_channels"):
|
|
35
|
+
return (1, self.num_channels) + (1,) * (x.ndim - 2)
|
|
36
|
+
return None
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .ChannelBased import ChannelBased
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from .Command import Command
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InitCommand(Command):
|
|
7
|
+
@staticmethod
|
|
8
|
+
def run():
|
|
9
|
+
path = Path(".")
|
|
10
|
+
init_file = path / "__init__.py"
|
|
11
|
+
init_file.write_text("")
|
|
12
|
+
for file in path.iterdir():
|
|
13
|
+
if file.name == "__init__.py" or file.suffix != ".py":
|
|
14
|
+
continue
|
|
15
|
+
tree = ast.parse(file.read_text())
|
|
16
|
+
classes = [
|
|
17
|
+
node.name
|
|
18
|
+
for node in tree.body
|
|
19
|
+
if isinstance(node, ast.ClassDef) and not node.name.startswith("_")
|
|
20
|
+
]
|
|
21
|
+
functions = [
|
|
22
|
+
node.name
|
|
23
|
+
for node in tree.body
|
|
24
|
+
if isinstance(node, ast.FunctionDef) and not node.name.startswith("_")
|
|
25
|
+
]
|
|
26
|
+
if not (classes or functions):
|
|
27
|
+
continue
|
|
28
|
+
module = file.stem
|
|
29
|
+
lines = [
|
|
30
|
+
f"from .{module} import {name}\n" for name in (*classes, *functions)
|
|
31
|
+
]
|
|
32
|
+
with init_file.open("a") as f:
|
|
33
|
+
f.writelines(lines)
|
|
34
|
+
print(f"Processed {file}: classes={classes}, functions={functions}")
|
homa/cli/HomaCommand.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import fire
|
|
2
2
|
from .namespaces import MakeNamespace, CacheNamespace
|
|
3
|
+
from .Commands import InitCommand
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class HomaCommand:
|
|
@@ -7,6 +8,9 @@ class HomaCommand:
|
|
|
7
8
|
self.make = MakeNamespace()
|
|
8
9
|
self.cache = CacheNamespace()
|
|
9
10
|
|
|
11
|
+
def init(self):
|
|
12
|
+
InitCommand.run()
|
|
13
|
+
|
|
10
14
|
|
|
11
15
|
def main():
|
|
12
16
|
fire.Fire(HomaCommand)
|
homa/ensemble/Ensemble.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
from .concerns import (
|
|
2
2
|
ReportsSize,
|
|
3
|
-
|
|
3
|
+
StoresModels,
|
|
4
4
|
ReportsClassificationMetrics,
|
|
5
|
-
HasNetwork,
|
|
6
5
|
PredictsProbabilities,
|
|
7
6
|
)
|
|
8
7
|
|
|
@@ -10,9 +9,8 @@ from .concerns import (
|
|
|
10
9
|
class Ensemble(
|
|
11
10
|
ReportsSize,
|
|
12
11
|
ReportsClassificationMetrics,
|
|
13
|
-
RecordsStateDictionaries,
|
|
14
12
|
PredictsProbabilities,
|
|
15
|
-
|
|
13
|
+
StoresModels,
|
|
16
14
|
):
|
|
17
15
|
def __init__(self):
|
|
18
16
|
super().__init__()
|
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class CalculatesMetricNecessities:
|
|
5
6
|
def __init__(self, *args, **kwargs):
|
|
6
7
|
super().__init__(*args, **kwargs)
|
|
7
8
|
|
|
9
|
+
@torch.no_grad()
|
|
8
10
|
def metric_necessities(self, dataloader):
|
|
9
|
-
|
|
10
|
-
|
|
11
|
+
predictions, labels = [], []
|
|
12
|
+
device = get_device()
|
|
11
13
|
for x, y in dataloader:
|
|
12
|
-
|
|
14
|
+
x, y = x.to(device), y.to(device)
|
|
15
|
+
sum_logits = None
|
|
13
16
|
for model in self.models:
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
model.to(device)
|
|
18
|
+
model.eval()
|
|
19
|
+
logits = model(x)
|
|
20
|
+
sum_logits = logits if sum_logits is None else sum_logits + logits
|
|
21
|
+
batch_predictions = sum_logits.argmax(dim=1)
|
|
22
|
+
predictions.extend(batch_predictions.cpu().numpy())
|
|
23
|
+
labels.extend(y.cpu().numpy())
|
|
24
|
+
return predictions, labels
|
|
@@ -9,3 +9,7 @@ class PredictsProbabilities(ReportsLogits):
|
|
|
9
9
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
10
10
|
logits = self.logits(x)
|
|
11
11
|
return torch.nn.functional.softmax(logits, dim=1)
|
|
12
|
+
|
|
13
|
+
@torch.no_grad()
|
|
14
|
+
def predict_(self, x: torch.Tensor) -> torch.Tensor:
|
|
15
|
+
return self.predict(x)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from sklearn.metrics import accuracy_score as accuracy
|
|
2
|
+
from torch.utils.data import DataLoader
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class ReportsEnsembleAccuracy:
|
|
5
6
|
def __init__(self, *args, **kwargs):
|
|
6
7
|
super().__init__(*args, **kwargs)
|
|
7
8
|
|
|
8
|
-
def accuracy(self) -> float:
|
|
9
|
-
predictions, labels = self.metric_necessities()
|
|
9
|
+
def accuracy(self, dataloader: DataLoader) -> float:
|
|
10
|
+
predictions, labels = self.metric_necessities(dataloader)
|
|
10
11
|
return accuracy(labels, predictions)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import List
|
|
4
|
+
from ...vision import Model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StoresModels:
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
super().__init__(*args, **kwargs)
|
|
10
|
+
self.models: List[torch.nn.Module] = []
|
|
11
|
+
|
|
12
|
+
def record(self, model: Model | torch.nn.Module):
|
|
13
|
+
model_: torch.nn.Module | None = None
|
|
14
|
+
if isinstance(model, Model):
|
|
15
|
+
model_ = deepcopy(model.network)
|
|
16
|
+
elif isinstance(model, torch.nn.Module):
|
|
17
|
+
model_ = deepcopy(model)
|
|
18
|
+
else:
|
|
19
|
+
raise TypeError("Wrong input to ensemble record")
|
|
20
|
+
self.models.append(model_)
|
|
21
|
+
|
|
22
|
+
def push(self, *args, **kwargs):
|
|
23
|
+
self.record(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def append(self, *args, **kwargs):
|
|
26
|
+
self.record(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
def add(self, *args, **kwargs):
|
|
29
|
+
self.record(*args, **kwargs)
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from .CalculatesMetricNecessities import CalculatesMetricNecessities
|
|
2
|
-
from .HasNetwork import HasNetwork
|
|
3
2
|
from .PredictsProbabilities import PredictsProbabilities
|
|
4
|
-
from .RecordsStateDictionaries import RecordsStateDictionaries
|
|
5
3
|
from .ReportsClassificationMetrics import ReportsClassificationMetrics
|
|
6
4
|
from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
|
|
7
5
|
from .ReportsEnsembleF1 import ReportsEnsembleF1
|
|
8
6
|
from .ReportsEnsembleKappa import ReportsEnsembleKappa
|
|
9
7
|
from .ReportsLogits import ReportsLogits
|
|
10
8
|
from .ReportsSize import ReportsSize
|
|
9
|
+
from .StoresModels import StoresModels
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Loss import Loss
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LogitNormLoss(Loss):
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
|
+
|
|
9
|
+
def forward(self, logits, target):
|
|
10
|
+
norms = torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7
|
|
11
|
+
normalized_logits = torch.div(logits, norms)
|
|
12
|
+
return torch.nn.functional.cross_entropy(normalized_logits, target)
|
homa/loss/Loss.py
ADDED
homa/loss/__init__.py
ADDED
homa/torch/__init__.py
CHANGED
homa/vision/Resnet.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from .modules import ResnetModule
|
|
3
|
-
from .
|
|
4
|
-
from .concerns import Trainable
|
|
3
|
+
from .Classifier import Classifier
|
|
4
|
+
from .concerns import Trainable, ReportsMetrics
|
|
5
|
+
from ..device import get_device
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
class Resnet(
|
|
8
|
-
def __init__(self, num_classes: int, lr: float):
|
|
8
|
+
class Resnet(Classifier, Trainable, ReportsMetrics):
|
|
9
|
+
def __init__(self, num_classes: int, lr: float = 0.001):
|
|
9
10
|
super().__init__()
|
|
10
|
-
self.network = ResnetModule(num_classes)
|
|
11
|
+
self.network = ResnetModule(num_classes).to(get_device())
|
|
11
12
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
12
13
|
self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
|