homa 0.1.91__tar.gz → 0.1.94__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {homa-0.1.91 → homa-0.1.94}/PKG-INFO +1 -1
- {homa-0.1.91 → homa-0.1.94}/pyproject.toml +1 -1
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/APLU.py +14 -8
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/GALU.py +6 -4
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/MELU.py +8 -6
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/PDELU.py +5 -3
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/SReLU.py +8 -4
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/SmallGALU.py +4 -2
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/WideMELU.py +10 -8
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
- homa-0.1.94/src/homa/loss/LogitNormLoss.py +12 -0
- homa-0.1.94/src/homa/loss/Loss.py +2 -0
- homa-0.1.94/src/homa/loss/__init__.py +2 -0
- homa-0.1.94/src/homa/vision/StochasticResnet.py +9 -0
- homa-0.1.94/src/homa/vision/StochasticSwin.py +9 -0
- homa-0.1.94/src/homa/vision/Swin.py +12 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/__init__.py +1 -0
- homa-0.1.94/src/homa/vision/modules/SwinModule.py +23 -0
- homa-0.1.94/src/homa/vision/modules/__init__.py +2 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/PKG-INFO +1 -1
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/SOURCES.txt +6 -1
- homa-0.1.91/src/homa/vision/StochasticResnet.py +0 -8
- homa-0.1.91/src/homa/vision/modules/StochasticResnetModule.py +0 -9
- homa-0.1.91/src/homa/vision/modules/__init__.py +0 -2
- {homa-0.1.91 → homa-0.1.94}/README.md +0 -0
- {homa-0.1.91 → homa-0.1.94}/setup.cfg +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/StochasticActivation.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/classes/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/activations/utils.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/cli/HomaCommand.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/cli/namespaces/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/device.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/Ensemble.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsLogits.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsSize.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/StoresModels.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/ensemble/concerns/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/settings.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/torch/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/torch/helpers.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/utils.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/ClassificationModel.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/Model.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/Resnet.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/HasLabels.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/HasLogits.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/HasProbabilities.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/ReportsAccuracy.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/Trainable.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/concerns/__init__.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/modules/ResnetModule.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa/vision/utils.py +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/dependency_links.txt +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/entry_points.txt +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/requires.txt +0 -0
- {homa-0.1.91 → homa-0.1.94}/src/homa.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "homa"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.94"
|
|
8
8
|
description = "A curated list of machine learning and deep learning helpers."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class APLU(torch.nn.Module):
|
|
@@ -12,6 +13,7 @@ class APLU(torch.nn.Module):
|
|
|
12
13
|
self.psi = None
|
|
13
14
|
self.mu = None
|
|
14
15
|
self._num_channels = None
|
|
16
|
+
self.device = get_device()
|
|
15
17
|
|
|
16
18
|
def _initialize_parameters(self, x):
|
|
17
19
|
if x.ndim < 2:
|
|
@@ -25,18 +27,23 @@ class APLU(torch.nn.Module):
|
|
|
25
27
|
param_shape = [1] * x.ndim
|
|
26
28
|
param_shape[1] = num_channels
|
|
27
29
|
|
|
28
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
|
|
29
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
30
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
|
|
30
|
+
self.alpha = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
31
|
+
self.beta = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
32
|
+
self.gamma = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
31
33
|
|
|
32
|
-
self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
35
|
+
self.device
|
|
36
|
+
)
|
|
37
|
+
self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
38
|
+
self.device
|
|
39
|
+
)
|
|
40
|
+
self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
41
|
+
self.device
|
|
42
|
+
)
|
|
35
43
|
|
|
36
44
|
def forward(self, x):
|
|
37
45
|
if self.alpha is None:
|
|
38
46
|
self._initialize_parameters(x)
|
|
39
|
-
|
|
40
47
|
a = torch.relu(x)
|
|
41
48
|
|
|
42
49
|
# following are called hinges
|
|
@@ -44,5 +51,4 @@ class APLU(torch.nn.Module):
|
|
|
44
51
|
c = self.beta * torch.relu(-x + self.psi)
|
|
45
52
|
d = self.gamma * torch.relu(-x + self.mu)
|
|
46
53
|
z = a + b + c + d
|
|
47
|
-
|
|
48
54
|
return z
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class GALU(torch.nn.Module):
|
|
@@ -12,6 +13,7 @@ class GALU(torch.nn.Module):
|
|
|
12
13
|
self.gamma = None
|
|
13
14
|
self.delta = None
|
|
14
15
|
self._num_channels = None
|
|
16
|
+
self.device = get_device()
|
|
15
17
|
|
|
16
18
|
def _initialize_parameters(self, x):
|
|
17
19
|
if x.ndim < 2:
|
|
@@ -23,10 +25,10 @@ class GALU(torch.nn.Module):
|
|
|
23
25
|
self._num_channels = num_channels
|
|
24
26
|
param_shape = [1] * x.ndim
|
|
25
27
|
param_shape[1] = num_channels
|
|
26
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
|
|
27
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
28
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
|
|
29
|
-
self.delta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
28
|
+
self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
29
|
+
self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
30
|
+
self.gamma = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
31
|
+
self.delta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
30
32
|
|
|
31
33
|
def forward(self, x):
|
|
32
34
|
if self.alpha is None:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class MELU(torch.nn.Module):
|
|
@@ -12,6 +13,7 @@ class MELU(torch.nn.Module):
|
|
|
12
13
|
self.xi = None
|
|
13
14
|
self.psi = None
|
|
14
15
|
self._initialized = False
|
|
16
|
+
self.device = get_device()
|
|
15
17
|
|
|
16
18
|
def _initialize_parameters(self, X: torch.Tensor):
|
|
17
19
|
if X.dim() != 4:
|
|
@@ -20,12 +22,12 @@ class MELU(torch.nn.Module):
|
|
|
20
22
|
)
|
|
21
23
|
num_channels = X.shape[1]
|
|
22
24
|
shape = (1, num_channels, 1, 1)
|
|
23
|
-
self.alpha = torch.nn.Parameter(torch.zeros(shape))
|
|
24
|
-
self.beta = torch.nn.Parameter(torch.zeros(shape))
|
|
25
|
-
self.gamma = torch.nn.Parameter(torch.zeros(shape))
|
|
26
|
-
self.delta = torch.nn.Parameter(torch.zeros(shape))
|
|
27
|
-
self.xi = torch.nn.Parameter(torch.zeros(shape))
|
|
28
|
-
self.psi = torch.nn.Parameter(torch.zeros(shape))
|
|
25
|
+
self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
26
|
+
self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
27
|
+
self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
28
|
+
self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
29
|
+
self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
30
|
+
self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
29
31
|
self._initialized = True
|
|
30
32
|
|
|
31
33
|
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class PDELU(torch.nn.Module):
|
|
@@ -12,6 +13,7 @@ class PDELU(torch.nn.Module):
|
|
|
12
13
|
self._power_val = 1.0 / (1.0 - self.theta)
|
|
13
14
|
self.alpha = torch.nn.UninitializedParameter()
|
|
14
15
|
self._num_channels = None
|
|
16
|
+
self.device = get_device()
|
|
15
17
|
|
|
16
18
|
def _initialize_parameters(self, x: torch.Tensor):
|
|
17
19
|
if x.ndim < 2:
|
|
@@ -23,14 +25,14 @@ class PDELU(torch.nn.Module):
|
|
|
23
25
|
self._num_channels = num_channels
|
|
24
26
|
param_shape = [1] * x.ndim
|
|
25
27
|
param_shape[1] = num_channels
|
|
26
|
-
init_tensor = torch.zeros(param_shape) + 0.1
|
|
27
|
-
self.alpha = torch.nn.Parameter(init_tensor)
|
|
28
|
+
init_tensor = torch.zeros(param_shape, device=self.device) + 0.1
|
|
29
|
+
self.alpha = torch.nn.Parameter(init_tensor).to(self.device)
|
|
28
30
|
|
|
29
31
|
def forward(self, x: torch.Tensor):
|
|
30
32
|
if self.alpha is None:
|
|
31
33
|
self._initialize_parameters(x)
|
|
32
34
|
|
|
33
|
-
zero = torch.tensor(0.0, device=
|
|
35
|
+
zero = torch.tensor(0.0, device=self.device, dtype=x.dtype)
|
|
34
36
|
positive_part = torch.relu(x)
|
|
35
37
|
inner_term = torch.relu(1.0 + (1.0 - self.theta) * x)
|
|
36
38
|
powered_term = torch.pow(inner_term, self._power_val)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class SReLU(torch.nn.Module):
|
|
@@ -18,6 +19,7 @@ class SReLU(torch.nn.Module):
|
|
|
18
19
|
self.beta = torch.nn.UninitializedParameter()
|
|
19
20
|
self.gamma = torch.nn.UninitializedParameter()
|
|
20
21
|
self.delta = torch.nn.UninitializedParameter()
|
|
22
|
+
self.device = get_device()
|
|
21
23
|
|
|
22
24
|
def _initialize_parameters(self, x: torch.Tensor):
|
|
23
25
|
if isinstance(self.alpha, torch.nn.UninitializedParameter):
|
|
@@ -31,14 +33,16 @@ class SReLU(torch.nn.Module):
|
|
|
31
33
|
param_shape[1] = num_channels
|
|
32
34
|
self.alpha = torch.nn.Parameter(
|
|
33
35
|
torch.full(param_shape, self.alpha_init_val)
|
|
34
|
-
)
|
|
35
|
-
self.beta = torch.nn.Parameter(
|
|
36
|
+
).to(self.device)
|
|
37
|
+
self.beta = torch.nn.Parameter(
|
|
38
|
+
torch.full(param_shape, self.beta_init_val)
|
|
39
|
+
).to(self.device)
|
|
36
40
|
self.gamma = torch.nn.Parameter(
|
|
37
41
|
torch.full(param_shape, self.gamma_init_val)
|
|
38
|
-
)
|
|
42
|
+
).to(self.device)
|
|
39
43
|
self.delta = torch.nn.Parameter(
|
|
40
44
|
torch.full(param_shape, self.delta_init_val)
|
|
41
|
-
)
|
|
45
|
+
).to(self.device)
|
|
42
46
|
|
|
43
47
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
44
48
|
self._initialize_parameters(x)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class SmallGALU(torch.nn.Module):
|
|
@@ -10,6 +11,7 @@ class SmallGALU(torch.nn.Module):
|
|
|
10
11
|
self.alpha = None
|
|
11
12
|
self.beta = None
|
|
12
13
|
self._num_channels = None
|
|
14
|
+
self.device = get_device()
|
|
13
15
|
|
|
14
16
|
def _initialize_parameters(self, x):
|
|
15
17
|
if x.ndim < 2:
|
|
@@ -21,8 +23,8 @@ class SmallGALU(torch.nn.Module):
|
|
|
21
23
|
self._num_channels = num_channels
|
|
22
24
|
param_shape = [1] * x.ndim
|
|
23
25
|
param_shape[1] = num_channels
|
|
24
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
|
|
25
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape))
|
|
26
|
+
self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
27
|
+
self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
26
28
|
|
|
27
29
|
def forward(self, x):
|
|
28
30
|
if self.alpha is None:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from ...device import get_device
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class WideMELU(torch.nn.Module):
|
|
@@ -14,6 +15,7 @@ class WideMELU(torch.nn.Module):
|
|
|
14
15
|
self.theta = None
|
|
15
16
|
self.lam = None
|
|
16
17
|
self._initialized = False
|
|
18
|
+
self.device = get_device()
|
|
17
19
|
|
|
18
20
|
def _initialize_parameters(self, X: torch.Tensor):
|
|
19
21
|
if X.dim() != 4:
|
|
@@ -24,14 +26,14 @@ class WideMELU(torch.nn.Module):
|
|
|
24
26
|
num_channels = X.shape[1]
|
|
25
27
|
shape = (1, num_channels, 1, 1)
|
|
26
28
|
|
|
27
|
-
self.alpha = torch.nn.Parameter(torch.zeros(shape))
|
|
28
|
-
self.beta = torch.nn.Parameter(torch.zeros(shape))
|
|
29
|
-
self.gamma = torch.nn.Parameter(torch.zeros(shape))
|
|
30
|
-
self.delta = torch.nn.Parameter(torch.zeros(shape))
|
|
31
|
-
self.xi = torch.nn.Parameter(torch.zeros(shape))
|
|
32
|
-
self.psi = torch.nn.Parameter(torch.zeros(shape))
|
|
33
|
-
self.theta = torch.nn.Parameter(torch.zeros(shape))
|
|
34
|
-
self.lam = torch.nn.Parameter(torch.zeros(shape))
|
|
29
|
+
self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
30
|
+
self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
31
|
+
self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
32
|
+
self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
33
|
+
self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
34
|
+
self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
35
|
+
self.theta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
36
|
+
self.lam = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
35
37
|
self._initialized = True
|
|
36
38
|
|
|
37
39
|
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
@@ -18,7 +18,7 @@ class CalculatesMetricNecessities:
|
|
|
18
18
|
model.eval()
|
|
19
19
|
logits = model(x)
|
|
20
20
|
sum_logits = logits if sum_logits is None else sum_logits + logits
|
|
21
|
-
|
|
22
|
-
predictions.extend(
|
|
21
|
+
batch_predictions = sum_logits.argmax(dim=1)
|
|
22
|
+
predictions.extend(batch_predictions.cpu().numpy())
|
|
23
23
|
labels.extend(y.cpu().numpy())
|
|
24
24
|
return predictions, labels
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Loss import Loss
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LogitNormLoss(Loss):
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
|
+
|
|
9
|
+
def forward(self, logits, target):
|
|
10
|
+
norms = torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7
|
|
11
|
+
normalized_logits = torch.div(logits, norms)
|
|
12
|
+
return torch.nn.functional.cross_entropy(normalized_logits, target)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .Resnet import Resnet
|
|
2
|
+
from .utils import replace_relu
|
|
3
|
+
from ..activations import StochasticActivation
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StochasticResnet(Resnet):
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
replace_relu(self.network, StochasticActivation)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .Swin import Swin
|
|
2
|
+
from .utils import replace_gelu
|
|
3
|
+
from ..activations import StochasticActivation
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StochasticSwin(Swin):
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
replace_gelu(self.network.encoder, StochasticActivation)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .ClassificationModel import ClassificationModel
|
|
3
|
+
from .concerns import Trainable, ReportsMetrics
|
|
4
|
+
from .modules import SwinModule
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Swin(ClassificationModel, Trainable, ReportsMetrics):
|
|
8
|
+
def __init__(self, num_classes: int, lr: float = 0.0001):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.network = SwinModule(num_classes=num_classes)
|
|
11
|
+
self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
|
|
12
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.models import swin_v2_b
|
|
3
|
+
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SwinModule(torch.nn.Module):
|
|
7
|
+
def __init__(self, num_classes: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.num_classes = num_classes
|
|
10
|
+
self._create_encoder()
|
|
11
|
+
self._create_fc()
|
|
12
|
+
|
|
13
|
+
def _create_encoder(self):
|
|
14
|
+
self.encoder = swin_v2_b(weights="DEFAULT")
|
|
15
|
+
self.encoder.head = torch.nn.Identity()
|
|
16
|
+
|
|
17
|
+
def _create_fc(self):
|
|
18
|
+
self.fc = torch.nn.Linear(1024, self.num_classes)
|
|
19
|
+
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
|
+
|
|
21
|
+
def forward(self, images: torch.Tensor):
|
|
22
|
+
features = self.encoder(images)
|
|
23
|
+
return self.fc(features)
|
|
@@ -37,12 +37,17 @@ src/homa/ensemble/concerns/ReportsLogits.py
|
|
|
37
37
|
src/homa/ensemble/concerns/ReportsSize.py
|
|
38
38
|
src/homa/ensemble/concerns/StoresModels.py
|
|
39
39
|
src/homa/ensemble/concerns/__init__.py
|
|
40
|
+
src/homa/loss/LogitNormLoss.py
|
|
41
|
+
src/homa/loss/Loss.py
|
|
42
|
+
src/homa/loss/__init__.py
|
|
40
43
|
src/homa/torch/__init__.py
|
|
41
44
|
src/homa/torch/helpers.py
|
|
42
45
|
src/homa/vision/ClassificationModel.py
|
|
43
46
|
src/homa/vision/Model.py
|
|
44
47
|
src/homa/vision/Resnet.py
|
|
45
48
|
src/homa/vision/StochasticResnet.py
|
|
49
|
+
src/homa/vision/StochasticSwin.py
|
|
50
|
+
src/homa/vision/Swin.py
|
|
46
51
|
src/homa/vision/__init__.py
|
|
47
52
|
src/homa/vision/utils.py
|
|
48
53
|
src/homa/vision/concerns/HasLabels.py
|
|
@@ -53,5 +58,5 @@ src/homa/vision/concerns/ReportsMetrics.py
|
|
|
53
58
|
src/homa/vision/concerns/Trainable.py
|
|
54
59
|
src/homa/vision/concerns/__init__.py
|
|
55
60
|
src/homa/vision/modules/ResnetModule.py
|
|
56
|
-
src/homa/vision/modules/
|
|
61
|
+
src/homa/vision/modules/SwinModule.py
|
|
57
62
|
src/homa/vision/modules/__init__.py
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
from .ResnetModule import ResnetModule
|
|
2
|
-
from ..utils import replace_relu
|
|
3
|
-
from ...activations import StochasticActivation
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class StochasticResnetModule(ResnetModule):
|
|
7
|
-
def __init__(self, *args, **kwargs):
|
|
8
|
-
super().__init__(*args, **kwargs)
|
|
9
|
-
replace_relu(self, StochasticActivation)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|