metacontroller-pytorch 0.0.31__tar.gz → 0.0.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/metacontroller/metacontroller.py +14 -2
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/metacontroller/metacontroller_with_binary_mapper.py +20 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/tests/test_metacontroller.py +6 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/README.md +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/metacontroller/metacontroller_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/train_babyai.py +0 -0
- {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/train_behavior_clone_babyai.py +0 -0
{metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/metacontroller/metacontroller.py
RENAMED
|
@@ -107,7 +107,6 @@ class MetaController(Module):
|
|
|
107
107
|
|
|
108
108
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
109
109
|
|
|
110
|
-
|
|
111
110
|
self.dim_latent = dim_latent
|
|
112
111
|
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
113
112
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
@@ -147,6 +146,13 @@ class MetaController(Module):
|
|
|
147
146
|
*self.action_proposer_mean_log_var.parameters()
|
|
148
147
|
]
|
|
149
148
|
|
|
149
|
+
def log_prob(
|
|
150
|
+
self,
|
|
151
|
+
action_dist,
|
|
152
|
+
sampled_latent_action
|
|
153
|
+
):
|
|
154
|
+
return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
|
|
155
|
+
|
|
150
156
|
def forward(
|
|
151
157
|
self,
|
|
152
158
|
residual_stream,
|
|
@@ -276,6 +282,12 @@ class MetaController(Module):
|
|
|
276
282
|
|
|
277
283
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
278
284
|
|
|
285
|
+
Hiddens = namedtuple('Hiddens', (
|
|
286
|
+
'lower_body',
|
|
287
|
+
'meta_controller',
|
|
288
|
+
'upper_body'
|
|
289
|
+
))
|
|
290
|
+
|
|
279
291
|
TransformerOutput = namedtuple('TransformerOutput', (
|
|
280
292
|
'residual_stream_latent',
|
|
281
293
|
'prev_hiddens'
|
|
@@ -441,4 +453,4 @@ class Transformer(Module):
|
|
|
441
453
|
if return_one:
|
|
442
454
|
return dist_params
|
|
443
455
|
|
|
444
|
-
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
456
|
+
return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -50,6 +50,9 @@ def default(*args):
|
|
|
50
50
|
def straight_through(src, tgt):
|
|
51
51
|
return tgt + src - src.detach()
|
|
52
52
|
|
|
53
|
+
def log(t, eps = 1e-20):
|
|
54
|
+
return t.clamp_min(eps).log()
|
|
55
|
+
|
|
53
56
|
# meta controller
|
|
54
57
|
|
|
55
58
|
@save_load()
|
|
@@ -137,6 +140,23 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
137
140
|
*self.proposer_to_binary_logits.parameters()
|
|
138
141
|
]
|
|
139
142
|
|
|
143
|
+
def log_prob(
|
|
144
|
+
self,
|
|
145
|
+
action_dist,
|
|
146
|
+
sampled_latent_action
|
|
147
|
+
):
|
|
148
|
+
action_prob = action_dist.sigmoid()
|
|
149
|
+
probs = stack((action_prob, 1. - action_prob), dim = -1)
|
|
150
|
+
log_probs = log(probs)
|
|
151
|
+
|
|
152
|
+
indices = sampled_latent_action.argmax(dim = -1)
|
|
153
|
+
codes = self.binary_mapper.codes[indices].long()
|
|
154
|
+
|
|
155
|
+
codes = rearrange(codes, '... -> ... 1')
|
|
156
|
+
action_log_probs = log_probs.gather(-1, codes)
|
|
157
|
+
|
|
158
|
+
return rearrange(action_log_probs, '... 1 -> ...')
|
|
159
|
+
|
|
140
160
|
def forward(
|
|
141
161
|
self,
|
|
142
162
|
residual_stream,
|
{metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/tests/test_metacontroller.py
RENAMED
|
@@ -80,6 +80,12 @@ def test_metacontroller(
|
|
|
80
80
|
assert logits.shape == (2, 1, *assert_shape)
|
|
81
81
|
past_action_id = model.action_readout.sample(logits)
|
|
82
82
|
|
|
83
|
+
# get log prob from meta controller latent actions
|
|
84
|
+
|
|
85
|
+
meta_controller_hidden = cache.prev_hiddens.meta_controller
|
|
86
|
+
|
|
87
|
+
old_log_probs = meta_controller.log_prob(meta_controller_hidden.action_dist, meta_controller_hidden.actions)
|
|
88
|
+
|
|
83
89
|
# evolutionary strategies over grpo
|
|
84
90
|
|
|
85
91
|
model.meta_controller = meta_controller
|
{metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.32}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|