homa 0.3.11__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
homa/rl/buffers/Buffer.py CHANGED
@@ -1,10 +1,12 @@
1
+ from collections import deque
2
+ from typing import Type
1
3
  from .concerns import ResetsCollection, HasRecordAlternatives
2
4
 
3
5
 
4
6
  class Buffer(ResetsCollection, HasRecordAlternatives):
5
7
  def __init__(self, capacity: int):
6
8
  self.capacity: int = capacity
7
- self.reset()
9
+ self.collection: Type[deque] = deque(maxlen=self.capacity)
8
10
 
9
11
  @property
10
12
  def size(self):
@@ -35,7 +35,7 @@ class SoftActorCriticBuffer(Buffer):
35
35
 
36
36
  if as_tensor:
37
37
  states = torch.from_numpy(states).float()
38
- actions = torch.from_numpy(actions).long()
38
+ actions = torch.from_numpy(actions).float()
39
39
  rewards = torch.from_numpy(rewards).float()
40
40
  next_states = torch.from_numpy(next_states).float()
41
41
  terminations = torch.from_numpy(terminations).float()
homa/rl/sac/SoftActor.py CHANGED
@@ -29,6 +29,7 @@ class SoftActor:
29
29
  )
30
30
 
31
31
  def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
32
+ self.network.train()
32
33
  self.optimizer.zero_grad()
33
34
  loss = self.loss(states=states, critic_network=critic_network)
34
35
  loss.backward()
@@ -64,6 +65,6 @@ class SoftActor:
64
65
  action = torch.tanh(pre_tanh)
65
66
 
66
67
  probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
67
- correction = torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
68
+ probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
68
69
 
69
- return action, probabilities - correction
70
+ return action, probabilities
homa/rl/sac/SoftCritic.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import torch
2
2
  from torch.nn.functional import mse_loss as mse
3
- from typing import Type
4
3
  from .modules import DualSoftCriticModule
5
4
  from .SoftActor import SoftActor
5
+ from ..utils import soft_update
6
6
 
7
7
 
8
8
  class SoftCritic:
@@ -31,6 +31,10 @@ class SoftCritic:
31
31
  hidden_dimension=hidden_dimension,
32
32
  action_dimension=action_dimension,
33
33
  )
34
+
35
+ # copy source to target when initiated
36
+ self.target.load_state_dict(self.network.state_dict())
37
+
34
38
  self.optimizer = torch.optim.AdamW(
35
39
  self.network.parameters(), lr=lr, weight_decay=weight_decay
36
40
  )
@@ -42,8 +46,9 @@ class SoftCritic:
42
46
  rewards: torch.Tensor,
43
47
  terminations: torch.Tensor,
44
48
  next_states: torch.Tensor,
45
- actor: torch.nn.Module,
49
+ actor: SoftActor,
46
50
  ):
51
+ self.network.train()
47
52
  self.optimizer.zero_grad()
48
53
  loss = self.loss(
49
54
  states=states,
@@ -65,7 +70,7 @@ class SoftCritic:
65
70
  next_states: torch.Tensor,
66
71
  actor: torch.nn.Module,
67
72
  ):
68
- q_alpha, q_beta = self.target(states, actions)
73
+ q_alpha, q_beta = self.network(states, actions)
69
74
  target = self.calculate_target(
70
75
  rewards=rewards,
71
76
  terminations=terminations,
@@ -82,19 +87,13 @@ class SoftCritic:
82
87
  next_states: torch.Tensor,
83
88
  actor: SoftActor,
84
89
  ):
90
+ termination_mask = 1 - terminations
85
91
  next_actions, next_probabilities = actor.sample(next_states)
86
92
  q_alpha, q_beta = self.target(next_states, next_actions)
87
93
  q = torch.min(q_alpha, q_beta)
88
- termination_mask = 1 - terminations
89
- entropy_q = q - self.alpha * next_probabilities * termination_mask
90
- return rewards + self.gamma * entropy_q
91
-
92
- def soft_update(
93
- self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
94
- ):
95
- for s, t in zip(network.parameters(), target.parameters()):
96
- t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
94
+ entropy_q = q - self.alpha * next_probabilities
95
+ return rewards + self.gamma * termination_mask * entropy_q
97
96
 
98
97
  def update(self):
99
- self.soft_update(self.network.alpha, self.target.alpha)
100
- self.soft_update(self.network.beta, self.target.beta)
98
+ soft_update(network=self.network.alpha, target=self.target.alpha, tau=self.tau)
99
+ soft_update(network=self.network.beta, target=self.target.beta, tau=self.tau)
@@ -30,6 +30,6 @@ class SoftActorModule(torch.nn.Module):
30
30
  def forward(self, state: torch.Tensor):
31
31
  features = self.phi(state)
32
32
  mean = self.mu(features)
33
- std = self.mu(features)
33
+ std = self.xi(features)
34
34
  std = std.clamp(self.min_std, self.max_std)
35
35
  return mean, std
homa/rl/utils.py ADDED
@@ -0,0 +1,7 @@
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def soft_update(network: torch.nn.Module, target: torch.nn.Module, tau: float):
6
+ for s, t in zip(network.parameters(), target.parameters()):
7
+ t.data.copy_(tau * s.data + (1 - tau) * t.data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.3.11
3
+ Version: 0.3.15
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -101,10 +101,11 @@ homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
101
101
  homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
102
102
  homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
103
103
  homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
104
- homa/rl/buffers/Buffer.py,sha256=wOk8MH0Wf0cpvavpHIK2O7PrbGP6MwHTH5YFkq2Ints,288
104
+ homa/rl/utils.py,sha256=IqbN5aDLwovocpPbxgywuetjz7GQwh9aJ4WFIOtLP3g,232
105
+ homa/rl/buffers/Buffer.py,sha256=YCESh9tFxgWOLzGQj_IA0zLJoZWDmz6gCNu1iYsGp1s,388
105
106
  homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
106
107
  homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
107
- homa/rl/buffers/SoftActorCriticBuffer.py,sha256=iDC2C5XFvONT3f7YX_gYXQJGU9wz2usvPOVGbQUd22M,1796
108
+ homa/rl/buffers/SoftActorCriticBuffer.py,sha256=JQ9Y6KeeQS5naO_JPONiks-HYXw7hiZZAbqpoWDZlNI,1797
108
109
  homa/rl/buffers/__init__.py,sha256=h1AkCHs6isXbNtxpaZfLp6YudHj1KlnOvURE64vhRa4,190
109
110
  homa/rl/buffers/concerns/HasRecordAlternatives.py,sha256=D5aVlPZlnGm0GyGtikKb4wZqyO6zpyqR1IOETmAgLx4,362
110
111
  homa/rl/buffers/concerns/ResetsCollection.py,sha256=bZ8q4czYXo1jMtVCnnlG69OgiJ0AqSGY6CiKzJC6xtQ,215
@@ -117,11 +118,11 @@ homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7G
117
118
  homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
118
119
  homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
119
120
  homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
120
- homa/rl/sac/SoftActor.py,sha256=CxR58IFrZ6xlmBj_gq_abZfgdzlVD71c6wA6wQiVL2c,2142
121
- homa/rl/sac/SoftCritic.py,sha256=wFIunTgKGBy64Igu7zuvE2BvGz2e-DTplviLyq4tQ7M,3031
121
+ homa/rl/sac/SoftActor.py,sha256=NSTqnv_BZzTqfgEEOIEtOgYV2_VycicIF0alD1O5Nk8,2162
122
+ homa/rl/sac/SoftCritic.py,sha256=rOgPR8zRUtjEwF9W4q5nZQaGXFmf_9tmXqaRWzUkAm8,2980
122
123
  homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
123
124
  homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
124
- homa/rl/sac/modules/SoftActorModule.py,sha256=AiWnsWkmQONjOAWAp06eO-lLWEYNJDmx8FSjPKTcjI0,1152
125
+ homa/rl/sac/modules/SoftActorModule.py,sha256=LQ4z7s8mE3wwb1JgxPs0QvnriZULK3_ULdhkt60Ffpw,1152
125
126
  homa/rl/sac/modules/SoftCriticModule.py,sha256=aOfhDZTB5og-BLTsmdBdIcRufygCJUas7P-ikBvWQ34,928
126
127
  homa/rl/sac/modules/__init__.py,sha256=h-22B5CAK1xhn75tolI5J5sQMxl--kOXbQ6r_JfHIOA,147
127
128
  homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
@@ -142,8 +143,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
142
143
  homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
143
144
  homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
144
145
  homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
145
- homa-0.3.11.dist-info/METADATA,sha256=SvSxNXB1IsX3N5IfhOsnWYtvhjpfzauJPanVH7i5cRs,1760
146
- homa-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
- homa-0.3.11.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
148
- homa-0.3.11.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
149
- homa-0.3.11.dist-info/RECORD,,
146
+ homa-0.3.15.dist-info/METADATA,sha256=jo-jsI9A6KvK95nGmxrh_IExb1WeNp_3DWTZgMDSxcI,1760
147
+ homa-0.3.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
148
+ homa-0.3.15.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
149
+ homa-0.3.15.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
150
+ homa-0.3.15.dist-info/RECORD,,
File without changes