metacontroller-pytorch 0.0.8__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/metacontroller/metacontroller.py +90 -24
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/tests/test_metacontroller.py +6 -1
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/README.md +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/metacontroller/metacontroller.py
RENAMED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
2
4
|
from functools import partial
|
|
3
5
|
from collections import namedtuple
|
|
4
6
|
|
|
@@ -50,6 +52,13 @@ def straight_through(src, tgt):
|
|
|
50
52
|
|
|
51
53
|
# meta controller
|
|
52
54
|
|
|
55
|
+
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
56
|
+
'prev_hiddens',
|
|
57
|
+
'action_dist',
|
|
58
|
+
'actions',
|
|
59
|
+
'kl_loss'
|
|
60
|
+
))
|
|
61
|
+
|
|
53
62
|
class MetaController(Module):
|
|
54
63
|
def __init__(
|
|
55
64
|
self,
|
|
@@ -105,9 +114,9 @@ class MetaController(Module):
|
|
|
105
114
|
return [
|
|
106
115
|
*self.bidirectional_temporal_compressor.parameters(),
|
|
107
116
|
*self.emitter.parameters(),
|
|
108
|
-
*self.emitter_to_action_mean_log_var.parameters()
|
|
117
|
+
*self.emitter_to_action_mean_log_var.parameters(),
|
|
109
118
|
*self.decoder.parameters(),
|
|
110
|
-
*self.switch_gating
|
|
119
|
+
*self.switch_gating.parameters()
|
|
111
120
|
]
|
|
112
121
|
|
|
113
122
|
def internal_rl_parameters(self):
|
|
@@ -119,10 +128,19 @@ class MetaController(Module):
|
|
|
119
128
|
def forward(
|
|
120
129
|
self,
|
|
121
130
|
residual_stream,
|
|
131
|
+
cache: MetaControllerOutput | None = None,
|
|
122
132
|
discovery_phase = False,
|
|
123
133
|
hard_switch = False
|
|
124
134
|
):
|
|
125
135
|
|
|
136
|
+
# destruct prev cache
|
|
137
|
+
|
|
138
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
139
|
+
|
|
140
|
+
# getting proposed action for the two phases
|
|
141
|
+
|
|
142
|
+
next_action_proposer_hidden = None
|
|
143
|
+
|
|
126
144
|
if discovery_phase:
|
|
127
145
|
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
128
146
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
@@ -131,7 +149,8 @@ class MetaController(Module):
|
|
|
131
149
|
readout = self.emitter_to_action_mean_log_var
|
|
132
150
|
|
|
133
151
|
else: # else internal rl phase
|
|
134
|
-
|
|
152
|
+
|
|
153
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
|
|
135
154
|
readout = self.action_proposer_mean_log_var
|
|
136
155
|
|
|
137
156
|
# sample from the gaussian as the action from the meta controller
|
|
@@ -144,35 +163,37 @@ class MetaController(Module):
|
|
|
144
163
|
|
|
145
164
|
batch, _, dim = sampled_action.shape
|
|
146
165
|
|
|
147
|
-
switching_unit_gru_out,
|
|
166
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
|
|
148
167
|
|
|
149
168
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
150
169
|
|
|
151
170
|
# need to encourage normal distribution
|
|
152
171
|
|
|
153
|
-
|
|
172
|
+
kl_loss = self.zero
|
|
154
173
|
|
|
155
174
|
if discovery_phase:
|
|
156
175
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
157
176
|
|
|
158
|
-
|
|
177
|
+
kl_loss = (0.5 * (
|
|
159
178
|
log_var.exp()
|
|
160
179
|
+ mean.square()
|
|
161
180
|
- log_var
|
|
162
181
|
- 1.
|
|
163
182
|
))
|
|
164
183
|
|
|
165
|
-
|
|
166
|
-
|
|
184
|
+
kl_loss = kl_loss * switch_beta
|
|
185
|
+
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
167
186
|
|
|
168
187
|
# maybe hard switch, then use associative scan
|
|
169
188
|
|
|
170
189
|
if hard_switch:
|
|
171
|
-
|
|
172
|
-
switch_beta = straight_through(switch_beta,
|
|
190
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
191
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
173
192
|
|
|
174
193
|
forget = 1. - switch_beta
|
|
175
|
-
gated_action = self.switch_gating(switch_beta, sampled_action * forget)
|
|
194
|
+
gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
|
|
195
|
+
|
|
196
|
+
next_switch_gated_action = gated_action[:, -1]
|
|
176
197
|
|
|
177
198
|
# decoder
|
|
178
199
|
|
|
@@ -187,10 +208,23 @@ class MetaController(Module):
|
|
|
187
208
|
|
|
188
209
|
modified_residual_stream = residual_stream + control_signal
|
|
189
210
|
|
|
190
|
-
|
|
211
|
+
# returning
|
|
212
|
+
|
|
213
|
+
next_hiddens = (
|
|
214
|
+
next_action_proposer_hidden,
|
|
215
|
+
next_switching_unit_gru_hidden,
|
|
216
|
+
next_switch_gated_action
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
|
|
191
220
|
|
|
192
221
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
193
222
|
|
|
223
|
+
TransformerOutput = namedtuple('TransformerOutput', (
|
|
224
|
+
'residual_stream_latent',
|
|
225
|
+
'prev_hiddens'
|
|
226
|
+
))
|
|
227
|
+
|
|
194
228
|
class Transformer(Module):
|
|
195
229
|
def __init__(
|
|
196
230
|
self,
|
|
@@ -249,29 +283,61 @@ class Transformer(Module):
|
|
|
249
283
|
self,
|
|
250
284
|
ids,
|
|
251
285
|
meta_controller: Module | None = None,
|
|
286
|
+
cache: TransformerOutput | None = None,
|
|
252
287
|
discovery_phase = False,
|
|
253
|
-
|
|
288
|
+
no_grad_transformer = None,
|
|
289
|
+
no_grad_meta_controller = None,
|
|
290
|
+
return_latents = False,
|
|
291
|
+
return_cache = False
|
|
254
292
|
):
|
|
255
293
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
256
294
|
|
|
257
|
-
|
|
295
|
+
meta_controlling = exists(meta_controller)
|
|
296
|
+
|
|
297
|
+
# by default, if meta controller is passed in, transformer is no grad
|
|
258
298
|
|
|
259
|
-
|
|
299
|
+
no_grad_transformer = default(no_grad_transformer, meta_controlling)
|
|
300
|
+
no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
|
|
301
|
+
|
|
302
|
+
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
303
|
+
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
304
|
+
|
|
305
|
+
# handle cache
|
|
306
|
+
|
|
307
|
+
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
308
|
+
|
|
309
|
+
# transformer lower body
|
|
310
|
+
|
|
311
|
+
with transformer_context():
|
|
312
|
+
|
|
313
|
+
embed = self.embed(ids)
|
|
314
|
+
|
|
315
|
+
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
260
316
|
|
|
261
317
|
# meta controller acts on residual stream here
|
|
262
318
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
319
|
+
with meta_controller_context():
|
|
320
|
+
|
|
321
|
+
if exists(meta_controller):
|
|
322
|
+
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
|
|
323
|
+
else:
|
|
324
|
+
modified_residual_stream, next_meta_hiddens = residual_stream, None
|
|
325
|
+
|
|
326
|
+
# modified residual stream sent back to transformer upper body
|
|
327
|
+
|
|
328
|
+
with transformer_context():
|
|
329
|
+
|
|
330
|
+
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
331
|
+
|
|
332
|
+
# head readout
|
|
267
333
|
|
|
268
|
-
|
|
334
|
+
dist_params = self.readout(attended)
|
|
269
335
|
|
|
270
|
-
|
|
336
|
+
# returning
|
|
271
337
|
|
|
272
|
-
|
|
338
|
+
return_one = not (return_latents or return_cache)
|
|
273
339
|
|
|
274
|
-
if
|
|
340
|
+
if return_one:
|
|
275
341
|
return dist_params
|
|
276
342
|
|
|
277
|
-
return dist_params,
|
|
343
|
+
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -26,9 +26,14 @@ def test_metacontroller(
|
|
|
26
26
|
switch_per_latent_dim = switch_per_latent_dim
|
|
27
27
|
)
|
|
28
28
|
|
|
29
|
-
logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
|
|
29
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
30
30
|
|
|
31
31
|
assert logits.shape == (1, 1024, 256)
|
|
32
32
|
|
|
33
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
34
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
35
|
+
|
|
36
|
+
assert logits.shape == (1, 1, 256)
|
|
37
|
+
|
|
33
38
|
model.meta_controller = meta_controller
|
|
34
39
|
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
{metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.10}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|