x-evolution 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_evolution/x_evolution.py +39 -20
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.27.dist-info}/METADATA +2 -2
- x_evolution-0.1.27.dist-info/RECORD +6 -0
- x_evolution-0.1.25.dist-info/RECORD +0 -6
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.27.dist-info}/WHEEL +0 -0
- {x_evolution-0.1.25.dist-info → x_evolution-0.1.27.dist-info}/licenses/LICENSE +0 -0
x_evolution/x_evolution.py
CHANGED
|
@@ -11,6 +11,7 @@ from torch.nn import Module, ModuleList, Parameter, ParameterList
|
|
|
11
11
|
from torch.optim import SGD, Adam, Optimizer
|
|
12
12
|
from torch.optim.lr_scheduler import LRScheduler
|
|
13
13
|
|
|
14
|
+
import torch.distributed as dist
|
|
14
15
|
import torch.nn.functional as F
|
|
15
16
|
|
|
16
17
|
from beartype import beartype
|
|
@@ -93,11 +94,17 @@ class EvoStrategy(Module):
|
|
|
93
94
|
verbose = True,
|
|
94
95
|
accelerator: Accelerator | None = None,
|
|
95
96
|
accelerate_kwargs: dict = dict(),
|
|
96
|
-
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
|
|
97
|
+
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
|
|
98
|
+
vectorized = False,
|
|
99
|
+
vector_size: int | None = None,
|
|
100
|
+
sync_on_init = True
|
|
97
101
|
):
|
|
98
102
|
super().__init__()
|
|
99
103
|
self.verbose = verbose
|
|
100
104
|
|
|
105
|
+
self.vectorized = vectorized
|
|
106
|
+
self.vector_size = vector_size
|
|
107
|
+
|
|
101
108
|
if not exists(accelerator):
|
|
102
109
|
accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
|
|
103
110
|
|
|
@@ -122,12 +129,10 @@ class EvoStrategy(Module):
|
|
|
122
129
|
self.model = model
|
|
123
130
|
self.noisable_model = Noisable(model, low_rank = noise_low_rank)
|
|
124
131
|
|
|
125
|
-
#
|
|
126
|
-
|
|
127
|
-
wrapped_model = accelerator.prepare(model)
|
|
132
|
+
# maybe sync model params and buffers
|
|
128
133
|
|
|
129
|
-
|
|
130
|
-
|
|
134
|
+
if sync_on_init:
|
|
135
|
+
self.sync_model_params_and_buffers_()
|
|
131
136
|
|
|
132
137
|
# get param dictionary
|
|
133
138
|
|
|
@@ -249,6 +254,17 @@ class EvoStrategy(Module):
|
|
|
249
254
|
def device(self):
|
|
250
255
|
return self.accelerate.device
|
|
251
256
|
|
|
257
|
+
@torch.no_grad()
|
|
258
|
+
def sync_model_params_and_buffers_(self):
|
|
259
|
+
if not self.accelerate.num_processes > 1:
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
for param in self.model.parameters():
|
|
263
|
+
dist.broadcast(param, src = 0)
|
|
264
|
+
|
|
265
|
+
for buffer in self.model.buffers():
|
|
266
|
+
dist.broadcast(buffer, src = 0)
|
|
267
|
+
|
|
252
268
|
def print(self, *args, **kwargs):
|
|
253
269
|
if not self.verbose:
|
|
254
270
|
return
|
|
@@ -475,24 +491,28 @@ class EvoStrategy(Module):
|
|
|
475
491
|
fitnesses.append([0., 0.] if self.mirror_sampling else 0.)
|
|
476
492
|
continue
|
|
477
493
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
# determine noise scale, which can be fixed or learned
|
|
494
|
+
def get_fitness(negate = False):
|
|
495
|
+
individual_param_seeds = with_seed(individual_seed.item())(randint)(0, MAX_SEED_VALUE, (self.num_params,))
|
|
496
|
+
noise_config = dict(zip(self.param_names_to_optimize, individual_param_seeds.tolist()))
|
|
483
497
|
|
|
484
|
-
|
|
498
|
+
noise_config_with_scale = dict()
|
|
499
|
+
for param_name, seed in noise_config.items():
|
|
500
|
+
noise_scale = self._get_noise_scale(param_name)
|
|
501
|
+
noise_config_with_scale[param_name] = (seed, noise_scale)
|
|
485
502
|
|
|
486
|
-
|
|
503
|
+
with model.temp_add_noise_(noise_config_with_scale, negate = negate):
|
|
504
|
+
fitness = with_seed(maybe_rollout_seed)(self.environment)(model)
|
|
487
505
|
|
|
488
|
-
|
|
506
|
+
if isinstance(fitness, Tensor) and fitness.numel() > 1:
|
|
507
|
+
fitness = fitness.mean().item()
|
|
508
|
+
elif isinstance(fitness, Tensor):
|
|
509
|
+
fitness = fitness.item()
|
|
489
510
|
|
|
490
|
-
|
|
511
|
+
return fitness
|
|
491
512
|
|
|
492
|
-
#
|
|
513
|
+
# evaluate
|
|
493
514
|
|
|
494
|
-
|
|
495
|
-
fitness = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
|
|
515
|
+
fitness = get_fitness(negate = False)
|
|
496
516
|
|
|
497
517
|
if not self.mirror_sampling:
|
|
498
518
|
fitnesses.append(fitness)
|
|
@@ -500,8 +520,7 @@ class EvoStrategy(Module):
|
|
|
500
520
|
|
|
501
521
|
# handle mirror sampling
|
|
502
522
|
|
|
503
|
-
|
|
504
|
-
fitness_mirrored = with_seed(maybe_rollout_seed)(rollout_for_fitness)()
|
|
523
|
+
fitness_mirrored = get_fitness(negate = True)
|
|
505
524
|
|
|
506
525
|
fitnesses.append([fitness, fitness_mirrored])
|
|
507
526
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-evolution
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.27
|
|
4
4
|
Summary: x-evolution
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-evolution/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-evolution
|
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerate
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: torch>=2.4
|
|
41
|
-
Requires-Dist: x-mlps-pytorch>=0.
|
|
41
|
+
Requires-Dist: x-mlps-pytorch>=0.2.0
|
|
42
42
|
Requires-Dist: x-transformers>=2.11.23
|
|
43
43
|
Provides-Extra: examples
|
|
44
44
|
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples'
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
+
x_evolution/x_evolution.py,sha256=-G5qXGMjwVzdkxIDR6xL_YGium4KfKC0cnlY76Upy0o,19799
|
|
3
|
+
x_evolution-0.1.27.dist-info/METADATA,sha256=hj0MUpIGVWoOY5wHsoy_ZF_cx7s48_HZicd4IgNUFEo,5853
|
|
4
|
+
x_evolution-0.1.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
x_evolution-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
x_evolution-0.1.27.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
-
x_evolution/x_evolution.py,sha256=lvN3ePqD6a5dW1gOv0d1I9yQ4rdv6OuIVvKvXa0yRBM,19126
|
|
3
|
-
x_evolution-0.1.25.dist-info/METADATA,sha256=1fAtssjj_t76rXwLu728z_ohlRZKQDhCm0oOl2eeIxA,5854
|
|
4
|
-
x_evolution-0.1.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
x_evolution-0.1.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
x_evolution-0.1.25.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|