evolutionary-policy-optimization 0.0.57__py3-none-any.whl → 0.0.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -127,15 +127,9 @@ def calc_generalized_advantage_estimate(
127
127
  delta = rewards + gamma * values_next * masks - values
128
128
  gates = gamma * lam * masks
129
129
 
130
- gates, delta = gates[..., :, None], delta[..., :, None]
131
-
132
130
  scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
133
131
 
134
- gae = scan(gates, delta)
135
-
136
- gae = gae[..., :, 0]
137
-
138
- return gae
132
+ return scan(gates, delta)
139
133
 
140
134
  # evolution related functions
141
135
 
@@ -856,10 +850,12 @@ class Agent(Module):
856
850
  dones
857
851
  ) = memories
858
852
 
853
+ masks = 1. - dones.float()
854
+
859
855
  advantages = self.calc_gae(
860
856
  rewards[:-1],
861
857
  values,
862
- dones[:-1],
858
+ masks[:-1],
863
859
  )
864
860
 
865
861
  valid_episode = episode_ids >= 0
@@ -956,7 +952,10 @@ def actor_loss(
956
952
  advantages, # Float[b]
957
953
  eps_clip = 0.2,
958
954
  entropy_weight = .01,
955
+ eps = 1e-5
959
956
  ):
957
+ batch = logits.shape[0]
958
+
960
959
  log_probs = gather_log_prob(logits, actions)
961
960
 
962
961
  ratio = (log_probs - old_log_probs).exp()
@@ -965,6 +964,8 @@ def actor_loss(
965
964
 
966
965
  clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
967
966
 
967
+ advantages = F.layer_norm(advantages, (batch,), eps = eps)
968
+
968
969
  actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
969
970
 
970
971
  # add entropy loss for exploration
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.57
3
+ Version: 0.0.60
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: adam-atan2-pytorch
38
- Requires-Dist: assoc-scan
38
+ Requires-Dist: assoc-scan>=0.0.2
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: ema-pytorch>=0.7.7
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=qPj5kRsISY1I6WjCc-ejpuiwOSxtPsSdMABmchXJ3s0,35252
3
+ evolutionary_policy_optimization/epo.py,sha256=pD4j_oP7Cg8vSVQE34oJMcXcYN4oOi2_TtOAI5YbZDQ,35298
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.57.dist-info/METADATA,sha256=WBHRK98s_lzWbqG4ouq620ayykPF9SHUz3HdvsRUywc,6213
7
- evolutionary_policy_optimization-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.57.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.60.dist-info/METADATA,sha256=vfiTPTi00-ZPqXgeVn4yEzCcD8p2k61aDmp5YX99Uww,6220
7
+ evolutionary_policy_optimization-0.0.60.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.60.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.60.dist-info/RECORD,,