homa 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
homa/rl/sac/SoftActor.py CHANGED
@@ -29,6 +29,7 @@ class SoftActor:
29
29
  )
30
30
 
31
31
  def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
32
+ self.network.train()
32
33
  self.optimizer.zero_grad()
33
34
  loss = self.loss(states=states, critic_network=critic_network)
34
35
  loss.backward()
@@ -64,6 +65,6 @@ class SoftActor:
64
65
  action = torch.tanh(pre_tanh)
65
66
 
66
67
  probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
67
- correction = torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
68
+ probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
68
69
 
69
- return action, probabilities - correction
70
+ return action, probabilities
homa/rl/sac/SoftCritic.py CHANGED
@@ -48,6 +48,7 @@ class SoftCritic:
48
48
  next_states: torch.Tensor,
49
49
  actor: SoftActor,
50
50
  ):
51
+ self.network.train()
51
52
  self.optimizer.zero_grad()
52
53
  loss = self.loss(
53
54
  states=states,
@@ -94,5 +95,5 @@ class SoftCritic:
94
95
  return rewards + self.gamma * termination_mask * entropy_q
95
96
 
96
97
  def update(self):
97
- soft_update(network=self.network.alpha, target=self.target.alpha)
98
- soft_update(network=self.network.beta, target=self.target.beta)
98
+ soft_update(network=self.network.alpha, target=self.target.alpha, tau=self.tau)
99
+ soft_update(network=self.network.beta, target=self.target.beta, tau=self.tau)
homa/rl/utils.py CHANGED
@@ -2,6 +2,6 @@ import torch
2
2
 
3
3
 
4
4
  @torch.no_grad()
5
- def soft_update(self, network: torch.nn.Module, target: torch.nn.Module):
5
+ def soft_update(network: torch.nn.Module, target: torch.nn.Module, tau: float):
6
6
  for s, t in zip(network.parameters(), target.parameters()):
7
- t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
7
+ t.data.copy_(tau * s.data + (1 - tau) * t.data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.3.13
3
+ Version: 0.3.15
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -101,7 +101,7 @@ homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
101
101
  homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
102
102
  homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
103
103
  homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
104
- homa/rl/utils.py,sha256=ySNGwWFBgN3Phg7Xn99CIJVHRWw3AmU9KcY-GP4ZQBc,236
104
+ homa/rl/utils.py,sha256=IqbN5aDLwovocpPbxgywuetjz7GQwh9aJ4WFIOtLP3g,232
105
105
  homa/rl/buffers/Buffer.py,sha256=YCESh9tFxgWOLzGQj_IA0zLJoZWDmz6gCNu1iYsGp1s,388
106
106
  homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
107
107
  homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
@@ -118,8 +118,8 @@ homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7G
118
118
  homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
119
119
  homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
120
120
  homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
121
- homa/rl/sac/SoftActor.py,sha256=CxR58IFrZ6xlmBj_gq_abZfgdzlVD71c6wA6wQiVL2c,2142
122
- homa/rl/sac/SoftCritic.py,sha256=EOX1vpH7YVwDuy-RdFgEpIKGioE7si0awoAYrHMTv4g,2923
121
+ homa/rl/sac/SoftActor.py,sha256=NSTqnv_BZzTqfgEEOIEtOgYV2_VycicIF0alD1O5Nk8,2162
122
+ homa/rl/sac/SoftCritic.py,sha256=rOgPR8zRUtjEwF9W4q5nZQaGXFmf_9tmXqaRWzUkAm8,2980
123
123
  homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
124
124
  homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
125
125
  homa/rl/sac/modules/SoftActorModule.py,sha256=LQ4z7s8mE3wwb1JgxPs0QvnriZULK3_ULdhkt60Ffpw,1152
@@ -143,8 +143,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
143
143
  homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
144
144
  homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
145
145
  homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
146
- homa-0.3.13.dist-info/METADATA,sha256=1oENQJVWWN-hpkF440JOK0IiOI-i-Vbx9zcl7y5_pcs,1760
147
- homa-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
148
- homa-0.3.13.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
149
- homa-0.3.13.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
150
- homa-0.3.13.dist-info/RECORD,,
146
+ homa-0.3.15.dist-info/METADATA,sha256=jo-jsI9A6KvK95nGmxrh_IExb1WeNp_3DWTZgMDSxcI,1760
147
+ homa-0.3.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
148
+ homa-0.3.15.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
149
+ homa-0.3.15.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
150
+ homa-0.3.15.dist-info/RECORD,,
File without changes