metacontroller-pytorch 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -150,8 +150,7 @@ class MetaController(Module):
150
150
 
151
151
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
152
 
153
- action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
154
- switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
153
+ # switch_beta = switch_beta.expand_as(sampled_action)
155
154
 
156
155
  # need to encourage normal distribution
157
156
 
@@ -165,10 +164,10 @@ class MetaController(Module):
165
164
  + mean.square()
166
165
  - log_var
167
166
  - 1.
168
- )).sum(dim = -1)
167
+ ))
169
168
 
170
169
  vae_kl_loss = vae_kl_loss * switch_beta
171
- vae_kl_loss = vae_kl_loss.mean()
170
+ vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
172
171
 
173
172
  # maybe hard switch, then use associative scan
174
173
 
@@ -177,20 +176,18 @@ class MetaController(Module):
177
176
  switch_beta = straight_through(switch_beta, hard_switch)
178
177
 
179
178
  forget = 1. - switch_beta
180
- gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
181
-
182
- gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
179
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget)
183
180
 
184
181
  # decoder
185
182
 
186
- decoder_out = self.decoder(gated_action_intent)
183
+ decoder_out = self.decoder(gated_action)
187
184
 
188
185
  w1, w2 = self.to_hyper_network_weights(decoder_out)
189
186
  hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
190
187
 
191
188
  # generating the residual stream controlling signal
192
189
 
193
- control_signal = einsum(gated_action_intent, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
190
+ control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
194
191
 
195
192
  modified_residual_stream = residual_stream + control_signal
196
193
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: assoc-scan
37
+ Requires-Dist: assoc-scan>=0.0.3
38
38
  Requires-Dist: discrete-continuous-embed-readout>=0.1.11
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=8AWkPlZWh2A1sRT6nMV0CGHuYhQ5pHEpd5bgyRmZelg,8316
3
+ metacontroller_pytorch-0.0.6.dist-info/METADATA,sha256=vXT_-n3bHgpddnS5axyyc-cADGNk4l2enJv4g4cTJ7A,3713
4
+ metacontroller_pytorch-0.0.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=-glgWcv6QQZ6wVAy6tK2Ye8QNXuWBiGmB1rs6DVWA1I,8573
3
- metacontroller_pytorch-0.0.5.dist-info/METADATA,sha256=CfXW_uO8B9gz31XkUO-2aVl4TN64iYicdZPtW7DzzHc,3706
4
- metacontroller_pytorch-0.0.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.5.dist-info/RECORD,,