metacontroller-pytorch 0.0.12__tar.gz → 0.0.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.12
3
+ Version: 0.0.15
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -22,7 +22,7 @@ from x_transformers import Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
25
- from discrete_continuous_embed_readout import Embed, Readout
25
+ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
@@ -234,30 +234,25 @@ class Transformer(Module):
234
234
  self,
235
235
  dim,
236
236
  *,
237
- embed: Embed | dict,
237
+ state_embed_readout: dict,
238
+ action_embed_readout: dict,
238
239
  lower_body: Decoder | dict,
239
240
  upper_body: Decoder | dict,
240
- readout: Readout | dict,
241
241
  meta_controller: MetaController | None = None
242
242
  ):
243
243
  super().__init__()
244
244
 
245
- if isinstance(embed, dict):
246
- embed = Embed(dim = dim, **embed)
247
-
248
245
  if isinstance(lower_body, dict):
249
246
  lower_body = Decoder(dim = dim, **lower_body)
250
247
 
251
248
  if isinstance(upper_body, dict):
252
249
  upper_body = Decoder(dim = dim, **upper_body)
253
250
 
254
- if isinstance(readout, dict):
255
- readout = Readout(dim = dim, **readout)
251
+ self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
252
+ self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
256
253
 
257
- self.embed = embed
258
254
  self.lower_body = lower_body
259
255
  self.upper_body = upper_body
260
- self.readout = readout
261
256
 
262
257
  # meta controller
263
258
 
@@ -285,13 +280,13 @@ class Transformer(Module):
285
280
 
286
281
  def forward(
287
282
  self,
288
- ids,
283
+ state,
284
+ action_ids,
289
285
  meta_controller: Module | None = None,
290
286
  cache: TransformerOutput | None = None,
291
287
  discovery_phase = False,
292
- no_grad_transformer = None,
293
- no_grad_meta_controller = None,
294
288
  meta_controller_temperature = 1.,
289
+ return_raw_action_dist = False,
295
290
  return_latents = False,
296
291
  return_cache = False,
297
292
  ):
@@ -299,23 +294,32 @@ class Transformer(Module):
299
294
 
300
295
  meta_controlling = exists(meta_controller)
301
296
 
302
- # by default, if meta controller is passed in, transformer is no grad
297
+ behavioral_cloning = not meta_controlling and not return_raw_action_dist
303
298
 
304
- no_grad_transformer = default(no_grad_transformer, meta_controlling)
305
- no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
299
+ # by default, if meta controller is passed in, transformer is no grad
306
300
 
307
- transformer_context = torch.no_grad if no_grad_transformer else nullcontext
308
- meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
301
+ lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
302
+ meta_controller_context = nullcontext if meta_controlling else torch.no_grad
303
+ upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
309
304
 
310
305
  # handle cache
311
306
 
312
307
  lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
313
308
 
309
+ # handle maybe behavioral cloning
310
+
311
+ if behavioral_cloning:
312
+ state, target_state = state[:, :-1], state[:, 1:]
313
+ action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
314
+
314
315
  # transformer lower body
315
316
 
316
- with transformer_context():
317
+ with lower_transformer_context():
317
318
 
318
- embed = self.embed(ids)
319
+ state_embed = self.state_embed(state)
320
+ action_embed = self.action_embed(action_ids)
321
+
322
+ embed = state_embed + action_embed
319
323
 
320
324
  residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
321
325
 
@@ -330,13 +334,23 @@ class Transformer(Module):
330
334
 
331
335
  # modified residual stream sent back to transformer upper body
332
336
 
333
- with transformer_context():
337
+ with upper_transformer_context():
334
338
 
335
339
  attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
336
340
 
337
341
  # head readout
338
342
 
339
- dist_params = self.readout(attended)
343
+ dist_params = self.action_readout(attended)
344
+
345
+ # maybe return behavior cloning loss
346
+
347
+ if behavioral_cloning:
348
+ state_dist_params = self.state_readout(attended)
349
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
350
+
351
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
352
+
353
+ return state_clone_loss, action_clone_loss
340
354
 
341
355
  # returning
342
356
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.12"
3
+ version = "0.0.15"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,57 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from metacontroller.metacontroller import Transformer, MetaController
6
+
7
+ @param('action_discrete', (False, True))
8
+ @param('discovery_phase', (False, True))
9
+ @param('switch_per_latent_dim', (False, True))
10
+ def test_metacontroller(
11
+ action_discrete,
12
+ discovery_phase,
13
+ switch_per_latent_dim
14
+ ):
15
+
16
+ state = torch.randn(1, 1024, 384)
17
+
18
+ if action_discrete:
19
+ actions = torch.randint(0, 4, (1, 1024))
20
+ action_embed_readout = dict(num_discrete = 4)
21
+ assert_shape = (4,)
22
+ else:
23
+ actions = torch.randn(1, 1024, 8)
24
+ action_embed_readout = dict(num_continuous = 8)
25
+ assert_shape = (8, 2)
26
+
27
+ # behavioral cloning pahse
28
+
29
+ model = Transformer(
30
+ dim = 512,
31
+ action_embed_readout = action_embed_readout,
32
+ state_embed_readout = dict(num_continuous = 384),
33
+ lower_body = dict(depth = 2,),
34
+ upper_body = dict(depth = 2,),
35
+ )
36
+
37
+ state_clone_loss, action_clone_loss = model(state, actions)
38
+ (state_clone_loss + 0.5 * action_clone_loss).backward()
39
+
40
+ # discovery and internal rl phase with meta controller
41
+
42
+ meta_controller = MetaController(
43
+ dim_latent = 512,
44
+ switch_per_latent_dim = switch_per_latent_dim
45
+ )
46
+
47
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
48
+
49
+ assert logits.shape == (1, 1024, *assert_shape)
50
+
51
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
52
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
53
+
54
+ assert logits.shape == (1, 1, *assert_shape)
55
+
56
+ model.meta_controller = meta_controller
57
+ model.evolve(1, lambda _: 1., noise_population_size = 2)
@@ -1,39 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from metacontroller.metacontroller import Transformer, MetaController
6
-
7
- @param('discovery_phase', (False, True))
8
- @param('switch_per_latent_dim', (False, True))
9
- def test_metacontroller(
10
- discovery_phase,
11
- switch_per_latent_dim
12
- ):
13
-
14
- ids = torch.randint(0, 256, (1, 1024))
15
-
16
- model = Transformer(
17
- 512,
18
- embed = dict(num_discrete = 256),
19
- lower_body = dict(depth = 2,),
20
- upper_body = dict(depth = 2,),
21
- readout = dict(num_discrete = 256)
22
- )
23
-
24
- meta_controller = MetaController(
25
- 512,
26
- switch_per_latent_dim = switch_per_latent_dim
27
- )
28
-
29
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
30
-
31
- assert logits.shape == (1, 1024, 256)
32
-
33
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
34
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
35
-
36
- assert logits.shape == (1, 1, 256)
37
-
38
- model.meta_controller = meta_controller
39
- model.evolve(1, lambda _: 1., noise_population_size = 2)