metacontroller-pytorch 0.0.25__tar.gz → 0.0.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/PKG-INFO +2 -1
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/metacontroller/metacontroller.py +8 -5
- metacontroller_pytorch-0.0.27/metacontroller/metacontroller_with_binary_mapper.py +266 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/pyproject.toml +2 -1
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/tests/test_metacontroller.py +25 -13
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/README.md +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/train_babyai.py +0 -0
- {metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/train_behavior_clone_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.27
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -42,6 +42,7 @@ Requires-Dist: loguru
|
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.23
|
|
43
43
|
Requires-Dist: torch-einops-utils>=0.0.16
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
|
+
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
45
46
|
Requires-Dist: x-evolution>=0.1.23
|
|
46
47
|
Requires-Dist: x-mlps-pytorch
|
|
47
48
|
Requires-Dist: x-transformers
|
{metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/metacontroller/metacontroller.py
RENAMED
|
@@ -6,7 +6,7 @@ from collections import namedtuple
|
|
|
6
6
|
from loguru import logger
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, stack, tensor
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
10
|
from torch.nn import Module, GRU, Linear, Identity
|
|
11
11
|
import torch.nn.functional as F
|
|
12
12
|
|
|
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
-
from torch_einops_utils import pad_at_dim, lens_to_mask
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
30
30
|
from torch_einops_utils.save_load import save_load
|
|
31
31
|
|
|
32
32
|
# constants
|
|
@@ -151,7 +151,8 @@ class MetaController(Module):
|
|
|
151
151
|
cache: MetaControllerOutput | None = None,
|
|
152
152
|
discovery_phase = False,
|
|
153
153
|
hard_switch = False,
|
|
154
|
-
temperature = 1
|
|
154
|
+
temperature = 1.,
|
|
155
|
+
episode_lens: Tensor | None = None
|
|
155
156
|
):
|
|
156
157
|
device = residual_stream.device
|
|
157
158
|
|
|
@@ -168,7 +169,9 @@ class MetaController(Module):
|
|
|
168
169
|
if discovery_phase:
|
|
169
170
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
170
171
|
|
|
171
|
-
|
|
172
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
173
|
+
|
|
174
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
172
175
|
|
|
173
176
|
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
174
177
|
readout = self.emitter_to_action_mean_log_var
|
|
@@ -391,7 +394,7 @@ class Transformer(Module):
|
|
|
391
394
|
with meta_controller_context():
|
|
392
395
|
|
|
393
396
|
if exists(meta_controller):
|
|
394
|
-
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
397
|
+
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
395
398
|
else:
|
|
396
399
|
control_signal, next_meta_hiddens = self.zero, None
|
|
397
400
|
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
|
+
from torch.nn import Module, GRU, Linear, Identity
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
# einops
|
|
14
|
+
|
|
15
|
+
import einx
|
|
16
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
17
|
+
from einops.layers.torch import Rearrange
|
|
18
|
+
|
|
19
|
+
# external modules
|
|
20
|
+
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
|
+
from x_mlps_pytorch import Feedforwards
|
|
23
|
+
|
|
24
|
+
from assoc_scan import AssocScan
|
|
25
|
+
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
27
|
+
from torch_einops_utils.save_load import save_load
|
|
28
|
+
|
|
29
|
+
from vector_quantize_pytorch import BinaryMapper
|
|
30
|
+
|
|
31
|
+
# constants
|
|
32
|
+
|
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
34
|
+
|
|
35
|
+
GRU = partial(GRU, batch_first = True)
|
|
36
|
+
|
|
37
|
+
# helper functions
|
|
38
|
+
|
|
39
|
+
def exists(v):
|
|
40
|
+
return v is not None
|
|
41
|
+
|
|
42
|
+
def default(*args):
|
|
43
|
+
for arg in args:
|
|
44
|
+
if exists(arg):
|
|
45
|
+
return arg
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
def straight_through(src, tgt):
|
|
49
|
+
return tgt + src - src.detach()
|
|
50
|
+
|
|
51
|
+
# meta controller
|
|
52
|
+
|
|
53
|
+
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
54
|
+
'prev_hiddens',
|
|
55
|
+
'action_dist',
|
|
56
|
+
'codes',
|
|
57
|
+
'kl_loss',
|
|
58
|
+
'switch_loss'
|
|
59
|
+
))
|
|
60
|
+
|
|
61
|
+
@save_load()
|
|
62
|
+
class MetaControllerWithBinaryMapper(Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
dim_model,
|
|
66
|
+
*,
|
|
67
|
+
dim_meta_controller = 256,
|
|
68
|
+
dim_code_bits = 4,
|
|
69
|
+
switch_per_code = False,
|
|
70
|
+
decoder_expansion_factor = 2.,
|
|
71
|
+
decoder_depth = 1,
|
|
72
|
+
hypernetwork_low_rank = 16,
|
|
73
|
+
assoc_scan_kwargs: dict = dict(),
|
|
74
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
75
|
+
attn_dim_head = 32, heads = 8
|
|
76
|
+
),
|
|
77
|
+
kl_loss_threshold = 0.
|
|
78
|
+
):
|
|
79
|
+
super().__init__()
|
|
80
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
81
|
+
|
|
82
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
83
|
+
|
|
84
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
85
|
+
|
|
86
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
87
|
+
self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
|
|
88
|
+
|
|
89
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
90
|
+
self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
|
|
91
|
+
|
|
92
|
+
# binary mapper
|
|
93
|
+
# proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
|
|
94
|
+
|
|
95
|
+
self.binary_mapper = BinaryMapper(
|
|
96
|
+
bits = dim_code_bits,
|
|
97
|
+
kl_loss_threshold = kl_loss_threshold
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.dim_code_bits = dim_code_bits
|
|
101
|
+
self.num_codes = self.binary_mapper.num_codes
|
|
102
|
+
|
|
103
|
+
# switching unit
|
|
104
|
+
|
|
105
|
+
self.switch_per_code = switch_per_code
|
|
106
|
+
|
|
107
|
+
self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
|
|
108
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
|
|
109
|
+
|
|
110
|
+
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
111
|
+
|
|
112
|
+
# decoder
|
|
113
|
+
|
|
114
|
+
assert hypernetwork_low_rank < self.num_codes
|
|
115
|
+
|
|
116
|
+
dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
|
|
117
|
+
|
|
118
|
+
self.decoder = Feedforwards(
|
|
119
|
+
dim_in = self.num_codes,
|
|
120
|
+
dim = dim_decoder_hidden,
|
|
121
|
+
depth = decoder_depth,
|
|
122
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
126
|
+
|
|
127
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
128
|
+
|
|
129
|
+
def discovery_parameters(self):
|
|
130
|
+
return [
|
|
131
|
+
*self.model_to_meta.parameters(),
|
|
132
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
133
|
+
*self.emitter.parameters(),
|
|
134
|
+
*self.emitter_to_binary_logits.parameters(),
|
|
135
|
+
*self.binary_mapper.parameters(),
|
|
136
|
+
*self.decoder.parameters(),
|
|
137
|
+
*self.switch_gating.parameters()
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
def internal_rl_parameters(self):
|
|
141
|
+
return [
|
|
142
|
+
*self.action_proposer.parameters(),
|
|
143
|
+
*self.proposer_to_binary_logits.parameters()
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
residual_stream,
|
|
149
|
+
cache: MetaControllerOutput | None = None,
|
|
150
|
+
discovery_phase = False,
|
|
151
|
+
hard_switch = False,
|
|
152
|
+
temperature = 1.,
|
|
153
|
+
episode_lens: Tensor | None = None
|
|
154
|
+
):
|
|
155
|
+
device = residual_stream.device
|
|
156
|
+
|
|
157
|
+
# destruct prev cache
|
|
158
|
+
|
|
159
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
160
|
+
|
|
161
|
+
# getting proposed action for the two phases
|
|
162
|
+
|
|
163
|
+
next_action_proposer_hidden = None
|
|
164
|
+
|
|
165
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
166
|
+
|
|
167
|
+
if discovery_phase:
|
|
168
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
169
|
+
|
|
170
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
171
|
+
|
|
172
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
173
|
+
to_logits = self.emitter_to_binary_logits
|
|
174
|
+
|
|
175
|
+
else: # else internal rl phase
|
|
176
|
+
|
|
177
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
178
|
+
to_logits = self.proposer_to_binary_logits
|
|
179
|
+
|
|
180
|
+
# sample from the binary mapper
|
|
181
|
+
|
|
182
|
+
binary_logits = to_logits(proposed_action_hidden)
|
|
183
|
+
|
|
184
|
+
one_hot, kl_loss = self.binary_mapper(
|
|
185
|
+
binary_logits,
|
|
186
|
+
temperature = temperature,
|
|
187
|
+
reduce_aux_kl_loss = False
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# bottled action is now the one-hot sparse codes (with straight-through)
|
|
191
|
+
|
|
192
|
+
sampled_codes = one_hot
|
|
193
|
+
|
|
194
|
+
# switching unit timer
|
|
195
|
+
|
|
196
|
+
batch, seq_len, dim = sampled_codes.shape
|
|
197
|
+
|
|
198
|
+
if not exists(prev_sampled_code):
|
|
199
|
+
prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
|
|
200
|
+
|
|
201
|
+
if discovery_phase:
|
|
202
|
+
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
203
|
+
else:
|
|
204
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
205
|
+
z_prev = prev_sampled_code
|
|
206
|
+
|
|
207
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
208
|
+
|
|
209
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
210
|
+
switch_input,
|
|
211
|
+
prev_switching_unit_gru_hidden
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
215
|
+
|
|
216
|
+
# losses
|
|
217
|
+
|
|
218
|
+
switch_loss = self.zero
|
|
219
|
+
|
|
220
|
+
if discovery_phase:
|
|
221
|
+
# weight unreduced kl loss by switch gates
|
|
222
|
+
|
|
223
|
+
weighted_kl_loss = kl_loss * switch_beta
|
|
224
|
+
kl_loss = weighted_kl_loss.sum(dim = -1).mean()
|
|
225
|
+
|
|
226
|
+
# encourage less switching
|
|
227
|
+
|
|
228
|
+
switch_loss = switch_beta.mean()
|
|
229
|
+
else:
|
|
230
|
+
kl_loss = self.zero
|
|
231
|
+
|
|
232
|
+
# maybe hard switch, then use associative scan
|
|
233
|
+
|
|
234
|
+
if hard_switch:
|
|
235
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
236
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
237
|
+
|
|
238
|
+
forget = 1. - switch_beta
|
|
239
|
+
|
|
240
|
+
# gated codes (or soft distribution)
|
|
241
|
+
|
|
242
|
+
gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
|
|
243
|
+
|
|
244
|
+
next_switch_gated_codes = gated_codes[:, -1]
|
|
245
|
+
|
|
246
|
+
# decoder
|
|
247
|
+
|
|
248
|
+
decoder_out = self.decoder(gated_codes)
|
|
249
|
+
|
|
250
|
+
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
251
|
+
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
252
|
+
|
|
253
|
+
# generating the residual stream controlling signal
|
|
254
|
+
|
|
255
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
256
|
+
|
|
257
|
+
# returning
|
|
258
|
+
|
|
259
|
+
next_hiddens = (
|
|
260
|
+
next_action_proposer_hidden,
|
|
261
|
+
next_switching_unit_gru_hidden,
|
|
262
|
+
next_switch_gated_codes,
|
|
263
|
+
sampled_codes[:, -1:]
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.27"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -32,6 +32,7 @@ dependencies = [
|
|
|
32
32
|
"memmap-replay-buffer>=0.0.23",
|
|
33
33
|
"torch>=2.5",
|
|
34
34
|
"torch-einops-utils>=0.0.16",
|
|
35
|
+
"vector-quantize-pytorch>=1.27.20",
|
|
35
36
|
"x-evolution>=0.1.23",
|
|
36
37
|
"x-mlps-pytorch",
|
|
37
38
|
"x-transformers"
|
{metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/tests/test_metacontroller.py
RENAMED
|
@@ -5,27 +5,30 @@ from pathlib import Path
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from metacontroller.metacontroller import Transformer, MetaController
|
|
8
|
+
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
8
9
|
|
|
9
10
|
from einops import rearrange
|
|
10
11
|
|
|
12
|
+
@param('use_binary_mapper_variant', (False, True))
|
|
11
13
|
@param('action_discrete', (False, True))
|
|
12
14
|
@param('switch_per_latent_dim', (False, True))
|
|
13
15
|
@param('variable_length', (False, True))
|
|
14
16
|
def test_metacontroller(
|
|
17
|
+
use_binary_mapper_variant,
|
|
15
18
|
action_discrete,
|
|
16
19
|
switch_per_latent_dim,
|
|
17
20
|
variable_length
|
|
18
21
|
):
|
|
19
22
|
|
|
20
|
-
state = torch.randn(
|
|
21
|
-
episode_lens = torch.tensor([
|
|
23
|
+
state = torch.randn(2, 128, 384)
|
|
24
|
+
episode_lens = torch.tensor([64, 64]) if variable_length else None
|
|
22
25
|
|
|
23
26
|
if action_discrete:
|
|
24
|
-
actions = torch.randint(0, 4, (
|
|
27
|
+
actions = torch.randint(0, 4, (2, 128))
|
|
25
28
|
action_embed_readout = dict(num_discrete = 4)
|
|
26
29
|
assert_shape = (4,)
|
|
27
30
|
else:
|
|
28
|
-
actions = torch.randn(
|
|
31
|
+
actions = torch.randn(2, 128, 8)
|
|
29
32
|
action_embed_readout = dict(num_continuous = 8)
|
|
30
33
|
assert_shape = (8, 2)
|
|
31
34
|
|
|
@@ -44,16 +47,24 @@ def test_metacontroller(
|
|
|
44
47
|
|
|
45
48
|
# discovery and internal rl phase with meta controller
|
|
46
49
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
50
|
+
if not use_binary_mapper_variant:
|
|
51
|
+
meta_controller = MetaController(
|
|
52
|
+
dim_model = 512,
|
|
53
|
+
dim_meta_controller = 256,
|
|
54
|
+
dim_latent = 128,
|
|
55
|
+
switch_per_latent_dim = switch_per_latent_dim
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
meta_controller = MetaControllerWithBinaryMapper(
|
|
59
|
+
dim_model = 512,
|
|
60
|
+
dim_meta_controller = 256,
|
|
61
|
+
switch_per_code = switch_per_latent_dim,
|
|
62
|
+
dim_code_bits = 8, # 2**8 = 256 codes
|
|
63
|
+
)
|
|
53
64
|
|
|
54
65
|
# discovery phase
|
|
55
66
|
|
|
56
|
-
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
67
|
+
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
|
|
57
68
|
(action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
|
|
58
69
|
|
|
59
70
|
# internal rl - done iteratively
|
|
@@ -66,7 +77,7 @@ def test_metacontroller(
|
|
|
66
77
|
|
|
67
78
|
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
68
79
|
|
|
69
|
-
assert logits.shape == (
|
|
80
|
+
assert logits.shape == (2, 1, *assert_shape)
|
|
70
81
|
past_action_id = model.action_readout.sample(logits)
|
|
71
82
|
|
|
72
83
|
# evolutionary strategies over grpo
|
|
@@ -78,7 +89,8 @@ def test_metacontroller(
|
|
|
78
89
|
|
|
79
90
|
meta_controller.save('./meta_controller.pt')
|
|
80
91
|
|
|
81
|
-
|
|
92
|
+
meta_controller_klass = meta_controller.__class__
|
|
93
|
+
rehydrated_meta_controller = meta_controller_klass.init_and_load('./meta_controller.pt')
|
|
82
94
|
|
|
83
95
|
model.save('./trained.pt')
|
|
84
96
|
|
{metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.25 → metacontroller_pytorch-0.0.27}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|