evolutionary-policy-optimization 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +23 -4
 - {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/METADATA +2 -2
 - {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/RECORD +5 -5
 - {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/WHEEL +0 -0
 - {evolutionary_policy_optimization-0.1.2.dist-info → evolutionary_policy_optimization-0.1.4.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -79,7 +79,7 @@ def interface_torch_numpy(fn, device): 
     | 
|
| 
       79 
79 
     | 
    
         
             
                @wraps(fn)
         
     | 
| 
       80 
80 
     | 
    
         
             
                def decorated_fn(*args, **kwargs):
         
     | 
| 
       81 
81 
     | 
    
         | 
| 
       82 
     | 
    
         
            -
                    args, kwargs = tree_map(lambda t: t.cpu().numpy() if  
     | 
| 
      
 82 
     | 
    
         
            +
                    args, kwargs = tree_map(lambda t: t.cpu().numpy() if is_tensor(t) else t, (args, kwargs))
         
     | 
| 
       83 
83 
     | 
    
         | 
| 
       84 
84 
     | 
    
         
             
                    out = fn(*args, **kwargs)
         
     | 
| 
       85 
85 
     | 
    
         | 
| 
         @@ -88,6 +88,16 @@ def interface_torch_numpy(fn, device): 
     | 
|
| 
       88 
88 
     | 
    
         | 
| 
       89 
89 
     | 
    
         
             
                return decorated_fn
         
     | 
| 
       90 
90 
     | 
    
         | 
| 
      
 91 
     | 
    
         
            +
            def move_input_tensors_to_device(fn):
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                @wraps(fn)
         
     | 
| 
      
 94 
     | 
    
         
            +
                def decorated_fn(self, *args, **kwargs):
         
     | 
| 
      
 95 
     | 
    
         
            +
                    args, kwargs = tree_map(lambda t: t.to(self.device) if is_tensor(t) else t, (args, kwargs))
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                    return fn(self, *args, **kwargs)
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                return decorated_fn
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
       91 
101 
     | 
    
         
             
            # tensor helpers
         
     | 
| 
       92 
102 
     | 
    
         | 
| 
       93 
103 
     | 
    
         
             
            def l2norm(t):
         
     | 
| 
         @@ -797,6 +807,8 @@ class Agent(Module): 
     | 
|
| 
       797 
807 
     | 
    
         | 
| 
       798 
808 
     | 
    
         
             
                    self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
         
     | 
| 
       799 
809 
     | 
    
         | 
| 
      
 810 
     | 
    
         
            +
                    dummy = tensor(0)
         
     | 
| 
      
 811 
     | 
    
         
            +
             
     | 
| 
       800 
812 
     | 
    
         
             
                    if wrap_with_accelerate:
         
     | 
| 
       801 
813 
     | 
    
         
             
                        (
         
     | 
| 
       802 
814 
     | 
    
         
             
                            self.actor,
         
     | 
| 
         @@ -816,11 +828,14 @@ class Agent(Module): 
     | 
|
| 
       816 
828 
     | 
    
         
             
                            )
         
     | 
| 
       817 
829 
     | 
    
         
             
                        )
         
     | 
| 
       818 
830 
     | 
    
         | 
| 
       819 
     | 
    
         
            -
             
     | 
| 
      
 831 
     | 
    
         
            +
                        if exists(self.critic_ema):
         
     | 
| 
      
 832 
     | 
    
         
            +
                            self.critic_ema.to(self.accelerate.device)
         
     | 
| 
       820 
833 
     | 
    
         | 
| 
       821 
     | 
    
         
            -
             
     | 
| 
      
 834 
     | 
    
         
            +
                        dummy = dummy.to(self.accelerate.device)
         
     | 
| 
       822 
835 
     | 
    
         | 
| 
       823 
     | 
    
         
            -
                     
     | 
| 
      
 836 
     | 
    
         
            +
                    # device tracking
         
     | 
| 
      
 837 
     | 
    
         
            +
             
     | 
| 
      
 838 
     | 
    
         
            +
                    self.register_buffer('dummy', dummy)
         
     | 
| 
       824 
839 
     | 
    
         | 
| 
       825 
840 
     | 
    
         
             
                @property
         
     | 
| 
       826 
841 
     | 
    
         
             
                def device(self):
         
     | 
| 
         @@ -870,6 +885,7 @@ class Agent(Module): 
     | 
|
| 
       870 
885 
     | 
    
         
             
                    if exists(pkg.get('latent_optim', None)):
         
     | 
| 
       871 
886 
     | 
    
         
             
                        self.latent_optim.load_state_dict(pkg['latent_optim'])
         
     | 
| 
       872 
887 
     | 
    
         | 
| 
      
 888 
     | 
    
         
            +
                @move_input_tensors_to_device
         
     | 
| 
       873 
889 
     | 
    
         
             
                def get_actor_actions(
         
     | 
| 
       874 
890 
     | 
    
         
             
                    self,
         
     | 
| 
       875 
891 
     | 
    
         
             
                    state,
         
     | 
| 
         @@ -884,6 +900,7 @@ class Agent(Module): 
     | 
|
| 
       884 
900 
     | 
    
         
             
                    if not exists(latent) and exists(latent_id):
         
     | 
| 
       885 
901 
     | 
    
         
             
                        latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
         
     | 
| 
       886 
902 
     | 
    
         | 
| 
      
 903 
     | 
    
         
            +
                    print(self.device, state.device, next(self.actor.parameters()).device)
         
     | 
| 
       887 
904 
     | 
    
         
             
                    logits = maybe_unwrap(self.actor)(state, latent)
         
     | 
| 
       888 
905 
     | 
    
         | 
| 
       889 
906 
     | 
    
         
             
                    if not sample:
         
     | 
| 
         @@ -895,6 +912,7 @@ class Agent(Module): 
     | 
|
| 
       895 
912 
     | 
    
         | 
| 
       896 
913 
     | 
    
         
             
                    return actions, log_probs
         
     | 
| 
       897 
914 
     | 
    
         | 
| 
      
 915 
     | 
    
         
            +
                @move_input_tensors_to_device
         
     | 
| 
       898 
916 
     | 
    
         
             
                def get_critic_values(
         
     | 
| 
       899 
917 
     | 
    
         
             
                    self,
         
     | 
| 
       900 
918 
     | 
    
         
             
                    state,
         
     | 
| 
         @@ -903,6 +921,7 @@ class Agent(Module): 
     | 
|
| 
       903 
921 
     | 
    
         
             
                    use_ema_if_available = False,
         
     | 
| 
       904 
922 
     | 
    
         
             
                    use_unwrapped_model = False
         
     | 
| 
       905 
923 
     | 
    
         
             
                ):
         
     | 
| 
      
 924 
     | 
    
         
            +
             
     | 
| 
       906 
925 
     | 
    
         
             
                    maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
         
     | 
| 
       907 
926 
     | 
    
         | 
| 
       908 
927 
     | 
    
         
             
                    if not exists(latent) and exists(latent_id):
         
     | 
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.4
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: evolutionary-policy-optimization
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.1. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.1.4
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: EPO - Pytorch
         
     | 
| 
       5 
5 
     | 
    
         
             
            Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
         
     | 
| 
       6 
6 
     | 
    
         
             
            Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
         
     | 
| 
         @@ -53,7 +53,7 @@ Description-Content-Type: text/markdown 
     | 
|
| 
       53 
53 
     | 
    
         | 
| 
       54 
54 
     | 
    
         
             
            <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
         
     | 
| 
       55 
55 
     | 
    
         | 
| 
       56 
     | 
    
         
            -
            ## Evolutionary Policy Optimization 
     | 
| 
      
 56 
     | 
    
         
            +
            ## Evolutionary Policy Optimization
         
     | 
| 
       57 
57 
     | 
    
         | 
| 
       58 
58 
     | 
    
         
             
            Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
         
     | 
| 
       59 
59 
     | 
    
         | 
| 
         @@ -1,10 +1,10 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
         
     | 
| 
       2 
2 
     | 
    
         
             
            evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
         
     | 
| 
       3 
3 
     | 
    
         
             
            evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
         
     | 
| 
       4 
     | 
    
         
            -
            evolutionary_policy_optimization/epo.py,sha256= 
     | 
| 
      
 4 
     | 
    
         
            +
            evolutionary_policy_optimization/epo.py,sha256=3PEmFE032fQz7mAfIoJzwmysbU6lnjCCngEMc9Jf46M,42180
         
     | 
| 
       5 
5 
     | 
    
         
             
            evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
         
     | 
| 
       6 
6 
     | 
    
         
             
            evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
         
     | 
| 
       7 
     | 
    
         
            -
            evolutionary_policy_optimization-0.1. 
     | 
| 
       8 
     | 
    
         
            -
            evolutionary_policy_optimization-0.1. 
     | 
| 
       9 
     | 
    
         
            -
            evolutionary_policy_optimization-0.1. 
     | 
| 
       10 
     | 
    
         
            -
            evolutionary_policy_optimization-0.1. 
     | 
| 
      
 7 
     | 
    
         
            +
            evolutionary_policy_optimization-0.1.4.dist-info/METADATA,sha256=81Q8wK2N7zlOsSGuqcXrftpd1TIYkrplSPBuldMj8uU,6330
         
     | 
| 
      
 8 
     | 
    
         
            +
            evolutionary_policy_optimization-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         
     | 
| 
      
 9 
     | 
    
         
            +
            evolutionary_policy_optimization-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         
     | 
| 
      
 10 
     | 
    
         
            +
            evolutionary_policy_optimization-0.1.4.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     |