metacontroller-pytorch 0.0.8__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
2
4
  from functools import partial
3
5
  from collections import namedtuple
4
6
 
@@ -50,6 +52,13 @@ def straight_through(src, tgt):
50
52
 
51
53
  # meta controller
52
54
 
55
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
56
+ 'prev_hiddens',
57
+ 'action_dist',
58
+ 'actions',
59
+ 'kl_loss'
60
+ ))
61
+
53
62
  class MetaController(Module):
54
63
  def __init__(
55
64
  self,
@@ -105,9 +114,9 @@ class MetaController(Module):
105
114
  return [
106
115
  *self.bidirectional_temporal_compressor.parameters(),
107
116
  *self.emitter.parameters(),
108
- *self.emitter_to_action_mean_log_var.parameters()
117
+ *self.emitter_to_action_mean_log_var.parameters(),
109
118
  *self.decoder.parameters(),
110
- *self.switch_gating
119
+ *self.switch_gating.parameters()
111
120
  ]
112
121
 
113
122
  def internal_rl_parameters(self):
@@ -119,10 +128,19 @@ class MetaController(Module):
119
128
  def forward(
120
129
  self,
121
130
  residual_stream,
131
+ cache: MetaControllerOutput | None = None,
122
132
  discovery_phase = False,
123
133
  hard_switch = False
124
134
  ):
125
135
 
136
+ # destruct prev cache
137
+
138
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
139
+
140
+ # getting proposed action for the two phases
141
+
142
+ next_action_proposer_hidden = None
143
+
126
144
  if discovery_phase:
127
145
  temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
128
146
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
@@ -131,7 +149,8 @@ class MetaController(Module):
131
149
  readout = self.emitter_to_action_mean_log_var
132
150
 
133
151
  else: # else internal rl phase
134
- proposed_action_hidden, _ = self.action_proposer(residual_stream)
152
+
153
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
135
154
  readout = self.action_proposer_mean_log_var
136
155
 
137
156
  # sample from the gaussian as the action from the meta controller
@@ -144,35 +163,37 @@ class MetaController(Module):
144
163
 
145
164
  batch, _, dim = sampled_action.shape
146
165
 
147
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
166
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
148
167
 
149
168
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
150
169
 
151
170
  # need to encourage normal distribution
152
171
 
153
- vae_kl_loss = self.zero
172
+ kl_loss = self.zero
154
173
 
155
174
  if discovery_phase:
156
175
  mean, log_var = action_dist.unbind(dim = -1)
157
176
 
158
- vae_kl_loss = (0.5 * (
177
+ kl_loss = (0.5 * (
159
178
  log_var.exp()
160
179
  + mean.square()
161
180
  - log_var
162
181
  - 1.
163
182
  ))
164
183
 
165
- vae_kl_loss = vae_kl_loss * switch_beta
166
- vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
184
+ kl_loss = kl_loss * switch_beta
185
+ kl_loss = kl_loss.sum(dim = -1).mean()
167
186
 
168
187
  # maybe hard switch, then use associative scan
169
188
 
170
189
  if hard_switch:
171
- hard_switch = (switch_beta > 0.5).float()
172
- switch_beta = straight_through(switch_beta, hard_switch)
190
+ hard_switch_beta = (switch_beta > 0.5).float()
191
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
173
192
 
174
193
  forget = 1. - switch_beta
175
- gated_action = self.switch_gating(switch_beta, sampled_action * forget)
194
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
195
+
196
+ next_switch_gated_action = gated_action[:, -1]
176
197
 
177
198
  # decoder
178
199
 
@@ -187,10 +208,23 @@ class MetaController(Module):
187
208
 
188
209
  modified_residual_stream = residual_stream + control_signal
189
210
 
190
- return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
211
+ # returning
212
+
213
+ next_hiddens = (
214
+ next_action_proposer_hidden,
215
+ next_switching_unit_gru_hidden,
216
+ next_switch_gated_action
217
+ )
218
+
219
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
191
220
 
192
221
  # main transformer, which is subsumed into the environment after behavioral cloning
193
222
 
223
+ TransformerOutput = namedtuple('TransformerOutput', (
224
+ 'residual_stream_latent',
225
+ 'prev_hiddens'
226
+ ))
227
+
194
228
  class Transformer(Module):
195
229
  def __init__(
196
230
  self,
@@ -249,29 +283,61 @@ class Transformer(Module):
249
283
  self,
250
284
  ids,
251
285
  meta_controller: Module | None = None,
286
+ cache: TransformerOutput | None = None,
252
287
  discovery_phase = False,
253
- return_latents = False
288
+ no_grad_transformer = None,
289
+ no_grad_meta_controller = None,
290
+ return_latents = False,
291
+ return_cache = False
254
292
  ):
255
293
  meta_controller = default(meta_controller, self.meta_controller)
256
294
 
257
- embed = self.embed(ids)
295
+ meta_controlling = exists(meta_controller)
296
+
297
+ # by default, if meta controller is passed in, transformer is no grad
258
298
 
259
- residual_stream = self.lower_body(embed)
299
+ no_grad_transformer = default(no_grad_transformer, meta_controlling)
300
+ no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
301
+
302
+ transformer_context = torch.no_grad if no_grad_transformer else nullcontext
303
+ meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
304
+
305
+ # handle cache
306
+
307
+ lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
308
+
309
+ # transformer lower body
310
+
311
+ with transformer_context():
312
+
313
+ embed = self.embed(ids)
314
+
315
+ residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
260
316
 
261
317
  # meta controller acts on residual stream here
262
318
 
263
- if exists(meta_controller):
264
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
265
- else:
266
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
319
+ with meta_controller_context():
320
+
321
+ if exists(meta_controller):
322
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
323
+ else:
324
+ modified_residual_stream, next_meta_hiddens = residual_stream, None
325
+
326
+ # modified residual stream sent back to transformer upper body
327
+
328
+ with transformer_context():
329
+
330
+ attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
331
+
332
+ # head readout
267
333
 
268
- # modified residual stream sent back
334
+ dist_params = self.readout(attended)
269
335
 
270
- attended = self.upper_body(modified_residual_stream)
336
+ # returning
271
337
 
272
- dist_params = self.readout(attended)
338
+ return_one = not (return_latents or return_cache)
273
339
 
274
- if not return_latents:
340
+ if return_one:
275
341
  return dist_params
276
342
 
277
- return dist_params, latents
343
+ return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.8"
3
+ version = "0.0.10"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,9 +26,14 @@ def test_metacontroller(
26
26
  switch_per_latent_dim = switch_per_latent_dim
27
27
  )
28
28
 
29
- logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
29
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
30
30
 
31
31
  assert logits.shape == (1, 1024, 256)
32
32
 
33
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
34
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
35
+
36
+ assert logits.shape == (1, 1, 256)
37
+
33
38
  model.meta_controller = meta_controller
34
39
  model.evolve(1, lambda _: 1., noise_population_size = 2)