metacontroller-pytorch 0.0.9__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.9
3
+ Version: 0.0.12
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan>=0.0.3
38
- Requires-Dist: discrete-continuous-embed-readout>=0.1.11
38
+ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: loguru
41
42
  Requires-Dist: torch>=2.5
42
43
  Requires-Dist: x-evolution>=0.1.23
43
44
  Requires-Dist: x-mlps-pytorch
@@ -3,6 +3,7 @@ from contextlib import nullcontext
3
3
 
4
4
  from functools import partial
5
5
  from collections import namedtuple
6
+ from loguru import logger
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, stack, tensor
@@ -52,6 +53,13 @@ def straight_through(src, tgt):
52
53
 
53
54
  # meta controller
54
55
 
56
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
+ 'prev_hiddens',
58
+ 'action_dist',
59
+ 'actions',
60
+ 'kl_loss'
61
+ ))
62
+
55
63
  class MetaController(Module):
56
64
  def __init__(
57
65
  self,
@@ -107,9 +115,9 @@ class MetaController(Module):
107
115
  return [
108
116
  *self.bidirectional_temporal_compressor.parameters(),
109
117
  *self.emitter.parameters(),
110
- *self.emitter_to_action_mean_log_var.parameters()
118
+ *self.emitter_to_action_mean_log_var.parameters(),
111
119
  *self.decoder.parameters(),
112
- *self.switch_gating
120
+ *self.switch_gating.parameters()
113
121
  ]
114
122
 
115
123
  def internal_rl_parameters(self):
@@ -121,11 +129,23 @@ class MetaController(Module):
121
129
  def forward(
122
130
  self,
123
131
  residual_stream,
132
+ cache: MetaControllerOutput | None = None,
124
133
  discovery_phase = False,
125
- hard_switch = False
134
+ hard_switch = False,
135
+ temperature = 1.
126
136
  ):
127
137
 
138
+ # destruct prev cache
139
+
140
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
141
+
142
+ # getting proposed action for the two phases
143
+
144
+ next_action_proposer_hidden = None
145
+
128
146
  if discovery_phase:
147
+ logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
+
129
149
  temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
130
150
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
131
151
 
@@ -133,48 +153,51 @@ class MetaController(Module):
133
153
  readout = self.emitter_to_action_mean_log_var
134
154
 
135
155
  else: # else internal rl phase
136
- proposed_action_hidden, _ = self.action_proposer(residual_stream)
156
+
157
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
137
158
  readout = self.action_proposer_mean_log_var
138
159
 
139
160
  # sample from the gaussian as the action from the meta controller
140
161
 
141
162
  action_dist = readout(proposed_action_hidden)
142
163
 
143
- sampled_action = readout.sample(action_dist)
164
+ sampled_action = readout.sample(action_dist, temperature = temperature)
144
165
 
145
166
  # switching unit timer
146
167
 
147
168
  batch, _, dim = sampled_action.shape
148
169
 
149
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
170
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
150
171
 
151
172
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
173
 
153
174
  # need to encourage normal distribution
154
175
 
155
- vae_kl_loss = self.zero
176
+ kl_loss = self.zero
156
177
 
157
178
  if discovery_phase:
158
179
  mean, log_var = action_dist.unbind(dim = -1)
159
180
 
160
- vae_kl_loss = (0.5 * (
181
+ kl_loss = (0.5 * (
161
182
  log_var.exp()
162
183
  + mean.square()
163
184
  - log_var
164
185
  - 1.
165
186
  ))
166
187
 
167
- vae_kl_loss = vae_kl_loss * switch_beta
168
- vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
188
+ kl_loss = kl_loss * switch_beta
189
+ kl_loss = kl_loss.sum(dim = -1).mean()
169
190
 
170
191
  # maybe hard switch, then use associative scan
171
192
 
172
193
  if hard_switch:
173
- hard_switch = (switch_beta > 0.5).float()
174
- switch_beta = straight_through(switch_beta, hard_switch)
194
+ hard_switch_beta = (switch_beta > 0.5).float()
195
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
175
196
 
176
197
  forget = 1. - switch_beta
177
- gated_action = self.switch_gating(switch_beta, sampled_action * forget)
198
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
199
+
200
+ next_switch_gated_action = gated_action[:, -1]
178
201
 
179
202
  # decoder
180
203
 
@@ -189,10 +212,23 @@ class MetaController(Module):
189
212
 
190
213
  modified_residual_stream = residual_stream + control_signal
191
214
 
192
- return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
215
+ # returning
216
+
217
+ next_hiddens = (
218
+ next_action_proposer_hidden,
219
+ next_switching_unit_gru_hidden,
220
+ next_switch_gated_action
221
+ )
222
+
223
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
193
224
 
194
225
  # main transformer, which is subsumed into the environment after behavioral cloning
195
226
 
227
+ TransformerOutput = namedtuple('TransformerOutput', (
228
+ 'residual_stream_latent',
229
+ 'prev_hiddens'
230
+ ))
231
+
196
232
  class Transformer(Module):
197
233
  def __init__(
198
234
  self,
@@ -251,10 +287,13 @@ class Transformer(Module):
251
287
  self,
252
288
  ids,
253
289
  meta_controller: Module | None = None,
290
+ cache: TransformerOutput | None = None,
254
291
  discovery_phase = False,
255
- return_latents = False,
256
292
  no_grad_transformer = None,
257
- no_grad_meta_controller = None
293
+ no_grad_meta_controller = None,
294
+ meta_controller_temperature = 1.,
295
+ return_latents = False,
296
+ return_cache = False,
258
297
  ):
259
298
  meta_controller = default(meta_controller, self.meta_controller)
260
299
 
@@ -268,28 +307,32 @@ class Transformer(Module):
268
307
  transformer_context = torch.no_grad if no_grad_transformer else nullcontext
269
308
  meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
270
309
 
310
+ # handle cache
311
+
312
+ lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
313
+
271
314
  # transformer lower body
272
315
 
273
316
  with transformer_context():
274
317
 
275
318
  embed = self.embed(ids)
276
319
 
277
- residual_stream = self.lower_body(embed)
320
+ residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
278
321
 
279
322
  # meta controller acts on residual stream here
280
323
 
281
324
  with meta_controller_context():
282
325
 
283
326
  if exists(meta_controller):
284
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
327
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
285
328
  else:
286
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
329
+ modified_residual_stream, next_meta_hiddens = residual_stream, None
287
330
 
288
331
  # modified residual stream sent back to transformer upper body
289
332
 
290
333
  with transformer_context():
291
334
 
292
- attended = self.upper_body(modified_residual_stream)
335
+ attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
293
336
 
294
337
  # head readout
295
338
 
@@ -297,7 +340,9 @@ class Transformer(Module):
297
340
 
298
341
  # returning
299
342
 
300
- if not return_latents:
343
+ return_one = not (return_latents or return_cache)
344
+
345
+ if return_one:
301
346
  return dist_params
302
347
 
303
- return dist_params, latents
348
+ return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.9"
3
+ version = "0.0.12"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,9 +25,10 @@ classifiers=[
25
25
 
26
26
  dependencies = [
27
27
  "assoc-scan>=0.0.3",
28
+ "discrete-continuous-embed-readout>=0.1.12",
28
29
  "einx>=0.3.0",
29
30
  "einops>=0.8.1",
30
- "discrete-continuous-embed-readout>=0.1.11",
31
+ "loguru",
31
32
  "torch>=2.5",
32
33
  "x-evolution>=0.1.23",
33
34
  "x-mlps-pytorch",
@@ -26,9 +26,14 @@ def test_metacontroller(
26
26
  switch_per_latent_dim = switch_per_latent_dim
27
27
  )
28
28
 
29
- logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
29
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
30
30
 
31
31
  assert logits.shape == (1, 1024, 256)
32
32
 
33
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
34
+ logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
35
+
36
+ assert logits.shape == (1, 1, 256)
37
+
33
38
  model.meta_controller = meta_controller
34
39
  model.evolve(1, lambda _: 1., noise_population_size = 2)