homa 0.0.19__tar.gz → 0.1.94__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. homa-0.1.94/PKG-INFO +75 -0
  2. homa-0.1.94/README.md +64 -0
  3. {homa-0.0.19 → homa-0.1.94}/pyproject.toml +1 -1
  4. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/APLU.py +14 -8
  5. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/GALU.py +6 -4
  6. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/MELU.py +8 -6
  7. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/PDELU.py +5 -3
  8. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/SReLU.py +8 -4
  9. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/SmallGALU.py +4 -2
  10. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/StochasticActivation.py +15 -1
  11. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/WideMELU.py +10 -8
  12. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/Ensemble.py +2 -4
  13. homa-0.1.94/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +24 -0
  14. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/PredictsProbabilities.py +4 -0
  15. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +3 -2
  16. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsLogits.py +4 -0
  17. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsSize.py +2 -2
  18. homa-0.1.94/src/homa/ensemble/concerns/StoresModels.py +29 -0
  19. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/__init__.py +1 -2
  20. homa-0.1.94/src/homa/loss/LogitNormLoss.py +12 -0
  21. homa-0.1.94/src/homa/loss/Loss.py +2 -0
  22. homa-0.1.94/src/homa/loss/__init__.py +2 -0
  23. homa-0.1.94/src/homa/torch/__init__.py +1 -0
  24. homa-0.1.94/src/homa/vision/ClassificationModel.py +5 -0
  25. homa-0.1.94/src/homa/vision/Resnet.py +13 -0
  26. homa-0.1.94/src/homa/vision/StochasticResnet.py +9 -0
  27. homa-0.1.94/src/homa/vision/StochasticSwin.py +9 -0
  28. homa-0.1.94/src/homa/vision/Swin.py +12 -0
  29. {homa-0.0.19 → homa-0.1.94}/src/homa/vision/__init__.py +1 -0
  30. homa-0.1.94/src/homa/vision/concerns/HasLabels.py +13 -0
  31. homa-0.1.94/src/homa/vision/concerns/HasLogits.py +12 -0
  32. homa-0.1.94/src/homa/vision/concerns/HasProbabilities.py +9 -0
  33. homa-0.1.94/src/homa/vision/concerns/ReportsAccuracy.py +27 -0
  34. homa-0.1.94/src/homa/vision/concerns/ReportsMetrics.py +6 -0
  35. {homa-0.0.19 → homa-0.1.94}/src/homa/vision/concerns/Trainable.py +5 -2
  36. homa-0.1.94/src/homa/vision/concerns/__init__.py +6 -0
  37. homa-0.1.94/src/homa/vision/modules/SwinModule.py +23 -0
  38. homa-0.1.94/src/homa/vision/modules/__init__.py +2 -0
  39. {homa-0.0.19 → homa-0.1.94}/src/homa/vision/utils.py +4 -0
  40. homa-0.1.94/src/homa.egg-info/PKG-INFO +75 -0
  41. {homa-0.0.19 → homa-0.1.94}/src/homa.egg-info/SOURCES.txt +14 -9
  42. homa-0.0.19/PKG-INFO +0 -21
  43. homa-0.0.19/README.md +0 -10
  44. homa-0.0.19/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -20
  45. homa-0.0.19/src/homa/ensemble/concerns/HasNetwork.py +0 -5
  46. homa-0.0.19/src/homa/ensemble/concerns/HasStateDicts.py +0 -8
  47. homa-0.0.19/src/homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
  48. homa-0.0.19/src/homa/torch/Module.py +0 -8
  49. homa-0.0.19/src/homa/torch/__init__.py +0 -2
  50. homa-0.0.19/src/homa/vision/Resnet.py +0 -12
  51. homa-0.0.19/src/homa/vision/StochasticResnet.py +0 -8
  52. homa-0.0.19/src/homa/vision/concerns/__init__.py +0 -1
  53. homa-0.0.19/src/homa/vision/modules/StochasticResnetModule.py +0 -9
  54. homa-0.0.19/src/homa/vision/modules/__init__.py +0 -2
  55. homa-0.0.19/src/homa.egg-info/PKG-INFO +0 -21
  56. homa-0.0.19/tests/test_ensemble.py +0 -28
  57. homa-0.0.19/tests/test_resnet.py +0 -21
  58. homa-0.0.19/tests/test_stochastic_resnet.py +0 -20
  59. {homa-0.0.19 → homa-0.1.94}/setup.cfg +0 -0
  60. {homa-0.0.19 → homa-0.1.94}/src/homa/__init__.py +0 -0
  61. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/__init__.py +0 -0
  62. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/classes/__init__.py +0 -0
  63. {homa-0.0.19 → homa-0.1.94}/src/homa/activations/utils.py +0 -0
  64. {homa-0.0.19 → homa-0.1.94}/src/homa/cli/HomaCommand.py +0 -0
  65. {homa-0.0.19 → homa-0.1.94}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
  66. {homa-0.0.19 → homa-0.1.94}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
  67. {homa-0.0.19 → homa-0.1.94}/src/homa/cli/namespaces/__init__.py +0 -0
  68. {homa-0.0.19 → homa-0.1.94}/src/homa/device.py +0 -0
  69. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/__init__.py +0 -0
  70. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -1
  71. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
  72. {homa-0.0.19 → homa-0.1.94}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
  73. {homa-0.0.19 → homa-0.1.94}/src/homa/settings.py +0 -0
  74. {homa-0.0.19 → homa-0.1.94}/src/homa/torch/helpers.py +0 -0
  75. {homa-0.0.19 → homa-0.1.94}/src/homa/utils.py +0 -0
  76. {homa-0.0.19 → homa-0.1.94}/src/homa/vision/Model.py +0 -0
  77. {homa-0.0.19 → homa-0.1.94}/src/homa/vision/modules/ResnetModule.py +0 -0
  78. {homa-0.0.19 → homa-0.1.94}/src/homa.egg-info/dependency_links.txt +0 -0
  79. {homa-0.0.19 → homa-0.1.94}/src/homa.egg-info/entry_points.txt +0 -0
  80. {homa-0.0.19 → homa-0.1.94}/src/homa.egg-info/requires.txt +0 -0
  81. {homa-0.0.19 → homa-0.1.94}/src/homa.egg-info/top_level.txt +0 -0
homa-0.1.94/PKG-INFO ADDED
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.4
2
+ Name: homa
3
+ Version: 0.1.94
4
+ Summary: A curated list of machine learning and deep learning helpers.
5
+ Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
+ Requires-Python: >=3.7
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: numpy
9
+ Requires-Dist: torch
10
+ Requires-Dist: fire
11
+
12
+ # Core
13
+
14
+ ### Device Management
15
+
16
+ ```py
17
+ from homa import cpu, mps, cuda, device
18
+
19
+ torch.tensor([1, 2, 3, 4, 5]).to(cpu())
20
+ torch.tensor([1, 2, 3, 4, 5]).to(cuda())
21
+ torch.tensor([1, 2, 3, 4, 5]).to(mps())
22
+ torch.tensor([1, 2, 3, 4, 5]).to(device())
23
+ ```
24
+
25
+ # Vision
26
+
27
+ ## Resnet
28
+
29
+ This is the standard ResNet50 module.
30
+
31
+ You can train the model with a `DataLoader` object.
32
+
33
+ ```py
34
+ from homa.vision import Resnet
35
+
36
+ model = Resnet(num_classes=10, lr=0.001)
37
+ for epoch in range(10):
38
+ model.train(train_dataloader)
39
+ ```
40
+
41
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
42
+
43
+ ```py
44
+ from homa.vision import Resnet
45
+
46
+ model = Resnet(num_classes=10, lr=0.001)
47
+ for epoch in range(10):
48
+ for x, y in train_dataloader:
49
+ model.train(x, y)
50
+ ```
51
+
52
+ ## StochasticResnet
53
+
54
+ This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
55
+
56
+ You can train the model with a `DataLoader` object.
57
+
58
+ ```py
59
+ from homa.vision import StochasticResnet
60
+
61
+ model = StochasticResnet(num_classes=10, lr=0.001)
62
+ for epoch in range(10):
63
+ model.train(train_dataloader)
64
+ ```
65
+
66
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
67
+
68
+ ```py
69
+ from homa.vision import StochasticResnet
70
+
71
+ model = StochasticResnet(num_classes=10, lr=0.001)
72
+ for epoch in range(10):
73
+ for x, y in train_dataloader:
74
+ model.train(x, y)
75
+ ```
homa-0.1.94/README.md ADDED
@@ -0,0 +1,64 @@
1
+ # Core
2
+
3
+ ### Device Management
4
+
5
+ ```py
6
+ from homa import cpu, mps, cuda, device
7
+
8
+ torch.tensor([1, 2, 3, 4, 5]).to(cpu())
9
+ torch.tensor([1, 2, 3, 4, 5]).to(cuda())
10
+ torch.tensor([1, 2, 3, 4, 5]).to(mps())
11
+ torch.tensor([1, 2, 3, 4, 5]).to(device())
12
+ ```
13
+
14
+ # Vision
15
+
16
+ ## Resnet
17
+
18
+ This is the standard ResNet50 module.
19
+
20
+ You can train the model with a `DataLoader` object.
21
+
22
+ ```py
23
+ from homa.vision import Resnet
24
+
25
+ model = Resnet(num_classes=10, lr=0.001)
26
+ for epoch in range(10):
27
+ model.train(train_dataloader)
28
+ ```
29
+
30
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
31
+
32
+ ```py
33
+ from homa.vision import Resnet
34
+
35
+ model = Resnet(num_classes=10, lr=0.001)
36
+ for epoch in range(10):
37
+ for x, y in train_dataloader:
38
+ model.train(x, y)
39
+ ```
40
+
41
+ ## StochasticResnet
42
+
43
+ This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
44
+
45
+ You can train the model with a `DataLoader` object.
46
+
47
+ ```py
48
+ from homa.vision import StochasticResnet
49
+
50
+ model = StochasticResnet(num_classes=10, lr=0.001)
51
+ for epoch in range(10):
52
+ model.train(train_dataloader)
53
+ ```
54
+
55
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
56
+
57
+ ```py
58
+ from homa.vision import StochasticResnet
59
+
60
+ model = StochasticResnet(num_classes=10, lr=0.001)
61
+ for epoch in range(10):
62
+ for x, y in train_dataloader:
63
+ model.train(x, y)
64
+ ```
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "homa"
7
- version = "0.0.19"
7
+ version = "0.1.94"
8
8
  description = "A curated list of machine learning and deep learning helpers."
9
9
  authors = [
10
10
  { name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class APLU(torch.nn.Module):
@@ -12,6 +13,7 @@ class APLU(torch.nn.Module):
12
13
  self.psi = None
13
14
  self.mu = None
14
15
  self._num_channels = None
16
+ self.device = get_device()
15
17
 
16
18
  def _initialize_parameters(self, x):
17
19
  if x.ndim < 2:
@@ -25,18 +27,23 @@ class APLU(torch.nn.Module):
25
27
  param_shape = [1] * x.ndim
26
28
  param_shape[1] = num_channels
27
29
 
28
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
29
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
30
- self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
30
+ self.alpha = torch.nn.Parameter(torch.zeros(param_shape), device=device())
31
+ self.beta = torch.nn.Parameter(torch.zeros(param_shape), device=device())
32
+ self.gamma = torch.nn.Parameter(torch.zeros(param_shape), device=device())
31
33
 
32
- self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
33
- self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
34
- self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
34
+ self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
35
+ self.device
36
+ )
37
+ self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
38
+ self.device
39
+ )
40
+ self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
41
+ self.device
42
+ )
35
43
 
36
44
  def forward(self, x):
37
45
  if self.alpha is None:
38
46
  self._initialize_parameters(x)
39
-
40
47
  a = torch.relu(x)
41
48
 
42
49
  # following are called hinges
@@ -44,5 +51,4 @@ class APLU(torch.nn.Module):
44
51
  c = self.beta * torch.relu(-x + self.psi)
45
52
  d = self.gamma * torch.relu(-x + self.mu)
46
53
  z = a + b + c + d
47
-
48
54
  return z
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class GALU(torch.nn.Module):
@@ -12,6 +13,7 @@ class GALU(torch.nn.Module):
12
13
  self.gamma = None
13
14
  self.delta = None
14
15
  self._num_channels = None
16
+ self.device = get_device()
15
17
 
16
18
  def _initialize_parameters(self, x):
17
19
  if x.ndim < 2:
@@ -23,10 +25,10 @@ class GALU(torch.nn.Module):
23
25
  self._num_channels = num_channels
24
26
  param_shape = [1] * x.ndim
25
27
  param_shape[1] = num_channels
26
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
27
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
28
- self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
29
- self.delta = torch.nn.Parameter(torch.zeros(param_shape))
28
+ self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
29
+ self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
30
+ self.gamma = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
31
+ self.delta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
30
32
 
31
33
  def forward(self, x):
32
34
  if self.alpha is None:
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class MELU(torch.nn.Module):
@@ -12,6 +13,7 @@ class MELU(torch.nn.Module):
12
13
  self.xi = None
13
14
  self.psi = None
14
15
  self._initialized = False
16
+ self.device = get_device()
15
17
 
16
18
  def _initialize_parameters(self, X: torch.Tensor):
17
19
  if X.dim() != 4:
@@ -20,12 +22,12 @@ class MELU(torch.nn.Module):
20
22
  )
21
23
  num_channels = X.shape[1]
22
24
  shape = (1, num_channels, 1, 1)
23
- self.alpha = torch.nn.Parameter(torch.zeros(shape))
24
- self.beta = torch.nn.Parameter(torch.zeros(shape))
25
- self.gamma = torch.nn.Parameter(torch.zeros(shape))
26
- self.delta = torch.nn.Parameter(torch.zeros(shape))
27
- self.xi = torch.nn.Parameter(torch.zeros(shape))
28
- self.psi = torch.nn.Parameter(torch.zeros(shape))
25
+ self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
26
+ self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
27
+ self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
28
+ self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
29
+ self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
30
+ self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
29
31
  self._initialized = True
30
32
 
31
33
  def forward(self, X: torch.Tensor) -> torch.Tensor:
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class PDELU(torch.nn.Module):
@@ -12,6 +13,7 @@ class PDELU(torch.nn.Module):
12
13
  self._power_val = 1.0 / (1.0 - self.theta)
13
14
  self.alpha = torch.nn.UninitializedParameter()
14
15
  self._num_channels = None
16
+ self.device = get_device()
15
17
 
16
18
  def _initialize_parameters(self, x: torch.Tensor):
17
19
  if x.ndim < 2:
@@ -23,14 +25,14 @@ class PDELU(torch.nn.Module):
23
25
  self._num_channels = num_channels
24
26
  param_shape = [1] * x.ndim
25
27
  param_shape[1] = num_channels
26
- init_tensor = torch.zeros(param_shape) + 0.1
27
- self.alpha = torch.nn.Parameter(init_tensor)
28
+ init_tensor = torch.zeros(param_shape, device=self.device) + 0.1
29
+ self.alpha = torch.nn.Parameter(init_tensor).to(self.device)
28
30
 
29
31
  def forward(self, x: torch.Tensor):
30
32
  if self.alpha is None:
31
33
  self._initialize_parameters(x)
32
34
 
33
- zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
35
+ zero = torch.tensor(0.0, device=self.device, dtype=x.dtype)
34
36
  positive_part = torch.relu(x)
35
37
  inner_term = torch.relu(1.0 + (1.0 - self.theta) * x)
36
38
  powered_term = torch.pow(inner_term, self._power_val)
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class SReLU(torch.nn.Module):
@@ -18,6 +19,7 @@ class SReLU(torch.nn.Module):
18
19
  self.beta = torch.nn.UninitializedParameter()
19
20
  self.gamma = torch.nn.UninitializedParameter()
20
21
  self.delta = torch.nn.UninitializedParameter()
22
+ self.device = get_device()
21
23
 
22
24
  def _initialize_parameters(self, x: torch.Tensor):
23
25
  if isinstance(self.alpha, torch.nn.UninitializedParameter):
@@ -31,14 +33,16 @@ class SReLU(torch.nn.Module):
31
33
  param_shape[1] = num_channels
32
34
  self.alpha = torch.nn.Parameter(
33
35
  torch.full(param_shape, self.alpha_init_val)
34
- )
35
- self.beta = torch.nn.Parameter(torch.full(param_shape, self.beta_init_val))
36
+ ).to(self.device)
37
+ self.beta = torch.nn.Parameter(
38
+ torch.full(param_shape, self.beta_init_val)
39
+ ).to(self.device)
36
40
  self.gamma = torch.nn.Parameter(
37
41
  torch.full(param_shape, self.gamma_init_val)
38
- )
42
+ ).to(self.device)
39
43
  self.delta = torch.nn.Parameter(
40
44
  torch.full(param_shape, self.delta_init_val)
41
- )
45
+ ).to(self.device)
42
46
 
43
47
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
48
  self._initialize_parameters(x)
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class SmallGALU(torch.nn.Module):
@@ -10,6 +11,7 @@ class SmallGALU(torch.nn.Module):
10
11
  self.alpha = None
11
12
  self.beta = None
12
13
  self._num_channels = None
14
+ self.device = get_device()
13
15
 
14
16
  def _initialize_parameters(self, x):
15
17
  if x.ndim < 2:
@@ -21,8 +23,8 @@ class SmallGALU(torch.nn.Module):
21
23
  self._num_channels = num_channels
22
24
  param_shape = [1] * x.ndim
23
25
  param_shape[1] = num_channels
24
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
25
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
26
+ self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
27
+ self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
26
28
 
27
29
  def forward(self, x):
28
30
  if self.alpha is None:
@@ -13,7 +13,21 @@ from .SReLU import SReLU
13
13
  class StochasticActivation(torch.nn.Module):
14
14
  def __init__(self):
15
15
  super().__init__()
16
- self.gate = random.choice([APLU, GALU, SmallGALU, MELU, WideMELU, PDELU, SReLU])
16
+ self.gate = random.choice(
17
+ [
18
+ APLU,
19
+ GALU,
20
+ SmallGALU,
21
+ MELU,
22
+ WideMELU,
23
+ PDELU,
24
+ SReLU,
25
+ torch.nn.ReLU,
26
+ torch.nn.PReLU,
27
+ torch.nn.LeakyReLU,
28
+ torch.nn.ELU,
29
+ ]
30
+ )
17
31
  self.gate = self.gate()
18
32
 
19
33
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class WideMELU(torch.nn.Module):
@@ -14,6 +15,7 @@ class WideMELU(torch.nn.Module):
14
15
  self.theta = None
15
16
  self.lam = None
16
17
  self._initialized = False
18
+ self.device = get_device()
17
19
 
18
20
  def _initialize_parameters(self, X: torch.Tensor):
19
21
  if X.dim() != 4:
@@ -24,14 +26,14 @@ class WideMELU(torch.nn.Module):
24
26
  num_channels = X.shape[1]
25
27
  shape = (1, num_channels, 1, 1)
26
28
 
27
- self.alpha = torch.nn.Parameter(torch.zeros(shape))
28
- self.beta = torch.nn.Parameter(torch.zeros(shape))
29
- self.gamma = torch.nn.Parameter(torch.zeros(shape))
30
- self.delta = torch.nn.Parameter(torch.zeros(shape))
31
- self.xi = torch.nn.Parameter(torch.zeros(shape))
32
- self.psi = torch.nn.Parameter(torch.zeros(shape))
33
- self.theta = torch.nn.Parameter(torch.zeros(shape))
34
- self.lam = torch.nn.Parameter(torch.zeros(shape))
29
+ self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
30
+ self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
31
+ self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
32
+ self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
33
+ self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
34
+ self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
35
+ self.theta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
36
+ self.lam = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
35
37
  self._initialized = True
36
38
 
37
39
  def forward(self, X: torch.Tensor) -> torch.Tensor:
@@ -1,8 +1,7 @@
1
1
  from .concerns import (
2
2
  ReportsSize,
3
- RecordsStateDictionaries,
3
+ StoresModels,
4
4
  ReportsClassificationMetrics,
5
- HasNetwork,
6
5
  PredictsProbabilities,
7
6
  )
8
7
 
@@ -10,9 +9,8 @@ from .concerns import (
10
9
  class Ensemble(
11
10
  ReportsSize,
12
11
  ReportsClassificationMetrics,
13
- RecordsStateDictionaries,
14
12
  PredictsProbabilities,
15
- HasNetwork,
13
+ StoresModels,
16
14
  ):
17
15
  def __init__(self):
18
16
  super().__init__()
@@ -0,0 +1,24 @@
1
+ import torch
2
+ from ...device import get_device
3
+
4
+
5
+ class CalculatesMetricNecessities:
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ @torch.no_grad()
10
+ def metric_necessities(self, dataloader):
11
+ predictions, labels = [], []
12
+ device = get_device()
13
+ for x, y in dataloader:
14
+ x, y = x.to(device), y.to(device)
15
+ sum_logits = None
16
+ for model in self.models:
17
+ model.to(device)
18
+ model.eval()
19
+ logits = model(x)
20
+ sum_logits = logits if sum_logits is None else sum_logits + logits
21
+ batch_predictions = sum_logits.argmax(dim=1)
22
+ predictions.extend(batch_predictions.cpu().numpy())
23
+ labels.extend(y.cpu().numpy())
24
+ return predictions, labels
@@ -9,3 +9,7 @@ class PredictsProbabilities(ReportsLogits):
9
9
  def predict(self, x: torch.Tensor) -> torch.Tensor:
10
10
  logits = self.logits(x)
11
11
  return torch.nn.functional.softmax(logits, dim=1)
12
+
13
+ @torch.no_grad()
14
+ def predict_(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.predict(x)
@@ -1,10 +1,11 @@
1
1
  from sklearn.metrics import accuracy_score as accuracy
2
+ from torch.utils.data import DataLoader
2
3
 
3
4
 
4
5
  class ReportsEnsembleAccuracy:
5
6
  def __init__(self, *args, **kwargs):
6
7
  super().__init__(*args, **kwargs)
7
8
 
8
- def accuracy(self) -> float:
9
- predictions, labels = self.metric_necessities()
9
+ def accuracy(self, dataloader: DataLoader) -> float:
10
+ predictions, labels = self.metric_necessities(dataloader)
10
11
  return accuracy(labels, predictions)
@@ -11,3 +11,7 @@ class ReportsLogits:
11
11
  for model in self.models:
12
12
  logits += model(x)
13
13
  return logits
14
+
15
+ @torch.no_grad()
16
+ def logits_(self, *args, **kwargs):
17
+ return self.logits(*args, **kwargs)
@@ -4,8 +4,8 @@ class ReportsSize:
4
4
 
5
5
  @property
6
6
  def size(self):
7
- return len(self.state_dicts)
7
+ return len(self.models)
8
8
 
9
9
  @property
10
10
  def length(self):
11
- return len(self.state_dicts)
11
+ return len(self.models)
@@ -0,0 +1,29 @@
1
+ import torch
2
+ from copy import deepcopy
3
+ from typing import List
4
+ from ...vision import Model
5
+
6
+
7
+ class StoresModels:
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.models: List[torch.nn.Module] = []
11
+
12
+ def record(self, model: Model | torch.nn.Module):
13
+ model_: torch.nn.Module | None = None
14
+ if isinstance(model, Model):
15
+ model_ = deepcopy(model.network)
16
+ elif isinstance(model, torch.nn.Module):
17
+ model_ = deepcopy(model)
18
+ else:
19
+ raise TypeError("Wrong input to ensemble record")
20
+ self.models.append(model_)
21
+
22
+ def push(self, *args, **kwargs):
23
+ self.record(*args, **kwargs)
24
+
25
+ def append(self, *args, **kwargs):
26
+ self.record(*args, **kwargs)
27
+
28
+ def add(self, *args, **kwargs):
29
+ self.record(*args, **kwargs)
@@ -1,10 +1,9 @@
1
1
  from .CalculatesMetricNecessities import CalculatesMetricNecessities
2
- from .HasNetwork import HasNetwork
3
2
  from .PredictsProbabilities import PredictsProbabilities
4
- from .RecordsStateDictionaries import RecordsStateDictionaries
5
3
  from .ReportsClassificationMetrics import ReportsClassificationMetrics
6
4
  from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
7
5
  from .ReportsEnsembleF1 import ReportsEnsembleF1
8
6
  from .ReportsEnsembleKappa import ReportsEnsembleKappa
9
7
  from .ReportsLogits import ReportsLogits
10
8
  from .ReportsSize import ReportsSize
9
+ from .StoresModels import StoresModels
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from .Loss import Loss
3
+
4
+
5
+ class LogitNormLoss(Loss):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def forward(self, logits, target):
10
+ norms = torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7
11
+ normalized_logits = torch.div(logits, norms)
12
+ return torch.nn.functional.cross_entropy(normalized_logits, target)
@@ -0,0 +1,2 @@
1
+ class Loss:
2
+ pass
@@ -0,0 +1,2 @@
1
+ from .LogitNormLoss import LogitNormLoss
2
+ from .Loss import Loss
@@ -0,0 +1 @@
1
+ from .helpers import *
@@ -0,0 +1,5 @@
1
+ from .Model import Model
2
+
3
+
4
+ class ClassificationModel(Model):
5
+ pass
@@ -0,0 +1,13 @@
1
+ import torch
2
+ from .modules import ResnetModule
3
+ from .ClassificationModel import ClassificationModel
4
+ from .concerns import Trainable, ReportsMetrics
5
+ from ..device import get_device
6
+
7
+
8
+ class Resnet(ClassificationModel, Trainable, ReportsMetrics):
9
+ def __init__(self, num_classes: int, lr: float = 0.001):
10
+ super().__init__()
11
+ self.network = ResnetModule(num_classes).to(get_device())
12
+ self.criterion = torch.nn.CrossEntropyLoss()
13
+ self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
@@ -0,0 +1,9 @@
1
+ from .Resnet import Resnet
2
+ from .utils import replace_relu
3
+ from ..activations import StochasticActivation
4
+
5
+
6
+ class StochasticResnet(Resnet):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ replace_relu(self.network, StochasticActivation)
@@ -0,0 +1,9 @@
1
+ from .Swin import Swin
2
+ from .utils import replace_gelu
3
+ from ..activations import StochasticActivation
4
+
5
+
6
+ class StochasticSwin(Swin):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ replace_gelu(self.network.encoder, StochasticActivation)
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from .ClassificationModel import ClassificationModel
3
+ from .concerns import Trainable, ReportsMetrics
4
+ from .modules import SwinModule
5
+
6
+
7
+ class Swin(ClassificationModel, Trainable, ReportsMetrics):
8
+ def __init__(self, num_classes: int, lr: float = 0.0001):
9
+ super().__init__()
10
+ self.network = SwinModule(num_classes=num_classes)
11
+ self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
12
+ self.criterion = torch.nn.CrossEntropyLoss()
@@ -1,3 +1,4 @@
1
1
  from .Model import Model
2
2
  from .Resnet import Resnet
3
3
  from .StochasticResnet import StochasticResnet
4
+ from .Swin import Swin
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ class HasLabels:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def predict(self, x: torch.Tensor):
9
+ return torch.argmax(self.logits(x), dim=1)
10
+
11
+ @torch.no_grad()
12
+ def predict_(self, x: torch.Tensor):
13
+ return torch.argmax(self.logits(x), dim=1)
@@ -0,0 +1,12 @@
1
+ import torch
2
+
3
+
4
+ class HasLogits:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
9
+ return self.network(x)
10
+
11
+ def logits_(self, x: torch.Tensor) -> torch.Tensor:
12
+ return self.network(x)