metacontroller-pytorch 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +8 -5
- {metacontroller_pytorch-0.0.25.dist-info → metacontroller_pytorch-0.0.26.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.26.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.25.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.25.dist-info → metacontroller_pytorch-0.0.26.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.25.dist-info → metacontroller_pytorch-0.0.26.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections import namedtuple
|
|
|
6
6
|
from loguru import logger
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, stack, tensor
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
10
|
from torch.nn import Module, GRU, Linear, Identity
|
|
11
11
|
import torch.nn.functional as F
|
|
12
12
|
|
|
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
-
from torch_einops_utils import pad_at_dim, lens_to_mask
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
30
30
|
from torch_einops_utils.save_load import save_load
|
|
31
31
|
|
|
32
32
|
# constants
|
|
@@ -151,7 +151,8 @@ class MetaController(Module):
|
|
|
151
151
|
cache: MetaControllerOutput | None = None,
|
|
152
152
|
discovery_phase = False,
|
|
153
153
|
hard_switch = False,
|
|
154
|
-
temperature = 1
|
|
154
|
+
temperature = 1.,
|
|
155
|
+
episode_lens: Tensor | None = None
|
|
155
156
|
):
|
|
156
157
|
device = residual_stream.device
|
|
157
158
|
|
|
@@ -168,7 +169,9 @@ class MetaController(Module):
|
|
|
168
169
|
if discovery_phase:
|
|
169
170
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
170
171
|
|
|
171
|
-
|
|
172
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
173
|
+
|
|
174
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
172
175
|
|
|
173
176
|
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
174
177
|
readout = self.emitter_to_action_mean_log_var
|
|
@@ -391,7 +394,7 @@ class Transformer(Module):
|
|
|
391
394
|
with meta_controller_context():
|
|
392
395
|
|
|
393
396
|
if exists(meta_controller):
|
|
394
|
-
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
397
|
+
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
395
398
|
else:
|
|
396
399
|
control_signal, next_meta_hiddens = self.zero, None
|
|
397
400
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
|
|
3
|
+
metacontroller_pytorch-0.0.26.dist-info/METADATA,sha256=E00jJkfHS_wsEuh-a4iIo42fQQ1NhX7r-HuSWtyimUQ,4363
|
|
4
|
+
metacontroller_pytorch-0.0.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.26.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=icToDxXPknHG5C5hTzVaVOCibYbJ3aDLmZlaMc3Xge0,14275
|
|
3
|
-
metacontroller_pytorch-0.0.25.dist-info/METADATA,sha256=HItPrlXUrJhZ1ZmpVU8JNftpyazBvJ3GVlOJPWL8NKE,4363
|
|
4
|
-
metacontroller_pytorch-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.25.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.25.dist-info → metacontroller_pytorch-0.0.26.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|