metacontroller-pytorch 0.0.43__py3-none-any.whl → 0.0.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -407,10 +407,19 @@ class Transformer(Module):
407
407
 
408
408
  # meta controller
409
409
 
410
- self.meta_controller = meta_controller
410
+ self.meta_controller = meta_controller
411
411
 
412
412
  self.register_buffer('zero', tensor(0.), persistent = False)
413
413
 
414
+ # ensure devices match
415
+
416
+ if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
417
+
418
+ def _ensure_consistent_device(self, network):
419
+ self.model_device = next(self.parameters()).device
420
+ if next(network.parameters()).device != self.model_device:
421
+ network.to(self.model_device)
422
+
414
423
  def evolve(
415
424
  self,
416
425
  num_generations,
@@ -447,6 +456,8 @@ class Transformer(Module):
447
456
 
448
457
  # meta controller is either given or already given at init
449
458
 
459
+ if exists(meta_controller): self._ensure_consistent_device(meta_controller)
460
+
450
461
  meta_controller = default(meta_controller, self.meta_controller)
451
462
 
452
463
  if force_behavior_cloning:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.43
3
+ Version: 0.0.44
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,8 +1,8 @@
1
1
  metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
- metacontroller/metacontroller.py,sha256=3L5vPXFm7WbYuuWuCd9RpkMVC--mNpAlsJxpLwMeQ8M,17959
2
+ metacontroller/metacontroller.py,sha256=hnIMuGi56bPGN9nm8J4IEqPfvRc4VvKBqDwJXwqUpyg,18399
3
3
  metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
4
4
  metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.43.dist-info/METADATA,sha256=y1CIKBvy4jTaqsC_FBfiA1FK7pQrqTasiwA96M83luc,6816
6
- metacontroller_pytorch-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.43.dist-info/RECORD,,
5
+ metacontroller_pytorch-0.0.44.dist-info/METADATA,sha256=0EcZBbV06qKgBlSWgtrHfQizqCQE5AUZXLlhDXQA3xo,6816
6
+ metacontroller_pytorch-0.0.44.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.44.dist-info/RECORD,,