metacontroller-pytorch 0.0.12__tar.gz → 0.0.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/metacontroller/metacontroller.py +36 -22
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/pyproject.toml +1 -1
- metacontroller_pytorch-0.0.15/tests/test_metacontroller.py +57 -0
- metacontroller_pytorch-0.0.12/tests/test_metacontroller.py +0 -39
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/README.md +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/metacontroller/metacontroller.py
RENAMED
|
@@ -22,7 +22,7 @@ from x_transformers import Decoder
|
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
25
|
-
from discrete_continuous_embed_readout import Embed, Readout
|
|
25
|
+
from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
@@ -234,30 +234,25 @@ class Transformer(Module):
|
|
|
234
234
|
self,
|
|
235
235
|
dim,
|
|
236
236
|
*,
|
|
237
|
-
|
|
237
|
+
state_embed_readout: dict,
|
|
238
|
+
action_embed_readout: dict,
|
|
238
239
|
lower_body: Decoder | dict,
|
|
239
240
|
upper_body: Decoder | dict,
|
|
240
|
-
readout: Readout | dict,
|
|
241
241
|
meta_controller: MetaController | None = None
|
|
242
242
|
):
|
|
243
243
|
super().__init__()
|
|
244
244
|
|
|
245
|
-
if isinstance(embed, dict):
|
|
246
|
-
embed = Embed(dim = dim, **embed)
|
|
247
|
-
|
|
248
245
|
if isinstance(lower_body, dict):
|
|
249
246
|
lower_body = Decoder(dim = dim, **lower_body)
|
|
250
247
|
|
|
251
248
|
if isinstance(upper_body, dict):
|
|
252
249
|
upper_body = Decoder(dim = dim, **upper_body)
|
|
253
250
|
|
|
254
|
-
|
|
255
|
-
|
|
251
|
+
self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
|
|
252
|
+
self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
|
|
256
253
|
|
|
257
|
-
self.embed = embed
|
|
258
254
|
self.lower_body = lower_body
|
|
259
255
|
self.upper_body = upper_body
|
|
260
|
-
self.readout = readout
|
|
261
256
|
|
|
262
257
|
# meta controller
|
|
263
258
|
|
|
@@ -285,13 +280,13 @@ class Transformer(Module):
|
|
|
285
280
|
|
|
286
281
|
def forward(
|
|
287
282
|
self,
|
|
288
|
-
|
|
283
|
+
state,
|
|
284
|
+
action_ids,
|
|
289
285
|
meta_controller: Module | None = None,
|
|
290
286
|
cache: TransformerOutput | None = None,
|
|
291
287
|
discovery_phase = False,
|
|
292
|
-
no_grad_transformer = None,
|
|
293
|
-
no_grad_meta_controller = None,
|
|
294
288
|
meta_controller_temperature = 1.,
|
|
289
|
+
return_raw_action_dist = False,
|
|
295
290
|
return_latents = False,
|
|
296
291
|
return_cache = False,
|
|
297
292
|
):
|
|
@@ -299,23 +294,32 @@ class Transformer(Module):
|
|
|
299
294
|
|
|
300
295
|
meta_controlling = exists(meta_controller)
|
|
301
296
|
|
|
302
|
-
|
|
297
|
+
behavioral_cloning = not meta_controlling and not return_raw_action_dist
|
|
303
298
|
|
|
304
|
-
|
|
305
|
-
no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
|
|
299
|
+
# by default, if meta controller is passed in, transformer is no grad
|
|
306
300
|
|
|
307
|
-
|
|
308
|
-
meta_controller_context =
|
|
301
|
+
lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
|
|
302
|
+
meta_controller_context = nullcontext if meta_controlling else torch.no_grad
|
|
303
|
+
upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
|
|
309
304
|
|
|
310
305
|
# handle cache
|
|
311
306
|
|
|
312
307
|
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
313
308
|
|
|
309
|
+
# handle maybe behavioral cloning
|
|
310
|
+
|
|
311
|
+
if behavioral_cloning:
|
|
312
|
+
state, target_state = state[:, :-1], state[:, 1:]
|
|
313
|
+
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
314
|
+
|
|
314
315
|
# transformer lower body
|
|
315
316
|
|
|
316
|
-
with
|
|
317
|
+
with lower_transformer_context():
|
|
317
318
|
|
|
318
|
-
|
|
319
|
+
state_embed = self.state_embed(state)
|
|
320
|
+
action_embed = self.action_embed(action_ids)
|
|
321
|
+
|
|
322
|
+
embed = state_embed + action_embed
|
|
319
323
|
|
|
320
324
|
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
321
325
|
|
|
@@ -330,13 +334,23 @@ class Transformer(Module):
|
|
|
330
334
|
|
|
331
335
|
# modified residual stream sent back to transformer upper body
|
|
332
336
|
|
|
333
|
-
with
|
|
337
|
+
with upper_transformer_context():
|
|
334
338
|
|
|
335
339
|
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
336
340
|
|
|
337
341
|
# head readout
|
|
338
342
|
|
|
339
|
-
dist_params = self.
|
|
343
|
+
dist_params = self.action_readout(attended)
|
|
344
|
+
|
|
345
|
+
# maybe return behavior cloning loss
|
|
346
|
+
|
|
347
|
+
if behavioral_cloning:
|
|
348
|
+
state_dist_params = self.state_readout(attended)
|
|
349
|
+
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
350
|
+
|
|
351
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
352
|
+
|
|
353
|
+
return state_clone_loss, action_clone_loss
|
|
340
354
|
|
|
341
355
|
# returning
|
|
342
356
|
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
|
+
|
|
7
|
+
@param('action_discrete', (False, True))
|
|
8
|
+
@param('discovery_phase', (False, True))
|
|
9
|
+
@param('switch_per_latent_dim', (False, True))
|
|
10
|
+
def test_metacontroller(
|
|
11
|
+
action_discrete,
|
|
12
|
+
discovery_phase,
|
|
13
|
+
switch_per_latent_dim
|
|
14
|
+
):
|
|
15
|
+
|
|
16
|
+
state = torch.randn(1, 1024, 384)
|
|
17
|
+
|
|
18
|
+
if action_discrete:
|
|
19
|
+
actions = torch.randint(0, 4, (1, 1024))
|
|
20
|
+
action_embed_readout = dict(num_discrete = 4)
|
|
21
|
+
assert_shape = (4,)
|
|
22
|
+
else:
|
|
23
|
+
actions = torch.randn(1, 1024, 8)
|
|
24
|
+
action_embed_readout = dict(num_continuous = 8)
|
|
25
|
+
assert_shape = (8, 2)
|
|
26
|
+
|
|
27
|
+
# behavioral cloning pahse
|
|
28
|
+
|
|
29
|
+
model = Transformer(
|
|
30
|
+
dim = 512,
|
|
31
|
+
action_embed_readout = action_embed_readout,
|
|
32
|
+
state_embed_readout = dict(num_continuous = 384),
|
|
33
|
+
lower_body = dict(depth = 2,),
|
|
34
|
+
upper_body = dict(depth = 2,),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
state_clone_loss, action_clone_loss = model(state, actions)
|
|
38
|
+
(state_clone_loss + 0.5 * action_clone_loss).backward()
|
|
39
|
+
|
|
40
|
+
# discovery and internal rl phase with meta controller
|
|
41
|
+
|
|
42
|
+
meta_controller = MetaController(
|
|
43
|
+
dim_latent = 512,
|
|
44
|
+
switch_per_latent_dim = switch_per_latent_dim
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
48
|
+
|
|
49
|
+
assert logits.shape == (1, 1024, *assert_shape)
|
|
50
|
+
|
|
51
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
52
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
53
|
+
|
|
54
|
+
assert logits.shape == (1, 1, *assert_shape)
|
|
55
|
+
|
|
56
|
+
model.meta_controller = meta_controller
|
|
57
|
+
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
param = pytest.mark.parametrize
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
|
-
|
|
7
|
-
@param('discovery_phase', (False, True))
|
|
8
|
-
@param('switch_per_latent_dim', (False, True))
|
|
9
|
-
def test_metacontroller(
|
|
10
|
-
discovery_phase,
|
|
11
|
-
switch_per_latent_dim
|
|
12
|
-
):
|
|
13
|
-
|
|
14
|
-
ids = torch.randint(0, 256, (1, 1024))
|
|
15
|
-
|
|
16
|
-
model = Transformer(
|
|
17
|
-
512,
|
|
18
|
-
embed = dict(num_discrete = 256),
|
|
19
|
-
lower_body = dict(depth = 2,),
|
|
20
|
-
upper_body = dict(depth = 2,),
|
|
21
|
-
readout = dict(num_discrete = 256)
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
meta_controller = MetaController(
|
|
25
|
-
512,
|
|
26
|
-
switch_per_latent_dim = switch_per_latent_dim
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
30
|
-
|
|
31
|
-
assert logits.shape == (1, 1024, 256)
|
|
32
|
-
|
|
33
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
34
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
35
|
-
|
|
36
|
-
assert logits.shape == (1, 1, 256)
|
|
37
|
-
|
|
38
|
-
model.meta_controller = meta_controller
|
|
39
|
-
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
{metacontroller_pytorch-0.0.12 → metacontroller_pytorch-0.0.15}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|