x-evolution 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_evolution/x_evolution.py +23 -15
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.26.dist-info}/METADATA +2 -2
- x_evolution-0.1.26.dist-info/RECORD +6 -0
- x_evolution-0.1.25.dist-info/RECORD +0 -6
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.26.dist-info}/WHEEL +0 -0
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.26.dist-info}/licenses/LICENSE +0 -0
x_evolution/x_evolution.py
CHANGED
|
@@ -93,11 +93,16 @@ class EvoStrategy(Module):
|
|
|
93
93
|
verbose = True,
|
|
94
94
|
accelerator: Accelerator | None = None,
|
|
95
95
|
accelerate_kwargs: dict = dict(),
|
|
96
|
-
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
|
|
96
|
+
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
|
|
97
|
+
vectorized = False,
|
|
98
|
+
vector_size: int | None = None
|
|
97
99
|
):
|
|
98
100
|
super().__init__()
|
|
99
101
|
self.verbose = verbose
|
|
100
102
|
|
|
103
|
+
self.vectorized = vectorized
|
|
104
|
+
self.vector_size = vector_size
|
|
105
|
+
|
|
101
106
|
if not exists(accelerator):
|
|
102
107
|
accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
|
|
103
108
|
|
|
@@ -475,24 +480,28 @@ class EvoStrategy(Module):
|
|
|
475
480
|
fitnesses.append([0., 0.] if self.mirror_sampling else 0.)
|
|
476
481
|
continue
|
|
477
482
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
# determine noise scale, which can be fixed or learned
|
|
483
|
+
def get_fitness(negate = False):
|
|
484
|
+
individual_param_seeds = with_seed(individual_seed.item())(randint)(0, MAX_SEED_VALUE, (self.num_params,))
|
|
485
|
+
noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
|
|
483
486
|
|
|
484
|
-
|
|
487
|
+
noise_config_with_scale = dict()
|
|
488
|
+
for param_name, seed in noise_config.items():
|
|
489
|
+
noise_scale = self._get_noise_scale(param_name)
|
|
490
|
+
noise_config_with_scale[param_name] = (seed, noise_scale)
|
|
485
491
|
|
|
486
|
-
|
|
492
|
+
with model.temp_add_noise_(noise_config_with_scale, negate = negate):
|
|
493
|
+
fitness = with_seed(maybe_rollout_seed)(self.environment)(model)
|
|
487
494
|
|
|
488
|
-
|
|
495
|
+
if isinstance(fitness, Tensor) and fitness.numel() > 1:
|
|
496
|
+
fitness = fitness.mean().item()
|
|
497
|
+
elif isinstance(fitness, Tensor):
|
|
498
|
+
fitness = fitness.item()
|
|
489
499
|
|
|
490
|
-
|
|
500
|
+
return fitness
|
|
491
501
|
|
|
492
|
-
#
|
|
502
|
+
# evaluate
|
|
493
503
|
|
|
494
|
-
|
|
495
|
-
fitness = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
|
|
504
|
+
fitness = get_fitness(negate = False)
|
|
496
505
|
|
|
497
506
|
if not self.mirror_sampling:
|
|
498
507
|
fitnesses.append(fitness)
|
|
@@ -500,8 +509,7 @@ class EvoStrategy(Module):
|
|
|
500
509
|
|
|
501
510
|
# handle mirror sampling
|
|
502
511
|
|
|
503
|
-
|
|
504
|
-
fitness_mirrored = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
|
|
512
|
+
fitness_mirrored = get_fitness(negate = True)
|
|
505
513
|
|
|
506
514
|
fitnesses.append([fitness, fitness_mirrored])
|
|
507
515
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-evolution
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.26
|
|
4
4
|
Summary: x-evolution
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-evolution/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-evolution
|
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerate
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: torch>=2.4
|
|
41
|
-
Requires-Dist: x-mlps-pytorch>=0.
|
|
41
|
+
Requires-Dist: x-mlps-pytorch>=0.2.0
|
|
42
42
|
Requires-Dist: x-transformers>=2.11.23
|
|
43
43
|
Provides-Extra: examples
|
|
44
44
|
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples'
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
+
x_evolution/x_evolution.py,sha256=Jln3wpkIQIp7xwa3KMibh0kSuob1NIi7Aj7Miz8RJdY,19491
|
|
3
|
+
x_evolution-0.1.26.dist-info/METADATA,sha256=7zamSGrDtvOUQzpyYCkrt42FmWC4GPYpVSWUAO8a6OA,5853
|
|
4
|
+
x_evolution-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
x_evolution-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
x_evolution-0.1.26.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
-
x_evolution/x_evolution.py,sha256=lvN3ePqD6a5dW1gOv0d1I9yQ4rdv6OuIVvKvXa0yRBM,19126
|
|
3
|
-
x_evolution-0.1.25.dist-info/METADATA,sha256=1fAtssjj_t76rXwLu728z_ohlRZKQDhCm0oOl2eeIxA,5854
|
|
4
|
-
x_evolution-0.1.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
x_evolution-0.1.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
x_evolution-0.1.25.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|