metacontroller-pytorch 0.0.41__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ GRPOOutput = namedtuple('GRPOOutput', (
70
+ 'state',
71
+ 'action',
72
+ 'log_prob',
73
+ 'switch_beta'
74
+ ))
75
+
69
76
  def z_score(t, eps = 1e-8):
70
77
  return (t - t.mean()) / (t.std() + eps)
71
78
 
@@ -107,6 +114,17 @@ def policy_loss(
107
114
 
108
115
  return masked_mean(losses, mask)
109
116
 
117
+ def extract_grpo_data(meta_controller, transformer_output):
118
+ meta_output = transformer_output.prev_hiddens.meta_controller
119
+
120
+ state = meta_output.input_residual_stream
121
+ action = meta_output.actions
122
+ switch_beta = meta_output.switch_beta
123
+
124
+ log_prob = meta_controller.log_prob(meta_output.action_dist, action)
125
+
126
+ return GRPOOutput(state, action, log_prob, switch_beta)
127
+
110
128
  @save_load()
111
129
  class MetaController(Module):
112
130
  def __init__(
@@ -273,7 +291,7 @@ class MetaController(Module):
273
291
  else:
274
292
  # else during inference, use the previous sampled latent action
275
293
 
276
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
294
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
277
295
  z_prev = prev_sampled_latent_action
278
296
 
279
297
  # switch input is previous latent action and the embedding
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
241
241
  if discovery_phase:
242
242
  z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
243
243
  else:
244
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
244
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
245
245
  z_prev = prev_sampled_code
246
246
 
247
247
  switch_input = torch.cat((meta_embed, z_prev), dim=-1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.41
3
+ Version: 0.0.43
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./fig1.png" width="400px"></img>
55
55
 
56
- ## metacontroller (wip)
56
+ ## metacontroller
57
57
 
58
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
59
59
 
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
+ metacontroller/metacontroller.py,sha256=3L5vPXFm7WbYuuWuCd9RpkMVC--mNpAlsJxpLwMeQ8M,17959
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=QJG-o4ZuEC6BkYff6nlyhVcze8o1q8GtbtCP8MLPrvc,9927
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.43.dist-info/METADATA,sha256=y1CIKBvy4jTaqsC_FBfiA1FK7pQrqTasiwA96M83luc,6816
6
+ metacontroller_pytorch-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.43.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
- metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.41.dist-info/METADATA,sha256=IvP-wC73xCnT8X1aul1IfcaC4fUwRq9Y2UB1h0JG5TI,6822
6
- metacontroller_pytorch-0.0.41.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.41.dist-info/RECORD,,