evolutionary-policy-optimization 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +23 -4
- {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/METADATA +2 -2
- {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -79,7 +79,7 @@ def interface_torch_numpy(fn, device):
|
|
79
79
|
@wraps(fn)
|
80
80
|
def decorated_fn(*args, **kwargs):
|
81
81
|
|
82
|
-
args, kwargs = tree_map(lambda t: t.cpu().numpy() if
|
82
|
+
args, kwargs = tree_map(lambda t: t.cpu().numpy() if is_tensor(t) else t, (args, kwargs))
|
83
83
|
|
84
84
|
out = fn(*args, **kwargs)
|
85
85
|
|
@@ -88,6 +88,16 @@ def interface_torch_numpy(fn, device):
|
|
88
88
|
|
89
89
|
return decorated_fn
|
90
90
|
|
91
|
+
def move_input_tensors_to_device(fn):
|
92
|
+
|
93
|
+
@wraps(fn)
|
94
|
+
def decorated_fn(self, *args, **kwargs):
|
95
|
+
args, kwargs = tree_map(lambda t: t.to(self.device) if is_tensor(t) else t, (args, kwargs))
|
96
|
+
|
97
|
+
return fn(self, *args, **kwargs)
|
98
|
+
|
99
|
+
return decorated_fn
|
100
|
+
|
91
101
|
# tensor helpers
|
92
102
|
|
93
103
|
def l2norm(t):
|
@@ -797,6 +807,8 @@ class Agent(Module):
|
|
797
807
|
|
798
808
|
self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
|
799
809
|
|
810
|
+
dummy = tensor(0)
|
811
|
+
|
800
812
|
if wrap_with_accelerate:
|
801
813
|
(
|
802
814
|
self.actor,
|
@@ -816,11 +828,14 @@ class Agent(Module):
|
|
816
828
|
)
|
817
829
|
)
|
818
830
|
|
819
|
-
|
831
|
+
if exists(self.critic_ema):
|
832
|
+
self.critic_ema.to(self.accelerate.device)
|
820
833
|
|
821
|
-
|
834
|
+
dummy = dummy.to(self.accelerate.device)
|
822
835
|
|
823
|
-
|
836
|
+
# device tracking
|
837
|
+
|
838
|
+
self.register_buffer('dummy', dummy)
|
824
839
|
|
825
840
|
@property
|
826
841
|
def device(self):
|
@@ -870,6 +885,7 @@ class Agent(Module):
|
|
870
885
|
if exists(pkg.get('latent_optim', None)):
|
871
886
|
self.latent_optim.load_state_dict(pkg['latent_optim'])
|
872
887
|
|
888
|
+
@move_input_tensors_to_device
|
873
889
|
def get_actor_actions(
|
874
890
|
self,
|
875
891
|
state,
|
@@ -884,6 +900,7 @@ class Agent(Module):
|
|
884
900
|
if not exists(latent) and exists(latent_id):
|
885
901
|
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
886
902
|
|
903
|
+
print(self.device, state.device, next(self.actor.parameters()).device)
|
887
904
|
logits = maybe_unwrap(self.actor)(state, latent)
|
888
905
|
|
889
906
|
if not sample:
|
@@ -895,6 +912,7 @@ class Agent(Module):
|
|
895
912
|
|
896
913
|
return actions, log_probs
|
897
914
|
|
915
|
+
@move_input_tensors_to_device
|
898
916
|
def get_critic_values(
|
899
917
|
self,
|
900
918
|
state,
|
@@ -903,6 +921,7 @@ class Agent(Module):
|
|
903
921
|
use_ema_if_available = False,
|
904
922
|
use_unwrapped_model = False
|
905
923
|
):
|
924
|
+
|
906
925
|
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
907
926
|
|
908
927
|
if not exists(latent) and exists(latent_id):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.4
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
|
|
53
53
|
|
54
54
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
55
55
|
|
56
|
-
## Evolutionary Policy Optimization
|
56
|
+
## Evolutionary Policy Optimization
|
57
57
|
|
58
58
|
Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
|
59
59
|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=3PEmFE032fQz7mAfIoJzwmysbU6lnjCCngEMc9Jf46M,42180
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.4.dist-info/METADATA,sha256=81Q8wK2N7zlOsSGuqcXrftpd1TIYkrplSPBuldMj8uU,6330
|
8
|
+
evolutionary_policy_optimization-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.4.dist-info/RECORD,,
|
File without changes
|