homa 0.2.9__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/activations/learnable/AOAF.py +1 -1
- homa/activations/learnable/AReLU.py +6 -3
- homa/activations/learnable/PiLU.py +1 -1
- homa/activations/learnable/__init__.py +2 -2
- homa/activations/learnable/concerns/ChannelBased.py +2 -0
- homa/core/__init__.py +0 -0
- homa/core/concerns/MovesNetworkToDevice.py +13 -0
- homa/core/concerns/TracksTime.py +7 -0
- homa/core/concerns/__init__.py +2 -0
- homa/device.py +5 -0
- homa/ensemble/Ensemble.py +4 -2
- homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
- homa/ensemble/concerns/PredictsProbabilities.py +2 -2
- homa/ensemble/concerns/ReportsClassificationMetrics.py +2 -1
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleSize.py +11 -0
- homa/ensemble/concerns/ReportsLogits.py +26 -5
- homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
- homa/ensemble/concerns/StoresModels.py +11 -8
- homa/ensemble/concerns/__init__.py +2 -1
- homa/ensemble/utils.py +9 -0
- homa/graph/GraphAttention.py +13 -0
- homa/graph/__init__.py +1 -0
- homa/graph/modules/GraphAttentionHeadModule.py +37 -0
- homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
- homa/graph/modules/__init__.py +2 -0
- homa/loss/Loss.py +4 -1
- homa/rl/DQN.py +2 -0
- homa/rl/DRQN.py +5 -0
- homa/rl/DiversityIsAllYouNeed.py +96 -0
- homa/rl/SoftActorCritic.py +67 -0
- homa/rl/__init__.py +4 -0
- homa/rl/buffers/Buffer.py +13 -0
- homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
- homa/rl/buffers/ImageBuffer.py +5 -0
- homa/rl/buffers/SoftActorCriticBuffer.py +64 -0
- homa/rl/buffers/__init__.py +4 -0
- homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
- homa/rl/buffers/concerns/ResetsCollection.py +9 -0
- homa/rl/buffers/concerns/__init__.py +2 -0
- homa/rl/diayn/Actor.py +54 -0
- homa/rl/diayn/Critic.py +41 -0
- homa/rl/diayn/Discriminator.py +45 -0
- homa/rl/diayn/__init__.py +3 -0
- homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
- homa/rl/diayn/modules/CriticModule.py +28 -0
- homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
- homa/rl/diayn/modules/__init__.py +3 -0
- homa/rl/sac/SoftActor.py +70 -0
- homa/rl/sac/SoftCritic.py +98 -0
- homa/rl/sac/__init__.py +2 -0
- homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
- homa/rl/sac/modules/SoftActorModule.py +35 -0
- homa/rl/sac/modules/SoftCriticModule.py +30 -0
- homa/rl/sac/modules/__init__.py +3 -0
- homa/rl/utils.py +7 -0
- homa/vision/Resnet.py +3 -3
- homa/vision/Swin.py +17 -5
- homa/vision/modules/SwinModule.py +17 -9
- {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/METADATA +1 -1
- {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/RECORD +66 -28
- homa/ensemble/concerns/ReportsSize.py +0 -11
- homa/torch/__init__.py +0 -1
- homa/torch/helpers.py +0 -6
- {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/WHEEL +0 -0
- {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/entry_points.txt +0 -0
- {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -12,5 +12,5 @@ class AOAF(AdaptiveActivationFunction, ChannelBased):
|
|
|
12
12
|
|
|
13
13
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
14
|
self.initialize(x, "a")
|
|
15
|
-
a = self.a.view(self.
|
|
15
|
+
a = self.a.view(self.parameter_shape(x))
|
|
16
16
|
return torch.relu(x - self.b * a) + self.c * a
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from ...device import get_device
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class AReLU(AdaptiveActivationFunction):
|
|
@@ -7,10 +8,12 @@ class AReLU(AdaptiveActivationFunction):
|
|
|
7
8
|
super(AReLU, self).__init__()
|
|
8
9
|
self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
|
|
9
10
|
self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
|
|
11
|
+
self.a.to(get_device())
|
|
12
|
+
self.b.to(get_device())
|
|
10
13
|
|
|
11
|
-
def forward(self,
|
|
14
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
15
|
negative_slope = torch.clamp(self.a, 0.01, 0.99)
|
|
13
16
|
positive_slope = 1 + torch.sigmoid(self.b)
|
|
14
|
-
positive = positive_slope * torch.relu(
|
|
15
|
-
negative = negative_slope * (-torch.relu(-
|
|
17
|
+
positive = positive_slope * torch.relu(x)
|
|
18
|
+
negative = negative_slope * (-torch.relu(-x))
|
|
16
19
|
return positive + negative
|
|
@@ -3,7 +3,7 @@ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
|
3
3
|
from .concerns import ChannelBased
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class
|
|
6
|
+
class PiLU(AdaptiveActivationFunction, ChannelBased):
|
|
7
7
|
def __init__(self):
|
|
8
8
|
super().__init__()
|
|
9
9
|
self.a = None
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from .StarReLU import StarReLU
|
|
2
1
|
from .DualLine import DualLine
|
|
3
2
|
from .LeLeLU import LeLeLU
|
|
4
3
|
from .AReLU import AReLU
|
|
5
4
|
from .PERU import PERU
|
|
6
5
|
from .ShiLU import ShiLU
|
|
6
|
+
from .StarReLU import StarReLU
|
|
7
7
|
from .DPReLU import DPReLU
|
|
8
|
-
from .PiLU import
|
|
8
|
+
from .PiLU import PiLU
|
|
9
9
|
from .FReLU import FReLU
|
|
10
10
|
from .AOAF import AOAF
|
|
@@ -21,12 +21,14 @@ class ChannelBased:
|
|
|
21
21
|
attrs = [attrs]
|
|
22
22
|
|
|
23
23
|
self.num_channels = x.shape[1]
|
|
24
|
+
device = x.device
|
|
24
25
|
for index, attr in enumerate(attrs):
|
|
25
26
|
if index < len(values) and values[index] is not None:
|
|
26
27
|
default_value = float(values[index])
|
|
27
28
|
else:
|
|
28
29
|
default_value = 1.0
|
|
29
30
|
param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
|
|
31
|
+
param = param.to(device)
|
|
30
32
|
setattr(self, attr, param)
|
|
31
33
|
self._initialized = True
|
|
32
34
|
|
homa/core/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ...device import move
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MovesNetworkToDevice:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
if not hasattr(self, "network"):
|
|
9
|
+
raise RuntimeError(
|
|
10
|
+
"MovesNetworkToDevice assumes the underlying class has a network property."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
move(self.network)
|
homa/device.py
CHANGED
homa/ensemble/Ensemble.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from .concerns import (
|
|
2
|
-
|
|
2
|
+
ReportsEnsembleSize,
|
|
3
3
|
StoresModels,
|
|
4
4
|
ReportsClassificationMetrics,
|
|
5
5
|
PredictsProbabilities,
|
|
6
|
+
SavesEnsembleModels,
|
|
6
7
|
)
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class Ensemble(
|
|
10
|
-
|
|
11
|
+
ReportsEnsembleSize,
|
|
11
12
|
ReportsClassificationMetrics,
|
|
12
13
|
PredictsProbabilities,
|
|
13
14
|
StoresModels,
|
|
15
|
+
SavesEnsembleModels,
|
|
14
16
|
):
|
|
15
17
|
def __init__(self):
|
|
16
18
|
super().__init__()
|
|
@@ -3,8 +3,8 @@ from .ReportsLogits import ReportsLogits
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class PredictsProbabilities(ReportsLogits):
|
|
6
|
-
def __init__(self
|
|
7
|
-
super().__init__(
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__()
|
|
8
8
|
|
|
9
9
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
10
10
|
logits = self.logits(x)
|
|
@@ -3,8 +3,8 @@ from torch.utils.data import DataLoader
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class ReportsEnsembleAccuracy:
|
|
6
|
-
def __init__(self
|
|
7
|
-
super().__init__(
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__()
|
|
8
8
|
|
|
9
9
|
def accuracy(self, dataloader: DataLoader) -> float:
|
|
10
10
|
predictions, labels = self.metric_necessities(dataloader)
|
|
@@ -2,8 +2,8 @@ from sklearn.metrics import f1_score as f1
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsEnsembleF1:
|
|
5
|
-
def __init__(self
|
|
6
|
-
super().__init__(
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
7
|
|
|
8
8
|
def f1(self) -> float:
|
|
9
9
|
predictions, labels = self.metric_necessities()
|
|
@@ -2,8 +2,8 @@ from sklearn.metrics import cohen_kappa_score as kappa
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsEnsembleKappa:
|
|
5
|
-
def __init__(self
|
|
6
|
-
super().__init__(
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
7
|
|
|
8
8
|
def accuracy(self) -> float:
|
|
9
9
|
predictions, labels = self.metric_necessities()
|
|
@@ -2,16 +2,37 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsLogits:
|
|
5
|
-
def __init__(self
|
|
6
|
-
super().__init__(
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
7
|
|
|
8
|
-
def
|
|
9
|
-
|
|
8
|
+
def logits_average(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
|
+
return self.logits_sim(x) / len(self.factories)
|
|
10
|
+
|
|
11
|
+
def logits_sum(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
batch_size = x.size(0)
|
|
10
13
|
logits = torch.zeros((batch_size, self.num_classes))
|
|
11
|
-
for
|
|
14
|
+
for factory, weight in zip(self.factories, self.weights):
|
|
15
|
+
model = factory(num_classes=self.num_classes)
|
|
16
|
+
model.load_state_dict(weight)
|
|
12
17
|
logits += model(x)
|
|
13
18
|
return logits
|
|
14
19
|
|
|
20
|
+
def check_aggregation_strategy(self, aggregation: str):
|
|
21
|
+
if aggregation not in ["mean", "average", "sum"]:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Ensemble aggregation strategy must be in [mean, average, sum], but found {aggregation}."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def logits(self, x: torch.Tensor, aggregation: str = "mean") -> torch.Tensor:
|
|
27
|
+
self.check_aggregation_strategy(aggregation=aggregation)
|
|
28
|
+
logits_handlers = {
|
|
29
|
+
"mean": self.logits_average,
|
|
30
|
+
"average": self.logits_average,
|
|
31
|
+
"sum": self.logits_sum,
|
|
32
|
+
}
|
|
33
|
+
handler = logits_handlers.get(aggregation)
|
|
34
|
+
return handler(x)
|
|
35
|
+
|
|
15
36
|
@torch.no_grad()
|
|
16
37
|
def logits_(self, *args, **kwargs):
|
|
17
38
|
return self.logits(*args, **kwargs)
|
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from
|
|
3
|
-
from
|
|
2
|
+
from typing import List, Type
|
|
3
|
+
from collections import OrderedDict
|
|
4
4
|
from ...vision import Model
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class StoresModels:
|
|
8
|
-
def __init__(self
|
|
9
|
-
super().__init__(
|
|
10
|
-
self.
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.factories: List[Type[torch.nn.Module]] = []
|
|
11
|
+
self.weights: List[OrderedDict] = []
|
|
11
12
|
|
|
12
13
|
def record(self, model: Model | torch.nn.Module):
|
|
13
14
|
model_: torch.nn.Module | None = None
|
|
14
15
|
if isinstance(model, Model):
|
|
15
|
-
model_ =
|
|
16
|
+
model_ = model.network
|
|
16
17
|
elif isinstance(model, torch.nn.Module):
|
|
17
|
-
model_ =
|
|
18
|
+
model_ = model
|
|
18
19
|
else:
|
|
19
20
|
raise TypeError("Wrong input to ensemble record")
|
|
20
|
-
|
|
21
|
+
|
|
22
|
+
self.factories.append(model_.__class__)
|
|
23
|
+
self.weights.append(model_.state_dict())
|
|
21
24
|
|
|
22
25
|
def push(self, *args, **kwargs):
|
|
23
26
|
self.record(*args, **kwargs)
|
|
@@ -5,5 +5,6 @@ from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
|
|
|
5
5
|
from .ReportsEnsembleF1 import ReportsEnsembleF1
|
|
6
6
|
from .ReportsEnsembleKappa import ReportsEnsembleKappa
|
|
7
7
|
from .ReportsLogits import ReportsLogits
|
|
8
|
-
from .
|
|
8
|
+
from .ReportsEnsembleSize import ReportsEnsembleSize
|
|
9
9
|
from .StoresModels import StoresModels
|
|
10
|
+
from .SavesEnsembleModels import SavesEnsembleModels
|
homa/ensemble/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import GraphAttentionModule
|
|
3
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GraphAttention(MovesNetworkToDevice):
|
|
7
|
+
def __init__(self, lr: float = 0.005, decay: float = 5e-4, dropout: float = 0.6):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.network = GraphAttentionModule()
|
|
10
|
+
self.optimizer = torch.nn.AdamW(
|
|
11
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
12
|
+
)
|
|
13
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
homa/graph/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .GraphAttention import GraphAttention
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GraphAttentionHeadModule(torch.nn.Module):
|
|
5
|
+
def __init__(self, in_features: int, out_features: int, alpha=0.2):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.in_features = in_features
|
|
8
|
+
self.out_features = out_features
|
|
9
|
+
self.alpha = alpha
|
|
10
|
+
|
|
11
|
+
self.W = torch.nn.Linear(in_features, out_features, bias=False)
|
|
12
|
+
self.a_1 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
13
|
+
self.a_2 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
14
|
+
|
|
15
|
+
self.leaky_relu = torch.nn.LeakyReLU(self.alpha)
|
|
16
|
+
self.elu = torch.nn.ELU()
|
|
17
|
+
self.reset_parameters()
|
|
18
|
+
|
|
19
|
+
def reset_parameters(self):
|
|
20
|
+
torch.nn.init.xavier_uniform_(self.W.weight, gain=1.414)
|
|
21
|
+
torch.nn.init.xavier_uniform_(self.a_1, gain=1.414)
|
|
22
|
+
torch.nn.init.xavier_uniform_(self.a_2, gain=1.414)
|
|
23
|
+
|
|
24
|
+
def forward(self, node_features, adj_matrix):
|
|
25
|
+
N = node_features.size(0)
|
|
26
|
+
h_prime = self.W(node_features)
|
|
27
|
+
s1 = torch.matmul(h_prime, self.a_1)
|
|
28
|
+
s2 = torch.matmul(h_prime, self.a_2)
|
|
29
|
+
e = s1 + s2.T
|
|
30
|
+
e = self.leaky_relu(e)
|
|
31
|
+
zero_vec = -9e15 * torch.ones_like(e)
|
|
32
|
+
attention_mask = torch.where(
|
|
33
|
+
adj_matrix > 0, e, zero_vec.to(node_features.device)
|
|
34
|
+
)
|
|
35
|
+
attention_weights = F.softmax(attention_mask, dim=1)
|
|
36
|
+
h_new = torch.matmul(attention_weights, h_prime)
|
|
37
|
+
return self.elu(h_new)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .GraphAttentionHeadModule import GraphAttentionHeadModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MultiHeadGraphAttentionModule(torch.nn.Module):
|
|
6
|
+
def __init__(self, num_heads: int, in_features: int, out_features: int, alpha=0.2):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.num_heads = num_heads
|
|
9
|
+
self.head_out_features = out_features
|
|
10
|
+
self.heads = torch.nn.ModuleList(
|
|
11
|
+
[
|
|
12
|
+
GraphAttentionHeadModule(in_features, out_features, alpha=alpha)
|
|
13
|
+
for _ in range(num_heads)
|
|
14
|
+
]
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def forward(
|
|
18
|
+
self, node_features: torch.Tensor, adj_matrix: torch.Tensor
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
outputs = [head(node_features, adj_matrix) for head in self.heads]
|
|
21
|
+
h_new_concat = torch.cat(outputs, dim=1)
|
|
22
|
+
return h_new_concat
|
homa/loss/Loss.py
CHANGED
homa/rl/DQN.py
ADDED
homa/rl/DRQN.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .diayn.Actor import Actor
|
|
3
|
+
from .diayn.Critic import Critic
|
|
4
|
+
from .diayn.Discriminator import Discriminator
|
|
5
|
+
from .buffers import DiversityIsAllYouNeedBuffer, Buffer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeed:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
state_dimension: int,
|
|
12
|
+
action_dimension: int,
|
|
13
|
+
hidden_dimension: int = 256,
|
|
14
|
+
num_skills: int = 10,
|
|
15
|
+
critic_decay: float = 0.0,
|
|
16
|
+
actor_decay: float = 0.0,
|
|
17
|
+
discriminator_decay: float = 0.0,
|
|
18
|
+
actor_lr: float = 0.0001,
|
|
19
|
+
critic_lr: float = 0.001,
|
|
20
|
+
discriminator_lr=0.001,
|
|
21
|
+
buffer_capacity: int = 1_000_000,
|
|
22
|
+
actor_epsilon: float = 1e-6,
|
|
23
|
+
gamma: float = 0.99,
|
|
24
|
+
min_std: float = -20.0,
|
|
25
|
+
max_std: float = 2.0,
|
|
26
|
+
):
|
|
27
|
+
self.buffer: Buffer = DiversityIsAllYouNeedBuffer(capacity=buffer_capacity)
|
|
28
|
+
self.num_skills: int = num_skills
|
|
29
|
+
self.actor = Actor(
|
|
30
|
+
state_dimension=state_dimension,
|
|
31
|
+
action_dimension=action_dimension,
|
|
32
|
+
hidden_dimension=hidden_dimension,
|
|
33
|
+
num_skills=num_skills,
|
|
34
|
+
lr=actor_lr,
|
|
35
|
+
decay=actor_decay,
|
|
36
|
+
epsilon=actor_epsilon,
|
|
37
|
+
min_std=min_std,
|
|
38
|
+
max_std=max_std,
|
|
39
|
+
)
|
|
40
|
+
self.critic = Critic(
|
|
41
|
+
state_dimension=state_dimension,
|
|
42
|
+
hidden_dimension=hidden_dimension,
|
|
43
|
+
num_skills=num_skills,
|
|
44
|
+
lr=critic_lr,
|
|
45
|
+
decay=critic_decay,
|
|
46
|
+
gamma=gamma,
|
|
47
|
+
)
|
|
48
|
+
self.discriminator = Discriminator(
|
|
49
|
+
state_dimension=state_dimension,
|
|
50
|
+
hidden_dimension=hidden_dimension,
|
|
51
|
+
num_skills=num_skills,
|
|
52
|
+
lr=discriminator_lr,
|
|
53
|
+
decay=discriminator_decay,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def one_hot(self, indices, max_index) -> torch.Tensor:
|
|
57
|
+
one_hot = torch.zeros(indices.size(0), max_index)
|
|
58
|
+
one_hot.scatter_(1, indices.unsqueeze(1), 1)
|
|
59
|
+
return one_hot
|
|
60
|
+
|
|
61
|
+
def skill_index(self) -> torch.Tensor:
|
|
62
|
+
return torch.randint(0, self.num_skills, (1,))
|
|
63
|
+
|
|
64
|
+
def skill(self) -> torch.Tensor:
|
|
65
|
+
return self.one_hot(self.skill_index(), self.num_skills)
|
|
66
|
+
|
|
67
|
+
def advantages(
|
|
68
|
+
self,
|
|
69
|
+
states: torch.Tensor,
|
|
70
|
+
skills: torch.Tensor,
|
|
71
|
+
rewards: torch.Tensor,
|
|
72
|
+
terminations: torch.Tensor,
|
|
73
|
+
next_states: torch.Tensor,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
values = self.critic.values(states=states, skills=skills)
|
|
76
|
+
termination_mask = 1 - terminations
|
|
77
|
+
next_values = self.critic.values_(states=next_states, skills=skills)
|
|
78
|
+
update = self.gamma * next_values * termination_mask
|
|
79
|
+
return rewards + update - values
|
|
80
|
+
|
|
81
|
+
def train(self, skill: torch.Tensor):
|
|
82
|
+
data = self.buffer.all_tensor()
|
|
83
|
+
skill_indices = skill.repeat(data.states.size(0), 1).long()
|
|
84
|
+
skills_indices_one_hot = self.one_hot(skill_indices, self.num_skills)
|
|
85
|
+
self.discriminator.train(
|
|
86
|
+
states=data.states, skills_indices=skills_indices_one_hot
|
|
87
|
+
)
|
|
88
|
+
advantages = self.advantages(
|
|
89
|
+
states=data.states,
|
|
90
|
+
rewards=data.rewards,
|
|
91
|
+
terminations=data.terminations,
|
|
92
|
+
next_states=data.next_states,
|
|
93
|
+
skills=skills,
|
|
94
|
+
)
|
|
95
|
+
self.critic.train(advantages=advantages)
|
|
96
|
+
self.actor.train(advantages=advantages)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from .sac import SoftActor, SoftCritic
|
|
2
|
+
from .buffers import SoftActorCriticBuffer
|
|
3
|
+
from ..core.concerns import TracksTime
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SoftActorCritic(TracksTime):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
state_dimension: int,
|
|
10
|
+
action_dimension: int,
|
|
11
|
+
hidden_dimension: int = 256,
|
|
12
|
+
buffer_capacity: int = 100_000,
|
|
13
|
+
batch_size: int = 256,
|
|
14
|
+
actor_lr: float = 0.0002,
|
|
15
|
+
critic_lr: float = 0.0003,
|
|
16
|
+
actor_decay: float = 0.0,
|
|
17
|
+
critic_decay: float = 0.0,
|
|
18
|
+
tau: float = 0.005,
|
|
19
|
+
alpha: float = 0.2,
|
|
20
|
+
gamma: float = 0.99,
|
|
21
|
+
min_std: float = -20,
|
|
22
|
+
max_std: float = 2,
|
|
23
|
+
warmup: int = 20_000,
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
self.batch_size: int = batch_size
|
|
28
|
+
self.warmup: int = warmup
|
|
29
|
+
self.tau: float = tau
|
|
30
|
+
|
|
31
|
+
self.actor = SoftActor(
|
|
32
|
+
state_dimension=state_dimension,
|
|
33
|
+
action_dimension=action_dimension,
|
|
34
|
+
hidden_dimension=hidden_dimension,
|
|
35
|
+
lr=actor_lr,
|
|
36
|
+
weight_decay=actor_decay,
|
|
37
|
+
alpha=alpha,
|
|
38
|
+
min_std=min_std,
|
|
39
|
+
max_std=max_std,
|
|
40
|
+
)
|
|
41
|
+
self.critic = SoftCritic(
|
|
42
|
+
state_dimension=state_dimension,
|
|
43
|
+
action_dimension=action_dimension,
|
|
44
|
+
hidden_dimension=hidden_dimension,
|
|
45
|
+
lr=critic_lr,
|
|
46
|
+
weight_decay=critic_decay,
|
|
47
|
+
gamma=gamma,
|
|
48
|
+
alpha=alpha,
|
|
49
|
+
)
|
|
50
|
+
self.buffer = SoftActorCriticBuffer(capacity=buffer_capacity)
|
|
51
|
+
|
|
52
|
+
def train(self):
|
|
53
|
+
# don't train before warmup
|
|
54
|
+
if self.buffer.size < self.warmup:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
data = self.buffer.sample_torch(self.batch_size)
|
|
58
|
+
self.critic.train(
|
|
59
|
+
states=data.states,
|
|
60
|
+
actions=data.actions,
|
|
61
|
+
rewards=data.rewards,
|
|
62
|
+
terminations=data.terminations,
|
|
63
|
+
next_states=data.next_states,
|
|
64
|
+
actor=self.actor,
|
|
65
|
+
)
|
|
66
|
+
self.actor.train(states=data.states, critic=self.critic)
|
|
67
|
+
self.critic.update(tau=self.tau)
|
homa/rl/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Type
|
|
3
|
+
from .concerns import ResetsCollection, HasRecordAlternatives
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Buffer(ResetsCollection, HasRecordAlternatives):
|
|
7
|
+
def __init__(self, capacity: int):
|
|
8
|
+
self.capacity: int = capacity
|
|
9
|
+
self.collection: Type[deque] = deque(maxlen=self.capacity)
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def size(self):
|
|
13
|
+
return len(self.collection)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from .Buffer import Buffer
|
|
5
|
+
from .concerns import HasRecordAlternatives
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeedBuffer(Buffer, HasRecordAlternatives):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def all_tensor(self) -> SimpleNamespace:
|
|
13
|
+
return self.all(tensor=True)
|
|
14
|
+
|
|
15
|
+
def all(self, tensor: bool = False) -> SimpleNamespace:
|
|
16
|
+
states, actions, rewards, next_states, terminations, probabilities = zip(
|
|
17
|
+
*self.collection
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if tensor:
|
|
21
|
+
states = torch.from_numpy(numpy.array(states))
|
|
22
|
+
actions = torch.from_numpy(numpy.array(actions))
|
|
23
|
+
rewards = torch.from_numpy(numpy.array(rewards))
|
|
24
|
+
next_states = torch.from_numpy(numpy.array(next_states))
|
|
25
|
+
terminations = torch.from_numpy(numpy.array(terminations))
|
|
26
|
+
probabilities = torch.from_numpy(numpy.array(probabilities))
|
|
27
|
+
|
|
28
|
+
return SimpleNamespace(
|
|
29
|
+
**{
|
|
30
|
+
"states": states,
|
|
31
|
+
"actions": actions,
|
|
32
|
+
"rewards": rewards,
|
|
33
|
+
"next_states": next_states,
|
|
34
|
+
"terminations": terminations,
|
|
35
|
+
"probabilities": probabilities,
|
|
36
|
+
}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def record(
|
|
40
|
+
self,
|
|
41
|
+
state: numpy.ndarray,
|
|
42
|
+
action: int,
|
|
43
|
+
reward: float,
|
|
44
|
+
next_state: numpy.ndarray,
|
|
45
|
+
termination: bool,
|
|
46
|
+
probability: numpy.ndarray,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.collection.append(
|
|
49
|
+
(state, action, reward, next_state, termination, probability)
|
|
50
|
+
)
|