metacontroller-pytorch 0.0.16__tar.gz → 0.0.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/metacontroller/metacontroller.py +9 -4
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/tests/test_metacontroller.py +2 -2
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/README.md +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/metacontroller/metacontroller.py
RENAMED
|
@@ -57,7 +57,8 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
57
57
|
'prev_hiddens',
|
|
58
58
|
'action_dist',
|
|
59
59
|
'actions',
|
|
60
|
-
'kl_loss'
|
|
60
|
+
'kl_loss',
|
|
61
|
+
'switch_loss'
|
|
61
62
|
))
|
|
62
63
|
|
|
63
64
|
class MetaController(Module):
|
|
@@ -173,7 +174,7 @@ class MetaController(Module):
|
|
|
173
174
|
|
|
174
175
|
# need to encourage normal distribution
|
|
175
176
|
|
|
176
|
-
kl_loss = self.zero
|
|
177
|
+
kl_loss = switch_loss = self.zero
|
|
177
178
|
|
|
178
179
|
if discovery_phase:
|
|
179
180
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
@@ -188,6 +189,10 @@ class MetaController(Module):
|
|
|
188
189
|
kl_loss = kl_loss * switch_beta
|
|
189
190
|
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
190
191
|
|
|
192
|
+
# encourage less switching
|
|
193
|
+
|
|
194
|
+
switch_loss = switch_beta.mean()
|
|
195
|
+
|
|
191
196
|
# maybe hard switch, then use associative scan
|
|
192
197
|
|
|
193
198
|
if hard_switch:
|
|
@@ -220,7 +225,7 @@ class MetaController(Module):
|
|
|
220
225
|
next_switch_gated_action
|
|
221
226
|
)
|
|
222
227
|
|
|
223
|
-
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
|
|
228
|
+
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
|
|
224
229
|
|
|
225
230
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
226
231
|
|
|
@@ -357,7 +362,7 @@ class Transformer(Module):
|
|
|
357
362
|
|
|
358
363
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
359
364
|
|
|
360
|
-
return action_recon_loss, next_meta_hiddens.kl_loss
|
|
365
|
+
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
361
366
|
|
|
362
367
|
# returning
|
|
363
368
|
|
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/tests/test_metacontroller.py
RENAMED
|
@@ -44,8 +44,8 @@ def test_metacontroller(
|
|
|
44
44
|
|
|
45
45
|
# discovery phase
|
|
46
46
|
|
|
47
|
-
(action_recon_loss, kl_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
48
|
-
(action_recon_loss + kl_loss * 0.1).backward()
|
|
47
|
+
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
48
|
+
(action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
|
|
49
49
|
|
|
50
50
|
# internal rl
|
|
51
51
|
|
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.17}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|