metacontroller-pytorch 0.0.44__tar.gz → 0.0.46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/PKG-INFO +1 -1
  2. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py +27 -3
  3. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/pyproject.toml +1 -1
  4. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_babyai.py +68 -42
  5. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py +7 -3
  6. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml +0 -0
  7. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.github/workflows/test.yml +0 -0
  8. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.gitignore +0 -0
  9. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/LICENSE +0 -0
  10. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/README.md +0 -0
  11. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/babyai_env.py +0 -0
  12. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/fig1.png +0 -0
  13. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/gather_babyai_trajs.py +0 -0
  14. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/__init__.py +0 -0
  15. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  16. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/transformer_with_resnet.py +0 -0
  17. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py +0 -0
  18. {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_babyai_evo_strat.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.44
3
+ Version: 0.0.46
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -54,6 +54,19 @@ def default(*args):
54
54
  def straight_through(src, tgt):
55
55
  return tgt + src - src.detach()
56
56
 
57
+ # losses
58
+
59
+ BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
60
+ 'state',
61
+ 'action'
62
+ ))
63
+
64
+ DiscoveryLosses = namedtuple('DiscoveryLosses', (
65
+ 'action_recon',
66
+ 'kl',
67
+ 'switch'
68
+ ))
69
+
57
70
  # meta controller
58
71
 
59
72
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
@@ -450,7 +463,8 @@ class Transformer(Module):
450
463
  return_raw_action_dist = False,
451
464
  return_latents = False,
452
465
  return_cache = False,
453
- episode_lens: Tensor | None = None
466
+ episode_lens: Tensor | None = None,
467
+ return_meta_controller_output = False
454
468
  ):
455
469
  device = state.device
456
470
 
@@ -544,13 +558,23 @@ class Transformer(Module):
544
558
 
545
559
  action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
546
560
 
547
- return state_clone_loss, action_clone_loss
561
+ losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
562
+
563
+ if not return_meta_controller_output:
564
+ return losses
565
+
566
+ return losses, next_meta_hiddens
548
567
 
549
568
  elif discovery_phase:
550
569
 
551
570
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
552
571
 
553
- return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
572
+ losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
573
+
574
+ if not return_meta_controller_output:
575
+ return losses
576
+
577
+ return losses, next_meta_hiddens
554
578
 
555
579
  # returning
556
580
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.44"
3
+ version = "0.0.46"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -18,35 +18,46 @@ from shutil import rmtree
18
18
  from tqdm import tqdm
19
19
 
20
20
  import torch
21
- from torch import cat, tensor, stack
21
+ from torch import cat, tensor, stack, Tensor
22
22
  from torch.optim import Adam
23
23
 
24
24
  from einops import rearrange
25
25
 
26
+ from torch_einops_utils import pad_sequence
27
+
26
28
  from accelerate import Accelerator
27
29
 
28
30
  from babyai_env import create_env
31
+
29
32
  from memmap_replay_buffer import ReplayBuffer
33
+
30
34
  from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
31
35
  from metacontroller.transformer_with_resnet import TransformerWithResnet
32
36
 
33
37
  # research entry point
34
38
 
35
39
  def reward_shaping_fn(
36
- cumulative_rewards: torch.Tensor,
37
- all_rewards: torch.Tensor,
38
- episode_lens: torch.Tensor
39
- ) -> torch.Tensor | None:
40
+ cumulative_rewards: Tensor, # float(num_episodes,)
41
+ all_rewards: Tensor, # float(num_episodes, max_timesteps)
42
+ episode_lens: Tensor, # int(num_episodes,)
43
+ reject_threshold_cumulative_reward_variance: float = 0.
44
+ ) -> Tensor | None:
40
45
  """
41
46
  researchers can modify this function to engineer rewards
42
47
  or return None to reject the entire batch
43
-
44
- cumulative_rewards: (num_episodes,)
45
- all_rewards: (num_episodes, max_timesteps)
46
- episode_lens: (num_episodes,)
47
48
  """
49
+
50
+ if cumulative_rewards.var() < reject_threshold_cumulative_reward_variance:
51
+ return None
52
+
48
53
  return cumulative_rewards
49
54
 
55
+ def should_reject_group_based_on_switch_betas(
56
+ switch_betas: Tensor,
57
+ episode_lens: Tensor
58
+ ):
59
+ return switch_betas.sum().item() == 0.
60
+
50
61
  # helpers
51
62
 
52
63
  def exists(v):
@@ -69,12 +80,14 @@ def main(
69
80
  meta_controller_weights_path: str | None = None,
70
81
  output_meta_controller_path = 'metacontroller_rl_trained.pt',
71
82
  use_resnet = False,
83
+ num_epochs = 3,
72
84
  lr = 1e-4,
73
85
  batch_size = 16,
74
86
  num_groups = 16,
75
87
  max_grad_norm = 1.0,
76
88
  use_wandb = False,
77
- wandb_project = 'metacontroller-babyai-rl'
89
+ wandb_project = 'metacontroller-babyai-rl',
90
+ reject_threshold_cumulative_reward_variance = 0.
78
91
  ):
79
92
  # accelerator
80
93
 
@@ -234,25 +247,37 @@ def main(
234
247
  cumulative_rewards = stack(all_cumulative_rewards)
235
248
  episode_lens = tensor(all_episode_lens)
236
249
 
237
- # pad step rewards
238
-
239
250
  max_len = max(all_episode_lens)
240
- padded_step_rewards = torch.zeros(num_episodes, max_len)
241
251
 
242
- for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
243
- padded_step_rewards[i, :length] = rewards
252
+ # pad step rewards
253
+
254
+ padded_step_rewards = pad_sequence(all_step_rewards, dim = 0)
244
255
 
245
256
  # reward shaping hook
246
257
 
247
- shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
258
+ shaped_rewards = reward_shaping_fn(
259
+ cumulative_rewards,
260
+ padded_step_rewards,
261
+ episode_lens,
262
+ reject_threshold_cumulative_reward_variance = reject_threshold_cumulative_reward_variance
263
+ )
248
264
 
249
265
  if not exists(shaped_rewards):
266
+ accelerator.print(f'group rejected - variance of {cumulative_rewards.var().item():.4f} is lower than threshold of {reject_threshold_cumulative_reward_variance}')
250
267
  continue
251
268
 
252
269
  group_advantages = z_score(shaped_rewards)
253
270
 
254
271
  group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
255
272
 
273
+ # whether to reject group based on switch betas (as it determines the mask for learning)
274
+
275
+ padded_group_switch_betas, episode_lens = pad_sequence(group_switch_betas, dim = 0, return_lens = True)
276
+
277
+ if should_reject_group_based_on_switch_betas(padded_group_switch_betas, episode_lens):
278
+ accelerator.print(f'group rejected - switch betas for the entire group does not meet criteria for learning')
279
+ continue
280
+
256
281
  for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
257
282
  replay_buffer.store_episode(
258
283
  states = states,
@@ -265,42 +290,43 @@ def main(
265
290
  # learn
266
291
 
267
292
  if len(replay_buffer) >= buffer_size:
268
- dl = replay_buffer.dataloader(batch_size = batch_size)
293
+ dl = replay_buffer.dataloader(batch_size = batch_size, shuffle = True)
269
294
  dl = accelerator.prepare(dl)
270
295
 
271
296
  meta_controller.train()
272
297
 
273
- batch = next(iter(dl))
274
-
275
- loss = meta_controller.policy_loss(
276
- batch['states'],
277
- batch['log_probs'],
278
- batch['latent_actions'],
279
- batch['advantages'],
280
- batch['switch_betas'] == 1.,
281
- episode_lens = batch['_lens']
282
- )
298
+ for epoch in range(num_epochs):
299
+ for batch in dl:
300
+ loss = meta_controller.policy_loss(
301
+ batch['states'],
302
+ batch['log_probs'],
303
+ batch['latent_actions'],
304
+ batch['advantages'],
305
+ batch['switch_betas'] == 1.,
306
+ episode_lens = batch['_lens']
307
+ )
283
308
 
284
- accelerator.backward(loss)
309
+ accelerator.backward(loss)
285
310
 
286
- grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
311
+ grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
287
312
 
288
- optim.step()
289
- optim.zero_grad()
313
+ optim.step()
314
+ optim.zero_grad()
290
315
 
291
- meta_controller.eval()
316
+ pbar.set_postfix(
317
+ epoch = epoch,
318
+ loss = f'{loss.item():.4f}',
319
+ grad_norm = f'{grad_norm.item():.4f}',
320
+ reward = f'{cumulative_rewards.mean().item():.4f}'
321
+ )
292
322
 
293
- pbar.set_postfix(
294
- loss = f'{loss.item():.4f}',
295
- grad_norm = f'{grad_norm.item():.4f}',
296
- reward = f'{cumulative_rewards.mean().item():.4f}'
297
- )
323
+ accelerator.log({
324
+ 'loss': loss.item(),
325
+ 'grad_norm': grad_norm.item(),
326
+ 'reward': cumulative_rewards.mean().item()
327
+ })
298
328
 
299
- accelerator.log({
300
- 'loss': loss.item(),
301
- 'grad_norm': grad_norm.item(),
302
- 'reward': cumulative_rewards.mean().item()
303
- })
329
+ meta_controller.eval()
304
330
 
305
331
  accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
306
332
 
@@ -110,11 +110,13 @@ def train(
110
110
 
111
111
  # state shape and action dimension
112
112
  # state: (B, T, H, W, C) or (B, T, D)
113
+
113
114
  state_shape = replay_buffer.shapes['state']
114
115
  if use_resnet: state_dim = 256
115
116
  else: state_dim = int(torch.tensor(state_shape).prod().item())
116
117
 
117
118
  # deduce num_actions from the environment
119
+
118
120
  from babyai_env import create_env
119
121
  temp_env = create_env(env_id)
120
122
  num_actions = int(temp_env.action_space.n)
@@ -181,12 +183,13 @@ def train(
181
183
 
182
184
 
183
185
  with accelerator.accumulate(model):
184
- losses = model(
186
+ losses, meta_controller_output = model(
185
187
  states,
186
188
  actions,
187
189
  episode_lens = episode_lens,
188
190
  discovery_phase = is_discovering,
189
- force_behavior_cloning = not is_discovering
191
+ force_behavior_cloning = not is_discovering,
192
+ return_meta_controller_output = True
190
193
  )
191
194
 
192
195
  if is_discovering:
@@ -201,7 +204,8 @@ def train(
201
204
  log = dict(
202
205
  action_recon_loss = action_recon_loss.item(),
203
206
  kl_loss = kl_loss.item(),
204
- switch_loss = switch_loss.item()
207
+ switch_loss = switch_loss.item(),
208
+ switch_density = meta_controller_output.switch_beta.mean().item()
205
209
  )
206
210
  else:
207
211
  state_loss, action_loss = losses