metacontroller-pytorch 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -54,6 +54,19 @@ def default(*args):
54
54
  def straight_through(src, tgt):
55
55
  return tgt + src - src.detach()
56
56
 
57
+ # losses
58
+
59
+ BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
60
+ 'state',
61
+ 'action'
62
+ ))
63
+
64
+ DiscoveryLosses = namedtuple('DiscoveryLosses', (
65
+ 'action_recon',
66
+ 'kl',
67
+ 'switch'
68
+ ))
69
+
57
70
  # meta controller
58
71
 
59
72
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
@@ -450,7 +463,8 @@ class Transformer(Module):
450
463
  return_raw_action_dist = False,
451
464
  return_latents = False,
452
465
  return_cache = False,
453
- episode_lens: Tensor | None = None
466
+ episode_lens: Tensor | None = None,
467
+ return_meta_controller_output = False
454
468
  ):
455
469
  device = state.device
456
470
 
@@ -544,13 +558,23 @@ class Transformer(Module):
544
558
 
545
559
  action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
546
560
 
547
- return state_clone_loss, action_clone_loss
561
+ losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
562
+
563
+ if not return_meta_controller_output:
564
+ return losses
565
+
566
+ return losses, next_meta_hiddens
548
567
 
549
568
  elif discovery_phase:
550
569
 
551
570
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
552
571
 
553
- return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
572
+ losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
573
+
574
+ if not return_meta_controller_output:
575
+ return losses
576
+
577
+ return losses, next_meta_hiddens
554
578
 
555
579
  # returning
556
580
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.44
3
+ Version: 0.0.46
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,8 +1,8 @@
1
1
  metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
- metacontroller/metacontroller.py,sha256=hnIMuGi56bPGN9nm8J4IEqPfvRc4VvKBqDwJXwqUpyg,18399
2
+ metacontroller/metacontroller.py,sha256=lp5gd0x59UH7anrzRT8NfmckYJHALqVf8HInYqdgUoQ,18949
3
3
  metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
4
4
  metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.44.dist-info/METADATA,sha256=0EcZBbV06qKgBlSWgtrHfQizqCQE5AUZXLlhDXQA3xo,6816
6
- metacontroller_pytorch-0.0.44.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.44.dist-info/RECORD,,
5
+ metacontroller_pytorch-0.0.46.dist-info/METADATA,sha256=NL9ib9WWJn0IDhYHENp3L8blfIDPGdrHqCnB7EyV3IA,6816
6
+ metacontroller_pytorch-0.0.46.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.46.dist-info/RECORD,,