metacontroller-pytorch 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +9 -18
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.23.dist-info}/METADATA +2 -1
- metacontroller_pytorch-0.0.23.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.21.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.23.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.23.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -26,6 +26,8 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
+
from torch_einops_utils import pad_at_dim
|
|
30
|
+
|
|
29
31
|
# constants
|
|
30
32
|
|
|
31
33
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -46,17 +48,6 @@ def default(*args):
|
|
|
46
48
|
return arg
|
|
47
49
|
return None
|
|
48
50
|
|
|
49
|
-
def is_empty(t):
|
|
50
|
-
return t.numel() == 0
|
|
51
|
-
|
|
52
|
-
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
|
53
|
-
if pad == (0, 0):
|
|
54
|
-
return t
|
|
55
|
-
|
|
56
|
-
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
57
|
-
zeros = ((0, 0) * dims_from_right)
|
|
58
|
-
return F.pad(t, (*zeros, *pad), value = value)
|
|
59
|
-
|
|
60
51
|
# tensor helpers
|
|
61
52
|
|
|
62
53
|
def straight_through(src, tgt):
|
|
@@ -333,7 +324,7 @@ class Transformer(Module):
|
|
|
333
324
|
def forward(
|
|
334
325
|
self,
|
|
335
326
|
state,
|
|
336
|
-
|
|
327
|
+
actions: Tensor | None = None,
|
|
337
328
|
meta_controller: Module | None = None,
|
|
338
329
|
cache: TransformerOutput | None = None,
|
|
339
330
|
discovery_phase = False,
|
|
@@ -363,10 +354,10 @@ class Transformer(Module):
|
|
|
363
354
|
# handle maybe behavioral cloning
|
|
364
355
|
|
|
365
356
|
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
366
|
-
assert
|
|
357
|
+
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
367
358
|
|
|
368
359
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
369
|
-
|
|
360
|
+
actions, target_actions = actions[:, :-1], actions[:, 1:]
|
|
370
361
|
|
|
371
362
|
# transformer lower body
|
|
372
363
|
|
|
@@ -376,8 +367,8 @@ class Transformer(Module):
|
|
|
376
367
|
|
|
377
368
|
# handle no past action for first timestep
|
|
378
369
|
|
|
379
|
-
if exists(
|
|
380
|
-
action_embed = self.action_embed(
|
|
370
|
+
if exists(actions):
|
|
371
|
+
action_embed = self.action_embed(actions)
|
|
381
372
|
else:
|
|
382
373
|
action_embed = state_embed[:, 0:0] # empty action embed
|
|
383
374
|
|
|
@@ -415,13 +406,13 @@ class Transformer(Module):
|
|
|
415
406
|
state_dist_params = self.state_readout(attended)
|
|
416
407
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
417
408
|
|
|
418
|
-
action_clone_loss = self.action_readout.calculate_loss(dist_params,
|
|
409
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
419
410
|
|
|
420
411
|
return state_clone_loss, action_clone_loss
|
|
421
412
|
|
|
422
413
|
elif meta_controlling and discovery_phase:
|
|
423
414
|
|
|
424
|
-
action_recon_loss = self.action_readout.calculate_loss(dist_params,
|
|
415
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
425
416
|
|
|
426
417
|
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
427
418
|
|
{metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.23.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.23
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -40,6 +40,7 @@ Requires-Dist: einops>=0.8.1
|
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.1
|
|
43
|
+
Requires-Dist: torch-einops-utils>=0.0.7
|
|
43
44
|
Requires-Dist: torch>=2.5
|
|
44
45
|
Requires-Dist: x-evolution>=0.1.23
|
|
45
46
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=mxXLeUTJ2AeYYUzpJQeE3TOrtZG_D45vfkb3jTNeEb8,13864
|
|
3
|
+
metacontroller_pytorch-0.0.23.dist-info/METADATA,sha256=md2Ew7uonopeMK6YFkgfcXrs_3AoIQ_5gzu9Om3Fhuc,4361
|
|
4
|
+
metacontroller_pytorch-0.0.23.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.23.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
|
|
3
|
-
metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
|
|
4
|
-
metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.21.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.23.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|