metacontroller-pytorch 0.0.43__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +39 -4
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.46.dist-info}/METADATA +1 -1
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.46.dist-info}/RECORD +5 -5
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.46.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.46.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -54,6 +54,19 @@ def default(*args):
|
|
|
54
54
|
def straight_through(src, tgt):
|
|
55
55
|
return tgt + src - src.detach()
|
|
56
56
|
|
|
57
|
+
# losses
|
|
58
|
+
|
|
59
|
+
BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
|
|
60
|
+
'state',
|
|
61
|
+
'action'
|
|
62
|
+
))
|
|
63
|
+
|
|
64
|
+
DiscoveryLosses = namedtuple('DiscoveryLosses', (
|
|
65
|
+
'action_recon',
|
|
66
|
+
'kl',
|
|
67
|
+
'switch'
|
|
68
|
+
))
|
|
69
|
+
|
|
57
70
|
# meta controller
|
|
58
71
|
|
|
59
72
|
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
@@ -407,10 +420,19 @@ class Transformer(Module):
|
|
|
407
420
|
|
|
408
421
|
# meta controller
|
|
409
422
|
|
|
410
|
-
self.meta_controller = meta_controller
|
|
423
|
+
self.meta_controller = meta_controller
|
|
411
424
|
|
|
412
425
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
413
426
|
|
|
427
|
+
# ensure devices match
|
|
428
|
+
|
|
429
|
+
if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
|
|
430
|
+
|
|
431
|
+
def _ensure_consistent_device(self, network):
|
|
432
|
+
self.model_device = next(self.parameters()).device
|
|
433
|
+
if next(network.parameters()).device != self.model_device:
|
|
434
|
+
network.to(self.model_device)
|
|
435
|
+
|
|
414
436
|
def evolve(
|
|
415
437
|
self,
|
|
416
438
|
num_generations,
|
|
@@ -441,12 +463,15 @@ class Transformer(Module):
|
|
|
441
463
|
return_raw_action_dist = False,
|
|
442
464
|
return_latents = False,
|
|
443
465
|
return_cache = False,
|
|
444
|
-
episode_lens: Tensor | None = None
|
|
466
|
+
episode_lens: Tensor | None = None,
|
|
467
|
+
return_meta_controller_output = False
|
|
445
468
|
):
|
|
446
469
|
device = state.device
|
|
447
470
|
|
|
448
471
|
# meta controller is either given or already given at init
|
|
449
472
|
|
|
473
|
+
if exists(meta_controller): self._ensure_consistent_device(meta_controller)
|
|
474
|
+
|
|
450
475
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
451
476
|
|
|
452
477
|
if force_behavior_cloning:
|
|
@@ -533,13 +558,23 @@ class Transformer(Module):
|
|
|
533
558
|
|
|
534
559
|
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
|
|
535
560
|
|
|
536
|
-
|
|
561
|
+
losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
|
|
562
|
+
|
|
563
|
+
if not return_meta_controller_output:
|
|
564
|
+
return losses
|
|
565
|
+
|
|
566
|
+
return losses, next_meta_hiddens
|
|
537
567
|
|
|
538
568
|
elif discovery_phase:
|
|
539
569
|
|
|
540
570
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
541
571
|
|
|
542
|
-
|
|
572
|
+
losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
|
|
573
|
+
|
|
574
|
+
if not return_meta_controller_output:
|
|
575
|
+
return losses
|
|
576
|
+
|
|
577
|
+
return losses, next_meta_hiddens
|
|
543
578
|
|
|
544
579
|
# returning
|
|
545
580
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
|
|
2
|
-
metacontroller/metacontroller.py,sha256=
|
|
2
|
+
metacontroller/metacontroller.py,sha256=lp5gd0x59UH7anrzRT8NfmckYJHALqVf8HInYqdgUoQ,18949
|
|
3
3
|
metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
|
|
4
4
|
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
-
metacontroller_pytorch-0.0.
|
|
6
|
-
metacontroller_pytorch-0.0.
|
|
7
|
-
metacontroller_pytorch-0.0.
|
|
8
|
-
metacontroller_pytorch-0.0.
|
|
5
|
+
metacontroller_pytorch-0.0.46.dist-info/METADATA,sha256=NL9ib9WWJn0IDhYHENp3L8blfIDPGdrHqCnB7EyV3IA,6816
|
|
6
|
+
metacontroller_pytorch-0.0.46.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.46.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.46.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|