metacontroller-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
61
61
  'input_residual_stream',
62
62
  'action_dist',
63
63
  'actions',
64
+ 'switch_beta',
64
65
  'kl_loss',
65
66
  'switch_loss'
66
67
  ))
@@ -151,7 +152,7 @@ class MetaController(Module):
151
152
  residual_stream,
152
153
  cache: MetaControllerOutput | None = None,
153
154
  discovery_phase = False,
154
- hard_switch = False,
155
+ hard_switch = None,
155
156
  temperature = 1.,
156
157
  episode_lens: Tensor | None = None
157
158
  ):
@@ -167,6 +168,8 @@ class MetaController(Module):
167
168
 
168
169
  meta_embed = self.model_to_meta(residual_stream)
169
170
 
171
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
172
+
170
173
  if discovery_phase:
171
174
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
172
175
 
@@ -269,7 +272,7 @@ class MetaController(Module):
269
272
  sampled_latent_action[:, -1:]
270
273
  )
271
274
 
272
- return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, kl_loss, switch_loss)
275
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
273
276
 
274
277
  # main transformer, which is subsumed into the environment after behavioral cloning
275
278
 
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
30
30
 
31
+ from metacontroller.metacontroller import MetaControllerOutput
32
+
31
33
  # constants
32
34
 
33
35
  LinearNoBias = partial(Linear, bias = False)
@@ -50,15 +52,6 @@ def straight_through(src, tgt):
50
52
 
51
53
  # meta controller
52
54
 
53
- MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
- 'prev_hiddens',
55
- 'input_residual_stream',
56
- 'action_dist',
57
- 'codes',
58
- 'kl_loss',
59
- 'switch_loss'
60
- ))
61
-
62
55
  @save_load()
63
56
  class MetaControllerWithBinaryMapper(Module):
64
57
  def __init__(
@@ -149,7 +142,7 @@ class MetaControllerWithBinaryMapper(Module):
149
142
  residual_stream,
150
143
  cache: MetaControllerOutput | None = None,
151
144
  discovery_phase = False,
152
- hard_switch = False,
145
+ hard_switch = None,
153
146
  temperature = 1.,
154
147
  episode_lens: Tensor | None = None
155
148
  ):
@@ -165,6 +158,8 @@ class MetaControllerWithBinaryMapper(Module):
165
158
 
166
159
  meta_embed = self.model_to_meta(residual_stream)
167
160
 
161
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
162
+
168
163
  if discovery_phase:
169
164
  mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
170
165
 
@@ -266,4 +261,4 @@ class MetaControllerWithBinaryMapper(Module):
266
261
  sampled_codes[:, -1:]
267
262
  )
268
263
 
269
- return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, kl_loss, switch_loss)
264
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.30
3
+ Version: 0.0.31
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
4
+ metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
+ metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
6
+ metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.31.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=ydkL3gYW5WGXQdQOIJQ_gibJs74laIIx-v4DmcJHi7M,14497
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=OGal6dftRPBg_QT1LNDYejNGNlmh4MBvdM41FAQJp9Y,8153
4
- metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
- metacontroller_pytorch-0.0.30.dist-info/METADATA,sha256=ghasc1GA0ZM-AZimY0FnGuRFsezVIcbI49V6TIOWeq4,4747
6
- metacontroller_pytorch-0.0.30.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.30.dist-info/RECORD,,