homa 0.2.95__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. homa/core/__init__.py +0 -0
  2. homa/core/concerns/MovesNetworkToDevice.py +13 -0
  3. homa/core/concerns/TracksTime.py +7 -0
  4. homa/core/concerns/__init__.py +2 -0
  5. homa/device.py +5 -0
  6. homa/ensemble/Ensemble.py +4 -2
  7. homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
  8. homa/ensemble/concerns/PredictsProbabilities.py +2 -2
  9. homa/ensemble/concerns/ReportsClassificationMetrics.py +2 -1
  10. homa/ensemble/concerns/ReportsEnsembleAccuracy.py +2 -2
  11. homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
  12. homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
  13. homa/ensemble/concerns/ReportsEnsembleSize.py +11 -0
  14. homa/ensemble/concerns/ReportsLogits.py +26 -5
  15. homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
  16. homa/ensemble/concerns/StoresModels.py +8 -11
  17. homa/ensemble/concerns/__init__.py +2 -1
  18. homa/ensemble/utils.py +9 -0
  19. homa/graph/GraphAttention.py +13 -0
  20. homa/graph/__init__.py +1 -0
  21. homa/graph/modules/GraphAttentionHeadModule.py +37 -0
  22. homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
  23. homa/graph/modules/__init__.py +2 -0
  24. homa/loss/Loss.py +4 -1
  25. homa/rl/DQN.py +2 -0
  26. homa/rl/DRQN.py +5 -0
  27. homa/rl/DiversityIsAllYouNeed.py +96 -0
  28. homa/rl/SoftActorCritic.py +67 -0
  29. homa/rl/__init__.py +4 -0
  30. homa/rl/buffers/Buffer.py +13 -0
  31. homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
  32. homa/rl/buffers/ImageBuffer.py +5 -0
  33. homa/rl/buffers/SoftActorCriticBuffer.py +64 -0
  34. homa/rl/buffers/__init__.py +4 -0
  35. homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
  36. homa/rl/buffers/concerns/ResetsCollection.py +9 -0
  37. homa/rl/buffers/concerns/__init__.py +2 -0
  38. homa/rl/diayn/Actor.py +54 -0
  39. homa/rl/diayn/Critic.py +41 -0
  40. homa/rl/diayn/Discriminator.py +45 -0
  41. homa/rl/diayn/__init__.py +3 -0
  42. homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
  43. homa/rl/diayn/modules/CriticModule.py +28 -0
  44. homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
  45. homa/rl/diayn/modules/__init__.py +3 -0
  46. homa/rl/sac/SoftActor.py +70 -0
  47. homa/rl/sac/SoftCritic.py +98 -0
  48. homa/rl/sac/__init__.py +2 -0
  49. homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
  50. homa/rl/sac/modules/SoftActorModule.py +35 -0
  51. homa/rl/sac/modules/SoftCriticModule.py +30 -0
  52. homa/rl/sac/modules/__init__.py +3 -0
  53. homa/rl/utils.py +7 -0
  54. homa/vision/Resnet.py +3 -3
  55. homa/vision/Swin.py +17 -5
  56. homa/vision/modules/SwinModule.py +17 -9
  57. {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/METADATA +1 -1
  58. {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/RECORD +61 -23
  59. homa/ensemble/concerns/ReportsSize.py +0 -11
  60. homa/torch/__init__.py +0 -1
  61. homa/torch/helpers.py +0 -6
  62. {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/WHEEL +0 -0
  63. {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/entry_points.txt +0 -0
  64. {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,12 @@
1
+ class HasRecordAlternatives:
2
+ def __init__(self, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
+
5
+ def add(self, *args, **kwargs) -> None:
6
+ self.record(*args, **kwargs)
7
+
8
+ def push(self, *args, **kwargs) -> None:
9
+ self.record(*args, **kwargs)
10
+
11
+ def append(self, *args, **kwargs) -> None:
12
+ self.record(*args, **kwargs)
@@ -0,0 +1,9 @@
1
+ from collections import deque
2
+
3
+
4
+ class ResetsCollection:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def reset(self):
9
+ self.collection = deque(maxlen=self.capacity)
@@ -0,0 +1,2 @@
1
+ from .HasRecordAlternatives import HasRecordAlternatives
2
+ from .ResetsCollection import ResetsCollection
homa/rl/diayn/Actor.py ADDED
@@ -0,0 +1,54 @@
1
+ import torch
2
+ from torch.distributions import Normal
3
+ from .modules import ContinuousActorModule
4
+ from ...core.concerns import MovesNetworkToDevice
5
+
6
+
7
+ class Actor(MovesNetworkToDevice):
8
+ def __init__(
9
+ self,
10
+ state_dimension: int,
11
+ action_dimension: int,
12
+ num_skills: int,
13
+ hidden_dimension: int,
14
+ lr: float,
15
+ decay: float,
16
+ epsilon: float,
17
+ min_std: float,
18
+ max_std: float,
19
+ ):
20
+ self.epsilon: float = epsilon
21
+ self.network = ContinuousActorModule(
22
+ state_dimension=state_dimension,
23
+ action_dimension=action_dimension,
24
+ hidden_dimension=hidden_dimension,
25
+ num_skills=num_skills,
26
+ min_std=min_std,
27
+ max_std=max_std,
28
+ )
29
+ self.optimizer = torch.optim.AdamW(
30
+ self.network.parameters(), lr=lr, weight_decay=decay
31
+ )
32
+
33
+ def action(self, state: torch.Tensor, skill: torch.Tensor):
34
+ mean, std = self.network(state, skill)
35
+ std = std.exp()
36
+ distribution = Normal(mean, std)
37
+ raw_action = distribution.rsample()
38
+ action = torch.tanh(raw_action)
39
+ corrected_probabilities = torch.log(1.0 - action.pow(2) + self.epsilon)
40
+ probabilities = distribution.log_prob(raw_action) - corrected_probabilities
41
+ probabilities = probabilities.sum(dim=-1, keepdim=True)
42
+ return action, probabilities
43
+
44
+ def train(self, advantages: torch.Tensor, probabilities: torch.Tensor) -> float:
45
+ self.optimizer.zero_grad()
46
+ loss = self.loss(advantages=advantages, probabilities=probabilities)
47
+ loss.backward()
48
+ self.optimizer.step()
49
+ return loss.item()
50
+
51
+ def loss(
52
+ self, advantages: torch.Tensor, probabilities: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ return -(probabilities * advantages.detach()).mean()
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from .modules import CriticModule
3
+ from ...core.concerns import MovesNetworkToDevice
4
+
5
+
6
+ class Critic(MovesNetworkToDevice):
7
+ def __init__(
8
+ self,
9
+ state_dimension: int,
10
+ hidden_dimension: int,
11
+ num_skills: int,
12
+ lr: float,
13
+ decay: float,
14
+ gamma: float,
15
+ ):
16
+ self.network = CriticModule(
17
+ state_dimension=state_dimension,
18
+ hidden_dimension=hidden_dimension,
19
+ num_skills=num_skills,
20
+ )
21
+ self.optimizer = torch.optim.AdamW(
22
+ self.network.parameters(), lr=lr, weight_decay=decay
23
+ )
24
+ self.criterion = torch.nn.SmoothL1Loss()
25
+ self.gamma: float = gamma
26
+
27
+ def train(self, advantages: torch.Tensor):
28
+ self.optimizer.zero_grad()
29
+ loss = self.loss(advantages=advantages)
30
+ loss.backward()
31
+ self.optimizer.step()
32
+
33
+ def loss(self, advantages: torch.Tensor):
34
+ return advantages.pow(2).mean()
35
+
36
+ def values(self, states: torch.Tensor, skills: torch.Tensor):
37
+ return self.network(states, skills)
38
+
39
+ @torch.no_grad()
40
+ def values_(self, *args, **kwargs):
41
+ return self.values(*args, **kwargs)
@@ -0,0 +1,45 @@
1
+ import torch
2
+ import numpy
3
+ from .modules import DiscriminatorModule
4
+ from ...core.concerns import MovesNetworkToDevice
5
+
6
+
7
+ class Discriminator(MovesNetworkToDevice):
8
+ def __init__(
9
+ self,
10
+ state_dimension: int,
11
+ hidden_dimension: int,
12
+ num_skills: int,
13
+ decay: float,
14
+ lr: float,
15
+ ):
16
+ self.num_skills: int = num_skills
17
+ self.network = DiscriminatorModule(
18
+ state_dimension=state_dimension,
19
+ hidden_dimension=hidden_dimension,
20
+ num_skills=num_skills,
21
+ )
22
+ self.optimizer = torch.optim.AdamW(
23
+ self.network.parameters(), lr=lr, weight_decay=decay
24
+ )
25
+ self.criterion = torch.nn.CrossEntropyLoss()
26
+
27
+ def loss(self, states: torch.Tensor, skills_indices: torch.Tensor):
28
+ logits = self.network(states)
29
+ return self.criterion(logits, skills_indices)
30
+
31
+ @torch.no_grad()
32
+ def reward(self, state: torch.Tensor, skill_index: torch.Tensor):
33
+ logits = self.network(state)
34
+ probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
35
+ entropy = numpy.log(1.0 / self.num_skills)
36
+ if skill_index.dim() == 1:
37
+ skill_index = skill_index.unsqueeze(-1)
38
+ reward = probabilities.gather(1, skill_index.long()) - entropy
39
+ return reward.squeeze(-1)
40
+
41
+ def train(self, states: torch.Tensor, skills_indices: torch.Tensor):
42
+ self.optimizer.zero_grad()
43
+ loss = self.loss(states=states, skills_indices=skills_indices)
44
+ loss.backward()
45
+ self.optimizer.step()
@@ -0,0 +1,3 @@
1
+ from .Actor import Actor
2
+ from .Critic import Critic
3
+ from .Discriminator import Discriminator
@@ -0,0 +1,42 @@
1
+ import torch
2
+
3
+
4
+ class ContinuousActorModule(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ state_dimension: int,
8
+ action_dimension: int,
9
+ hidden_dimension: int,
10
+ num_skills: int,
11
+ min_std: float,
12
+ max_std: float,
13
+ ):
14
+ super().__init__()
15
+ self.state_dimension: int = state_dimension
16
+ self.action_dimension: int = action_dimension
17
+ self.num_skills: int = num_skills
18
+ self.hidden_dimension: int = hidden_dimension
19
+ self.input_dimension: int = self.state_dimension + self.num_skills
20
+ self.min_std: float = min_std
21
+ self.max_std: float = max_std
22
+
23
+ self.phi = torch.nn.Sequential(
24
+ torch.nn.Linear(self.input_dimension, self.hidden_dimension),
25
+ torch.nn.ReLU(),
26
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
27
+ torch.nn.ReLU(),
28
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
29
+ )
30
+
31
+ self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
32
+ self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
33
+
34
+ def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
35
+ # fix the size to be one state per batch
36
+ state = state.view(state.size(0), -1)
37
+
38
+ psi = torch.cat([state, skill], dim=-1)
39
+ features = self.phi(psi)
40
+ mean = self.mu(features)
41
+ std = self.xi(features).clamp(self.min_std, self.max_std)
42
+ return mean, std
@@ -0,0 +1,28 @@
1
+ import torch
2
+
3
+
4
+ class CriticModule(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ state_dimension: int,
8
+ hidden_dimension: int,
9
+ num_skills: int,
10
+ ):
11
+ super().__init__()
12
+ self.state_dimension: int = state_dimension
13
+ self.num_skills: int = num_skills
14
+ self.hidden_dimension: int = hidden_dimension
15
+ self.input_dimension: int = self.state_dimension + self.num_skills
16
+
17
+ self.phi = torch.nn.Sequential(
18
+ torch.nn.Linear(self.input_dimension, self.hidden_dimension),
19
+ torch.nn.ReLU(),
20
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
21
+ torch.nn.ReLU(),
22
+ )
23
+ self.fc = (torch.nn.Linear(self.hidden_dimension, 1),)
24
+
25
+ def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
26
+ psi = torch.cat([state, skill], dim=-1)
27
+ features = self.phi(psi)
28
+ return self.fc(features).squeeze(-1)
@@ -0,0 +1,24 @@
1
+ import torch
2
+ from typing import Type
3
+
4
+
5
+ class DiscriminatorModule(torch.nn.Module):
6
+ def __init__(self, state_dimension: int, hidden_dimension: int, num_skills: int):
7
+ super().__init__()
8
+ self.state_dimension: int = state_dimension
9
+ self.hidden_dimension: int = hidden_dimension
10
+ self.num_skills: int = num_skills
11
+
12
+ self.phi: Type[torch.nn.Sequential] = torch.nn.Sequential(
13
+ torch.nn.Linear(self.state_dimension, self.hidden_dimension),
14
+ torch.nn.ReLU(),
15
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
16
+ torch.nn.ReLU(),
17
+ )
18
+ self.fc: Type[torch.nn.Linear] = torch.nn.Linear(
19
+ self.hidden_dimension, self.num_skills
20
+ )
21
+
22
+ def forward(self, state: torch.Tensor) -> torch.Tensor:
23
+ features: torch.Tensor = self.phi(state)
24
+ return self.fc(features)
@@ -0,0 +1,3 @@
1
+ from .DiscriminatorModule import DiscriminatorModule
2
+ from .CriticModule import CriticModule
3
+ from .ContinuousActorModule import ContinuousActorModule
@@ -0,0 +1,70 @@
1
+ import torch
2
+ import numpy
3
+ from .SoftCritic import SoftCritic
4
+ from .modules import SoftActorModule
5
+ from ...core.concerns import MovesNetworkToDevice
6
+
7
+
8
+ class SoftActor(MovesNetworkToDevice):
9
+ def __init__(
10
+ self,
11
+ state_dimension: int,
12
+ hidden_dimension: int,
13
+ action_dimension: int,
14
+ lr: float,
15
+ weight_decay: float,
16
+ alpha: float,
17
+ min_std: float,
18
+ max_std: float,
19
+ ):
20
+ self.alpha: float = alpha
21
+
22
+ self.network = SoftActorModule(
23
+ state_dimension=state_dimension,
24
+ hidden_dimension=hidden_dimension,
25
+ action_dimension=action_dimension,
26
+ min_std=min_std,
27
+ max_std=max_std,
28
+ )
29
+ self.optimizer = torch.optim.AdamW(
30
+ self.network.parameters(), lr=lr, weight_decay=weight_decay
31
+ )
32
+
33
+ def train(self, states: torch.Tensor, critic: SoftCritic):
34
+ self.network.train()
35
+ self.optimizer.zero_grad()
36
+ loss = self.loss(states=states, critic=critic)
37
+ loss.backward()
38
+ self.optimizer.step()
39
+
40
+ def loss(self, states: torch.Tensor, critic: SoftCritic) -> torch.Tensor:
41
+ actions, probabilities = self.sample(states)
42
+ q_alpha, q_beta = critic.network(states, actions)
43
+ q = torch.min(q_alpha, q_beta)
44
+ return (self.alpha * probabilities - q).mean()
45
+
46
+ def process_state(self, state: numpy.ndarray | torch.Tensor) -> torch.Tensor:
47
+ if isinstance(state, numpy.ndarray):
48
+ state = torch.from_numpy(state).float()
49
+
50
+ if state.ndim < 2:
51
+ state = state.unsqueeze(0)
52
+
53
+ return state
54
+
55
+ def sample(self, state: numpy.ndarray | torch.Tensor):
56
+ state = self.process_state(state)
57
+
58
+ mean, std = self.network(state)
59
+ # following line prevents standard deviations to be negative
60
+ std = std.exp()
61
+
62
+ distribution = torch.distributions.Normal(mean, std)
63
+
64
+ pre_tanh = distribution.rsample()
65
+ action = torch.tanh(pre_tanh)
66
+
67
+ probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
68
+ probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
69
+
70
+ return action, probabilities
@@ -0,0 +1,98 @@
1
+ import torch
2
+ from torch.nn.functional import mse_loss as mse
3
+ from .modules import DualSoftCriticModule
4
+ from .SoftActor import SoftActor
5
+ from ..utils import soft_update
6
+ from ...core.concerns import MovesNetworkToDevice
7
+
8
+
9
+ class SoftCritic(MovesNetworkToDevice):
10
+ def __init__(
11
+ self,
12
+ state_dimension: int,
13
+ hidden_dimension: int,
14
+ action_dimension: int,
15
+ lr: float,
16
+ weight_decay: float,
17
+ gamma: float,
18
+ alpha: float,
19
+ ):
20
+ self.gamma: float = gamma
21
+ self.alpha: float = alpha
22
+
23
+ self.network = DualSoftCriticModule(
24
+ state_dimension=state_dimension,
25
+ hidden_dimension=hidden_dimension,
26
+ action_dimension=action_dimension,
27
+ )
28
+ self.target = DualSoftCriticModule(
29
+ state_dimension=state_dimension,
30
+ hidden_dimension=hidden_dimension,
31
+ action_dimension=action_dimension,
32
+ )
33
+
34
+ # copy source to target when initiated
35
+ self.target.load_state_dict(self.network.state_dict())
36
+
37
+ self.optimizer = torch.optim.AdamW(
38
+ self.network.parameters(), lr=lr, weight_decay=weight_decay
39
+ )
40
+
41
+ def train(
42
+ self,
43
+ states: torch.Tensor,
44
+ actions: torch.Tensor,
45
+ rewards: torch.Tensor,
46
+ terminations: torch.Tensor,
47
+ next_states: torch.Tensor,
48
+ actor: SoftActor,
49
+ ):
50
+ self.network.train()
51
+ self.optimizer.zero_grad()
52
+ loss = self.loss(
53
+ states=states,
54
+ actions=actions,
55
+ rewards=rewards,
56
+ terminations=terminations,
57
+ next_states=next_states,
58
+ actor=actor,
59
+ )
60
+ loss.backward()
61
+ self.optimizer.step()
62
+
63
+ def loss(
64
+ self,
65
+ states: torch.Tensor,
66
+ actions: torch.Tensor,
67
+ rewards: torch.Tensor,
68
+ terminations: torch.Tensor,
69
+ next_states: torch.Tensor,
70
+ actor: torch.nn.Module,
71
+ ):
72
+ q_alpha, q_beta = self.network(states, actions)
73
+ target = self.calculate_target(
74
+ rewards=rewards,
75
+ terminations=terminations,
76
+ next_states=next_states,
77
+ actor=actor,
78
+ )
79
+ return mse(q_alpha, target) + mse(q_beta, target)
80
+
81
+ @torch.no_grad()
82
+ def calculate_target(
83
+ self,
84
+ rewards: torch.Tensor,
85
+ terminations: torch.Tensor,
86
+ next_states: torch.Tensor,
87
+ actor: SoftActor,
88
+ ):
89
+ termination_mask = 1 - terminations
90
+ next_actions, next_probabilities = actor.sample(next_states)
91
+ q_alpha, q_beta = self.target(next_states, next_actions)
92
+ q = torch.min(q_alpha, q_beta)
93
+ entropy_q = q - self.alpha * next_probabilities
94
+ return rewards + self.gamma * termination_mask * entropy_q
95
+
96
+ def update(self, tau: float):
97
+ soft_update(network=self.network.alpha, target=self.target.alpha, tau=tau)
98
+ soft_update(network=self.network.beta, target=self.target.beta, tau=tau)
@@ -0,0 +1,2 @@
1
+ from .SoftActor import SoftActor
2
+ from .SoftCritic import SoftCritic
@@ -0,0 +1,22 @@
1
+ import torch
2
+ from .SoftCriticModule import SoftCriticModule
3
+
4
+
5
+ class DualSoftCriticModule(torch.nn.Module):
6
+ def __init__(
7
+ self, state_dimension: int, hidden_dimension: int, action_dimension: int
8
+ ):
9
+ super().__init__()
10
+ self.alpha = SoftCriticModule(
11
+ state_dimension=state_dimension,
12
+ hidden_dimension=hidden_dimension,
13
+ action_dimension=action_dimension,
14
+ )
15
+ self.beta = SoftCriticModule(
16
+ state_dimension=state_dimension,
17
+ hidden_dimension=hidden_dimension,
18
+ action_dimension=action_dimension,
19
+ )
20
+
21
+ def forward(self, state: torch.Tensor, action: torch.Tensor):
22
+ return self.alpha(state, action), self.beta(state, action)
@@ -0,0 +1,35 @@
1
+ import torch
2
+
3
+
4
+ class SoftActorModule(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ state_dimension: int,
8
+ hidden_dimension: int,
9
+ action_dimension: int,
10
+ min_std: float,
11
+ max_std: float,
12
+ ):
13
+ super().__init__()
14
+
15
+ self.state_dimension: int = state_dimension
16
+ self.hidden_dimension: int = hidden_dimension
17
+ self.action_dimension: int = action_dimension
18
+ self.min_std: float = float(min_std)
19
+ self.max_std: float = float(max_std)
20
+
21
+ self.phi = torch.nn.Sequential(
22
+ torch.nn.Linear(self.state_dimension, self.hidden_dimension),
23
+ torch.nn.ReLU(),
24
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
25
+ torch.nn.ReLU(),
26
+ )
27
+ self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
28
+ self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
29
+
30
+ def forward(self, state: torch.Tensor):
31
+ features = self.phi(state)
32
+ mean = self.mu(features)
33
+ std = self.xi(features)
34
+ std = std.clamp(self.min_std, self.max_std)
35
+ return mean, std
@@ -0,0 +1,30 @@
1
+ import torch
2
+
3
+
4
+ class SoftCriticModule(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ state_dimension: int,
8
+ hidden_dimension: int,
9
+ action_dimension: int,
10
+ ):
11
+ super().__init__()
12
+
13
+ self.state_dimension: int = state_dimension
14
+ self.action_dimension: int = action_dimension
15
+ self.hidden_dimension: int = hidden_dimension
16
+
17
+ self.phi = torch.nn.Sequential(
18
+ torch.nn.Linear(
19
+ self.state_dimension + self.action_dimension, self.hidden_dimension
20
+ ),
21
+ torch.nn.ReLU(),
22
+ torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
23
+ torch.nn.ReLU(),
24
+ )
25
+ self.fc = torch.nn.Linear(self.hidden_dimension, 1)
26
+
27
+ def forward(self, state: torch.Tensor, action: torch.Tensor):
28
+ psi = torch.cat([state, action], dim=1)
29
+ features = self.phi(psi)
30
+ return self.fc(features)
@@ -0,0 +1,3 @@
1
+ from .SoftActorModule import SoftActorModule
2
+ from .SoftCriticModule import SoftCriticModule
3
+ from .DualSoftCriticModule import DualSoftCriticModule
homa/rl/utils.py ADDED
@@ -0,0 +1,7 @@
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def soft_update(network: torch.nn.Module, target: torch.nn.Module, tau: float):
6
+ for s, t in zip(network.parameters(), target.parameters()):
7
+ t.data.copy_(tau * s.data + (1 - tau) * t.data)
homa/vision/Resnet.py CHANGED
@@ -2,12 +2,12 @@ import torch
2
2
  from .modules import ResnetModule
3
3
  from .Classifier import Classifier
4
4
  from .concerns import Trainable, ReportsMetrics
5
- from ..device import get_device
5
+ from ..core.concerns import MovesNetworkToDevice
6
6
 
7
7
 
8
- class Resnet(Classifier, Trainable, ReportsMetrics):
8
+ class Resnet(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
9
9
  def __init__(self, num_classes: int, lr: float = 0.001):
10
10
  super().__init__()
11
- self.network = ResnetModule(num_classes).to(get_device())
11
+ self.network = ResnetModule(num_classes)
12
12
  self.criterion = torch.nn.CrossEntropyLoss()
13
13
  self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
homa/vision/Swin.py CHANGED
@@ -2,12 +2,24 @@ import torch
2
2
  from .Classifier import Classifier
3
3
  from .concerns import Trainable, ReportsMetrics
4
4
  from .modules import SwinModule
5
- from ..device import get_device
5
+ from ..core.concerns import MovesNetworkToDevice
6
6
 
7
7
 
8
- class Swin(Classifier, Trainable, ReportsMetrics):
9
- def __init__(self, num_classes: int, lr: float = 0.0001):
8
+ class Swin(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
9
+ def __init__(
10
+ self,
11
+ num_classes: int,
12
+ lr: float = 0.0001,
13
+ decay: float = 0.0,
14
+ variant: str = "base",
15
+ weights="DEFAULT",
16
+ ):
10
17
  super().__init__()
11
- self.network = SwinModule(num_classes=num_classes).to(get_device())
12
- self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
18
+ self.num_classes = num_classes
19
+ self.network = SwinModule(
20
+ num_classes=self.num_classes, variant=variant, weights=weights
21
+ )
22
+ self.optimizer = torch.optim.AdamW(
23
+ self.network.parameters(), lr=lr, weight_decay=decay
24
+ )
13
25
  self.criterion = torch.nn.CrossEntropyLoss()
@@ -1,21 +1,29 @@
1
1
  import torch
2
- from torchvision.models import swin_v2_b
2
+ from torchvision.models import swin_v2_b, swin_v2_s, swin_v2_t
3
3
  from torch.nn.init import kaiming_uniform_ as kaiming
4
4
 
5
5
 
6
6
  class SwinModule(torch.nn.Module):
7
- def __init__(self, num_classes: int):
7
+ def __init__(self, num_classes: int, variant: str, weights):
8
8
  super().__init__()
9
- self.num_classes = num_classes
10
- self._create_encoder()
11
- self._create_fc()
9
+ self._create_encoder(variant=variant, weights=weights)
10
+ self._create_fc(num_classes=num_classes)
12
11
 
13
- def _create_encoder(self):
14
- self.encoder = swin_v2_b(weights="DEFAULT")
12
+ def variant_instance(self, variant: str):
13
+ variant_map = {"tiny": swin_v2_t, "small": swin_v2_s, "base": swin_v2_b}
14
+ return variant_map.get(variant)
15
+
16
+ def _create_encoder(self, variant: str, weights):
17
+ if variant not in ["tiny", "small", "base"]:
18
+ raise ValueError(
19
+ f"Swin variant needs to be one of [tiny, small, base]. Invalid {variant} was provided."
20
+ )
21
+ instance = self.variant_instnace(variant)
22
+ self.encoder = instance(weights=weights)
15
23
  self.encoder.head = torch.nn.Identity()
16
24
 
17
- def _create_fc(self):
18
- self.fc = torch.nn.Linear(1024, self.num_classes)
25
+ def _create_fc(self, num_classes: int):
26
+ self.fc = torch.nn.Linear(1024, num_classes)
19
27
  kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
20
28
 
21
29
  def forward(self, images: torch.Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.2.95
3
+ Version: 0.3.2
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7