metacontroller-pytorch 0.0.5__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.5
3
+ Version: 0.0.8
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: assoc-scan
37
+ Requires-Dist: assoc-scan>=0.0.3
38
38
  Requires-Dist: discrete-continuous-embed-readout>=0.1.11
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
@@ -113,9 +113,7 @@ class MetaController(Module):
113
113
  def internal_rl_parameters(self):
114
114
  return [
115
115
  *self.action_proposer.parameters(),
116
- *self.action_proposer_mean_log_var.parameters(),
117
- *self.decoder.parameters(),
118
- *self.switch_gating
116
+ *self.action_proposer_mean_log_var.parameters()
119
117
  ]
120
118
 
121
119
  def forward(
@@ -150,9 +148,6 @@ class MetaController(Module):
150
148
 
151
149
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
150
 
153
- action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
154
- switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
155
-
156
151
  # need to encourage normal distribution
157
152
 
158
153
  vae_kl_loss = self.zero
@@ -165,10 +160,10 @@ class MetaController(Module):
165
160
  + mean.square()
166
161
  - log_var
167
162
  - 1.
168
- )).sum(dim = -1)
163
+ ))
169
164
 
170
165
  vae_kl_loss = vae_kl_loss * switch_beta
171
- vae_kl_loss = vae_kl_loss.mean()
166
+ vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
172
167
 
173
168
  # maybe hard switch, then use associative scan
174
169
 
@@ -177,20 +172,18 @@ class MetaController(Module):
177
172
  switch_beta = straight_through(switch_beta, hard_switch)
178
173
 
179
174
  forget = 1. - switch_beta
180
- gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
181
-
182
- gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
175
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget)
183
176
 
184
177
  # decoder
185
178
 
186
- decoder_out = self.decoder(gated_action_intent)
179
+ decoder_out = self.decoder(gated_action)
187
180
 
188
181
  w1, w2 = self.to_hyper_network_weights(decoder_out)
189
182
  hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
190
183
 
191
184
  # generating the residual stream controlling signal
192
185
 
193
- control_signal = einsum(gated_action_intent, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
186
+ control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
194
187
 
195
188
  modified_residual_stream = residual_stream + control_signal
196
189
 
@@ -236,13 +229,15 @@ class Transformer(Module):
236
229
 
237
230
  def evolve(
238
231
  self,
232
+ num_generations,
239
233
  environment,
240
234
  **kwargs
241
235
  ):
242
- assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
236
+ assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
243
237
 
244
238
  evo_strat = EvoStrategy(
245
239
  self,
240
+ num_generations = num_generations,
246
241
  environment = environment,
247
242
  params_to_optimize = self.meta_controller.internal_rl_parameters(),
248
243
  **kwargs
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.5"
3
+ version = "0.0.8"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,7 +24,7 @@ classifiers=[
24
24
  ]
25
25
 
26
26
  dependencies = [
27
- "assoc-scan",
27
+ "assoc-scan>=0.0.3",
28
28
  "einx>=0.3.0",
29
29
  "einops>=0.8.1",
30
30
  "discrete-continuous-embed-readout>=0.1.11",
@@ -29,3 +29,6 @@ def test_metacontroller(
29
29
  logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
30
30
 
31
31
  assert logits.shape == (1, 1024, 256)
32
+
33
+ model.meta_controller = meta_controller
34
+ model.evolve(1, lambda _: 1., noise_population_size = 2)