evolutionary-policy-optimization 0.0.67__py3-none-any.whl → 0.0.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -55,6 +55,9 @@ def xnor(x, y):
55
55
  def divisible_by(num, den):
56
56
  return (num % den) == 0
57
57
 
58
+ def to_device(inp, device):
59
+ return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
60
+
58
61
  # tensor helpers
59
62
 
60
63
  def l2norm(t):
@@ -678,6 +681,8 @@ class Agent(Module):
678
681
  actor_lr = 1e-4,
679
682
  critic_lr = 1e-4,
680
683
  latent_lr = 1e-5,
684
+ actor_weight_decay = 1e-3,
685
+ critic_weight_decay = 1e-3,
681
686
  diversity_aux_loss_weight = 0.,
682
687
  use_critic_ema = True,
683
688
  critic_ema_beta = 0.99,
@@ -734,8 +739,8 @@ class Agent(Module):
734
739
 
735
740
  # optimizers
736
741
 
737
- self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
738
- self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
742
+ self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, weight_decay = actor_weight_decay, **actor_optim_kwargs)
743
+ self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, weight_decay = critic_weight_decay, **critic_optim_kwargs)
739
744
 
740
745
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
741
746
 
@@ -744,6 +749,12 @@ class Agent(Module):
744
749
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
745
750
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
746
751
 
752
+ self.register_buffer('dummy', tensor(0))
753
+
754
+ @property
755
+ def device(self):
756
+ return self.dummy.device
757
+
747
758
  def save(self, path, overwrite = False):
748
759
  path = Path(path)
749
760
 
@@ -838,7 +849,10 @@ class Agent(Module):
838
849
  self,
839
850
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
840
851
  epochs = 2
852
+
841
853
  ):
854
+ memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
855
+
842
856
  memories, rewards_per_latent_episode = memories_and_cumulative_rewards
843
857
 
844
858
  # stack memories
@@ -1217,6 +1231,8 @@ class EPO(Module):
1217
1231
  terminated
1218
1232
  )
1219
1233
 
1234
+ memory = Memory(*tuple(t.cpu() for t in memory))
1235
+
1220
1236
  memories.append(memory)
1221
1237
 
1222
1238
  time += 1
@@ -1228,7 +1244,7 @@ class EPO(Module):
1228
1244
 
1229
1245
  memory_for_gae = memory._replace(
1230
1246
  episode_id = invalid_episode,
1231
- value = next_value,
1247
+ value = next_value.cpu(),
1232
1248
  done = tensor(True)
1233
1249
  )
1234
1250
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.67
3
+ Version: 0.0.69
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=7WOGl22WknudsNLSZ18AWwW7rPt5ITMAtByetUwLp7M,37654
3
+ evolutionary_policy_optimization/epo.py,sha256=e83tghTNXfCW0zhhb4nIjvfbzDvzWRxgTlm3vKJd4rM,38189
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.67.dist-info/METADATA,sha256=S1biwayyDA4vTOXknMU5KeWtJTFvxePHZZ0OZzuaNms,6220
7
- evolutionary_policy_optimization-0.0.67.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.67.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.67.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.69.dist-info/METADATA,sha256=UZEaCY5lfTRMkuyEQs5PLA1AZzSOcsRzXey9kgdd9i0,6220
7
+ evolutionary_policy_optimization-0.0.69.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.69.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.69.dist-info/RECORD,,