metacontroller-pytorch 0.0.44__tar.gz → 0.0.46__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py +27 -3
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_babyai.py +68 -42
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py +7 -3
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/README.md +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py +0 -0
- {metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_babyai_evo_strat.py +0 -0
{metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py
RENAMED
|
@@ -54,6 +54,19 @@ def default(*args):
|
|
|
54
54
|
def straight_through(src, tgt):
|
|
55
55
|
return tgt + src - src.detach()
|
|
56
56
|
|
|
57
|
+
# losses
|
|
58
|
+
|
|
59
|
+
BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
|
|
60
|
+
'state',
|
|
61
|
+
'action'
|
|
62
|
+
))
|
|
63
|
+
|
|
64
|
+
DiscoveryLosses = namedtuple('DiscoveryLosses', (
|
|
65
|
+
'action_recon',
|
|
66
|
+
'kl',
|
|
67
|
+
'switch'
|
|
68
|
+
))
|
|
69
|
+
|
|
57
70
|
# meta controller
|
|
58
71
|
|
|
59
72
|
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
@@ -450,7 +463,8 @@ class Transformer(Module):
|
|
|
450
463
|
return_raw_action_dist = False,
|
|
451
464
|
return_latents = False,
|
|
452
465
|
return_cache = False,
|
|
453
|
-
episode_lens: Tensor | None = None
|
|
466
|
+
episode_lens: Tensor | None = None,
|
|
467
|
+
return_meta_controller_output = False
|
|
454
468
|
):
|
|
455
469
|
device = state.device
|
|
456
470
|
|
|
@@ -544,13 +558,23 @@ class Transformer(Module):
|
|
|
544
558
|
|
|
545
559
|
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
|
|
546
560
|
|
|
547
|
-
|
|
561
|
+
losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
|
|
562
|
+
|
|
563
|
+
if not return_meta_controller_output:
|
|
564
|
+
return losses
|
|
565
|
+
|
|
566
|
+
return losses, next_meta_hiddens
|
|
548
567
|
|
|
549
568
|
elif discovery_phase:
|
|
550
569
|
|
|
551
570
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
552
571
|
|
|
553
|
-
|
|
572
|
+
losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
|
|
573
|
+
|
|
574
|
+
if not return_meta_controller_output:
|
|
575
|
+
return losses
|
|
576
|
+
|
|
577
|
+
return losses, next_meta_hiddens
|
|
554
578
|
|
|
555
579
|
# returning
|
|
556
580
|
|
|
@@ -18,35 +18,46 @@ from shutil import rmtree
|
|
|
18
18
|
from tqdm import tqdm
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
from torch import cat, tensor, stack
|
|
21
|
+
from torch import cat, tensor, stack, Tensor
|
|
22
22
|
from torch.optim import Adam
|
|
23
23
|
|
|
24
24
|
from einops import rearrange
|
|
25
25
|
|
|
26
|
+
from torch_einops_utils import pad_sequence
|
|
27
|
+
|
|
26
28
|
from accelerate import Accelerator
|
|
27
29
|
|
|
28
30
|
from babyai_env import create_env
|
|
31
|
+
|
|
29
32
|
from memmap_replay_buffer import ReplayBuffer
|
|
33
|
+
|
|
30
34
|
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
31
35
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
32
36
|
|
|
33
37
|
# research entry point
|
|
34
38
|
|
|
35
39
|
def reward_shaping_fn(
|
|
36
|
-
cumulative_rewards:
|
|
37
|
-
all_rewards:
|
|
38
|
-
episode_lens:
|
|
39
|
-
|
|
40
|
+
cumulative_rewards: Tensor, # float(num_episodes,)
|
|
41
|
+
all_rewards: Tensor, # float(num_episodes, max_timesteps)
|
|
42
|
+
episode_lens: Tensor, # int(num_episodes,)
|
|
43
|
+
reject_threshold_cumulative_reward_variance: float = 0.
|
|
44
|
+
) -> Tensor | None:
|
|
40
45
|
"""
|
|
41
46
|
researchers can modify this function to engineer rewards
|
|
42
47
|
or return None to reject the entire batch
|
|
43
|
-
|
|
44
|
-
cumulative_rewards: (num_episodes,)
|
|
45
|
-
all_rewards: (num_episodes, max_timesteps)
|
|
46
|
-
episode_lens: (num_episodes,)
|
|
47
48
|
"""
|
|
49
|
+
|
|
50
|
+
if cumulative_rewards.var() < reject_threshold_cumulative_reward_variance:
|
|
51
|
+
return None
|
|
52
|
+
|
|
48
53
|
return cumulative_rewards
|
|
49
54
|
|
|
55
|
+
def should_reject_group_based_on_switch_betas(
|
|
56
|
+
switch_betas: Tensor,
|
|
57
|
+
episode_lens: Tensor
|
|
58
|
+
):
|
|
59
|
+
return switch_betas.sum().item() == 0.
|
|
60
|
+
|
|
50
61
|
# helpers
|
|
51
62
|
|
|
52
63
|
def exists(v):
|
|
@@ -69,12 +80,14 @@ def main(
|
|
|
69
80
|
meta_controller_weights_path: str | None = None,
|
|
70
81
|
output_meta_controller_path = 'metacontroller_rl_trained.pt',
|
|
71
82
|
use_resnet = False,
|
|
83
|
+
num_epochs = 3,
|
|
72
84
|
lr = 1e-4,
|
|
73
85
|
batch_size = 16,
|
|
74
86
|
num_groups = 16,
|
|
75
87
|
max_grad_norm = 1.0,
|
|
76
88
|
use_wandb = False,
|
|
77
|
-
wandb_project = 'metacontroller-babyai-rl'
|
|
89
|
+
wandb_project = 'metacontroller-babyai-rl',
|
|
90
|
+
reject_threshold_cumulative_reward_variance = 0.
|
|
78
91
|
):
|
|
79
92
|
# accelerator
|
|
80
93
|
|
|
@@ -234,25 +247,37 @@ def main(
|
|
|
234
247
|
cumulative_rewards = stack(all_cumulative_rewards)
|
|
235
248
|
episode_lens = tensor(all_episode_lens)
|
|
236
249
|
|
|
237
|
-
# pad step rewards
|
|
238
|
-
|
|
239
250
|
max_len = max(all_episode_lens)
|
|
240
|
-
padded_step_rewards = torch.zeros(num_episodes, max_len)
|
|
241
251
|
|
|
242
|
-
|
|
243
|
-
|
|
252
|
+
# pad step rewards
|
|
253
|
+
|
|
254
|
+
padded_step_rewards = pad_sequence(all_step_rewards, dim = 0)
|
|
244
255
|
|
|
245
256
|
# reward shaping hook
|
|
246
257
|
|
|
247
|
-
shaped_rewards = reward_shaping_fn(
|
|
258
|
+
shaped_rewards = reward_shaping_fn(
|
|
259
|
+
cumulative_rewards,
|
|
260
|
+
padded_step_rewards,
|
|
261
|
+
episode_lens,
|
|
262
|
+
reject_threshold_cumulative_reward_variance = reject_threshold_cumulative_reward_variance
|
|
263
|
+
)
|
|
248
264
|
|
|
249
265
|
if not exists(shaped_rewards):
|
|
266
|
+
accelerator.print(f'group rejected - variance of {cumulative_rewards.var().item():.4f} is lower than threshold of {reject_threshold_cumulative_reward_variance}')
|
|
250
267
|
continue
|
|
251
268
|
|
|
252
269
|
group_advantages = z_score(shaped_rewards)
|
|
253
270
|
|
|
254
271
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
|
|
255
272
|
|
|
273
|
+
# whether to reject group based on switch betas (as it determines the mask for learning)
|
|
274
|
+
|
|
275
|
+
padded_group_switch_betas, episode_lens = pad_sequence(group_switch_betas, dim = 0, return_lens = True)
|
|
276
|
+
|
|
277
|
+
if should_reject_group_based_on_switch_betas(padded_group_switch_betas, episode_lens):
|
|
278
|
+
accelerator.print(f'group rejected - switch betas for the entire group does not meet criteria for learning')
|
|
279
|
+
continue
|
|
280
|
+
|
|
256
281
|
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
257
282
|
replay_buffer.store_episode(
|
|
258
283
|
states = states,
|
|
@@ -265,42 +290,43 @@ def main(
|
|
|
265
290
|
# learn
|
|
266
291
|
|
|
267
292
|
if len(replay_buffer) >= buffer_size:
|
|
268
|
-
dl = replay_buffer.dataloader(batch_size = batch_size)
|
|
293
|
+
dl = replay_buffer.dataloader(batch_size = batch_size, shuffle = True)
|
|
269
294
|
dl = accelerator.prepare(dl)
|
|
270
295
|
|
|
271
296
|
meta_controller.train()
|
|
272
297
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
298
|
+
for epoch in range(num_epochs):
|
|
299
|
+
for batch in dl:
|
|
300
|
+
loss = meta_controller.policy_loss(
|
|
301
|
+
batch['states'],
|
|
302
|
+
batch['log_probs'],
|
|
303
|
+
batch['latent_actions'],
|
|
304
|
+
batch['advantages'],
|
|
305
|
+
batch['switch_betas'] == 1.,
|
|
306
|
+
episode_lens = batch['_lens']
|
|
307
|
+
)
|
|
283
308
|
|
|
284
|
-
|
|
309
|
+
accelerator.backward(loss)
|
|
285
310
|
|
|
286
|
-
|
|
311
|
+
grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
|
|
287
312
|
|
|
288
|
-
|
|
289
|
-
|
|
313
|
+
optim.step()
|
|
314
|
+
optim.zero_grad()
|
|
290
315
|
|
|
291
|
-
|
|
316
|
+
pbar.set_postfix(
|
|
317
|
+
epoch = epoch,
|
|
318
|
+
loss = f'{loss.item():.4f}',
|
|
319
|
+
grad_norm = f'{grad_norm.item():.4f}',
|
|
320
|
+
reward = f'{cumulative_rewards.mean().item():.4f}'
|
|
321
|
+
)
|
|
292
322
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
323
|
+
accelerator.log({
|
|
324
|
+
'loss': loss.item(),
|
|
325
|
+
'grad_norm': grad_norm.item(),
|
|
326
|
+
'reward': cumulative_rewards.mean().item()
|
|
327
|
+
})
|
|
298
328
|
|
|
299
|
-
|
|
300
|
-
'loss': loss.item(),
|
|
301
|
-
'grad_norm': grad_norm.item(),
|
|
302
|
-
'reward': cumulative_rewards.mean().item()
|
|
303
|
-
})
|
|
329
|
+
meta_controller.eval()
|
|
304
330
|
|
|
305
331
|
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
|
|
306
332
|
|
{metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py
RENAMED
|
@@ -110,11 +110,13 @@ def train(
|
|
|
110
110
|
|
|
111
111
|
# state shape and action dimension
|
|
112
112
|
# state: (B, T, H, W, C) or (B, T, D)
|
|
113
|
+
|
|
113
114
|
state_shape = replay_buffer.shapes['state']
|
|
114
115
|
if use_resnet: state_dim = 256
|
|
115
116
|
else: state_dim = int(torch.tensor(state_shape).prod().item())
|
|
116
117
|
|
|
117
118
|
# deduce num_actions from the environment
|
|
119
|
+
|
|
118
120
|
from babyai_env import create_env
|
|
119
121
|
temp_env = create_env(env_id)
|
|
120
122
|
num_actions = int(temp_env.action_space.n)
|
|
@@ -181,12 +183,13 @@ def train(
|
|
|
181
183
|
|
|
182
184
|
|
|
183
185
|
with accelerator.accumulate(model):
|
|
184
|
-
losses = model(
|
|
186
|
+
losses, meta_controller_output = model(
|
|
185
187
|
states,
|
|
186
188
|
actions,
|
|
187
189
|
episode_lens = episode_lens,
|
|
188
190
|
discovery_phase = is_discovering,
|
|
189
|
-
force_behavior_cloning = not is_discovering
|
|
191
|
+
force_behavior_cloning = not is_discovering,
|
|
192
|
+
return_meta_controller_output = True
|
|
190
193
|
)
|
|
191
194
|
|
|
192
195
|
if is_discovering:
|
|
@@ -201,7 +204,8 @@ def train(
|
|
|
201
204
|
log = dict(
|
|
202
205
|
action_recon_loss = action_recon_loss.item(),
|
|
203
206
|
kl_loss = kl_loss.item(),
|
|
204
|
-
switch_loss = switch_loss.item()
|
|
207
|
+
switch_loss = switch_loss.item(),
|
|
208
|
+
switch_density = meta_controller_output.switch_beta.mean().item()
|
|
205
209
|
)
|
|
206
210
|
else:
|
|
207
211
|
state_loss, action_loss = losses
|
{metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.44 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|