evolutionary-policy-optimization 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,7 +79,7 @@ def interface_torch_numpy(fn, device):
79
79
  @wraps(fn)
80
80
  def decorated_fn(*args, **kwargs):
81
81
 
82
- args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
82
+ args, kwargs = tree_map(lambda t: t.cpu().numpy() if is_tensor(t) else t, (args, kwargs))
83
83
 
84
84
  out = fn(*args, **kwargs)
85
85
 
@@ -88,6 +88,16 @@ def interface_torch_numpy(fn, device):
88
88
 
89
89
  return decorated_fn
90
90
 
91
+ def move_input_tensors_to_device(fn):
92
+
93
+ @wraps(fn)
94
+ def decorated_fn(self, *args, **kwargs):
95
+ args, kwargs = tree_map(lambda t: t.to(self.device) if is_tensor(t) else t, (args, kwargs))
96
+
97
+ return fn(self, *args, **kwargs)
98
+
99
+ return decorated_fn
100
+
91
101
  # tensor helpers
92
102
 
93
103
  def l2norm(t):
@@ -797,6 +807,8 @@ class Agent(Module):
797
807
 
798
808
  self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
799
809
 
810
+ dummy = tensor(0)
811
+
800
812
  if wrap_with_accelerate:
801
813
  (
802
814
  self.actor,
@@ -816,11 +828,14 @@ class Agent(Module):
816
828
  )
817
829
  )
818
830
 
819
- # device tracking
831
+ if exists(self.critic_ema):
832
+ self.critic_ema.to(self.accelerate.device)
820
833
 
821
- self.register_buffer('dummy', tensor(0, device = self.accelerate.device))
834
+ dummy = dummy.to(self.accelerate.device)
822
835
 
823
- self.critic_ema.to(self.accelerate.device)
836
+ # device tracking
837
+
838
+ self.register_buffer('dummy', dummy)
824
839
 
825
840
  @property
826
841
  def device(self):
@@ -870,6 +885,7 @@ class Agent(Module):
870
885
  if exists(pkg.get('latent_optim', None)):
871
886
  self.latent_optim.load_state_dict(pkg['latent_optim'])
872
887
 
888
+ @move_input_tensors_to_device
873
889
  def get_actor_actions(
874
890
  self,
875
891
  state,
@@ -884,6 +900,7 @@ class Agent(Module):
884
900
  if not exists(latent) and exists(latent_id):
885
901
  latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
886
902
 
903
+ print(self.device, state.device, next(self.actor.parameters()).device)
887
904
  logits = maybe_unwrap(self.actor)(state, latent)
888
905
 
889
906
  if not sample:
@@ -895,6 +912,7 @@ class Agent(Module):
895
912
 
896
913
  return actions, log_probs
897
914
 
915
+ @move_input_tensors_to_device
898
916
  def get_critic_values(
899
917
  self,
900
918
  state,
@@ -903,6 +921,7 @@ class Agent(Module):
903
921
  use_ema_if_available = False,
904
922
  use_unwrapped_model = False
905
923
  ):
924
+
906
925
  maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
907
926
 
908
927
  if not exists(latent) and exists(latent_id):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
55
55
 
56
- ## Evolutionary Policy Optimization (wip)
56
+ ## Evolutionary Policy Optimization
57
57
 
58
58
  Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
59
59
 
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=VgiPvaG-Lib1JZRS3oNV7BSv-19YSWKK2yCMtqDqP-M,41682
4
+ evolutionary_policy_optimization/epo.py,sha256=3PEmFE032fQz7mAfIoJzwmysbU6lnjCCngEMc9Jf46M,42180
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.2.dist-info/METADATA,sha256=kXkidoqrBxyMtQvh7GHr1wh1pexy2WDSDWN5laQSPhI,6336
8
- evolutionary_policy_optimization-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.2.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.4.dist-info/METADATA,sha256=81Q8wK2N7zlOsSGuqcXrftpd1TIYkrplSPBuldMj8uU,6330
8
+ evolutionary_policy_optimization-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.4.dist-info/RECORD,,