metacontroller-pytorch 0.0.34__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -329,6 +329,11 @@ class MetaController(Module):
329
329
  sampled_latent_action[:, -1:]
330
330
  )
331
331
 
332
+ # squeeze out the last dimension of switch_beta if single gate for all latent dimensions
333
+
334
+ if not self.switch_per_latent_dim:
335
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
336
+
332
337
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
333
338
 
334
339
  # main transformer, which is subsumed into the environment after behavioral cloning
@@ -296,4 +296,9 @@ class MetaControllerWithBinaryMapper(Module):
296
296
  sampled_codes[:, -1:]
297
297
  )
298
298
 
299
+ # squeeze out the last dimension of switch_beta if single gate for all codes
300
+
301
+ if not self.switch_per_code:
302
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
303
+
299
304
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.34
3
+ Version: 0.0.35
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=sj-cHpYm9NHZEBKbLQaf4MtZCv2lcBI2cAyj5Y9bAgc,16410
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=9mMKMp3zVQzjbJvoC1dBRibarHHgjnOf1tRyeY1VvAM,9423
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.35.dist-info/METADATA,sha256=eV2Y0yW-iY2_I0gPyCA8OqChqVWFwh3GkJZFzQcZ2a0,4747
6
+ metacontroller_pytorch-0.0.35.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.35.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=Ii4Z2MuVMXJeWxAnZnRfSKrGv1t6_Y4R5BgvmtDrSW8,16203
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=PgXK7uk--gvZPk1h4WLdCBEA7m9Ji12xixa2wqLmsLY,9234
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.34.dist-info/METADATA,sha256=xBFG34yRWTkcfJLFvC22jARLkR_W_6kcSsAV8r5UFWY,4747
6
- metacontroller_pytorch-0.0.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.34.dist-info/RECORD,,