homa 0.2.95__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/core/__init__.py +0 -0
- homa/core/concerns/MovesNetworkToDevice.py +13 -0
- homa/core/concerns/TracksTime.py +7 -0
- homa/core/concerns/__init__.py +2 -0
- homa/device.py +5 -0
- homa/ensemble/Ensemble.py +4 -2
- homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
- homa/ensemble/concerns/PredictsProbabilities.py +2 -2
- homa/ensemble/concerns/ReportsClassificationMetrics.py +2 -1
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleSize.py +11 -0
- homa/ensemble/concerns/ReportsLogits.py +26 -5
- homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
- homa/ensemble/concerns/StoresModels.py +8 -11
- homa/ensemble/concerns/__init__.py +2 -1
- homa/ensemble/utils.py +9 -0
- homa/graph/GraphAttention.py +13 -0
- homa/graph/__init__.py +1 -0
- homa/graph/modules/GraphAttentionHeadModule.py +37 -0
- homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
- homa/graph/modules/__init__.py +2 -0
- homa/loss/Loss.py +4 -1
- homa/rl/DQN.py +2 -0
- homa/rl/DRQN.py +5 -0
- homa/rl/DiversityIsAllYouNeed.py +96 -0
- homa/rl/SoftActorCritic.py +67 -0
- homa/rl/__init__.py +4 -0
- homa/rl/buffers/Buffer.py +13 -0
- homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
- homa/rl/buffers/ImageBuffer.py +5 -0
- homa/rl/buffers/SoftActorCriticBuffer.py +64 -0
- homa/rl/buffers/__init__.py +4 -0
- homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
- homa/rl/buffers/concerns/ResetsCollection.py +9 -0
- homa/rl/buffers/concerns/__init__.py +2 -0
- homa/rl/diayn/Actor.py +54 -0
- homa/rl/diayn/Critic.py +41 -0
- homa/rl/diayn/Discriminator.py +45 -0
- homa/rl/diayn/__init__.py +3 -0
- homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
- homa/rl/diayn/modules/CriticModule.py +28 -0
- homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
- homa/rl/diayn/modules/__init__.py +3 -0
- homa/rl/sac/SoftActor.py +70 -0
- homa/rl/sac/SoftCritic.py +98 -0
- homa/rl/sac/__init__.py +2 -0
- homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
- homa/rl/sac/modules/SoftActorModule.py +35 -0
- homa/rl/sac/modules/SoftCriticModule.py +30 -0
- homa/rl/sac/modules/__init__.py +3 -0
- homa/rl/utils.py +7 -0
- homa/vision/Resnet.py +3 -3
- homa/vision/Swin.py +17 -5
- homa/vision/modules/SwinModule.py +17 -9
- {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/METADATA +1 -1
- {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/RECORD +61 -23
- homa/ensemble/concerns/ReportsSize.py +0 -11
- homa/torch/__init__.py +0 -1
- homa/torch/helpers.py +0 -6
- {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/WHEEL +0 -0
- {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/entry_points.txt +0 -0
- {homa-0.2.95.dist-info → homa-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class HasRecordAlternatives:
|
|
2
|
+
def __init__(self, *args, **kwargs):
|
|
3
|
+
super().__init__(*args, **kwargs)
|
|
4
|
+
|
|
5
|
+
def add(self, *args, **kwargs) -> None:
|
|
6
|
+
self.record(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def push(self, *args, **kwargs) -> None:
|
|
9
|
+
self.record(*args, **kwargs)
|
|
10
|
+
|
|
11
|
+
def append(self, *args, **kwargs) -> None:
|
|
12
|
+
self.record(*args, **kwargs)
|
homa/rl/diayn/Actor.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.distributions import Normal
|
|
3
|
+
from .modules import ContinuousActorModule
|
|
4
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Actor(MovesNetworkToDevice):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
state_dimension: int,
|
|
11
|
+
action_dimension: int,
|
|
12
|
+
num_skills: int,
|
|
13
|
+
hidden_dimension: int,
|
|
14
|
+
lr: float,
|
|
15
|
+
decay: float,
|
|
16
|
+
epsilon: float,
|
|
17
|
+
min_std: float,
|
|
18
|
+
max_std: float,
|
|
19
|
+
):
|
|
20
|
+
self.epsilon: float = epsilon
|
|
21
|
+
self.network = ContinuousActorModule(
|
|
22
|
+
state_dimension=state_dimension,
|
|
23
|
+
action_dimension=action_dimension,
|
|
24
|
+
hidden_dimension=hidden_dimension,
|
|
25
|
+
num_skills=num_skills,
|
|
26
|
+
min_std=min_std,
|
|
27
|
+
max_std=max_std,
|
|
28
|
+
)
|
|
29
|
+
self.optimizer = torch.optim.AdamW(
|
|
30
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def action(self, state: torch.Tensor, skill: torch.Tensor):
|
|
34
|
+
mean, std = self.network(state, skill)
|
|
35
|
+
std = std.exp()
|
|
36
|
+
distribution = Normal(mean, std)
|
|
37
|
+
raw_action = distribution.rsample()
|
|
38
|
+
action = torch.tanh(raw_action)
|
|
39
|
+
corrected_probabilities = torch.log(1.0 - action.pow(2) + self.epsilon)
|
|
40
|
+
probabilities = distribution.log_prob(raw_action) - corrected_probabilities
|
|
41
|
+
probabilities = probabilities.sum(dim=-1, keepdim=True)
|
|
42
|
+
return action, probabilities
|
|
43
|
+
|
|
44
|
+
def train(self, advantages: torch.Tensor, probabilities: torch.Tensor) -> float:
|
|
45
|
+
self.optimizer.zero_grad()
|
|
46
|
+
loss = self.loss(advantages=advantages, probabilities=probabilities)
|
|
47
|
+
loss.backward()
|
|
48
|
+
self.optimizer.step()
|
|
49
|
+
return loss.item()
|
|
50
|
+
|
|
51
|
+
def loss(
|
|
52
|
+
self, advantages: torch.Tensor, probabilities: torch.Tensor
|
|
53
|
+
) -> torch.Tensor:
|
|
54
|
+
return -(probabilities * advantages.detach()).mean()
|
homa/rl/diayn/Critic.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import CriticModule
|
|
3
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Critic(MovesNetworkToDevice):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
state_dimension: int,
|
|
10
|
+
hidden_dimension: int,
|
|
11
|
+
num_skills: int,
|
|
12
|
+
lr: float,
|
|
13
|
+
decay: float,
|
|
14
|
+
gamma: float,
|
|
15
|
+
):
|
|
16
|
+
self.network = CriticModule(
|
|
17
|
+
state_dimension=state_dimension,
|
|
18
|
+
hidden_dimension=hidden_dimension,
|
|
19
|
+
num_skills=num_skills,
|
|
20
|
+
)
|
|
21
|
+
self.optimizer = torch.optim.AdamW(
|
|
22
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
23
|
+
)
|
|
24
|
+
self.criterion = torch.nn.SmoothL1Loss()
|
|
25
|
+
self.gamma: float = gamma
|
|
26
|
+
|
|
27
|
+
def train(self, advantages: torch.Tensor):
|
|
28
|
+
self.optimizer.zero_grad()
|
|
29
|
+
loss = self.loss(advantages=advantages)
|
|
30
|
+
loss.backward()
|
|
31
|
+
self.optimizer.step()
|
|
32
|
+
|
|
33
|
+
def loss(self, advantages: torch.Tensor):
|
|
34
|
+
return advantages.pow(2).mean()
|
|
35
|
+
|
|
36
|
+
def values(self, states: torch.Tensor, skills: torch.Tensor):
|
|
37
|
+
return self.network(states, skills)
|
|
38
|
+
|
|
39
|
+
@torch.no_grad()
|
|
40
|
+
def values_(self, *args, **kwargs):
|
|
41
|
+
return self.values(*args, **kwargs)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from .modules import DiscriminatorModule
|
|
4
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Discriminator(MovesNetworkToDevice):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
state_dimension: int,
|
|
11
|
+
hidden_dimension: int,
|
|
12
|
+
num_skills: int,
|
|
13
|
+
decay: float,
|
|
14
|
+
lr: float,
|
|
15
|
+
):
|
|
16
|
+
self.num_skills: int = num_skills
|
|
17
|
+
self.network = DiscriminatorModule(
|
|
18
|
+
state_dimension=state_dimension,
|
|
19
|
+
hidden_dimension=hidden_dimension,
|
|
20
|
+
num_skills=num_skills,
|
|
21
|
+
)
|
|
22
|
+
self.optimizer = torch.optim.AdamW(
|
|
23
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
24
|
+
)
|
|
25
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
26
|
+
|
|
27
|
+
def loss(self, states: torch.Tensor, skills_indices: torch.Tensor):
|
|
28
|
+
logits = self.network(states)
|
|
29
|
+
return self.criterion(logits, skills_indices)
|
|
30
|
+
|
|
31
|
+
@torch.no_grad()
|
|
32
|
+
def reward(self, state: torch.Tensor, skill_index: torch.Tensor):
|
|
33
|
+
logits = self.network(state)
|
|
34
|
+
probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
35
|
+
entropy = numpy.log(1.0 / self.num_skills)
|
|
36
|
+
if skill_index.dim() == 1:
|
|
37
|
+
skill_index = skill_index.unsqueeze(-1)
|
|
38
|
+
reward = probabilities.gather(1, skill_index.long()) - entropy
|
|
39
|
+
return reward.squeeze(-1)
|
|
40
|
+
|
|
41
|
+
def train(self, states: torch.Tensor, skills_indices: torch.Tensor):
|
|
42
|
+
self.optimizer.zero_grad()
|
|
43
|
+
loss = self.loss(states=states, skills_indices=skills_indices)
|
|
44
|
+
loss.backward()
|
|
45
|
+
self.optimizer.step()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ContinuousActorModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
action_dimension: int,
|
|
9
|
+
hidden_dimension: int,
|
|
10
|
+
num_skills: int,
|
|
11
|
+
min_std: float,
|
|
12
|
+
max_std: float,
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.state_dimension: int = state_dimension
|
|
16
|
+
self.action_dimension: int = action_dimension
|
|
17
|
+
self.num_skills: int = num_skills
|
|
18
|
+
self.hidden_dimension: int = hidden_dimension
|
|
19
|
+
self.input_dimension: int = self.state_dimension + self.num_skills
|
|
20
|
+
self.min_std: float = min_std
|
|
21
|
+
self.max_std: float = max_std
|
|
22
|
+
|
|
23
|
+
self.phi = torch.nn.Sequential(
|
|
24
|
+
torch.nn.Linear(self.input_dimension, self.hidden_dimension),
|
|
25
|
+
torch.nn.ReLU(),
|
|
26
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
27
|
+
torch.nn.ReLU(),
|
|
28
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
32
|
+
self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
33
|
+
|
|
34
|
+
def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
# fix the size to be one state per batch
|
|
36
|
+
state = state.view(state.size(0), -1)
|
|
37
|
+
|
|
38
|
+
psi = torch.cat([state, skill], dim=-1)
|
|
39
|
+
features = self.phi(psi)
|
|
40
|
+
mean = self.mu(features)
|
|
41
|
+
std = self.xi(features).clamp(self.min_std, self.max_std)
|
|
42
|
+
return mean, std
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CriticModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
num_skills: int,
|
|
10
|
+
):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.state_dimension: int = state_dimension
|
|
13
|
+
self.num_skills: int = num_skills
|
|
14
|
+
self.hidden_dimension: int = hidden_dimension
|
|
15
|
+
self.input_dimension: int = self.state_dimension + self.num_skills
|
|
16
|
+
|
|
17
|
+
self.phi = torch.nn.Sequential(
|
|
18
|
+
torch.nn.Linear(self.input_dimension, self.hidden_dimension),
|
|
19
|
+
torch.nn.ReLU(),
|
|
20
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
21
|
+
torch.nn.ReLU(),
|
|
22
|
+
)
|
|
23
|
+
self.fc = (torch.nn.Linear(self.hidden_dimension, 1),)
|
|
24
|
+
|
|
25
|
+
def forward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
psi = torch.cat([state, skill], dim=-1)
|
|
27
|
+
features = self.phi(psi)
|
|
28
|
+
return self.fc(features).squeeze(-1)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Type
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DiscriminatorModule(torch.nn.Module):
|
|
6
|
+
def __init__(self, state_dimension: int, hidden_dimension: int, num_skills: int):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.state_dimension: int = state_dimension
|
|
9
|
+
self.hidden_dimension: int = hidden_dimension
|
|
10
|
+
self.num_skills: int = num_skills
|
|
11
|
+
|
|
12
|
+
self.phi: Type[torch.nn.Sequential] = torch.nn.Sequential(
|
|
13
|
+
torch.nn.Linear(self.state_dimension, self.hidden_dimension),
|
|
14
|
+
torch.nn.ReLU(),
|
|
15
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
16
|
+
torch.nn.ReLU(),
|
|
17
|
+
)
|
|
18
|
+
self.fc: Type[torch.nn.Linear] = torch.nn.Linear(
|
|
19
|
+
self.hidden_dimension, self.num_skills
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, state: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
features: torch.Tensor = self.phi(state)
|
|
24
|
+
return self.fc(features)
|
homa/rl/sac/SoftActor.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from .SoftCritic import SoftCritic
|
|
4
|
+
from .modules import SoftActorModule
|
|
5
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SoftActor(MovesNetworkToDevice):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
state_dimension: int,
|
|
12
|
+
hidden_dimension: int,
|
|
13
|
+
action_dimension: int,
|
|
14
|
+
lr: float,
|
|
15
|
+
weight_decay: float,
|
|
16
|
+
alpha: float,
|
|
17
|
+
min_std: float,
|
|
18
|
+
max_std: float,
|
|
19
|
+
):
|
|
20
|
+
self.alpha: float = alpha
|
|
21
|
+
|
|
22
|
+
self.network = SoftActorModule(
|
|
23
|
+
state_dimension=state_dimension,
|
|
24
|
+
hidden_dimension=hidden_dimension,
|
|
25
|
+
action_dimension=action_dimension,
|
|
26
|
+
min_std=min_std,
|
|
27
|
+
max_std=max_std,
|
|
28
|
+
)
|
|
29
|
+
self.optimizer = torch.optim.AdamW(
|
|
30
|
+
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def train(self, states: torch.Tensor, critic: SoftCritic):
|
|
34
|
+
self.network.train()
|
|
35
|
+
self.optimizer.zero_grad()
|
|
36
|
+
loss = self.loss(states=states, critic=critic)
|
|
37
|
+
loss.backward()
|
|
38
|
+
self.optimizer.step()
|
|
39
|
+
|
|
40
|
+
def loss(self, states: torch.Tensor, critic: SoftCritic) -> torch.Tensor:
|
|
41
|
+
actions, probabilities = self.sample(states)
|
|
42
|
+
q_alpha, q_beta = critic.network(states, actions)
|
|
43
|
+
q = torch.min(q_alpha, q_beta)
|
|
44
|
+
return (self.alpha * probabilities - q).mean()
|
|
45
|
+
|
|
46
|
+
def process_state(self, state: numpy.ndarray | torch.Tensor) -> torch.Tensor:
|
|
47
|
+
if isinstance(state, numpy.ndarray):
|
|
48
|
+
state = torch.from_numpy(state).float()
|
|
49
|
+
|
|
50
|
+
if state.ndim < 2:
|
|
51
|
+
state = state.unsqueeze(0)
|
|
52
|
+
|
|
53
|
+
return state
|
|
54
|
+
|
|
55
|
+
def sample(self, state: numpy.ndarray | torch.Tensor):
|
|
56
|
+
state = self.process_state(state)
|
|
57
|
+
|
|
58
|
+
mean, std = self.network(state)
|
|
59
|
+
# following line prevents standard deviations to be negative
|
|
60
|
+
std = std.exp()
|
|
61
|
+
|
|
62
|
+
distribution = torch.distributions.Normal(mean, std)
|
|
63
|
+
|
|
64
|
+
pre_tanh = distribution.rsample()
|
|
65
|
+
action = torch.tanh(pre_tanh)
|
|
66
|
+
|
|
67
|
+
probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
|
|
68
|
+
probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
|
|
69
|
+
|
|
70
|
+
return action, probabilities
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn.functional import mse_loss as mse
|
|
3
|
+
from .modules import DualSoftCriticModule
|
|
4
|
+
from .SoftActor import SoftActor
|
|
5
|
+
from ..utils import soft_update
|
|
6
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SoftCritic(MovesNetworkToDevice):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
state_dimension: int,
|
|
13
|
+
hidden_dimension: int,
|
|
14
|
+
action_dimension: int,
|
|
15
|
+
lr: float,
|
|
16
|
+
weight_decay: float,
|
|
17
|
+
gamma: float,
|
|
18
|
+
alpha: float,
|
|
19
|
+
):
|
|
20
|
+
self.gamma: float = gamma
|
|
21
|
+
self.alpha: float = alpha
|
|
22
|
+
|
|
23
|
+
self.network = DualSoftCriticModule(
|
|
24
|
+
state_dimension=state_dimension,
|
|
25
|
+
hidden_dimension=hidden_dimension,
|
|
26
|
+
action_dimension=action_dimension,
|
|
27
|
+
)
|
|
28
|
+
self.target = DualSoftCriticModule(
|
|
29
|
+
state_dimension=state_dimension,
|
|
30
|
+
hidden_dimension=hidden_dimension,
|
|
31
|
+
action_dimension=action_dimension,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# copy source to target when initiated
|
|
35
|
+
self.target.load_state_dict(self.network.state_dict())
|
|
36
|
+
|
|
37
|
+
self.optimizer = torch.optim.AdamW(
|
|
38
|
+
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def train(
|
|
42
|
+
self,
|
|
43
|
+
states: torch.Tensor,
|
|
44
|
+
actions: torch.Tensor,
|
|
45
|
+
rewards: torch.Tensor,
|
|
46
|
+
terminations: torch.Tensor,
|
|
47
|
+
next_states: torch.Tensor,
|
|
48
|
+
actor: SoftActor,
|
|
49
|
+
):
|
|
50
|
+
self.network.train()
|
|
51
|
+
self.optimizer.zero_grad()
|
|
52
|
+
loss = self.loss(
|
|
53
|
+
states=states,
|
|
54
|
+
actions=actions,
|
|
55
|
+
rewards=rewards,
|
|
56
|
+
terminations=terminations,
|
|
57
|
+
next_states=next_states,
|
|
58
|
+
actor=actor,
|
|
59
|
+
)
|
|
60
|
+
loss.backward()
|
|
61
|
+
self.optimizer.step()
|
|
62
|
+
|
|
63
|
+
def loss(
|
|
64
|
+
self,
|
|
65
|
+
states: torch.Tensor,
|
|
66
|
+
actions: torch.Tensor,
|
|
67
|
+
rewards: torch.Tensor,
|
|
68
|
+
terminations: torch.Tensor,
|
|
69
|
+
next_states: torch.Tensor,
|
|
70
|
+
actor: torch.nn.Module,
|
|
71
|
+
):
|
|
72
|
+
q_alpha, q_beta = self.network(states, actions)
|
|
73
|
+
target = self.calculate_target(
|
|
74
|
+
rewards=rewards,
|
|
75
|
+
terminations=terminations,
|
|
76
|
+
next_states=next_states,
|
|
77
|
+
actor=actor,
|
|
78
|
+
)
|
|
79
|
+
return mse(q_alpha, target) + mse(q_beta, target)
|
|
80
|
+
|
|
81
|
+
@torch.no_grad()
|
|
82
|
+
def calculate_target(
|
|
83
|
+
self,
|
|
84
|
+
rewards: torch.Tensor,
|
|
85
|
+
terminations: torch.Tensor,
|
|
86
|
+
next_states: torch.Tensor,
|
|
87
|
+
actor: SoftActor,
|
|
88
|
+
):
|
|
89
|
+
termination_mask = 1 - terminations
|
|
90
|
+
next_actions, next_probabilities = actor.sample(next_states)
|
|
91
|
+
q_alpha, q_beta = self.target(next_states, next_actions)
|
|
92
|
+
q = torch.min(q_alpha, q_beta)
|
|
93
|
+
entropy_q = q - self.alpha * next_probabilities
|
|
94
|
+
return rewards + self.gamma * termination_mask * entropy_q
|
|
95
|
+
|
|
96
|
+
def update(self, tau: float):
|
|
97
|
+
soft_update(network=self.network.alpha, target=self.target.alpha, tau=tau)
|
|
98
|
+
soft_update(network=self.network.beta, target=self.target.beta, tau=tau)
|
homa/rl/sac/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .SoftCriticModule import SoftCriticModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DualSoftCriticModule(torch.nn.Module):
|
|
6
|
+
def __init__(
|
|
7
|
+
self, state_dimension: int, hidden_dimension: int, action_dimension: int
|
|
8
|
+
):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.alpha = SoftCriticModule(
|
|
11
|
+
state_dimension=state_dimension,
|
|
12
|
+
hidden_dimension=hidden_dimension,
|
|
13
|
+
action_dimension=action_dimension,
|
|
14
|
+
)
|
|
15
|
+
self.beta = SoftCriticModule(
|
|
16
|
+
state_dimension=state_dimension,
|
|
17
|
+
hidden_dimension=hidden_dimension,
|
|
18
|
+
action_dimension=action_dimension,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def forward(self, state: torch.Tensor, action: torch.Tensor):
|
|
22
|
+
return self.alpha(state, action), self.beta(state, action)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SoftActorModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
min_std: float,
|
|
11
|
+
max_std: float,
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
|
|
15
|
+
self.state_dimension: int = state_dimension
|
|
16
|
+
self.hidden_dimension: int = hidden_dimension
|
|
17
|
+
self.action_dimension: int = action_dimension
|
|
18
|
+
self.min_std: float = float(min_std)
|
|
19
|
+
self.max_std: float = float(max_std)
|
|
20
|
+
|
|
21
|
+
self.phi = torch.nn.Sequential(
|
|
22
|
+
torch.nn.Linear(self.state_dimension, self.hidden_dimension),
|
|
23
|
+
torch.nn.ReLU(),
|
|
24
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
25
|
+
torch.nn.ReLU(),
|
|
26
|
+
)
|
|
27
|
+
self.mu = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
28
|
+
self.xi = torch.nn.Linear(self.hidden_dimension, self.action_dimension)
|
|
29
|
+
|
|
30
|
+
def forward(self, state: torch.Tensor):
|
|
31
|
+
features = self.phi(state)
|
|
32
|
+
mean = self.mu(features)
|
|
33
|
+
std = self.xi(features)
|
|
34
|
+
std = std.clamp(self.min_std, self.max_std)
|
|
35
|
+
return mean, std
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SoftCriticModule(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
state_dimension: int,
|
|
8
|
+
hidden_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
):
|
|
11
|
+
super().__init__()
|
|
12
|
+
|
|
13
|
+
self.state_dimension: int = state_dimension
|
|
14
|
+
self.action_dimension: int = action_dimension
|
|
15
|
+
self.hidden_dimension: int = hidden_dimension
|
|
16
|
+
|
|
17
|
+
self.phi = torch.nn.Sequential(
|
|
18
|
+
torch.nn.Linear(
|
|
19
|
+
self.state_dimension + self.action_dimension, self.hidden_dimension
|
|
20
|
+
),
|
|
21
|
+
torch.nn.ReLU(),
|
|
22
|
+
torch.nn.Linear(self.hidden_dimension, self.hidden_dimension),
|
|
23
|
+
torch.nn.ReLU(),
|
|
24
|
+
)
|
|
25
|
+
self.fc = torch.nn.Linear(self.hidden_dimension, 1)
|
|
26
|
+
|
|
27
|
+
def forward(self, state: torch.Tensor, action: torch.Tensor):
|
|
28
|
+
psi = torch.cat([state, action], dim=1)
|
|
29
|
+
features = self.phi(psi)
|
|
30
|
+
return self.fc(features)
|
homa/rl/utils.py
ADDED
homa/vision/Resnet.py
CHANGED
|
@@ -2,12 +2,12 @@ import torch
|
|
|
2
2
|
from .modules import ResnetModule
|
|
3
3
|
from .Classifier import Classifier
|
|
4
4
|
from .concerns import Trainable, ReportsMetrics
|
|
5
|
-
from ..
|
|
5
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class Resnet(Classifier, Trainable, ReportsMetrics):
|
|
8
|
+
class Resnet(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
|
|
9
9
|
def __init__(self, num_classes: int, lr: float = 0.001):
|
|
10
10
|
super().__init__()
|
|
11
|
-
self.network = ResnetModule(num_classes)
|
|
11
|
+
self.network = ResnetModule(num_classes)
|
|
12
12
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
13
13
|
self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
|
homa/vision/Swin.py
CHANGED
|
@@ -2,12 +2,24 @@ import torch
|
|
|
2
2
|
from .Classifier import Classifier
|
|
3
3
|
from .concerns import Trainable, ReportsMetrics
|
|
4
4
|
from .modules import SwinModule
|
|
5
|
-
from ..
|
|
5
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class Swin(Classifier, Trainable, ReportsMetrics):
|
|
9
|
-
def __init__(
|
|
8
|
+
class Swin(Classifier, Trainable, ReportsMetrics, MovesNetworkToDevice):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
num_classes: int,
|
|
12
|
+
lr: float = 0.0001,
|
|
13
|
+
decay: float = 0.0,
|
|
14
|
+
variant: str = "base",
|
|
15
|
+
weights="DEFAULT",
|
|
16
|
+
):
|
|
10
17
|
super().__init__()
|
|
11
|
-
self.
|
|
12
|
-
self.
|
|
18
|
+
self.num_classes = num_classes
|
|
19
|
+
self.network = SwinModule(
|
|
20
|
+
num_classes=self.num_classes, variant=variant, weights=weights
|
|
21
|
+
)
|
|
22
|
+
self.optimizer = torch.optim.AdamW(
|
|
23
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
24
|
+
)
|
|
13
25
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
@@ -1,21 +1,29 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from torchvision.models import swin_v2_b
|
|
2
|
+
from torchvision.models import swin_v2_b, swin_v2_s, swin_v2_t
|
|
3
3
|
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class SwinModule(torch.nn.Module):
|
|
7
|
-
def __init__(self, num_classes: int):
|
|
7
|
+
def __init__(self, num_classes: int, variant: str, weights):
|
|
8
8
|
super().__init__()
|
|
9
|
-
self.
|
|
10
|
-
self.
|
|
11
|
-
self._create_fc()
|
|
9
|
+
self._create_encoder(variant=variant, weights=weights)
|
|
10
|
+
self._create_fc(num_classes=num_classes)
|
|
12
11
|
|
|
13
|
-
def
|
|
14
|
-
|
|
12
|
+
def variant_instance(self, variant: str):
|
|
13
|
+
variant_map = {"tiny": swin_v2_t, "small": swin_v2_s, "base": swin_v2_b}
|
|
14
|
+
return variant_map.get(variant)
|
|
15
|
+
|
|
16
|
+
def _create_encoder(self, variant: str, weights):
|
|
17
|
+
if variant not in ["tiny", "small", "base"]:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
f"Swin variant needs to be one of [tiny, small, base]. Invalid {variant} was provided."
|
|
20
|
+
)
|
|
21
|
+
instance = self.variant_instnace(variant)
|
|
22
|
+
self.encoder = instance(weights=weights)
|
|
15
23
|
self.encoder.head = torch.nn.Identity()
|
|
16
24
|
|
|
17
|
-
def _create_fc(self):
|
|
18
|
-
self.fc = torch.nn.Linear(1024,
|
|
25
|
+
def _create_fc(self, num_classes: int):
|
|
26
|
+
self.fc = torch.nn.Linear(1024, num_classes)
|
|
19
27
|
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
28
|
|
|
21
29
|
def forward(self, images: torch.Tensor):
|