metacontroller-pytorch 0.0.29__tar.gz → 0.0.31__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/metacontroller/metacontroller.py +6 -2
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/metacontroller/metacontroller_with_binary_mapper.py +6 -10
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/tests/test_metacontroller.py +1 -1
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/README.md +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/metacontroller/metacontroller_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/train_babyai.py +0 -0
- {metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/train_behavior_clone_babyai.py +0 -0
{metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/metacontroller/metacontroller.py
RENAMED
|
@@ -58,8 +58,10 @@ def straight_through(src, tgt):
|
|
|
58
58
|
|
|
59
59
|
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
60
60
|
'prev_hiddens',
|
|
61
|
+
'input_residual_stream',
|
|
61
62
|
'action_dist',
|
|
62
63
|
'actions',
|
|
64
|
+
'switch_beta',
|
|
63
65
|
'kl_loss',
|
|
64
66
|
'switch_loss'
|
|
65
67
|
))
|
|
@@ -150,7 +152,7 @@ class MetaController(Module):
|
|
|
150
152
|
residual_stream,
|
|
151
153
|
cache: MetaControllerOutput | None = None,
|
|
152
154
|
discovery_phase = False,
|
|
153
|
-
hard_switch =
|
|
155
|
+
hard_switch = None,
|
|
154
156
|
temperature = 1.,
|
|
155
157
|
episode_lens: Tensor | None = None
|
|
156
158
|
):
|
|
@@ -166,6 +168,8 @@ class MetaController(Module):
|
|
|
166
168
|
|
|
167
169
|
meta_embed = self.model_to_meta(residual_stream)
|
|
168
170
|
|
|
171
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
172
|
+
|
|
169
173
|
if discovery_phase:
|
|
170
174
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
171
175
|
|
|
@@ -268,7 +272,7 @@ class MetaController(Module):
|
|
|
268
272
|
sampled_latent_action[:, -1:]
|
|
269
273
|
)
|
|
270
274
|
|
|
271
|
-
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
275
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
272
276
|
|
|
273
277
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
274
278
|
|
|
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
|
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
30
30
|
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput
|
|
32
|
+
|
|
31
33
|
# constants
|
|
32
34
|
|
|
33
35
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -50,14 +52,6 @@ def straight_through(src, tgt):
|
|
|
50
52
|
|
|
51
53
|
# meta controller
|
|
52
54
|
|
|
53
|
-
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
54
|
-
'prev_hiddens',
|
|
55
|
-
'action_dist',
|
|
56
|
-
'codes',
|
|
57
|
-
'kl_loss',
|
|
58
|
-
'switch_loss'
|
|
59
|
-
))
|
|
60
|
-
|
|
61
55
|
@save_load()
|
|
62
56
|
class MetaControllerWithBinaryMapper(Module):
|
|
63
57
|
def __init__(
|
|
@@ -148,7 +142,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
148
142
|
residual_stream,
|
|
149
143
|
cache: MetaControllerOutput | None = None,
|
|
150
144
|
discovery_phase = False,
|
|
151
|
-
hard_switch =
|
|
145
|
+
hard_switch = None,
|
|
152
146
|
temperature = 1.,
|
|
153
147
|
episode_lens: Tensor | None = None
|
|
154
148
|
):
|
|
@@ -164,6 +158,8 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
164
158
|
|
|
165
159
|
meta_embed = self.model_to_meta(residual_stream)
|
|
166
160
|
|
|
161
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
162
|
+
|
|
167
163
|
if discovery_phase:
|
|
168
164
|
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
169
165
|
|
|
@@ -265,4 +261,4 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
265
261
|
sampled_codes[:, -1:]
|
|
266
262
|
)
|
|
267
263
|
|
|
268
|
-
return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
|
|
264
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
{metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.29 → metacontroller_pytorch-0.0.31}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|