metacontroller-pytorch 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +16 -4
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.17.dist-info}/METADATA +2 -2
- metacontroller_pytorch-0.0.17.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.15.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.17.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.17.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -57,7 +57,8 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
57
57
|
'prev_hiddens',
|
|
58
58
|
'action_dist',
|
|
59
59
|
'actions',
|
|
60
|
-
'kl_loss'
|
|
60
|
+
'kl_loss',
|
|
61
|
+
'switch_loss'
|
|
61
62
|
))
|
|
62
63
|
|
|
63
64
|
class MetaController(Module):
|
|
@@ -173,7 +174,7 @@ class MetaController(Module):
|
|
|
173
174
|
|
|
174
175
|
# need to encourage normal distribution
|
|
175
176
|
|
|
176
|
-
kl_loss = self.zero
|
|
177
|
+
kl_loss = switch_loss = self.zero
|
|
177
178
|
|
|
178
179
|
if discovery_phase:
|
|
179
180
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
@@ -188,6 +189,10 @@ class MetaController(Module):
|
|
|
188
189
|
kl_loss = kl_loss * switch_beta
|
|
189
190
|
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
190
191
|
|
|
192
|
+
# encourage less switching
|
|
193
|
+
|
|
194
|
+
switch_loss = switch_beta.mean()
|
|
195
|
+
|
|
191
196
|
# maybe hard switch, then use associative scan
|
|
192
197
|
|
|
193
198
|
if hard_switch:
|
|
@@ -220,7 +225,7 @@ class MetaController(Module):
|
|
|
220
225
|
next_switch_gated_action
|
|
221
226
|
)
|
|
222
227
|
|
|
223
|
-
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
|
|
228
|
+
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
|
|
224
229
|
|
|
225
230
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
226
231
|
|
|
@@ -308,7 +313,8 @@ class Transformer(Module):
|
|
|
308
313
|
|
|
309
314
|
# handle maybe behavioral cloning
|
|
310
315
|
|
|
311
|
-
if behavioral_cloning:
|
|
316
|
+
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
317
|
+
|
|
312
318
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
313
319
|
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
314
320
|
|
|
@@ -352,6 +358,12 @@ class Transformer(Module):
|
|
|
352
358
|
|
|
353
359
|
return state_clone_loss, action_clone_loss
|
|
354
360
|
|
|
361
|
+
elif meta_controlling and discovery_phase:
|
|
362
|
+
|
|
363
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
364
|
+
|
|
365
|
+
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
366
|
+
|
|
355
367
|
# returning
|
|
356
368
|
|
|
357
369
|
return_one = not (return_latents or return_cache)
|
{metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.17.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -60,7 +60,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
|
|
|
60
60
|
@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
|
|
61
61
|
title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
|
|
62
62
|
author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
|
|
63
|
-
year={2025},
|
|
63
|
+
year = {2025},
|
|
64
64
|
eprint = {2512.20605},
|
|
65
65
|
archivePrefix = {arXiv},
|
|
66
66
|
primaryClass = {cs.LG},
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=blxDztbtXyP3cNbjnM3fEw_KZdLFJp_l1Sub6-7zIKg,12041
|
|
3
|
+
metacontroller_pytorch-0.0.17.dist-info/METADATA,sha256=_8hYYTO_ME23kgZXqSfhA1XXAA8W877F-AL8amA7LKM,3741
|
|
4
|
+
metacontroller_pytorch-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.17.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=ug3xeMTZKApTF8oOPx9hWypeDjRflf1IJp8RiysXgTo,11618
|
|
3
|
-
metacontroller_pytorch-0.0.15.dist-info/METADATA,sha256=9d39BpcuVeOVVSD66lCVHCK1GjrkeKzRtxKOPOc-7xQ,3736
|
|
4
|
-
metacontroller_pytorch-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.15.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.15.dist-info → metacontroller_pytorch-0.0.17.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|