evolutionary-policy-optimization 0.0.46__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from functools import partial, wraps
4
4
  from pathlib import Path
5
5
  from collections import namedtuple
6
+ from random import randrange
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, stack, is_tensor, tensor
@@ -356,9 +357,10 @@ class LatentGenePool(Module):
356
357
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
357
358
  frac_elitism = 0.1, # frac of population to preserve from being noised
358
359
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
360
+ migrate_every = 100, # how many steps before a migration between islands
359
361
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
360
362
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
361
- default_should_run_ga_gamma = 1.5
363
+ default_should_run_ga_gamma = 1.5,
362
364
  ):
363
365
  super().__init__()
364
366
 
@@ -408,6 +410,13 @@ class LatentGenePool(Module):
408
410
 
409
411
  self.should_run_genetic_algorithm = should_run_genetic_algorithm
410
412
 
413
+ self.can_migrate = num_islands > 1
414
+ self.migrate_every = migrate_every
415
+ self.register_buffer('step', tensor(0))
416
+
417
+ def advance_step_(self):
418
+ self.step.add_(1)
419
+
411
420
  def firefly_step(
412
421
  self,
413
422
  fitness,
@@ -459,7 +468,7 @@ class LatentGenePool(Module):
459
468
  self,
460
469
  fitness, # Float['p'],
461
470
  inplace = True,
462
- migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
471
+ migrate = None # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
463
472
  ):
464
473
  device = self.latents.device
465
474
 
@@ -546,8 +555,9 @@ class LatentGenePool(Module):
546
555
 
547
556
  # 6. maybe migration
548
557
 
558
+ migrate = self.can_migrate and default(migrate, divisible_by(self.step.item(), self.migrate_every))
559
+
549
560
  if migrate:
550
- assert self.num_islands > 1
551
561
  randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
552
562
 
553
563
  migrate_mask = randperm < self.num_migrate
@@ -580,6 +590,8 @@ class LatentGenePool(Module):
580
590
 
581
591
  self.latents.copy_(genes)
582
592
 
593
+ self.advance_step_()
594
+
583
595
  def forward(
584
596
  self,
585
597
  *args,
@@ -989,7 +1001,8 @@ class EPO(Module):
989
1001
  @torch.no_grad()
990
1002
  def forward(
991
1003
  self,
992
- env
1004
+ env,
1005
+ fix_seed_across_latents = True
993
1006
  ) -> MemoriesAndCumulativeRewards:
994
1007
 
995
1008
  self.agent.eval()
@@ -1002,12 +1015,22 @@ class EPO(Module):
1002
1015
 
1003
1016
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1004
1017
 
1018
+ # maybe fix seed for environment across all latents
1019
+
1020
+ env_reset_kwargs = dict()
1021
+
1022
+ if fix_seed_across_latents:
1023
+ seed = randrange(int(1e6))
1024
+ env_reset_kwargs = dict(seed = seed)
1025
+
1026
+ # for each latent (on a single machine for now)
1027
+
1005
1028
  for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
1006
1029
  time = 0
1007
1030
 
1008
1031
  # initial state
1009
1032
 
1010
- state = env.reset()
1033
+ state = env.reset(**env_reset_kwargs)
1011
1034
 
1012
1035
  # get latent from pool
1013
1036
 
@@ -25,7 +25,8 @@ class Env(Module):
25
25
  return self.dummy.device
26
26
 
27
27
  def reset(
28
- self
28
+ self,
29
+ seed
29
30
  ):
30
31
  state = randn(self.state_shape, device = self.device)
31
32
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.46
3
+ Version: 0.0.48
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/epo.py,sha256=FkliOiKdmUvKuwFqb1_A-ddahnOqjTR8Djx_I6UZAlU,32625
3
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
+ evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
5
+ evolutionary_policy_optimization-0.0.48.dist-info/METADATA,sha256=GpuUVs0VO2ydhU3X4-A_cA_xNmvdtYaAM8tb_VKneBo,6213
6
+ evolutionary_policy_optimization-0.0.48.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.48.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.48.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=SAhWgRY8uPQEKFg1_nz1mvh8A6S_sHwnDykhd0F5xEI,31853
3
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
- evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
5
- evolutionary_policy_optimization-0.0.46.dist-info/METADATA,sha256=xP2kdKo52-X4Z5XXTPpW0M_NFI0spuigeL7fvqFlsRM,6213
6
- evolutionary_policy_optimization-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.46.dist-info/RECORD,,