x-evolution 0.1.17__tar.gz → 0.1.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.17
3
+ Version: 0.1.19
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -153,3 +153,4 @@ $ accelerate launch train.py
153
153
  }
154
154
  ```
155
155
 
156
+ *Nothing makes sense except in the light of evolution* - Theodosius Dobzhansky
@@ -104,3 +104,4 @@ $ accelerate launch train.py
104
104
  }
105
105
  ```
106
106
 
107
+ *Nothing makes sense except in the light of evolution* - Theodosius Dobzhansky
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-evolution"
3
- version = "0.1.17"
3
+ version = "0.1.19"
4
4
  description = "x-evolution"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,20 +25,25 @@ class LunarEnvironment(Module):
25
25
 
26
26
  env = gym.make('LunarLander-v3', render_mode = 'rgb_array')
27
27
 
28
- rmtree(video_folder, ignore_errors = True)
28
+ self.env = env
29
+ self.max_steps = max_steps
30
+ self.repeats = repeats
31
+ self.video_folder = video_folder
32
+ self.render_every_eps = render_every_eps
33
+
34
+ def pre_main_callback(self):
35
+ # the `pre_main_callback` on the environment passed in is called before the start of the evolutionary strategies loop
36
+
37
+ rmtree(self.video_folder, ignore_errors = True)
29
38
 
30
- env = gym.wrappers.RecordVideo(
31
- env = env,
32
- video_folder = video_folder,
39
+ self.env = gym.wrappers.RecordVideo(
40
+ env = self.env,
41
+ video_folder = self.video_folder,
33
42
  name_prefix = 'recording',
34
- episode_trigger = lambda eps_num: (eps_num % render_every_eps) == 0,
43
+ episode_trigger = lambda eps_num: (eps_num % self.render_every_eps) == 0,
35
44
  disable_logger = True
36
45
  )
37
46
 
38
- self.env = env
39
- self.max_steps = max_steps
40
- self.repeats = repeats
41
-
42
47
  def forward(self, model):
43
48
 
44
49
  device = next(model.parameters()).device
@@ -96,20 +96,30 @@ class EvoStrategy(Module):
96
96
  reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
97
97
  ):
98
98
  super().__init__()
99
+ self.verbose = verbose
99
100
 
100
101
  if not exists(accelerator):
101
102
  accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
102
103
 
103
104
  self.accelerate = accelerator
104
105
 
106
+ # environment - with optional init
107
+
108
+ self.environment = environment
109
+
110
+ if accelerator.is_main_process:
111
+ if hasattr(environment, 'pre_main_callback') and callable(environment.pre_main_callback):
112
+ self.print('pre_main_callback detected on environment passed in and is invoked')
113
+ environment.pre_main_callback()
114
+
115
+ # take care of model and parameters
116
+
105
117
  if isinstance(model, list):
106
118
  model = ModuleList(model)
107
119
 
108
120
  self.model = model
109
121
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
110
122
 
111
- self.environment = environment
112
-
113
123
  named_parameters_dict = dict(model.named_parameters())
114
124
 
115
125
  param_to_name_index = {param: name for name, param in named_parameters_dict.items()}
@@ -217,10 +227,6 @@ class EvoStrategy(Module):
217
227
 
218
228
  self.reject_generation_fitnesses_if = reject_generation_fitnesses_if
219
229
 
220
- # verbose
221
-
222
- self.verbose = verbose
223
-
224
230
  # checkpointing
225
231
 
226
232
  self.checkpoint_every = checkpoint_every
File without changes
File without changes
File without changes