metacontroller-pytorch 0.0.14__tar.gz → 0.0.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.15
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -22,7 +22,7 @@ from x_transformers import Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
25
- from discrete_continuous_embed_readout import Embed, Readout
25
+ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
@@ -234,30 +234,25 @@ class Transformer(Module):
234
234
  self,
235
235
  dim,
236
236
  *,
237
- embed: Embed | dict,
237
+ state_embed_readout: dict,
238
+ action_embed_readout: dict,
238
239
  lower_body: Decoder | dict,
239
240
  upper_body: Decoder | dict,
240
- readout: Readout | dict,
241
241
  meta_controller: MetaController | None = None
242
242
  ):
243
243
  super().__init__()
244
244
 
245
- if isinstance(embed, dict):
246
- embed = Embed(dim = dim, **embed)
247
-
248
245
  if isinstance(lower_body, dict):
249
246
  lower_body = Decoder(dim = dim, **lower_body)
250
247
 
251
248
  if isinstance(upper_body, dict):
252
249
  upper_body = Decoder(dim = dim, **upper_body)
253
250
 
254
- if isinstance(readout, dict):
255
- readout = Readout(dim = dim, **readout)
251
+ self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
252
+ self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
256
253
 
257
- self.embed = embed
258
254
  self.lower_body = lower_body
259
255
  self.upper_body = upper_body
260
- self.readout = readout
261
256
 
262
257
  # meta controller
263
258
 
@@ -285,11 +280,13 @@ class Transformer(Module):
285
280
 
286
281
  def forward(
287
282
  self,
288
- ids,
283
+ state,
284
+ action_ids,
289
285
  meta_controller: Module | None = None,
290
286
  cache: TransformerOutput | None = None,
291
287
  discovery_phase = False,
292
288
  meta_controller_temperature = 1.,
289
+ return_raw_action_dist = False,
293
290
  return_latents = False,
294
291
  return_cache = False,
295
292
  ):
@@ -297,21 +294,32 @@ class Transformer(Module):
297
294
 
298
295
  meta_controlling = exists(meta_controller)
299
296
 
297
+ behavioral_cloning = not meta_controlling and not return_raw_action_dist
298
+
300
299
  # by default, if meta controller is passed in, transformer is no grad
301
300
 
302
301
  lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
303
302
  meta_controller_context = nullcontext if meta_controlling else torch.no_grad
304
- upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
303
+ upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
305
304
 
306
305
  # handle cache
307
306
 
308
307
  lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
309
308
 
309
+ # handle maybe behavioral cloning
310
+
311
+ if behavioral_cloning:
312
+ state, target_state = state[:, :-1], state[:, 1:]
313
+ action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
314
+
310
315
  # transformer lower body
311
316
 
312
317
  with lower_transformer_context():
313
318
 
314
- embed = self.embed(ids)
319
+ state_embed = self.state_embed(state)
320
+ action_embed = self.action_embed(action_ids)
321
+
322
+ embed = state_embed + action_embed
315
323
 
316
324
  residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
317
325
 
@@ -332,7 +340,17 @@ class Transformer(Module):
332
340
 
333
341
  # head readout
334
342
 
335
- dist_params = self.readout(attended)
343
+ dist_params = self.action_readout(attended)
344
+
345
+ # maybe return behavior cloning loss
346
+
347
+ if behavioral_cloning:
348
+ state_dist_params = self.state_readout(attended)
349
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
350
+
351
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
352
+
353
+ return state_clone_loss, action_clone_loss
336
354
 
337
355
  # returning
338
356
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.14"
3
+ version = "0.0.15"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,57 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from metacontroller.metacontroller import Transformer, MetaController
6
+
7
+ @param('action_discrete', (False, True))
8
+ @param('discovery_phase', (False, True))
9
+ @param('switch_per_latent_dim', (False, True))
10
+ def test_metacontroller(
11
+ action_discrete,
12
+ discovery_phase,
13
+ switch_per_latent_dim
14
+ ):
15
+
16
+ state = torch.randn(1, 1024, 384)
17
+
18
+ if action_discrete:
19
+ actions = torch.randint(0, 4, (1, 1024))
20
+ action_embed_readout = dict(num_discrete = 4)
21
+ assert_shape = (4,)
22
+ else:
23
+ actions = torch.randn(1, 1024, 8)
24
+ action_embed_readout = dict(num_continuous = 8)
25
+ assert_shape = (8, 2)
26
+
27
+ # behavioral cloning pahse
28
+
29
+ model = Transformer(
30
+ dim = 512,
31
+ action_embed_readout = action_embed_readout,
32
+ state_embed_readout = dict(num_continuous = 384),
33
+ lower_body = dict(depth = 2,),
34
+ upper_body = dict(depth = 2,),
35
+ )
36
+
37
+ state_clone_loss, action_clone_loss = model(state, actions)
38
+ (state_clone_loss + 0.5 * action_clone_loss).backward()
39
+
40
+ # discovery and internal rl phase with meta controller
41
+
42
+ meta_controller = MetaController(
43
+ dim_latent = 512,
44
+ switch_per_latent_dim = switch_per_latent_dim
45
+ )
46
+
47
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
48
+
49
+ assert logits.shape == (1, 1024, *assert_shape)
50
+
51
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
52
+ logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
53
+
54
+ assert logits.shape == (1, 1, *assert_shape)
55
+
56
+ model.meta_controller = meta_controller
57
+ model.evolve(1, lambda _: 1., noise_population_size = 2)
@@ -1,39 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from metacontroller.metacontroller import Transformer, MetaController
6
-
7
- @param('discovery_phase', (False, True))
8
- @param('switch_per_latent_dim', (False, True))
9
- def test_metacontroller(
10
- discovery_phase,
11
- switch_per_latent_dim
12
- ):
13
-
14
- ids = torch.randint(0, 256, (1, 1024))
15
-
16
- model = Transformer(
17
- 512,
18
- embed = dict(num_discrete = 256),
19
- lower_body = dict(depth = 2,),
20
- upper_body = dict(depth = 2,),
21
- readout = dict(num_discrete = 256)
22
- )
23
-
24
- meta_controller = MetaController(
25
- 512,
26
- switch_per_latent_dim = switch_per_latent_dim
27
- )
28
-
29
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
30
-
31
- assert logits.shape == (1, 1024, 256)
32
-
33
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
34
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
35
-
36
- assert logits.shape == (1, 1, 256)
37
-
38
- model.meta_controller = meta_controller
39
- model.evolve(1, lambda _: 1., noise_population_size = 2)