metacontroller-pytorch 0.0.43__py3-none-any.whl → 0.0.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +12 -1
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.44.dist-info}/METADATA +1 -1
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.44.dist-info}/RECORD +5 -5
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.44.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.44.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -407,10 +407,19 @@ class Transformer(Module):
|
|
|
407
407
|
|
|
408
408
|
# meta controller
|
|
409
409
|
|
|
410
|
-
self.meta_controller = meta_controller
|
|
410
|
+
self.meta_controller = meta_controller
|
|
411
411
|
|
|
412
412
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
413
413
|
|
|
414
|
+
# ensure devices match
|
|
415
|
+
|
|
416
|
+
if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
|
|
417
|
+
|
|
418
|
+
def _ensure_consistent_device(self, network):
|
|
419
|
+
self.model_device = next(self.parameters()).device
|
|
420
|
+
if next(network.parameters()).device != self.model_device:
|
|
421
|
+
network.to(self.model_device)
|
|
422
|
+
|
|
414
423
|
def evolve(
|
|
415
424
|
self,
|
|
416
425
|
num_generations,
|
|
@@ -447,6 +456,8 @@ class Transformer(Module):
|
|
|
447
456
|
|
|
448
457
|
# meta controller is either given or already given at init
|
|
449
458
|
|
|
459
|
+
if exists(meta_controller): self._ensure_consistent_device(meta_controller)
|
|
460
|
+
|
|
450
461
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
451
462
|
|
|
452
463
|
if force_behavior_cloning:
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
|
|
2
|
-
metacontroller/metacontroller.py,sha256=
|
|
2
|
+
metacontroller/metacontroller.py,sha256=hnIMuGi56bPGN9nm8J4IEqPfvRc4VvKBqDwJXwqUpyg,18399
|
|
3
3
|
metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
|
|
4
4
|
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
-
metacontroller_pytorch-0.0.
|
|
6
|
-
metacontroller_pytorch-0.0.
|
|
7
|
-
metacontroller_pytorch-0.0.
|
|
8
|
-
metacontroller_pytorch-0.0.
|
|
5
|
+
metacontroller_pytorch-0.0.44.dist-info/METADATA,sha256=0EcZBbV06qKgBlSWgtrHfQizqCQE5AUZXLlhDXQA3xo,6816
|
|
6
|
+
metacontroller_pytorch-0.0.44.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.44.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.43.dist-info → metacontroller_pytorch-0.0.44.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|