metacontroller-pytorch 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -6,7 +6,7 @@ from collections import namedtuple
6
6
  from loguru import logger
7
7
 
8
8
  import torch
9
- from torch import nn, cat, stack, tensor
9
+ from torch import nn, cat, stack, tensor, Tensor
10
10
  from torch.nn import Module, GRU, Linear, Identity
11
11
  import torch.nn.functional as F
12
12
 
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import pad_at_dim, lens_to_mask
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -151,7 +151,8 @@ class MetaController(Module):
151
151
  cache: MetaControllerOutput | None = None,
152
152
  discovery_phase = False,
153
153
  hard_switch = False,
154
- temperature = 1.
154
+ temperature = 1.,
155
+ episode_lens: Tensor | None = None
155
156
  ):
156
157
  device = residual_stream.device
157
158
 
@@ -168,7 +169,9 @@ class MetaController(Module):
168
169
  if discovery_phase:
169
170
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
170
171
 
171
- encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
172
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
173
+
174
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
172
175
 
173
176
  proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
174
177
  readout = self.emitter_to_action_mean_log_var
@@ -391,7 +394,7 @@ class Transformer(Module):
391
394
  with meta_controller_context():
392
395
 
393
396
  if exists(meta_controller):
394
- control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
397
+ control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
395
398
  else:
396
399
  control_signal, next_meta_hiddens = self.zero, None
397
400
 
@@ -0,0 +1,266 @@
1
+ from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
4
+ from functools import partial
5
+ from collections import namedtuple
6
+ from loguru import logger
7
+
8
+ import torch
9
+ from torch import nn, cat, stack, tensor, Tensor
10
+ from torch.nn import Module, GRU, Linear, Identity
11
+ import torch.nn.functional as F
12
+
13
+ # einops
14
+
15
+ import einx
16
+ from einops import einsum, rearrange, repeat, reduce
17
+ from einops.layers.torch import Rearrange
18
+
19
+ # external modules
20
+
21
+ from x_transformers import Encoder, Decoder
22
+ from x_mlps_pytorch import Feedforwards
23
+
24
+ from assoc_scan import AssocScan
25
+
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
27
+ from torch_einops_utils.save_load import save_load
28
+
29
+ from vector_quantize_pytorch import BinaryMapper
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ GRU = partial(GRU, batch_first = True)
36
+
37
+ # helper functions
38
+
39
+ def exists(v):
40
+ return v is not None
41
+
42
+ def default(*args):
43
+ for arg in args:
44
+ if exists(arg):
45
+ return arg
46
+ return None
47
+
48
+ def straight_through(src, tgt):
49
+ return tgt + src - src.detach()
50
+
51
+ # meta controller
52
+
53
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
+ 'prev_hiddens',
55
+ 'action_dist',
56
+ 'codes',
57
+ 'kl_loss',
58
+ 'switch_loss'
59
+ ))
60
+
61
+ @save_load()
62
+ class MetaControllerWithBinaryMapper(Module):
63
+ def __init__(
64
+ self,
65
+ dim_model,
66
+ *,
67
+ dim_meta_controller = 256,
68
+ dim_code_bits = 4,
69
+ switch_per_code = False,
70
+ decoder_expansion_factor = 2.,
71
+ decoder_depth = 1,
72
+ hypernetwork_low_rank = 16,
73
+ assoc_scan_kwargs: dict = dict(),
74
+ bidirectional_temporal_encoder_kwargs: dict = dict(
75
+ attn_dim_head = 32, heads = 8
76
+ ),
77
+ kl_loss_threshold = 0.
78
+ ):
79
+ super().__init__()
80
+ dim_meta = default(dim_meta_controller, dim_model)
81
+
82
+ self.model_to_meta = Linear(dim_model, dim_meta)
83
+
84
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
85
+
86
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
87
+ self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
88
+
89
+ self.action_proposer = GRU(dim_meta, dim_meta)
90
+ self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
91
+
92
+ # binary mapper
93
+ # proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
94
+
95
+ self.binary_mapper = BinaryMapper(
96
+ bits = dim_code_bits,
97
+ kl_loss_threshold = kl_loss_threshold
98
+ )
99
+
100
+ self.dim_code_bits = dim_code_bits
101
+ self.num_codes = self.binary_mapper.num_codes
102
+
103
+ # switching unit
104
+
105
+ self.switch_per_code = switch_per_code
106
+
107
+ self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
108
+ self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
109
+
110
+ self.switch_gating = AssocScan(**assoc_scan_kwargs)
111
+
112
+ # decoder
113
+
114
+ assert hypernetwork_low_rank < self.num_codes
115
+
116
+ dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
117
+
118
+ self.decoder = Feedforwards(
119
+ dim_in = self.num_codes,
120
+ dim = dim_decoder_hidden,
121
+ depth = decoder_depth,
122
+ dim_out = 2 * hypernetwork_low_rank * dim_model
123
+ )
124
+
125
+ self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
126
+
127
+ self.register_buffer('zero', tensor(0.), persistent = False)
128
+
129
+ def discovery_parameters(self):
130
+ return [
131
+ *self.model_to_meta.parameters(),
132
+ *self.bidirectional_temporal_encoder.parameters(),
133
+ *self.emitter.parameters(),
134
+ *self.emitter_to_binary_logits.parameters(),
135
+ *self.binary_mapper.parameters(),
136
+ *self.decoder.parameters(),
137
+ *self.switch_gating.parameters()
138
+ ]
139
+
140
+ def internal_rl_parameters(self):
141
+ return [
142
+ *self.action_proposer.parameters(),
143
+ *self.proposer_to_binary_logits.parameters()
144
+ ]
145
+
146
+ def forward(
147
+ self,
148
+ residual_stream,
149
+ cache: MetaControllerOutput | None = None,
150
+ discovery_phase = False,
151
+ hard_switch = False,
152
+ temperature = 1.,
153
+ episode_lens: Tensor | None = None
154
+ ):
155
+ device = residual_stream.device
156
+
157
+ # destruct prev cache
158
+
159
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
160
+
161
+ # getting proposed action for the two phases
162
+
163
+ next_action_proposer_hidden = None
164
+
165
+ meta_embed = self.model_to_meta(residual_stream)
166
+
167
+ if discovery_phase:
168
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
169
+
170
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
171
+
172
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
173
+ to_logits = self.emitter_to_binary_logits
174
+
175
+ else: # else internal rl phase
176
+
177
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
178
+ to_logits = self.proposer_to_binary_logits
179
+
180
+ # sample from the binary mapper
181
+
182
+ binary_logits = to_logits(proposed_action_hidden)
183
+
184
+ one_hot, kl_loss = self.binary_mapper(
185
+ binary_logits,
186
+ temperature = temperature,
187
+ reduce_aux_kl_loss = False
188
+ )
189
+
190
+ # bottled action is now the one-hot sparse codes (with straight-through)
191
+
192
+ sampled_codes = one_hot
193
+
194
+ # switching unit timer
195
+
196
+ batch, seq_len, dim = sampled_codes.shape
197
+
198
+ if not exists(prev_sampled_code):
199
+ prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
200
+
201
+ if discovery_phase:
202
+ z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
203
+ else:
204
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
205
+ z_prev = prev_sampled_code
206
+
207
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
208
+
209
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
210
+ switch_input,
211
+ prev_switching_unit_gru_hidden
212
+ )
213
+
214
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
215
+
216
+ # losses
217
+
218
+ switch_loss = self.zero
219
+
220
+ if discovery_phase:
221
+ # weight unreduced kl loss by switch gates
222
+
223
+ weighted_kl_loss = kl_loss * switch_beta
224
+ kl_loss = weighted_kl_loss.sum(dim = -1).mean()
225
+
226
+ # encourage less switching
227
+
228
+ switch_loss = switch_beta.mean()
229
+ else:
230
+ kl_loss = self.zero
231
+
232
+ # maybe hard switch, then use associative scan
233
+
234
+ if hard_switch:
235
+ hard_switch_beta = (switch_beta > 0.5).float()
236
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
237
+
238
+ forget = 1. - switch_beta
239
+
240
+ # gated codes (or soft distribution)
241
+
242
+ gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
243
+
244
+ next_switch_gated_codes = gated_codes[:, -1]
245
+
246
+ # decoder
247
+
248
+ decoder_out = self.decoder(gated_codes)
249
+
250
+ w1, w2 = self.to_hyper_network_weights(decoder_out)
251
+ hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
252
+
253
+ # generating the residual stream controlling signal
254
+
255
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
256
+
257
+ # returning
258
+
259
+ next_hiddens = (
260
+ next_action_proposer_hidden,
261
+ next_switching_unit_gru_hidden,
262
+ next_switch_gated_codes,
263
+ sampled_codes[:, -1:]
264
+ )
265
+
266
+ return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.25
3
+ Version: 0.0.27
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -42,6 +42,7 @@ Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.23
43
43
  Requires-Dist: torch-einops-utils>=0.0.16
44
44
  Requires-Dist: torch>=2.5
45
+ Requires-Dist: vector-quantize-pytorch>=1.27.20
45
46
  Requires-Dist: x-evolution>=0.1.23
46
47
  Requires-Dist: x-mlps-pytorch
47
48
  Requires-Dist: x-transformers
@@ -0,0 +1,7 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=GCvyF-5XILiexKQKu26h8NroTyeS7ksS1Q02mN5EGVw,8014
4
+ metacontroller_pytorch-0.0.27.dist-info/METADATA,sha256=X_cwahEVbf7nS7c7QEi1t-kBqezkjycP4fSFKL1D-rk,4411
5
+ metacontroller_pytorch-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ metacontroller_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ metacontroller_pytorch-0.0.27.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=icToDxXPknHG5C5hTzVaVOCibYbJ3aDLmZlaMc3Xge0,14275
3
- metacontroller_pytorch-0.0.25.dist-info/METADATA,sha256=HItPrlXUrJhZ1ZmpVU8JNftpyazBvJ3GVlOJPWL8NKE,4363
4
- metacontroller_pytorch-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.25.dist-info/RECORD,,