x-evolution 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -93,11 +93,16 @@ class EvoStrategy(Module):
93
93
  verbose = True,
94
94
  accelerator: Accelerator | None = None,
95
95
  accelerate_kwargs: dict = dict(),
96
- reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
96
+ reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
97
+ vectorized = False,
98
+ vector_size: int | None = None
97
99
  ):
98
100
  super().__init__()
99
101
  self.verbose = verbose
100
102
 
103
+ self.vectorized = vectorized
104
+ self.vector_size = vector_size
105
+
101
106
  if not exists(accelerator):
102
107
  accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
103
108
 
@@ -475,24 +480,28 @@ class EvoStrategy(Module):
475
480
  fitnesses.append([0., 0.] if self.mirror_sampling else 0.)
476
481
  continue
477
482
 
478
- individual_param_seeds = with_seed(individual_seed)(randint)(0, MAX_SEED_VALUE, (self.num_params,))
479
-
480
- noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
481
-
482
- # determine noise scale, which can be fixed or learned
483
+ def get_fitness(negate = False):
484
+ individual_param_seeds = with_seed(individual_seed.item())(randint)(0, MAX_SEED_VALUE, (self.num_params,))
485
+ noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
483
486
 
484
- noise_config_with_scale = dict()
487
+ noise_config_with_scale = dict()
488
+ for param_name, seed in noise_config.items():
489
+ noise_scale = self._get_noise_scale(param_name)
490
+ noise_config_with_scale[param_name] = (seed, noise_scale)
485
491
 
486
- for param_name, seed in noise_config.items():
492
+ with model.temp_add_noise_(noise_config_with_scale, negate = negate):
493
+ fitness = with_seed(maybe_rollout_seed)(self.environment)(model)
487
494
 
488
- noise_scale = self._get_noise_scale(param_name)
495
+ if isinstance(fitness, Tensor) and fitness.numel() > 1:
496
+ fitness = fitness.mean().item()
497
+ elif isinstance(fitness, Tensor):
498
+ fitness = fitness.item()
489
499
 
490
- noise_config_with_scale[param_name] = (seed, noise_scale)
500
+ return fitness
491
501
 
492
- # maybe roll out with a fixed seed
502
+ # evaluate
493
503
 
494
- with model.temp_add_noise_(noise_config_with_scale):
495
- fitness = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
504
+ fitness = get_fitness(negate = False)
496
505
 
497
506
  if not self.mirror_sampling:
498
507
  fitnesses.append(fitness)
@@ -500,8 +509,7 @@ class EvoStrategy(Module):
500
509
 
501
510
  # handle mirror sampling
502
511
 
503
- with model.temp_add_noise_(noise_config_with_scale, negate = True):
504
- fitness_mirrored = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
512
+ fitness_mirrored = get_fitness(negate = True)
505
513
 
506
514
  fitnesses.append([fitness, fitness_mirrored])
507
515
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -38,7 +38,7 @@ Requires-Dist: accelerate
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: torch>=2.4
41
- Requires-Dist: x-mlps-pytorch>=0.1.31
41
+ Requires-Dist: x-mlps-pytorch>=0.2.0
42
42
  Requires-Dist: x-transformers>=2.11.23
43
43
  Provides-Extra: examples
44
44
  Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples'
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=Jln3wpkIQIp7xwa3KMibh0kSuob1NIi7Aj7Miz8RJdY,19491
3
+ x_evolution-0.1.26.dist-info/METADATA,sha256=7zamSGrDtvOUQzpyYCkrt42FmWC4GPYpVSWUAO8a6OA,5853
4
+ x_evolution-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.26.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=lvN3ePqD6a5dW1gOv0d1I9yQ4rdv6OuIVvKvXa0yRBM,19126
3
- x_evolution-0.1.24.dist-info/METADATA,sha256=V811hLDPjhaqPwy-9Y0w8GihtH11-OWtlmqHBwDfro8,5854
4
- x_evolution-0.1.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.24.dist-info/RECORD,,