metacontroller-pytorch 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -43,6 +43,11 @@ def default(*args):
43
43
  return arg
44
44
  return None
45
45
 
46
+ # tensor helpers
47
+
48
+ def straight_through(src, tgt):
49
+ return tgt + src - src.detach()
50
+
46
51
  # meta controller
47
52
 
48
53
  class MetaController(Module):
@@ -116,7 +121,8 @@ class MetaController(Module):
116
121
  def forward(
117
122
  self,
118
123
  residual_stream,
119
- discovery_phase = False
124
+ discovery_phase = False,
125
+ hard_switch = False
120
126
  ):
121
127
 
122
128
  if discovery_phase:
@@ -136,6 +142,17 @@ class MetaController(Module):
136
142
 
137
143
  sampled_action = readout.sample(action_dist)
138
144
 
145
+ # switching unit timer
146
+
147
+ batch, _, dim = sampled_action.shape
148
+
149
+ switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
150
+
151
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
+
153
+ action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
154
+ switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
155
+
139
156
  # need to encourage normal distribution
140
157
 
141
158
  vae_kl_loss = self.zero
@@ -148,18 +165,16 @@ class MetaController(Module):
148
165
  + mean.square()
149
166
  - log_var
150
167
  - 1.
151
- )).sum(dim = -1).mean()
168
+ )).sum(dim = -1)
152
169
 
153
- # switching unit timer
170
+ vae_kl_loss = vae_kl_loss * switch_beta
171
+ vae_kl_loss = vae_kl_loss.mean()
154
172
 
155
- batch, _, dim = sampled_action.shape
173
+ # maybe hard switch, then use associative scan
156
174
 
157
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
158
-
159
- switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
160
-
161
- action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
162
- switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
175
+ if hard_switch:
176
+ hard_switch = (switch_beta > 0.5).float()
177
+ switch_beta = straight_through(switch_beta, hard_switch)
163
178
 
164
179
  forget = 1. - switch_beta
165
180
  gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }