metacontroller-pytorch 0.0.15__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- from metacontroller.metacontroller import MetaController
1
+ from metacontroller.metacontroller import MetaController, Transformer
@@ -6,7 +6,7 @@ from collections import namedtuple
6
6
  from loguru import logger
7
7
 
8
8
  import torch
9
- from torch import nn, cat, stack, tensor
9
+ from torch import nn, cat, stack, tensor, Tensor
10
10
  from torch.nn import Module, GRU, Linear, Identity
11
11
  import torch.nn.functional as F
12
12
 
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
18
18
 
19
19
  # external modules
20
20
 
21
- from x_transformers import Decoder
21
+ from x_transformers import Encoder, Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
@@ -26,6 +26,9 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left
30
+ from torch_einops_utils.save_load import save_load
31
+
29
32
  # constants
30
33
 
31
34
  LinearNoBias = partial(Linear, bias = False)
@@ -55,42 +58,100 @@ def straight_through(src, tgt):
55
58
 
56
59
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
60
  'prev_hiddens',
61
+ 'input_residual_stream',
58
62
  'action_dist',
59
63
  'actions',
60
- 'kl_loss'
64
+ 'switch_beta',
65
+ 'kl_loss',
66
+ 'switch_loss'
61
67
  ))
62
68
 
69
+ def z_score(t, eps = 1e-8):
70
+ return (t - t.mean()) / (t.std() + eps)
71
+
72
+ def policy_loss(
73
+ meta_controller,
74
+ state,
75
+ old_log_probs,
76
+ actions,
77
+ advantages,
78
+ mask,
79
+ episode_lens = None,
80
+ eps_clip = 0.2
81
+ ):
82
+ # get new log probs
83
+
84
+ action_dist = meta_controller.get_action_dist_for_internal_rl(state)
85
+ new_log_probs = meta_controller.log_prob(action_dist, actions)
86
+
87
+ # calculate ratio
88
+
89
+ ratio = (new_log_probs - old_log_probs).exp()
90
+
91
+ # align ratio and advantages
92
+
93
+ ratio, advantages = align_dims_left((ratio, advantages))
94
+
95
+ # ppo surrogate loss
96
+
97
+ surr1 = ratio * advantages
98
+ surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
99
+
100
+ losses = -torch.min(surr1, surr2)
101
+
102
+ # masking
103
+
104
+ if exists(episode_lens):
105
+ mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
106
+ mask = mask & episode_mask
107
+
108
+ return masked_mean(losses, mask)
109
+
110
+ @save_load()
63
111
  class MetaController(Module):
64
112
  def __init__(
65
113
  self,
66
- dim_latent,
114
+ dim_model,
67
115
  *,
116
+ dim_meta_controller = 256,
117
+ dim_latent = 128,
68
118
  switch_per_latent_dim = True,
69
119
  decoder_expansion_factor = 2.,
70
120
  decoder_depth = 1,
71
121
  hypernetwork_low_rank = 16,
72
- assoc_scan_kwargs: dict = dict()
122
+ assoc_scan_kwargs: dict = dict(),
123
+ bidirectional_temporal_encoder_kwargs: dict = dict(
124
+ attn_dim_head = 32,
125
+ heads = 8
126
+ )
73
127
  ):
74
128
  super().__init__()
129
+ self.dim_model = dim_model
130
+ dim_meta = default(dim_meta_controller, dim_model)
131
+
132
+ # the linear that brings from model dimension
133
+
134
+ self.model_to_meta = Linear(dim_model, dim_meta)
75
135
 
76
- # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
136
+ # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
77
137
 
78
- self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
138
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
79
139
 
80
- self.emitter = GRU(dim_latent * 2, dim_latent * 2)
81
- self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
140
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
141
+ self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
82
142
 
83
143
  # internal rl phase substitutes the acausal + emitter with a causal ssm
84
144
 
85
- self.action_proposer = GRU(dim_latent, dim_latent)
86
- self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
145
+ self.action_proposer = GRU(dim_meta, dim_meta)
146
+ self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
87
147
 
88
148
  # switching unit
89
149
 
90
150
  self.switch_per_latent_dim = switch_per_latent_dim
91
151
 
92
- self.switching_unit = GRU(dim_latent, dim_latent)
93
- self.to_switching_unit_beta = nn.Linear(dim_latent, dim_latent if switch_per_latent_dim else 1, bias = False)
152
+ self.dim_latent = dim_latent
153
+ self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
154
+ self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
94
155
 
95
156
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
96
157
 
@@ -104,16 +165,26 @@ class MetaController(Module):
104
165
  dim_in = dim_latent,
105
166
  dim = dim_decoder_hidden,
106
167
  depth = decoder_depth,
107
- dim_out = 2 * hypernetwork_low_rank * dim_latent
168
+ dim_out = 2 * hypernetwork_low_rank * dim_model
108
169
  )
109
170
 
110
171
  self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
111
172
 
112
173
  self.register_buffer('zero', tensor(0.), persistent = False)
113
174
 
175
+ @property
176
+ def replay_buffer_field_dict(self):
177
+ return dict(
178
+ states = ('float', self.dim_model),
179
+ log_probs = ('float', self.dim_latent),
180
+ switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
181
+ latent_actions = ('float', self.dim_latent)
182
+ )
183
+
114
184
  def discovery_parameters(self):
115
185
  return [
116
- *self.bidirectional_temporal_compressor.parameters(),
186
+ *self.model_to_meta.parameters(),
187
+ *self.bidirectional_temporal_encoder.parameters(),
117
188
  *self.emitter.parameters(),
118
189
  *self.emitter_to_action_mean_log_var.parameters(),
119
190
  *self.decoder.parameters(),
@@ -126,54 +197,99 @@ class MetaController(Module):
126
197
  *self.action_proposer_mean_log_var.parameters()
127
198
  ]
128
199
 
200
+ def get_action_dist_for_internal_rl(
201
+ self,
202
+ residual_stream
203
+ ):
204
+ meta_embed = self.model_to_meta(residual_stream)
205
+
206
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
207
+
208
+ return self.action_proposer_mean_log_var(proposed_action_hidden)
209
+
210
+ def log_prob(
211
+ self,
212
+ action_dist,
213
+ sampled_latent_action
214
+ ):
215
+ return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
216
+
129
217
  def forward(
130
218
  self,
131
219
  residual_stream,
132
220
  cache: MetaControllerOutput | None = None,
133
221
  discovery_phase = False,
134
- hard_switch = False,
135
- temperature = 1.
222
+ hard_switch = None,
223
+ temperature = 1.,
224
+ episode_lens: Tensor | None = None
136
225
  ):
226
+ device = residual_stream.device
137
227
 
138
228
  # destruct prev cache
139
229
 
140
- prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
230
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
141
231
 
142
232
  # getting proposed action for the two phases
143
233
 
144
234
  next_action_proposer_hidden = None
145
235
 
236
+ meta_embed = self.model_to_meta(residual_stream)
237
+
238
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
239
+
146
240
  if discovery_phase:
147
241
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
242
 
149
- temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
150
- temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
243
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
151
244
 
152
- proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
245
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
246
+
247
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
153
248
  readout = self.emitter_to_action_mean_log_var
154
249
 
155
250
  else: # else internal rl phase
156
251
 
157
- proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
252
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
158
253
  readout = self.action_proposer_mean_log_var
159
254
 
160
255
  # sample from the gaussian as the action from the meta controller
161
256
 
162
257
  action_dist = readout(proposed_action_hidden)
163
258
 
164
- sampled_action = readout.sample(action_dist, temperature = temperature)
259
+ sampled_latent_action = readout.sample(action_dist, temperature = temperature)
165
260
 
166
261
  # switching unit timer
167
262
 
168
- batch, _, dim = sampled_action.shape
263
+ batch, seq_len, dim = sampled_latent_action.shape
264
+
265
+ # initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
169
266
 
170
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
267
+ if not exists(prev_sampled_latent_action):
268
+ prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
269
+
270
+ if discovery_phase:
271
+ z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
272
+
273
+ else:
274
+ # else during inference, use the previous sampled latent action
275
+
276
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
277
+ z_prev = prev_sampled_latent_action
278
+
279
+ # switch input is previous latent action and the embedding
280
+
281
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
282
+
283
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
284
+ switch_input,
285
+ prev_switching_unit_gru_hidden
286
+ )
171
287
 
172
288
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
173
289
 
174
290
  # need to encourage normal distribution
175
291
 
176
- kl_loss = self.zero
292
+ kl_loss = switch_loss = self.zero
177
293
 
178
294
  if discovery_phase:
179
295
  mean, log_var = action_dist.unbind(dim = -1)
@@ -188,6 +304,10 @@ class MetaController(Module):
188
304
  kl_loss = kl_loss * switch_beta
189
305
  kl_loss = kl_loss.sum(dim = -1).mean()
190
306
 
307
+ # encourage less switching
308
+
309
+ switch_loss = switch_beta.mean()
310
+
191
311
  # maybe hard switch, then use associative scan
192
312
 
193
313
  if hard_switch:
@@ -195,7 +315,7 @@ class MetaController(Module):
195
315
  switch_beta = straight_through(switch_beta, hard_switch_beta)
196
316
 
197
317
  forget = 1. - switch_beta
198
- gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
318
+ gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
199
319
 
200
320
  next_switch_gated_action = gated_action[:, -1]
201
321
 
@@ -208,27 +328,40 @@ class MetaController(Module):
208
328
 
209
329
  # generating the residual stream controlling signal
210
330
 
211
- control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
212
-
213
- modified_residual_stream = residual_stream + control_signal
331
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
214
332
 
215
333
  # returning
216
334
 
217
335
  next_hiddens = (
218
336
  next_action_proposer_hidden,
219
337
  next_switching_unit_gru_hidden,
220
- next_switch_gated_action
338
+ next_switch_gated_action,
339
+ sampled_latent_action[:, -1:]
221
340
  )
222
341
 
223
- return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
342
+ # squeeze out the last dimension of switch_beta if single gate for all latent dimensions
343
+
344
+ if not self.switch_per_latent_dim:
345
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
346
+
347
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
348
+
349
+ MetaController.policy_loss = policy_loss
224
350
 
225
351
  # main transformer, which is subsumed into the environment after behavioral cloning
226
352
 
353
+ Hiddens = namedtuple('Hiddens', (
354
+ 'lower_body',
355
+ 'meta_controller',
356
+ 'upper_body'
357
+ ))
358
+
227
359
  TransformerOutput = namedtuple('TransformerOutput', (
228
360
  'residual_stream_latent',
229
361
  'prev_hiddens'
230
362
  ))
231
363
 
364
+ @save_load()
232
365
  class Transformer(Module):
233
366
  def __init__(
234
367
  self,
@@ -243,7 +376,7 @@ class Transformer(Module):
243
376
  super().__init__()
244
377
 
245
378
  if isinstance(lower_body, dict):
246
- lower_body = Decoder(dim = dim, **lower_body)
379
+ lower_body = Decoder(dim = dim, pre_norm_has_final_norm = False, **lower_body)
247
380
 
248
381
  if isinstance(upper_body, dict):
249
382
  upper_body = Decoder(dim = dim, **upper_body)
@@ -281,26 +414,38 @@ class Transformer(Module):
281
414
  def forward(
282
415
  self,
283
416
  state,
284
- action_ids,
417
+ actions: Tensor | None = None,
285
418
  meta_controller: Module | None = None,
286
419
  cache: TransformerOutput | None = None,
287
420
  discovery_phase = False,
421
+ force_behavior_cloning = False,
288
422
  meta_controller_temperature = 1.,
289
423
  return_raw_action_dist = False,
290
424
  return_latents = False,
291
425
  return_cache = False,
426
+ episode_lens: Tensor | None = None
292
427
  ):
428
+ device = state.device
429
+
430
+ # meta controller is either given or already given at init
431
+
293
432
  meta_controller = default(meta_controller, self.meta_controller)
294
433
 
295
- meta_controlling = exists(meta_controller)
434
+ if force_behavior_cloning:
435
+ assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
436
+ meta_controller = None
296
437
 
297
- behavioral_cloning = not meta_controlling and not return_raw_action_dist
438
+ has_meta_controller = exists(meta_controller)
439
+
440
+ assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
441
+
442
+ behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
298
443
 
299
444
  # by default, if meta controller is passed in, transformer is no grad
300
445
 
301
- lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
302
- meta_controller_context = nullcontext if meta_controlling else torch.no_grad
303
- upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
446
+ lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
447
+ meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
448
+ upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
304
449
 
305
450
  # handle cache
306
451
 
@@ -308,16 +453,31 @@ class Transformer(Module):
308
453
 
309
454
  # handle maybe behavioral cloning
310
455
 
311
- if behavioral_cloning:
456
+ if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
457
+
458
+ assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
459
+
312
460
  state, target_state = state[:, :-1], state[:, 1:]
313
- action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
461
+ actions, target_actions = actions[:, :-1], actions[:, 1:]
462
+
463
+ if exists(episode_lens):
464
+ episode_lens = (episode_lens - 1).clamp(min = 0)
314
465
 
315
466
  # transformer lower body
316
467
 
317
468
  with lower_transformer_context():
318
469
 
319
470
  state_embed = self.state_embed(state)
320
- action_embed = self.action_embed(action_ids)
471
+
472
+ # handle no past action for first timestep
473
+
474
+ if exists(actions):
475
+ action_embed = self.action_embed(actions)
476
+ else:
477
+ action_embed = state_embed[:, 0:0] # empty action embed
478
+
479
+ if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
480
+ action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
321
481
 
322
482
  embed = state_embed + action_embed
323
483
 
@@ -327,10 +487,12 @@ class Transformer(Module):
327
487
 
328
488
  with meta_controller_context():
329
489
 
330
- if exists(meta_controller):
331
- modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
490
+ if exists(meta_controller) and not behavioral_cloning:
491
+ control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
332
492
  else:
333
- modified_residual_stream, next_meta_hiddens = residual_stream, None
493
+ control_signal, next_meta_hiddens = self.zero, None
494
+
495
+ modified_residual_stream = residual_stream + control_signal
334
496
 
335
497
  # modified residual stream sent back to transformer upper body
336
498
 
@@ -345,13 +507,22 @@ class Transformer(Module):
345
507
  # maybe return behavior cloning loss
346
508
 
347
509
  if behavioral_cloning:
510
+
511
+ loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
512
+
348
513
  state_dist_params = self.state_readout(attended)
349
- state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
514
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
350
515
 
351
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
516
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
352
517
 
353
518
  return state_clone_loss, action_clone_loss
354
519
 
520
+ elif discovery_phase:
521
+
522
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
523
+
524
+ return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
525
+
355
526
  # returning
356
527
 
357
528
  return_one = not (return_latents or return_cache)
@@ -359,4 +530,4 @@ class Transformer(Module):
359
530
  if return_one:
360
531
  return dist_params
361
532
 
362
- return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
533
+ return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -0,0 +1,315 @@
1
+ from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
4
+ from functools import partial
5
+ from collections import namedtuple
6
+ from loguru import logger
7
+
8
+ import torch
9
+ from torch import nn, cat, stack, tensor, Tensor
10
+ from torch.nn import Module, GRU, Linear, Identity
11
+ import torch.nn.functional as F
12
+
13
+ # einops
14
+
15
+ import einx
16
+ from einops import einsum, rearrange, repeat, reduce
17
+ from einops.layers.torch import Rearrange
18
+
19
+ # external modules
20
+
21
+ from x_transformers import Encoder, Decoder
22
+ from x_mlps_pytorch import Feedforwards
23
+
24
+ from assoc_scan import AssocScan
25
+
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
27
+ from torch_einops_utils.save_load import save_load
28
+
29
+ from vector_quantize_pytorch import BinaryMapper
30
+
31
+ from metacontroller.metacontroller import MetaControllerOutput, policy_loss
32
+
33
+ # constants
34
+
35
+ LinearNoBias = partial(Linear, bias = False)
36
+
37
+ GRU = partial(GRU, batch_first = True)
38
+
39
+ # helper functions
40
+
41
+ def exists(v):
42
+ return v is not None
43
+
44
+ def default(*args):
45
+ for arg in args:
46
+ if exists(arg):
47
+ return arg
48
+ return None
49
+
50
+ def straight_through(src, tgt):
51
+ return tgt + src - src.detach()
52
+
53
+ def log(t, eps = 1e-20):
54
+ return t.clamp_min(eps).log()
55
+
56
+ # meta controller
57
+
58
+ @save_load()
59
+ class MetaControllerWithBinaryMapper(Module):
60
+ def __init__(
61
+ self,
62
+ dim_model,
63
+ *,
64
+ dim_meta_controller = 256,
65
+ dim_code_bits = 4,
66
+ switch_per_code = False,
67
+ decoder_expansion_factor = 2.,
68
+ decoder_depth = 1,
69
+ hypernetwork_low_rank = 16,
70
+ assoc_scan_kwargs: dict = dict(),
71
+ bidirectional_temporal_encoder_kwargs: dict = dict(
72
+ attn_dim_head = 32, heads = 8
73
+ ),
74
+ kl_loss_threshold = 0.
75
+ ):
76
+ super().__init__()
77
+ self.dim_model = dim_model
78
+ assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
+
80
+ dim_meta = default(dim_meta_controller, dim_model)
81
+
82
+ self.model_to_meta = Linear(dim_model, dim_meta)
83
+
84
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
85
+
86
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
87
+ self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
88
+
89
+ self.action_proposer = GRU(dim_meta, dim_meta)
90
+ self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
91
+
92
+ # binary mapper
93
+ # proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
94
+
95
+ self.binary_mapper = BinaryMapper(
96
+ bits = dim_code_bits,
97
+ kl_loss_threshold = kl_loss_threshold
98
+ )
99
+
100
+ self.dim_code_bits = dim_code_bits
101
+ self.num_codes = self.binary_mapper.num_codes
102
+
103
+ # switching unit
104
+
105
+ self.switch_per_code = switch_per_code
106
+
107
+ self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
108
+ self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
109
+
110
+ self.switch_gating = AssocScan(**assoc_scan_kwargs)
111
+
112
+ # decoder
113
+
114
+ assert hypernetwork_low_rank < self.num_codes
115
+
116
+ dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
117
+
118
+ self.decoder = Feedforwards(
119
+ dim_in = self.num_codes,
120
+ dim = dim_decoder_hidden,
121
+ depth = decoder_depth,
122
+ dim_out = 2 * hypernetwork_low_rank * dim_model
123
+ )
124
+
125
+ self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
126
+
127
+ self.register_buffer('zero', tensor(0.), persistent = False)
128
+
129
+ @property
130
+ def replay_buffer_field_dict(self):
131
+ return dict(
132
+ states = ('float', self.dim_model),
133
+ log_probs = ('float', self.dim_code_bits),
134
+ switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
135
+ latent_actions = ('float', self.num_codes)
136
+ )
137
+
138
+ def discovery_parameters(self):
139
+ return [
140
+ *self.model_to_meta.parameters(),
141
+ *self.bidirectional_temporal_encoder.parameters(),
142
+ *self.emitter.parameters(),
143
+ *self.emitter_to_binary_logits.parameters(),
144
+ *self.binary_mapper.parameters(),
145
+ *self.decoder.parameters(),
146
+ *self.switch_gating.parameters()
147
+ ]
148
+
149
+ def internal_rl_parameters(self):
150
+ return [
151
+ *self.action_proposer.parameters(),
152
+ *self.proposer_to_binary_logits.parameters()
153
+ ]
154
+
155
+ def get_action_dist_for_internal_rl(
156
+ self,
157
+ residual_stream
158
+ ):
159
+ meta_embed = self.model_to_meta(residual_stream)
160
+
161
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
162
+
163
+ return self.proposer_to_binary_logits(proposed_action_hidden)
164
+
165
+ def log_prob(
166
+ self,
167
+ action_dist,
168
+ sampled_latent_action
169
+ ):
170
+ log_probs = stack((
171
+ F.logsigmoid(action_dist),
172
+ F.logsigmoid(-action_dist)
173
+ ), dim = -1)
174
+
175
+ indices = sampled_latent_action.argmax(dim = -1)
176
+ codes = self.binary_mapper.codes[indices].long()
177
+
178
+ codes = rearrange(codes, '... -> ... 1')
179
+ action_log_probs = log_probs.gather(-1, codes)
180
+ action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
181
+
182
+ return action_log_probs
183
+
184
+ def forward(
185
+ self,
186
+ residual_stream,
187
+ cache: MetaControllerOutput | None = None,
188
+ discovery_phase = False,
189
+ hard_switch = None,
190
+ temperature = 1.,
191
+ episode_lens: Tensor | None = None
192
+ ):
193
+ device = residual_stream.device
194
+
195
+ # destruct prev cache
196
+
197
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
198
+
199
+ # getting proposed action for the two phases
200
+
201
+ next_action_proposer_hidden = None
202
+
203
+ meta_embed = self.model_to_meta(residual_stream)
204
+
205
+ hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
206
+
207
+ if discovery_phase:
208
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
209
+
210
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
211
+
212
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
213
+ to_logits = self.emitter_to_binary_logits
214
+
215
+ else: # else internal rl phase
216
+
217
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
218
+ to_logits = self.proposer_to_binary_logits
219
+
220
+ # sample from the binary mapper
221
+
222
+ binary_logits = to_logits(proposed_action_hidden)
223
+
224
+ one_hot, kl_loss = self.binary_mapper(
225
+ binary_logits,
226
+ temperature = temperature,
227
+ reduce_aux_kl_loss = False
228
+ )
229
+
230
+ # bottled action is now the one-hot sparse codes (with straight-through)
231
+
232
+ sampled_codes = one_hot
233
+
234
+ # switching unit timer
235
+
236
+ batch, seq_len, dim = sampled_codes.shape
237
+
238
+ if not exists(prev_sampled_code):
239
+ prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
240
+
241
+ if discovery_phase:
242
+ z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
243
+ else:
244
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
245
+ z_prev = prev_sampled_code
246
+
247
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
248
+
249
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
250
+ switch_input,
251
+ prev_switching_unit_gru_hidden
252
+ )
253
+
254
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
255
+
256
+ # losses
257
+
258
+ switch_loss = self.zero
259
+
260
+ if discovery_phase:
261
+ # weight unreduced kl loss by switch gates
262
+
263
+ kl_loss, switch_beta = align_dims_left((kl_loss, switch_beta))
264
+
265
+ weighted_kl_loss = kl_loss * switch_beta
266
+ kl_loss = weighted_kl_loss.sum(dim = -1).mean()
267
+
268
+ # encourage less switching
269
+
270
+ switch_loss = switch_beta.mean()
271
+ else:
272
+ kl_loss = self.zero
273
+
274
+ # maybe hard switch, then use associative scan
275
+
276
+ if hard_switch:
277
+ hard_switch_beta = (switch_beta > 0.5).float()
278
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
279
+
280
+ forget = 1. - switch_beta
281
+
282
+ # gated codes (or soft distribution)
283
+
284
+ gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
285
+
286
+ next_switch_gated_codes = gated_codes[:, -1]
287
+
288
+ # decoder
289
+
290
+ decoder_out = self.decoder(gated_codes)
291
+
292
+ w1, w2 = self.to_hyper_network_weights(decoder_out)
293
+ hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
294
+
295
+ # generating the residual stream controlling signal
296
+
297
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
298
+
299
+ # returning
300
+
301
+ next_hiddens = (
302
+ next_action_proposer_hidden,
303
+ next_switching_unit_gru_hidden,
304
+ next_switch_gated_codes,
305
+ sampled_codes[:, -1:]
306
+ )
307
+
308
+ # squeeze out the last dimension of switch_beta if single gate for all codes
309
+
310
+ if not self.switch_per_code:
311
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
312
+
313
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
314
+
315
+ MetaControllerWithBinaryMapper.policy_loss = policy_loss
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from metacontroller.metacontroller import Transformer
10
+
11
+ from torch_einops_utils import pack_with_inverse
12
+
13
+ # resnet components
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ class BasicBlock(Module):
19
+ expansion = 1
20
+
21
+ def __init__(
22
+ self,
23
+ dim,
24
+ dim_out,
25
+ stride = 1,
26
+ downsample: Module | None = None
27
+ ):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
30
+ self.bn1 = nn.BatchNorm2d(dim_out)
31
+ self.relu = nn.ReLU(inplace = True)
32
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
33
+ self.bn2 = nn.BatchNorm2d(dim_out)
34
+ self.downsample = downsample
35
+
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ identity = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if exists(self.downsample):
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ return self.relu(out)
51
+
52
+ class Bottleneck(Module):
53
+ expansion = 4
54
+
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ dim_out,
59
+ stride = 1,
60
+ downsample: Module | None = None
61
+ ):
62
+ super().__init__()
63
+ width = dim_out # simple resnet shortcut
64
+ self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
65
+ self.bn1 = nn.BatchNorm2d(width)
66
+ self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
67
+ self.bn2 = nn.BatchNorm2d(width)
68
+ self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
69
+ self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
70
+ self.relu = nn.ReLU(inplace = True)
71
+ self.downsample = downsample
72
+
73
+ def forward(self, x: Tensor) -> Tensor:
74
+ identity = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if exists(self.downsample):
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ return self.relu(out)
92
+
93
+ class ResNet(Module):
94
+ def __init__(
95
+ self,
96
+ block: type[BasicBlock | Bottleneck],
97
+ layers: list[int],
98
+ num_classes = 1000,
99
+ channels = 3
100
+ ):
101
+ super().__init__()
102
+ self.inplanes = 64
103
+
104
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
105
+ self.bn1 = nn.BatchNorm2d(64)
106
+ self.relu = nn.ReLU(inplace = True)
107
+ self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
113
+
114
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
115
+ self.flatten = Rearrange('b c 1 1 -> b c')
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(
119
+ self,
120
+ block: type[BasicBlock | Bottleneck],
121
+ planes: int,
122
+ blocks: int,
123
+ stride: int = 1
124
+ ) -> nn.Sequential:
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
129
+ nn.BatchNorm2d(planes * block.expansion),
130
+ )
131
+
132
+ layers = []
133
+ layers.append(block(self.inplanes, planes, stride, downsample))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(block(self.inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ x = self.conv1(x)
142
+ x = self.bn1(x)
143
+ x = self.relu(x)
144
+ x = self.maxpool(x)
145
+
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+
151
+ x = self.avgpool(x)
152
+ x = self.flatten(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+ # resnet factory
157
+
158
+ def resnet18(num_classes: any = 1000):
159
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
160
+
161
+ def resnet34(num_classes: any = 1000):
162
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
163
+
164
+ def resnet50(num_classes: any = 1000):
165
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
166
+
167
+ # transformer with resnet
168
+
169
+ class TransformerWithResnet(Transformer):
170
+ def __init__(
171
+ self,
172
+ *,
173
+ resnet_type = 'resnet18',
174
+ **kwargs
175
+ ):
176
+ super().__init__(**kwargs)
177
+ resnet_klass = resnet18
178
+ if resnet_type == 'resnet34':
179
+ resnet_klass = resnet34
180
+ elif resnet_type == 'resnet50':
181
+ resnet_klass = resnet50
182
+
183
+ self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
184
+ self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
185
+
186
+ def visual_encode(self, x: Tensor) -> Tensor:
187
+ if x.shape[-1] == 3:
188
+ x = rearrange(x, '... h w c -> ... c h w')
189
+
190
+ x, inverse = pack_with_inverse(x, '* c h w')
191
+
192
+ h = self.visual_encoder(x)
193
+
194
+ return inverse(h, '* d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.15
3
+ Version: 0.0.41
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,7 +39,10 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
+ Requires-Dist: memmap-replay-buffer>=0.0.25
43
+ Requires-Dist: torch-einops-utils>=0.0.19
42
44
  Requires-Dist: torch>=2.5
45
+ Requires-Dist: vector-quantize-pytorch>=1.27.20
43
46
  Requires-Dist: x-evolution>=0.1.23
44
47
  Requires-Dist: x-mlps-pytorch
45
48
  Requires-Dist: x-transformers
@@ -54,13 +57,110 @@ Description-Content-Type: text/markdown
54
57
 
55
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
56
59
 
60
+ ## Install
61
+
62
+ ```shell
63
+ $ pip install metacontroller-pytorch
64
+ ```
65
+
66
+ ## Appreciation
67
+
68
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
69
+
70
+ - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
+
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from metacontroller import Transformer, MetaController
77
+
78
+ # 1. initialize model
79
+
80
+ model = Transformer(
81
+ dim = 512,
82
+ action_embed_readout = dict(num_discrete = 4),
83
+ state_embed_readout = dict(num_continuous = 384),
84
+ lower_body = dict(depth = 2),
85
+ upper_body = dict(depth = 2)
86
+ )
87
+
88
+ state = torch.randn(2, 128, 384)
89
+ actions = torch.randint(0, 4, (2, 128))
90
+
91
+ # 2. behavioral cloning (BC)
92
+
93
+ state_loss, action_loss = model(state, actions)
94
+ (state_loss + action_loss).backward()
95
+
96
+ # 3. discovery phase
97
+
98
+ meta_controller = MetaController(
99
+ dim_model = 512,
100
+ dim_meta_controller = 256,
101
+ dim_latent = 128
102
+ )
103
+
104
+ action_recon_loss, kl_loss, switch_loss = model(
105
+ state,
106
+ actions,
107
+ meta_controller = meta_controller,
108
+ discovery_phase = True
109
+ )
110
+
111
+ (action_recon_loss + kl_loss + switch_loss).backward()
112
+
113
+ # 4. internal rl phase (GRPO)
114
+
115
+ # ... collect trajectories ...
116
+
117
+ logits, cache = model(
118
+ one_state,
119
+ past_action_id,
120
+ meta_controller = meta_controller,
121
+ return_cache = True
122
+ )
123
+
124
+ meta_output = cache.prev_hiddens.meta_controller
125
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
126
+
127
+ # ... calculate advantages ...
128
+
129
+ loss = meta_controller.policy_loss(
130
+ group_states,
131
+ group_old_log_probs,
132
+ group_latent_actions,
133
+ group_advantages,
134
+ group_switch_betas
135
+ )
136
+
137
+ loss.backward()
138
+ ```
139
+
140
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
141
+
142
+ ```python
143
+ # 5. evolve (ES over GRPO)
144
+
145
+ model.meta_controller = meta_controller
146
+
147
+ def environment_callable(model):
148
+ # return a fitness score
149
+ return 1.0
150
+
151
+ model.evolve(
152
+ num_generations = 10,
153
+ environment = environment_callable
154
+ )
155
+ ```
156
+
57
157
  ## Citations
58
158
 
59
159
  ```bibtex
60
160
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
61
161
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
62
162
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
63
- year={2025},
163
+ year = {2025},
64
164
  eprint = {2512.20605},
65
165
  archivePrefix = {arXiv},
66
166
  primaryClass = {cs.LG},
@@ -78,3 +178,29 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
78
178
  url = {https://api.semanticscholar.org/CorpusID:279464702}
79
179
  }
80
180
  ```
181
+
182
+ ```bibtex
183
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
184
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
185
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
186
+ year = {2025},
187
+ eprint = {2507.07955},
188
+ archivePrefix = {arXiv},
189
+ primaryClass = {cs.LG},
190
+ url = {https://arxiv.org/abs/2507.07955},
191
+ }
192
+ ```
193
+
194
+ ```bibtex
195
+ @misc{fleuret2025freetransformer,
196
+ title = {The Free Transformer},
197
+ author = {François Fleuret},
198
+ year = {2025},
199
+ eprint = {2510.17558},
200
+ archivePrefix = {arXiv},
201
+ primaryClass = {cs.LG},
202
+ url = {https://arxiv.org/abs/2510.17558},
203
+ }
204
+ ```
205
+
206
+ *Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
+ metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.41.dist-info/METADATA,sha256=IvP-wC73xCnT8X1aul1IfcaC4fUwRq9Y2UB1h0JG5TI,6822
6
+ metacontroller_pytorch-0.0.41.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.41.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=ug3xeMTZKApTF8oOPx9hWypeDjRflf1IJp8RiysXgTo,11618
3
- metacontroller_pytorch-0.0.15.dist-info/METADATA,sha256=9d39BpcuVeOVVSD66lCVHCK1GjrkeKzRtxKOPOc-7xQ,3736
4
- metacontroller_pytorch-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.15.dist-info/RECORD,,