metacontroller-pytorch 0.0.1__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/metacontroller/metacontroller.py +12 -4
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/tests/test_metacontroller.py +7 -2
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/README.md +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/metacontroller/metacontroller.py
RENAMED
|
@@ -49,6 +49,7 @@ class MetaController(Module):
|
|
|
49
49
|
self,
|
|
50
50
|
dim_latent,
|
|
51
51
|
*,
|
|
52
|
+
switch_per_latent_dim = True,
|
|
52
53
|
decoder_expansion_factor = 2.,
|
|
53
54
|
decoder_depth = 1,
|
|
54
55
|
hypernetwork_low_rank = 16,
|
|
@@ -70,8 +71,10 @@ class MetaController(Module):
|
|
|
70
71
|
|
|
71
72
|
# switching unit
|
|
72
73
|
|
|
74
|
+
self.switch_per_latent_dim = switch_per_latent_dim
|
|
75
|
+
|
|
73
76
|
self.switching_unit = GRU(dim_latent, dim_latent)
|
|
74
|
-
self.to_switching_unit_beta = nn.Linear(dim_latent, 1, bias = False)
|
|
77
|
+
self.to_switching_unit_beta = nn.Linear(dim_latent, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
75
78
|
|
|
76
79
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
77
80
|
|
|
@@ -154,7 +157,7 @@ class MetaController(Module):
|
|
|
154
157
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
155
158
|
|
|
156
159
|
action_intent_for_gating = rearrange(sampled_action_intents, 'b n d -> (b d) n')
|
|
157
|
-
switch_beta = repeat(switch_beta, 'b n
|
|
160
|
+
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
158
161
|
|
|
159
162
|
forget = 1. - switch_beta
|
|
160
163
|
gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
|
|
@@ -212,6 +215,8 @@ class Transformer(Module):
|
|
|
212
215
|
|
|
213
216
|
self.meta_controller = meta_controller
|
|
214
217
|
|
|
218
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
219
|
+
|
|
215
220
|
def evolve(
|
|
216
221
|
self,
|
|
217
222
|
environment,
|
|
@@ -235,7 +240,7 @@ class Transformer(Module):
|
|
|
235
240
|
discovery_phase = False,
|
|
236
241
|
return_latents = False
|
|
237
242
|
):
|
|
238
|
-
meta_controller = default(meta_controller, self.meta_controller
|
|
243
|
+
meta_controller = default(meta_controller, self.meta_controller)
|
|
239
244
|
|
|
240
245
|
embed = self.embed(ids)
|
|
241
246
|
|
|
@@ -243,7 +248,10 @@ class Transformer(Module):
|
|
|
243
248
|
|
|
244
249
|
# meta controller acts on residual stream here
|
|
245
250
|
|
|
246
|
-
|
|
251
|
+
if exists(meta_controller):
|
|
252
|
+
modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
253
|
+
else:
|
|
254
|
+
modified_residual_stream, vae_aux_loss = residual_stream, self.zero
|
|
247
255
|
|
|
248
256
|
# modified residual stream sent back
|
|
249
257
|
|
|
@@ -5,8 +5,10 @@ import torch
|
|
|
5
5
|
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
6
|
|
|
7
7
|
@param('discovery_phase', (False, True))
|
|
8
|
+
@param('switch_per_latent_dim', (False, True))
|
|
8
9
|
def test_metacontroller(
|
|
9
|
-
discovery_phase
|
|
10
|
+
discovery_phase,
|
|
11
|
+
switch_per_latent_dim
|
|
10
12
|
):
|
|
11
13
|
|
|
12
14
|
ids = torch.randint(0, 256, (1, 1024))
|
|
@@ -19,7 +21,10 @@ def test_metacontroller(
|
|
|
19
21
|
readout = dict(num_discrete = 256)
|
|
20
22
|
)
|
|
21
23
|
|
|
22
|
-
meta_controller = MetaController(
|
|
24
|
+
meta_controller = MetaController(
|
|
25
|
+
512,
|
|
26
|
+
switch_per_latent_dim = switch_per_latent_dim
|
|
27
|
+
)
|
|
23
28
|
|
|
24
29
|
logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
|
|
25
30
|
|
{metacontroller_pytorch-0.0.1 → metacontroller_pytorch-0.0.3}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|