homa 0.2.95__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/core/__init__.py +0 -0
- homa/core/concerns/MovesNetworkToDevice.py +13 -0
- homa/core/concerns/__init__.py +1 -0
- homa/device.py +4 -0
- homa/ensemble/Ensemble.py +4 -2
- homa/ensemble/concerns/{ReportsSize.py → ReportsEnsembleSize.py} +3 -3
- homa/ensemble/concerns/ReportsLogits.py +3 -1
- homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
- homa/ensemble/concerns/StoresModels.py +6 -9
- homa/ensemble/concerns/__init__.py +2 -1
- homa/ensemble/utils.py +9 -0
- homa/graph/GraphAttention.py +13 -0
- homa/graph/__init__.py +1 -0
- homa/graph/modules/GraphAttentionHeadModule.py +37 -0
- homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
- homa/graph/modules/__init__.py +2 -0
- homa/loss/Loss.py +4 -1
- homa/rl/DQN.py +2 -0
- homa/rl/DRQN.py +5 -0
- homa/rl/DiversityIsAllYouNeed.py +96 -0
- homa/rl/SoftActorCritic.py +64 -0
- homa/rl/__init__.py +4 -0
- homa/rl/buffers/Buffer.py +11 -0
- homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
- homa/rl/buffers/ImageBuffer.py +5 -0
- homa/rl/buffers/SoftActorCriticBuffer.py +56 -0
- homa/rl/buffers/__init__.py +4 -0
- homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
- homa/rl/buffers/concerns/ResetsCollection.py +9 -0
- homa/rl/buffers/concerns/__init__.py +2 -0
- homa/rl/diayn/Actor.py +54 -0
- homa/rl/diayn/Critic.py +41 -0
- homa/rl/diayn/Discriminator.py +45 -0
- homa/rl/diayn/__init__.py +3 -0
- homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
- homa/rl/diayn/modules/CriticModule.py +28 -0
- homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
- homa/rl/diayn/modules/__init__.py +3 -0
- homa/rl/sac/SoftActor.py +69 -0
- homa/rl/sac/SoftCritic.py +100 -0
- homa/rl/sac/__init__.py +2 -0
- homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
- homa/rl/sac/modules/SoftActorModule.py +35 -0
- homa/rl/sac/modules/SoftCriticModule.py +30 -0
- homa/rl/sac/modules/__init__.py +3 -0
- homa/vision/Resnet.py +3 -3
- homa/vision/Swin.py +17 -5
- homa/vision/modules/SwinModule.py +17 -9
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/METADATA +1 -1
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/RECORD +53 -17
- homa/torch/__init__.py +0 -1
- homa/torch/helpers.py +0 -6
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/WHEEL +0 -0
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/entry_points.txt +0 -0
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from .modules import DiscriminatorModule
|
|
4
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Discriminator(MovesNetworkToDevice):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
state_dimension: int,
|
|
11
|
+
hidden_dimension: int,
|
|
12
|
+
num_skills: int,
|
|
13
|
+
decay: float,
|
|
14
|
+
lr: float,
|
|
15
|
+
):
|
|
16
|
+
self.num_skills: int = num_skills
|
|
17
|
+
self.network = DiscriminatorModule(
|
|
18
|
+
state_dimension=state_dimension,
|
|
19
|
+
hidden_dimension=hidden_dimension,
|
|
20
|
+
num_skills=num_skills,
|
|
21
|
+
)
|
|
22
|
+
self.optimizer = torch.optim.AdamW(
|
|
23
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
24
|
+
)
|
|
25
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
26
|
+
|
|
27
|
+
def loss(self, states: torch.Tensor, skills_indices: torch.Tensor):
|
|
28
|
+
logits = self.network(states)
|
|
29
|
+
return self.criterion(logits, skills_indices)
|
|
30
|
+
|
|
31
|
+
@torch.no_grad()
|
|
32
|
+
def reward(self, state: torch.Tensor, skill_index: torch.Tensor):
|
|
33
|
+
logits = self.network(state)
|
|
34
|
+
probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
35
|
+
entropy = numpy.log(1.0 / self.num_skills)
|
|
36
|
+
if skill_index.dim() == 1:
|
|
37
|
+
skill_index = skill_index.unsqueeze(-1)
|
|
38
|
+
reward = probabilities.gather(1, skill_index.long()) - entropy
|
|
39
|
+
return reward.squeeze(-1)
|
|
40
|
+
|
|
41
|
+
def train(self, states: torch.Tensor, skills_indices: torch.Tensor):
|
|
42
|
+
self.optimizer.zero_grad()
|
|
43
|
+
loss = self.loss(states=states, skills_indices=skills_indices)
|
|
44
|
+
loss.backward()
|
|
45
|
+
self.optimizer.step()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ContinuousActorModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
action_dimension: int,
|
|
9
|
+
hidden_dimension: int,
|
|
10
|
+
num_skills: int,
|
|
11
|
+
min_std: float,
|
|
12
|
+
max_std: float,
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.state_dimension: int = state_dimension
|
|
16
|
+
self.action_dimension: int = action_dimension
|
|
17
|
+
self.num_skills: int = num_skills
|
|
18
|
+
self.hidden_dimension: int = hidden_dimension
|
|
19
|
+
self.input_dimension: int = self.state_dimension + self.num_skills
|
|
20
|
+
self.min_std: float = min_std
|
|
21
|
+
self.max_std: float = max_std
|
|
22
|
+
|
|
23
|
+
self.phi = torch.nn.Sequential(
|
|
24
|
+
torch.nn.Linear(self.input_dimension, self.hidden_dimension),
|
|
25
|
+
torch.nn.ReLU(),
|
|
26
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
27
|
+
torch.nn.ReLU(),
|
|
28
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
32
|
+
self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
33
|
+
|
|
34
|
+
def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
# fix the size to be one state per batch
|
|
36
|
+
state = state.view(state.size(0), -1)
|
|
37
|
+
|
|
38
|
+
psi = torch.cat([state, skill], dim=-1)
|
|
39
|
+
features = self.phi(psi)
|
|
40
|
+
mean = self.mu(features)
|
|
41
|
+
std = self.xi(features).clamp(self.min_std, self.max_std)
|
|
42
|
+
return mean, std
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CriticModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
num_skills: int,
|
|
10
|
+
):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.state_dimension: int = state_dimension
|
|
13
|
+
self.num_skills: int = num_skills
|
|
14
|
+
self.hidden_dimension: int = hidden_dimension
|
|
15
|
+
self.input_dimension: int = self.state_dimension + self.num_skills
|
|
16
|
+
|
|
17
|
+
self.phi = torch.nn.Sequential(
|
|
18
|
+
torch.nn.Linear(self.input_dimension, self.hidden_dimension),
|
|
19
|
+
torch.nn.ReLU(),
|
|
20
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
21
|
+
torch.nn.ReLU(),
|
|
22
|
+
)
|
|
23
|
+
self.fc = (torch.nn.Linear(self.hidden_dimension, 1),)
|
|
24
|
+
|
|
25
|
+
def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
psi = torch.cat([state, skill], dim=-1)
|
|
27
|
+
features = self.phi(psi)
|
|
28
|
+
return self.fc(features).squeeze(-1)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Type
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DiscriminatorModule(torch.nn.Module):
|
|
6
|
+
def __init__(self, state_dimension: int, hidden_dimension: int, num_skills: int):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.state_dimension: int = state_dimension
|
|
9
|
+
self.hidden_dimension: int = hidden_dimension
|
|
10
|
+
self.num_skills: int = num_skills
|
|
11
|
+
|
|
12
|
+
self.phi: Type[torch.nn.Sequential] = torch.nn.Sequential(
|
|
13
|
+
torch.nn.Linear(self.state_dimension, self.hidden_dimension),
|
|
14
|
+
torch.nn.ReLU(),
|
|
15
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
16
|
+
torch.nn.ReLU(),
|
|
17
|
+
)
|
|
18
|
+
self.fc: Type[torch.nn.Linear] = torch.nn.Linear(
|
|
19
|
+
self.hidden_dimension, self.num_skills
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, state: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
features: torch.Tensor = self.phi(state)
|
|
24
|
+
return self.fc(features)
|
homa/rl/sac/SoftActor.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from .modules import SoftActorModule
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SoftActor:
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
state_dimension: int,
|
|
10
|
+
hidden_dimension: int,
|
|
11
|
+
action_dimension: int,
|
|
12
|
+
lr: float,
|
|
13
|
+
weight_decay: float,
|
|
14
|
+
alpha: float,
|
|
15
|
+
min_std: float,
|
|
16
|
+
max_std: float,
|
|
17
|
+
):
|
|
18
|
+
self.alpha: float = alpha
|
|
19
|
+
|
|
20
|
+
self.network = SoftActorModule(
|
|
21
|
+
state_dimension=state_dimension,
|
|
22
|
+
hidden_dimension=hidden_dimension,
|
|
23
|
+
action_dimension=action_dimension,
|
|
24
|
+
min_std=min_std,
|
|
25
|
+
max_std=max_std,
|
|
26
|
+
)
|
|
27
|
+
self.optimizer = torch.optim.AdamW(
|
|
28
|
+
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
|
|
32
|
+
self.optimizer.zero_grad()
|
|
33
|
+
loss = self.loss(states=states, critic_network=critic_network)
|
|
34
|
+
loss.backward()
|
|
35
|
+
self.optimizer.step()
|
|
36
|
+
|
|
37
|
+
def loss(
|
|
38
|
+
self, states: torch.Tensor, critic_network: torch.nn.Module
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
actions, probabilities = self.sample(states)
|
|
41
|
+
q_alpha, q_beta = critic_network(states, actions)
|
|
42
|
+
q = torch.min(q_alpha, q_beta)
|
|
43
|
+
return (self.alpha * probabilities - q).mean()
|
|
44
|
+
|
|
45
|
+
def process_state(self, state: numpy.ndarray | torch.Tensor) -> torch.Tensor:
|
|
46
|
+
if isinstance(state, numpy.ndarray):
|
|
47
|
+
state = torch.from_numpy(state)
|
|
48
|
+
|
|
49
|
+
if state.ndim < 2:
|
|
50
|
+
state = state.unsqueeze(0)
|
|
51
|
+
|
|
52
|
+
return state
|
|
53
|
+
|
|
54
|
+
def sample(self, state: numpy.ndarray | torch.Tensor):
|
|
55
|
+
state = self.process_state(state)
|
|
56
|
+
|
|
57
|
+
mean, std = self.network(state)
|
|
58
|
+
# following line prevents standard deviations to be negative
|
|
59
|
+
std = std.exp()
|
|
60
|
+
|
|
61
|
+
distribution = torch.distributions.Normal(mean, std)
|
|
62
|
+
|
|
63
|
+
pre_tanh = distribution.rsample()
|
|
64
|
+
action = torch.tanh(pre_tanh)
|
|
65
|
+
|
|
66
|
+
probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
|
|
67
|
+
correction = torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
|
|
68
|
+
|
|
69
|
+
return action, probabilities - correction
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn.functional import mse_loss as mse
|
|
3
|
+
from typing import Type
|
|
4
|
+
from .modules import DualSoftCriticModule
|
|
5
|
+
from .SoftActor import SoftActor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SoftCritic:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
state_dimension: int,
|
|
12
|
+
hidden_dimension: int,
|
|
13
|
+
action_dimension: int,
|
|
14
|
+
lr: float,
|
|
15
|
+
weight_decay: float,
|
|
16
|
+
tau: float,
|
|
17
|
+
gamma: float,
|
|
18
|
+
alpha: float,
|
|
19
|
+
):
|
|
20
|
+
self.tau: float = tau
|
|
21
|
+
self.gamma: float = gamma
|
|
22
|
+
self.alpha: float = alpha
|
|
23
|
+
|
|
24
|
+
self.network = DualSoftCriticModule(
|
|
25
|
+
state_dimension=state_dimension,
|
|
26
|
+
hidden_dimension=hidden_dimension,
|
|
27
|
+
action_dimension=action_dimension,
|
|
28
|
+
)
|
|
29
|
+
self.target = DualSoftCriticModule(
|
|
30
|
+
state_dimension=state_dimension,
|
|
31
|
+
hidden_dimension=hidden_dimension,
|
|
32
|
+
action_dimension=action_dimension,
|
|
33
|
+
)
|
|
34
|
+
self.optimizer = torch.optim.AdamW(
|
|
35
|
+
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def train(
|
|
39
|
+
self,
|
|
40
|
+
states: torch.Tensor,
|
|
41
|
+
actions: torch.Tensor,
|
|
42
|
+
rewards: torch.Tensor,
|
|
43
|
+
terminations: torch.Tensor,
|
|
44
|
+
next_states: torch.Tensor,
|
|
45
|
+
actor: torch.nn.Module,
|
|
46
|
+
):
|
|
47
|
+
self.optimizer.zero_grad()
|
|
48
|
+
loss = self.loss(
|
|
49
|
+
states=states,
|
|
50
|
+
actions=actions,
|
|
51
|
+
rewards=rewards,
|
|
52
|
+
terminations=terminations,
|
|
53
|
+
next_states=next_states,
|
|
54
|
+
actor=actor,
|
|
55
|
+
)
|
|
56
|
+
loss.backward()
|
|
57
|
+
self.optimizer.step()
|
|
58
|
+
|
|
59
|
+
def loss(
|
|
60
|
+
self,
|
|
61
|
+
states: torch.Tensor,
|
|
62
|
+
actions: torch.Tensor,
|
|
63
|
+
rewards: torch.Tensor,
|
|
64
|
+
terminations: torch.Tensor,
|
|
65
|
+
next_states: torch.Tensor,
|
|
66
|
+
actor: torch.nn.Module,
|
|
67
|
+
):
|
|
68
|
+
q_alpha, q_beta = self.target(states, actions)
|
|
69
|
+
target = self.calculate_target(
|
|
70
|
+
rewards=rewards,
|
|
71
|
+
terminations=terminations,
|
|
72
|
+
next_states=next_states,
|
|
73
|
+
actor=actor,
|
|
74
|
+
)
|
|
75
|
+
return mse(q_alpha, target) + mse(q_beta, target)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad()
|
|
78
|
+
def calculate_target(
|
|
79
|
+
self,
|
|
80
|
+
rewards: torch.Tensor,
|
|
81
|
+
terminations: torch.Tensor,
|
|
82
|
+
next_states: torch.Tensor,
|
|
83
|
+
actor: SoftActor,
|
|
84
|
+
):
|
|
85
|
+
next_actions, next_probabilities = actor.sample(next_states)
|
|
86
|
+
q_alpha, q_beta = self.target(next_states, next_actions)
|
|
87
|
+
q = torch.min(q_alpha, q_beta)
|
|
88
|
+
termination_mask = 1 - terminations
|
|
89
|
+
entropy_q = q - self.alpha * next_probabilities * termination_mask
|
|
90
|
+
return rewards + self.gamma * entropy_q
|
|
91
|
+
|
|
92
|
+
def soft_update(
|
|
93
|
+
self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
|
|
94
|
+
):
|
|
95
|
+
for s, t in zip(network.parameters(), target.parameters()):
|
|
96
|
+
t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
|
|
97
|
+
|
|
98
|
+
def update(self):
|
|
99
|
+
self.soft_update(self.network.alpha, self.target.alpha)
|
|
100
|
+
self.soft_update(self.network.beta, self.target.beta)
|
homa/rl/sac/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .SoftCriticModule import SoftCriticModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DualSoftCriticModule(torch.nn.Module):
|
|
6
|
+
def __init__(
|
|
7
|
+
self, state_dimension: int, hidden_dimension: int, action_dimension: int
|
|
8
|
+
):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.alpha = SoftCriticModule(
|
|
11
|
+
state_dimension=state_dimension,
|
|
12
|
+
hidden_dimension=hidden_dimension,
|
|
13
|
+
action_dimension=action_dimension,
|
|
14
|
+
)
|
|
15
|
+
self.beta = SoftCriticModule(
|
|
16
|
+
state_dimension=state_dimension,
|
|
17
|
+
hidden_dimension=hidden_dimension,
|
|
18
|
+
action_dimension=action_dimension,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def forward(self, state: torch.Tensor, action: torch.Tensor):
|
|
22
|
+
return self.alpha(state, action), self.beta(state, action)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SoftActorModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
min_std: float,
|
|
11
|
+
max_std: float,
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
|
|
15
|
+
self.state_dimension: int = state_dimension
|
|
16
|
+
self.hidden_dimension: int = hidden_dimension
|
|
17
|
+
self.action_dimension: int = action_dimension
|
|
18
|
+
self.min_std: float = float(min_std)
|
|
19
|
+
self.max_std: float = float(max_std)
|
|
20
|
+
|
|
21
|
+
self.phi = torch.nn.Sequential(
|
|
22
|
+
torch.nn.Linear(self.state_dimension, self.hidden_dimension),
|
|
23
|
+
torch.nn.ReLU(),
|
|
24
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
25
|
+
torch.nn.ReLU(),
|
|
26
|
+
)
|
|
27
|
+
self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
28
|
+
self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
29
|
+
|
|
30
|
+
def forward(self, state: torch.Tensor):
|
|
31
|
+
features = self.phi(state)
|
|
32
|
+
mean = self.mu(features)
|
|
33
|
+
std = self.mu(features)
|
|
34
|
+
std = std.clamp(self.min_std, self.max_std)
|
|
35
|
+
return mean, std
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SoftCriticModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
):
|
|
11
|
+
super().__init__()
|
|
12
|
+
|
|
13
|
+
self.state_dimension: int = state_dimension
|
|
14
|
+
self.action_dimension: int = action_dimension
|
|
15
|
+
self.hidden_dimension: int = hidden_dimension
|
|
16
|
+
|
|
17
|
+
self.phi = torch.nn.Sequential(
|
|
18
|
+
torch.nn.Linear(
|
|
19
|
+
self.state_dimension + self.action_dimension, self.hidden_dimension
|
|
20
|
+
),
|
|
21
|
+
torch.nn.ReLU(),
|
|
22
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
23
|
+
torch.nn.ReLU(),
|
|
24
|
+
)
|
|
25
|
+
self.fc = torch.nn.Linear(self.hidden_dimension, 1)
|
|
26
|
+
|
|
27
|
+
def forward(self, state: torch.Tensor, action: torch.Tensor):
|
|
28
|
+
psi = torch.cat([state, action], dim=1)
|
|
29
|
+
features = self.phi(psi)
|
|
30
|
+
return self.fc(features)
|
homa/vision/Resnet.py
CHANGED
|
@@ -2,12 +2,12 @@ import torch
|
|
|
2
2
|
from .modules import ResnetModule
|
|
3
3
|
from .Classifier import Classifier
|
|
4
4
|
from .concerns import Trainable, ReportsMetrics
|
|
5
|
-
from ..
|
|
5
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class Resnet(Classifier, Trainable, ReportsMetrics):
|
|
8
|
+
class Resnet(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
|
|
9
9
|
def __init__(self, num_classes: int, lr: float = 0.001):
|
|
10
10
|
super().__init__()
|
|
11
|
-
self.network = ResnetModule(num_classes)
|
|
11
|
+
self.network = ResnetModule(num_classes)
|
|
12
12
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
13
13
|
self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
|
homa/vision/Swin.py
CHANGED
|
@@ -2,12 +2,24 @@ import torch
|
|
|
2
2
|
from .Classifier import Classifier
|
|
3
3
|
from .concerns import Trainable, ReportsMetrics
|
|
4
4
|
from .modules import SwinModule
|
|
5
|
-
from ..
|
|
5
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class Swin(Classifier, Trainable, ReportsMetrics):
|
|
9
|
-
def __init__(
|
|
8
|
+
class Swin(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
num_classes: int,
|
|
12
|
+
lr: float = 0.0001,
|
|
13
|
+
decay: float = 0.0,
|
|
14
|
+
variant: str = "base",
|
|
15
|
+
weights="DEFAULT",
|
|
16
|
+
):
|
|
10
17
|
super().__init__()
|
|
11
|
-
self.
|
|
12
|
-
self.
|
|
18
|
+
self.num_classes = num_classes
|
|
19
|
+
self.network = SwinModule(
|
|
20
|
+
num_classes=self.num_classes, variant=variant, weights=weights
|
|
21
|
+
)
|
|
22
|
+
self.optimizer = torch.optim.AdamW(
|
|
23
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
24
|
+
)
|
|
13
25
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
@@ -1,21 +1,29 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from torchvision.models import swin_v2_b
|
|
2
|
+
from torchvision.models import swin_v2_b, swin_v2_s, swin_v2_t
|
|
3
3
|
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class SwinModule(torch.nn.Module):
|
|
7
|
-
def __init__(self, num_classes: int):
|
|
7
|
+
def __init__(self, num_classes: int, variant: str, weights):
|
|
8
8
|
super().__init__()
|
|
9
|
-
self.
|
|
10
|
-
self.
|
|
11
|
-
self._create_fc()
|
|
9
|
+
self._create_encoder(variant=variant, weights=weights)
|
|
10
|
+
self._create_fc(num_classes=num_classes)
|
|
12
11
|
|
|
13
|
-
def
|
|
14
|
-
|
|
12
|
+
def variant_instance(self, variant: str):
|
|
13
|
+
variant_map = {"tiny": swin_v2_t, "small": swin_v2_s, "base": swin_v2_b}
|
|
14
|
+
return variant_map.get(variant)
|
|
15
|
+
|
|
16
|
+
def _create_encoder(self, variant: str, weights):
|
|
17
|
+
if variant not in ["tiny", "small", "base"]:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
f"Swin variant needs to be one of [tiny, small, base]. Invalid {variant} was provided."
|
|
20
|
+
)
|
|
21
|
+
instance = self.variant_instnace(variant)
|
|
22
|
+
self.encoder = instance(weights=weights)
|
|
15
23
|
self.encoder.head = torch.nn.Identity()
|
|
16
24
|
|
|
17
|
-
def _create_fc(self):
|
|
18
|
-
self.fc = torch.nn.Linear(1024,
|
|
25
|
+
def _create_fc(self, num_classes: int):
|
|
26
|
+
self.fc = torch.nn.Linear(1024, num_classes)
|
|
19
27
|
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
28
|
|
|
21
29
|
def forward(self, images: torch.Tensor):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
|
|
2
|
-
homa/device.py,sha256=
|
|
2
|
+
homa/device.py,sha256=dpKI-ah_kPgNfFH_ism8YXHndEndGngBrTVnuZZ2J2I,408
|
|
3
3
|
homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
|
|
4
4
|
homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
|
|
5
5
|
homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
|
|
@@ -71,29 +71,65 @@ homa/cli/Commands/__init__.py,sha256=PYKkcG06R5LqLnp2x8otuimzRpL4oMbziL3xEMkCffc
|
|
|
71
71
|
homa/cli/namespaces/CacheNamespace.py,sha256=QXGljzj287stzTx0y_MXnqvCgPLqd7WjSPop2WDe14E,784
|
|
72
72
|
homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33IFCiqYro,590
|
|
73
73
|
homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
|
|
74
|
-
homa/
|
|
74
|
+
homa/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
75
|
+
homa/core/concerns/MovesNetworkToDevice.py,sha256=OPMvO7scsM6NNy_fM0cJdkRdoVc-b2j6l4bz88cBif0,348
|
|
76
|
+
homa/core/concerns/__init__.py,sha256=O9OXMIMYrkIgp11lAyEv-OgT3Wq0IvNdDVZr2bOmpQU,55
|
|
77
|
+
homa/ensemble/Ensemble.py,sha256=mrqwbEm8OtiBmEgKuO6RzO1V8v80vrQFIJ4WHl8Yqgk,356
|
|
75
78
|
homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
|
|
79
|
+
homa/ensemble/utils.py,sha256=nn6eAgGW7ZafjjOVJWzGUWE0XYeyJAOMNEHm-lHxd6A,200
|
|
76
80
|
homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
|
|
77
81
|
homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
|
|
78
82
|
homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
|
|
79
83
|
homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
|
|
80
84
|
homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
|
|
81
85
|
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
|
|
82
|
-
homa/ensemble/concerns/
|
|
83
|
-
homa/ensemble/concerns/
|
|
84
|
-
homa/ensemble/concerns/
|
|
85
|
-
homa/ensemble/concerns/
|
|
86
|
+
homa/ensemble/concerns/ReportsEnsembleSize.py,sha256=eIweQHpLcfGnNLwiMuTho-9rDgxV0xXGHPTOaEOABzw,240
|
|
87
|
+
homa/ensemble/concerns/ReportsLogits.py,sha256=sJZGJwTISZo2DFmJbI5zqhrt7CblNi09iGn1zaEk-ro,593
|
|
88
|
+
homa/ensemble/concerns/SavesEnsembleModels.py,sha256=VIXT9wJ8FiCspIvI2-F4WPa6mBBe9SWvMLFyad3TgRg,275
|
|
89
|
+
homa/ensemble/concerns/StoresModels.py,sha256=uAYbdUtadZsAJ9-Fj4jJFLWC23qfiXKo1mBm6-PZkN4,963
|
|
90
|
+
homa/ensemble/concerns/__init__.py,sha256=IF5mHIgzCuCpA2EmpkctbjAr0kYW4P96v7RffK2V_iQ,548
|
|
91
|
+
homa/graph/GraphAttention.py,sha256=oPXuc1s-3BXwGkHuomEIxnOcZSRBbL8b8fO0432RdDo,478
|
|
92
|
+
homa/graph/__init__.py,sha256=NCtMUB-awe9UvkwDYqWXxTAZ1RW-AwSW1DD9X_kFkD0,43
|
|
93
|
+
homa/graph/modules/GraphAttentionHeadModule.py,sha256=R47ScMnOgpRLNR9encaqbM8tFiYfb2UA2X18f55NMek,1397
|
|
94
|
+
homa/graph/modules/MultiHeadGraphAttentionModule.py,sha256=tmxCGLxIVlvn_mvnPsqT8zrSCH_UVEIMLkR9VHky670,792
|
|
95
|
+
homa/graph/modules/__init__.py,sha256=R-wuNFJvRZ8U-6v7GNGrigxPfh1BlupVZo-MPd0HiR8,136
|
|
86
96
|
homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
|
|
87
|
-
homa/loss/Loss.py,sha256=
|
|
97
|
+
homa/loss/Loss.py,sha256=OROjusRHg4F3PA92SjU1utCgS1D_5KqELlcVFhvQOoU,53
|
|
88
98
|
homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
|
|
89
|
-
homa/
|
|
90
|
-
homa/
|
|
99
|
+
homa/rl/DQN.py,sha256=PaNq9Z1K87IQ7Y7mhiJ1CE4TofgV7c7m1py8qT09vE4,20
|
|
100
|
+
homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
|
|
101
|
+
homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
|
|
102
|
+
homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
|
|
103
|
+
homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
|
|
104
|
+
homa/rl/buffers/Buffer.py,sha256=wOk8MH0Wf0cpvavpHIK2O7PrbGP6MwHTH5YFkq2Ints,288
|
|
105
|
+
homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
|
|
106
|
+
homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
|
|
107
|
+
homa/rl/buffers/SoftActorCriticBuffer.py,sha256=iDC2C5XFvONT3f7YX_gYXQJGU9wz2usvPOVGbQUd22M,1796
|
|
108
|
+
homa/rl/buffers/__init__.py,sha256=h1AkCHs6isXbNtxpaZfLp6YudHj1KlnOvURE64vhRa4,190
|
|
109
|
+
homa/rl/buffers/concerns/HasRecordAlternatives.py,sha256=D5aVlPZlnGm0GyGtikKb4wZqyO6zpyqR1IOETmAgLx4,362
|
|
110
|
+
homa/rl/buffers/concerns/ResetsCollection.py,sha256=bZ8q4czYXo1jMtVCnnlG69OgiJ0AqSGY6CiKzJC6xtQ,215
|
|
111
|
+
homa/rl/buffers/concerns/__init__.py,sha256=g9EKH503NhO0clJhxRMFD-upSw1nkzjKLCxH4SVE-wk,104
|
|
112
|
+
homa/rl/diayn/Actor.py,sha256=SYh1gKQ6DgKFYYPq0BEV10B1QVNQw6bDk08GmdwazNc,1868
|
|
113
|
+
homa/rl/diayn/Critic.py,sha256=ML2FQj6dH8gJakDHlDQRCOChUA2z4pPDM52zCrp_6xk,1188
|
|
114
|
+
homa/rl/diayn/Discriminator.py,sha256=m2faov_tZned7Tcogci5X_prHmncqyqPuPrm3xWZWIo,1566
|
|
115
|
+
homa/rl/diayn/__init__.py,sha256=HV0LWJ-FbTPNf3kBH_GFWoxUGFwvsi6SrHRsz7QRYVQ,93
|
|
116
|
+
homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7Gi8ycZQeGDq8uzKlDI,1522
|
|
117
|
+
homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
|
|
118
|
+
homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
|
|
119
|
+
homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
|
|
120
|
+
homa/rl/sac/SoftActor.py,sha256=CxR58IFrZ6xlmBj_gq_abZfgdzlVD71c6wA6wQiVL2c,2142
|
|
121
|
+
homa/rl/sac/SoftCritic.py,sha256=wFIunTgKGBy64Igu7zuvE2BvGz2e-DTplviLyq4tQ7M,3031
|
|
122
|
+
homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
|
|
123
|
+
homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
|
|
124
|
+
homa/rl/sac/modules/SoftActorModule.py,sha256=AiWnsWkmQONjOAWAp06eO-lLWEYNJDmx8FSjPKTcjI0,1152
|
|
125
|
+
homa/rl/sac/modules/SoftCriticModule.py,sha256=aOfhDZTB5og-BLTsmdBdIcRufygCJUas7P-ikBvWQ34,928
|
|
126
|
+
homa/rl/sac/modules/__init__.py,sha256=h-22B5CAK1xhn75tolI5J5sQMxl--kOXbQ6r_JfHIOA,147
|
|
91
127
|
homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
|
|
92
128
|
homa/vision/Model.py,sha256=JIeVpHJwirHfsDfYYbLsu0kt7bGf4nhMQGIOagUDKw4,22
|
|
93
|
-
homa/vision/Resnet.py,sha256=
|
|
129
|
+
homa/vision/Resnet.py,sha256=BuDMMcu8J_mVlEHaMDche2mVl-SApT80OKmoDA4eAPQ,535
|
|
94
130
|
homa/vision/StochasticClassifier.py,sha256=6-o0TaH4iWXiPFefL7DOdLr3ZrTnjnJ9PIgQLlygN8w,497
|
|
95
131
|
homa/vision/StochasticSwin.py,sha256=FggzfaVYrP4fnjAFcdMpDozwQHc7CQhl3iRw78oQh0o,425
|
|
96
|
-
homa/vision/Swin.py,sha256=
|
|
132
|
+
homa/vision/Swin.py,sha256=8sNm8S3uzyTIhu6msp4hUV0dKIcTBid_EBNr7H_iK20,789
|
|
97
133
|
homa/vision/__init__.py,sha256=w5OkcmdU6Ik5wHIJzeV1Z2UElQtvCsUZks1Q-xViSVg,153
|
|
98
134
|
homa/vision/utils.py,sha256=WB2b7eMDaf6UO3SuS7cB6IJk-9NRQesLavuzWUZRZyg,389
|
|
99
135
|
homa/vision/concerns/HasLabels.py,sha256=fM6nHLeQaEaWDlV6R8NQ5hgOSiwspPxOIwj-nvYXbP0,321
|
|
@@ -104,10 +140,10 @@ homa/vision/concerns/ReportsMetrics.py,sha256=93Hw_JBUbwfkrJNJA1xFSQ4cqRwzbSv4nP
|
|
|
104
140
|
homa/vision/concerns/Trainable.py,sha256=SRCW3XpG9_DQgubyqhALlYDHwAWNzVVFjshUv1ecuEQ,988
|
|
105
141
|
homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioURsg,234
|
|
106
142
|
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
107
|
-
homa/vision/modules/SwinModule.py,sha256=
|
|
143
|
+
homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
|
|
108
144
|
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
109
|
-
homa-0.
|
|
110
|
-
homa-0.
|
|
111
|
-
homa-0.
|
|
112
|
-
homa-0.
|
|
113
|
-
homa-0.
|
|
145
|
+
homa-0.3.11.dist-info/METADATA,sha256=SvSxNXB1IsX3N5IfhOsnWYtvhjpfzauJPanVH7i5cRs,1760
|
|
146
|
+
homa-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
147
|
+
homa-0.3.11.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
148
|
+
homa-0.3.11.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
149
|
+
homa-0.3.11.dist-info/RECORD,,
|
homa/torch/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .helpers import *
|
homa/torch/helpers.py
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|