metacontroller-pytorch 0.0.28__tar.gz → 0.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/PKG-INFO +5 -1
  2. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/README.md +4 -0
  3. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/gather_babyai_trajs.py +43 -2
  4. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/metacontroller/metacontroller.py +2 -1
  5. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/metacontroller/metacontroller_with_binary_mapper.py +2 -1
  6. metacontroller_pytorch-0.0.30/metacontroller/metacontroller_with_resnet.py +250 -0
  7. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/pyproject.toml +1 -1
  8. metacontroller_pytorch-0.0.30/test_babyai_e2e.sh +14 -0
  9. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/tests/test_metacontroller.py +1 -1
  10. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/train_behavior_clone_babyai.py +25 -18
  11. metacontroller_pytorch-0.0.28/test_babyai_e2e.sh +0 -14
  12. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/.github/workflows/python-publish.yml +0 -0
  13. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/.github/workflows/test.yml +0 -0
  14. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/.gitignore +0 -0
  15. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/LICENSE +0 -0
  16. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/fig1.png +0 -0
  17. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/metacontroller/__init__.py +0 -0
  18. {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.30}/train_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.28
3
+ Version: 0.0.30
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -67,6 +67,8 @@ $ pip install metacontroller-pytorch
67
67
 
68
68
  - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
69
69
 
70
+ - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
+
70
72
  ## Citations
71
73
 
72
74
  ```bibtex
@@ -103,3 +105,5 @@ $ pip install metacontroller-pytorch
103
105
  url = {https://arxiv.org/abs/2510.17558},
104
106
  }
105
107
  ```
108
+
109
+ *Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
@@ -14,6 +14,8 @@ $ pip install metacontroller-pytorch
14
14
 
15
15
  - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
16
16
 
17
+ - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
18
+
17
19
  ## Citations
18
20
 
19
21
  ```bibtex
@@ -50,3 +52,5 @@ $ pip install metacontroller-pytorch
50
52
  url = {https://arxiv.org/abs/2510.17558},
51
53
  }
52
54
  ```
55
+
56
+ *Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
@@ -31,8 +31,13 @@ import gymnasium as gym
31
31
  from minigrid.utils.baby_ai_bot import BabyAIBot
32
32
  from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
33
33
 
34
+ from gymnasium import spaces
35
+ from gymnasium.core import ObservationWrapper
36
+ from minigrid.core.constants import OBJECT_TO_IDX
37
+
34
38
  from memmap_replay_buffer import ReplayBuffer
35
39
 
40
+
36
41
  # helpers
37
42
 
38
43
  def exists(val):
@@ -41,6 +46,40 @@ def exists(val):
41
46
  def sample(prob):
42
47
  return random.random() < prob
43
48
 
49
+ # wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
50
+
51
+ class RGBImgPartialObsWrapper(ObservationWrapper):
52
+ """
53
+ Wrapper to use partially observable RGB image as observation.
54
+ This can be used to have the agent to solve the gridworld in pixel space.
55
+ """
56
+ def __init__(self, env, tile_size=1):
57
+ super().__init__(env)
58
+
59
+ # Rendering attributes for observations
60
+ self.tile_size = tile_size
61
+
62
+ symbolic_image_space = self.observation_space["image"]
63
+
64
+ obs_shape = env.observation_space.spaces["image"].shape
65
+ new_image_space = spaces.Box(
66
+ low=0,
67
+ high=255,
68
+ shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
69
+ dtype="uint8",
70
+ )
71
+
72
+ self.observation_space = spaces.Dict(
73
+ {**self.observation_space.spaces, "image": symbolic_image_space, "rgb_image": new_image_space}
74
+ )
75
+
76
+ def observation(self, obs):
77
+ rgb_img_partial = self.unwrapped.get_frame(
78
+ tile_size=self.tile_size, agent_pov=True
79
+ )
80
+
81
+ return {**obs, "rgb_image": rgb_img_partial}
82
+
44
83
  # agent
45
84
 
46
85
  class BabyAIBotEpsilonGreedy:
@@ -72,6 +111,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
72
111
  env = gym.make(env_id, render_mode="rgb_array", highlight=False)
73
112
  env = FullyObsWrapper(env.unwrapped)
74
113
  env = SymbolicObsWrapper(env.unwrapped)
114
+ env = RGBImgPartialObsWrapper(env.unwrapped)
75
115
 
76
116
  try:
77
117
  state_obs, _ = env.reset(seed=seed)
@@ -88,7 +128,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
88
128
  env.close()
89
129
  return None, None, False, 0
90
130
 
91
- episode_state[_step] = state_obs["image"]
131
+ episode_state[_step] = state_obs["rgb_image"] / 255. # normalizd to 0 to 1
92
132
  episode_action[_step] = action
93
133
 
94
134
  state_obs, reward, terminated, truncated, info = env.step(action)
@@ -127,7 +167,8 @@ def collect_demonstrations(
127
167
  temp_env = gym.make(env_id)
128
168
  temp_env = FullyObsWrapper(temp_env.unwrapped)
129
169
  temp_env = SymbolicObsWrapper(temp_env.unwrapped)
130
- state_shape = temp_env.observation_space['image'].shape
170
+ temp_env = RGBImgPartialObsWrapper(temp_env.unwrapped)
171
+ state_shape = temp_env.observation_space['rgb_image'].shape
131
172
  temp_env.close()
132
173
 
133
174
  logger.info(f"Detected state shape: {state_shape} for env {env_id}")
@@ -58,6 +58,7 @@ def straight_through(src, tgt):
58
58
 
59
59
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
60
60
  'prev_hiddens',
61
+ 'input_residual_stream',
61
62
  'action_dist',
62
63
  'actions',
63
64
  'kl_loss',
@@ -268,7 +269,7 @@ class MetaController(Module):
268
269
  sampled_latent_action[:, -1:]
269
270
  )
270
271
 
271
- return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
272
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, kl_loss, switch_loss)
272
273
 
273
274
  # main transformer, which is subsumed into the environment after behavioral cloning
274
275
 
@@ -52,6 +52,7 @@ def straight_through(src, tgt):
52
52
 
53
53
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
54
54
  'prev_hiddens',
55
+ 'input_residual_stream',
55
56
  'action_dist',
56
57
  'codes',
57
58
  'kl_loss',
@@ -265,4 +266,4 @@ class MetaControllerWithBinaryMapper(Module):
265
266
  sampled_codes[:, -1:]
266
267
  )
267
268
 
268
- return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
269
+ return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, kl_loss, switch_loss)
@@ -0,0 +1,250 @@
1
+ from typing import Any, List, Type, Union, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch import nn
6
+ from einops import rearrange
7
+ from metacontroller.metacontroller import Transformer
8
+
9
+ class TransformerWithResnetEncoder(Transformer):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
13
+ self.visual_encoder = resnet18(out_dim=self.resnet_dim)
14
+
15
+ def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
16
+ b, t = x.shape[:2]
17
+ x = rearrange(x, 'b t h w c -> (b t) c h w')
18
+ h = self.visual_encoder(x)
19
+ h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
20
+ return h
21
+
22
+ # resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
23
+
24
+ class _BasicBlock(nn.Module):
25
+ expansion: int = 1
26
+
27
+ def __init__(
28
+ self,
29
+ in_channels: int,
30
+ out_channels: int,
31
+ stride: int,
32
+ downsample: Optional[nn.Module] = None,
33
+ groups: int = 1,
34
+ base_channels: int = 64,
35
+ ) -> None:
36
+ super(_BasicBlock, self).__init__()
37
+ self.stride = stride
38
+ self.downsample = downsample
39
+ self.groups = groups
40
+ self.base_channels = base_channels
41
+
42
+ self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
43
+ self.bn1 = nn.BatchNorm2d(out_channels)
44
+ self.relu = nn.ReLU(True)
45
+ self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
46
+ self.bn2 = nn.BatchNorm2d(out_channels)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ identity = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+
61
+ out = torch.add(out, identity)
62
+ out = self.relu(out)
63
+
64
+ return out
65
+
66
+
67
+ class _Bottleneck(nn.Module):
68
+ expansion: int = 4
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int,
73
+ out_channels: int,
74
+ stride: int,
75
+ downsample: Optional[nn.Module] = None,
76
+ groups: int = 1,
77
+ base_channels: int = 64,
78
+ ) -> None:
79
+ super(_Bottleneck, self).__init__()
80
+ self.stride = stride
81
+ self.downsample = downsample
82
+ self.groups = groups
83
+ self.base_channels = base_channels
84
+
85
+ channels = int(out_channels * (base_channels / 64.0)) * groups
86
+
87
+ self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
88
+ self.bn1 = nn.BatchNorm2d(channels)
89
+ self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
90
+ self.bn2 = nn.BatchNorm2d(channels)
91
+ self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
92
+ self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
93
+ self.relu = nn.ReLU(True)
94
+
95
+ def forward(self, x: Tensor) -> Tensor:
96
+ identity = x
97
+
98
+ out = self.conv1(x)
99
+ out = self.bn1(out)
100
+ out = self.relu(out)
101
+
102
+ out = self.conv2(out)
103
+ out = self.bn2(out)
104
+ out = self.relu(out)
105
+
106
+ out = self.conv3(out)
107
+ out = self.bn3(out)
108
+
109
+ if self.downsample is not None:
110
+ identity = self.downsample(x)
111
+
112
+ out = torch.add(out, identity)
113
+ out = self.relu(out)
114
+
115
+ return out
116
+
117
+
118
+ class ResNet(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ arch_cfg: List[int],
123
+ block: Type[Union[_BasicBlock, _Bottleneck]],
124
+ groups: int = 1,
125
+ channels_per_group: int = 64,
126
+ out_dim: int = 1000,
127
+ ) -> None:
128
+ super(ResNet, self).__init__()
129
+ self.in_channels = 64
130
+ self.dilation = 1
131
+ self.groups = groups
132
+ self.base_channels = channels_per_group
133
+
134
+ self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
135
+ self.bn1 = nn.BatchNorm2d(self.in_channels)
136
+ self.relu = nn.ReLU(True)
137
+ self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
138
+
139
+ self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
140
+ self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
141
+ self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
142
+ self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
143
+
144
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145
+
146
+ self.fc = nn.Linear(512 * block.expansion, out_dim)
147
+
148
+ # Initialize neural network weights
149
+ self._initialize_weights()
150
+
151
+ def _make_layer(
152
+ self,
153
+ repeat_times: int,
154
+ block: Type[Union[_BasicBlock, _Bottleneck]],
155
+ channels: int,
156
+ stride: int = 1,
157
+ ) -> nn.Sequential:
158
+ downsample = None
159
+
160
+ if stride != 1 or self.in_channels != channels * block.expansion:
161
+ downsample = nn.Sequential(
162
+ nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
163
+ nn.BatchNorm2d(channels * block.expansion),
164
+ )
165
+
166
+ layers = [
167
+ block(
168
+ self.in_channels,
169
+ channels,
170
+ stride,
171
+ downsample,
172
+ self.groups,
173
+ self.base_channels
174
+ )
175
+ ]
176
+ self.in_channels = channels * block.expansion
177
+ for _ in range(1, repeat_times):
178
+ layers.append(
179
+ block(
180
+ self.in_channels,
181
+ channels,
182
+ 1,
183
+ None,
184
+ self.groups,
185
+ self.base_channels,
186
+ )
187
+ )
188
+
189
+ return nn.Sequential(*layers)
190
+
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ out = self._forward_impl(x)
193
+
194
+ return out
195
+
196
+ # Support torch.script function
197
+ def _forward_impl(self, x: Tensor) -> Tensor:
198
+ out = self.conv1(x)
199
+ out = self.bn1(out)
200
+ out = self.relu(out)
201
+ out = self.maxpool(out)
202
+
203
+ out = self.layer1(out)
204
+ out = self.layer2(out)
205
+ out = self.layer3(out)
206
+ out = self.layer4(out)
207
+
208
+ out = self.avgpool(out)
209
+ out = torch.flatten(out, 1)
210
+ out = self.fc(out)
211
+
212
+ return out
213
+
214
+ def _initialize_weights(self) -> None:
215
+ for module in self.modules():
216
+ if isinstance(module, nn.Conv2d):
217
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
218
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
219
+ nn.init.constant_(module.weight, 1)
220
+ nn.init.constant_(module.bias, 0)
221
+
222
+
223
+ def resnet18(**kwargs: Any) -> ResNet:
224
+ model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
225
+
226
+ return model
227
+
228
+
229
+ def resnet34(**kwargs: Any) -> ResNet:
230
+ model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
231
+
232
+ return model
233
+
234
+
235
+ def resnet50(**kwargs: Any) -> ResNet:
236
+ model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
237
+
238
+ return model
239
+
240
+
241
+ def resnet101(**kwargs: Any) -> ResNet:
242
+ model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
243
+
244
+ return model
245
+
246
+
247
+ def resnet152(**kwargs: Any) -> ResNet:
248
+ model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
249
+
250
+ return model
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.28"
3
+ version = "0.0.30"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,14 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
+
8
+ # 2. Behavioral cloning
9
+ echo "Training behavioral cloning model..."
10
+ uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
11
+
12
+ # 3. Inference rollouts
13
+ echo "Running inference rollouts..."
14
+ uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -59,7 +59,7 @@ def test_metacontroller(
59
59
  dim_model = 512,
60
60
  dim_meta_controller = 256,
61
61
  switch_per_code = switch_per_latent_dim,
62
- dim_code_bits = 8, # 2**8 = 256 codes
62
+ dim_code_bits = 8, # 2 ** 8 = 256 codes
63
63
  )
64
64
 
65
65
  # discovery phase
@@ -26,6 +26,7 @@ from memmap_replay_buffer import ReplayBuffer
26
26
  from einops import rearrange
27
27
 
28
28
  from metacontroller.metacontroller import Transformer
29
+ from metacontroller.metacontroller_with_resnet import TransformerWithResnetEncoder
29
30
 
30
31
  import minigrid
31
32
  import gymnasium as gym
@@ -33,7 +34,8 @@ import gymnasium as gym
33
34
  def train(
34
35
  input_dir: str = "babyai-minibosslevel-trajectories",
35
36
  env_id: str = "BabyAI-MiniBossLevel-v0",
36
- epochs: int = 10,
37
+ cloning_epochs: int = 10,
38
+ discovery_epochs: int = 10,
37
39
  batch_size: int = 32,
38
40
  lr: float = 1e-4,
39
41
  dim: int = 512,
@@ -44,7 +46,8 @@ def train(
44
46
  wandb_project: str = "metacontroller-babyai-bc",
45
47
  checkpoint_path: str = "transformer_bc.pt",
46
48
  state_loss_weight: float = 1.,
47
- action_loss_weight: float = 1.
49
+ action_loss_weight: float = 1.,
50
+ use_resnet: bool = False
48
51
  ):
49
52
  # accelerator
50
53
 
@@ -54,7 +57,8 @@ def train(
54
57
  accelerator.init_trackers(
55
58
  wandb_project,
56
59
  config = {
57
- "epochs": epochs,
60
+ "cloning_epochs": cloning_epochs,
61
+ "discovery_epochs": discovery_epochs,
58
62
  "batch_size": batch_size,
59
63
  "lr": lr,
60
64
  "dim": dim,
@@ -78,12 +82,8 @@ def train(
78
82
  # state shape and action dimension
79
83
  # state: (B, T, H, W, C) or (B, T, D)
80
84
  state_shape = replay_buffer.shapes['state']
81
- state_dim = int(torch.tensor(state_shape).prod().item())
82
-
83
- # state shape and action dimension
84
- # state: (B, T, H, W, C) or (B, T, D)
85
- state_shape = replay_buffer.shapes['state']
86
- state_dim = int(torch.tensor(state_shape).prod().item())
85
+ if use_resnet: state_dim = 256
86
+ else: state_dim = int(torch.tensor(state_shape).prod().item())
87
87
 
88
88
  # deduce num_actions from the environment
89
89
  minigrid.register_minigrid_envs()
@@ -94,8 +94,9 @@ def train(
94
94
  accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
95
95
 
96
96
  # transformer
97
-
98
- model = Transformer(
97
+
98
+ transformer_class = TransformerWithResnetEncoder if use_resnet else Transformer
99
+ model = transformer_class(
99
100
  dim = dim,
100
101
  state_embed_readout = dict(num_continuous = state_dim),
101
102
  action_embed_readout = dict(num_discrete = num_actions),
@@ -112,13 +113,13 @@ def train(
112
113
  model, optim, dataloader = accelerator.prepare(model, optim, dataloader)
113
114
 
114
115
  # training
115
-
116
- for epoch in range(epochs):
116
+ for epoch in range(cloning_epochs + discovery_epochs):
117
117
  model.train()
118
118
  total_state_loss = 0.
119
119
  total_action_loss = 0.
120
120
 
121
121
  progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
122
+ is_discovering = (epoch >= cloning_epochs) # discovery phase is BC with metacontroller tuning
122
123
 
123
124
  for batch in progress_bar:
124
125
  # batch is a NamedTuple (e.g. MemoryMappedBatch)
@@ -128,15 +129,20 @@ def train(
128
129
  actions = batch['action'].long()
129
130
  episode_lens = batch.get('_lens')
130
131
 
131
- # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
132
-
133
- states = rearrange(states, 'b t ... -> b t (...)')
132
+ # use resnet18 to embed visual observations
133
+ if use_resnet:
134
+ states = model.visual_encode(states)
135
+ else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
136
+ states = rearrange(states, 'b t ... -> b t (...)')
134
137
 
135
138
  with accelerator.accumulate(model):
136
- state_loss, action_loss = model(states, actions, episode_lens = episode_lens)
139
+ state_loss, action_loss = model(states, actions, episode_lens = episode_lens, discovery_phase=is_discovering)
137
140
  loss = state_loss * state_loss_weight + action_loss * action_loss_weight
138
141
 
139
142
  accelerator.backward(loss)
143
+
144
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
145
+
140
146
  optim.step()
141
147
  optim.zero_grad()
142
148
 
@@ -148,7 +154,8 @@ def train(
148
154
  accelerator.log({
149
155
  "state_loss": state_loss.item(),
150
156
  "action_loss": action_loss.item(),
151
- "total_loss": loss.item()
157
+ "total_loss": loss.item(),
158
+ "grad_norm": grad_norm.item()
152
159
  })
153
160
 
154
161
  progress_bar.set_postfix(
@@ -1,14 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py --num_seeds 10 --num_episodes_per_seed 10 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
-
8
- # 2. Behavioral cloning
9
- echo "Training behavioral cloning model..."
10
- uv run train_behavior_clone_babyai.py --epochs 1 --batch_size 16 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt
11
-
12
- # 3. Inference rollouts
13
- echo "Running inference rollouts..."
14
- uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100