metacontroller-pytorch 0.0.34__tar.gz → 0.0.36__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller.py +19 -7
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller_with_binary_mapper.py +9 -2
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/pyproject.toml +2 -2
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/tests/test_metacontroller.py +64 -21
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/train_behavior_clone_babyai.py +82 -25
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/README.md +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/train_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.36
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,7 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
-
Requires-Dist: memmap-replay-buffer>=0.0.
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.25
|
|
43
43
|
Requires-Dist: torch-einops-utils>=0.0.19
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
45
|
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
{metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller.py
RENAMED
|
@@ -329,8 +329,15 @@ class MetaController(Module):
|
|
|
329
329
|
sampled_latent_action[:, -1:]
|
|
330
330
|
)
|
|
331
331
|
|
|
332
|
+
# squeeze out the last dimension of switch_beta if single gate for all latent dimensions
|
|
333
|
+
|
|
334
|
+
if not self.switch_per_latent_dim:
|
|
335
|
+
switch_beta = rearrange(switch_beta, '... 1 -> ...')
|
|
336
|
+
|
|
332
337
|
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
333
338
|
|
|
339
|
+
MetaController.policy_loss = policy_loss
|
|
340
|
+
|
|
334
341
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
335
342
|
|
|
336
343
|
Hiddens = namedtuple('Hiddens', (
|
|
@@ -409,17 +416,21 @@ class Transformer(Module):
|
|
|
409
416
|
):
|
|
410
417
|
device = state.device
|
|
411
418
|
|
|
419
|
+
# meta controller is either given or already given at init
|
|
420
|
+
|
|
412
421
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
413
422
|
|
|
414
|
-
|
|
423
|
+
has_meta_controller = exists(meta_controller)
|
|
424
|
+
|
|
425
|
+
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
415
426
|
|
|
416
|
-
behavioral_cloning = not
|
|
427
|
+
behavioral_cloning = not has_meta_controller and not return_raw_action_dist
|
|
417
428
|
|
|
418
429
|
# by default, if meta controller is passed in, transformer is no grad
|
|
419
430
|
|
|
420
|
-
lower_transformer_context = nullcontext if not
|
|
421
|
-
meta_controller_context = nullcontext if
|
|
422
|
-
upper_transformer_context = nullcontext if (not
|
|
431
|
+
lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
|
|
432
|
+
meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
|
|
433
|
+
upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
|
|
423
434
|
|
|
424
435
|
# handle cache
|
|
425
436
|
|
|
@@ -427,7 +438,8 @@ class Transformer(Module):
|
|
|
427
438
|
|
|
428
439
|
# handle maybe behavioral cloning
|
|
429
440
|
|
|
430
|
-
if behavioral_cloning or
|
|
441
|
+
if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
|
|
442
|
+
|
|
431
443
|
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
432
444
|
|
|
433
445
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
@@ -490,7 +502,7 @@ class Transformer(Module):
|
|
|
490
502
|
|
|
491
503
|
return state_clone_loss, action_clone_loss
|
|
492
504
|
|
|
493
|
-
elif
|
|
505
|
+
elif discovery_phase:
|
|
494
506
|
|
|
495
507
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
496
508
|
|
|
@@ -28,7 +28,7 @@ from torch_einops_utils.save_load import save_load
|
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
30
30
|
|
|
31
|
-
from metacontroller.metacontroller import MetaControllerOutput
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput, policy_loss
|
|
32
32
|
|
|
33
33
|
# constants
|
|
34
34
|
|
|
@@ -170,7 +170,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
170
170
|
action_log_probs = log_probs.gather(-1, codes)
|
|
171
171
|
action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
|
|
172
172
|
|
|
173
|
-
return action_log_probs
|
|
173
|
+
return action_log_probs
|
|
174
174
|
|
|
175
175
|
def forward(
|
|
176
176
|
self,
|
|
@@ -296,4 +296,11 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
296
296
|
sampled_codes[:, -1:]
|
|
297
297
|
)
|
|
298
298
|
|
|
299
|
+
# squeeze out the last dimension of switch_beta if single gate for all codes
|
|
300
|
+
|
|
301
|
+
if not self.switch_per_code:
|
|
302
|
+
switch_beta = rearrange(switch_beta, '... 1 -> ...')
|
|
303
|
+
|
|
299
304
|
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
|
305
|
+
|
|
306
|
+
MetaControllerWithBinaryMapper.policy_loss = policy_loss
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.36"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -29,7 +29,7 @@ dependencies = [
|
|
|
29
29
|
"einx>=0.3.0",
|
|
30
30
|
"einops>=0.8.1",
|
|
31
31
|
"loguru",
|
|
32
|
-
"memmap-replay-buffer>=0.0.
|
|
32
|
+
"memmap-replay-buffer>=0.0.25",
|
|
33
33
|
"torch>=2.5",
|
|
34
34
|
"torch-einops-utils>=0.0.19",
|
|
35
35
|
"vector-quantize-pytorch>=1.27.20",
|
{metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/tests/test_metacontroller.py
RENAMED
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
param = pytest.mark.parametrize
|
|
3
3
|
|
|
4
|
+
from shutil import rmtree
|
|
4
5
|
from pathlib import Path
|
|
6
|
+
from functools import partial
|
|
5
7
|
|
|
6
8
|
import torch
|
|
7
9
|
from torch import cat
|
|
8
10
|
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
9
11
|
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
10
12
|
|
|
13
|
+
from memmap_replay_buffer import ReplayBuffer
|
|
14
|
+
|
|
11
15
|
from einops import rearrange
|
|
12
16
|
|
|
13
17
|
# functions
|
|
@@ -65,6 +69,12 @@ def test_metacontroller(
|
|
|
65
69
|
dim_latent = 128,
|
|
66
70
|
switch_per_latent_dim = switch_per_latent_dim
|
|
67
71
|
)
|
|
72
|
+
|
|
73
|
+
field_shapes = dict(
|
|
74
|
+
log_probs = ('float', 128),
|
|
75
|
+
switch_betas = ('float', 128 if switch_per_latent_dim else 1),
|
|
76
|
+
latent_actions = ('float', 128)
|
|
77
|
+
)
|
|
68
78
|
else:
|
|
69
79
|
meta_controller = MetaControllerWithBinaryMapper(
|
|
70
80
|
dim_model = 512,
|
|
@@ -73,6 +83,12 @@ def test_metacontroller(
|
|
|
73
83
|
dim_code_bits = 8, # 2 ** 8 = 256 codes
|
|
74
84
|
)
|
|
75
85
|
|
|
86
|
+
field_shapes = dict(
|
|
87
|
+
log_probs = ('float', 8),
|
|
88
|
+
switch_betas = ('float', 8 if switch_per_latent_dim else 1),
|
|
89
|
+
latent_actions = ('float', 256)
|
|
90
|
+
)
|
|
91
|
+
|
|
76
92
|
# discovery phase
|
|
77
93
|
|
|
78
94
|
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
|
|
@@ -80,6 +96,23 @@ def test_metacontroller(
|
|
|
80
96
|
|
|
81
97
|
# internal rl - done iteratively
|
|
82
98
|
|
|
99
|
+
# replay buffer
|
|
100
|
+
|
|
101
|
+
test_folder = './test-buffer-for-grpo'
|
|
102
|
+
|
|
103
|
+
replay_buffer = ReplayBuffer(
|
|
104
|
+
test_folder,
|
|
105
|
+
max_episodes = 3,
|
|
106
|
+
max_timesteps = 256,
|
|
107
|
+
fields = dict(
|
|
108
|
+
states = ('float', 512),
|
|
109
|
+
**field_shapes
|
|
110
|
+
),
|
|
111
|
+
meta_fields = dict(
|
|
112
|
+
advantages = 'float'
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
83
116
|
# simulate grpo
|
|
84
117
|
|
|
85
118
|
all_episodes = []
|
|
@@ -116,11 +149,11 @@ def test_metacontroller(
|
|
|
116
149
|
|
|
117
150
|
# accumulate across time for the episode data
|
|
118
151
|
|
|
119
|
-
all_episodes.append(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
152
|
+
all_episodes.append((
|
|
153
|
+
cat(states, dim = 1),
|
|
154
|
+
cat(log_probs, dim = 1),
|
|
155
|
+
cat(switch_betas, dim = 1),
|
|
156
|
+
cat(latent_actions, dim = 1)
|
|
124
157
|
))
|
|
125
158
|
|
|
126
159
|
all_rewards.append(torch.randn(1))
|
|
@@ -128,29 +161,37 @@ def test_metacontroller(
|
|
|
128
161
|
# calculate advantages using z-score
|
|
129
162
|
|
|
130
163
|
rewards = cat(all_rewards)
|
|
131
|
-
|
|
164
|
+
group_advantages = z_score(rewards)
|
|
132
165
|
|
|
133
|
-
assert
|
|
166
|
+
assert group_advantages.shape == (3,)
|
|
134
167
|
|
|
135
168
|
# simulate a policy loss update over the entire group
|
|
136
169
|
|
|
137
|
-
group_states
|
|
138
|
-
group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
|
|
139
|
-
group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
|
|
140
|
-
group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
|
|
170
|
+
group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
|
|
141
171
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
advantages
|
|
149
|
-
group_switch_betas == 1.,
|
|
150
|
-
episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
|
|
172
|
+
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
173
|
+
replay_buffer.store_episode(
|
|
174
|
+
states = states,
|
|
175
|
+
log_probs = log_probs,
|
|
176
|
+
switch_betas = switch_betas,
|
|
177
|
+
latent_actions = latent_actions,
|
|
178
|
+
advantages = advantages
|
|
151
179
|
)
|
|
152
180
|
|
|
153
|
-
|
|
181
|
+
dl = replay_buffer.dataloader(batch_size = 3)
|
|
182
|
+
|
|
183
|
+
batch = next(iter(dl))
|
|
184
|
+
|
|
185
|
+
loss = meta_controller.policy_loss(
|
|
186
|
+
batch['states'],
|
|
187
|
+
batch['log_probs'],
|
|
188
|
+
batch['latent_actions'],
|
|
189
|
+
batch['advantages'],
|
|
190
|
+
batch['switch_betas'] == 1.,
|
|
191
|
+
episode_lens = batch['_lens']
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
loss.backward()
|
|
154
195
|
|
|
155
196
|
# evolutionary strategies over grpo
|
|
156
197
|
|
|
@@ -170,3 +211,5 @@ def test_metacontroller(
|
|
|
170
211
|
|
|
171
212
|
Path('./meta_controller.pt').unlink()
|
|
172
213
|
Path('./trained.pt').unlink()
|
|
214
|
+
|
|
215
|
+
rmtree(test_folder, ignore_errors = True)
|
{metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/train_behavior_clone_babyai.py
RENAMED
|
@@ -25,29 +25,35 @@ from accelerate import Accelerator
|
|
|
25
25
|
from memmap_replay_buffer import ReplayBuffer
|
|
26
26
|
from einops import rearrange
|
|
27
27
|
|
|
28
|
-
from metacontroller.metacontroller import Transformer
|
|
28
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
29
29
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
30
30
|
|
|
31
31
|
import minigrid
|
|
32
32
|
import gymnasium as gym
|
|
33
33
|
|
|
34
34
|
def train(
|
|
35
|
-
input_dir
|
|
36
|
-
env_id
|
|
37
|
-
cloning_epochs
|
|
38
|
-
discovery_epochs
|
|
39
|
-
batch_size
|
|
40
|
-
lr
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
35
|
+
input_dir = "babyai-minibosslevel-trajectories",
|
|
36
|
+
env_id = "BabyAI-MiniBossLevel-v0",
|
|
37
|
+
cloning_epochs = 10,
|
|
38
|
+
discovery_epochs = 10,
|
|
39
|
+
batch_size = 32,
|
|
40
|
+
lr = 1e-4,
|
|
41
|
+
discovery_lr = 1e-4,
|
|
42
|
+
dim = 512,
|
|
43
|
+
depth = 2,
|
|
44
|
+
heads = 8,
|
|
45
|
+
dim_head = 64,
|
|
46
|
+
use_wandb = False,
|
|
47
|
+
wandb_project = "metacontroller-babyai-bc",
|
|
48
|
+
checkpoint_path = "transformer_bc.pt",
|
|
49
|
+
meta_controller_checkpoint_path = "meta_controller_discovery.pt",
|
|
50
|
+
state_loss_weight = 1.,
|
|
51
|
+
action_loss_weight = 1.,
|
|
52
|
+
discovery_action_recon_loss_weight = 1.,
|
|
53
|
+
discovery_kl_loss_weight = 1.,
|
|
54
|
+
discovery_switch_loss_weight = 1.,
|
|
55
|
+
max_grad_norm = 1.,
|
|
56
|
+
use_resnet = False
|
|
51
57
|
):
|
|
52
58
|
# accelerator
|
|
53
59
|
|
|
@@ -96,6 +102,7 @@ def train(
|
|
|
96
102
|
# transformer
|
|
97
103
|
|
|
98
104
|
transformer_class = TransformerWithResnet if use_resnet else Transformer
|
|
105
|
+
|
|
99
106
|
model = transformer_class(
|
|
100
107
|
dim = dim,
|
|
101
108
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
@@ -104,23 +111,34 @@ def train(
|
|
|
104
111
|
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
|
|
105
112
|
)
|
|
106
113
|
|
|
114
|
+
# meta controller
|
|
115
|
+
|
|
116
|
+
meta_controller = MetaController(dim)
|
|
117
|
+
|
|
107
118
|
# optimizer
|
|
108
119
|
|
|
109
|
-
|
|
120
|
+
optim_model = Adam(model.parameters(), lr = lr)
|
|
121
|
+
|
|
122
|
+
optim_meta_controller = Adam(meta_controller.discovery_parameters(), lr = discovery_lr)
|
|
110
123
|
|
|
111
124
|
# prepare
|
|
112
125
|
|
|
113
|
-
model,
|
|
126
|
+
model, optim_model, optim_meta_controller, dataloader = accelerator.prepare(model, optim_model, optim_meta_controller, dataloader)
|
|
114
127
|
|
|
115
128
|
# training
|
|
129
|
+
|
|
116
130
|
for epoch in range(cloning_epochs + discovery_epochs):
|
|
131
|
+
|
|
117
132
|
model.train()
|
|
118
133
|
total_state_loss = 0.
|
|
119
134
|
total_action_loss = 0.
|
|
120
135
|
|
|
121
136
|
progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
|
|
137
|
+
|
|
122
138
|
is_discovering = (epoch >= cloning_epochs) # discovery phase is BC with metacontroller tuning
|
|
123
139
|
|
|
140
|
+
optim = optim_model if not is_discovering else optim_meta_controller
|
|
141
|
+
|
|
124
142
|
for batch in progress_bar:
|
|
125
143
|
# batch is a NamedTuple (e.g. MemoryMappedBatch)
|
|
126
144
|
# state: (B, T, 7, 7, 3), action: (B, T)
|
|
@@ -130,18 +148,53 @@ def train(
|
|
|
130
148
|
episode_lens = batch.get('_lens')
|
|
131
149
|
|
|
132
150
|
# use resnet18 to embed visual observations
|
|
151
|
+
|
|
133
152
|
if use_resnet:
|
|
134
153
|
states = model.visual_encode(states)
|
|
135
154
|
else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
|
|
136
155
|
states = rearrange(states, 'b t ... -> b t (...)')
|
|
137
156
|
|
|
138
157
|
with accelerator.accumulate(model):
|
|
139
|
-
|
|
140
|
-
|
|
158
|
+
losses = model(
|
|
159
|
+
states,
|
|
160
|
+
actions,
|
|
161
|
+
episode_lens = episode_lens,
|
|
162
|
+
discovery_phase = is_discovering,
|
|
163
|
+
meta_controller = meta_controller if is_discovering else None
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if is_discovering:
|
|
167
|
+
action_recon_loss, kl_loss, switch_loss = losses
|
|
168
|
+
|
|
169
|
+
loss = (
|
|
170
|
+
action_recon_loss * discovery_action_recon_loss_weight +
|
|
171
|
+
kl_loss * discovery_kl_loss_weight +
|
|
172
|
+
switch_loss * discovery_switch_loss_weight
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
log = dict(
|
|
176
|
+
action_recon_loss = action_recon_loss.item(),
|
|
177
|
+
kl_loss = kl_loss.item(),
|
|
178
|
+
switch_loss = switch_loss.item()
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
state_loss, action_loss = losses
|
|
182
|
+
|
|
183
|
+
loss = (
|
|
184
|
+
state_loss * state_loss_weight +
|
|
185
|
+
action_loss * action_loss_weight
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
log = dict(
|
|
189
|
+
state_loss = state_loss.item(),
|
|
190
|
+
action_loss = action_loss.item(),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# backprop
|
|
141
194
|
|
|
142
195
|
accelerator.backward(loss)
|
|
143
196
|
|
|
144
|
-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=
|
|
197
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
|
|
145
198
|
|
|
146
199
|
optim.step()
|
|
147
200
|
optim.zero_grad()
|
|
@@ -152,8 +205,7 @@ def train(
|
|
|
152
205
|
total_action_loss += action_loss.item()
|
|
153
206
|
|
|
154
207
|
accelerator.log({
|
|
155
|
-
|
|
156
|
-
"action_loss": action_loss.item(),
|
|
208
|
+
**log,
|
|
157
209
|
"total_loss": loss.item(),
|
|
158
210
|
"grad_norm": grad_norm.item()
|
|
159
211
|
})
|
|
@@ -172,9 +224,14 @@ def train(
|
|
|
172
224
|
|
|
173
225
|
accelerator.wait_for_everyone()
|
|
174
226
|
if accelerator.is_main_process:
|
|
227
|
+
|
|
175
228
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
176
229
|
unwrapped_model.save(checkpoint_path)
|
|
177
|
-
|
|
230
|
+
|
|
231
|
+
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
232
|
+
unwrapped_meta_controller.save(meta_controller_checkpoint_path)
|
|
233
|
+
|
|
234
|
+
accelerator.print(f"Model saved to {checkpoint_path}, MetaControler to {meta_controller_checkpoint_path}")
|
|
178
235
|
|
|
179
236
|
accelerator.end_training()
|
|
180
237
|
|
{metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.36}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|