metacontroller-pytorch 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- metacontroller/metacontroller.py +7 -2
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.37.dist-info}/METADATA +1 -1
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.37.dist-info}/RECORD +5 -5
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.37.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.37.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -408,6 +408,7 @@ class Transformer(Module):
|
|
|
408
408
|
meta_controller: Module | None = None,
|
|
409
409
|
cache: TransformerOutput | None = None,
|
|
410
410
|
discovery_phase = False,
|
|
411
|
+
force_behavior_cloning = False,
|
|
411
412
|
meta_controller_temperature = 1.,
|
|
412
413
|
return_raw_action_dist = False,
|
|
413
414
|
return_latents = False,
|
|
@@ -420,11 +421,15 @@ class Transformer(Module):
|
|
|
420
421
|
|
|
421
422
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
422
423
|
|
|
424
|
+
if force_behavior_cloning:
|
|
425
|
+
assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
|
|
426
|
+
meta_controller = None
|
|
427
|
+
|
|
423
428
|
has_meta_controller = exists(meta_controller)
|
|
424
429
|
|
|
425
430
|
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
426
431
|
|
|
427
|
-
behavioral_cloning = not has_meta_controller and not return_raw_action_dist
|
|
432
|
+
behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
|
|
428
433
|
|
|
429
434
|
# by default, if meta controller is passed in, transformer is no grad
|
|
430
435
|
|
|
@@ -472,7 +477,7 @@ class Transformer(Module):
|
|
|
472
477
|
|
|
473
478
|
with meta_controller_context():
|
|
474
479
|
|
|
475
|
-
if exists(meta_controller):
|
|
480
|
+
if exists(meta_controller) and not behavioral_cloning:
|
|
476
481
|
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
477
482
|
else:
|
|
478
483
|
control_signal, next_meta_hiddens = self.zero, None
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=
|
|
2
|
+
metacontroller/metacontroller.py,sha256=3cVg4TkfD8bkuED0mcGcfAEjJujcJ9tf_qMB8ict12c,17017
|
|
3
3
|
metacontroller/metacontroller_with_binary_mapper.py,sha256=odZs49ZWY7_FfEweYkD0moX7Vn0jGd91FjFTxzjLyr8,9480
|
|
4
4
|
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
-
metacontroller_pytorch-0.0.
|
|
6
|
-
metacontroller_pytorch-0.0.
|
|
7
|
-
metacontroller_pytorch-0.0.
|
|
8
|
-
metacontroller_pytorch-0.0.
|
|
5
|
+
metacontroller_pytorch-0.0.37.dist-info/METADATA,sha256=4mkDBWI-ma5TR38PpAIxEKj6VlVlQOYvJQAGVswQ3IQ,4747
|
|
6
|
+
metacontroller_pytorch-0.0.37.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.37.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.37.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|