metacontroller-pytorch 0.0.31__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -107,7 +107,6 @@ class MetaController(Module):
107
107
 
108
108
  self.switch_per_latent_dim = switch_per_latent_dim
109
109
 
110
-
111
110
  self.dim_latent = dim_latent
112
111
  self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
113
112
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
@@ -147,6 +146,13 @@ class MetaController(Module):
147
146
  *self.action_proposer_mean_log_var.parameters()
148
147
  ]
149
148
 
149
+ def log_prob(
150
+ self,
151
+ action_dist,
152
+ sampled_latent_action
153
+ ):
154
+ return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
155
+
150
156
  def forward(
151
157
  self,
152
158
  residual_stream,
@@ -276,6 +282,12 @@ class MetaController(Module):
276
282
 
277
283
  # main transformer, which is subsumed into the environment after behavioral cloning
278
284
 
285
+ Hiddens = namedtuple('Hiddens', (
286
+ 'lower_body',
287
+ 'meta_controller',
288
+ 'upper_body'
289
+ ))
290
+
279
291
  TransformerOutput = namedtuple('TransformerOutput', (
280
292
  'residual_stream_latent',
281
293
  'prev_hiddens'
@@ -441,4 +453,4 @@ class Transformer(Module):
441
453
  if return_one:
442
454
  return dist_params
443
455
 
444
- return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
456
+ return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -50,6 +50,9 @@ def default(*args):
50
50
  def straight_through(src, tgt):
51
51
  return tgt + src - src.detach()
52
52
 
53
+ def log(t, eps = 1e-20):
54
+ return t.clamp_min(eps).log()
55
+
53
56
  # meta controller
54
57
 
55
58
  @save_load()
@@ -137,6 +140,23 @@ class MetaControllerWithBinaryMapper(Module):
137
140
  *self.proposer_to_binary_logits.parameters()
138
141
  ]
139
142
 
143
+ def log_prob(
144
+ self,
145
+ action_dist,
146
+ sampled_latent_action
147
+ ):
148
+ action_prob = action_dist.sigmoid()
149
+ probs = stack((action_prob, 1. - action_prob), dim = -1)
150
+ log_probs = log(probs)
151
+
152
+ indices = sampled_latent_action.argmax(dim = -1)
153
+ codes = self.binary_mapper.codes[indices].long()
154
+
155
+ codes = rearrange(codes, '... -> ... 1')
156
+ action_log_probs = log_probs.gather(-1, codes)
157
+
158
+ return rearrange(action_log_probs, '... 1 -> ...')
159
+
140
160
  def forward(
141
161
  self,
142
162
  residual_stream,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.32
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=somE9gX36c1d9hF2n8Qn4foRY8krHGodvrvulhkIGE8,15006
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=CTGK8ruQ3TkioVUwFTHdrbfzubaeuhSdXHfHtaDcwMY,8813
4
+ metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
+ metacontroller_pytorch-0.0.32.dist-info/METADATA,sha256=hr08iXm6Mb-rnDu2xPrr9YQ6cwTtX1F79MfBYt54Y94,4747
6
+ metacontroller_pytorch-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.32.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
4
- metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
- metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
6
- metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.31.dist-info/RECORD,,