metacontroller-pytorch 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,7 +57,8 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
57
  'prev_hiddens',
58
58
  'action_dist',
59
59
  'actions',
60
- 'kl_loss'
60
+ 'kl_loss',
61
+ 'switch_loss'
61
62
  ))
62
63
 
63
64
  class MetaController(Module):
@@ -173,7 +174,7 @@ class MetaController(Module):
173
174
 
174
175
  # need to encourage normal distribution
175
176
 
176
- kl_loss = self.zero
177
+ kl_loss = switch_loss = self.zero
177
178
 
178
179
  if discovery_phase:
179
180
  mean, log_var = action_dist.unbind(dim = -1)
@@ -188,6 +189,10 @@ class MetaController(Module):
188
189
  kl_loss = kl_loss * switch_beta
189
190
  kl_loss = kl_loss.sum(dim = -1).mean()
190
191
 
192
+ # encourage less switching
193
+
194
+ switch_loss = switch_beta.mean()
195
+
191
196
  # maybe hard switch, then use associative scan
192
197
 
193
198
  if hard_switch:
@@ -220,7 +225,7 @@ class MetaController(Module):
220
225
  next_switch_gated_action
221
226
  )
222
227
 
223
- return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
228
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
224
229
 
225
230
  # main transformer, which is subsumed into the environment after behavioral cloning
226
231
 
@@ -357,7 +362,7 @@ class Transformer(Module):
357
362
 
358
363
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
359
364
 
360
- return action_recon_loss, next_meta_hiddens.kl_loss
365
+ return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
361
366
 
362
367
  # returning
363
368
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.16
3
+ Version: 0.0.17
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=blxDztbtXyP3cNbjnM3fEw_KZdLFJp_l1Sub6-7zIKg,12041
3
+ metacontroller_pytorch-0.0.17.dist-info/METADATA,sha256=_8hYYTO_ME23kgZXqSfhA1XXAA8W877F-AL8amA7LKM,3741
4
+ metacontroller_pytorch-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.17.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=BT7GH8F9NkEIYLEueBkkZ8glQ3Oht1FRoV84SIaTWdQ,11878
3
- metacontroller_pytorch-0.0.16.dist-info/METADATA,sha256=eyECb3994X58zyExLnnffMl3pOoMlIb-WAUhepIt0r8,3741
4
- metacontroller_pytorch-0.0.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.16.dist-info/RECORD,,