homa 0.2.95__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/core/__init__.py +0 -0
- homa/core/concerns/MovesNetworkToDevice.py +13 -0
- homa/core/concerns/__init__.py +1 -0
- homa/device.py +4 -0
- homa/ensemble/Ensemble.py +4 -2
- homa/ensemble/concerns/{ReportsSize.py → ReportsEnsembleSize.py} +3 -3
- homa/ensemble/concerns/ReportsLogits.py +3 -1
- homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
- homa/ensemble/concerns/StoresModels.py +6 -9
- homa/ensemble/concerns/__init__.py +2 -1
- homa/ensemble/utils.py +9 -0
- homa/graph/GraphAttention.py +13 -0
- homa/graph/__init__.py +1 -0
- homa/graph/modules/GraphAttentionHeadModule.py +37 -0
- homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
- homa/graph/modules/__init__.py +2 -0
- homa/loss/Loss.py +4 -1
- homa/rl/DQN.py +2 -0
- homa/rl/DRQN.py +5 -0
- homa/rl/DiversityIsAllYouNeed.py +96 -0
- homa/rl/SoftActorCritic.py +64 -0
- homa/rl/__init__.py +4 -0
- homa/rl/buffers/Buffer.py +11 -0
- homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
- homa/rl/buffers/ImageBuffer.py +5 -0
- homa/rl/buffers/SoftActorCriticBuffer.py +56 -0
- homa/rl/buffers/__init__.py +4 -0
- homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
- homa/rl/buffers/concerns/ResetsCollection.py +9 -0
- homa/rl/buffers/concerns/__init__.py +2 -0
- homa/rl/diayn/Actor.py +54 -0
- homa/rl/diayn/Critic.py +41 -0
- homa/rl/diayn/Discriminator.py +45 -0
- homa/rl/diayn/__init__.py +3 -0
- homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
- homa/rl/diayn/modules/CriticModule.py +28 -0
- homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
- homa/rl/diayn/modules/__init__.py +3 -0
- homa/rl/sac/SoftActor.py +69 -0
- homa/rl/sac/SoftCritic.py +100 -0
- homa/rl/sac/__init__.py +2 -0
- homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
- homa/rl/sac/modules/SoftActorModule.py +35 -0
- homa/rl/sac/modules/SoftCriticModule.py +30 -0
- homa/rl/sac/modules/__init__.py +3 -0
- homa/vision/Resnet.py +3 -3
- homa/vision/Swin.py +17 -5
- homa/vision/modules/SwinModule.py +17 -9
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/METADATA +1 -1
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/RECORD +53 -17
- homa/torch/__init__.py +0 -1
- homa/torch/helpers.py +0 -6
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/WHEEL +0 -0
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/entry_points.txt +0 -0
- {homa-0.2.95.dist-info → homa-0.3.11.dist-info}/top_level.txt +0 -0
homa/core/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ...device import move
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MovesNetworkToDevice:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
if not hasattr(self, "network"):
|
|
9
|
+
raise RuntimeError(
|
|
10
|
+
"MovesNetworkToDevice assumes the underlying class has a network property."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
move(self.network)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .MovesNetworkToDevice import MovesNetworkToDevice
|
homa/device.py
CHANGED
homa/ensemble/Ensemble.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from .concerns import (
|
|
2
|
-
|
|
2
|
+
ReportsEnsembleSize,
|
|
3
3
|
StoresModels,
|
|
4
4
|
ReportsClassificationMetrics,
|
|
5
5
|
PredictsProbabilities,
|
|
6
|
+
SavesEnsembleModels,
|
|
6
7
|
)
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class Ensemble(
|
|
10
|
-
|
|
11
|
+
ReportsEnsembleSize,
|
|
11
12
|
ReportsClassificationMetrics,
|
|
12
13
|
PredictsProbabilities,
|
|
13
14
|
StoresModels,
|
|
15
|
+
SavesEnsembleModels,
|
|
14
16
|
):
|
|
15
17
|
def __init__(self):
|
|
16
18
|
super().__init__()
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
class
|
|
1
|
+
class ReportsEnsembleSize:
|
|
2
2
|
def __init__(self, *args, **kwargs):
|
|
3
3
|
super().__init__(*args, **kwargs)
|
|
4
4
|
|
|
5
5
|
@property
|
|
6
6
|
def size(self):
|
|
7
|
-
return len(self.
|
|
7
|
+
return len(self.weights)
|
|
8
8
|
|
|
9
9
|
@property
|
|
10
10
|
def length(self):
|
|
11
|
-
return
|
|
11
|
+
return self.size
|
|
@@ -8,7 +8,9 @@ class ReportsLogits:
|
|
|
8
8
|
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
9
|
batch_size = x.shape[0]
|
|
10
10
|
logits = torch.zeros((batch_size, self.num_classes))
|
|
11
|
-
for
|
|
11
|
+
for factory, weight in zip(self.factories, self.weights):
|
|
12
|
+
model = factory(num_classes=self.num_classes)
|
|
13
|
+
model.load_state_dict(weight)
|
|
12
14
|
logits += model(x)
|
|
13
15
|
return logits
|
|
14
16
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import
|
|
3
|
-
from
|
|
2
|
+
from typing import List, Type
|
|
3
|
+
from collections import OrderedDict
|
|
4
4
|
from ...vision import Model
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class StoresModels:
|
|
8
8
|
def __init__(self, *args, **kwargs):
|
|
9
9
|
super().__init__(*args, **kwargs)
|
|
10
|
-
self.
|
|
10
|
+
self.factories: List[Type[torch.nn.Module]] = []
|
|
11
|
+
self.weights: List[OrderedDict] = []
|
|
11
12
|
|
|
12
13
|
def record(self, model: Model | torch.nn.Module):
|
|
13
14
|
model_: torch.nn.Module | None = None
|
|
@@ -18,12 +19,8 @@ class StoresModels:
|
|
|
18
19
|
else:
|
|
19
20
|
raise TypeError("Wrong input to ensemble record")
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
torch.save(model_.to("cpu"), buffer)
|
|
24
|
-
buffer.seek(0)
|
|
25
|
-
model_ = torch.load(buffer, map_location=device)
|
|
26
|
-
self.models.append(model_)
|
|
22
|
+
self.factories.append(model_.__class__)
|
|
23
|
+
self.weights.append(model_.state_dict())
|
|
27
24
|
|
|
28
25
|
def push(self, *args, **kwargs):
|
|
29
26
|
self.record(*args, **kwargs)
|
|
@@ -5,5 +5,6 @@ from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
|
|
|
5
5
|
from .ReportsEnsembleF1 import ReportsEnsembleF1
|
|
6
6
|
from .ReportsEnsembleKappa import ReportsEnsembleKappa
|
|
7
7
|
from .ReportsLogits import ReportsLogits
|
|
8
|
-
from .
|
|
8
|
+
from .ReportsEnsembleSize import ReportsEnsembleSize
|
|
9
9
|
from .StoresModels import StoresModels
|
|
10
|
+
from .SavesEnsembleModels import SavesEnsembleModels
|
homa/ensemble/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import GraphAttentionModule
|
|
3
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GraphAttention(MovesNetworkToDevice):
|
|
7
|
+
def __init__(self, lr: float = 0.005, decay: float = 5e-4, dropout: float = 0.6):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.network = GraphAttentionModule()
|
|
10
|
+
self.optimizer = torch.nn.AdamW(
|
|
11
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
12
|
+
)
|
|
13
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
homa/graph/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .GraphAttention import GraphAttention
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GraphAttentionHeadModule(torch.nn.Module):
|
|
5
|
+
def __init__(self, in_features: int, out_features: int, alpha=0.2):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.in_features = in_features
|
|
8
|
+
self.out_features = out_features
|
|
9
|
+
self.alpha = alpha
|
|
10
|
+
|
|
11
|
+
self.W = torch.nn.Linear(in_features, out_features, bias=False)
|
|
12
|
+
self.a_1 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
13
|
+
self.a_2 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
14
|
+
|
|
15
|
+
self.leaky_relu = torch.nn.LeakyReLU(self.alpha)
|
|
16
|
+
self.elu = torch.nn.ELU()
|
|
17
|
+
self.reset_parameters()
|
|
18
|
+
|
|
19
|
+
def reset_parameters(self):
|
|
20
|
+
torch.nn.init.xavier_uniform_(self.W.weight, gain=1.414)
|
|
21
|
+
torch.nn.init.xavier_uniform_(self.a_1, gain=1.414)
|
|
22
|
+
torch.nn.init.xavier_uniform_(self.a_2, gain=1.414)
|
|
23
|
+
|
|
24
|
+
def forward(self, node_features, adj_matrix):
|
|
25
|
+
N = node_features.size(0)
|
|
26
|
+
h_prime = self.W(node_features)
|
|
27
|
+
s1 = torch.matmul(h_prime, self.a_1)
|
|
28
|
+
s2 = torch.matmul(h_prime, self.a_2)
|
|
29
|
+
e = s1 + s2.T
|
|
30
|
+
e = self.leaky_relu(e)
|
|
31
|
+
zero_vec = -9e15 * torch.ones_like(e)
|
|
32
|
+
attention_mask = torch.where(
|
|
33
|
+
adj_matrix > 0, e, zero_vec.to(node_features.device)
|
|
34
|
+
)
|
|
35
|
+
attention_weights = F.softmax(attention_mask, dim=1)
|
|
36
|
+
h_new = torch.matmul(attention_weights, h_prime)
|
|
37
|
+
return self.elu(h_new)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .GraphAttentionHeadModule import GraphAttentionHeadModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MultiHeadGraphAttentionModule(torch.nn.Module):
|
|
6
|
+
def __init__(self, num_heads: int, in_features: int, out_features: int, alpha=0.2):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.num_heads = num_heads
|
|
9
|
+
self.head_out_features = out_features
|
|
10
|
+
self.heads = torch.nn.ModuleList(
|
|
11
|
+
[
|
|
12
|
+
GraphAttentionHeadModule(in_features, out_features, alpha=alpha)
|
|
13
|
+
for _ in range(num_heads)
|
|
14
|
+
]
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def forward(
|
|
18
|
+
self, node_features: torch.Tensor, adj_matrix: torch.Tensor
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
outputs = [head(node_features, adj_matrix) for head in self.heads]
|
|
21
|
+
h_new_concat = torch.cat(outputs, dim=1)
|
|
22
|
+
return h_new_concat
|
homa/loss/Loss.py
CHANGED
homa/rl/DQN.py
ADDED
homa/rl/DRQN.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .diayn.Actor import Actor
|
|
3
|
+
from .diayn.Critic import Critic
|
|
4
|
+
from .diayn.Discriminator import Discriminator
|
|
5
|
+
from .buffers import DiversityIsAllYouNeedBuffer, Buffer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeed:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
state_dimension: int,
|
|
12
|
+
action_dimension: int,
|
|
13
|
+
hidden_dimension: int = 256,
|
|
14
|
+
num_skills: int = 10,
|
|
15
|
+
critic_decay: float = 0.0,
|
|
16
|
+
actor_decay: float = 0.0,
|
|
17
|
+
discriminator_decay: float = 0.0,
|
|
18
|
+
actor_lr: float = 0.0001,
|
|
19
|
+
critic_lr: float = 0.001,
|
|
20
|
+
discriminator_lr=0.001,
|
|
21
|
+
buffer_capacity: int = 1_000_000,
|
|
22
|
+
actor_epsilon: float = 1e-6,
|
|
23
|
+
gamma: float = 0.99,
|
|
24
|
+
min_std: float = -20.0,
|
|
25
|
+
max_std: float = 2.0,
|
|
26
|
+
):
|
|
27
|
+
self.buffer: Buffer = DiversityIsAllYouNeedBuffer(capacity=buffer_capacity)
|
|
28
|
+
self.num_skills: int = num_skills
|
|
29
|
+
self.actor = Actor(
|
|
30
|
+
state_dimension=state_dimension,
|
|
31
|
+
action_dimension=action_dimension,
|
|
32
|
+
hidden_dimension=hidden_dimension,
|
|
33
|
+
num_skills=num_skills,
|
|
34
|
+
lr=actor_lr,
|
|
35
|
+
decay=actor_decay,
|
|
36
|
+
epsilon=actor_epsilon,
|
|
37
|
+
min_std=min_std,
|
|
38
|
+
max_std=max_std,
|
|
39
|
+
)
|
|
40
|
+
self.critic = Critic(
|
|
41
|
+
state_dimension=state_dimension,
|
|
42
|
+
hidden_dimension=hidden_dimension,
|
|
43
|
+
num_skills=num_skills,
|
|
44
|
+
lr=critic_lr,
|
|
45
|
+
decay=critic_decay,
|
|
46
|
+
gamma=gamma,
|
|
47
|
+
)
|
|
48
|
+
self.discriminator = Discriminator(
|
|
49
|
+
state_dimension=state_dimension,
|
|
50
|
+
hidden_dimension=hidden_dimension,
|
|
51
|
+
num_skills=num_skills,
|
|
52
|
+
lr=discriminator_lr,
|
|
53
|
+
decay=discriminator_decay,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def one_hot(self, indices, max_index) -> torch.Tensor:
|
|
57
|
+
one_hot = torch.zeros(indices.size(0), max_index)
|
|
58
|
+
one_hot.scatter_(1, indices.unsqueeze(1), 1)
|
|
59
|
+
return one_hot
|
|
60
|
+
|
|
61
|
+
def skill_index(self) -> torch.Tensor:
|
|
62
|
+
return torch.randint(0, self.num_skills, (1,))
|
|
63
|
+
|
|
64
|
+
def skill(self) -> torch.Tensor:
|
|
65
|
+
return self.one_hot(self.skill_index(), self.num_skills)
|
|
66
|
+
|
|
67
|
+
def advantages(
|
|
68
|
+
self,
|
|
69
|
+
states: torch.Tensor,
|
|
70
|
+
skills: torch.Tensor,
|
|
71
|
+
rewards: torch.Tensor,
|
|
72
|
+
terminations: torch.Tensor,
|
|
73
|
+
next_states: torch.Tensor,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
values = self.critic.values(states=states, skills=skills)
|
|
76
|
+
termination_mask = 1 - terminations
|
|
77
|
+
next_values = self.critic.values_(states=next_states, skills=skills)
|
|
78
|
+
update = self.gamma * next_values * termination_mask
|
|
79
|
+
return rewards + update - values
|
|
80
|
+
|
|
81
|
+
def train(self, skill: torch.Tensor):
|
|
82
|
+
data = self.buffer.all_tensor()
|
|
83
|
+
skill_indices = skill.repeat(data.states.size(0), 1).long()
|
|
84
|
+
skills_indices_one_hot = self.one_hot(skill_indices, self.num_skills)
|
|
85
|
+
self.discriminator.train(
|
|
86
|
+
states=data.states, skills_indices=skills_indices_one_hot
|
|
87
|
+
)
|
|
88
|
+
advantages = self.advantages(
|
|
89
|
+
states=data.states,
|
|
90
|
+
rewards=data.rewards,
|
|
91
|
+
terminations=data.terminations,
|
|
92
|
+
next_states=data.next_states,
|
|
93
|
+
skills=skills,
|
|
94
|
+
)
|
|
95
|
+
self.critic.train(advantages=advantages)
|
|
96
|
+
self.actor.train(advantages=advantages)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from .sac import SoftActor, SoftCritic
|
|
2
|
+
from .buffers import SoftActorCriticBuffer
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SoftActorCritic:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
state_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
hidden_dimension: int = 256,
|
|
11
|
+
buffer_capacity: int = 1_000_000,
|
|
12
|
+
batch_size: int = 256,
|
|
13
|
+
actor_lr: float = 0.0002,
|
|
14
|
+
critic_lr: float = 0.0003,
|
|
15
|
+
actor_decay: float = 0.0,
|
|
16
|
+
critic_decay: float = 0.0,
|
|
17
|
+
tau: float = 0.005,
|
|
18
|
+
alpha: float = 0.2,
|
|
19
|
+
gamma: float = 0.99,
|
|
20
|
+
min_std: float = -20,
|
|
21
|
+
max_std: float = 2,
|
|
22
|
+
warmup: int = 10_000,
|
|
23
|
+
):
|
|
24
|
+
self.batch_size: int = batch_size
|
|
25
|
+
self.warmup: int = warmup
|
|
26
|
+
|
|
27
|
+
self.actor = SoftActor(
|
|
28
|
+
state_dimension=state_dimension,
|
|
29
|
+
action_dimension=action_dimension,
|
|
30
|
+
hidden_dimension=hidden_dimension,
|
|
31
|
+
lr=actor_lr,
|
|
32
|
+
weight_decay=actor_decay,
|
|
33
|
+
alpha=alpha,
|
|
34
|
+
min_std=min_std,
|
|
35
|
+
max_std=max_std,
|
|
36
|
+
)
|
|
37
|
+
self.critic = SoftCritic(
|
|
38
|
+
state_dimension=state_dimension,
|
|
39
|
+
action_dimension=action_dimension,
|
|
40
|
+
hidden_dimension=hidden_dimension,
|
|
41
|
+
lr=critic_lr,
|
|
42
|
+
weight_decay=critic_decay,
|
|
43
|
+
tau=tau,
|
|
44
|
+
gamma=gamma,
|
|
45
|
+
alpha=alpha,
|
|
46
|
+
)
|
|
47
|
+
self.buffer = SoftActorCriticBuffer(capacity=buffer_capacity)
|
|
48
|
+
|
|
49
|
+
def train(self):
|
|
50
|
+
# don't train before warmup
|
|
51
|
+
if self.buffer.size < self.warmup:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
data = self.buffer.sample_torch(self.batch_size)
|
|
55
|
+
self.critic.train(
|
|
56
|
+
states=data.states,
|
|
57
|
+
actions=data.actions,
|
|
58
|
+
rewards=data.rewards,
|
|
59
|
+
terminations=data.terminations,
|
|
60
|
+
next_states=data.next_states,
|
|
61
|
+
actor=self.actor,
|
|
62
|
+
)
|
|
63
|
+
self.actor.train(states=data.states, critic_network=self.critic.network)
|
|
64
|
+
self.critic.update()
|
homa/rl/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .concerns import ResetsCollection, HasRecordAlternatives
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Buffer(ResetsCollection, HasRecordAlternatives):
|
|
5
|
+
def __init__(self, capacity: int):
|
|
6
|
+
self.capacity: int = capacity
|
|
7
|
+
self.reset()
|
|
8
|
+
|
|
9
|
+
@property
|
|
10
|
+
def size(self):
|
|
11
|
+
return len(self.collection)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from .Buffer import Buffer
|
|
5
|
+
from .concerns import HasRecordAlternatives
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeedBuffer(Buffer, HasRecordAlternatives):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def all_tensor(self) -> SimpleNamespace:
|
|
13
|
+
return self.all(tensor=True)
|
|
14
|
+
|
|
15
|
+
def all(self, tensor: bool = False) -> SimpleNamespace:
|
|
16
|
+
states, actions, rewards, next_states, terminations, probabilities = zip(
|
|
17
|
+
*self.collection
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if tensor:
|
|
21
|
+
states = torch.from_numpy(numpy.array(states))
|
|
22
|
+
actions = torch.from_numpy(numpy.array(actions))
|
|
23
|
+
rewards = torch.from_numpy(numpy.array(rewards))
|
|
24
|
+
next_states = torch.from_numpy(numpy.array(next_states))
|
|
25
|
+
terminations = torch.from_numpy(numpy.array(terminations))
|
|
26
|
+
probabilities = torch.from_numpy(numpy.array(probabilities))
|
|
27
|
+
|
|
28
|
+
return SimpleNamespace(
|
|
29
|
+
**{
|
|
30
|
+
"states": states,
|
|
31
|
+
"actions": actions,
|
|
32
|
+
"rewards": rewards,
|
|
33
|
+
"next_states": next_states,
|
|
34
|
+
"terminations": terminations,
|
|
35
|
+
"probabilities": probabilities,
|
|
36
|
+
}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def record(
|
|
40
|
+
self,
|
|
41
|
+
state: numpy.ndarray,
|
|
42
|
+
action: int,
|
|
43
|
+
reward: float,
|
|
44
|
+
next_state: numpy.ndarray,
|
|
45
|
+
termination: bool,
|
|
46
|
+
probability: numpy.ndarray,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.collection.append(
|
|
49
|
+
(state, action, reward, next_state, termination, probability)
|
|
50
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import numpy
|
|
2
|
+
import random
|
|
3
|
+
import torch
|
|
4
|
+
from types import SimpleNamespace
|
|
5
|
+
from .Buffer import Buffer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SoftActorCriticBuffer(Buffer):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def record(
|
|
13
|
+
self,
|
|
14
|
+
state: numpy.ndarray,
|
|
15
|
+
action: int,
|
|
16
|
+
reward: float,
|
|
17
|
+
next_state: numpy.ndarray,
|
|
18
|
+
termination: float,
|
|
19
|
+
probability: numpy.ndarray,
|
|
20
|
+
):
|
|
21
|
+
self.collection.append(
|
|
22
|
+
(state, action, reward, next_state, termination, probability)
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def sample(self, k: int, as_tensor: bool = False):
|
|
26
|
+
batch = random.sample(self.collection, k)
|
|
27
|
+
states, actions, rewards, next_states, terminations, probabilities = zip(*batch)
|
|
28
|
+
|
|
29
|
+
states = numpy.array(states)
|
|
30
|
+
actions = numpy.array(actions)
|
|
31
|
+
rewards = numpy.array(rewards)
|
|
32
|
+
next_states = numpy.array(next_states)
|
|
33
|
+
terminations = numpy.array(terminations)
|
|
34
|
+
probabilities = numpy.array(probabilities)
|
|
35
|
+
|
|
36
|
+
if as_tensor:
|
|
37
|
+
states = torch.from_numpy(states).float()
|
|
38
|
+
actions = torch.from_numpy(actions).long()
|
|
39
|
+
rewards = torch.from_numpy(rewards).float()
|
|
40
|
+
next_states = torch.from_numpy(next_states).float()
|
|
41
|
+
terminations = torch.from_numpy(terminations).float()
|
|
42
|
+
probabilities = torch.from_numpy(probabilities).float()
|
|
43
|
+
|
|
44
|
+
return SimpleNamespace(
|
|
45
|
+
**{
|
|
46
|
+
"states": states,
|
|
47
|
+
"actions": actions,
|
|
48
|
+
"rewards": rewards,
|
|
49
|
+
"next_states": next_states,
|
|
50
|
+
"terminations": terminations,
|
|
51
|
+
"probabilities": probabilities,
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def sample_torch(self, k: int):
|
|
56
|
+
return self.sample(k=k, as_tensor=True)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class HasRecordAlternatives:
|
|
2
|
+
def __init__(self, *args, **kwargs):
|
|
3
|
+
super().__init__(*args, **kwargs)
|
|
4
|
+
|
|
5
|
+
def add(self, *args, **kwargs) -> None:
|
|
6
|
+
self.record(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def push(self, *args, **kwargs) -> None:
|
|
9
|
+
self.record(*args, **kwargs)
|
|
10
|
+
|
|
11
|
+
def append(self, *args, **kwargs) -> None:
|
|
12
|
+
self.record(*args, **kwargs)
|
homa/rl/diayn/Actor.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.distributions import Normal
|
|
3
|
+
from .modules import ContinuousActorModule
|
|
4
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Actor(MovesNetworkToDevice):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
state_dimension: int,
|
|
11
|
+
action_dimension: int,
|
|
12
|
+
num_skills: int,
|
|
13
|
+
hidden_dimension: int,
|
|
14
|
+
lr: float,
|
|
15
|
+
decay: float,
|
|
16
|
+
epsilon: float,
|
|
17
|
+
min_std: float,
|
|
18
|
+
max_std: float,
|
|
19
|
+
):
|
|
20
|
+
self.epsilon: float = epsilon
|
|
21
|
+
self.network = ContinuousActorModule(
|
|
22
|
+
state_dimension=state_dimension,
|
|
23
|
+
action_dimension=action_dimension,
|
|
24
|
+
hidden_dimension=hidden_dimension,
|
|
25
|
+
num_skills=num_skills,
|
|
26
|
+
min_std=min_std,
|
|
27
|
+
max_std=max_std,
|
|
28
|
+
)
|
|
29
|
+
self.optimizer = torch.optim.AdamW(
|
|
30
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def action(self, state: torch.Tensor, skill: torch.Tensor):
|
|
34
|
+
mean, std = self.network(state, skill)
|
|
35
|
+
std = std.exp()
|
|
36
|
+
distribution = Normal(mean, std)
|
|
37
|
+
raw_action = distribution.rsample()
|
|
38
|
+
action = torch.tanh(raw_action)
|
|
39
|
+
corrected_probabilities = torch.log(1.0 - action.pow(2) + self.epsilon)
|
|
40
|
+
probabilities = distribution.log_prob(raw_action) - corrected_probabilities
|
|
41
|
+
probabilities = probabilities.sum(dim=-1, keepdim=True)
|
|
42
|
+
return action, probabilities
|
|
43
|
+
|
|
44
|
+
def train(self, advantages: torch.Tensor, probabilities: torch.Tensor) -> float:
|
|
45
|
+
self.optimizer.zero_grad()
|
|
46
|
+
loss = self.loss(advantages=advantages, probabilities=probabilities)
|
|
47
|
+
loss.backward()
|
|
48
|
+
self.optimizer.step()
|
|
49
|
+
return loss.item()
|
|
50
|
+
|
|
51
|
+
def loss(
|
|
52
|
+
self, advantages: torch.Tensor, probabilities: torch.Tensor
|
|
53
|
+
) -> torch.Tensor:
|
|
54
|
+
return -(probabilities * advantages.detach()).mean()
|
homa/rl/diayn/Critic.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import CriticModule
|
|
3
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Critic(MovesNetworkToDevice):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
state_dimension: int,
|
|
10
|
+
hidden_dimension: int,
|
|
11
|
+
num_skills: int,
|
|
12
|
+
lr: float,
|
|
13
|
+
decay: float,
|
|
14
|
+
gamma: float,
|
|
15
|
+
):
|
|
16
|
+
self.network = CriticModule(
|
|
17
|
+
state_dimension=state_dimension,
|
|
18
|
+
hidden_dimension=hidden_dimension,
|
|
19
|
+
num_skills=num_skills,
|
|
20
|
+
)
|
|
21
|
+
self.optimizer = torch.optim.AdamW(
|
|
22
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
23
|
+
)
|
|
24
|
+
self.criterion = torch.nn.SmoothL1Loss()
|
|
25
|
+
self.gamma: float = gamma
|
|
26
|
+
|
|
27
|
+
def train(self, advantages: torch.Tensor):
|
|
28
|
+
self.optimizer.zero_grad()
|
|
29
|
+
loss = self.loss(advantages=advantages)
|
|
30
|
+
loss.backward()
|
|
31
|
+
self.optimizer.step()
|
|
32
|
+
|
|
33
|
+
def loss(self, advantages: torch.Tensor):
|
|
34
|
+
return advantages.pow(2).mean()
|
|
35
|
+
|
|
36
|
+
def values(self, states: torch.Tensor, skills: torch.Tensor):
|
|
37
|
+
return self.network(states, skills)
|
|
38
|
+
|
|
39
|
+
@torch.no_grad()
|
|
40
|
+
def values_(self, *args, **kwargs):
|
|
41
|
+
return self.values(*args, **kwargs)
|