metacontroller-pytorch 0.0.14__tar.gz → 0.0.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/metacontroller/metacontroller.py +32 -14
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/pyproject.toml +1 -1
- metacontroller_pytorch-0.0.15/tests/test_metacontroller.py +57 -0
- metacontroller_pytorch-0.0.14/tests/test_metacontroller.py +0 -39
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/README.md +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/metacontroller/metacontroller.py
RENAMED
|
@@ -22,7 +22,7 @@ from x_transformers import Decoder
|
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
25
|
-
from discrete_continuous_embed_readout import Embed, Readout
|
|
25
|
+
from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
@@ -234,30 +234,25 @@ class Transformer(Module):
|
|
|
234
234
|
self,
|
|
235
235
|
dim,
|
|
236
236
|
*,
|
|
237
|
-
|
|
237
|
+
state_embed_readout: dict,
|
|
238
|
+
action_embed_readout: dict,
|
|
238
239
|
lower_body: Decoder | dict,
|
|
239
240
|
upper_body: Decoder | dict,
|
|
240
|
-
readout: Readout | dict,
|
|
241
241
|
meta_controller: MetaController | None = None
|
|
242
242
|
):
|
|
243
243
|
super().__init__()
|
|
244
244
|
|
|
245
|
-
if isinstance(embed, dict):
|
|
246
|
-
embed = Embed(dim = dim, **embed)
|
|
247
|
-
|
|
248
245
|
if isinstance(lower_body, dict):
|
|
249
246
|
lower_body = Decoder(dim = dim, **lower_body)
|
|
250
247
|
|
|
251
248
|
if isinstance(upper_body, dict):
|
|
252
249
|
upper_body = Decoder(dim = dim, **upper_body)
|
|
253
250
|
|
|
254
|
-
|
|
255
|
-
|
|
251
|
+
self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
|
|
252
|
+
self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
|
|
256
253
|
|
|
257
|
-
self.embed = embed
|
|
258
254
|
self.lower_body = lower_body
|
|
259
255
|
self.upper_body = upper_body
|
|
260
|
-
self.readout = readout
|
|
261
256
|
|
|
262
257
|
# meta controller
|
|
263
258
|
|
|
@@ -285,11 +280,13 @@ class Transformer(Module):
|
|
|
285
280
|
|
|
286
281
|
def forward(
|
|
287
282
|
self,
|
|
288
|
-
|
|
283
|
+
state,
|
|
284
|
+
action_ids,
|
|
289
285
|
meta_controller: Module | None = None,
|
|
290
286
|
cache: TransformerOutput | None = None,
|
|
291
287
|
discovery_phase = False,
|
|
292
288
|
meta_controller_temperature = 1.,
|
|
289
|
+
return_raw_action_dist = False,
|
|
293
290
|
return_latents = False,
|
|
294
291
|
return_cache = False,
|
|
295
292
|
):
|
|
@@ -297,21 +294,32 @@ class Transformer(Module):
|
|
|
297
294
|
|
|
298
295
|
meta_controlling = exists(meta_controller)
|
|
299
296
|
|
|
297
|
+
behavioral_cloning = not meta_controlling and not return_raw_action_dist
|
|
298
|
+
|
|
300
299
|
# by default, if meta controller is passed in, transformer is no grad
|
|
301
300
|
|
|
302
301
|
lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
|
|
303
302
|
meta_controller_context = nullcontext if meta_controlling else torch.no_grad
|
|
304
|
-
upper_transformer_context = nullcontext if meta_controlling
|
|
303
|
+
upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
|
|
305
304
|
|
|
306
305
|
# handle cache
|
|
307
306
|
|
|
308
307
|
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
309
308
|
|
|
309
|
+
# handle maybe behavioral cloning
|
|
310
|
+
|
|
311
|
+
if behavioral_cloning:
|
|
312
|
+
state, target_state = state[:, :-1], state[:, 1:]
|
|
313
|
+
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
314
|
+
|
|
310
315
|
# transformer lower body
|
|
311
316
|
|
|
312
317
|
with lower_transformer_context():
|
|
313
318
|
|
|
314
|
-
|
|
319
|
+
state_embed = self.state_embed(state)
|
|
320
|
+
action_embed = self.action_embed(action_ids)
|
|
321
|
+
|
|
322
|
+
embed = state_embed + action_embed
|
|
315
323
|
|
|
316
324
|
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
317
325
|
|
|
@@ -332,7 +340,17 @@ class Transformer(Module):
|
|
|
332
340
|
|
|
333
341
|
# head readout
|
|
334
342
|
|
|
335
|
-
dist_params = self.
|
|
343
|
+
dist_params = self.action_readout(attended)
|
|
344
|
+
|
|
345
|
+
# maybe return behavior cloning loss
|
|
346
|
+
|
|
347
|
+
if behavioral_cloning:
|
|
348
|
+
state_dist_params = self.state_readout(attended)
|
|
349
|
+
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
350
|
+
|
|
351
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
352
|
+
|
|
353
|
+
return state_clone_loss, action_clone_loss
|
|
336
354
|
|
|
337
355
|
# returning
|
|
338
356
|
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
|
+
|
|
7
|
+
@param('action_discrete', (False, True))
|
|
8
|
+
@param('discovery_phase', (False, True))
|
|
9
|
+
@param('switch_per_latent_dim', (False, True))
|
|
10
|
+
def test_metacontroller(
|
|
11
|
+
action_discrete,
|
|
12
|
+
discovery_phase,
|
|
13
|
+
switch_per_latent_dim
|
|
14
|
+
):
|
|
15
|
+
|
|
16
|
+
state = torch.randn(1, 1024, 384)
|
|
17
|
+
|
|
18
|
+
if action_discrete:
|
|
19
|
+
actions = torch.randint(0, 4, (1, 1024))
|
|
20
|
+
action_embed_readout = dict(num_discrete = 4)
|
|
21
|
+
assert_shape = (4,)
|
|
22
|
+
else:
|
|
23
|
+
actions = torch.randn(1, 1024, 8)
|
|
24
|
+
action_embed_readout = dict(num_continuous = 8)
|
|
25
|
+
assert_shape = (8, 2)
|
|
26
|
+
|
|
27
|
+
# behavioral cloning pahse
|
|
28
|
+
|
|
29
|
+
model = Transformer(
|
|
30
|
+
dim = 512,
|
|
31
|
+
action_embed_readout = action_embed_readout,
|
|
32
|
+
state_embed_readout = dict(num_continuous = 384),
|
|
33
|
+
lower_body = dict(depth = 2,),
|
|
34
|
+
upper_body = dict(depth = 2,),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
state_clone_loss, action_clone_loss = model(state, actions)
|
|
38
|
+
(state_clone_loss + 0.5 * action_clone_loss).backward()
|
|
39
|
+
|
|
40
|
+
# discovery and internal rl phase with meta controller
|
|
41
|
+
|
|
42
|
+
meta_controller = MetaController(
|
|
43
|
+
dim_latent = 512,
|
|
44
|
+
switch_per_latent_dim = switch_per_latent_dim
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
48
|
+
|
|
49
|
+
assert logits.shape == (1, 1024, *assert_shape)
|
|
50
|
+
|
|
51
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
52
|
+
logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
53
|
+
|
|
54
|
+
assert logits.shape == (1, 1, *assert_shape)
|
|
55
|
+
|
|
56
|
+
model.meta_controller = meta_controller
|
|
57
|
+
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
param = pytest.mark.parametrize
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
|
-
|
|
7
|
-
@param('discovery_phase', (False, True))
|
|
8
|
-
@param('switch_per_latent_dim', (False, True))
|
|
9
|
-
def test_metacontroller(
|
|
10
|
-
discovery_phase,
|
|
11
|
-
switch_per_latent_dim
|
|
12
|
-
):
|
|
13
|
-
|
|
14
|
-
ids = torch.randint(0, 256, (1, 1024))
|
|
15
|
-
|
|
16
|
-
model = Transformer(
|
|
17
|
-
512,
|
|
18
|
-
embed = dict(num_discrete = 256),
|
|
19
|
-
lower_body = dict(depth = 2,),
|
|
20
|
-
upper_body = dict(depth = 2,),
|
|
21
|
-
readout = dict(num_discrete = 256)
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
meta_controller = MetaController(
|
|
25
|
-
512,
|
|
26
|
-
switch_per_latent_dim = switch_per_latent_dim
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
30
|
-
|
|
31
|
-
assert logits.shape == (1, 1024, 256)
|
|
32
|
-
|
|
33
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
34
|
-
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
35
|
-
|
|
36
|
-
assert logits.shape == (1, 1, 256)
|
|
37
|
-
|
|
38
|
-
model.meta_controller = meta_controller
|
|
39
|
-
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
{metacontroller_pytorch-0.0.14 → metacontroller_pytorch-0.0.15}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|