metacontroller-pytorch 0.0.30__py3-none-any.whl → 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- metacontroller/metacontroller.py +19 -4
- metacontroller/metacontroller_with_binary_mapper.py +26 -11
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.32.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.32.dist-info/RECORD +8 -0
- metacontroller_pytorch-0.0.30.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.32.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.32.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -61,6 +61,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
61
61
|
'input_residual_stream',
|
|
62
62
|
'action_dist',
|
|
63
63
|
'actions',
|
|
64
|
+
'switch_beta',
|
|
64
65
|
'kl_loss',
|
|
65
66
|
'switch_loss'
|
|
66
67
|
))
|
|
@@ -106,7 +107,6 @@ class MetaController(Module):
|
|
|
106
107
|
|
|
107
108
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
108
109
|
|
|
109
|
-
|
|
110
110
|
self.dim_latent = dim_latent
|
|
111
111
|
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
112
112
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
@@ -146,12 +146,19 @@ class MetaController(Module):
|
|
|
146
146
|
*self.action_proposer_mean_log_var.parameters()
|
|
147
147
|
]
|
|
148
148
|
|
|
149
|
+
def log_prob(
|
|
150
|
+
self,
|
|
151
|
+
action_dist,
|
|
152
|
+
sampled_latent_action
|
|
153
|
+
):
|
|
154
|
+
return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
|
|
155
|
+
|
|
149
156
|
def forward(
|
|
150
157
|
self,
|
|
151
158
|
residual_stream,
|
|
152
159
|
cache: MetaControllerOutput | None = None,
|
|
153
160
|
discovery_phase = False,
|
|
154
|
-
hard_switch =
|
|
161
|
+
hard_switch = None,
|
|
155
162
|
temperature = 1.,
|
|
156
163
|
episode_lens: Tensor | None = None
|
|
157
164
|
):
|
|
@@ -167,6 +174,8 @@ class MetaController(Module):
|
|
|
167
174
|
|
|
168
175
|
meta_embed = self.model_to_meta(residual_stream)
|
|
169
176
|
|
|
177
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
178
|
+
|
|
170
179
|
if discovery_phase:
|
|
171
180
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
172
181
|
|
|
@@ -269,10 +278,16 @@ class MetaController(Module):
|
|
|
269
278
|
sampled_latent_action[:, -1:]
|
|
270
279
|
)
|
|
271
280
|
|
|
272
|
-
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
281
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
273
282
|
|
|
274
283
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
275
284
|
|
|
285
|
+
Hiddens = namedtuple('Hiddens', (
|
|
286
|
+
'lower_body',
|
|
287
|
+
'meta_controller',
|
|
288
|
+
'upper_body'
|
|
289
|
+
))
|
|
290
|
+
|
|
276
291
|
TransformerOutput = namedtuple('TransformerOutput', (
|
|
277
292
|
'residual_stream_latent',
|
|
278
293
|
'prev_hiddens'
|
|
@@ -438,4 +453,4 @@ class Transformer(Module):
|
|
|
438
453
|
if return_one:
|
|
439
454
|
return dist_params
|
|
440
455
|
|
|
441
|
-
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
456
|
+
return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -28,6 +28,8 @@ from torch_einops_utils.save_load import save_load
|
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
30
30
|
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput
|
|
32
|
+
|
|
31
33
|
# constants
|
|
32
34
|
|
|
33
35
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -48,16 +50,10 @@ def default(*args):
|
|
|
48
50
|
def straight_through(src, tgt):
|
|
49
51
|
return tgt + src - src.detach()
|
|
50
52
|
|
|
51
|
-
|
|
53
|
+
def log(t, eps = 1e-20):
|
|
54
|
+
return t.clamp_min(eps).log()
|
|
52
55
|
|
|
53
|
-
|
|
54
|
-
'prev_hiddens',
|
|
55
|
-
'input_residual_stream',
|
|
56
|
-
'action_dist',
|
|
57
|
-
'codes',
|
|
58
|
-
'kl_loss',
|
|
59
|
-
'switch_loss'
|
|
60
|
-
))
|
|
56
|
+
# meta controller
|
|
61
57
|
|
|
62
58
|
@save_load()
|
|
63
59
|
class MetaControllerWithBinaryMapper(Module):
|
|
@@ -144,12 +140,29 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
144
140
|
*self.proposer_to_binary_logits.parameters()
|
|
145
141
|
]
|
|
146
142
|
|
|
143
|
+
def log_prob(
|
|
144
|
+
self,
|
|
145
|
+
action_dist,
|
|
146
|
+
sampled_latent_action
|
|
147
|
+
):
|
|
148
|
+
action_prob = action_dist.sigmoid()
|
|
149
|
+
probs = stack((action_prob, 1. - action_prob), dim = -1)
|
|
150
|
+
log_probs = log(probs)
|
|
151
|
+
|
|
152
|
+
indices = sampled_latent_action.argmax(dim = -1)
|
|
153
|
+
codes = self.binary_mapper.codes[indices].long()
|
|
154
|
+
|
|
155
|
+
codes = rearrange(codes, '... -> ... 1')
|
|
156
|
+
action_log_probs = log_probs.gather(-1, codes)
|
|
157
|
+
|
|
158
|
+
return rearrange(action_log_probs, '... 1 -> ...')
|
|
159
|
+
|
|
147
160
|
def forward(
|
|
148
161
|
self,
|
|
149
162
|
residual_stream,
|
|
150
163
|
cache: MetaControllerOutput | None = None,
|
|
151
164
|
discovery_phase = False,
|
|
152
|
-
hard_switch =
|
|
165
|
+
hard_switch = None,
|
|
153
166
|
temperature = 1.,
|
|
154
167
|
episode_lens: Tensor | None = None
|
|
155
168
|
):
|
|
@@ -165,6 +178,8 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
165
178
|
|
|
166
179
|
meta_embed = self.model_to_meta(residual_stream)
|
|
167
180
|
|
|
181
|
+
hard_switch = default(hard_switch, not discovery_phase) # think during internal RL phase, it needs to be a hard switch, then only the actions emitted during the switch is reinforced
|
|
182
|
+
|
|
168
183
|
if discovery_phase:
|
|
169
184
|
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
170
185
|
|
|
@@ -266,4 +281,4 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
266
281
|
sampled_codes[:, -1:]
|
|
267
282
|
)
|
|
268
283
|
|
|
269
|
-
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, kl_loss, switch_loss)
|
|
284
|
+
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=somE9gX36c1d9hF2n8Qn4foRY8krHGodvrvulhkIGE8,15006
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=CTGK8ruQ3TkioVUwFTHdrbfzubaeuhSdXHfHtaDcwMY,8813
|
|
4
|
+
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
+
metacontroller_pytorch-0.0.32.dist-info/METADATA,sha256=hr08iXm6Mb-rnDu2xPrr9YQ6cwTtX1F79MfBYt54Y94,4747
|
|
6
|
+
metacontroller_pytorch-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.32.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=ydkL3gYW5WGXQdQOIJQ_gibJs74laIIx-v4DmcJHi7M,14497
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=OGal6dftRPBg_QT1LNDYejNGNlmh4MBvdM41FAQJp9Y,8153
|
|
4
|
-
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
-
metacontroller_pytorch-0.0.30.dist-info/METADATA,sha256=ghasc1GA0ZM-AZimY0FnGuRFsezVIcbI49V6TIOWeq4,4747
|
|
6
|
-
metacontroller_pytorch-0.0.30.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.30.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.30.dist-info → metacontroller_pytorch-0.0.32.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|