x-evolution 0.1.19__tar.gz → 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_evolution-0.1.19 → x_evolution-0.1.23}/PKG-INFO +1 -1
- {x_evolution-0.1.19 → x_evolution-0.1.23}/pyproject.toml +1 -1
- {x_evolution-0.1.19 → x_evolution-0.1.23}/x_evolution/x_evolution.py +11 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/.github/workflows/python-publish.yml +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/.github/workflows/test.yml +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/.gitignore +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/LICENSE +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/README.md +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/tests/test_evolution.py +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/train_lunar.py +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/train_mnist.py +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/train_xor.py +0 -0
- {x_evolution-0.1.19 → x_evolution-0.1.23}/x_evolution/__init__.py +0 -0
|
@@ -112,6 +112,8 @@ class EvoStrategy(Module):
|
|
|
112
112
|
self.print('pre_main_callback detected on environment passed in and is invoked')
|
|
113
113
|
environment.pre_main_callback()
|
|
114
114
|
|
|
115
|
+
accelerator.wait_for_everyone()
|
|
116
|
+
|
|
115
117
|
# take care of model and parameters
|
|
116
118
|
|
|
117
119
|
if isinstance(model, list):
|
|
@@ -120,6 +122,15 @@ class EvoStrategy(Module):
|
|
|
120
122
|
self.model = model
|
|
121
123
|
self.noisable_model = Noisable(model, low_rank = noise_low_rank)
|
|
122
124
|
|
|
125
|
+
# use prepare and run through environment once to sync params
|
|
126
|
+
|
|
127
|
+
wrapped_model = accelerator.prepare(model)
|
|
128
|
+
|
|
129
|
+
with torch.no_grad():
|
|
130
|
+
environment(wrapped_model)
|
|
131
|
+
|
|
132
|
+
# get param dictionary
|
|
133
|
+
|
|
123
134
|
named_parameters_dict = dict(model.named_parameters())
|
|
124
135
|
|
|
125
136
|
param_to_name_index = {param: name for name, param in named_parameters_dict.items()}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|