metacontroller-pytorch 0.0.6__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/metacontroller/metacontroller.py +4 -6
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/tests/test_metacontroller.py +3 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/README.md +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/metacontroller/metacontroller.py
RENAMED
|
@@ -113,9 +113,7 @@ class MetaController(Module):
|
|
|
113
113
|
def internal_rl_parameters(self):
|
|
114
114
|
return [
|
|
115
115
|
*self.action_proposer.parameters(),
|
|
116
|
-
*self.action_proposer_mean_log_var.parameters()
|
|
117
|
-
*self.decoder.parameters(),
|
|
118
|
-
*self.switch_gating
|
|
116
|
+
*self.action_proposer_mean_log_var.parameters()
|
|
119
117
|
]
|
|
120
118
|
|
|
121
119
|
def forward(
|
|
@@ -150,8 +148,6 @@ class MetaController(Module):
|
|
|
150
148
|
|
|
151
149
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
150
|
|
|
153
|
-
# switch_beta = switch_beta.expand_as(sampled_action)
|
|
154
|
-
|
|
155
151
|
# need to encourage normal distribution
|
|
156
152
|
|
|
157
153
|
vae_kl_loss = self.zero
|
|
@@ -233,13 +229,15 @@ class Transformer(Module):
|
|
|
233
229
|
|
|
234
230
|
def evolve(
|
|
235
231
|
self,
|
|
232
|
+
num_generations,
|
|
236
233
|
environment,
|
|
237
234
|
**kwargs
|
|
238
235
|
):
|
|
239
|
-
assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
|
|
236
|
+
assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
|
|
240
237
|
|
|
241
238
|
evo_strat = EvoStrategy(
|
|
242
239
|
self,
|
|
240
|
+
num_generations = num_generations,
|
|
243
241
|
environment = environment,
|
|
244
242
|
params_to_optimize = self.meta_controller.internal_rl_parameters(),
|
|
245
243
|
**kwargs
|
|
@@ -29,3 +29,6 @@ def test_metacontroller(
|
|
|
29
29
|
logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
|
|
30
30
|
|
|
31
31
|
assert logits.shape == (1, 1024, 256)
|
|
32
|
+
|
|
33
|
+
model.meta_controller = meta_controller
|
|
34
|
+
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
{metacontroller_pytorch-0.0.6 → metacontroller_pytorch-0.0.8}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|