metacontroller-pytorch 0.0.42__py3-none-any.whl → 0.0.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +13 -2
- metacontroller/metacontroller_with_binary_mapper.py +1 -1
- {metacontroller_pytorch-0.0.42.dist-info → metacontroller_pytorch-0.0.44.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.44.dist-info/RECORD +8 -0
- metacontroller_pytorch-0.0.42.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.42.dist-info → metacontroller_pytorch-0.0.44.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.42.dist-info → metacontroller_pytorch-0.0.44.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -291,7 +291,7 @@ class MetaController(Module):
|
|
|
291
291
|
else:
|
|
292
292
|
# else during inference, use the previous sampled latent action
|
|
293
293
|
|
|
294
|
-
assert seq_len == 1,
|
|
294
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
295
295
|
z_prev = prev_sampled_latent_action
|
|
296
296
|
|
|
297
297
|
# switch input is previous latent action and the embedding
|
|
@@ -407,10 +407,19 @@ class Transformer(Module):
|
|
|
407
407
|
|
|
408
408
|
# meta controller
|
|
409
409
|
|
|
410
|
-
self.meta_controller = meta_controller
|
|
410
|
+
self.meta_controller = meta_controller
|
|
411
411
|
|
|
412
412
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
413
413
|
|
|
414
|
+
# ensure devices match
|
|
415
|
+
|
|
416
|
+
if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
|
|
417
|
+
|
|
418
|
+
def _ensure_consistent_device(self, network):
|
|
419
|
+
self.model_device = next(self.parameters()).device
|
|
420
|
+
if next(network.parameters()).device != self.model_device:
|
|
421
|
+
network.to(self.model_device)
|
|
422
|
+
|
|
414
423
|
def evolve(
|
|
415
424
|
self,
|
|
416
425
|
num_generations,
|
|
@@ -447,6 +456,8 @@ class Transformer(Module):
|
|
|
447
456
|
|
|
448
457
|
# meta controller is either given or already given at init
|
|
449
458
|
|
|
459
|
+
if exists(meta_controller): self._ensure_consistent_device(meta_controller)
|
|
460
|
+
|
|
450
461
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
451
462
|
|
|
452
463
|
if force_behavior_cloning:
|
|
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
241
241
|
if discovery_phase:
|
|
242
242
|
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
243
243
|
else:
|
|
244
|
-
assert seq_len == 1,
|
|
244
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
245
245
|
z_prev = prev_sampled_code
|
|
246
246
|
|
|
247
247
|
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
|
|
2
|
+
metacontroller/metacontroller.py,sha256=hnIMuGi56bPGN9nm8J4IEqPfvRc4VvKBqDwJXwqUpyg,18399
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
|
|
4
|
+
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
+
metacontroller_pytorch-0.0.44.dist-info/METADATA,sha256=0EcZBbV06qKgBlSWgtrHfQizqCQE5AUZXLlhDXQA3xo,6816
|
|
6
|
+
metacontroller_pytorch-0.0.44.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.44.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
|
|
2
|
-
metacontroller/metacontroller.py,sha256=hOzMIeBwNZhrzpt6tnLahxuHJ4pPQ7JlEGBOxYHI_88,17875
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
|
|
4
|
-
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
-
metacontroller_pytorch-0.0.42.dist-info/METADATA,sha256=f9KRrtFWHgZrx5HZYBGNrtfXrcfSOeZlRFfx7VYMOd0,6816
|
|
6
|
-
metacontroller_pytorch-0.0.42.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.42.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.42.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.42.dist-info → metacontroller_pytorch-0.0.44.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|