metacontroller-pytorch 0.0.43__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,6 +54,19 @@ def default(*args):
54
54
  def straight_through(src, tgt):
55
55
  return tgt + src - src.detach()
56
56
 
57
+ # losses
58
+
59
+ BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
60
+ 'state',
61
+ 'action'
62
+ ))
63
+
64
+ DiscoveryLosses = namedtuple('DiscoveryLosses', (
65
+ 'action_recon',
66
+ 'kl',
67
+ 'switch'
68
+ ))
69
+
57
70
  # meta controller
58
71
 
59
72
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
@@ -407,10 +420,19 @@ class Transformer(Module):
407
420
 
408
421
  # meta controller
409
422
 
410
- self.meta_controller = meta_controller
423
+ self.meta_controller = meta_controller
411
424
 
412
425
  self.register_buffer('zero', tensor(0.), persistent = False)
413
426
 
427
+ # ensure devices match
428
+
429
+ if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
430
+
431
+ def _ensure_consistent_device(self, network):
432
+ self.model_device = next(self.parameters()).device
433
+ if next(network.parameters()).device != self.model_device:
434
+ network.to(self.model_device)
435
+
414
436
  def evolve(
415
437
  self,
416
438
  num_generations,
@@ -441,12 +463,15 @@ class Transformer(Module):
441
463
  return_raw_action_dist = False,
442
464
  return_latents = False,
443
465
  return_cache = False,
444
- episode_lens: Tensor | None = None
466
+ episode_lens: Tensor | None = None,
467
+ return_meta_controller_output = False
445
468
  ):
446
469
  device = state.device
447
470
 
448
471
  # meta controller is either given or already given at init
449
472
 
473
+ if exists(meta_controller): self._ensure_consistent_device(meta_controller)
474
+
450
475
  meta_controller = default(meta_controller, self.meta_controller)
451
476
 
452
477
  if force_behavior_cloning:
@@ -533,13 +558,23 @@ class Transformer(Module):
533
558
 
534
559
  action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
535
560
 
536
- return state_clone_loss, action_clone_loss
561
+ losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
562
+
563
+ if not return_meta_controller_output:
564
+ return losses
565
+
566
+ return losses, next_meta_hiddens
537
567
 
538
568
  elif discovery_phase:
539
569
 
540
570
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
541
571
 
542
- return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
572
+ losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
573
+
574
+ if not return_meta_controller_output:
575
+ return losses
576
+
577
+ return losses, next_meta_hiddens
543
578
 
544
579
  # returning
545
580
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.43
3
+ Version: 0.0.46
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,8 +1,8 @@
1
1
  metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
- metacontroller/metacontroller.py,sha256=3L5vPXFm7WbYuuWuCd9RpkMVC--mNpAlsJxpLwMeQ8M,17959
2
+ metacontroller/metacontroller.py,sha256=lp5gd0x59UH7anrzRT8NfmckYJHALqVf8HInYqdgUoQ,18949
3
3
  metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
4
4
  metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.43.dist-info/METADATA,sha256=y1CIKBvy4jTaqsC_FBfiA1FK7pQrqTasiwA96M83luc,6816
6
- metacontroller_pytorch-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.43.dist-info/RECORD,,
5
+ metacontroller_pytorch-0.0.46.dist-info/METADATA,sha256=NL9ib9WWJn0IDhYHENp3L8blfIDPGdrHqCnB7EyV3IA,6816
6
+ metacontroller_pytorch-0.0.46.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.46.dist-info/RECORD,,