metacontroller-pytorch 0.0.26__tar.gz → 0.0.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (16) hide show
  1. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/PKG-INFO +2 -1
  2. metacontroller_pytorch-0.0.27/metacontroller/metacontroller_with_binary_mapper.py +266 -0
  3. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/pyproject.toml +2 -1
  4. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/tests/test_metacontroller.py +24 -12
  5. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/.github/workflows/python-publish.yml +0 -0
  6. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/.github/workflows/test.yml +0 -0
  7. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/.gitignore +0 -0
  8. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/LICENSE +0 -0
  9. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/README.md +0 -0
  10. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/fig1.png +0 -0
  11. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/gather_babyai_trajs.py +0 -0
  12. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/metacontroller/__init__.py +0 -0
  13. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/metacontroller/metacontroller.py +0 -0
  14. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/test_babyai_e2e.sh +0 -0
  15. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/train_babyai.py +0 -0
  16. {metacontroller_pytorch-0.0.26 → metacontroller_pytorch-0.0.27}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -42,6 +42,7 @@ Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.23
43
43
  Requires-Dist: torch-einops-utils>=0.0.16
44
44
  Requires-Dist: torch>=2.5
45
+ Requires-Dist: vector-quantize-pytorch>=1.27.20
45
46
  Requires-Dist: x-evolution>=0.1.23
46
47
  Requires-Dist: x-mlps-pytorch
47
48
  Requires-Dist: x-transformers
@@ -0,0 +1,266 @@
1
+ from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
4
+ from functools import partial
5
+ from collections import namedtuple
6
+ from loguru import logger
7
+
8
+ import torch
9
+ from torch import nn, cat, stack, tensor, Tensor
10
+ from torch.nn import Module, GRU, Linear, Identity
11
+ import torch.nn.functional as F
12
+
13
+ # einops
14
+
15
+ import einx
16
+ from einops import einsum, rearrange, repeat, reduce
17
+ from einops.layers.torch import Rearrange
18
+
19
+ # external modules
20
+
21
+ from x_transformers import Encoder, Decoder
22
+ from x_mlps_pytorch import Feedforwards
23
+
24
+ from assoc_scan import AssocScan
25
+
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
27
+ from torch_einops_utils.save_load import save_load
28
+
29
+ from vector_quantize_pytorch import BinaryMapper
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ GRU = partial(GRU, batch_first = True)
36
+
37
+ # helper functions
38
+
39
+ def exists(v):
40
+ return v is not None
41
+
42
+ def default(*args):
43
+ for arg in args:
44
+ if exists(arg):
45
+ return arg
46
+ return None
47
+
48
+ def straight_through(src, tgt):
49
+ return tgt + src - src.detach()
50
+
51
+ # meta controller
52
+
53
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
+ 'prev_hiddens',
55
+ 'action_dist',
56
+ 'codes',
57
+ 'kl_loss',
58
+ 'switch_loss'
59
+ ))
60
+
61
+ @save_load()
62
+ class MetaControllerWithBinaryMapper(Module):
63
+ def __init__(
64
+ self,
65
+ dim_model,
66
+ *,
67
+ dim_meta_controller = 256,
68
+ dim_code_bits = 4,
69
+ switch_per_code = False,
70
+ decoder_expansion_factor = 2.,
71
+ decoder_depth = 1,
72
+ hypernetwork_low_rank = 16,
73
+ assoc_scan_kwargs: dict = dict(),
74
+ bidirectional_temporal_encoder_kwargs: dict = dict(
75
+ attn_dim_head = 32, heads = 8
76
+ ),
77
+ kl_loss_threshold = 0.
78
+ ):
79
+ super().__init__()
80
+ dim_meta = default(dim_meta_controller, dim_model)
81
+
82
+ self.model_to_meta = Linear(dim_model, dim_meta)
83
+
84
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
85
+
86
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
87
+ self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
88
+
89
+ self.action_proposer = GRU(dim_meta, dim_meta)
90
+ self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
91
+
92
+ # binary mapper
93
+ # proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
94
+
95
+ self.binary_mapper = BinaryMapper(
96
+ bits = dim_code_bits,
97
+ kl_loss_threshold = kl_loss_threshold
98
+ )
99
+
100
+ self.dim_code_bits = dim_code_bits
101
+ self.num_codes = self.binary_mapper.num_codes
102
+
103
+ # switching unit
104
+
105
+ self.switch_per_code = switch_per_code
106
+
107
+ self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
108
+ self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
109
+
110
+ self.switch_gating = AssocScan(**assoc_scan_kwargs)
111
+
112
+ # decoder
113
+
114
+ assert hypernetwork_low_rank < self.num_codes
115
+
116
+ dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
117
+
118
+ self.decoder = Feedforwards(
119
+ dim_in = self.num_codes,
120
+ dim = dim_decoder_hidden,
121
+ depth = decoder_depth,
122
+ dim_out = 2 * hypernetwork_low_rank * dim_model
123
+ )
124
+
125
+ self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
126
+
127
+ self.register_buffer('zero', tensor(0.), persistent = False)
128
+
129
+ def discovery_parameters(self):
130
+ return [
131
+ *self.model_to_meta.parameters(),
132
+ *self.bidirectional_temporal_encoder.parameters(),
133
+ *self.emitter.parameters(),
134
+ *self.emitter_to_binary_logits.parameters(),
135
+ *self.binary_mapper.parameters(),
136
+ *self.decoder.parameters(),
137
+ *self.switch_gating.parameters()
138
+ ]
139
+
140
+ def internal_rl_parameters(self):
141
+ return [
142
+ *self.action_proposer.parameters(),
143
+ *self.proposer_to_binary_logits.parameters()
144
+ ]
145
+
146
+ def forward(
147
+ self,
148
+ residual_stream,
149
+ cache: MetaControllerOutput | None = None,
150
+ discovery_phase = False,
151
+ hard_switch = False,
152
+ temperature = 1.,
153
+ episode_lens: Tensor | None = None
154
+ ):
155
+ device = residual_stream.device
156
+
157
+ # destruct prev cache
158
+
159
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
160
+
161
+ # getting proposed action for the two phases
162
+
163
+ next_action_proposer_hidden = None
164
+
165
+ meta_embed = self.model_to_meta(residual_stream)
166
+
167
+ if discovery_phase:
168
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
169
+
170
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
171
+
172
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
173
+ to_logits = self.emitter_to_binary_logits
174
+
175
+ else: # else internal rl phase
176
+
177
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
178
+ to_logits = self.proposer_to_binary_logits
179
+
180
+ # sample from the binary mapper
181
+
182
+ binary_logits = to_logits(proposed_action_hidden)
183
+
184
+ one_hot, kl_loss = self.binary_mapper(
185
+ binary_logits,
186
+ temperature = temperature,
187
+ reduce_aux_kl_loss = False
188
+ )
189
+
190
+ # bottled action is now the one-hot sparse codes (with straight-through)
191
+
192
+ sampled_codes = one_hot
193
+
194
+ # switching unit timer
195
+
196
+ batch, seq_len, dim = sampled_codes.shape
197
+
198
+ if not exists(prev_sampled_code):
199
+ prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
200
+
201
+ if discovery_phase:
202
+ z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
203
+ else:
204
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
205
+ z_prev = prev_sampled_code
206
+
207
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
208
+
209
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
210
+ switch_input,
211
+ prev_switching_unit_gru_hidden
212
+ )
213
+
214
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
215
+
216
+ # losses
217
+
218
+ switch_loss = self.zero
219
+
220
+ if discovery_phase:
221
+ # weight unreduced kl loss by switch gates
222
+
223
+ weighted_kl_loss = kl_loss * switch_beta
224
+ kl_loss = weighted_kl_loss.sum(dim = -1).mean()
225
+
226
+ # encourage less switching
227
+
228
+ switch_loss = switch_beta.mean()
229
+ else:
230
+ kl_loss = self.zero
231
+
232
+ # maybe hard switch, then use associative scan
233
+
234
+ if hard_switch:
235
+ hard_switch_beta = (switch_beta > 0.5).float()
236
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
237
+
238
+ forget = 1. - switch_beta
239
+
240
+ # gated codes (or soft distribution)
241
+
242
+ gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
243
+
244
+ next_switch_gated_codes = gated_codes[:, -1]
245
+
246
+ # decoder
247
+
248
+ decoder_out = self.decoder(gated_codes)
249
+
250
+ w1, w2 = self.to_hyper_network_weights(decoder_out)
251
+ hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
252
+
253
+ # generating the residual stream controlling signal
254
+
255
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
256
+
257
+ # returning
258
+
259
+ next_hiddens = (
260
+ next_action_proposer_hidden,
261
+ next_switching_unit_gru_hidden,
262
+ next_switch_gated_codes,
263
+ sampled_codes[:, -1:]
264
+ )
265
+
266
+ return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.26"
3
+ version = "0.0.27"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -32,6 +32,7 @@ dependencies = [
32
32
  "memmap-replay-buffer>=0.0.23",
33
33
  "torch>=2.5",
34
34
  "torch-einops-utils>=0.0.16",
35
+ "vector-quantize-pytorch>=1.27.20",
35
36
  "x-evolution>=0.1.23",
36
37
  "x-mlps-pytorch",
37
38
  "x-transformers"
@@ -5,27 +5,30 @@ from pathlib import Path
5
5
 
6
6
  import torch
7
7
  from metacontroller.metacontroller import Transformer, MetaController
8
+ from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
8
9
 
9
10
  from einops import rearrange
10
11
 
12
+ @param('use_binary_mapper_variant', (False, True))
11
13
  @param('action_discrete', (False, True))
12
14
  @param('switch_per_latent_dim', (False, True))
13
15
  @param('variable_length', (False, True))
14
16
  def test_metacontroller(
17
+ use_binary_mapper_variant,
15
18
  action_discrete,
16
19
  switch_per_latent_dim,
17
20
  variable_length
18
21
  ):
19
22
 
20
- state = torch.randn(1, 1024, 384)
21
- episode_lens = torch.tensor([512]) if variable_length else None
23
+ state = torch.randn(2, 128, 384)
24
+ episode_lens = torch.tensor([64, 64]) if variable_length else None
22
25
 
23
26
  if action_discrete:
24
- actions = torch.randint(0, 4, (1, 1024))
27
+ actions = torch.randint(0, 4, (2, 128))
25
28
  action_embed_readout = dict(num_discrete = 4)
26
29
  assert_shape = (4,)
27
30
  else:
28
- actions = torch.randn(1, 1024, 8)
31
+ actions = torch.randn(2, 128, 8)
29
32
  action_embed_readout = dict(num_continuous = 8)
30
33
  assert_shape = (8, 2)
31
34
 
@@ -44,12 +47,20 @@ def test_metacontroller(
44
47
 
45
48
  # discovery and internal rl phase with meta controller
46
49
 
47
- meta_controller = MetaController(
48
- dim_model = 512,
49
- dim_meta_controller = 256,
50
- dim_latent = 128,
51
- switch_per_latent_dim = switch_per_latent_dim
52
- )
50
+ if not use_binary_mapper_variant:
51
+ meta_controller = MetaController(
52
+ dim_model = 512,
53
+ dim_meta_controller = 256,
54
+ dim_latent = 128,
55
+ switch_per_latent_dim = switch_per_latent_dim
56
+ )
57
+ else:
58
+ meta_controller = MetaControllerWithBinaryMapper(
59
+ dim_model = 512,
60
+ dim_meta_controller = 256,
61
+ switch_per_code = switch_per_latent_dim,
62
+ dim_code_bits = 8, # 2**8 = 256 codes
63
+ )
53
64
 
54
65
  # discovery phase
55
66
 
@@ -66,7 +77,7 @@ def test_metacontroller(
66
77
 
67
78
  logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
68
79
 
69
- assert logits.shape == (1, 1, *assert_shape)
80
+ assert logits.shape == (2, 1, *assert_shape)
70
81
  past_action_id = model.action_readout.sample(logits)
71
82
 
72
83
  # evolutionary strategies over grpo
@@ -78,7 +89,8 @@ def test_metacontroller(
78
89
 
79
90
  meta_controller.save('./meta_controller.pt')
80
91
 
81
- rehydrated_meta_controller = MetaController.init_and_load('./meta_controller.pt')
92
+ meta_controller_klass = meta_controller.__class__
93
+ rehydrated_meta_controller = meta_controller_klass.init_and_load('./meta_controller.pt')
82
94
 
83
95
  model.save('./trained.pt')
84
96