metacontroller-pytorch 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
2
4
  from functools import partial
3
5
  from collections import namedtuple
4
6
 
@@ -250,26 +252,50 @@ class Transformer(Module):
250
252
  ids,
251
253
  meta_controller: Module | None = None,
252
254
  discovery_phase = False,
253
- return_latents = False
255
+ return_latents = False,
256
+ no_grad_transformer = None,
257
+ no_grad_meta_controller = None
254
258
  ):
255
259
  meta_controller = default(meta_controller, self.meta_controller)
256
260
 
257
- embed = self.embed(ids)
261
+ meta_controlling = exists(meta_controller)
262
+
263
+ # by default, if meta controller is passed in, transformer is no grad
264
+
265
+ no_grad_transformer = default(no_grad_transformer, meta_controlling)
266
+ no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
267
+
268
+ transformer_context = torch.no_grad if no_grad_transformer else nullcontext
269
+ meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
270
+
271
+ # transformer lower body
272
+
273
+ with transformer_context():
258
274
 
259
- residual_stream = self.lower_body(embed)
275
+ embed = self.embed(ids)
276
+
277
+ residual_stream = self.lower_body(embed)
260
278
 
261
279
  # meta controller acts on residual stream here
262
280
 
263
- if exists(meta_controller):
264
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
265
- else:
266
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
281
+ with meta_controller_context():
282
+
283
+ if exists(meta_controller):
284
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
285
+ else:
286
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
287
+
288
+ # modified residual stream sent back to transformer upper body
289
+
290
+ with transformer_context():
291
+
292
+ attended = self.upper_body(modified_residual_stream)
267
293
 
268
- # modified residual stream sent back
294
+ # head readout
269
295
 
270
- attended = self.upper_body(modified_residual_stream)
296
+ dist_params = self.readout(attended)
271
297
 
272
- dist_params = self.readout(attended)
298
+ # returning
273
299
 
274
300
  if not return_latents:
275
301
  return dist_params
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=V2Nb7ByGj310CalTzho-grwNsoHMp55oN5spkedJihc,9189
3
+ metacontroller_pytorch-0.0.9.dist-info/METADATA,sha256=BA4AHlFW8DsD_NPXNv8N8rmRPISZNTkcjvGautB7xJA,3713
4
+ metacontroller_pytorch-0.0.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.9.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=cbo0F861KIcGIhJ1j-Js6Qwfm_-8nm5Sm0LJiRO7hl0,8265
3
- metacontroller_pytorch-0.0.8.dist-info/METADATA,sha256=a7aUiVugnv5PJ-AZqnCyEczWsmuUS30s-3DsBKuThNQ,3713
4
- metacontroller_pytorch-0.0.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.8.dist-info/RECORD,,