metacontroller-pytorch 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,8 +289,6 @@ class Transformer(Module):
289
289
  meta_controller: Module | None = None,
290
290
  cache: TransformerOutput | None = None,
291
291
  discovery_phase = False,
292
- no_grad_transformer = None,
293
- no_grad_meta_controller = None,
294
292
  meta_controller_temperature = 1.,
295
293
  return_latents = False,
296
294
  return_cache = False,
@@ -301,11 +299,9 @@ class Transformer(Module):
301
299
 
302
300
  # by default, if meta controller is passed in, transformer is no grad
303
301
 
304
- no_grad_transformer = default(no_grad_transformer, meta_controlling)
305
- no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
306
-
307
- transformer_context = torch.no_grad if no_grad_transformer else nullcontext
308
- meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
302
+ lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
303
+ meta_controller_context = nullcontext if meta_controlling else torch.no_grad
304
+ upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
309
305
 
310
306
  # handle cache
311
307
 
@@ -313,7 +309,7 @@ class Transformer(Module):
313
309
 
314
310
  # transformer lower body
315
311
 
316
- with transformer_context():
312
+ with lower_transformer_context():
317
313
 
318
314
  embed = self.embed(ids)
319
315
 
@@ -330,7 +326,7 @@ class Transformer(Module):
330
326
 
331
327
  # modified residual stream sent back to transformer upper body
332
328
 
333
- with transformer_context():
329
+ with upper_transformer_context():
334
330
 
335
331
  attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
336
332
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=rQJLyXMHHNCZm0iohWqAkcMpYSi8b1Z0dgB-8AJVqqo,10751
3
+ metacontroller_pytorch-0.0.14.dist-info/METADATA,sha256=-CP3Ak1NPaqTpyF4tTgwn-T47Pv2OiPzPFxecwGe3Ng,3736
4
+ metacontroller_pytorch-0.0.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.14.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=11Du7hg4Fj3epfOY42xa9IZO9b0bhquBddenWliwIMY,10956
3
- metacontroller_pytorch-0.0.12.dist-info/METADATA,sha256=fKGk-rrrTXv1ueGzehcMV2k2G5n_Tmyk-VpXNo1gcFE,3736
4
- metacontroller_pytorch-0.0.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.12.dist-info/RECORD,,