x-evolution 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from torch.nn import Module, ModuleList, Parameter, ParameterList
11
11
  from torch.optim import SGD, Adam, Optimizer
12
12
  from torch.optim.lr_scheduler import LRScheduler
13
13
 
14
+ import torch.distributed as dist
14
15
  import torch.nn.functional as F
15
16
 
16
17
  from beartype import beartype
@@ -93,11 +94,17 @@ class EvoStrategy(Module):
93
94
  verbose = True,
94
95
  accelerator: Accelerator | None = None,
95
96
  accelerate_kwargs: dict = dict(),
96
- reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
97
+ reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
98
+ vectorized = False,
99
+ vector_size: int | None = None,
100
+ sync_on_init = True
97
101
  ):
98
102
  super().__init__()
99
103
  self.verbose = verbose
100
104
 
105
+ self.vectorized = vectorized
106
+ self.vector_size = vector_size
107
+
101
108
  if not exists(accelerator):
102
109
  accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
103
110
 
@@ -122,12 +129,10 @@ class EvoStrategy(Module):
122
129
  self.model = model
123
130
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
124
131
 
125
- # use prepare and run through environment once to sync params
126
-
127
- wrapped_model = accelerator.prepare(model)
132
+ # maybe sync model params and buffers
128
133
 
129
- with torch.no_grad():
130
- environment(wrapped_model)
134
+ if sync_on_init:
135
+ self.sync_model_params_and_buffers_()
131
136
 
132
137
  # get param dictionary
133
138
 
@@ -249,6 +254,17 @@ class EvoStrategy(Module):
249
254
  def device(self):
250
255
  return self.accelerate.device
251
256
 
257
+ @torch.no_grad()
258
+ def sync_model_params_and_buffers_(self):
259
+ if not self.accelerate.num_processes > 1:
260
+ return
261
+
262
+ for param in self.model.parameters():
263
+ dist.broadcast(param, src = 0)
264
+
265
+ for buffer in self.model.buffers():
266
+ dist.broadcast(buffer, src = 0)
267
+
252
268
  def print(self, *args, **kwargs):
253
269
  if not self.verbose:
254
270
  return
@@ -475,24 +491,28 @@ class EvoStrategy(Module):
475
491
  fitnesses.append([0., 0.] if self.mirror_sampling else 0.)
476
492
  continue
477
493
 
478
- individual_param_seeds = with_seed(individual_seed)(randint)(0, MAX_SEED_VALUE, (self.num_params,))
479
-
480
- noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
481
-
482
- # determine noise scale, which can be fixed or learned
494
+ def get_fitness(negate = False):
495
+ individual_param_seeds = with_seed(individual_seed.item())(randint)(0, MAX_SEED_VALUE, (self.num_params,))
496
+ noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
483
497
 
484
- noise_config_with_scale = dict()
498
+ noise_config_with_scale = dict()
499
+ for param_name, seed in noise_config.items():
500
+ noise_scale = self._get_noise_scale(param_name)
501
+ noise_config_with_scale[param_name] = (seed, noise_scale)
485
502
 
486
- for param_name, seed in noise_config.items():
503
+ with model.temp_add_noise_(noise_config_with_scale, negate = negate):
504
+ fitness = with_seed(maybe_rollout_seed)(self.environment)(model)
487
505
 
488
- noise_scale = self._get_noise_scale(param_name)
506
+ if isinstance(fitness, Tensor) and fitness.numel() > 1:
507
+ fitness = fitness.mean().item()
508
+ elif isinstance(fitness, Tensor):
509
+ fitness = fitness.item()
489
510
 
490
- noise_config_with_scale[param_name] = (seed, noise_scale)
511
+ return fitness
491
512
 
492
- # maybe roll out with a fixed seed
513
+ # evaluate
493
514
 
494
- with model.temp_add_noise_(noise_config_with_scale):
495
- fitness = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
515
+ fitness = get_fitness(negate = False)
496
516
 
497
517
  if not self.mirror_sampling:
498
518
  fitnesses.append(fitness)
@@ -500,8 +520,7 @@ class EvoStrategy(Module):
500
520
 
501
521
  # handle mirror sampling
502
522
 
503
- with model.temp_add_noise_(noise_config_with_scale, negate = True):
504
- fitness_mirrored = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
523
+ fitness_mirrored = get_fitness(negate = True)
505
524
 
506
525
  fitnesses.append([fitness, fitness_mirrored])
507
526
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -38,7 +38,7 @@ Requires-Dist: accelerate
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: torch>=2.4
41
- Requires-Dist: x-mlps-pytorch>=0.1.31
41
+ Requires-Dist: x-mlps-pytorch>=0.2.0
42
42
  Requires-Dist: x-transformers>=2.11.23
43
43
  Provides-Extra: examples
44
44
  Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples'
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=-G5qXGMjwVzdkxIDR6xL_YGium4KfKC0cnlY76Upy0o,19799
3
+ x_evolution-0.1.27.dist-info/METADATA,sha256=hj0MUpIGVWoOY5wHsoy_ZF_cx7s48_HZicd4IgNUFEo,5853
4
+ x_evolution-0.1.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.27.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=lvN3ePqD6a5dW1gOv0d1I9yQ4rdv6OuIVvKvXa0yRBM,19126
3
- x_evolution-0.1.25.dist-info/METADATA,sha256=1fAtssjj_t76rXwLu728z_ohlRZKQDhCm0oOl2eeIxA,5854
4
- x_evolution-0.1.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.25.dist-info/RECORD,,