metacontroller-pytorch 0.0.15__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- metacontroller/__init__.py +1 -1
- metacontroller/metacontroller.py +219 -48
- metacontroller/metacontroller_with_binary_mapper.py +315 -0
- metacontroller/transformer_with_resnet.py +194 -0
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.41.dist-info}/METADATA +128 -2
- metacontroller_pytorch-0.0.41.dist-info/RECORD +8 -0
- metacontroller_pytorch-0.0.15.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.41.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.41.dist-info}/licenses/LICENSE +0 -0
metacontroller/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from metacontroller.metacontroller import MetaController
|
|
1
|
+
from metacontroller.metacontroller import MetaController, Transformer
|
metacontroller/metacontroller.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections import namedtuple
|
|
|
6
6
|
from loguru import logger
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, stack, tensor
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
10
|
from torch.nn import Module, GRU, Linear, Identity
|
|
11
11
|
import torch.nn.functional as F
|
|
12
12
|
|
|
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
|
|
|
18
18
|
|
|
19
19
|
# external modules
|
|
20
20
|
|
|
21
|
-
from x_transformers import Decoder
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
@@ -26,6 +26,9 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left
|
|
30
|
+
from torch_einops_utils.save_load import save_load
|
|
31
|
+
|
|
29
32
|
# constants
|
|
30
33
|
|
|
31
34
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -55,42 +58,100 @@ def straight_through(src, tgt):
|
|
|
55
58
|
|
|
56
59
|
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
57
60
|
'prev_hiddens',
|
|
61
|
+
'input_residual_stream',
|
|
58
62
|
'action_dist',
|
|
59
63
|
'actions',
|
|
60
|
-
'
|
|
64
|
+
'switch_beta',
|
|
65
|
+
'kl_loss',
|
|
66
|
+
'switch_loss'
|
|
61
67
|
))
|
|
62
68
|
|
|
69
|
+
def z_score(t, eps = 1e-8):
|
|
70
|
+
return (t - t.mean()) / (t.std() + eps)
|
|
71
|
+
|
|
72
|
+
def policy_loss(
|
|
73
|
+
meta_controller,
|
|
74
|
+
state,
|
|
75
|
+
old_log_probs,
|
|
76
|
+
actions,
|
|
77
|
+
advantages,
|
|
78
|
+
mask,
|
|
79
|
+
episode_lens = None,
|
|
80
|
+
eps_clip = 0.2
|
|
81
|
+
):
|
|
82
|
+
# get new log probs
|
|
83
|
+
|
|
84
|
+
action_dist = meta_controller.get_action_dist_for_internal_rl(state)
|
|
85
|
+
new_log_probs = meta_controller.log_prob(action_dist, actions)
|
|
86
|
+
|
|
87
|
+
# calculate ratio
|
|
88
|
+
|
|
89
|
+
ratio = (new_log_probs - old_log_probs).exp()
|
|
90
|
+
|
|
91
|
+
# align ratio and advantages
|
|
92
|
+
|
|
93
|
+
ratio, advantages = align_dims_left((ratio, advantages))
|
|
94
|
+
|
|
95
|
+
# ppo surrogate loss
|
|
96
|
+
|
|
97
|
+
surr1 = ratio * advantages
|
|
98
|
+
surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
|
|
99
|
+
|
|
100
|
+
losses = -torch.min(surr1, surr2)
|
|
101
|
+
|
|
102
|
+
# masking
|
|
103
|
+
|
|
104
|
+
if exists(episode_lens):
|
|
105
|
+
mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
|
|
106
|
+
mask = mask & episode_mask
|
|
107
|
+
|
|
108
|
+
return masked_mean(losses, mask)
|
|
109
|
+
|
|
110
|
+
@save_load()
|
|
63
111
|
class MetaController(Module):
|
|
64
112
|
def __init__(
|
|
65
113
|
self,
|
|
66
|
-
|
|
114
|
+
dim_model,
|
|
67
115
|
*,
|
|
116
|
+
dim_meta_controller = 256,
|
|
117
|
+
dim_latent = 128,
|
|
68
118
|
switch_per_latent_dim = True,
|
|
69
119
|
decoder_expansion_factor = 2.,
|
|
70
120
|
decoder_depth = 1,
|
|
71
121
|
hypernetwork_low_rank = 16,
|
|
72
|
-
assoc_scan_kwargs: dict = dict()
|
|
122
|
+
assoc_scan_kwargs: dict = dict(),
|
|
123
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
124
|
+
attn_dim_head = 32,
|
|
125
|
+
heads = 8
|
|
126
|
+
)
|
|
73
127
|
):
|
|
74
128
|
super().__init__()
|
|
129
|
+
self.dim_model = dim_model
|
|
130
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
131
|
+
|
|
132
|
+
# the linear that brings from model dimension
|
|
133
|
+
|
|
134
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
75
135
|
|
|
76
|
-
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use
|
|
136
|
+
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
|
|
77
137
|
|
|
78
|
-
self.
|
|
138
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
79
139
|
|
|
80
|
-
self.emitter = GRU(
|
|
81
|
-
self.emitter_to_action_mean_log_var = Readout(
|
|
140
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
141
|
+
self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
|
|
82
142
|
|
|
83
143
|
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
84
144
|
|
|
85
|
-
self.action_proposer = GRU(
|
|
86
|
-
self.action_proposer_mean_log_var = Readout(
|
|
145
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
146
|
+
self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
|
|
87
147
|
|
|
88
148
|
# switching unit
|
|
89
149
|
|
|
90
150
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
91
151
|
|
|
92
|
-
self.
|
|
93
|
-
self.
|
|
152
|
+
self.dim_latent = dim_latent
|
|
153
|
+
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
154
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
94
155
|
|
|
95
156
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
96
157
|
|
|
@@ -104,16 +165,26 @@ class MetaController(Module):
|
|
|
104
165
|
dim_in = dim_latent,
|
|
105
166
|
dim = dim_decoder_hidden,
|
|
106
167
|
depth = decoder_depth,
|
|
107
|
-
dim_out = 2 * hypernetwork_low_rank *
|
|
168
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
108
169
|
)
|
|
109
170
|
|
|
110
171
|
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
111
172
|
|
|
112
173
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
113
174
|
|
|
175
|
+
@property
|
|
176
|
+
def replay_buffer_field_dict(self):
|
|
177
|
+
return dict(
|
|
178
|
+
states = ('float', self.dim_model),
|
|
179
|
+
log_probs = ('float', self.dim_latent),
|
|
180
|
+
switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
|
|
181
|
+
latent_actions = ('float', self.dim_latent)
|
|
182
|
+
)
|
|
183
|
+
|
|
114
184
|
def discovery_parameters(self):
|
|
115
185
|
return [
|
|
116
|
-
*self.
|
|
186
|
+
*self.model_to_meta.parameters(),
|
|
187
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
117
188
|
*self.emitter.parameters(),
|
|
118
189
|
*self.emitter_to_action_mean_log_var.parameters(),
|
|
119
190
|
*self.decoder.parameters(),
|
|
@@ -126,54 +197,99 @@ class MetaController(Module):
|
|
|
126
197
|
*self.action_proposer_mean_log_var.parameters()
|
|
127
198
|
]
|
|
128
199
|
|
|
200
|
+
def get_action_dist_for_internal_rl(
|
|
201
|
+
self,
|
|
202
|
+
residual_stream
|
|
203
|
+
):
|
|
204
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
205
|
+
|
|
206
|
+
proposed_action_hidden, _ = self.action_proposer(meta_embed)
|
|
207
|
+
|
|
208
|
+
return self.action_proposer_mean_log_var(proposed_action_hidden)
|
|
209
|
+
|
|
210
|
+
def log_prob(
|
|
211
|
+
self,
|
|
212
|
+
action_dist,
|
|
213
|
+
sampled_latent_action
|
|
214
|
+
):
|
|
215
|
+
return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
|
|
216
|
+
|
|
129
217
|
def forward(
|
|
130
218
|
self,
|
|
131
219
|
residual_stream,
|
|
132
220
|
cache: MetaControllerOutput | None = None,
|
|
133
221
|
discovery_phase = False,
|
|
134
|
-
hard_switch =
|
|
135
|
-
temperature = 1
|
|
222
|
+
hard_switch = None,
|
|
223
|
+
temperature = 1.,
|
|
224
|
+
episode_lens: Tensor | None = None
|
|
136
225
|
):
|
|
226
|
+
device = residual_stream.device
|
|
137
227
|
|
|
138
228
|
# destruct prev cache
|
|
139
229
|
|
|
140
|
-
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) *
|
|
230
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
141
231
|
|
|
142
232
|
# getting proposed action for the two phases
|
|
143
233
|
|
|
144
234
|
next_action_proposer_hidden = None
|
|
145
235
|
|
|
236
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
237
|
+
|
|
238
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
239
|
+
|
|
146
240
|
if discovery_phase:
|
|
147
241
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
148
242
|
|
|
149
|
-
|
|
150
|
-
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
243
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
151
244
|
|
|
152
|
-
|
|
245
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
246
|
+
|
|
247
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
153
248
|
readout = self.emitter_to_action_mean_log_var
|
|
154
249
|
|
|
155
250
|
else: # else internal rl phase
|
|
156
251
|
|
|
157
|
-
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(
|
|
252
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
158
253
|
readout = self.action_proposer_mean_log_var
|
|
159
254
|
|
|
160
255
|
# sample from the gaussian as the action from the meta controller
|
|
161
256
|
|
|
162
257
|
action_dist = readout(proposed_action_hidden)
|
|
163
258
|
|
|
164
|
-
|
|
259
|
+
sampled_latent_action = readout.sample(action_dist, temperature = temperature)
|
|
165
260
|
|
|
166
261
|
# switching unit timer
|
|
167
262
|
|
|
168
|
-
batch,
|
|
263
|
+
batch, seq_len, dim = sampled_latent_action.shape
|
|
264
|
+
|
|
265
|
+
# initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
|
|
169
266
|
|
|
170
|
-
|
|
267
|
+
if not exists(prev_sampled_latent_action):
|
|
268
|
+
prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
|
|
269
|
+
|
|
270
|
+
if discovery_phase:
|
|
271
|
+
z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
|
|
272
|
+
|
|
273
|
+
else:
|
|
274
|
+
# else during inference, use the previous sampled latent action
|
|
275
|
+
|
|
276
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
277
|
+
z_prev = prev_sampled_latent_action
|
|
278
|
+
|
|
279
|
+
# switch input is previous latent action and the embedding
|
|
280
|
+
|
|
281
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
282
|
+
|
|
283
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
284
|
+
switch_input,
|
|
285
|
+
prev_switching_unit_gru_hidden
|
|
286
|
+
)
|
|
171
287
|
|
|
172
288
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
173
289
|
|
|
174
290
|
# need to encourage normal distribution
|
|
175
291
|
|
|
176
|
-
kl_loss = self.zero
|
|
292
|
+
kl_loss = switch_loss = self.zero
|
|
177
293
|
|
|
178
294
|
if discovery_phase:
|
|
179
295
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
@@ -188,6 +304,10 @@ class MetaController(Module):
|
|
|
188
304
|
kl_loss = kl_loss * switch_beta
|
|
189
305
|
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
190
306
|
|
|
307
|
+
# encourage less switching
|
|
308
|
+
|
|
309
|
+
switch_loss = switch_beta.mean()
|
|
310
|
+
|
|
191
311
|
# maybe hard switch, then use associative scan
|
|
192
312
|
|
|
193
313
|
if hard_switch:
|
|
@@ -195,7 +315,7 @@ class MetaController(Module):
|
|
|
195
315
|
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
196
316
|
|
|
197
317
|
forget = 1. - switch_beta
|
|
198
|
-
gated_action = self.switch_gating(switch_beta,
|
|
318
|
+
gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
|
|
199
319
|
|
|
200
320
|
next_switch_gated_action = gated_action[:, -1]
|
|
201
321
|
|
|
@@ -208,27 +328,40 @@ class MetaController(Module):
|
|
|
208
328
|
|
|
209
329
|
# generating the residual stream controlling signal
|
|
210
330
|
|
|
211
|
-
control_signal = einsum(
|
|
212
|
-
|
|
213
|
-
modified_residual_stream = residual_stream + control_signal
|
|
331
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
214
332
|
|
|
215
333
|
# returning
|
|
216
334
|
|
|
217
335
|
next_hiddens = (
|
|
218
336
|
next_action_proposer_hidden,
|
|
219
337
|
next_switching_unit_gru_hidden,
|
|
220
|
-
next_switch_gated_action
|
|
338
|
+
next_switch_gated_action,
|
|
339
|
+
sampled_latent_action[:, -1:]
|
|
221
340
|
)
|
|
222
341
|
|
|
223
|
-
|
|
342
|
+
# squeeze out the last dimension of switch_beta if single gate for all latent dimensions
|
|
343
|
+
|
|
344
|
+
if not self.switch_per_latent_dim:
|
|
345
|
+
switch_beta = rearrange(switch_beta, '... 1 -> ...')
|
|
346
|
+
|
|
347
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
348
|
+
|
|
349
|
+
MetaController.policy_loss = policy_loss
|
|
224
350
|
|
|
225
351
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
226
352
|
|
|
353
|
+
Hiddens = namedtuple('Hiddens', (
|
|
354
|
+
'lower_body',
|
|
355
|
+
'meta_controller',
|
|
356
|
+
'upper_body'
|
|
357
|
+
))
|
|
358
|
+
|
|
227
359
|
TransformerOutput = namedtuple('TransformerOutput', (
|
|
228
360
|
'residual_stream_latent',
|
|
229
361
|
'prev_hiddens'
|
|
230
362
|
))
|
|
231
363
|
|
|
364
|
+
@save_load()
|
|
232
365
|
class Transformer(Module):
|
|
233
366
|
def __init__(
|
|
234
367
|
self,
|
|
@@ -243,7 +376,7 @@ class Transformer(Module):
|
|
|
243
376
|
super().__init__()
|
|
244
377
|
|
|
245
378
|
if isinstance(lower_body, dict):
|
|
246
|
-
lower_body = Decoder(dim = dim, **lower_body)
|
|
379
|
+
lower_body = Decoder(dim = dim, pre_norm_has_final_norm = False, **lower_body)
|
|
247
380
|
|
|
248
381
|
if isinstance(upper_body, dict):
|
|
249
382
|
upper_body = Decoder(dim = dim, **upper_body)
|
|
@@ -281,26 +414,38 @@ class Transformer(Module):
|
|
|
281
414
|
def forward(
|
|
282
415
|
self,
|
|
283
416
|
state,
|
|
284
|
-
|
|
417
|
+
actions: Tensor | None = None,
|
|
285
418
|
meta_controller: Module | None = None,
|
|
286
419
|
cache: TransformerOutput | None = None,
|
|
287
420
|
discovery_phase = False,
|
|
421
|
+
force_behavior_cloning = False,
|
|
288
422
|
meta_controller_temperature = 1.,
|
|
289
423
|
return_raw_action_dist = False,
|
|
290
424
|
return_latents = False,
|
|
291
425
|
return_cache = False,
|
|
426
|
+
episode_lens: Tensor | None = None
|
|
292
427
|
):
|
|
428
|
+
device = state.device
|
|
429
|
+
|
|
430
|
+
# meta controller is either given or already given at init
|
|
431
|
+
|
|
293
432
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
294
433
|
|
|
295
|
-
|
|
434
|
+
if force_behavior_cloning:
|
|
435
|
+
assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
|
|
436
|
+
meta_controller = None
|
|
296
437
|
|
|
297
|
-
|
|
438
|
+
has_meta_controller = exists(meta_controller)
|
|
439
|
+
|
|
440
|
+
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
441
|
+
|
|
442
|
+
behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
|
|
298
443
|
|
|
299
444
|
# by default, if meta controller is passed in, transformer is no grad
|
|
300
445
|
|
|
301
|
-
lower_transformer_context = nullcontext if not
|
|
302
|
-
meta_controller_context = nullcontext if
|
|
303
|
-
upper_transformer_context = nullcontext if (not
|
|
446
|
+
lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
|
|
447
|
+
meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
|
|
448
|
+
upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
|
|
304
449
|
|
|
305
450
|
# handle cache
|
|
306
451
|
|
|
@@ -308,16 +453,31 @@ class Transformer(Module):
|
|
|
308
453
|
|
|
309
454
|
# handle maybe behavioral cloning
|
|
310
455
|
|
|
311
|
-
if behavioral_cloning:
|
|
456
|
+
if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
|
|
457
|
+
|
|
458
|
+
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
459
|
+
|
|
312
460
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
313
|
-
|
|
461
|
+
actions, target_actions = actions[:, :-1], actions[:, 1:]
|
|
462
|
+
|
|
463
|
+
if exists(episode_lens):
|
|
464
|
+
episode_lens = (episode_lens - 1).clamp(min = 0)
|
|
314
465
|
|
|
315
466
|
# transformer lower body
|
|
316
467
|
|
|
317
468
|
with lower_transformer_context():
|
|
318
469
|
|
|
319
470
|
state_embed = self.state_embed(state)
|
|
320
|
-
|
|
471
|
+
|
|
472
|
+
# handle no past action for first timestep
|
|
473
|
+
|
|
474
|
+
if exists(actions):
|
|
475
|
+
action_embed = self.action_embed(actions)
|
|
476
|
+
else:
|
|
477
|
+
action_embed = state_embed[:, 0:0] # empty action embed
|
|
478
|
+
|
|
479
|
+
if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
|
|
480
|
+
action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
|
|
321
481
|
|
|
322
482
|
embed = state_embed + action_embed
|
|
323
483
|
|
|
@@ -327,10 +487,12 @@ class Transformer(Module):
|
|
|
327
487
|
|
|
328
488
|
with meta_controller_context():
|
|
329
489
|
|
|
330
|
-
if exists(meta_controller):
|
|
331
|
-
|
|
490
|
+
if exists(meta_controller) and not behavioral_cloning:
|
|
491
|
+
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
332
492
|
else:
|
|
333
|
-
|
|
493
|
+
control_signal, next_meta_hiddens = self.zero, None
|
|
494
|
+
|
|
495
|
+
modified_residual_stream = residual_stream + control_signal
|
|
334
496
|
|
|
335
497
|
# modified residual stream sent back to transformer upper body
|
|
336
498
|
|
|
@@ -345,13 +507,22 @@ class Transformer(Module):
|
|
|
345
507
|
# maybe return behavior cloning loss
|
|
346
508
|
|
|
347
509
|
if behavioral_cloning:
|
|
510
|
+
|
|
511
|
+
loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
|
|
512
|
+
|
|
348
513
|
state_dist_params = self.state_readout(attended)
|
|
349
|
-
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
514
|
+
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
|
|
350
515
|
|
|
351
|
-
action_clone_loss = self.action_readout.calculate_loss(dist_params,
|
|
516
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
|
|
352
517
|
|
|
353
518
|
return state_clone_loss, action_clone_loss
|
|
354
519
|
|
|
520
|
+
elif discovery_phase:
|
|
521
|
+
|
|
522
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
523
|
+
|
|
524
|
+
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
525
|
+
|
|
355
526
|
# returning
|
|
356
527
|
|
|
357
528
|
return_one = not (return_latents or return_cache)
|
|
@@ -359,4 +530,4 @@ class Transformer(Module):
|
|
|
359
530
|
if return_one:
|
|
360
531
|
return dist_params
|
|
361
532
|
|
|
362
|
-
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
533
|
+
return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
|
+
from torch.nn import Module, GRU, Linear, Identity
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
# einops
|
|
14
|
+
|
|
15
|
+
import einx
|
|
16
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
17
|
+
from einops.layers.torch import Rearrange
|
|
18
|
+
|
|
19
|
+
# external modules
|
|
20
|
+
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
|
+
from x_mlps_pytorch import Feedforwards
|
|
23
|
+
|
|
24
|
+
from assoc_scan import AssocScan
|
|
25
|
+
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
|
|
27
|
+
from torch_einops_utils.save_load import save_load
|
|
28
|
+
|
|
29
|
+
from vector_quantize_pytorch import BinaryMapper
|
|
30
|
+
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput, policy_loss
|
|
32
|
+
|
|
33
|
+
# constants
|
|
34
|
+
|
|
35
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
36
|
+
|
|
37
|
+
GRU = partial(GRU, batch_first = True)
|
|
38
|
+
|
|
39
|
+
# helper functions
|
|
40
|
+
|
|
41
|
+
def exists(v):
|
|
42
|
+
return v is not None
|
|
43
|
+
|
|
44
|
+
def default(*args):
|
|
45
|
+
for arg in args:
|
|
46
|
+
if exists(arg):
|
|
47
|
+
return arg
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
def straight_through(src, tgt):
|
|
51
|
+
return tgt + src - src.detach()
|
|
52
|
+
|
|
53
|
+
def log(t, eps = 1e-20):
|
|
54
|
+
return t.clamp_min(eps).log()
|
|
55
|
+
|
|
56
|
+
# meta controller
|
|
57
|
+
|
|
58
|
+
@save_load()
|
|
59
|
+
class MetaControllerWithBinaryMapper(Module):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
dim_model,
|
|
63
|
+
*,
|
|
64
|
+
dim_meta_controller = 256,
|
|
65
|
+
dim_code_bits = 4,
|
|
66
|
+
switch_per_code = False,
|
|
67
|
+
decoder_expansion_factor = 2.,
|
|
68
|
+
decoder_depth = 1,
|
|
69
|
+
hypernetwork_low_rank = 16,
|
|
70
|
+
assoc_scan_kwargs: dict = dict(),
|
|
71
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
72
|
+
attn_dim_head = 32, heads = 8
|
|
73
|
+
),
|
|
74
|
+
kl_loss_threshold = 0.
|
|
75
|
+
):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.dim_model = dim_model
|
|
78
|
+
assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
|
|
79
|
+
|
|
80
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
81
|
+
|
|
82
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
83
|
+
|
|
84
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
85
|
+
|
|
86
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
87
|
+
self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
|
|
88
|
+
|
|
89
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
90
|
+
self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
|
|
91
|
+
|
|
92
|
+
# binary mapper
|
|
93
|
+
# proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
|
|
94
|
+
|
|
95
|
+
self.binary_mapper = BinaryMapper(
|
|
96
|
+
bits = dim_code_bits,
|
|
97
|
+
kl_loss_threshold = kl_loss_threshold
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.dim_code_bits = dim_code_bits
|
|
101
|
+
self.num_codes = self.binary_mapper.num_codes
|
|
102
|
+
|
|
103
|
+
# switching unit
|
|
104
|
+
|
|
105
|
+
self.switch_per_code = switch_per_code
|
|
106
|
+
|
|
107
|
+
self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
|
|
108
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
|
|
109
|
+
|
|
110
|
+
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
111
|
+
|
|
112
|
+
# decoder
|
|
113
|
+
|
|
114
|
+
assert hypernetwork_low_rank < self.num_codes
|
|
115
|
+
|
|
116
|
+
dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
|
|
117
|
+
|
|
118
|
+
self.decoder = Feedforwards(
|
|
119
|
+
dim_in = self.num_codes,
|
|
120
|
+
dim = dim_decoder_hidden,
|
|
121
|
+
depth = decoder_depth,
|
|
122
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
126
|
+
|
|
127
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def replay_buffer_field_dict(self):
|
|
131
|
+
return dict(
|
|
132
|
+
states = ('float', self.dim_model),
|
|
133
|
+
log_probs = ('float', self.dim_code_bits),
|
|
134
|
+
switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
|
|
135
|
+
latent_actions = ('float', self.num_codes)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def discovery_parameters(self):
|
|
139
|
+
return [
|
|
140
|
+
*self.model_to_meta.parameters(),
|
|
141
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
142
|
+
*self.emitter.parameters(),
|
|
143
|
+
*self.emitter_to_binary_logits.parameters(),
|
|
144
|
+
*self.binary_mapper.parameters(),
|
|
145
|
+
*self.decoder.parameters(),
|
|
146
|
+
*self.switch_gating.parameters()
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
def internal_rl_parameters(self):
|
|
150
|
+
return [
|
|
151
|
+
*self.action_proposer.parameters(),
|
|
152
|
+
*self.proposer_to_binary_logits.parameters()
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
def get_action_dist_for_internal_rl(
|
|
156
|
+
self,
|
|
157
|
+
residual_stream
|
|
158
|
+
):
|
|
159
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
160
|
+
|
|
161
|
+
proposed_action_hidden, _ = self.action_proposer(meta_embed)
|
|
162
|
+
|
|
163
|
+
return self.proposer_to_binary_logits(proposed_action_hidden)
|
|
164
|
+
|
|
165
|
+
def log_prob(
|
|
166
|
+
self,
|
|
167
|
+
action_dist,
|
|
168
|
+
sampled_latent_action
|
|
169
|
+
):
|
|
170
|
+
log_probs = stack((
|
|
171
|
+
F.logsigmoid(action_dist),
|
|
172
|
+
F.logsigmoid(-action_dist)
|
|
173
|
+
), dim = -1)
|
|
174
|
+
|
|
175
|
+
indices = sampled_latent_action.argmax(dim = -1)
|
|
176
|
+
codes = self.binary_mapper.codes[indices].long()
|
|
177
|
+
|
|
178
|
+
codes = rearrange(codes, '... -> ... 1')
|
|
179
|
+
action_log_probs = log_probs.gather(-1, codes)
|
|
180
|
+
action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
|
|
181
|
+
|
|
182
|
+
return action_log_probs
|
|
183
|
+
|
|
184
|
+
def forward(
|
|
185
|
+
self,
|
|
186
|
+
residual_stream,
|
|
187
|
+
cache: MetaControllerOutput | None = None,
|
|
188
|
+
discovery_phase = False,
|
|
189
|
+
hard_switch = None,
|
|
190
|
+
temperature = 1.,
|
|
191
|
+
episode_lens: Tensor | None = None
|
|
192
|
+
):
|
|
193
|
+
device = residual_stream.device
|
|
194
|
+
|
|
195
|
+
# destruct prev cache
|
|
196
|
+
|
|
197
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
198
|
+
|
|
199
|
+
# getting proposed action for the two phases
|
|
200
|
+
|
|
201
|
+
next_action_proposer_hidden = None
|
|
202
|
+
|
|
203
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
204
|
+
|
|
205
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
206
|
+
|
|
207
|
+
if discovery_phase:
|
|
208
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
209
|
+
|
|
210
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
211
|
+
|
|
212
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
213
|
+
to_logits = self.emitter_to_binary_logits
|
|
214
|
+
|
|
215
|
+
else: # else internal rl phase
|
|
216
|
+
|
|
217
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
218
|
+
to_logits = self.proposer_to_binary_logits
|
|
219
|
+
|
|
220
|
+
# sample from the binary mapper
|
|
221
|
+
|
|
222
|
+
binary_logits = to_logits(proposed_action_hidden)
|
|
223
|
+
|
|
224
|
+
one_hot, kl_loss = self.binary_mapper(
|
|
225
|
+
binary_logits,
|
|
226
|
+
temperature = temperature,
|
|
227
|
+
reduce_aux_kl_loss = False
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# bottled action is now the one-hot sparse codes (with straight-through)
|
|
231
|
+
|
|
232
|
+
sampled_codes = one_hot
|
|
233
|
+
|
|
234
|
+
# switching unit timer
|
|
235
|
+
|
|
236
|
+
batch, seq_len, dim = sampled_codes.shape
|
|
237
|
+
|
|
238
|
+
if not exists(prev_sampled_code):
|
|
239
|
+
prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
|
|
240
|
+
|
|
241
|
+
if discovery_phase:
|
|
242
|
+
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
243
|
+
else:
|
|
244
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
245
|
+
z_prev = prev_sampled_code
|
|
246
|
+
|
|
247
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
248
|
+
|
|
249
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
250
|
+
switch_input,
|
|
251
|
+
prev_switching_unit_gru_hidden
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
255
|
+
|
|
256
|
+
# losses
|
|
257
|
+
|
|
258
|
+
switch_loss = self.zero
|
|
259
|
+
|
|
260
|
+
if discovery_phase:
|
|
261
|
+
# weight unreduced kl loss by switch gates
|
|
262
|
+
|
|
263
|
+
kl_loss, switch_beta = align_dims_left((kl_loss, switch_beta))
|
|
264
|
+
|
|
265
|
+
weighted_kl_loss = kl_loss * switch_beta
|
|
266
|
+
kl_loss = weighted_kl_loss.sum(dim = -1).mean()
|
|
267
|
+
|
|
268
|
+
# encourage less switching
|
|
269
|
+
|
|
270
|
+
switch_loss = switch_beta.mean()
|
|
271
|
+
else:
|
|
272
|
+
kl_loss = self.zero
|
|
273
|
+
|
|
274
|
+
# maybe hard switch, then use associative scan
|
|
275
|
+
|
|
276
|
+
if hard_switch:
|
|
277
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
278
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
279
|
+
|
|
280
|
+
forget = 1. - switch_beta
|
|
281
|
+
|
|
282
|
+
# gated codes (or soft distribution)
|
|
283
|
+
|
|
284
|
+
gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
|
|
285
|
+
|
|
286
|
+
next_switch_gated_codes = gated_codes[:, -1]
|
|
287
|
+
|
|
288
|
+
# decoder
|
|
289
|
+
|
|
290
|
+
decoder_out = self.decoder(gated_codes)
|
|
291
|
+
|
|
292
|
+
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
293
|
+
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
294
|
+
|
|
295
|
+
# generating the residual stream controlling signal
|
|
296
|
+
|
|
297
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
298
|
+
|
|
299
|
+
# returning
|
|
300
|
+
|
|
301
|
+
next_hiddens = (
|
|
302
|
+
next_action_proposer_hidden,
|
|
303
|
+
next_switching_unit_gru_hidden,
|
|
304
|
+
next_switch_gated_codes,
|
|
305
|
+
sampled_codes[:, -1:]
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# squeeze out the last dimension of switch_beta if single gate for all codes
|
|
309
|
+
|
|
310
|
+
if not self.switch_per_code:
|
|
311
|
+
switch_beta = rearrange(switch_beta, '... 1 -> ...')
|
|
312
|
+
|
|
313
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
|
314
|
+
|
|
315
|
+
MetaControllerWithBinaryMapper.policy_loss = policy_loss
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
from torch.nn import Module, ModuleList
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
|
|
9
|
+
from metacontroller.metacontroller import Transformer
|
|
10
|
+
|
|
11
|
+
from torch_einops_utils import pack_with_inverse
|
|
12
|
+
|
|
13
|
+
# resnet components
|
|
14
|
+
|
|
15
|
+
def exists(v):
|
|
16
|
+
return v is not None
|
|
17
|
+
|
|
18
|
+
class BasicBlock(Module):
|
|
19
|
+
expansion = 1
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dim,
|
|
24
|
+
dim_out,
|
|
25
|
+
stride = 1,
|
|
26
|
+
downsample: Module | None = None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
|
|
30
|
+
self.bn1 = nn.BatchNorm2d(dim_out)
|
|
31
|
+
self.relu = nn.ReLU(inplace = True)
|
|
32
|
+
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
|
|
33
|
+
self.bn2 = nn.BatchNorm2d(dim_out)
|
|
34
|
+
self.downsample = downsample
|
|
35
|
+
|
|
36
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
37
|
+
identity = x
|
|
38
|
+
|
|
39
|
+
out = self.conv1(x)
|
|
40
|
+
out = self.bn1(out)
|
|
41
|
+
out = self.relu(out)
|
|
42
|
+
|
|
43
|
+
out = self.conv2(out)
|
|
44
|
+
out = self.bn2(out)
|
|
45
|
+
|
|
46
|
+
if exists(self.downsample):
|
|
47
|
+
identity = self.downsample(x)
|
|
48
|
+
|
|
49
|
+
out += identity
|
|
50
|
+
return self.relu(out)
|
|
51
|
+
|
|
52
|
+
class Bottleneck(Module):
|
|
53
|
+
expansion = 4
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
dim,
|
|
58
|
+
dim_out,
|
|
59
|
+
stride = 1,
|
|
60
|
+
downsample: Module | None = None
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
width = dim_out # simple resnet shortcut
|
|
64
|
+
self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
|
|
65
|
+
self.bn1 = nn.BatchNorm2d(width)
|
|
66
|
+
self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
|
|
67
|
+
self.bn2 = nn.BatchNorm2d(width)
|
|
68
|
+
self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
|
|
69
|
+
self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
|
|
70
|
+
self.relu = nn.ReLU(inplace = True)
|
|
71
|
+
self.downsample = downsample
|
|
72
|
+
|
|
73
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
74
|
+
identity = x
|
|
75
|
+
|
|
76
|
+
out = self.conv1(x)
|
|
77
|
+
out = self.bn1(out)
|
|
78
|
+
out = self.relu(out)
|
|
79
|
+
|
|
80
|
+
out = self.conv2(out)
|
|
81
|
+
out = self.bn2(out)
|
|
82
|
+
out = self.relu(out)
|
|
83
|
+
|
|
84
|
+
out = self.conv3(out)
|
|
85
|
+
out = self.bn3(out)
|
|
86
|
+
|
|
87
|
+
if exists(self.downsample):
|
|
88
|
+
identity = self.downsample(x)
|
|
89
|
+
|
|
90
|
+
out += identity
|
|
91
|
+
return self.relu(out)
|
|
92
|
+
|
|
93
|
+
class ResNet(Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
block: type[BasicBlock | Bottleneck],
|
|
97
|
+
layers: list[int],
|
|
98
|
+
num_classes = 1000,
|
|
99
|
+
channels = 3
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.inplanes = 64
|
|
103
|
+
|
|
104
|
+
self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
|
|
105
|
+
self.bn1 = nn.BatchNorm2d(64)
|
|
106
|
+
self.relu = nn.ReLU(inplace = True)
|
|
107
|
+
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
|
|
108
|
+
|
|
109
|
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
110
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
|
|
111
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
|
|
112
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
|
|
113
|
+
|
|
114
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
115
|
+
self.flatten = Rearrange('b c 1 1 -> b c')
|
|
116
|
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
117
|
+
|
|
118
|
+
def _make_layer(
|
|
119
|
+
self,
|
|
120
|
+
block: type[BasicBlock | Bottleneck],
|
|
121
|
+
planes: int,
|
|
122
|
+
blocks: int,
|
|
123
|
+
stride: int = 1
|
|
124
|
+
) -> nn.Sequential:
|
|
125
|
+
downsample = None
|
|
126
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
127
|
+
downsample = nn.Sequential(
|
|
128
|
+
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
|
|
129
|
+
nn.BatchNorm2d(planes * block.expansion),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
layers = []
|
|
133
|
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
134
|
+
self.inplanes = planes * block.expansion
|
|
135
|
+
for _ in range(1, blocks):
|
|
136
|
+
layers.append(block(self.inplanes, planes))
|
|
137
|
+
|
|
138
|
+
return nn.Sequential(*layers)
|
|
139
|
+
|
|
140
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
141
|
+
x = self.conv1(x)
|
|
142
|
+
x = self.bn1(x)
|
|
143
|
+
x = self.relu(x)
|
|
144
|
+
x = self.maxpool(x)
|
|
145
|
+
|
|
146
|
+
x = self.layer1(x)
|
|
147
|
+
x = self.layer2(x)
|
|
148
|
+
x = self.layer3(x)
|
|
149
|
+
x = self.layer4(x)
|
|
150
|
+
|
|
151
|
+
x = self.avgpool(x)
|
|
152
|
+
x = self.flatten(x)
|
|
153
|
+
x = self.fc(x)
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
# resnet factory
|
|
157
|
+
|
|
158
|
+
def resnet18(num_classes: any = 1000):
|
|
159
|
+
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
|
160
|
+
|
|
161
|
+
def resnet34(num_classes: any = 1000):
|
|
162
|
+
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
|
|
163
|
+
|
|
164
|
+
def resnet50(num_classes: any = 1000):
|
|
165
|
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
|
|
166
|
+
|
|
167
|
+
# transformer with resnet
|
|
168
|
+
|
|
169
|
+
class TransformerWithResnet(Transformer):
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
*,
|
|
173
|
+
resnet_type = 'resnet18',
|
|
174
|
+
**kwargs
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
resnet_klass = resnet18
|
|
178
|
+
if resnet_type == 'resnet34':
|
|
179
|
+
resnet_klass = resnet34
|
|
180
|
+
elif resnet_type == 'resnet50':
|
|
181
|
+
resnet_klass = resnet50
|
|
182
|
+
|
|
183
|
+
self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
|
|
184
|
+
self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
|
|
185
|
+
|
|
186
|
+
def visual_encode(self, x: Tensor) -> Tensor:
|
|
187
|
+
if x.shape[-1] == 3:
|
|
188
|
+
x = rearrange(x, '... h w c -> ... c h w')
|
|
189
|
+
|
|
190
|
+
x, inverse = pack_with_inverse(x, '* c h w')
|
|
191
|
+
|
|
192
|
+
h = self.visual_encoder(x)
|
|
193
|
+
|
|
194
|
+
return inverse(h, '* d')
|
{metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.41.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.41
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,7 +39,10 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.25
|
|
43
|
+
Requires-Dist: torch-einops-utils>=0.0.19
|
|
42
44
|
Requires-Dist: torch>=2.5
|
|
45
|
+
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
43
46
|
Requires-Dist: x-evolution>=0.1.23
|
|
44
47
|
Requires-Dist: x-mlps-pytorch
|
|
45
48
|
Requires-Dist: x-transformers
|
|
@@ -54,13 +57,110 @@ Description-Content-Type: text/markdown
|
|
|
54
57
|
|
|
55
58
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
56
59
|
|
|
60
|
+
## Install
|
|
61
|
+
|
|
62
|
+
```shell
|
|
63
|
+
$ pip install metacontroller-pytorch
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Appreciation
|
|
67
|
+
|
|
68
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
69
|
+
|
|
70
|
+
- [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
|
|
71
|
+
|
|
72
|
+
## Usage
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
import torch
|
|
76
|
+
from metacontroller import Transformer, MetaController
|
|
77
|
+
|
|
78
|
+
# 1. initialize model
|
|
79
|
+
|
|
80
|
+
model = Transformer(
|
|
81
|
+
dim = 512,
|
|
82
|
+
action_embed_readout = dict(num_discrete = 4),
|
|
83
|
+
state_embed_readout = dict(num_continuous = 384),
|
|
84
|
+
lower_body = dict(depth = 2),
|
|
85
|
+
upper_body = dict(depth = 2)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
state = torch.randn(2, 128, 384)
|
|
89
|
+
actions = torch.randint(0, 4, (2, 128))
|
|
90
|
+
|
|
91
|
+
# 2. behavioral cloning (BC)
|
|
92
|
+
|
|
93
|
+
state_loss, action_loss = model(state, actions)
|
|
94
|
+
(state_loss + action_loss).backward()
|
|
95
|
+
|
|
96
|
+
# 3. discovery phase
|
|
97
|
+
|
|
98
|
+
meta_controller = MetaController(
|
|
99
|
+
dim_model = 512,
|
|
100
|
+
dim_meta_controller = 256,
|
|
101
|
+
dim_latent = 128
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
action_recon_loss, kl_loss, switch_loss = model(
|
|
105
|
+
state,
|
|
106
|
+
actions,
|
|
107
|
+
meta_controller = meta_controller,
|
|
108
|
+
discovery_phase = True
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
(action_recon_loss + kl_loss + switch_loss).backward()
|
|
112
|
+
|
|
113
|
+
# 4. internal rl phase (GRPO)
|
|
114
|
+
|
|
115
|
+
# ... collect trajectories ...
|
|
116
|
+
|
|
117
|
+
logits, cache = model(
|
|
118
|
+
one_state,
|
|
119
|
+
past_action_id,
|
|
120
|
+
meta_controller = meta_controller,
|
|
121
|
+
return_cache = True
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
meta_output = cache.prev_hiddens.meta_controller
|
|
125
|
+
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
126
|
+
|
|
127
|
+
# ... calculate advantages ...
|
|
128
|
+
|
|
129
|
+
loss = meta_controller.policy_loss(
|
|
130
|
+
group_states,
|
|
131
|
+
group_old_log_probs,
|
|
132
|
+
group_latent_actions,
|
|
133
|
+
group_advantages,
|
|
134
|
+
group_switch_betas
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
loss.backward()
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
# 5. evolve (ES over GRPO)
|
|
144
|
+
|
|
145
|
+
model.meta_controller = meta_controller
|
|
146
|
+
|
|
147
|
+
def environment_callable(model):
|
|
148
|
+
# return a fitness score
|
|
149
|
+
return 1.0
|
|
150
|
+
|
|
151
|
+
model.evolve(
|
|
152
|
+
num_generations = 10,
|
|
153
|
+
environment = environment_callable
|
|
154
|
+
)
|
|
155
|
+
```
|
|
156
|
+
|
|
57
157
|
## Citations
|
|
58
158
|
|
|
59
159
|
```bibtex
|
|
60
160
|
@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
|
|
61
161
|
title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
|
|
62
162
|
author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
|
|
63
|
-
year={2025},
|
|
163
|
+
year = {2025},
|
|
64
164
|
eprint = {2512.20605},
|
|
65
165
|
archivePrefix = {arXiv},
|
|
66
166
|
primaryClass = {cs.LG},
|
|
@@ -78,3 +178,29 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
|
|
|
78
178
|
url = {https://api.semanticscholar.org/CorpusID:279464702}
|
|
79
179
|
}
|
|
80
180
|
```
|
|
181
|
+
|
|
182
|
+
```bibtex
|
|
183
|
+
@misc{hwang2025dynamicchunkingendtoendhierarchical,
|
|
184
|
+
title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
|
|
185
|
+
author = {Sukjun Hwang and Brandon Wang and Albert Gu},
|
|
186
|
+
year = {2025},
|
|
187
|
+
eprint = {2507.07955},
|
|
188
|
+
archivePrefix = {arXiv},
|
|
189
|
+
primaryClass = {cs.LG},
|
|
190
|
+
url = {https://arxiv.org/abs/2507.07955},
|
|
191
|
+
}
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
```bibtex
|
|
195
|
+
@misc{fleuret2025freetransformer,
|
|
196
|
+
title = {The Free Transformer},
|
|
197
|
+
author = {François Fleuret},
|
|
198
|
+
year = {2025},
|
|
199
|
+
eprint = {2510.17558},
|
|
200
|
+
archivePrefix = {arXiv},
|
|
201
|
+
primaryClass = {cs.LG},
|
|
202
|
+
url = {https://arxiv.org/abs/2510.17558},
|
|
203
|
+
}
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
*Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
|
|
2
|
+
metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
|
|
4
|
+
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
+
metacontroller_pytorch-0.0.41.dist-info/METADATA,sha256=IvP-wC73xCnT8X1aul1IfcaC4fUwRq9Y2UB1h0JG5TI,6822
|
|
6
|
+
metacontroller_pytorch-0.0.41.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.41.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=ug3xeMTZKApTF8oOPx9hWypeDjRflf1IJp8RiysXgTo,11618
|
|
3
|
-
metacontroller_pytorch-0.0.15.dist-info/METADATA,sha256=9d39BpcuVeOVVSD66lCVHCK1GjrkeKzRtxKOPOc-7xQ,3736
|
|
4
|
-
metacontroller_pytorch-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.15.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.41.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|