metacontroller-pytorch 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,6 +43,11 @@ def default(*args):
43
43
  return arg
44
44
  return None
45
45
 
46
+ # tensor helpers
47
+
48
+ def straight_through(src, tgt):
49
+ return tgt + src - src.detach()
50
+
46
51
  # meta controller
47
52
 
48
53
  class MetaController(Module):
@@ -116,7 +121,8 @@ class MetaController(Module):
116
121
  def forward(
117
122
  self,
118
123
  residual_stream,
119
- discovery_phase = False
124
+ discovery_phase = False,
125
+ hard_switch = False
120
126
  ):
121
127
 
122
128
  if discovery_phase:
@@ -136,6 +142,16 @@ class MetaController(Module):
136
142
 
137
143
  sampled_action = readout.sample(action_dist)
138
144
 
145
+ # switching unit timer
146
+
147
+ batch, _, dim = sampled_action.shape
148
+
149
+ switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
150
+
151
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
+
153
+ # switch_beta = switch_beta.expand_as(sampled_action)
154
+
139
155
  # need to encourage normal distribution
140
156
 
141
157
  vae_kl_loss = self.zero
@@ -148,34 +164,30 @@ class MetaController(Module):
148
164
  + mean.square()
149
165
  - log_var
150
166
  - 1.
151
- )).sum(dim = -1).mean()
167
+ ))
152
168
 
153
- # switching unit timer
169
+ vae_kl_loss = vae_kl_loss * switch_beta
170
+ vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
154
171
 
155
- batch, _, dim = sampled_action.shape
172
+ # maybe hard switch, then use associative scan
156
173
 
157
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
158
-
159
- switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
160
-
161
- action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
162
- switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
174
+ if hard_switch:
175
+ hard_switch = (switch_beta > 0.5).float()
176
+ switch_beta = straight_through(switch_beta, hard_switch)
163
177
 
164
178
  forget = 1. - switch_beta
165
- gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
166
-
167
- gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
179
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget)
168
180
 
169
181
  # decoder
170
182
 
171
- decoder_out = self.decoder(gated_action_intent)
183
+ decoder_out = self.decoder(gated_action)
172
184
 
173
185
  w1, w2 = self.to_hyper_network_weights(decoder_out)
174
186
  hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
175
187
 
176
188
  # generating the residual stream controlling signal
177
189
 
178
- control_signal = einsum(gated_action_intent, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
190
+ control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
179
191
 
180
192
  modified_residual_stream = residual_stream + control_signal
181
193
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: assoc-scan
37
+ Requires-Dist: assoc-scan>=0.0.3
38
38
  Requires-Dist: discrete-continuous-embed-readout>=0.1.11
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=8AWkPlZWh2A1sRT6nMV0CGHuYhQ5pHEpd5bgyRmZelg,8316
3
+ metacontroller_pytorch-0.0.6.dist-info/METADATA,sha256=vXT_-n3bHgpddnS5axyyc-cADGNk4l2enJv4g4cTJ7A,3713
4
+ metacontroller_pytorch-0.0.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=fGbNUdxTYGgHBdINJziUINDYmANkNLy0yiTIt4dycKM,8162
3
- metacontroller_pytorch-0.0.4.dist-info/METADATA,sha256=wfeiKctuqzj_NlWq2Xg5hbgjs6bzMmgL-VdTCzgceS8,3706
4
- metacontroller_pytorch-0.0.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.4.dist-info/RECORD,,