SURE-tools 2.1.42__py3-none-any.whl → 2.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -73,7 +73,7 @@ class PerturbFlow(nn.Module):
73
73
  config_enum: str = 'parallel',
74
74
  use_cuda: bool = False,
75
75
  seed: int = 42,
76
- zero_bias: bool = True,
76
+ zero_bias: bool|list = True,
77
77
  dtype = torch.float32, # type: ignore
78
78
  ):
79
79
  super().__init__()
@@ -97,7 +97,11 @@ class PerturbFlow(nn.Module):
97
97
  self.post_layer_fct = post_layer_fct
98
98
  self.post_act_fct = post_act_fct
99
99
  self.hidden_layer_activation = hidden_layer_activation
100
- self.use_bias = not zero_bias
100
+ if type(zero_bias) == list:
101
+ self.use_bias = [not x for x in zero_bias]
102
+ else:
103
+ self.use_bias = [not zero_bias] * self.cell_factor_size
104
+ #self.use_bias = not zero_bias
101
105
 
102
106
  self.codebook_weights = None
103
107
 
@@ -185,10 +189,29 @@ class PerturbFlow(nn.Module):
185
189
  use_cuda=self.use_cuda,
186
190
  )
187
191
 
188
- self.encoder_zn = MLP(
189
- [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
192
+ #self.encoder_zn = MLP(
193
+ # [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
194
+ # activation=activate_fct,
195
+ # output_activation=[None, Exp],
196
+ # post_layer_fct=post_layer_fct,
197
+ # post_act_fct=post_act_fct,
198
+ # allow_broadcast=self.allow_broadcast,
199
+ # use_cuda=self.use_cuda,
200
+ #)
201
+
202
+ self.encoder_zn_loc = MLP(
203
+ [self.input_size] + hidden_sizes + [self.input_size * latent_dim],
204
+ activation=activate_fct,
205
+ output_activation=None,
206
+ post_layer_fct=post_layer_fct,
207
+ post_act_fct=post_act_fct,
208
+ allow_broadcast=self.allow_broadcast,
209
+ use_cuda=self.use_cuda,
210
+ )
211
+ self.encoder_zn_scale = MLP(
212
+ [self.input_size] + hidden_sizes + [latent_dim],
190
213
  activation=activate_fct,
191
- output_activation=[None, Exp],
214
+ output_activation=Exp,
192
215
  post_layer_fct=post_layer_fct,
193
216
  post_act_fct=post_act_fct,
194
217
  allow_broadcast=self.allow_broadcast,
@@ -198,7 +221,7 @@ class PerturbFlow(nn.Module):
198
221
  if self.cell_factor_size>0:
199
222
  self.cell_factor_effect = nn.ModuleList()
200
223
  for i in np.arange(self.cell_factor_size):
201
- if self.use_bias:
224
+ if self.use_bias[i]:
202
225
  self.cell_factor_effect.append(MLP(
203
226
  [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
204
227
  activation=activate_fct,
@@ -374,7 +397,11 @@ class PerturbFlow(nn.Module):
374
397
 
375
398
  def guide1(self, xs):
376
399
  with pyro.plate('data'):
377
- zn_loc, zn_scale = self.encoder_zn(xs)
400
+ batch_size = xs.shape[0]
401
+ #zn_loc, zn_scale = self.encoder_zn(xs)
402
+ zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
403
+ zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
404
+ zn_loc = zn_loc_3d.sum(dim=1)
378
405
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
379
406
 
380
407
  alpha = self.encoder_n(zns)
@@ -462,7 +489,11 @@ class PerturbFlow(nn.Module):
462
489
 
463
490
  def guide2(self, xs, us=None):
464
491
  with pyro.plate('data'):
465
- zn_loc, zn_scale = self.encoder_zn(xs)
492
+ batch_size = xs.shape[0]
493
+ #zn_loc, zn_scale = self.encoder_zn(xs)
494
+ zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
495
+ zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
496
+ zn_loc = zn_loc_3d.sum(dim=1)
466
497
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
467
498
 
468
499
  alpha = self.encoder_n(zns)
@@ -557,7 +588,11 @@ class PerturbFlow(nn.Module):
557
588
  def guide3(self, xs, ys, embeds=None):
558
589
  with pyro.plate('data'):
559
590
  if embeds is None:
560
- zn_loc, zn_scale = self.encoder_zn(xs)
591
+ batch_size = xs.shape[0]
592
+ #zn_loc, zn_scale = self.encoder_zn(xs)
593
+ zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
594
+ zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
595
+ zn_loc = zn_loc_3d.sum(dim=1)
561
596
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
562
597
 
563
598
  def model4(self, xs, us, ys, embeds=None):
@@ -659,7 +694,11 @@ class PerturbFlow(nn.Module):
659
694
  def guide4(self, xs, us, ys, embeds=None):
660
695
  with pyro.plate('data'):
661
696
  if embeds is None:
662
- zn_loc, zn_scale = self.encoder_zn(xs)
697
+ batch_size = xs.shape[0]
698
+ #zn_loc, zn_scale = self.encoder_zn(xs)
699
+ zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
700
+ zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
701
+ zn_loc = zn_loc_3d.sum(dim=1)
663
702
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
664
703
 
665
704
  def _total_effects(self, zns, us):
@@ -688,7 +727,11 @@ class PerturbFlow(nn.Module):
688
727
  return cb
689
728
 
690
729
  def _get_basal_embedding(self, xs):
691
- zns, _ = self.encoder_zn(xs)
730
+ #zns, _ = self.encoder_zn(xs)
731
+ batch_size = xs.shape[0]
732
+ zns = self.encoder_zn_loc(xs)
733
+ zns_3d = zns.view(batch_size, self.input_size, self.latent_dim)
734
+ zns = zns_3d.sum(dim=1)
692
735
  return zns
693
736
 
694
737
  def get_basal_embedding(self,
@@ -727,7 +770,8 @@ class PerturbFlow(nn.Module):
727
770
  if self.supervised_mode:
728
771
  alpha = self.encoder_n(xs)
729
772
  else:
730
- zns,_ = self.encoder_zn(xs)
773
+ #zns,_ = self.encoder_zn(xs)
774
+ zns = self._get_basal_embedding(xs)
731
775
  alpha = self.encoder_n(zns)
732
776
  return alpha
733
777
 
@@ -797,7 +841,8 @@ class PerturbFlow(nn.Module):
797
841
  return A
798
842
 
799
843
  def _cell_response(self, xs, factor_idx, perturb):
800
- zns,_ = self.encoder_zn(xs)
844
+ #zns,_ = self.encoder_zn(xs)
845
+ zns = self._get_basal_embedding(xs)
801
846
  if perturb.ndim==2:
802
847
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
803
848
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.42
3
+ Version: 2.1.44
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=4oz02P9kTRNJ897ITR1oOxKYGVw2sgvFKazz2FC-6f0,52118
1
+ SURE/PerturbFlow.py,sha256=EbbEQBf9HQ6XeUw3g4DKFbe2IVCJRGmCwjQcHwCnQxU,54309
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
17
17
  SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.42.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.42.dist-info/METADATA,sha256=hwV5P2BtOm4qQ-v0seI9yavlU5zNiHr0_jKovfBl-0E,2651
22
- sure_tools-2.1.42.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.42.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.42.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.42.dist-info/RECORD,,
20
+ sure_tools-2.1.44.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.44.dist-info/METADATA,sha256=GOqZOvGRfh--3bd3Bgpkiv7E05gfRoDNflkE3c4Mt6s,2651
22
+ sure_tools-2.1.44.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.44.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.44.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.44.dist-info/RECORD,,