homa 0.1.1__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. homa/activations/APLU.py +49 -0
  2. homa/activations/ActivationFunction.py +6 -0
  3. homa/activations/AdaptiveActivationFunction.py +15 -0
  4. homa/activations/BaseDLReLU.py +34 -0
  5. homa/activations/CaLU.py +13 -0
  6. homa/activations/DLReLU.py +6 -0
  7. homa/activations/ERF.py +10 -0
  8. homa/activations/Elliot.py +10 -0
  9. homa/activations/ExpExpish.py +9 -0
  10. homa/activations/ExponentialDLReLU.py +6 -0
  11. homa/activations/ExponentialSwish.py +10 -0
  12. homa/activations/GCU.py +9 -0
  13. homa/activations/GaLU.py +11 -0
  14. homa/activations/GaussianReLU.py +50 -0
  15. homa/activations/GeneralizedSwish.py +10 -0
  16. homa/activations/Gish.py +11 -0
  17. homa/activations/LaLU.py +11 -0
  18. homa/activations/LogLogish.py +10 -0
  19. homa/activations/LogSigmoid.py +10 -0
  20. homa/activations/Logish.py +10 -0
  21. homa/activations/MeLU.py +11 -0
  22. homa/activations/MexicanReLU.py +49 -0
  23. homa/activations/MinSin.py +10 -0
  24. homa/activations/NReLU.py +12 -0
  25. homa/activations/NoisyReLU.py +6 -0
  26. homa/activations/PLogish.py +6 -0
  27. homa/activations/ParametricLogish.py +13 -0
  28. homa/activations/Phish.py +11 -0
  29. homa/activations/RReLU.py +16 -0
  30. homa/activations/RandomizedSlopedReLU.py +7 -0
  31. homa/activations/SGELU.py +12 -0
  32. homa/activations/SReLU.py +37 -0
  33. homa/activations/SelfArctan.py +9 -0
  34. homa/activations/ShiftedReLU.py +10 -0
  35. homa/activations/SigmoidDerivative.py +10 -0
  36. homa/activations/SineReLU.py +11 -0
  37. homa/activations/SlopedReLU.py +13 -0
  38. homa/activations/SmallGaLU.py +11 -0
  39. homa/activations/Smish.py +9 -0
  40. homa/activations/SoftsignRReLU.py +17 -0
  41. homa/activations/Suish.py +11 -0
  42. homa/activations/TBSReLU.py +13 -0
  43. homa/activations/TSReLU.py +10 -0
  44. homa/activations/TangentBipolarSigmoidReLU.py +6 -0
  45. homa/activations/TangentSigmoidReLU.py +6 -0
  46. homa/activations/TeLU.py +9 -0
  47. homa/activations/TripleStateSwish.py +15 -0
  48. homa/activations/WideMeLU.py +15 -0
  49. homa/activations/__init__.py +49 -2
  50. homa/activations/learnable/AOAF.py +16 -0
  51. homa/activations/learnable/AReLU.py +16 -0
  52. homa/activations/learnable/DPReLU.py +16 -0
  53. homa/activations/learnable/DualLine.py +18 -0
  54. homa/activations/learnable/FReLU.py +14 -0
  55. homa/activations/learnable/LeLeLU.py +14 -0
  56. homa/activations/learnable/PERU.py +16 -0
  57. homa/activations/learnable/PiLU.py +18 -0
  58. homa/activations/learnable/ShiLU.py +16 -0
  59. homa/activations/learnable/StarReLU.py +16 -0
  60. homa/activations/learnable/__init__.py +10 -0
  61. homa/activations/learnable/concerns/ChannelBased.py +36 -0
  62. homa/activations/learnable/concerns/__init__.py +1 -0
  63. homa/cli/Commands/Command.py +2 -0
  64. homa/cli/Commands/InitCommand.py +34 -0
  65. homa/cli/Commands/__init__.py +2 -0
  66. homa/cli/HomaCommand.py +4 -0
  67. homa/ensemble/Ensemble.py +2 -4
  68. homa/ensemble/concerns/CalculatesMetricNecessities.py +14 -10
  69. homa/ensemble/concerns/PredictsProbabilities.py +4 -0
  70. homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -1
  71. homa/ensemble/concerns/ReportsEnsembleAccuracy.py +3 -2
  72. homa/ensemble/concerns/ReportsLogits.py +4 -0
  73. homa/ensemble/concerns/ReportsSize.py +2 -2
  74. homa/ensemble/concerns/StoresModels.py +29 -0
  75. homa/ensemble/concerns/__init__.py +1 -2
  76. homa/loss/LogitNormLoss.py +12 -0
  77. homa/loss/Loss.py +2 -0
  78. homa/loss/__init__.py +2 -0
  79. homa/torch/__init__.py +0 -1
  80. homa/vision/Classifier.py +5 -0
  81. homa/vision/Resnet.py +6 -5
  82. homa/vision/StochasticClassifier.py +29 -0
  83. homa/vision/StochasticSwin.py +11 -0
  84. homa/vision/Swin.py +13 -0
  85. homa/vision/__init__.py +3 -1
  86. homa/vision/concerns/HasLabels.py +13 -0
  87. homa/vision/concerns/HasLogits.py +12 -0
  88. homa/vision/concerns/HasProbabilities.py +9 -0
  89. homa/vision/concerns/ReportsAccuracy.py +27 -0
  90. homa/vision/concerns/ReportsMetrics.py +6 -0
  91. homa/vision/concerns/Trainable.py +5 -2
  92. homa/vision/concerns/__init__.py +5 -0
  93. homa/vision/modules/SwinModule.py +23 -0
  94. homa/vision/modules/__init__.py +1 -1
  95. homa/vision/utils.py +9 -18
  96. homa-0.2.9.dist-info/METADATA +75 -0
  97. homa-0.2.9.dist-info/RECORD +113 -0
  98. homa/activations/classes/APLU.py +0 -48
  99. homa/activations/classes/GALU.py +0 -51
  100. homa/activations/classes/MELU.py +0 -50
  101. homa/activations/classes/PDELU.py +0 -39
  102. homa/activations/classes/SReLU.py +0 -49
  103. homa/activations/classes/SmallGALU.py +0 -39
  104. homa/activations/classes/StochasticActivation.py +0 -20
  105. homa/activations/classes/WideMELU.py +0 -61
  106. homa/activations/classes/__init__.py +0 -8
  107. homa/activations/utils.py +0 -27
  108. homa/ensemble/concerns/HasNetwork.py +0 -5
  109. homa/ensemble/concerns/HasStateDicts.py +0 -8
  110. homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
  111. homa/torch/Module.py +0 -8
  112. homa/vision/StochasticResnet.py +0 -8
  113. homa/vision/modules/StochasticResnetModule.py +0 -9
  114. homa-0.1.1.dist-info/METADATA +0 -21
  115. homa-0.1.1.dist-info/RECORD +0 -51
  116. {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/WHEEL +0 -0
  117. {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/entry_points.txt +0 -0
  118. {homa-0.1.1.dist-info → homa-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class TBSReLU(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ a = 1 - torch.exp(-x)
11
+ b = 1 + torch.exp(-x)
12
+ c = a / b
13
+ return x * torch.tanh(c)
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class TSReLU(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return x * torch.tanh(torch.sigmoid(x))
@@ -0,0 +1,6 @@
1
+ from .TBSReLU import TBSReLU
2
+
3
+
4
+ class TangentBipolarSigmoidReLU(TBSReLU):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
@@ -0,0 +1,6 @@
1
+ from .TSReLU import TSReLU
2
+
3
+
4
+ class TangentSigmoidReLU(TSReLU):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ class TeLU(torch.nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x: torch.Tensor):
9
+ return x * torch.tanh(torch.exp(x))
@@ -0,0 +1,15 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class TripleStateSwish(ActivationFunction):
6
+ def __init__(self, alpha: float = 20, beta: float = 40, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.alpha = alpha
9
+ self.beta = beta
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ a = 1 / (1 + torch.exp(-x))
13
+ b = 1 / (1 + torch.exp(-x + self.alpha))
14
+ c = 1 / (1 + torch.exp(-x + self.beta))
15
+ return x * a * (a + b + c)
@@ -0,0 +1,15 @@
1
+ from .MexicanReLU import MexicanReLU
2
+
3
+
4
+ class WideMeLU(MexicanReLU):
5
+ def __init__(self, channels: int | None = None, max_input: float = 1.0):
6
+ self.hats = [
7
+ (2.0, 2.0),
8
+ (1.0, 1.0),
9
+ (3.0, 1.0),
10
+ (0.5, 0.5),
11
+ (1.5, 0.5),
12
+ (2.5, 0.5),
13
+ (3.5, 0.5),
14
+ ]
15
+ super().__init__(self.hats, channels=channels, max_input=max_input)
@@ -1,2 +1,49 @@
1
- from .classes import *
2
- from .utils import *
1
+ from .ShiftedReLU import ShiftedReLU
2
+ from .PLogish import PLogish
3
+ from .ParametricLogish import ParametricLogish
4
+ from .ExpExpish import ExpExpish
5
+ from .GeneralizedSwish import GeneralizedSwish
6
+ from .TBSReLU import TBSReLU
7
+ from .NoisyReLU import NoisyReLU
8
+ from .ExponentialDLReLU import ExponentialDLReLU
9
+ from .SReLU import SReLU
10
+ from .TangentSigmoidReLU import TangentSigmoidReLU
11
+ from .Phish import Phish
12
+ from .WideMeLU import WideMeLU
13
+ from .SelfArctan import SelfArctan
14
+ from .LogSigmoid import LogSigmoid
15
+ from .SlopedReLU import SlopedReLU
16
+ from .SmallGaLU import SmallGaLU
17
+ from .MinSin import MinSin
18
+ from .LaLU import LaLU
19
+ from .MexicanReLU import MexicanReLU
20
+ from .APLU import APLU
21
+ from .ERF import ERF
22
+ from .TangentBipolarSigmoidReLU import TangentBipolarSigmoidReLU
23
+ from .BaseDLReLU import BaseDLReLU
24
+ from .Logish import Logish
25
+ from .TripleStateSwish import TripleStateSwish
26
+ from .ExponentialSwish import ExponentialSwish
27
+ from .TeLU import TeLU
28
+ from .Elliot import Elliot
29
+ from .MeLU import MeLU
30
+ from .GaussianReLU import GaussianReLU
31
+ from .ActivationFunction import ActivationFunction
32
+ from .RReLU import RReLU
33
+ from .Suish import Suish
34
+ from .SoftsignRReLU import SoftsignRReLU
35
+ from .Gish import Gish
36
+ from .NReLU import NReLU
37
+ from .LogLogish import LogLogish
38
+ from .SGELU import SGELU
39
+ from .GaLU import GaLU
40
+ from .TSReLU import TSReLU
41
+ from .SineReLU import SineReLU
42
+ from .DLReLU import DLReLU
43
+ from .CaLU import CaLU
44
+ from .RandomizedSlopedReLU import RandomizedSlopedReLU
45
+ from .GCU import GCU
46
+ from .SigmoidDerivative import SigmoidDerivative
47
+ from .Smish import Smish
48
+ from .AdaptiveActivationFunction import AdaptiveActivationFunction
49
+ from .learnable import *
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from .concerns import ChannelBased
3
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
4
+
5
+
6
+ class AOAF(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self, b: float = 0.17, c: float = 0.17):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = b
11
+ self.c = c
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ self.initialize(x, "a")
15
+ a = self.a.view(self.parameter_view(x))
16
+ return torch.relu(x - self.b * a) + self.c * a
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+
4
+
5
+ class AReLU(AdaptiveActivationFunction):
6
+ def __init__(self):
7
+ super(AReLU, self).__init__()
8
+ self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
9
+ self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
10
+
11
+ def forward(self, z):
12
+ negative_slope = torch.clamp(self.a, 0.01, 0.99)
13
+ positive_slope = 1 + torch.sigmoid(self.b)
14
+ positive = positive_slope * torch.relu(z)
15
+ negative = negative_slope * (-torch.relu(-z))
16
+ return positive + negative
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class DPReLU(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+
12
+ def forward(self, x: torch.Tensor):
13
+ self.initialize(x, ["a", "b"], [1, 0.01])
14
+ a = self.a.view(self.parameter_shape(x))
15
+ b = self.b.view(self.parameter_shape(x))
16
+ return torch.where(x >= 0, a * x, b * x)
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class DualLine(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+ self.m = None
12
+
13
+ def forward(self, x: torch.Tensor):
14
+ self.initialize(x, ["a", "b", "m"], [1, 0.01, -0.22])
15
+ a = self.a.view(self.parameter_shape(x))
16
+ b = self.b.view(self.parameter_shape(x))
17
+ m = self.m.view(self.parameter_shape(x))
18
+ return torch.where(x >= 0, a * x + m, b * x + m)
@@ -0,0 +1,14 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class FReLU(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.b = None
10
+
11
+ def forward(self, x: torch.Tensor):
12
+ self.initialize(x, "b")
13
+ b = self.b.view(self.parameter_shape(x))
14
+ return torch.where(x >= 0, x + b, b)
@@ -0,0 +1,14 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class LeLeLU(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+
11
+ def forward(self, x: torch.Tensor):
12
+ self.initialize(x, "a")
13
+ a = self.a.view(self.parameter_shape(x))
14
+ return torch.where(x >= 0, a * x, 0.01 * a * x)
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from .concerns import ChannelBased
3
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
4
+
5
+
6
+ class PERU(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+
12
+ def forward(self, x: torch.Tensor):
13
+ self.initialize(x, ["a", "b"])
14
+ a = self.a.view(self.parameter_shape(x))
15
+ b = self.b.view(self.parameter_shape(x))
16
+ return torch.where(x >= 0, a * x, a * x * torch.exp(b * x))
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class DualLine(AdaptiveActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+ self.c = None
12
+
13
+ def forward(self, x: torch.Tensor):
14
+ self.initialize(x, ["a", "b", "c"], [1, 0.01, 1])
15
+ a = self.a.view(self.parameter_shape(x))
16
+ b = self.b.view(self.parameter_shape(x))
17
+ c = self.c.view(self.parameter_shape(x))
18
+ return torch.where(x >= c, a * x + c * (1 - a), b * x + c * (1 - b))
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from ..ActivationFunction import ActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class ShiLU(ActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ self.initialize(x, ["a", "b"])
14
+ a = self.a.view(self.parameter_shape(x))
15
+ b = self.b.view(self.parameter_shape(x))
16
+ return torch.relu(x) * a + b
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from ..ActivationFunction import ActivationFunction
3
+ from .concerns import ChannelBased
4
+
5
+
6
+ class StarReLU(ActivationFunction, ChannelBased):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.a = None
10
+ self.b = None
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ self.initialize(x, ["a", "b"])
14
+ a = self.a.view(self.parameter_shape(x))
15
+ b = self.b.view(self.parameter_shape(x))
16
+ return a * torch.relu(x).pow(2) + b
@@ -0,0 +1,10 @@
1
+ from .StarReLU import StarReLU
2
+ from .DualLine import DualLine
3
+ from .LeLeLU import LeLeLU
4
+ from .AReLU import AReLU
5
+ from .PERU import PERU
6
+ from .ShiLU import ShiLU
7
+ from .DPReLU import DPReLU
8
+ from .PiLU import DualLine
9
+ from .FReLU import FReLU
10
+ from .AOAF import AOAF
@@ -0,0 +1,36 @@
1
+ import torch
2
+ from typing import List
3
+
4
+
5
+ class ChannelBased:
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self._initialized = False
9
+ self.num_channels = None
10
+
11
+ def initialize(
12
+ self, x: torch.Tensor, attrs: List[str] | str, values: List[float] | float = []
13
+ ):
14
+ if getattr(self, "_initialized", False):
15
+ return
16
+
17
+ if not isinstance(values, list):
18
+ values = [values]
19
+
20
+ if not isinstance(attrs, list):
21
+ attrs = [attrs]
22
+
23
+ self.num_channels = x.shape[1]
24
+ for index, attr in enumerate(attrs):
25
+ if index < len(values) and values[index] is not None:
26
+ default_value = float(values[index])
27
+ else:
28
+ default_value = 1.0
29
+ param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
30
+ setattr(self, attr, param)
31
+ self._initialized = True
32
+
33
+ def parameter_shape(self, x: torch.Tensor) -> tuple | None:
34
+ if hasattr(self, "num_channels"):
35
+ return (1, self.num_channels) + (1,) * (x.ndim - 2)
36
+ return None
@@ -0,0 +1 @@
1
+ from .ChannelBased import ChannelBased
@@ -0,0 +1,2 @@
1
+ class Command:
2
+ pass
@@ -0,0 +1,34 @@
1
+ import ast
2
+ from pathlib import Path
3
+ from .Command import Command
4
+
5
+
6
+ class InitCommand(Command):
7
+ @staticmethod
8
+ def run():
9
+ path = Path(".")
10
+ init_file = path / "__init__.py"
11
+ init_file.write_text("")
12
+ for file in path.iterdir():
13
+ if file.name == "__init__.py" or file.suffix != ".py":
14
+ continue
15
+ tree = ast.parse(file.read_text())
16
+ classes = [
17
+ node.name
18
+ for node in tree.body
19
+ if isinstance(node, ast.ClassDef) and not node.name.startswith("_")
20
+ ]
21
+ functions = [
22
+ node.name
23
+ for node in tree.body
24
+ if isinstance(node, ast.FunctionDef) and not node.name.startswith("_")
25
+ ]
26
+ if not (classes or functions):
27
+ continue
28
+ module = file.stem
29
+ lines = [
30
+ f"from .{module} import {name}\n" for name in (*classes, *functions)
31
+ ]
32
+ with init_file.open("a") as f:
33
+ f.writelines(lines)
34
+ print(f"Processed {file}: classes={classes}, functions={functions}")
@@ -0,0 +1,2 @@
1
+ from .Command import Command
2
+ from .InitCommand import InitCommand
homa/cli/HomaCommand.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import fire
2
2
  from .namespaces import MakeNamespace, CacheNamespace
3
+ from .Commands import InitCommand
3
4
 
4
5
 
5
6
  class HomaCommand:
@@ -7,6 +8,9 @@ class HomaCommand:
7
8
  self.make = MakeNamespace()
8
9
  self.cache = CacheNamespace()
9
10
 
11
+ def init(self):
12
+ InitCommand.run()
13
+
10
14
 
11
15
  def main():
12
16
  fire.Fire(HomaCommand)
homa/ensemble/Ensemble.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from .concerns import (
2
2
  ReportsSize,
3
- RecordsStateDictionaries,
3
+ StoresModels,
4
4
  ReportsClassificationMetrics,
5
- HasNetwork,
6
5
  PredictsProbabilities,
7
6
  )
8
7
 
@@ -10,9 +9,8 @@ from .concerns import (
10
9
  class Ensemble(
11
10
  ReportsSize,
12
11
  ReportsClassificationMetrics,
13
- RecordsStateDictionaries,
14
12
  PredictsProbabilities,
15
- HasNetwork,
13
+ StoresModels,
16
14
  ):
17
15
  def __init__(self):
18
16
  super().__init__()
@@ -1,20 +1,24 @@
1
1
  import torch
2
+ from ...device import get_device
2
3
 
3
4
 
4
5
  class CalculatesMetricNecessities:
5
6
  def __init__(self, *args, **kwargs):
6
7
  super().__init__(*args, **kwargs)
7
8
 
9
+ @torch.no_grad()
8
10
  def metric_necessities(self, dataloader):
9
- all_predictions = []
10
- all_labels = []
11
+ predictions, labels = [], []
12
+ device = get_device()
11
13
  for x, y in dataloader:
12
- batch_logits_list = []
14
+ x, y = x.to(device), y.to(device)
15
+ sum_logits = None
13
16
  for model in self.models:
14
- batch_logits_list.append(model(x))
15
- all_batch_logits = torch.stack(batch_logits_list)
16
- avg_logits = torch.mean(all_batch_logits, dim=0)
17
- _, preds = torch.max(avg_logits, 1)
18
- all_predictions.extend(preds.cpu().numpy())
19
- all_labels.extend(y.cpu().numpy())
20
- return all_predictions, all_labels
17
+ model.to(device)
18
+ model.eval()
19
+ logits = model(x)
20
+ sum_logits = logits if sum_logits is None else sum_logits + logits
21
+ batch_predictions = sum_logits.argmax(dim=1)
22
+ predictions.extend(batch_predictions.cpu().numpy())
23
+ labels.extend(y.cpu().numpy())
24
+ return predictions, labels
@@ -9,3 +9,7 @@ class PredictsProbabilities(ReportsLogits):
9
9
  def predict(self, x: torch.Tensor) -> torch.Tensor:
10
10
  logits = self.logits(x)
11
11
  return torch.nn.functional.softmax(logits, dim=1)
12
+
13
+ @torch.no_grad()
14
+ def predict_(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.predict(x)
@@ -6,8 +6,8 @@ from .CalculatesMetricNecessities import CalculatesMetricNecessities
6
6
 
7
7
  class ReportsClassificationMetrics(
8
8
  CalculatesMetricNecessities,
9
- ReportsEnsembleF1,
10
9
  ReportsEnsembleAccuracy,
10
+ ReportsEnsembleF1,
11
11
  ReportsEnsembleKappa,
12
12
  ):
13
13
  pass
@@ -1,10 +1,11 @@
1
1
  from sklearn.metrics import accuracy_score as accuracy
2
+ from torch.utils.data import DataLoader
2
3
 
3
4
 
4
5
  class ReportsEnsembleAccuracy:
5
6
  def __init__(self, *args, **kwargs):
6
7
  super().__init__(*args, **kwargs)
7
8
 
8
- def accuracy(self) -> float:
9
- predictions, labels = self.metric_necessities()
9
+ def accuracy(self, dataloader: DataLoader) -> float:
10
+ predictions, labels = self.metric_necessities(dataloader)
10
11
  return accuracy(labels, predictions)
@@ -11,3 +11,7 @@ class ReportsLogits:
11
11
  for model in self.models:
12
12
  logits += model(x)
13
13
  return logits
14
+
15
+ @torch.no_grad()
16
+ def logits_(self, *args, **kwargs):
17
+ return self.logits(*args, **kwargs)
@@ -4,8 +4,8 @@ class ReportsSize:
4
4
 
5
5
  @property
6
6
  def size(self):
7
- return len(self.state_dicts)
7
+ return len(self.models)
8
8
 
9
9
  @property
10
10
  def length(self):
11
- return len(self.state_dicts)
11
+ return len(self.models)
@@ -0,0 +1,29 @@
1
+ import torch
2
+ from copy import deepcopy
3
+ from typing import List
4
+ from ...vision import Model
5
+
6
+
7
+ class StoresModels:
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.models: List[torch.nn.Module] = []
11
+
12
+ def record(self, model: Model | torch.nn.Module):
13
+ model_: torch.nn.Module | None = None
14
+ if isinstance(model, Model):
15
+ model_ = deepcopy(model.network)
16
+ elif isinstance(model, torch.nn.Module):
17
+ model_ = deepcopy(model)
18
+ else:
19
+ raise TypeError("Wrong input to ensemble record")
20
+ self.models.append(model_)
21
+
22
+ def push(self, *args, **kwargs):
23
+ self.record(*args, **kwargs)
24
+
25
+ def append(self, *args, **kwargs):
26
+ self.record(*args, **kwargs)
27
+
28
+ def add(self, *args, **kwargs):
29
+ self.record(*args, **kwargs)
@@ -1,10 +1,9 @@
1
1
  from .CalculatesMetricNecessities import CalculatesMetricNecessities
2
- from .HasNetwork import HasNetwork
3
2
  from .PredictsProbabilities import PredictsProbabilities
4
- from .RecordsStateDictionaries import RecordsStateDictionaries
5
3
  from .ReportsClassificationMetrics import ReportsClassificationMetrics
6
4
  from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
7
5
  from .ReportsEnsembleF1 import ReportsEnsembleF1
8
6
  from .ReportsEnsembleKappa import ReportsEnsembleKappa
9
7
  from .ReportsLogits import ReportsLogits
10
8
  from .ReportsSize import ReportsSize
9
+ from .StoresModels import StoresModels
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from .Loss import Loss
3
+
4
+
5
+ class LogitNormLoss(Loss):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def forward(self, logits, target):
10
+ norms = torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7
11
+ normalized_logits = torch.div(logits, norms)
12
+ return torch.nn.functional.cross_entropy(normalized_logits, target)
homa/loss/Loss.py ADDED
@@ -0,0 +1,2 @@
1
+ class Loss:
2
+ pass
homa/loss/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .LogitNormLoss import LogitNormLoss
2
+ from .Loss import Loss
homa/torch/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .Module import Module
2
1
  from .helpers import *
@@ -0,0 +1,5 @@
1
+ from .Model import Model
2
+
3
+
4
+ class Classifier(Model):
5
+ pass
homa/vision/Resnet.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import torch
2
2
  from .modules import ResnetModule
3
- from .Model import Model
4
- from .concerns import Trainable
3
+ from .Classifier import Classifier
4
+ from .concerns import Trainable, ReportsMetrics
5
+ from ..device import get_device
5
6
 
6
7
 
7
- class Resnet(Model, Trainable):
8
- def __init__(self, num_classes: int, lr: float):
8
+ class Resnet(Classifier, Trainable, ReportsMetrics):
9
+ def __init__(self, num_classes: int, lr: float = 0.001):
9
10
  super().__init__()
10
- self.network = ResnetModule(num_classes)
11
+ self.network = ResnetModule(num_classes).to(get_device())
11
12
  self.criterion = torch.nn.CrossEntropyLoss()
12
13
  self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)