metacontroller-pytorch 0.0.24__tar.gz → 0.0.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.24
3
+ Version: 0.0.26
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,209 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "gymnasium",
4
+ # "minigrid",
5
+ # "tqdm",
6
+ # "fire",
7
+ # "memmap-replay-buffer>=0.0.23",
8
+ # "loguru"
9
+ # ]
10
+ # ///
11
+
12
+ # taken with modifications from https://github.com/ddidacus/bot-minigrid-babyai/blob/main/tests/get_trajectories.py
13
+
14
+ import fire
15
+ import random
16
+ import multiprocessing
17
+ from loguru import logger
18
+ from pathlib import Path
19
+
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category = UserWarning)
22
+
23
+ import numpy as np
24
+
25
+ from tqdm import tqdm
26
+ from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED
27
+
28
+ import torch
29
+ import minigrid
30
+ import gymnasium as gym
31
+ from minigrid.utils.baby_ai_bot import BabyAIBot
32
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
33
+
34
+ from memmap_replay_buffer import ReplayBuffer
35
+
36
+ # helpers
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+ def sample(prob):
42
+ return random.random() < prob
43
+
44
+ # agent
45
+
46
+ class BabyAIBotEpsilonGreedy:
47
+ def __init__(self, env, random_action_prob = 0.):
48
+ self.expert = BabyAIBot(env)
49
+ self.random_action_prob = random_action_prob
50
+ self.num_actions = env.action_space.n
51
+ self.last_action = None
52
+
53
+ def __call__(self, state):
54
+ if sample(self.random_action_prob):
55
+ action = torch.randint(0, self.num_actions, ()).item()
56
+ else:
57
+ action = self.expert.replan(self.last_action)
58
+
59
+ self.last_action = action
60
+ return action
61
+
62
+ # functions
63
+
64
+ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_shape):
65
+ """
66
+ Collect a single episode of demonstrations.
67
+ Returns tuple of (episode_state, episode_action, success, episode_length)
68
+ """
69
+ if env_id not in gym.envs.registry:
70
+ minigrid.register_minigrid_envs()
71
+
72
+ env = gym.make(env_id, render_mode="rgb_array", highlight=False)
73
+ env = FullyObsWrapper(env.unwrapped)
74
+ env = SymbolicObsWrapper(env.unwrapped)
75
+
76
+ try:
77
+ state_obs, _ = env.reset(seed=seed)
78
+
79
+ episode_state = np.zeros((num_steps, *state_shape), dtype=np.float32)
80
+ episode_action = np.zeros(num_steps, dtype=np.float32)
81
+
82
+ expert = BabyAIBotEpsilonGreedy(env.unwrapped, random_action_prob = random_action_prob)
83
+
84
+ for _step in range(num_steps):
85
+ try:
86
+ action = expert(state_obs)
87
+ except Exception:
88
+ env.close()
89
+ return None, None, False, 0
90
+
91
+ episode_state[_step] = state_obs["image"]
92
+ episode_action[_step] = action
93
+
94
+ state_obs, reward, terminated, truncated, info = env.step(action)
95
+
96
+ if terminated:
97
+ env.close()
98
+ return episode_state, episode_action, True, _step + 1
99
+
100
+ env.close()
101
+ return episode_state, episode_action, False, num_steps
102
+
103
+ except Exception:
104
+ env.close()
105
+ return None, None, False, 0
106
+
107
+ def collect_demonstrations(
108
+ env_id = "BabyAI-MiniBossLevel-v0",
109
+ num_seeds = 100,
110
+ num_episodes_per_seed = 100,
111
+ num_steps = 500,
112
+ random_action_prob = 0.05,
113
+ num_workers = None,
114
+ output_dir = "babyai-minibosslevel-trajectories"
115
+ ):
116
+ """
117
+ The BabyAI Bot should be able to solve all BabyAI environments,
118
+ allowing us therefore to generate demonstrations.
119
+ Parallelized version using ProcessPoolExecutor.
120
+ """
121
+
122
+ # Register minigrid envs if not already registered
123
+ if env_id not in gym.envs.registry:
124
+ minigrid.register_minigrid_envs()
125
+
126
+ # Determine state shape from environment
127
+ temp_env = gym.make(env_id)
128
+ temp_env = FullyObsWrapper(temp_env.unwrapped)
129
+ temp_env = SymbolicObsWrapper(temp_env.unwrapped)
130
+ state_shape = temp_env.observation_space['image'].shape
131
+ temp_env.close()
132
+
133
+ logger.info(f"Detected state shape: {state_shape} for env {env_id}")
134
+
135
+ if not exists(num_workers):
136
+ num_workers = multiprocessing.cpu_count()
137
+
138
+ total_episodes = num_seeds * num_episodes_per_seed
139
+
140
+ # Prepare seeds for all episodes
141
+ seeds = []
142
+ for count in range(num_seeds):
143
+ for it in range(num_episodes_per_seed):
144
+ seeds.append(count + 1)
145
+
146
+ successful = 0
147
+ progressbar = tqdm(total=total_episodes)
148
+
149
+ output_folder = Path(output_dir)
150
+
151
+ fields = {
152
+ 'state': ('float', state_shape),
153
+ 'action': ('float', ())
154
+ }
155
+
156
+ buffer = ReplayBuffer(
157
+ folder = output_folder,
158
+ max_episodes = total_episodes,
159
+ max_timesteps = num_steps,
160
+ fields = fields,
161
+ overwrite = True
162
+ )
163
+
164
+ # Parallel execution with bounded pending futures to avoid OOM
165
+ max_pending = num_workers * 4
166
+
167
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
168
+ seed_iter = iter(seeds)
169
+ futures = {}
170
+
171
+ # Initial batch of submissions
172
+ for _ in range(min(max_pending, len(seeds))):
173
+ seed = next(seed_iter, None)
174
+ if exists(seed):
175
+ future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
176
+ futures[future] = seed
177
+
178
+ # Process completed tasks and submit new ones
179
+ while futures:
180
+ # Wait for at least one future to complete
181
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
182
+
183
+ for future in done:
184
+ del futures[future]
185
+ episode_state, episode_action, success, episode_length = future.result()
186
+
187
+ if success and exists(episode_state):
188
+ buffer.store_episode(
189
+ state = episode_state[:episode_length],
190
+ action = episode_action[:episode_length]
191
+ )
192
+ successful += 1
193
+
194
+ progressbar.update(1)
195
+ progressbar.set_description(f"success rate = {successful}/{progressbar.n:.2f}")
196
+
197
+ # Submit a new task to replace the completed one
198
+ seed = next(seed_iter, None)
199
+ if exists(seed):
200
+ new_future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
201
+ futures[new_future] = seed
202
+
203
+ buffer.flush()
204
+ progressbar.close()
205
+
206
+ logger.info(f"Saved {successful} trajectories to {output_dir}")
207
+
208
+ if __name__ == "__main__":
209
+ fire.Fire(collect_demonstrations)
@@ -6,7 +6,7 @@ from collections import namedtuple
6
6
  from loguru import logger
7
7
 
8
8
  import torch
9
- from torch import nn, cat, stack, tensor
9
+ from torch import nn, cat, stack, tensor, Tensor
10
10
  from torch.nn import Module, GRU, Linear, Identity
11
11
  import torch.nn.functional as F
12
12
 
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import pad_at_dim
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -151,7 +151,8 @@ class MetaController(Module):
151
151
  cache: MetaControllerOutput | None = None,
152
152
  discovery_phase = False,
153
153
  hard_switch = False,
154
- temperature = 1.
154
+ temperature = 1.,
155
+ episode_lens: Tensor | None = None
155
156
  ):
156
157
  device = residual_stream.device
157
158
 
@@ -168,7 +169,9 @@ class MetaController(Module):
168
169
  if discovery_phase:
169
170
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
170
171
 
171
- encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
172
+ mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
173
+
174
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
172
175
 
173
176
  proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
174
177
  readout = self.emitter_to_action_mean_log_var
@@ -335,6 +338,7 @@ class Transformer(Module):
335
338
  return_raw_action_dist = False,
336
339
  return_latents = False,
337
340
  return_cache = False,
341
+ episode_lens: Tensor | None = None
338
342
  ):
339
343
  device = state.device
340
344
 
@@ -362,6 +366,9 @@ class Transformer(Module):
362
366
  state, target_state = state[:, :-1], state[:, 1:]
363
367
  actions, target_actions = actions[:, :-1], actions[:, 1:]
364
368
 
369
+ if exists(episode_lens):
370
+ episode_lens = (episode_lens - 1).clamp(min = 0)
371
+
365
372
  # transformer lower body
366
373
 
367
374
  with lower_transformer_context():
@@ -387,7 +394,7 @@ class Transformer(Module):
387
394
  with meta_controller_context():
388
395
 
389
396
  if exists(meta_controller):
390
- control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
397
+ control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
391
398
  else:
392
399
  control_signal, next_meta_hiddens = self.zero, None
393
400
 
@@ -406,10 +413,14 @@ class Transformer(Module):
406
413
  # maybe return behavior cloning loss
407
414
 
408
415
  if behavioral_cloning:
416
+ loss_mask = None
417
+ if exists(episode_lens):
418
+ loss_mask = lens_to_mask(episode_lens, state.shape[1])
419
+
409
420
  state_dist_params = self.state_readout(attended)
410
- state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
421
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
411
422
 
412
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
423
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
413
424
 
414
425
  return state_clone_loss, action_clone_loss
415
426
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.24"
3
+ version = "0.0.26"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,14 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py --num_seeds 10 --num_episodes_per_seed 10 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
+
8
+ # 2. Behavioral cloning
9
+ echo "Training behavioral cloning model..."
10
+ uv run train_behavior_clone_babyai.py --epochs 1 --batch_size 16 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt
11
+
12
+ # 3. Inference rollouts
13
+ echo "Running inference rollouts..."
14
+ uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -10,12 +10,15 @@ from einops import rearrange
10
10
 
11
11
  @param('action_discrete', (False, True))
12
12
  @param('switch_per_latent_dim', (False, True))
13
+ @param('variable_length', (False, True))
13
14
  def test_metacontroller(
14
15
  action_discrete,
15
- switch_per_latent_dim
16
+ switch_per_latent_dim,
17
+ variable_length
16
18
  ):
17
19
 
18
20
  state = torch.randn(1, 1024, 384)
21
+ episode_lens = torch.tensor([512]) if variable_length else None
19
22
 
20
23
  if action_discrete:
21
24
  actions = torch.randint(0, 4, (1, 1024))
@@ -36,7 +39,7 @@ def test_metacontroller(
36
39
  upper_body = dict(depth = 2,),
37
40
  )
38
41
 
39
- state_clone_loss, action_clone_loss = model(state, actions)
42
+ state_clone_loss, action_clone_loss = model(state, actions, episode_lens = episode_lens)
40
43
  (state_clone_loss + 0.5 * action_clone_loss).backward()
41
44
 
42
45
  # discovery and internal rl phase with meta controller
@@ -50,7 +53,7 @@ def test_metacontroller(
50
53
 
51
54
  # discovery phase
52
55
 
53
- (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
56
+ (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
54
57
  (action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
55
58
 
56
59
  # internal rl - done iteratively
@@ -13,13 +13,17 @@
13
13
  from fire import Fire
14
14
  from tqdm import tqdm
15
15
  from shutil import rmtree
16
+ from pathlib import Path
16
17
 
17
18
  import torch
19
+ from einops import rearrange
18
20
 
19
21
  import gymnasium as gym
20
22
  import minigrid
23
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
21
24
 
22
25
  from memmap_replay_buffer import ReplayBuffer
26
+ from metacontroller.metacontroller import Transformer
23
27
 
24
28
  # functions
25
29
 
@@ -41,12 +45,15 @@ def main(
41
45
  buffer_size = 5_000,
42
46
  render_every_eps = 1_000,
43
47
  video_folder = './recordings',
44
- seed = None
48
+ seed = None,
49
+ weights_path = None
45
50
  ):
46
51
 
47
52
  # environment
48
53
 
49
54
  env = gym.make(env_name, render_mode = 'rgb_array')
55
+ env = FullyObsWrapper(env.unwrapped)
56
+ env = SymbolicObsWrapper(env.unwrapped)
50
57
 
51
58
  rmtree(video_folder, ignore_errors = True)
52
59
 
@@ -58,6 +65,15 @@ def main(
58
65
  disable_logger = True
59
66
  )
60
67
 
68
+ # maybe load model
69
+
70
+ model = None
71
+ if exists(weights_path):
72
+ weights_path = Path(weights_path)
73
+ assert weights_path.exists(), f"weights not found at {weights_path}"
74
+ model = Transformer.init_and_load(str(weights_path), strict = False)
75
+ model.eval()
76
+
61
77
  # replay
62
78
 
63
79
  replay_buffer = ReplayBuffer(
@@ -79,10 +95,37 @@ def main(
79
95
 
80
96
  state, *_ = env.reset(seed = seed)
81
97
 
98
+ cache = None
99
+ past_action_id = None
100
+
82
101
  for _ in range(max_timesteps):
83
102
 
84
- action = torch.randint(0, 7, ())
85
- next_state, reward, terminated, truncated, *_ = env.step(action.numpy())
103
+ if exists(model):
104
+ # preprocess state
105
+ # assume state is a dict with 'image'
106
+ image = state['image']
107
+ image_tensor = torch.from_numpy(image).float()
108
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
109
+
110
+ if exists(past_action_id) and torch.is_tensor(past_action_id):
111
+ past_action_id = past_action_id.long()
112
+
113
+ with torch.no_grad():
114
+ logits, cache = model(
115
+ image_tensor,
116
+ past_action_id,
117
+ return_cache = True,
118
+ return_raw_action_dist = True,
119
+ cache = cache
120
+ )
121
+
122
+ action = model.action_readout.sample(logits)
123
+ past_action_id = action
124
+ action = action.squeeze()
125
+ else:
126
+ action = torch.randint(0, 7, ())
127
+
128
+ next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
86
129
 
87
130
  done = terminated or truncated
88
131
 
@@ -91,7 +134,7 @@ def main(
91
134
 
92
135
  state = next_state
93
136
 
94
- # running
137
+ env.close()
95
138
 
96
139
  if __name__ == '__main__':
97
140
  Fire(main)
@@ -0,0 +1,175 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate",
4
+ # "fire",
5
+ # "memmap-replay-buffer>=0.0.23",
6
+ # "metacontroller-pytorch",
7
+ # "torch",
8
+ # "einops",
9
+ # "tqdm",
10
+ # "wandb",
11
+ # "gymnasium",
12
+ # "minigrid"
13
+ # ]
14
+ # ///
15
+
16
+ import fire
17
+ from tqdm import tqdm
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ from torch.optim import Adam
22
+ from torch.utils.data import DataLoader
23
+
24
+ from accelerate import Accelerator
25
+ from memmap_replay_buffer import ReplayBuffer
26
+ from einops import rearrange
27
+
28
+ from metacontroller.metacontroller import Transformer
29
+
30
+ import minigrid
31
+ import gymnasium as gym
32
+
33
+ def train(
34
+ input_dir: str = "babyai-minibosslevel-trajectories",
35
+ env_id: str = "BabyAI-MiniBossLevel-v0",
36
+ epochs: int = 10,
37
+ batch_size: int = 32,
38
+ lr: float = 1e-4,
39
+ dim: int = 512,
40
+ depth: int = 2,
41
+ heads: int = 8,
42
+ dim_head: int = 64,
43
+ use_wandb: bool = False,
44
+ wandb_project: str = "metacontroller-babyai-bc",
45
+ checkpoint_path: str = "transformer_bc.pt",
46
+ state_loss_weight: float = 1.,
47
+ action_loss_weight: float = 1.
48
+ ):
49
+ # accelerator
50
+
51
+ accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
52
+
53
+ if use_wandb:
54
+ accelerator.init_trackers(
55
+ wandb_project,
56
+ config = {
57
+ "epochs": epochs,
58
+ "batch_size": batch_size,
59
+ "lr": lr,
60
+ "dim": dim,
61
+ "depth": depth,
62
+ "heads": heads,
63
+ "dim_head": dim_head,
64
+ "env_id": env_id,
65
+ "state_loss_weight": state_loss_weight,
66
+ "action_loss_weight": action_loss_weight
67
+ }
68
+ )
69
+
70
+ # replay buffer and dataloader
71
+
72
+ input_path = Path(input_dir)
73
+ assert input_path.exists(), f"Input directory {input_dir} does not exist"
74
+
75
+ replay_buffer = ReplayBuffer.from_folder(input_path)
76
+ dataloader = replay_buffer.dataloader(batch_size = batch_size)
77
+
78
+ # state shape and action dimension
79
+ # state: (B, T, H, W, C) or (B, T, D)
80
+ state_shape = replay_buffer.shapes['state']
81
+ state_dim = int(torch.tensor(state_shape).prod().item())
82
+
83
+ # state shape and action dimension
84
+ # state: (B, T, H, W, C) or (B, T, D)
85
+ state_shape = replay_buffer.shapes['state']
86
+ state_dim = int(torch.tensor(state_shape).prod().item())
87
+
88
+ # deduce num_actions from the environment
89
+ minigrid.register_minigrid_envs()
90
+ temp_env = gym.make(env_id)
91
+ num_actions = int(temp_env.action_space.n)
92
+ temp_env.close()
93
+
94
+ accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
95
+
96
+ # transformer
97
+
98
+ model = Transformer(
99
+ dim = dim,
100
+ state_embed_readout = dict(num_continuous = state_dim),
101
+ action_embed_readout = dict(num_discrete = num_actions),
102
+ lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
103
+ upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
104
+ )
105
+
106
+ # optimizer
107
+
108
+ optim = Adam(model.parameters(), lr = lr)
109
+
110
+ # prepare
111
+
112
+ model, optim, dataloader = accelerator.prepare(model, optim, dataloader)
113
+
114
+ # training
115
+
116
+ for epoch in range(epochs):
117
+ model.train()
118
+ total_state_loss = 0.
119
+ total_action_loss = 0.
120
+
121
+ progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
122
+
123
+ for batch in progress_bar:
124
+ # batch is a NamedTuple (e.g. MemoryMappedBatch)
125
+ # state: (B, T, 7, 7, 3), action: (B, T)
126
+
127
+ states = batch['state'].float()
128
+ actions = batch['action'].long()
129
+ episode_lens = batch.get('_lens')
130
+
131
+ # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
132
+
133
+ states = rearrange(states, 'b t ... -> b t (...)')
134
+
135
+ with accelerator.accumulate(model):
136
+ state_loss, action_loss = model(states, actions, episode_lens = episode_lens)
137
+ loss = state_loss * state_loss_weight + action_loss * action_loss_weight
138
+
139
+ accelerator.backward(loss)
140
+ optim.step()
141
+ optim.zero_grad()
142
+
143
+ # log
144
+
145
+ total_state_loss += state_loss.item()
146
+ total_action_loss += action_loss.item()
147
+
148
+ accelerator.log({
149
+ "state_loss": state_loss.item(),
150
+ "action_loss": action_loss.item(),
151
+ "total_loss": loss.item()
152
+ })
153
+
154
+ progress_bar.set_postfix(
155
+ state_loss = state_loss.item(),
156
+ action_loss = action_loss.item()
157
+ )
158
+
159
+ avg_state_loss = total_state_loss / len(dataloader)
160
+ avg_action_loss = total_action_loss / len(dataloader)
161
+
162
+ accelerator.print(f"Epoch {epoch}: state_loss={avg_state_loss:.4f}, action_loss={avg_action_loss:.4f}")
163
+
164
+ # save weights
165
+
166
+ accelerator.wait_for_everyone()
167
+ if accelerator.is_main_process:
168
+ unwrapped_model = accelerator.unwrap_model(model)
169
+ unwrapped_model.save(checkpoint_path)
170
+ accelerator.print(f"Model saved to {checkpoint_path}")
171
+
172
+ accelerator.end_training()
173
+
174
+ if __name__ == "__main__":
175
+ fire.Fire(train)