metacontroller-pytorch 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -408,6 +408,7 @@ class Transformer(Module):
408
408
  meta_controller: Module | None = None,
409
409
  cache: TransformerOutput | None = None,
410
410
  discovery_phase = False,
411
+ force_behavior_cloning = False,
411
412
  meta_controller_temperature = 1.,
412
413
  return_raw_action_dist = False,
413
414
  return_latents = False,
@@ -420,11 +421,15 @@ class Transformer(Module):
420
421
 
421
422
  meta_controller = default(meta_controller, self.meta_controller)
422
423
 
424
+ if force_behavior_cloning:
425
+ assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
426
+ meta_controller = None
427
+
423
428
  has_meta_controller = exists(meta_controller)
424
429
 
425
430
  assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
426
431
 
427
- behavioral_cloning = not has_meta_controller and not return_raw_action_dist
432
+ behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
428
433
 
429
434
  # by default, if meta controller is passed in, transformer is no grad
430
435
 
@@ -472,7 +477,7 @@ class Transformer(Module):
472
477
 
473
478
  with meta_controller_context():
474
479
 
475
- if exists(meta_controller):
480
+ if exists(meta_controller) and not behavioral_cloning:
476
481
  control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
477
482
  else:
478
483
  control_signal, next_meta_hiddens = self.zero, None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.37
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,8 +1,8 @@
1
1
  metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=rVYzBJ8jQx9tfkZ3B9NdxKTI7dyBxXtTl4kwfizYuis,16728
2
+ metacontroller/metacontroller.py,sha256=3cVg4TkfD8bkuED0mcGcfAEjJujcJ9tf_qMB8ict12c,17017
3
3
  metacontroller/metacontroller_with_binary_mapper.py,sha256=odZs49ZWY7_FfEweYkD0moX7Vn0jGd91FjFTxzjLyr8,9480
4
4
  metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.36.dist-info/METADATA,sha256=eLKG8B0gSZyIMkaLvjYE8SvWVN387BuuNFoOC_6lmT4,4747
6
- metacontroller_pytorch-0.0.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.36.dist-info/RECORD,,
5
+ metacontroller_pytorch-0.0.37.dist-info/METADATA,sha256=4mkDBWI-ma5TR38PpAIxEKj6VlVlQOYvJQAGVswQ3IQ,4747
6
+ metacontroller_pytorch-0.0.37.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.37.dist-info/RECORD,,