metacontroller-pytorch 0.0.33__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -143,22 +143,34 @@ class MetaControllerWithBinaryMapper(Module):
143
143
  *self.proposer_to_binary_logits.parameters()
144
144
  ]
145
145
 
146
+ def get_action_dist_for_internal_rl(
147
+ self,
148
+ residual_stream
149
+ ):
150
+ meta_embed = self.model_to_meta(residual_stream)
151
+
152
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
153
+
154
+ return self.proposer_to_binary_logits(proposed_action_hidden)
155
+
146
156
  def log_prob(
147
157
  self,
148
158
  action_dist,
149
159
  sampled_latent_action
150
160
  ):
151
- action_prob = action_dist.sigmoid()
152
- probs = stack((action_prob, 1. - action_prob), dim = -1)
153
- log_probs = log(probs)
161
+ log_probs = stack((
162
+ F.logsigmoid(action_dist),
163
+ F.logsigmoid(-action_dist)
164
+ ), dim = -1)
154
165
 
155
166
  indices = sampled_latent_action.argmax(dim = -1)
156
167
  codes = self.binary_mapper.codes[indices].long()
157
168
 
158
169
  codes = rearrange(codes, '... -> ... 1')
159
170
  action_log_probs = log_probs.gather(-1, codes)
171
+ action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
160
172
 
161
- return rearrange(action_log_probs, '... 1 -> ...')
173
+ return action_log_probs.sum(dim = -1)
162
174
 
163
175
  def forward(
164
176
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.33
3
+ Version: 0.0.34
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=Ii4Z2MuVMXJeWxAnZnRfSKrGv1t6_Y4R5BgvmtDrSW8,16203
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=PgXK7uk--gvZPk1h4WLdCBEA7m9Ji12xixa2wqLmsLY,9234
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.34.dist-info/METADATA,sha256=xBFG34yRWTkcfJLFvC22jARLkR_W_6kcSsAV8r5UFWY,4747
6
+ metacontroller_pytorch-0.0.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.34.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=B9XHYgVBrcJkEWhUORz--D5AHjcLnvLRVY9SRqVbhdw,16222
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=7vGtenScxvDQhkeYUmNnTTbVTJAtIFUVqEoWVGZP2Is,8936
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.33.dist-info/METADATA,sha256=kWHDAnQeEueYWRzqPi5ouXjtSDIG2bnDqOfhznSUOoM,4747
6
- metacontroller_pytorch-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.33.dist-info/RECORD,,