metacontroller-pytorch 0.0.6__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
2
4
  from functools import partial
3
5
  from collections import namedtuple
4
6
 
@@ -113,9 +115,7 @@ class MetaController(Module):
113
115
  def internal_rl_parameters(self):
114
116
  return [
115
117
  *self.action_proposer.parameters(),
116
- *self.action_proposer_mean_log_var.parameters(),
117
- *self.decoder.parameters(),
118
- *self.switch_gating
118
+ *self.action_proposer_mean_log_var.parameters()
119
119
  ]
120
120
 
121
121
  def forward(
@@ -150,8 +150,6 @@ class MetaController(Module):
150
150
 
151
151
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
152
 
153
- # switch_beta = switch_beta.expand_as(sampled_action)
154
-
155
153
  # need to encourage normal distribution
156
154
 
157
155
  vae_kl_loss = self.zero
@@ -233,13 +231,15 @@ class Transformer(Module):
233
231
 
234
232
  def evolve(
235
233
  self,
234
+ num_generations,
236
235
  environment,
237
236
  **kwargs
238
237
  ):
239
- assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
238
+ assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
240
239
 
241
240
  evo_strat = EvoStrategy(
242
241
  self,
242
+ num_generations = num_generations,
243
243
  environment = environment,
244
244
  params_to_optimize = self.meta_controller.internal_rl_parameters(),
245
245
  **kwargs
@@ -252,26 +252,50 @@ class Transformer(Module):
252
252
  ids,
253
253
  meta_controller: Module | None = None,
254
254
  discovery_phase = False,
255
- return_latents = False
255
+ return_latents = False,
256
+ no_grad_transformer = None,
257
+ no_grad_meta_controller = None
256
258
  ):
257
259
  meta_controller = default(meta_controller, self.meta_controller)
258
260
 
259
- embed = self.embed(ids)
261
+ meta_controlling = exists(meta_controller)
262
+
263
+ # by default, if meta controller is passed in, transformer is no grad
264
+
265
+ no_grad_transformer = default(no_grad_transformer, meta_controlling)
266
+ no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
267
+
268
+ transformer_context = torch.no_grad if no_grad_transformer else nullcontext
269
+ meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
260
270
 
261
- residual_stream = self.lower_body(embed)
271
+ # transformer lower body
272
+
273
+ with transformer_context():
274
+
275
+ embed = self.embed(ids)
276
+
277
+ residual_stream = self.lower_body(embed)
262
278
 
263
279
  # meta controller acts on residual stream here
264
280
 
265
- if exists(meta_controller):
266
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
267
- else:
268
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
281
+ with meta_controller_context():
282
+
283
+ if exists(meta_controller):
284
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
285
+ else:
286
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
287
+
288
+ # modified residual stream sent back to transformer upper body
289
+
290
+ with transformer_context():
291
+
292
+ attended = self.upper_body(modified_residual_stream)
269
293
 
270
- # modified residual stream sent back
294
+ # head readout
271
295
 
272
- attended = self.upper_body(modified_residual_stream)
296
+ dist_params = self.readout(attended)
273
297
 
274
- dist_params = self.readout(attended)
298
+ # returning
275
299
 
276
300
  if not return_latents:
277
301
  return dist_params
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.6
3
+ Version: 0.0.9
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=V2Nb7ByGj310CalTzho-grwNsoHMp55oN5spkedJihc,9189
3
+ metacontroller_pytorch-0.0.9.dist-info/METADATA,sha256=BA4AHlFW8DsD_NPXNv8N8rmRPISZNTkcjvGautB7xJA,3713
4
+ metacontroller_pytorch-0.0.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.9.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=8AWkPlZWh2A1sRT6nMV0CGHuYhQ5pHEpd5bgyRmZelg,8316
3
- metacontroller_pytorch-0.0.6.dist-info/METADATA,sha256=vXT_-n3bHgpddnS5axyyc-cADGNk4l2enJv4g4cTJ7A,3713
4
- metacontroller_pytorch-0.0.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.6.dist-info/RECORD,,