homa 0.2.9__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. homa/activations/learnable/AOAF.py +1 -1
  2. homa/activations/learnable/AReLU.py +6 -3
  3. homa/activations/learnable/PiLU.py +1 -1
  4. homa/activations/learnable/__init__.py +2 -2
  5. homa/activations/learnable/concerns/ChannelBased.py +2 -0
  6. homa/core/__init__.py +0 -0
  7. homa/core/concerns/MovesNetworkToDevice.py +13 -0
  8. homa/core/concerns/TracksTime.py +7 -0
  9. homa/core/concerns/__init__.py +2 -0
  10. homa/device.py +5 -0
  11. homa/ensemble/Ensemble.py +4 -2
  12. homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
  13. homa/ensemble/concerns/PredictsProbabilities.py +2 -2
  14. homa/ensemble/concerns/ReportsClassificationMetrics.py +2 -1
  15. homa/ensemble/concerns/ReportsEnsembleAccuracy.py +2 -2
  16. homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
  17. homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
  18. homa/ensemble/concerns/ReportsEnsembleSize.py +11 -0
  19. homa/ensemble/concerns/ReportsLogits.py +26 -5
  20. homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
  21. homa/ensemble/concerns/StoresModels.py +11 -8
  22. homa/ensemble/concerns/__init__.py +2 -1
  23. homa/ensemble/utils.py +9 -0
  24. homa/graph/GraphAttention.py +13 -0
  25. homa/graph/__init__.py +1 -0
  26. homa/graph/modules/GraphAttentionHeadModule.py +37 -0
  27. homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
  28. homa/graph/modules/__init__.py +2 -0
  29. homa/loss/Loss.py +4 -1
  30. homa/rl/DQN.py +2 -0
  31. homa/rl/DRQN.py +5 -0
  32. homa/rl/DiversityIsAllYouNeed.py +96 -0
  33. homa/rl/SoftActorCritic.py +67 -0
  34. homa/rl/__init__.py +4 -0
  35. homa/rl/buffers/Buffer.py +13 -0
  36. homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
  37. homa/rl/buffers/ImageBuffer.py +5 -0
  38. homa/rl/buffers/SoftActorCriticBuffer.py +64 -0
  39. homa/rl/buffers/__init__.py +4 -0
  40. homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
  41. homa/rl/buffers/concerns/ResetsCollection.py +9 -0
  42. homa/rl/buffers/concerns/__init__.py +2 -0
  43. homa/rl/diayn/Actor.py +54 -0
  44. homa/rl/diayn/Critic.py +41 -0
  45. homa/rl/diayn/Discriminator.py +45 -0
  46. homa/rl/diayn/__init__.py +3 -0
  47. homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
  48. homa/rl/diayn/modules/CriticModule.py +28 -0
  49. homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
  50. homa/rl/diayn/modules/__init__.py +3 -0
  51. homa/rl/sac/SoftActor.py +70 -0
  52. homa/rl/sac/SoftCritic.py +98 -0
  53. homa/rl/sac/__init__.py +2 -0
  54. homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
  55. homa/rl/sac/modules/SoftActorModule.py +35 -0
  56. homa/rl/sac/modules/SoftCriticModule.py +30 -0
  57. homa/rl/sac/modules/__init__.py +3 -0
  58. homa/rl/utils.py +7 -0
  59. homa/vision/Resnet.py +3 -3
  60. homa/vision/Swin.py +17 -5
  61. homa/vision/modules/SwinModule.py +17 -9
  62. {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/METADATA +1 -1
  63. {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/RECORD +66 -28
  64. homa/ensemble/concerns/ReportsSize.py +0 -11
  65. homa/torch/__init__.py +0 -1
  66. homa/torch/helpers.py +0 -6
  67. {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/WHEEL +0 -0
  68. {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/entry_points.txt +0 -0
  69. {homa-0.2.9.dist-info → homa-0.3.2.dist-info}/top_level.txt +0 -0
@@ -12,5 +12,5 @@ class AOAF(AdaptiveActivationFunction, ChannelBased):
12
12
 
13
13
  def forward(self, x: torch.Tensor) -> torch.Tensor:
14
14
  self.initialize(x, "a")
15
- a = self.a.view(self.parameter_view(x))
15
+ a = self.a.view(self.parameter_shape(x))
16
16
  return torch.relu(x - self.b * a) + self.c * a
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from ...device import get_device
3
4
 
4
5
 
5
6
  class AReLU(AdaptiveActivationFunction):
@@ -7,10 +8,12 @@ class AReLU(AdaptiveActivationFunction):
7
8
  super(AReLU, self).__init__()
8
9
  self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
9
10
  self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
11
+ self.a.to(get_device())
12
+ self.b.to(get_device())
10
13
 
11
- def forward(self, z):
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
15
  negative_slope = torch.clamp(self.a, 0.01, 0.99)
13
16
  positive_slope = 1 + torch.sigmoid(self.b)
14
- positive = positive_slope * torch.relu(z)
15
- negative = negative_slope * (-torch.relu(-z))
17
+ positive = positive_slope * torch.relu(x)
18
+ negative = negative_slope * (-torch.relu(-x))
16
19
  return positive + negative
@@ -3,7 +3,7 @@ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
3
  from .concerns import ChannelBased
4
4
 
5
5
 
6
- class DualLine(AdaptiveActivationFunction, ChannelBased):
6
+ class PiLU(AdaptiveActivationFunction, ChannelBased):
7
7
  def __init__(self):
8
8
  super().__init__()
9
9
  self.a = None
@@ -1,10 +1,10 @@
1
- from .StarReLU import StarReLU
2
1
  from .DualLine import DualLine
3
2
  from .LeLeLU import LeLeLU
4
3
  from .AReLU import AReLU
5
4
  from .PERU import PERU
6
5
  from .ShiLU import ShiLU
6
+ from .StarReLU import StarReLU
7
7
  from .DPReLU import DPReLU
8
- from .PiLU import DualLine
8
+ from .PiLU import PiLU
9
9
  from .FReLU import FReLU
10
10
  from .AOAF import AOAF
@@ -21,12 +21,14 @@ class ChannelBased:
21
21
  attrs = [attrs]
22
22
 
23
23
  self.num_channels = x.shape[1]
24
+ device = x.device
24
25
  for index, attr in enumerate(attrs):
25
26
  if index < len(values) and values[index] is not None:
26
27
  default_value = float(values[index])
27
28
  else:
28
29
  default_value = 1.0
29
30
  param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
31
+ param = param.to(device)
30
32
  setattr(self, attr, param)
31
33
  self._initialized = True
32
34
 
homa/core/__init__.py ADDED
File without changes
@@ -0,0 +1,13 @@
1
+ from ...device import move
2
+
3
+
4
+ class MovesNetworkToDevice:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ if not hasattr(self, "network"):
9
+ raise RuntimeError(
10
+ "MovesNetworkToDevice assumes the underlying class has a network property."
11
+ )
12
+
13
+ move(self.network)
@@ -0,0 +1,7 @@
1
+ class TracksTime:
2
+ def __init__(self):
3
+ super().__init__()
4
+ self.t = 0
5
+
6
+ def tick(self):
7
+ self.t += 1
@@ -0,0 +1,2 @@
1
+ from .MovesNetworkToDevice import MovesNetworkToDevice
2
+ from .TracksTime import TracksTime
homa/device.py CHANGED
@@ -23,3 +23,8 @@ def mps():
23
23
 
24
24
  def device():
25
25
  return get_device()
26
+
27
+
28
+ def move(*modules):
29
+ for module in modules:
30
+ module.to(get_device())
homa/ensemble/Ensemble.py CHANGED
@@ -1,16 +1,18 @@
1
1
  from .concerns import (
2
- ReportsSize,
2
+ ReportsEnsembleSize,
3
3
  StoresModels,
4
4
  ReportsClassificationMetrics,
5
5
  PredictsProbabilities,
6
+ SavesEnsembleModels,
6
7
  )
7
8
 
8
9
 
9
10
  class Ensemble(
10
- ReportsSize,
11
+ ReportsEnsembleSize,
11
12
  ReportsClassificationMetrics,
12
13
  PredictsProbabilities,
13
14
  StoresModels,
15
+ SavesEnsembleModels,
14
16
  ):
15
17
  def __init__(self):
16
18
  super().__init__()
@@ -3,8 +3,8 @@ from ...device import get_device
3
3
 
4
4
 
5
5
  class CalculatesMetricNecessities:
6
- def __init__(self, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
6
+ def __init__(self):
7
+ super().__init__()
8
8
 
9
9
  @torch.no_grad()
10
10
  def metric_necessities(self, dataloader):
@@ -3,8 +3,8 @@ from .ReportsLogits import ReportsLogits
3
3
 
4
4
 
5
5
  class PredictsProbabilities(ReportsLogits):
6
- def __init__(self, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
6
+ def __init__(self):
7
+ super().__init__()
8
8
 
9
9
  def predict(self, x: torch.Tensor) -> torch.Tensor:
10
10
  logits = self.logits(x)
@@ -10,4 +10,5 @@ class ReportsClassificationMetrics(
10
10
  ReportsEnsembleF1,
11
11
  ReportsEnsembleKappa,
12
12
  ):
13
- pass
13
+ def __init__(self):
14
+ super().__init__()
@@ -3,8 +3,8 @@ from torch.utils.data import DataLoader
3
3
 
4
4
 
5
5
  class ReportsEnsembleAccuracy:
6
- def __init__(self, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
6
+ def __init__(self):
7
+ super().__init__()
8
8
 
9
9
  def accuracy(self, dataloader: DataLoader) -> float:
10
10
  predictions, labels = self.metric_necessities(dataloader)
@@ -2,8 +2,8 @@ from sklearn.metrics import f1_score as f1
2
2
 
3
3
 
4
4
  class ReportsEnsembleF1:
5
- def __init__(self, *args, **kwargs):
6
- super().__init__(*args, **kwargs)
5
+ def __init__(self):
6
+ super().__init__()
7
7
 
8
8
  def f1(self) -> float:
9
9
  predictions, labels = self.metric_necessities()
@@ -2,8 +2,8 @@ from sklearn.metrics import cohen_kappa_score as kappa
2
2
 
3
3
 
4
4
  class ReportsEnsembleKappa:
5
- def __init__(self, *args, **kwargs):
6
- super().__init__(*args, **kwargs)
5
+ def __init__(self):
6
+ super().__init__()
7
7
 
8
8
  def accuracy(self) -> float:
9
9
  predictions, labels = self.metric_necessities()
@@ -0,0 +1,11 @@
1
+ class ReportsEnsembleSize:
2
+ def __init__(self):
3
+ super().__init__()
4
+
5
+ @property
6
+ def size(self):
7
+ return len(self.weights)
8
+
9
+ @property
10
+ def length(self):
11
+ return self.size
@@ -2,16 +2,37 @@ import torch
2
2
 
3
3
 
4
4
  class ReportsLogits:
5
- def __init__(self, *args, **kwargs):
6
- super().__init__(*args, **kwargs)
5
+ def __init__(self):
6
+ super().__init__()
7
7
 
8
- def logits(self, x: torch.Tensor) -> torch.Tensor:
9
- batch_size = x.shape[0]
8
+ def logits_average(self, x: torch.Tensor) -> torch.Tensor:
9
+ return self.logits_sim(x) / len(self.factories)
10
+
11
+ def logits_sum(self, x: torch.Tensor) -> torch.Tensor:
12
+ batch_size = x.size(0)
10
13
  logits = torch.zeros((batch_size, self.num_classes))
11
- for model in self.models:
14
+ for factory, weight in zip(self.factories, self.weights):
15
+ model = factory(num_classes=self.num_classes)
16
+ model.load_state_dict(weight)
12
17
  logits += model(x)
13
18
  return logits
14
19
 
20
+ def check_aggregation_strategy(self, aggregation: str):
21
+ if aggregation not in ["mean", "average", "sum"]:
22
+ raise ValueError(
23
+ f"Ensemble aggregation strategy must be in [mean, average, sum], but found {aggregation}."
24
+ )
25
+
26
+ def logits(self, x: torch.Tensor, aggregation: str = "mean") -> torch.Tensor:
27
+ self.check_aggregation_strategy(aggregation=aggregation)
28
+ logits_handlers = {
29
+ "mean": self.logits_average,
30
+ "average": self.logits_average,
31
+ "sum": self.logits_sum,
32
+ }
33
+ handler = logits_handlers.get(aggregation)
34
+ return handler(x)
35
+
15
36
  @torch.no_grad()
16
37
  def logits_(self, *args, **kwargs):
17
38
  return self.logits(*args, **kwargs)
@@ -0,0 +1,13 @@
1
+ class SavesEnsembleModels:
2
+ def __init__(self):
3
+ super().__init__()
4
+
5
+ def save(self):
6
+ self.save_factories()
7
+ self.save_weights()
8
+
9
+ def save_factories(self):
10
+ pass
11
+
12
+ def save_weights(self):
13
+ pass
@@ -1,23 +1,26 @@
1
1
  import torch
2
- from copy import deepcopy
3
- from typing import List
2
+ from typing import List, Type
3
+ from collections import OrderedDict
4
4
  from ...vision import Model
5
5
 
6
6
 
7
7
  class StoresModels:
8
- def __init__(self, *args, **kwargs):
9
- super().__init__(*args, **kwargs)
10
- self.models: List[torch.nn.Module] = []
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.factories: List[Type[torch.nn.Module]] = []
11
+ self.weights: List[OrderedDict] = []
11
12
 
12
13
  def record(self, model: Model | torch.nn.Module):
13
14
  model_: torch.nn.Module | None = None
14
15
  if isinstance(model, Model):
15
- model_ = deepcopy(model.network)
16
+ model_ = model.network
16
17
  elif isinstance(model, torch.nn.Module):
17
- model_ = deepcopy(model)
18
+ model_ = model
18
19
  else:
19
20
  raise TypeError("Wrong input to ensemble record")
20
- self.models.append(model_)
21
+
22
+ self.factories.append(model_.__class__)
23
+ self.weights.append(model_.state_dict())
21
24
 
22
25
  def push(self, *args, **kwargs):
23
26
  self.record(*args, **kwargs)
@@ -5,5 +5,6 @@ from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
5
5
  from .ReportsEnsembleF1 import ReportsEnsembleF1
6
6
  from .ReportsEnsembleKappa import ReportsEnsembleKappa
7
7
  from .ReportsLogits import ReportsLogits
8
- from .ReportsSize import ReportsSize
8
+ from .ReportsEnsembleSize import ReportsEnsembleSize
9
9
  from .StoresModels import StoresModels
10
+ from .SavesEnsembleModels import SavesEnsembleModels
homa/ensemble/utils.py ADDED
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ def get_model_device(model: torch.nn.Module):
5
+ try:
6
+ device = next(model.parameters()).device
7
+ except StopIteration:
8
+ device = torch.device("cpu")
9
+ return device
@@ -0,0 +1,13 @@
1
+ import torch
2
+ from .modules import GraphAttentionModule
3
+ from ..core.concerns import MovesNetworkToDevice
4
+
5
+
6
+ class GraphAttention(MovesNetworkToDevice):
7
+ def __init__(self, lr: float = 0.005, decay: float = 5e-4, dropout: float = 0.6):
8
+ super().__init__()
9
+ self.network = GraphAttentionModule()
10
+ self.optimizer = torch.nn.AdamW(
11
+ self.network.parameters(), lr=lr, weight_decay=decay
12
+ )
13
+ self.criterion = torch.nn.CrossEntropyLoss()
homa/graph/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .GraphAttention import GraphAttention
@@ -0,0 +1,37 @@
1
+ import torch
2
+
3
+
4
+ class GraphAttentionHeadModule(torch.nn.Module):
5
+ def __init__(self, in_features: int, out_features: int, alpha=0.2):
6
+ super().__init__()
7
+ self.in_features = in_features
8
+ self.out_features = out_features
9
+ self.alpha = alpha
10
+
11
+ self.W = torch.nn.Linear(in_features, out_features, bias=False)
12
+ self.a_1 = torch.nn.Parameter(torch.randn(out_features, 1))
13
+ self.a_2 = torch.nn.Parameter(torch.randn(out_features, 1))
14
+
15
+ self.leaky_relu = torch.nn.LeakyReLU(self.alpha)
16
+ self.elu = torch.nn.ELU()
17
+ self.reset_parameters()
18
+
19
+ def reset_parameters(self):
20
+ torch.nn.init.xavier_uniform_(self.W.weight, gain=1.414)
21
+ torch.nn.init.xavier_uniform_(self.a_1, gain=1.414)
22
+ torch.nn.init.xavier_uniform_(self.a_2, gain=1.414)
23
+
24
+ def forward(self, node_features, adj_matrix):
25
+ N = node_features.size(0)
26
+ h_prime = self.W(node_features)
27
+ s1 = torch.matmul(h_prime, self.a_1)
28
+ s2 = torch.matmul(h_prime, self.a_2)
29
+ e = s1 + s2.T
30
+ e = self.leaky_relu(e)
31
+ zero_vec = -9e15 * torch.ones_like(e)
32
+ attention_mask = torch.where(
33
+ adj_matrix > 0, e, zero_vec.to(node_features.device)
34
+ )
35
+ attention_weights = F.softmax(attention_mask, dim=1)
36
+ h_new = torch.matmul(attention_weights, h_prime)
37
+ return self.elu(h_new)
@@ -0,0 +1,22 @@
1
+ import torch
2
+ from .GraphAttentionHeadModule import GraphAttentionHeadModule
3
+
4
+
5
+ class MultiHeadGraphAttentionModule(torch.nn.Module):
6
+ def __init__(self, num_heads: int, in_features: int, out_features: int, alpha=0.2):
7
+ super().__init__()
8
+ self.num_heads = num_heads
9
+ self.head_out_features = out_features
10
+ self.heads = torch.nn.ModuleList(
11
+ [
12
+ GraphAttentionHeadModule(in_features, out_features, alpha=alpha)
13
+ for _ in range(num_heads)
14
+ ]
15
+ )
16
+
17
+ def forward(
18
+ self, node_features: torch.Tensor, adj_matrix: torch.Tensor
19
+ ) -> torch.Tensor:
20
+ outputs = [head(node_features, adj_matrix) for head in self.heads]
21
+ h_new_concat = torch.cat(outputs, dim=1)
22
+ return h_new_concat
@@ -0,0 +1,2 @@
1
+ from .GraphAttentionHeadModule import GraphAttentionHeadModule
2
+ from .MultiHeadGraphAttentionModule import MultiHeadGraphAttentionModule
homa/loss/Loss.py CHANGED
@@ -1,2 +1,5 @@
1
- class Loss:
1
+ import torch
2
+
3
+
4
+ class Loss(torch.nn.Module):
2
5
  pass
homa/rl/DQN.py ADDED
@@ -0,0 +1,2 @@
1
+ class DQN:
2
+ pass
homa/rl/DRQN.py ADDED
@@ -0,0 +1,5 @@
1
+ from .DQN import DQN
2
+
3
+
4
+ class DRQN(DQN):
5
+ pass
@@ -0,0 +1,96 @@
1
+ import torch
2
+ from .diayn.Actor import Actor
3
+ from .diayn.Critic import Critic
4
+ from .diayn.Discriminator import Discriminator
5
+ from .buffers import DiversityIsAllYouNeedBuffer, Buffer
6
+
7
+
8
+ class DiversityIsAllYouNeed:
9
+ def __init__(
10
+ self,
11
+ state_dimension: int,
12
+ action_dimension: int,
13
+ hidden_dimension: int = 256,
14
+ num_skills: int = 10,
15
+ critic_decay: float = 0.0,
16
+ actor_decay: float = 0.0,
17
+ discriminator_decay: float = 0.0,
18
+ actor_lr: float = 0.0001,
19
+ critic_lr: float = 0.001,
20
+ discriminator_lr=0.001,
21
+ buffer_capacity: int = 1_000_000,
22
+ actor_epsilon: float = 1e-6,
23
+ gamma: float = 0.99,
24
+ min_std: float = -20.0,
25
+ max_std: float = 2.0,
26
+ ):
27
+ self.buffer: Buffer = DiversityIsAllYouNeedBuffer(capacity=buffer_capacity)
28
+ self.num_skills: int = num_skills
29
+ self.actor = Actor(
30
+ state_dimension=state_dimension,
31
+ action_dimension=action_dimension,
32
+ hidden_dimension=hidden_dimension,
33
+ num_skills=num_skills,
34
+ lr=actor_lr,
35
+ decay=actor_decay,
36
+ epsilon=actor_epsilon,
37
+ min_std=min_std,
38
+ max_std=max_std,
39
+ )
40
+ self.critic = Critic(
41
+ state_dimension=state_dimension,
42
+ hidden_dimension=hidden_dimension,
43
+ num_skills=num_skills,
44
+ lr=critic_lr,
45
+ decay=critic_decay,
46
+ gamma=gamma,
47
+ )
48
+ self.discriminator = Discriminator(
49
+ state_dimension=state_dimension,
50
+ hidden_dimension=hidden_dimension,
51
+ num_skills=num_skills,
52
+ lr=discriminator_lr,
53
+ decay=discriminator_decay,
54
+ )
55
+
56
+ def one_hot(self, indices, max_index) -> torch.Tensor:
57
+ one_hot = torch.zeros(indices.size(0), max_index)
58
+ one_hot.scatter_(1, indices.unsqueeze(1), 1)
59
+ return one_hot
60
+
61
+ def skill_index(self) -> torch.Tensor:
62
+ return torch.randint(0, self.num_skills, (1,))
63
+
64
+ def skill(self) -> torch.Tensor:
65
+ return self.one_hot(self.skill_index(), self.num_skills)
66
+
67
+ def advantages(
68
+ self,
69
+ states: torch.Tensor,
70
+ skills: torch.Tensor,
71
+ rewards: torch.Tensor,
72
+ terminations: torch.Tensor,
73
+ next_states: torch.Tensor,
74
+ ) -> torch.Tensor:
75
+ values = self.critic.values(states=states, skills=skills)
76
+ termination_mask = 1 - terminations
77
+ next_values = self.critic.values_(states=next_states, skills=skills)
78
+ update = self.gamma * next_values * termination_mask
79
+ return rewards + update - values
80
+
81
+ def train(self, skill: torch.Tensor):
82
+ data = self.buffer.all_tensor()
83
+ skill_indices = skill.repeat(data.states.size(0), 1).long()
84
+ skills_indices_one_hot = self.one_hot(skill_indices, self.num_skills)
85
+ self.discriminator.train(
86
+ states=data.states, skills_indices=skills_indices_one_hot
87
+ )
88
+ advantages = self.advantages(
89
+ states=data.states,
90
+ rewards=data.rewards,
91
+ terminations=data.terminations,
92
+ next_states=data.next_states,
93
+ skills=skills,
94
+ )
95
+ self.critic.train(advantages=advantages)
96
+ self.actor.train(advantages=advantages)
@@ -0,0 +1,67 @@
1
+ from .sac import SoftActor, SoftCritic
2
+ from .buffers import SoftActorCriticBuffer
3
+ from ..core.concerns import TracksTime
4
+
5
+
6
+ class SoftActorCritic(TracksTime):
7
+ def __init__(
8
+ self,
9
+ state_dimension: int,
10
+ action_dimension: int,
11
+ hidden_dimension: int = 256,
12
+ buffer_capacity: int = 100_000,
13
+ batch_size: int = 256,
14
+ actor_lr: float = 0.0002,
15
+ critic_lr: float = 0.0003,
16
+ actor_decay: float = 0.0,
17
+ critic_decay: float = 0.0,
18
+ tau: float = 0.005,
19
+ alpha: float = 0.2,
20
+ gamma: float = 0.99,
21
+ min_std: float = -20,
22
+ max_std: float = 2,
23
+ warmup: int = 20_000,
24
+ ):
25
+ super().__init__()
26
+
27
+ self.batch_size: int = batch_size
28
+ self.warmup: int = warmup
29
+ self.tau: float = tau
30
+
31
+ self.actor = SoftActor(
32
+ state_dimension=state_dimension,
33
+ action_dimension=action_dimension,
34
+ hidden_dimension=hidden_dimension,
35
+ lr=actor_lr,
36
+ weight_decay=actor_decay,
37
+ alpha=alpha,
38
+ min_std=min_std,
39
+ max_std=max_std,
40
+ )
41
+ self.critic = SoftCritic(
42
+ state_dimension=state_dimension,
43
+ action_dimension=action_dimension,
44
+ hidden_dimension=hidden_dimension,
45
+ lr=critic_lr,
46
+ weight_decay=critic_decay,
47
+ gamma=gamma,
48
+ alpha=alpha,
49
+ )
50
+ self.buffer = SoftActorCriticBuffer(capacity=buffer_capacity)
51
+
52
+ def train(self):
53
+ # don't train before warmup
54
+ if self.buffer.size < self.warmup:
55
+ return
56
+
57
+ data = self.buffer.sample_torch(self.batch_size)
58
+ self.critic.train(
59
+ states=data.states,
60
+ actions=data.actions,
61
+ rewards=data.rewards,
62
+ terminations=data.terminations,
63
+ next_states=data.next_states,
64
+ actor=self.actor,
65
+ )
66
+ self.actor.train(states=data.states, critic=self.critic)
67
+ self.critic.update(tau=self.tau)
homa/rl/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .DiversityIsAllYouNeed import DiversityIsAllYouNeed
2
+ from .SoftActorCritic import SoftActorCritic
3
+ from .DQN import DQN
4
+ from .DRQN import DRQN
@@ -0,0 +1,13 @@
1
+ from collections import deque
2
+ from typing import Type
3
+ from .concerns import ResetsCollection, HasRecordAlternatives
4
+
5
+
6
+ class Buffer(ResetsCollection, HasRecordAlternatives):
7
+ def __init__(self, capacity: int):
8
+ self.capacity: int = capacity
9
+ self.collection: Type[deque] = deque(maxlen=self.capacity)
10
+
11
+ @property
12
+ def size(self):
13
+ return len(self.collection)
@@ -0,0 +1,50 @@
1
+ import torch
2
+ import numpy
3
+ from types import SimpleNamespace
4
+ from .Buffer import Buffer
5
+ from .concerns import HasRecordAlternatives
6
+
7
+
8
+ class DiversityIsAllYouNeedBuffer(Buffer, HasRecordAlternatives):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def all_tensor(self) -> SimpleNamespace:
13
+ return self.all(tensor=True)
14
+
15
+ def all(self, tensor: bool = False) -> SimpleNamespace:
16
+ states, actions, rewards, next_states, terminations, probabilities = zip(
17
+ *self.collection
18
+ )
19
+
20
+ if tensor:
21
+ states = torch.from_numpy(numpy.array(states))
22
+ actions = torch.from_numpy(numpy.array(actions))
23
+ rewards = torch.from_numpy(numpy.array(rewards))
24
+ next_states = torch.from_numpy(numpy.array(next_states))
25
+ terminations = torch.from_numpy(numpy.array(terminations))
26
+ probabilities = torch.from_numpy(numpy.array(probabilities))
27
+
28
+ return SimpleNamespace(
29
+ **{
30
+ "states": states,
31
+ "actions": actions,
32
+ "rewards": rewards,
33
+ "next_states": next_states,
34
+ "terminations": terminations,
35
+ "probabilities": probabilities,
36
+ }
37
+ )
38
+
39
+ def record(
40
+ self,
41
+ state: numpy.ndarray,
42
+ action: int,
43
+ reward: float,
44
+ next_state: numpy.ndarray,
45
+ termination: bool,
46
+ probability: numpy.ndarray,
47
+ ) -> None:
48
+ self.collection.append(
49
+ (state, action, reward, next_state, termination, probability)
50
+ )