SURE-tools 2.1.44__py3-none-any.whl → 2.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -189,29 +189,10 @@ class PerturbFlow(nn.Module):
189
189
  use_cuda=self.use_cuda,
190
190
  )
191
191
 
192
- #self.encoder_zn = MLP(
193
- # [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
194
- # activation=activate_fct,
195
- # output_activation=[None, Exp],
196
- # post_layer_fct=post_layer_fct,
197
- # post_act_fct=post_act_fct,
198
- # allow_broadcast=self.allow_broadcast,
199
- # use_cuda=self.use_cuda,
200
- #)
201
-
202
- self.encoder_zn_loc = MLP(
203
- [self.input_size] + hidden_sizes + [self.input_size * latent_dim],
204
- activation=activate_fct,
205
- output_activation=None,
206
- post_layer_fct=post_layer_fct,
207
- post_act_fct=post_act_fct,
208
- allow_broadcast=self.allow_broadcast,
209
- use_cuda=self.use_cuda,
210
- )
211
- self.encoder_zn_scale = MLP(
212
- [self.input_size] + hidden_sizes + [latent_dim],
192
+ self.encoder_zn = MLP(
193
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
213
194
  activation=activate_fct,
214
- output_activation=Exp,
195
+ output_activation=[None, Exp],
215
196
  post_layer_fct=post_layer_fct,
216
197
  post_act_fct=post_act_fct,
217
198
  allow_broadcast=self.allow_broadcast,
@@ -397,11 +378,7 @@ class PerturbFlow(nn.Module):
397
378
 
398
379
  def guide1(self, xs):
399
380
  with pyro.plate('data'):
400
- batch_size = xs.shape[0]
401
- #zn_loc, zn_scale = self.encoder_zn(xs)
402
- zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
403
- zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
404
- zn_loc = zn_loc_3d.sum(dim=1)
381
+ zn_loc, zn_scale = self.encoder_zn(xs)
405
382
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
406
383
 
407
384
  alpha = self.encoder_n(zns)
@@ -489,11 +466,7 @@ class PerturbFlow(nn.Module):
489
466
 
490
467
  def guide2(self, xs, us=None):
491
468
  with pyro.plate('data'):
492
- batch_size = xs.shape[0]
493
- #zn_loc, zn_scale = self.encoder_zn(xs)
494
- zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
495
- zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
496
- zn_loc = zn_loc_3d.sum(dim=1)
469
+ zn_loc, zn_scale = self.encoder_zn(xs)
497
470
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
498
471
 
499
472
  alpha = self.encoder_n(zns)
@@ -588,11 +561,7 @@ class PerturbFlow(nn.Module):
588
561
  def guide3(self, xs, ys, embeds=None):
589
562
  with pyro.plate('data'):
590
563
  if embeds is None:
591
- batch_size = xs.shape[0]
592
- #zn_loc, zn_scale = self.encoder_zn(xs)
593
- zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
594
- zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
595
- zn_loc = zn_loc_3d.sum(dim=1)
564
+ zn_loc, zn_scale = self.encoder_zn(xs)
596
565
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
597
566
 
598
567
  def model4(self, xs, us, ys, embeds=None):
@@ -694,11 +663,7 @@ class PerturbFlow(nn.Module):
694
663
  def guide4(self, xs, us, ys, embeds=None):
695
664
  with pyro.plate('data'):
696
665
  if embeds is None:
697
- batch_size = xs.shape[0]
698
- #zn_loc, zn_scale = self.encoder_zn(xs)
699
- zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
700
- zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
701
- zn_loc = zn_loc_3d.sum(dim=1)
666
+ zn_loc, zn_scale = self.encoder_zn(xs)
702
667
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
703
668
 
704
669
  def _total_effects(self, zns, us):
@@ -727,11 +692,7 @@ class PerturbFlow(nn.Module):
727
692
  return cb
728
693
 
729
694
  def _get_basal_embedding(self, xs):
730
- #zns, _ = self.encoder_zn(xs)
731
- batch_size = xs.shape[0]
732
- zns = self.encoder_zn_loc(xs)
733
- zns_3d = zns.view(batch_size, self.input_size, self.latent_dim)
734
- zns = zns_3d.sum(dim=1)
695
+ zns, _ = self.encoder_zn(xs)
735
696
  return zns
736
697
 
737
698
  def get_basal_embedding(self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.44
3
+ Version: 2.1.45
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=EbbEQBf9HQ6XeUw3g4DKFbe2IVCJRGmCwjQcHwCnQxU,54309
1
+ SURE/PerturbFlow.py,sha256=EjBvHUNmQeyF41H-rzg7fSrcKo9_ts2za-JiMRp4y9M,52394
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
17
17
  SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.44.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.44.dist-info/METADATA,sha256=GOqZOvGRfh--3bd3Bgpkiv7E05gfRoDNflkE3c4Mt6s,2651
22
- sure_tools-2.1.44.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.44.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.44.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.44.dist-info/RECORD,,
20
+ sure_tools-2.1.45.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.45.dist-info/METADATA,sha256=y0VumQ_zuHF_45Mo2MKKap2J_l4rGEzHiYLSxuq1isw,2651
22
+ sure_tools-2.1.45.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.45.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.45.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.45.dist-info/RECORD,,