metacontroller-pytorch 0.0.36__tar.gz → 0.0.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/metacontroller/metacontroller.py +7 -2
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/train_behavior_clone_babyai.py +10 -14
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/README.md +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/tests/test_metacontroller.py +0 -0
- {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/train_babyai.py +0 -0
{metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/metacontroller/metacontroller.py
RENAMED
|
@@ -408,6 +408,7 @@ class Transformer(Module):
|
|
|
408
408
|
meta_controller: Module | None = None,
|
|
409
409
|
cache: TransformerOutput | None = None,
|
|
410
410
|
discovery_phase = False,
|
|
411
|
+
force_behavior_cloning = False,
|
|
411
412
|
meta_controller_temperature = 1.,
|
|
412
413
|
return_raw_action_dist = False,
|
|
413
414
|
return_latents = False,
|
|
@@ -420,11 +421,15 @@ class Transformer(Module):
|
|
|
420
421
|
|
|
421
422
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
422
423
|
|
|
424
|
+
if force_behavior_cloning:
|
|
425
|
+
assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
|
|
426
|
+
meta_controller = None
|
|
427
|
+
|
|
423
428
|
has_meta_controller = exists(meta_controller)
|
|
424
429
|
|
|
425
430
|
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
426
431
|
|
|
427
|
-
behavioral_cloning = not has_meta_controller and not return_raw_action_dist
|
|
432
|
+
behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
|
|
428
433
|
|
|
429
434
|
# by default, if meta controller is passed in, transformer is no grad
|
|
430
435
|
|
|
@@ -472,7 +477,7 @@ class Transformer(Module):
|
|
|
472
477
|
|
|
473
478
|
with meta_controller_context():
|
|
474
479
|
|
|
475
|
-
if exists(meta_controller):
|
|
480
|
+
if exists(meta_controller) and not behavioral_cloning:
|
|
476
481
|
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
477
482
|
else:
|
|
478
483
|
control_signal, next_meta_hiddens = self.zero, None
|
{metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/train_behavior_clone_babyai.py
RENAMED
|
@@ -130,8 +130,8 @@ def train(
|
|
|
130
130
|
for epoch in range(cloning_epochs + discovery_epochs):
|
|
131
131
|
|
|
132
132
|
model.train()
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
from collections import defaultdict
|
|
134
|
+
total_losses = defaultdict(float)
|
|
135
135
|
|
|
136
136
|
progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
|
|
137
137
|
|
|
@@ -200,9 +200,9 @@ def train(
|
|
|
200
200
|
optim.zero_grad()
|
|
201
201
|
|
|
202
202
|
# log
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
203
|
+
|
|
204
|
+
for key, value in log.items():
|
|
205
|
+
total_losses[key] += value
|
|
206
206
|
|
|
207
207
|
accelerator.log({
|
|
208
208
|
**log,
|
|
@@ -210,15 +210,11 @@ def train(
|
|
|
210
210
|
"grad_norm": grad_norm.item()
|
|
211
211
|
})
|
|
212
212
|
|
|
213
|
-
progress_bar.set_postfix(
|
|
214
|
-
state_loss = state_loss.item(),
|
|
215
|
-
action_loss = action_loss.item()
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
avg_state_loss = total_state_loss / len(dataloader)
|
|
219
|
-
avg_action_loss = total_action_loss / len(dataloader)
|
|
213
|
+
progress_bar.set_postfix(**log)
|
|
220
214
|
|
|
221
|
-
|
|
215
|
+
avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
|
|
216
|
+
avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
|
|
217
|
+
accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
|
|
222
218
|
|
|
223
219
|
# save weights
|
|
224
220
|
|
|
@@ -231,7 +227,7 @@ def train(
|
|
|
231
227
|
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
232
228
|
unwrapped_meta_controller.save(meta_controller_checkpoint_path)
|
|
233
229
|
|
|
234
|
-
accelerator.print(f"Model saved to {checkpoint_path},
|
|
230
|
+
accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
|
|
235
231
|
|
|
236
232
|
accelerator.end_training()
|
|
237
233
|
|
{metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.37}/tests/test_metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|