metacontroller-pytorch 0.0.30__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
61
61
  'input_residual_stream',
62
62
  'action_dist',
63
63
  'actions',
64
+ 'switch_beta',
64
65
  'kl_loss',
65
66
  'switch_loss'
66
67
  ))
@@ -106,7 +107,6 @@ class MetaController(Module):
106
107
 
107
108
  self.switch_per_latent_dim = switch_per_latent_dim
108
109
 
109
-
110
110
  self.dim_latent = dim_latent
111
111
  self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
112
112
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
@@ -146,12 +146,19 @@ class MetaController(Module):
146
146
  *self.action_proposer_mean_log_var.parameters()
147
147
  ]
148
148
 
149
+ def log_prob(
150
+ self,
151
+ action_dist,
152
+ sampled_latent_action
153
+ ):
154
+ return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
155
+
149
156
  def forward(
150
157
  self,
151
158
  residual_stream,
152
159
  cache: MetaControllerOutput | None = None,
153
160
  discovery_phase = False,
154
- hard_switch = False,
161
+ hard_switch = None,
155
162
  temperature = 1.,
156
163
  episode_lens: Tensor | None = None
157
164
  ):
@@ -167,6 +174,8 @@ class MetaController(Module):
167
174
 
168
175
  meta_embed = self.model_to_meta(residual_stream)
169
176
 
177
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
178
+
170
179
  if discovery_phase:
171
180
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
172
181
 
@@ -269,10 +278,16 @@ class MetaController(Module):
269
278
  sampled_latent_action[:, -1:]
270
279
  )
271
280
 
272
- return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, kl_loss, switch_loss)
281
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
273
282
 
274
283
  # main transformer, which is subsumed into the environment after behavioral cloning
275
284
 
285
+ Hiddens = namedtuple('Hiddens', (
286
+ 'lower_body',
287
+ 'meta_controller',
288
+ 'upper_body'
289
+ ))
290
+
276
291
  TransformerOutput = namedtuple('TransformerOutput', (
277
292
  'residual_stream_latent',
278
293
  'prev_hiddens'
@@ -438,4 +453,4 @@ class Transformer(Module):
438
453
  if return_one:
439
454
  return dist_params
440
455
 
441
- return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
456
+ return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
30
30
 
31
+ from metacontroller.metacontroller import MetaControllerOutput
32
+
31
33
  # constants
32
34
 
33
35
  LinearNoBias = partial(Linear, bias = False)
@@ -48,16 +50,10 @@ def default(*args):
48
50
  def straight_through(src, tgt):
49
51
  return tgt + src - src.detach()
50
52
 
51
- # meta controller
53
+ def log(t, eps = 1e-20):
54
+ return t.clamp_min(eps).log()
52
55
 
53
- MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
- 'prev_hiddens',
55
- 'input_residual_stream',
56
- 'action_dist',
57
- 'codes',
58
- 'kl_loss',
59
- 'switch_loss'
60
- ))
56
+ # meta controller
61
57
 
62
58
  @save_load()
63
59
  class MetaControllerWithBinaryMapper(Module):
@@ -144,12 +140,29 @@ class MetaControllerWithBinaryMapper(Module):
144
140
  *self.proposer_to_binary_logits.parameters()
145
141
  ]
146
142
 
143
+ def log_prob(
144
+ self,
145
+ action_dist,
146
+ sampled_latent_action
147
+ ):
148
+ action_prob = action_dist.sigmoid()
149
+ probs = stack((action_prob, 1. - action_prob), dim = -1)
150
+ log_probs = log(probs)
151
+
152
+ indices = sampled_latent_action.argmax(dim = -1)
153
+ codes = self.binary_mapper.codes[indices].long()
154
+
155
+ codes = rearrange(codes, '... -> ... 1')
156
+ action_log_probs = log_probs.gather(-1, codes)
157
+
158
+ return rearrange(action_log_probs, '... 1 -> ...')
159
+
147
160
  def forward(
148
161
  self,
149
162
  residual_stream,
150
163
  cache: MetaControllerOutput | None = None,
151
164
  discovery_phase = False,
152
- hard_switch = False,
165
+ hard_switch = None,
153
166
  temperature = 1.,
154
167
  episode_lens: Tensor | None = None
155
168
  ):
@@ -165,6 +178,8 @@ class MetaControllerWithBinaryMapper(Module):
165
178
 
166
179
  meta_embed = self.model_to_meta(residual_stream)
167
180
 
181
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
182
+
168
183
  if discovery_phase:
169
184
  mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
170
185
 
@@ -266,4 +281,4 @@ class MetaControllerWithBinaryMapper(Module):
266
281
  sampled_codes[:, -1:]
267
282
  )
268
283
 
269
- return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, kl_loss, switch_loss)
284
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.30
3
+ Version: 0.0.32
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=somE9gX36c1d9hF2n8Qn4foRY8krHGodvrvulhkIGE8,15006
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=CTGK8ruQ3TkioVUwFTHdrbfzubaeuhSdXHfHtaDcwMY,8813
4
+ metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
+ metacontroller_pytorch-0.0.32.dist-info/METADATA,sha256=hr08iXm6Mb-rnDu2xPrr9YQ6cwTtX1F79MfBYt54Y94,4747
6
+ metacontroller_pytorch-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.32.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=ydkL3gYW5WGXQdQOIJQ_gibJs74laIIx-v4DmcJHi7M,14497
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=OGal6dftRPBg_QT1LNDYejNGNlmh4MBvdM41FAQJp9Y,8153
4
- metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
- metacontroller_pytorch-0.0.30.dist-info/METADATA,sha256=ghasc1GA0ZM-AZimY0FnGuRFsezVIcbI49V6TIOWeq4,4747
6
- metacontroller_pytorch-0.0.30.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.30.dist-info/RECORD,,