metacontroller-pytorch 0.0.23__tar.gz → 0.0.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,8 +39,8 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
- Requires-Dist: memmap-replay-buffer>=0.0.1
43
- Requires-Dist: torch-einops-utils>=0.0.7
42
+ Requires-Dist: memmap-replay-buffer>=0.0.23
43
+ Requires-Dist: torch-einops-utils>=0.0.16
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: x-evolution>=0.1.23
46
46
  Requires-Dist: x-mlps-pytorch
@@ -0,0 +1,209 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "gymnasium",
4
+ # "minigrid",
5
+ # "tqdm",
6
+ # "fire",
7
+ # "memmap-replay-buffer>=0.0.23",
8
+ # "loguru"
9
+ # ]
10
+ # ///
11
+
12
+ # taken with modifications from https://github.com/ddidacus/bot-minigrid-babyai/blob/main/tests/get_trajectories.py
13
+
14
+ import fire
15
+ import random
16
+ import multiprocessing
17
+ from loguru import logger
18
+ from pathlib import Path
19
+
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category = UserWarning)
22
+
23
+ import numpy as np
24
+
25
+ from tqdm import tqdm
26
+ from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED
27
+
28
+ import torch
29
+ import minigrid
30
+ import gymnasium as gym
31
+ from minigrid.utils.baby_ai_bot import BabyAIBot
32
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
33
+
34
+ from memmap_replay_buffer import ReplayBuffer
35
+
36
+ # helpers
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+ def sample(prob):
42
+ return random.random() < prob
43
+
44
+ # agent
45
+
46
+ class BabyAIBotEpsilonGreedy:
47
+ def __init__(self, env, random_action_prob = 0.):
48
+ self.expert = BabyAIBot(env)
49
+ self.random_action_prob = random_action_prob
50
+ self.num_actions = env.action_space.n
51
+ self.last_action = None
52
+
53
+ def __call__(self, state):
54
+ if sample(self.random_action_prob):
55
+ action = torch.randint(0, self.num_actions, ()).item()
56
+ else:
57
+ action = self.expert.replan(self.last_action)
58
+
59
+ self.last_action = action
60
+ return action
61
+
62
+ # functions
63
+
64
+ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_shape):
65
+ """
66
+ Collect a single episode of demonstrations.
67
+ Returns tuple of (episode_state, episode_action, success, episode_length)
68
+ """
69
+ if env_id not in gym.envs.registry:
70
+ minigrid.register_minigrid_envs()
71
+
72
+ env = gym.make(env_id, render_mode="rgb_array", highlight=False)
73
+ env = FullyObsWrapper(env.unwrapped)
74
+ env = SymbolicObsWrapper(env.unwrapped)
75
+
76
+ try:
77
+ state_obs, _ = env.reset(seed=seed)
78
+
79
+ episode_state = np.zeros((num_steps, *state_shape), dtype=np.float32)
80
+ episode_action = np.zeros(num_steps, dtype=np.float32)
81
+
82
+ expert = BabyAIBotEpsilonGreedy(env.unwrapped, random_action_prob = random_action_prob)
83
+
84
+ for _step in range(num_steps):
85
+ try:
86
+ action = expert(state_obs)
87
+ except Exception:
88
+ env.close()
89
+ return None, None, False, 0
90
+
91
+ episode_state[_step] = state_obs["image"]
92
+ episode_action[_step] = action
93
+
94
+ state_obs, reward, terminated, truncated, info = env.step(action)
95
+
96
+ if terminated:
97
+ env.close()
98
+ return episode_state, episode_action, True, _step + 1
99
+
100
+ env.close()
101
+ return episode_state, episode_action, False, num_steps
102
+
103
+ except Exception:
104
+ env.close()
105
+ return None, None, False, 0
106
+
107
+ def collect_demonstrations(
108
+ env_id = "BabyAI-MiniBossLevel-v0",
109
+ num_seeds = 100,
110
+ num_episodes_per_seed = 100,
111
+ num_steps = 500,
112
+ random_action_prob = 0.05,
113
+ num_workers = None,
114
+ output_dir = "babyai-minibosslevel-trajectories"
115
+ ):
116
+ """
117
+ The BabyAI Bot should be able to solve all BabyAI environments,
118
+ allowing us therefore to generate demonstrations.
119
+ Parallelized version using ProcessPoolExecutor.
120
+ """
121
+
122
+ # Register minigrid envs if not already registered
123
+ if env_id not in gym.envs.registry:
124
+ minigrid.register_minigrid_envs()
125
+
126
+ # Determine state shape from environment
127
+ temp_env = gym.make(env_id)
128
+ temp_env = FullyObsWrapper(temp_env.unwrapped)
129
+ temp_env = SymbolicObsWrapper(temp_env.unwrapped)
130
+ state_shape = temp_env.observation_space['image'].shape
131
+ temp_env.close()
132
+
133
+ logger.info(f"Detected state shape: {state_shape} for env {env_id}")
134
+
135
+ if not exists(num_workers):
136
+ num_workers = multiprocessing.cpu_count()
137
+
138
+ total_episodes = num_seeds * num_episodes_per_seed
139
+
140
+ # Prepare seeds for all episodes
141
+ seeds = []
142
+ for count in range(num_seeds):
143
+ for it in range(num_episodes_per_seed):
144
+ seeds.append(count + 1)
145
+
146
+ successful = 0
147
+ progressbar = tqdm(total=total_episodes)
148
+
149
+ output_folder = Path(output_dir)
150
+
151
+ fields = {
152
+ 'state': ('float', state_shape),
153
+ 'action': ('float', ())
154
+ }
155
+
156
+ buffer = ReplayBuffer(
157
+ folder = output_folder,
158
+ max_episodes = total_episodes,
159
+ max_timesteps = num_steps,
160
+ fields = fields,
161
+ overwrite = True
162
+ )
163
+
164
+ # Parallel execution with bounded pending futures to avoid OOM
165
+ max_pending = num_workers * 4
166
+
167
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
168
+ seed_iter = iter(seeds)
169
+ futures = {}
170
+
171
+ # Initial batch of submissions
172
+ for _ in range(min(max_pending, len(seeds))):
173
+ seed = next(seed_iter, None)
174
+ if exists(seed):
175
+ future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
176
+ futures[future] = seed
177
+
178
+ # Process completed tasks and submit new ones
179
+ while futures:
180
+ # Wait for at least one future to complete
181
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
182
+
183
+ for future in done:
184
+ del futures[future]
185
+ episode_state, episode_action, success, episode_length = future.result()
186
+
187
+ if success and exists(episode_state):
188
+ buffer.store_episode(
189
+ state = episode_state[:episode_length],
190
+ action = episode_action[:episode_length]
191
+ )
192
+ successful += 1
193
+
194
+ progressbar.update(1)
195
+ progressbar.set_description(f"success rate = {successful}/{progressbar.n:.2f}")
196
+
197
+ # Submit a new task to replace the completed one
198
+ seed = next(seed_iter, None)
199
+ if exists(seed):
200
+ new_future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
201
+ futures[new_future] = seed
202
+
203
+ buffer.flush()
204
+ progressbar.close()
205
+
206
+ logger.info(f"Saved {successful} trajectories to {output_dir}")
207
+
208
+ if __name__ == "__main__":
209
+ fire.Fire(collect_demonstrations)
@@ -26,7 +26,8 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import pad_at_dim
29
+ from torch_einops_utils import pad_at_dim, lens_to_mask
30
+ from torch_einops_utils.save_load import save_load
30
31
 
31
32
  # constants
32
33
 
@@ -63,6 +64,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
63
64
  'switch_loss'
64
65
  ))
65
66
 
67
+ @save_load()
66
68
  class MetaController(Module):
67
69
  def __init__(
68
70
  self,
@@ -272,6 +274,7 @@ TransformerOutput = namedtuple('TransformerOutput', (
272
274
  'prev_hiddens'
273
275
  ))
274
276
 
277
+ @save_load()
275
278
  class Transformer(Module):
276
279
  def __init__(
277
280
  self,
@@ -332,6 +335,7 @@ class Transformer(Module):
332
335
  return_raw_action_dist = False,
333
336
  return_latents = False,
334
337
  return_cache = False,
338
+ episode_lens: Tensor | None = None
335
339
  ):
336
340
  device = state.device
337
341
 
@@ -359,6 +363,9 @@ class Transformer(Module):
359
363
  state, target_state = state[:, :-1], state[:, 1:]
360
364
  actions, target_actions = actions[:, :-1], actions[:, 1:]
361
365
 
366
+ if exists(episode_lens):
367
+ episode_lens = (episode_lens - 1).clamp(min = 0)
368
+
362
369
  # transformer lower body
363
370
 
364
371
  with lower_transformer_context():
@@ -403,10 +410,14 @@ class Transformer(Module):
403
410
  # maybe return behavior cloning loss
404
411
 
405
412
  if behavioral_cloning:
413
+ loss_mask = None
414
+ if exists(episode_lens):
415
+ loss_mask = lens_to_mask(episode_lens, state.shape[1])
416
+
406
417
  state_dist_params = self.state_readout(attended)
407
- state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
418
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
408
419
 
409
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
420
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
410
421
 
411
422
  return state_clone_loss, action_clone_loss
412
423
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.23"
3
+ version = "0.0.25"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,9 +29,9 @@ dependencies = [
29
29
  "einx>=0.3.0",
30
30
  "einops>=0.8.1",
31
31
  "loguru",
32
- "memmap-replay-buffer>=0.0.1",
32
+ "memmap-replay-buffer>=0.0.23",
33
33
  "torch>=2.5",
34
- "torch-einops-utils>=0.0.7",
34
+ "torch-einops-utils>=0.0.16",
35
35
  "x-evolution>=0.1.23",
36
36
  "x-mlps-pytorch",
37
37
  "x-transformers"
@@ -0,0 +1,14 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py --num_seeds 10 --num_episodes_per_seed 10 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
+
8
+ # 2. Behavioral cloning
9
+ echo "Training behavioral cloning model..."
10
+ uv run train_behavior_clone_babyai.py --epochs 1 --batch_size 16 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt
11
+
12
+ # 3. Inference rollouts
13
+ echo "Running inference rollouts..."
14
+ uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -1,6 +1,8 @@
1
1
  import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
+ from pathlib import Path
5
+
4
6
  import torch
5
7
  from metacontroller.metacontroller import Transformer, MetaController
6
8
 
@@ -8,12 +10,15 @@ from einops import rearrange
8
10
 
9
11
  @param('action_discrete', (False, True))
10
12
  @param('switch_per_latent_dim', (False, True))
13
+ @param('variable_length', (False, True))
11
14
  def test_metacontroller(
12
15
  action_discrete,
13
- switch_per_latent_dim
16
+ switch_per_latent_dim,
17
+ variable_length
14
18
  ):
15
19
 
16
20
  state = torch.randn(1, 1024, 384)
21
+ episode_lens = torch.tensor([512]) if variable_length else None
17
22
 
18
23
  if action_discrete:
19
24
  actions = torch.randint(0, 4, (1, 1024))
@@ -34,7 +39,7 @@ def test_metacontroller(
34
39
  upper_body = dict(depth = 2,),
35
40
  )
36
41
 
37
- state_clone_loss, action_clone_loss = model(state, actions)
42
+ state_clone_loss, action_clone_loss = model(state, actions, episode_lens = episode_lens)
38
43
  (state_clone_loss + 0.5 * action_clone_loss).backward()
39
44
 
40
45
  # discovery and internal rl phase with meta controller
@@ -68,3 +73,16 @@ def test_metacontroller(
68
73
 
69
74
  model.meta_controller = meta_controller
70
75
  model.evolve(1, lambda _: 1., noise_population_size = 2)
76
+
77
+ # saving and loading
78
+
79
+ meta_controller.save('./meta_controller.pt')
80
+
81
+ rehydrated_meta_controller = MetaController.init_and_load('./meta_controller.pt')
82
+
83
+ model.save('./trained.pt')
84
+
85
+ rehydrated_model = Transformer.init_and_load('./trained.pt', strict = False)
86
+
87
+ Path('./meta_controller.pt').unlink()
88
+ Path('./trained.pt').unlink()
@@ -3,7 +3,7 @@
3
3
  # "fire",
4
4
  # "gymnasium",
5
5
  # "gymnasium[other]",
6
- # "memmap-replay-buffer>=0.0.10",
6
+ # "memmap-replay-buffer>=0.0.12",
7
7
  # "metacontroller-pytorch",
8
8
  # "minigrid",
9
9
  # "tqdm"
@@ -13,13 +13,17 @@
13
13
  from fire import Fire
14
14
  from tqdm import tqdm
15
15
  from shutil import rmtree
16
+ from pathlib import Path
16
17
 
17
18
  import torch
19
+ from einops import rearrange
18
20
 
19
21
  import gymnasium as gym
20
22
  import minigrid
23
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
21
24
 
22
25
  from memmap_replay_buffer import ReplayBuffer
26
+ from metacontroller.metacontroller import Transformer
23
27
 
24
28
  # functions
25
29
 
@@ -41,12 +45,15 @@ def main(
41
45
  buffer_size = 5_000,
42
46
  render_every_eps = 1_000,
43
47
  video_folder = './recordings',
44
- seed = None
48
+ seed = None,
49
+ weights_path = None
45
50
  ):
46
51
 
47
52
  # environment
48
53
 
49
54
  env = gym.make(env_name, render_mode = 'rgb_array')
55
+ env = FullyObsWrapper(env.unwrapped)
56
+ env = SymbolicObsWrapper(env.unwrapped)
50
57
 
51
58
  rmtree(video_folder, ignore_errors = True)
52
59
 
@@ -58,6 +65,15 @@ def main(
58
65
  disable_logger = True
59
66
  )
60
67
 
68
+ # maybe load model
69
+
70
+ model = None
71
+ if exists(weights_path):
72
+ weights_path = Path(weights_path)
73
+ assert weights_path.exists(), f"weights not found at {weights_path}"
74
+ model = Transformer.init_and_load(str(weights_path), strict = False)
75
+ model.eval()
76
+
61
77
  # replay
62
78
 
63
79
  replay_buffer = ReplayBuffer(
@@ -79,10 +95,37 @@ def main(
79
95
 
80
96
  state, *_ = env.reset(seed = seed)
81
97
 
98
+ cache = None
99
+ past_action_id = None
100
+
82
101
  for _ in range(max_timesteps):
83
102
 
84
- action = torch.randint(0, 7, ())
85
- next_state, reward, terminated, truncated, *_ = env.step(action.numpy())
103
+ if exists(model):
104
+ # preprocess state
105
+ # assume state is a dict with 'image'
106
+ image = state['image']
107
+ image_tensor = torch.from_numpy(image).float()
108
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
109
+
110
+ if exists(past_action_id) and torch.is_tensor(past_action_id):
111
+ past_action_id = past_action_id.long()
112
+
113
+ with torch.no_grad():
114
+ logits, cache = model(
115
+ image_tensor,
116
+ past_action_id,
117
+ return_cache = True,
118
+ return_raw_action_dist = True,
119
+ cache = cache
120
+ )
121
+
122
+ action = model.action_readout.sample(logits)
123
+ past_action_id = action
124
+ action = action.squeeze()
125
+ else:
126
+ action = torch.randint(0, 7, ())
127
+
128
+ next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
86
129
 
87
130
  done = terminated or truncated
88
131
 
@@ -91,7 +134,7 @@ def main(
91
134
 
92
135
  state = next_state
93
136
 
94
- # running
137
+ env.close()
95
138
 
96
139
  if __name__ == '__main__':
97
140
  Fire(main)
@@ -0,0 +1,175 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate",
4
+ # "fire",
5
+ # "memmap-replay-buffer>=0.0.23",
6
+ # "metacontroller-pytorch",
7
+ # "torch",
8
+ # "einops",
9
+ # "tqdm",
10
+ # "wandb",
11
+ # "gymnasium",
12
+ # "minigrid"
13
+ # ]
14
+ # ///
15
+
16
+ import fire
17
+ from tqdm import tqdm
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ from torch.optim import Adam
22
+ from torch.utils.data import DataLoader
23
+
24
+ from accelerate import Accelerator
25
+ from memmap_replay_buffer import ReplayBuffer
26
+ from einops import rearrange
27
+
28
+ from metacontroller.metacontroller import Transformer
29
+
30
+ import minigrid
31
+ import gymnasium as gym
32
+
33
+ def train(
34
+ input_dir: str = "babyai-minibosslevel-trajectories",
35
+ env_id: str = "BabyAI-MiniBossLevel-v0",
36
+ epochs: int = 10,
37
+ batch_size: int = 32,
38
+ lr: float = 1e-4,
39
+ dim: int = 512,
40
+ depth: int = 2,
41
+ heads: int = 8,
42
+ dim_head: int = 64,
43
+ use_wandb: bool = False,
44
+ wandb_project: str = "metacontroller-babyai-bc",
45
+ checkpoint_path: str = "transformer_bc.pt",
46
+ state_loss_weight: float = 1.,
47
+ action_loss_weight: float = 1.
48
+ ):
49
+ # accelerator
50
+
51
+ accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
52
+
53
+ if use_wandb:
54
+ accelerator.init_trackers(
55
+ wandb_project,
56
+ config = {
57
+ "epochs": epochs,
58
+ "batch_size": batch_size,
59
+ "lr": lr,
60
+ "dim": dim,
61
+ "depth": depth,
62
+ "heads": heads,
63
+ "dim_head": dim_head,
64
+ "env_id": env_id,
65
+ "state_loss_weight": state_loss_weight,
66
+ "action_loss_weight": action_loss_weight
67
+ }
68
+ )
69
+
70
+ # replay buffer and dataloader
71
+
72
+ input_path = Path(input_dir)
73
+ assert input_path.exists(), f"Input directory {input_dir} does not exist"
74
+
75
+ replay_buffer = ReplayBuffer.from_folder(input_path)
76
+ dataloader = replay_buffer.dataloader(batch_size = batch_size)
77
+
78
+ # state shape and action dimension
79
+ # state: (B, T, H, W, C) or (B, T, D)
80
+ state_shape = replay_buffer.shapes['state']
81
+ state_dim = int(torch.tensor(state_shape).prod().item())
82
+
83
+ # state shape and action dimension
84
+ # state: (B, T, H, W, C) or (B, T, D)
85
+ state_shape = replay_buffer.shapes['state']
86
+ state_dim = int(torch.tensor(state_shape).prod().item())
87
+
88
+ # deduce num_actions from the environment
89
+ minigrid.register_minigrid_envs()
90
+ temp_env = gym.make(env_id)
91
+ num_actions = int(temp_env.action_space.n)
92
+ temp_env.close()
93
+
94
+ accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
95
+
96
+ # transformer
97
+
98
+ model = Transformer(
99
+ dim = dim,
100
+ state_embed_readout = dict(num_continuous = state_dim),
101
+ action_embed_readout = dict(num_discrete = num_actions),
102
+ lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
103
+ upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
104
+ )
105
+
106
+ # optimizer
107
+
108
+ optim = Adam(model.parameters(), lr = lr)
109
+
110
+ # prepare
111
+
112
+ model, optim, dataloader = accelerator.prepare(model, optim, dataloader)
113
+
114
+ # training
115
+
116
+ for epoch in range(epochs):
117
+ model.train()
118
+ total_state_loss = 0.
119
+ total_action_loss = 0.
120
+
121
+ progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
122
+
123
+ for batch in progress_bar:
124
+ # batch is a NamedTuple (e.g. MemoryMappedBatch)
125
+ # state: (B, T, 7, 7, 3), action: (B, T)
126
+
127
+ states = batch['state'].float()
128
+ actions = batch['action'].long()
129
+ episode_lens = batch.get('_lens')
130
+
131
+ # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
132
+
133
+ states = rearrange(states, 'b t ... -> b t (...)')
134
+
135
+ with accelerator.accumulate(model):
136
+ state_loss, action_loss = model(states, actions, episode_lens = episode_lens)
137
+ loss = state_loss * state_loss_weight + action_loss * action_loss_weight
138
+
139
+ accelerator.backward(loss)
140
+ optim.step()
141
+ optim.zero_grad()
142
+
143
+ # log
144
+
145
+ total_state_loss += state_loss.item()
146
+ total_action_loss += action_loss.item()
147
+
148
+ accelerator.log({
149
+ "state_loss": state_loss.item(),
150
+ "action_loss": action_loss.item(),
151
+ "total_loss": loss.item()
152
+ })
153
+
154
+ progress_bar.set_postfix(
155
+ state_loss = state_loss.item(),
156
+ action_loss = action_loss.item()
157
+ )
158
+
159
+ avg_state_loss = total_state_loss / len(dataloader)
160
+ avg_action_loss = total_action_loss / len(dataloader)
161
+
162
+ accelerator.print(f"Epoch {epoch}: state_loss={avg_state_loss:.4f}, action_loss={avg_action_loss:.4f}")
163
+
164
+ # save weights
165
+
166
+ accelerator.wait_for_everyone()
167
+ if accelerator.is_main_process:
168
+ unwrapped_model = accelerator.unwrap_model(model)
169
+ unwrapped_model.save(checkpoint_path)
170
+ accelerator.print(f"Model saved to {checkpoint_path}")
171
+
172
+ accelerator.end_training()
173
+
174
+ if __name__ == "__main__":
175
+ fire.Fire(train)