metacontroller-pytorch 0.0.35__tar.gz → 0.0.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (17) hide show
  1. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/PKG-INFO +2 -2
  2. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/metacontroller/metacontroller.py +20 -8
  3. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/metacontroller/metacontroller_with_binary_mapper.py +4 -2
  4. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/pyproject.toml +2 -2
  5. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/tests/test_metacontroller.py +56 -10
  6. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/train_behavior_clone_babyai.py +91 -38
  7. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
  8. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/.github/workflows/test.yml +0 -0
  9. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/.gitignore +0 -0
  10. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/LICENSE +0 -0
  11. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/README.md +0 -0
  12. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/fig1.png +0 -0
  13. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/gather_babyai_trajs.py +0 -0
  14. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/metacontroller/__init__.py +0 -0
  15. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/metacontroller/transformer_with_resnet.py +0 -0
  16. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/test_babyai_e2e.sh +0 -0
  17. {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.37}/train_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.35
3
+ Version: 0.0.37
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,7 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
- Requires-Dist: memmap-replay-buffer>=0.0.23
42
+ Requires-Dist: memmap-replay-buffer>=0.0.25
43
43
  Requires-Dist: torch-einops-utils>=0.0.19
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: vector-quantize-pytorch>=1.27.20
@@ -336,6 +336,8 @@ class MetaController(Module):
336
336
 
337
337
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
338
338
 
339
+ MetaController.policy_loss = policy_loss
340
+
339
341
  # main transformer, which is subsumed into the environment after behavioral cloning
340
342
 
341
343
  Hiddens = namedtuple('Hiddens', (
@@ -406,6 +408,7 @@ class Transformer(Module):
406
408
  meta_controller: Module | None = None,
407
409
  cache: TransformerOutput | None = None,
408
410
  discovery_phase = False,
411
+ force_behavior_cloning = False,
409
412
  meta_controller_temperature = 1.,
410
413
  return_raw_action_dist = False,
411
414
  return_latents = False,
@@ -414,17 +417,25 @@ class Transformer(Module):
414
417
  ):
415
418
  device = state.device
416
419
 
420
+ # meta controller is either given or already given at init
421
+
417
422
  meta_controller = default(meta_controller, self.meta_controller)
418
423
 
419
- meta_controlling = exists(meta_controller)
424
+ if force_behavior_cloning:
425
+ assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
426
+ meta_controller = None
427
+
428
+ has_meta_controller = exists(meta_controller)
420
429
 
421
- behavioral_cloning = not meta_controlling and not return_raw_action_dist
430
+ assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
431
+
432
+ behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
422
433
 
423
434
  # by default, if meta controller is passed in, transformer is no grad
424
435
 
425
- lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
426
- meta_controller_context = nullcontext if meta_controlling else torch.no_grad
427
- upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
436
+ lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
437
+ meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
438
+ upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
428
439
 
429
440
  # handle cache
430
441
 
@@ -432,7 +443,8 @@ class Transformer(Module):
432
443
 
433
444
  # handle maybe behavioral cloning
434
445
 
435
- if behavioral_cloning or (meta_controlling and discovery_phase):
446
+ if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
447
+
436
448
  assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
437
449
 
438
450
  state, target_state = state[:, :-1], state[:, 1:]
@@ -465,7 +477,7 @@ class Transformer(Module):
465
477
 
466
478
  with meta_controller_context():
467
479
 
468
- if exists(meta_controller):
480
+ if exists(meta_controller) and not behavioral_cloning:
469
481
  control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
470
482
  else:
471
483
  control_signal, next_meta_hiddens = self.zero, None
@@ -495,7 +507,7 @@ class Transformer(Module):
495
507
 
496
508
  return state_clone_loss, action_clone_loss
497
509
 
498
- elif meta_controlling and discovery_phase:
510
+ elif discovery_phase:
499
511
 
500
512
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
501
513
 
@@ -28,7 +28,7 @@ from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
30
30
 
31
- from metacontroller.metacontroller import MetaControllerOutput
31
+ from metacontroller.metacontroller import MetaControllerOutput, policy_loss
32
32
 
33
33
  # constants
34
34
 
@@ -170,7 +170,7 @@ class MetaControllerWithBinaryMapper(Module):
170
170
  action_log_probs = log_probs.gather(-1, codes)
171
171
  action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
172
172
 
173
- return action_log_probs.sum(dim = -1)
173
+ return action_log_probs
174
174
 
175
175
  def forward(
176
176
  self,
@@ -302,3 +302,5 @@ class MetaControllerWithBinaryMapper(Module):
302
302
  switch_beta = rearrange(switch_beta, '... 1 -> ...')
303
303
 
304
304
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
305
+
306
+ MetaControllerWithBinaryMapper.policy_loss = policy_loss
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.35"
3
+ version = "0.0.37"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "einx>=0.3.0",
30
30
  "einops>=0.8.1",
31
31
  "loguru",
32
- "memmap-replay-buffer>=0.0.23",
32
+ "memmap-replay-buffer>=0.0.25",
33
33
  "torch>=2.5",
34
34
  "torch-einops-utils>=0.0.19",
35
35
  "vector-quantize-pytorch>=1.27.20",
@@ -1,6 +1,7 @@
1
1
  import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
+ from shutil import rmtree
4
5
  from pathlib import Path
5
6
  from functools import partial
6
7
 
@@ -9,6 +10,8 @@ from torch import cat
9
10
  from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
10
11
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
11
12
 
13
+ from memmap_replay_buffer import ReplayBuffer
14
+
12
15
  from einops import rearrange
13
16
 
14
17
  # functions
@@ -66,6 +69,12 @@ def test_metacontroller(
66
69
  dim_latent = 128,
67
70
  switch_per_latent_dim = switch_per_latent_dim
68
71
  )
72
+
73
+ field_shapes = dict(
74
+ log_probs = ('float', 128),
75
+ switch_betas = ('float', 128 if switch_per_latent_dim else 1),
76
+ latent_actions = ('float', 128)
77
+ )
69
78
  else:
70
79
  meta_controller = MetaControllerWithBinaryMapper(
71
80
  dim_model = 512,
@@ -74,6 +83,12 @@ def test_metacontroller(
74
83
  dim_code_bits = 8, # 2 ** 8 = 256 codes
75
84
  )
76
85
 
86
+ field_shapes = dict(
87
+ log_probs = ('float', 8),
88
+ switch_betas = ('float', 8 if switch_per_latent_dim else 1),
89
+ latent_actions = ('float', 256)
90
+ )
91
+
77
92
  # discovery phase
78
93
 
79
94
  (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
@@ -81,6 +96,23 @@ def test_metacontroller(
81
96
 
82
97
  # internal rl - done iteratively
83
98
 
99
+ # replay buffer
100
+
101
+ test_folder = './test-buffer-for-grpo'
102
+
103
+ replay_buffer = ReplayBuffer(
104
+ test_folder,
105
+ max_episodes = 3,
106
+ max_timesteps = 256,
107
+ fields = dict(
108
+ states = ('float', 512),
109
+ **field_shapes
110
+ ),
111
+ meta_fields = dict(
112
+ advantages = 'float'
113
+ )
114
+ )
115
+
84
116
  # simulate grpo
85
117
 
86
118
  all_episodes = []
@@ -129,22 +161,34 @@ def test_metacontroller(
129
161
  # calculate advantages using z-score
130
162
 
131
163
  rewards = cat(all_rewards)
132
- advantages = z_score(rewards)
164
+ group_advantages = z_score(rewards)
133
165
 
134
- assert advantages.shape == (3,)
166
+ assert group_advantages.shape == (3,)
135
167
 
136
168
  # simulate a policy loss update over the entire group
137
169
 
138
170
  group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
139
171
 
140
- loss = policy_loss(
141
- meta_controller,
142
- group_states,
143
- group_log_probs,
144
- group_latent_actions,
145
- advantages,
146
- group_switch_betas == 1.,
147
- episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
172
+ for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
173
+ replay_buffer.store_episode(
174
+ states = states,
175
+ log_probs = log_probs,
176
+ switch_betas = switch_betas,
177
+ latent_actions = latent_actions,
178
+ advantages = advantages
179
+ )
180
+
181
+ dl = replay_buffer.dataloader(batch_size = 3)
182
+
183
+ batch = next(iter(dl))
184
+
185
+ loss = meta_controller.policy_loss(
186
+ batch['states'],
187
+ batch['log_probs'],
188
+ batch['latent_actions'],
189
+ batch['advantages'],
190
+ batch['switch_betas'] == 1.,
191
+ episode_lens = batch['_lens']
148
192
  )
149
193
 
150
194
  loss.backward()
@@ -167,3 +211,5 @@ def test_metacontroller(
167
211
 
168
212
  Path('./meta_controller.pt').unlink()
169
213
  Path('./trained.pt').unlink()
214
+
215
+ rmtree(test_folder, ignore_errors = True)
@@ -25,29 +25,35 @@ from accelerate import Accelerator
25
25
  from memmap_replay_buffer import ReplayBuffer
26
26
  from einops import rearrange
27
27
 
28
- from metacontroller.metacontroller import Transformer
28
+ from metacontroller.metacontroller import Transformer, MetaController
29
29
  from metacontroller.transformer_with_resnet import TransformerWithResnet
30
30
 
31
31
  import minigrid
32
32
  import gymnasium as gym
33
33
 
34
34
  def train(
35
- input_dir: str = "babyai-minibosslevel-trajectories",
36
- env_id: str = "BabyAI-MiniBossLevel-v0",
37
- cloning_epochs: int = 10,
38
- discovery_epochs: int = 10,
39
- batch_size: int = 32,
40
- lr: float = 1e-4,
41
- dim: int = 512,
42
- depth: int = 2,
43
- heads: int = 8,
44
- dim_head: int = 64,
45
- use_wandb: bool = False,
46
- wandb_project: str = "metacontroller-babyai-bc",
47
- checkpoint_path: str = "transformer_bc.pt",
48
- state_loss_weight: float = 1.,
49
- action_loss_weight: float = 1.,
50
- use_resnet: bool = False
35
+ input_dir = "babyai-minibosslevel-trajectories",
36
+ env_id = "BabyAI-MiniBossLevel-v0",
37
+ cloning_epochs = 10,
38
+ discovery_epochs = 10,
39
+ batch_size = 32,
40
+ lr = 1e-4,
41
+ discovery_lr = 1e-4,
42
+ dim = 512,
43
+ depth = 2,
44
+ heads = 8,
45
+ dim_head = 64,
46
+ use_wandb = False,
47
+ wandb_project = "metacontroller-babyai-bc",
48
+ checkpoint_path = "transformer_bc.pt",
49
+ meta_controller_checkpoint_path = "meta_controller_discovery.pt",
50
+ state_loss_weight = 1.,
51
+ action_loss_weight = 1.,
52
+ discovery_action_recon_loss_weight = 1.,
53
+ discovery_kl_loss_weight = 1.,
54
+ discovery_switch_loss_weight = 1.,
55
+ max_grad_norm = 1.,
56
+ use_resnet = False
51
57
  ):
52
58
  # accelerator
53
59
 
@@ -96,6 +102,7 @@ def train(
96
102
  # transformer
97
103
 
98
104
  transformer_class = TransformerWithResnet if use_resnet else Transformer
105
+
99
106
  model = transformer_class(
100
107
  dim = dim,
101
108
  state_embed_readout = dict(num_continuous = state_dim),
@@ -104,23 +111,34 @@ def train(
104
111
  upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
105
112
  )
106
113
 
114
+ # meta controller
115
+
116
+ meta_controller = MetaController(dim)
117
+
107
118
  # optimizer
108
119
 
109
- optim = Adam(model.parameters(), lr = lr)
120
+ optim_model = Adam(model.parameters(), lr = lr)
121
+
122
+ optim_meta_controller = Adam(meta_controller.discovery_parameters(), lr = discovery_lr)
110
123
 
111
124
  # prepare
112
125
 
113
- model, optim, dataloader = accelerator.prepare(model, optim, dataloader)
126
+ model, optim_model, optim_meta_controller, dataloader = accelerator.prepare(model, optim_model, optim_meta_controller, dataloader)
114
127
 
115
128
  # training
129
+
116
130
  for epoch in range(cloning_epochs + discovery_epochs):
131
+
117
132
  model.train()
118
- total_state_loss = 0.
119
- total_action_loss = 0.
133
+ from collections import defaultdict
134
+ total_losses = defaultdict(float)
120
135
 
121
136
  progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
137
+
122
138
  is_discovering = (epoch >= cloning_epochs) # discovery phase is BC with metacontroller tuning
123
139
 
140
+ optim = optim_model if not is_discovering else optim_meta_controller
141
+
124
142
  for batch in progress_bar:
125
143
  # batch is a NamedTuple (e.g. MemoryMappedBatch)
126
144
  # state: (B, T, 7, 7, 3), action: (B, T)
@@ -130,51 +148,86 @@ def train(
130
148
  episode_lens = batch.get('_lens')
131
149
 
132
150
  # use resnet18 to embed visual observations
151
+
133
152
  if use_resnet:
134
153
  states = model.visual_encode(states)
135
154
  else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
136
155
  states = rearrange(states, 'b t ... -> b t (...)')
137
156
 
138
157
  with accelerator.accumulate(model):
139
- state_loss, action_loss = model(states, actions, episode_lens = episode_lens, discovery_phase=is_discovering)
140
- loss = state_loss * state_loss_weight + action_loss * action_loss_weight
158
+ losses = model(
159
+ states,
160
+ actions,
161
+ episode_lens = episode_lens,
162
+ discovery_phase = is_discovering,
163
+ meta_controller = meta_controller if is_discovering else None
164
+ )
165
+
166
+ if is_discovering:
167
+ action_recon_loss, kl_loss, switch_loss = losses
168
+
169
+ loss = (
170
+ action_recon_loss * discovery_action_recon_loss_weight +
171
+ kl_loss * discovery_kl_loss_weight +
172
+ switch_loss * discovery_switch_loss_weight
173
+ )
174
+
175
+ log = dict(
176
+ action_recon_loss = action_recon_loss.item(),
177
+ kl_loss = kl_loss.item(),
178
+ switch_loss = switch_loss.item()
179
+ )
180
+ else:
181
+ state_loss, action_loss = losses
182
+
183
+ loss = (
184
+ state_loss * state_loss_weight +
185
+ action_loss * action_loss_weight
186
+ )
187
+
188
+ log = dict(
189
+ state_loss = state_loss.item(),
190
+ action_loss = action_loss.item(),
191
+ )
192
+
193
+ # backprop
141
194
 
142
195
  accelerator.backward(loss)
143
196
 
144
- grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
197
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
145
198
 
146
199
  optim.step()
147
200
  optim.zero_grad()
148
201
 
149
202
  # log
150
-
151
- total_state_loss += state_loss.item()
152
- total_action_loss += action_loss.item()
203
+
204
+ for key, value in log.items():
205
+ total_losses[key] += value
153
206
 
154
207
  accelerator.log({
155
- "state_loss": state_loss.item(),
156
- "action_loss": action_loss.item(),
208
+ **log,
157
209
  "total_loss": loss.item(),
158
210
  "grad_norm": grad_norm.item()
159
211
  })
160
212
 
161
- progress_bar.set_postfix(
162
- state_loss = state_loss.item(),
163
- action_loss = action_loss.item()
164
- )
165
-
166
- avg_state_loss = total_state_loss / len(dataloader)
167
- avg_action_loss = total_action_loss / len(dataloader)
213
+ progress_bar.set_postfix(**log)
168
214
 
169
- accelerator.print(f"Epoch {epoch}: state_loss={avg_state_loss:.4f}, action_loss={avg_action_loss:.4f}")
215
+ avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
216
+ avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
217
+ accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
170
218
 
171
219
  # save weights
172
220
 
173
221
  accelerator.wait_for_everyone()
174
222
  if accelerator.is_main_process:
223
+
175
224
  unwrapped_model = accelerator.unwrap_model(model)
176
225
  unwrapped_model.save(checkpoint_path)
177
- accelerator.print(f"Model saved to {checkpoint_path}")
226
+
227
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
228
+ unwrapped_meta_controller.save(meta_controller_checkpoint_path)
229
+
230
+ accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
178
231
 
179
232
  accelerator.end_training()
180
233