metacontroller-pytorch 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,8 +58,10 @@ def straight_through(src, tgt):
58
58
 
59
59
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
60
60
  'prev_hiddens',
61
+ 'input_residual_stream',
61
62
  'action_dist',
62
63
  'actions',
64
+ 'switch_beta',
63
65
  'kl_loss',
64
66
  'switch_loss'
65
67
  ))
@@ -150,7 +152,7 @@ class MetaController(Module):
150
152
  residual_stream,
151
153
  cache: MetaControllerOutput | None = None,
152
154
  discovery_phase = False,
153
- hard_switch = False,
155
+ hard_switch = None,
154
156
  temperature = 1.,
155
157
  episode_lens: Tensor | None = None
156
158
  ):
@@ -166,6 +168,8 @@ class MetaController(Module):
166
168
 
167
169
  meta_embed = self.model_to_meta(residual_stream)
168
170
 
171
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
172
+
169
173
  if discovery_phase:
170
174
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
171
175
 
@@ -268,7 +272,7 @@ class MetaController(Module):
268
272
  sampled_latent_action[:, -1:]
269
273
  )
270
274
 
271
- return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
275
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
272
276
 
273
277
  # main transformer, which is subsumed into the environment after behavioral cloning
274
278
 
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
30
30
 
31
+ from metacontroller.metacontroller import MetaControllerOutput
32
+
31
33
  # constants
32
34
 
33
35
  LinearNoBias = partial(Linear, bias = False)
@@ -50,14 +52,6 @@ def straight_through(src, tgt):
50
52
 
51
53
  # meta controller
52
54
 
53
- MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
- 'prev_hiddens',
55
- 'action_dist',
56
- 'codes',
57
- 'kl_loss',
58
- 'switch_loss'
59
- ))
60
-
61
55
  @save_load()
62
56
  class MetaControllerWithBinaryMapper(Module):
63
57
  def __init__(
@@ -148,7 +142,7 @@ class MetaControllerWithBinaryMapper(Module):
148
142
  residual_stream,
149
143
  cache: MetaControllerOutput | None = None,
150
144
  discovery_phase = False,
151
- hard_switch = False,
145
+ hard_switch = None,
152
146
  temperature = 1.,
153
147
  episode_lens: Tensor | None = None
154
148
  ):
@@ -164,6 +158,8 @@ class MetaControllerWithBinaryMapper(Module):
164
158
 
165
159
  meta_embed = self.model_to_meta(residual_stream)
166
160
 
161
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
162
+
167
163
  if discovery_phase:
168
164
  mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
169
165
 
@@ -265,4 +261,4 @@ class MetaControllerWithBinaryMapper(Module):
265
261
  sampled_codes[:, -1:]
266
262
  )
267
263
 
268
- return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
264
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.31
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
4
+ metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
+ metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
6
+ metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.31.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=uUFCSIRq20TdctRd7O20A_I2SiB9AgYS6z5iQMFqf2Q,8107
4
- metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
- metacontroller_pytorch-0.0.29.dist-info/METADATA,sha256=8zeOj2sUZ-5V_qGXvzXoBH3lpCJqHgPfZq0-YllSrTs,4747
6
- metacontroller_pytorch-0.0.29.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.29.dist-info/RECORD,,