metacontroller-pytorch 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +5 -9
- {metacontroller_pytorch-0.0.12.dist-info → metacontroller_pytorch-0.0.14.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.14.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.12.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.12.dist-info → metacontroller_pytorch-0.0.14.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.12.dist-info → metacontroller_pytorch-0.0.14.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -289,8 +289,6 @@ class Transformer(Module):
|
|
|
289
289
|
meta_controller: Module | None = None,
|
|
290
290
|
cache: TransformerOutput | None = None,
|
|
291
291
|
discovery_phase = False,
|
|
292
|
-
no_grad_transformer = None,
|
|
293
|
-
no_grad_meta_controller = None,
|
|
294
292
|
meta_controller_temperature = 1.,
|
|
295
293
|
return_latents = False,
|
|
296
294
|
return_cache = False,
|
|
@@ -301,11 +299,9 @@ class Transformer(Module):
|
|
|
301
299
|
|
|
302
300
|
# by default, if meta controller is passed in, transformer is no grad
|
|
303
301
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
308
|
-
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
302
|
+
lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
|
|
303
|
+
meta_controller_context = nullcontext if meta_controlling else torch.no_grad
|
|
304
|
+
upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
|
|
309
305
|
|
|
310
306
|
# handle cache
|
|
311
307
|
|
|
@@ -313,7 +309,7 @@ class Transformer(Module):
|
|
|
313
309
|
|
|
314
310
|
# transformer lower body
|
|
315
311
|
|
|
316
|
-
with
|
|
312
|
+
with lower_transformer_context():
|
|
317
313
|
|
|
318
314
|
embed = self.embed(ids)
|
|
319
315
|
|
|
@@ -330,7 +326,7 @@ class Transformer(Module):
|
|
|
330
326
|
|
|
331
327
|
# modified residual stream sent back to transformer upper body
|
|
332
328
|
|
|
333
|
-
with
|
|
329
|
+
with upper_transformer_context():
|
|
334
330
|
|
|
335
331
|
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
336
332
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=rQJLyXMHHNCZm0iohWqAkcMpYSi8b1Z0dgB-8AJVqqo,10751
|
|
3
|
+
metacontroller_pytorch-0.0.14.dist-info/METADATA,sha256=-CP3Ak1NPaqTpyF4tTgwn-T47Pv2OiPzPFxecwGe3Ng,3736
|
|
4
|
+
metacontroller_pytorch-0.0.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.14.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=11Du7hg4Fj3epfOY42xa9IZO9b0bhquBddenWliwIMY,10956
|
|
3
|
-
metacontroller_pytorch-0.0.12.dist-info/METADATA,sha256=fKGk-rrrTXv1ueGzehcMV2k2G5n_Tmyk-VpXNo1gcFE,3736
|
|
4
|
-
metacontroller_pytorch-0.0.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.12.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.12.dist-info → metacontroller_pytorch-0.0.14.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|