evolutionary-policy-optimization 0.0.68__py3-none-any.whl → 0.0.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -416,6 +416,8 @@ class LatentGenePool(Module):
416
416
  self.num_natural_selected = int(frac_natural_selected * latents_per_island)
417
417
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
418
418
 
419
+ assert self.num_tournament_participants >= 2
420
+
419
421
  self.crossover_random = crossover_random
420
422
 
421
423
  self.mutation_strength = mutation_strength
@@ -681,6 +683,8 @@ class Agent(Module):
681
683
  actor_lr = 1e-4,
682
684
  critic_lr = 1e-4,
683
685
  latent_lr = 1e-5,
686
+ actor_weight_decay = 1e-3,
687
+ critic_weight_decay = 1e-3,
684
688
  diversity_aux_loss_weight = 0.,
685
689
  use_critic_ema = True,
686
690
  critic_ema_beta = 0.99,
@@ -737,8 +741,8 @@ class Agent(Module):
737
741
 
738
742
  # optimizers
739
743
 
740
- self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
741
- self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
744
+ self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, weight_decay = actor_weight_decay, **actor_optim_kwargs)
745
+ self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, weight_decay = critic_weight_decay, **critic_optim_kwargs)
742
746
 
743
747
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
744
748
 
@@ -843,7 +847,7 @@ class Agent(Module):
843
847
 
844
848
  return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
845
849
 
846
- def forward(
850
+ def learn_from(
847
851
  self,
848
852
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
849
853
  epochs = 2
@@ -851,11 +855,13 @@ class Agent(Module):
851
855
  ):
852
856
  memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
853
857
 
854
- memories, rewards_per_latent_episode = memories_and_cumulative_rewards
858
+ memories_list, rewards_per_latent_episode = memories_and_cumulative_rewards
855
859
 
856
860
  # stack memories
857
861
 
858
- memories = map(stack, zip(*memories))
862
+ memories = map(stack, zip(*memories_list))
863
+
864
+ memories_list.clear()
859
865
 
860
866
  maybe_barrier()
861
867
 
@@ -977,7 +983,6 @@ class Agent(Module):
977
983
  # apply evolution
978
984
 
979
985
  if self.has_latent_genes:
980
-
981
986
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
982
987
 
983
988
  # reinforcement learning related - ppo
@@ -1157,9 +1162,10 @@ class EPO(Module):
1157
1162
  yield latent_id, episode_id, maybe_seed
1158
1163
 
1159
1164
  @torch.no_grad()
1160
- def forward(
1165
+ def gather_experience_from(
1161
1166
  self,
1162
1167
  env,
1168
+ memories: list[Memory] | None = None,
1163
1169
  fix_environ_across_latents = None
1164
1170
  ) -> MemoriesAndCumulativeRewards:
1165
1171
 
@@ -1169,7 +1175,8 @@ class EPO(Module):
1169
1175
 
1170
1176
  invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
1171
1177
 
1172
- memories: list[Memory] = []
1178
+ if not exists(memories):
1179
+ memories = []
1173
1180
 
1174
1181
  rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1175
1182
 
@@ -1252,3 +1259,18 @@ class EPO(Module):
1252
1259
  memories = memories,
1253
1260
  cumulative_rewards = rewards_per_latent_episode
1254
1261
  )
1262
+
1263
+ def forward(
1264
+ self,
1265
+ agent: Agent,
1266
+ env,
1267
+ num_learning_cycles
1268
+ ):
1269
+
1270
+ for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
1271
+
1272
+ memories = self.gather_experience_from(env)
1273
+
1274
+ agent.learn_from(memories)
1275
+
1276
+ print(f'training complete')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.68
3
+ Version: 0.0.70
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -114,11 +114,14 @@ from evolutionary_policy_optimization import (
114
114
 
115
115
  agent = create_agent(
116
116
  dim_state = 512,
117
- num_latents = 8,
117
+ num_latents = 16,
118
118
  dim_latent = 32,
119
119
  actor_num_actions = 5,
120
120
  actor_dim_hiddens = (256, 128),
121
- critic_dim_hiddens = (256, 128, 64)
121
+ critic_dim_hiddens = (256, 128, 64),
122
+ latent_gene_pool_kwargs = dict(
123
+ frac_natural_selected = 0.5
124
+ )
122
125
  )
123
126
 
124
127
  epo = EPO(
@@ -130,9 +133,7 @@ epo = EPO(
130
133
 
131
134
  env = Env((512,))
132
135
 
133
- memories = epo(env)
134
-
135
- agent(memories)
136
+ epo(agent, env, num_learning_cycles = 5)
136
137
 
137
138
  # saving and loading
138
139
 
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=xhE_kHas54xGsgOese9SQEvyK7NKZqEuK3AiVhm0y7Q,38047
3
+ evolutionary_policy_optimization/epo.py,sha256=MmsqMwytVqBkb1f2piUygueOfn--Icb817P4bDcfPks,38683
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.68.dist-info/METADATA,sha256=hOOKOrrPQtQmK3zN1z5nkGJEoaQLyXUzs9ArsEKn1DE,6220
7
- evolutionary_policy_optimization-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.68.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.68.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.70.dist-info/METADATA,sha256=fX_hR3dCjKUQ3VZtT-sUQy0qT8sF-nDPyP0QDAGHf60,6304
7
+ evolutionary_policy_optimization-0.0.70.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.70.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.70.dist-info/RECORD,,