metacontroller-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- metacontroller/metacontroller.py +5 -2
- metacontroller/metacontroller_with_binary_mapper.py +6 -11
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.31.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.31.dist-info/RECORD +8 -0
- metacontroller_pytorch-0.0.30.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.31.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.31.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -61,6 +61,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
61
61
|
'input_residual_stream',
|
|
62
62
|
'action_dist',
|
|
63
63
|
'actions',
|
|
64
|
+
'switch_beta',
|
|
64
65
|
'kl_loss',
|
|
65
66
|
'switch_loss'
|
|
66
67
|
))
|
|
@@ -151,7 +152,7 @@ class MetaController(Module):
|
|
|
151
152
|
residual_stream,
|
|
152
153
|
cache: MetaControllerOutput | None = None,
|
|
153
154
|
discovery_phase = False,
|
|
154
|
-
hard_switch =
|
|
155
|
+
hard_switch = None,
|
|
155
156
|
temperature = 1.,
|
|
156
157
|
episode_lens: Tensor | None = None
|
|
157
158
|
):
|
|
@@ -167,6 +168,8 @@ class MetaController(Module):
|
|
|
167
168
|
|
|
168
169
|
meta_embed = self.model_to_meta(residual_stream)
|
|
169
170
|
|
|
171
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
172
|
+
|
|
170
173
|
if discovery_phase:
|
|
171
174
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
172
175
|
|
|
@@ -269,7 +272,7 @@ class MetaController(Module):
|
|
|
269
272
|
sampled_latent_action[:, -1:]
|
|
270
273
|
)
|
|
271
274
|
|
|
272
|
-
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
275
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
273
276
|
|
|
274
277
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
275
278
|
|
|
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
|
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
30
30
|
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput
|
|
32
|
+
|
|
31
33
|
# constants
|
|
32
34
|
|
|
33
35
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -50,15 +52,6 @@ def straight_through(src, tgt):
|
|
|
50
52
|
|
|
51
53
|
# meta controller
|
|
52
54
|
|
|
53
|
-
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
54
|
-
'prev_hiddens',
|
|
55
|
-
'input_residual_stream',
|
|
56
|
-
'action_dist',
|
|
57
|
-
'codes',
|
|
58
|
-
'kl_loss',
|
|
59
|
-
'switch_loss'
|
|
60
|
-
))
|
|
61
|
-
|
|
62
55
|
@save_load()
|
|
63
56
|
class MetaControllerWithBinaryMapper(Module):
|
|
64
57
|
def __init__(
|
|
@@ -149,7 +142,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
149
142
|
residual_stream,
|
|
150
143
|
cache: MetaControllerOutput | None = None,
|
|
151
144
|
discovery_phase = False,
|
|
152
|
-
hard_switch =
|
|
145
|
+
hard_switch = None,
|
|
153
146
|
temperature = 1.,
|
|
154
147
|
episode_lens: Tensor | None = None
|
|
155
148
|
):
|
|
@@ -165,6 +158,8 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
165
158
|
|
|
166
159
|
meta_embed = self.model_to_meta(residual_stream)
|
|
167
160
|
|
|
161
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
162
|
+
|
|
168
163
|
if discovery_phase:
|
|
169
164
|
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
170
165
|
|
|
@@ -266,4 +261,4 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
266
261
|
sampled_codes[:, -1:]
|
|
267
262
|
)
|
|
268
263
|
|
|
269
|
-
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, kl_loss, switch_loss)
|
|
264
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
|
|
4
|
+
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
+
metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
|
|
6
|
+
metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.31.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=ydkL3gYW5WGXQdQOIJQ_gibJs74laIIx-v4DmcJHi7M,14497
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=OGal6dftRPBg_QT1LNDYejNGNlmh4MBvdM41FAQJp9Y,8153
|
|
4
|
-
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
-
metacontroller_pytorch-0.0.30.dist-info/METADATA,sha256=ghasc1GA0ZM-AZimY0FnGuRFsezVIcbI49V6TIOWeq4,4747
|
|
6
|
-
metacontroller_pytorch-0.0.30.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.30.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.31.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|