metacontroller-pytorch 0.0.41__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/README.md +1 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller.py +19 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller_with_binary_mapper.py +1 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/tests/test_metacontroller.py +16 -16
- metacontroller_pytorch-0.0.43/train_baby_evo_strat.py +213 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/train_babyai.py +21 -21
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/train_behavior_clone_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
|
|
|
53
53
|
|
|
54
54
|
<img src="./fig1.png" width="400px"></img>
|
|
55
55
|
|
|
56
|
-
## metacontroller
|
|
56
|
+
## metacontroller
|
|
57
57
|
|
|
58
58
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
59
59
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
<img src="./fig1.png" width="400px"></img>
|
|
2
2
|
|
|
3
|
-
## metacontroller
|
|
3
|
+
## metacontroller
|
|
4
4
|
|
|
5
5
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
6
6
|
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller.py
RENAMED
|
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
66
66
|
'switch_loss'
|
|
67
67
|
))
|
|
68
68
|
|
|
69
|
+
GRPOOutput = namedtuple('GRPOOutput', (
|
|
70
|
+
'state',
|
|
71
|
+
'action',
|
|
72
|
+
'log_prob',
|
|
73
|
+
'switch_beta'
|
|
74
|
+
))
|
|
75
|
+
|
|
69
76
|
def z_score(t, eps = 1e-8):
|
|
70
77
|
return (t - t.mean()) / (t.std() + eps)
|
|
71
78
|
|
|
@@ -107,6 +114,17 @@ def policy_loss(
|
|
|
107
114
|
|
|
108
115
|
return masked_mean(losses, mask)
|
|
109
116
|
|
|
117
|
+
def extract_grpo_data(meta_controller, transformer_output):
|
|
118
|
+
meta_output = transformer_output.prev_hiddens.meta_controller
|
|
119
|
+
|
|
120
|
+
state = meta_output.input_residual_stream
|
|
121
|
+
action = meta_output.actions
|
|
122
|
+
switch_beta = meta_output.switch_beta
|
|
123
|
+
|
|
124
|
+
log_prob = meta_controller.log_prob(meta_output.action_dist, action)
|
|
125
|
+
|
|
126
|
+
return GRPOOutput(state, action, log_prob, switch_beta)
|
|
127
|
+
|
|
110
128
|
@save_load()
|
|
111
129
|
class MetaController(Module):
|
|
112
130
|
def __init__(
|
|
@@ -273,7 +291,7 @@ class MetaController(Module):
|
|
|
273
291
|
else:
|
|
274
292
|
# else during inference, use the previous sampled latent action
|
|
275
293
|
|
|
276
|
-
assert seq_len == 1,
|
|
294
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
277
295
|
z_prev = prev_sampled_latent_action
|
|
278
296
|
|
|
279
297
|
# switch input is previous latent action and the embedding
|
|
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
241
241
|
if discovery_phase:
|
|
242
242
|
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
243
243
|
else:
|
|
244
|
-
assert seq_len == 1,
|
|
244
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
245
245
|
z_prev = prev_sampled_code
|
|
246
246
|
|
|
247
247
|
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/tests/test_metacontroller.py
RENAMED
|
@@ -7,7 +7,7 @@ from functools import partial
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
from torch import cat
|
|
10
|
-
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
10
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
11
11
|
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
12
12
|
|
|
13
13
|
from memmap_replay_buffer import ReplayBuffer
|
|
@@ -109,36 +109,29 @@ def test_metacontroller(
|
|
|
109
109
|
cache = None
|
|
110
110
|
past_action_id = None
|
|
111
111
|
|
|
112
|
-
|
|
113
|
-
log_probs = []
|
|
114
|
-
switch_betas = []
|
|
115
|
-
latent_actions = []
|
|
112
|
+
grpo_data_list = []
|
|
116
113
|
|
|
117
114
|
for one_state in subset_state.unbind(dim = 1):
|
|
118
115
|
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
119
116
|
|
|
120
|
-
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
117
|
+
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, cache = cache, return_cache = True)
|
|
121
118
|
|
|
122
119
|
past_action_id = model.action_readout.sample(logits)
|
|
123
120
|
|
|
124
|
-
#
|
|
121
|
+
# extract grpo data and store
|
|
125
122
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
129
|
-
|
|
130
|
-
states.append(meta_output.input_residual_stream)
|
|
131
|
-
log_probs.append(old_log_probs)
|
|
132
|
-
switch_betas.append(meta_output.switch_beta)
|
|
133
|
-
latent_actions.append(meta_output.actions)
|
|
123
|
+
grpo_data = extract_grpo_data(meta_controller, cache)
|
|
124
|
+
grpo_data_list.append(grpo_data)
|
|
134
125
|
|
|
135
126
|
# accumulate across time for the episode data
|
|
136
127
|
|
|
128
|
+
states, actions, log_probs, switch_betas = zip(*grpo_data_list)
|
|
129
|
+
|
|
137
130
|
all_episodes.append((
|
|
138
131
|
cat(states, dim = 1),
|
|
139
132
|
cat(log_probs, dim = 1),
|
|
140
133
|
cat(switch_betas, dim = 1),
|
|
141
|
-
cat(
|
|
134
|
+
cat(actions, dim = 1)
|
|
142
135
|
))
|
|
143
136
|
|
|
144
137
|
all_rewards.append(torch.randn(1))
|
|
@@ -153,6 +146,13 @@ def test_metacontroller(
|
|
|
153
146
|
# simulate a policy loss update over the entire group
|
|
154
147
|
|
|
155
148
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
|
|
149
|
+
|
|
150
|
+
# parallel verification
|
|
151
|
+
|
|
152
|
+
parallel_action_dist = meta_controller.get_action_dist_for_internal_rl(group_states)
|
|
153
|
+
parallel_log_probs = meta_controller.log_prob(parallel_action_dist, group_latent_actions)
|
|
154
|
+
|
|
155
|
+
assert torch.allclose(parallel_log_probs, group_log_probs, atol = 1e-5), 'parallel log probs do not match stored log probs'
|
|
156
156
|
|
|
157
157
|
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
158
158
|
replay_buffer.store_episode(
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "fire",
|
|
4
|
+
# "gymnasium",
|
|
5
|
+
# "gymnasium[other]",
|
|
6
|
+
# "metacontroller-pytorch",
|
|
7
|
+
# "minigrid",
|
|
8
|
+
# "tqdm",
|
|
9
|
+
# "x-evolution",
|
|
10
|
+
# "einops"
|
|
11
|
+
# ]
|
|
12
|
+
# ///
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import fire
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from shutil import rmtree
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn, Tensor, tensor
|
|
22
|
+
from torch.nn import Module
|
|
23
|
+
from einops import rearrange
|
|
24
|
+
|
|
25
|
+
from babyai_env import create_env
|
|
26
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
27
|
+
|
|
28
|
+
# functions
|
|
29
|
+
|
|
30
|
+
def exists(v):
|
|
31
|
+
return v is not None
|
|
32
|
+
|
|
33
|
+
def default(v, d):
|
|
34
|
+
return v if exists(v) else d
|
|
35
|
+
|
|
36
|
+
# default fitness function
|
|
37
|
+
|
|
38
|
+
def default_fitness_fn(
|
|
39
|
+
rewards: list[float],
|
|
40
|
+
states: list[any],
|
|
41
|
+
actions: list[any],
|
|
42
|
+
next_states: list[any],
|
|
43
|
+
infos: list[any]
|
|
44
|
+
) -> float:
|
|
45
|
+
"""
|
|
46
|
+
researchers can modify this function to engineer their own rewards and fitness scores
|
|
47
|
+
processing the entire episode at once for every noise vector of the population separately
|
|
48
|
+
"""
|
|
49
|
+
return sum(rewards)
|
|
50
|
+
|
|
51
|
+
# babyai environment for ES
|
|
52
|
+
|
|
53
|
+
class BabyAIEnvironment(Module):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
57
|
+
video_folder = './recordings_babyai_es',
|
|
58
|
+
render_every_eps = 100,
|
|
59
|
+
max_steps = 500,
|
|
60
|
+
use_resnet = False,
|
|
61
|
+
fitness_fn = default_fitness_fn
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self.env_id = env_id
|
|
66
|
+
self.video_folder = video_folder
|
|
67
|
+
self.render_every_eps = render_every_eps
|
|
68
|
+
self.max_steps = max_steps
|
|
69
|
+
self.use_resnet = use_resnet
|
|
70
|
+
self.fitness_fn = fitness_fn
|
|
71
|
+
|
|
72
|
+
# initial env creation for observation space etc. if needed
|
|
73
|
+
# but create_env is called inside pre_main_callback or reset
|
|
74
|
+
self.env = None
|
|
75
|
+
|
|
76
|
+
def pre_main_callback(self):
|
|
77
|
+
# clean up and initialize environment
|
|
78
|
+
rmtree(self.video_folder, ignore_errors = True)
|
|
79
|
+
|
|
80
|
+
self.env = create_env(
|
|
81
|
+
self.env_id,
|
|
82
|
+
render_mode = 'rgb_array',
|
|
83
|
+
video_folder = self.video_folder,
|
|
84
|
+
render_every_eps = self.render_every_eps
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def forward(self, model):
|
|
88
|
+
device = next(model.parameters()).device
|
|
89
|
+
|
|
90
|
+
seed = torch.randint(0, int(1e6), ()).item()
|
|
91
|
+
state, _ = self.env.reset(seed = seed)
|
|
92
|
+
|
|
93
|
+
step = 0
|
|
94
|
+
cache = None
|
|
95
|
+
past_action_id = None
|
|
96
|
+
|
|
97
|
+
unwrapped_model = getattr(model, 'model', model)
|
|
98
|
+
|
|
99
|
+
episode_rewards = []
|
|
100
|
+
episode_states = []
|
|
101
|
+
episode_actions = []
|
|
102
|
+
episode_next_states = []
|
|
103
|
+
episode_infos = []
|
|
104
|
+
|
|
105
|
+
while step < self.max_steps:
|
|
106
|
+
image = state['image']
|
|
107
|
+
image_tensor = torch.from_numpy(image).float().to(device)
|
|
108
|
+
|
|
109
|
+
if self.use_resnet:
|
|
110
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
|
|
111
|
+
image_tensor = unwrapped_model.visual_encode(image_tensor)
|
|
112
|
+
else:
|
|
113
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
|
|
114
|
+
|
|
115
|
+
if torch.is_tensor(past_action_id):
|
|
116
|
+
past_action_id = past_action_id.long()
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
logits, cache = model(
|
|
120
|
+
image_tensor,
|
|
121
|
+
past_action_id,
|
|
122
|
+
return_cache = True,
|
|
123
|
+
return_raw_action_dist = True,
|
|
124
|
+
cache = cache
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
action = unwrapped_model.action_readout.sample(logits)
|
|
128
|
+
past_action_id = action
|
|
129
|
+
action_id = action.squeeze()
|
|
130
|
+
|
|
131
|
+
next_state, reward, terminated, truncated, info = self.env.step(action_id.cpu().numpy().item())
|
|
132
|
+
|
|
133
|
+
episode_rewards.append(reward)
|
|
134
|
+
episode_states.append(state)
|
|
135
|
+
episode_actions.append(action_id)
|
|
136
|
+
episode_next_states.append(next_state)
|
|
137
|
+
episode_infos.append(info)
|
|
138
|
+
|
|
139
|
+
done = terminated or truncated
|
|
140
|
+
if done:
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
state = next_state
|
|
144
|
+
step += 1
|
|
145
|
+
|
|
146
|
+
return self.fitness_fn(
|
|
147
|
+
episode_rewards,
|
|
148
|
+
episode_states,
|
|
149
|
+
episode_actions,
|
|
150
|
+
episode_next_states,
|
|
151
|
+
episode_infos
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def main(
|
|
155
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
156
|
+
num_generations = 100,
|
|
157
|
+
max_steps = 500,
|
|
158
|
+
render_every_eps = 100,
|
|
159
|
+
video_folder = './recordings_babyai_es',
|
|
160
|
+
transformer_weights_path: str | None = None,
|
|
161
|
+
meta_controller_weights_path: str | None = None,
|
|
162
|
+
output_meta_controller_path = 'metacontroller_es_trained.pt',
|
|
163
|
+
use_resnet = False,
|
|
164
|
+
noise_population_size = 50,
|
|
165
|
+
noise_scale = 1e-2,
|
|
166
|
+
learning_rate = 1e-3,
|
|
167
|
+
fitness_fn = default_fitness_fn
|
|
168
|
+
):
|
|
169
|
+
# load model
|
|
170
|
+
|
|
171
|
+
assert exists(transformer_weights_path), "Transformer weights must be provided"
|
|
172
|
+
|
|
173
|
+
# lazy import to avoid unnecessary dependencies if not used
|
|
174
|
+
from metacontroller.transformer_with_resnet import TransformerWithResnet as TransformerResnet
|
|
175
|
+
transformer_klass = TransformerResnet if use_resnet else Transformer
|
|
176
|
+
|
|
177
|
+
model = transformer_klass.init_and_load(transformer_weights_path, strict = False)
|
|
178
|
+
model.eval()
|
|
179
|
+
|
|
180
|
+
if exists(meta_controller_weights_path):
|
|
181
|
+
meta_controller = MetaController.init_and_load(meta_controller_weights_path, strict = False)
|
|
182
|
+
model.meta_controller = meta_controller
|
|
183
|
+
|
|
184
|
+
assert exists(model.meta_controller), "MetaController must be present for evolution"
|
|
185
|
+
|
|
186
|
+
# setup environment
|
|
187
|
+
|
|
188
|
+
babyai_env = BabyAIEnvironment(
|
|
189
|
+
env_id = env_id,
|
|
190
|
+
video_folder = video_folder,
|
|
191
|
+
render_every_eps = render_every_eps,
|
|
192
|
+
max_steps = max_steps,
|
|
193
|
+
use_resnet = use_resnet,
|
|
194
|
+
fitness_fn = fitness_fn
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# evolve
|
|
198
|
+
|
|
199
|
+
model.evolve(
|
|
200
|
+
num_generations = num_generations,
|
|
201
|
+
environment = babyai_env,
|
|
202
|
+
noise_population_size = noise_population_size,
|
|
203
|
+
noise_scale = noise_scale,
|
|
204
|
+
learning_rate = learning_rate
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# save
|
|
208
|
+
|
|
209
|
+
model.meta_controller.save(output_meta_controller_path)
|
|
210
|
+
print(f'MetaController weights saved to {output_meta_controller_path}')
|
|
211
|
+
|
|
212
|
+
if __name__ == '__main__':
|
|
213
|
+
fire.Fire(main)
|
|
@@ -26,7 +26,7 @@ from accelerate import Accelerator
|
|
|
26
26
|
|
|
27
27
|
from babyai_env import create_env
|
|
28
28
|
from memmap_replay_buffer import ReplayBuffer
|
|
29
|
-
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
29
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
30
30
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
31
31
|
|
|
32
32
|
# research entry point
|
|
@@ -57,22 +57,23 @@ def default(v, d):
|
|
|
57
57
|
# main
|
|
58
58
|
|
|
59
59
|
def main(
|
|
60
|
-
env_name
|
|
61
|
-
num_episodes
|
|
62
|
-
max_timesteps
|
|
63
|
-
buffer_size
|
|
64
|
-
render_every_eps
|
|
65
|
-
video_folder
|
|
60
|
+
env_name = 'BabyAI-BossLevel-v0',
|
|
61
|
+
num_episodes = int(10e6),
|
|
62
|
+
max_timesteps = 500,
|
|
63
|
+
buffer_size = 5_000,
|
|
64
|
+
render_every_eps = 1_000,
|
|
65
|
+
video_folder = './recordings',
|
|
66
66
|
seed: int | None = None,
|
|
67
67
|
transformer_weights_path: str | None = None,
|
|
68
68
|
meta_controller_weights_path: str | None = None,
|
|
69
|
-
output_meta_controller_path
|
|
70
|
-
use_resnet
|
|
71
|
-
lr
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
69
|
+
output_meta_controller_path = 'metacontroller_rl_trained.pt',
|
|
70
|
+
use_resnet = False,
|
|
71
|
+
lr = 1e-4,
|
|
72
|
+
batch_size = 16,
|
|
73
|
+
num_groups = 16,
|
|
74
|
+
max_grad_norm = 1.0,
|
|
75
|
+
use_wandb = False,
|
|
76
|
+
wandb_project = 'metacontroller-babyai-rl'
|
|
76
77
|
):
|
|
77
78
|
# accelerator
|
|
78
79
|
|
|
@@ -195,13 +196,12 @@ def main(
|
|
|
195
196
|
|
|
196
197
|
# GRPO collection
|
|
197
198
|
|
|
198
|
-
|
|
199
|
-
old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
199
|
+
grpo_data = extract_grpo_data(unwrapped_meta_controller, cache)
|
|
200
200
|
|
|
201
|
-
states.append(
|
|
202
|
-
log_probs.append(
|
|
203
|
-
switch_betas.append(
|
|
204
|
-
latent_actions.append(
|
|
201
|
+
states.append(grpo_data.state)
|
|
202
|
+
log_probs.append(grpo_data.log_prob)
|
|
203
|
+
switch_betas.append(grpo_data.switch_beta)
|
|
204
|
+
latent_actions.append(grpo_data.action)
|
|
205
205
|
|
|
206
206
|
next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
|
|
207
207
|
|
|
@@ -264,7 +264,7 @@ def main(
|
|
|
264
264
|
# learn
|
|
265
265
|
|
|
266
266
|
if len(replay_buffer) >= buffer_size:
|
|
267
|
-
dl = replay_buffer.dataloader(batch_size =
|
|
267
|
+
dl = replay_buffer.dataloader(batch_size = batch_size)
|
|
268
268
|
dl = accelerator.prepare(dl)
|
|
269
269
|
|
|
270
270
|
meta_controller.train()
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|