SURE-tools 2.1.83__py3-none-any.whl → 2.1.85__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -74,6 +74,7 @@ class PerturbFlow(nn.Module):
74
74
  use_cuda: bool = True,
75
75
  seed: int = 42,
76
76
  zero_bias: bool|list = True,
77
+ enumrate: bool = False,
77
78
  dtype = torch.float32, # type: ignore
78
79
  ):
79
80
  super().__init__()
@@ -102,6 +103,7 @@ class PerturbFlow(nn.Module):
102
103
  else:
103
104
  self.use_bias = [not zero_bias] * self.cell_factor_size
104
105
  #self.use_bias = not zero_bias
106
+ self.enumerate = enumerate
105
107
 
106
108
  self.codebook_weights = None
107
109
 
@@ -204,7 +206,7 @@ class PerturbFlow(nn.Module):
204
206
  for i in np.arange(self.cell_factor_size):
205
207
  if self.use_bias[i]:
206
208
  self.cell_factor_effect.append(MLP(
207
- [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
209
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
210
  activation=activate_fct,
209
211
  output_activation=None,
210
212
  post_layer_fct=post_layer_fct,
@@ -215,7 +217,7 @@ class PerturbFlow(nn.Module):
215
217
  )
216
218
  else:
217
219
  self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
220
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
219
221
  activation=activate_fct,
220
222
  output_activation=None,
221
223
  post_layer_fct=post_layer_fct,
@@ -429,7 +431,10 @@ class PerturbFlow(nn.Module):
429
431
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
432
 
431
433
  if self.cell_factor_size>0:
432
- zus = self._total_effects(ns, us)
434
+ if self.enumerate:
435
+ zus = self._total_effects(zn_loc, us)
436
+ else:
437
+ zus = self._total_effects(zns, us)
433
438
  zs = zns+zus
434
439
  else:
435
440
  zs = zns
@@ -631,7 +636,10 @@ class PerturbFlow(nn.Module):
631
636
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
637
  # else:
633
638
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
- zus = self._total_effects(ns, us)
639
+ if self.enumerate:
640
+ zus = self._total_effects(zn_loc, us)
641
+ else:
642
+ zus = self._total_effects(zns, us)
635
643
  zs = zns+zus
636
644
  else:
637
645
  zs = zns
@@ -842,8 +850,7 @@ class PerturbFlow(nn.Module):
842
850
 
843
851
  def _cell_response(self, xs, factor_idx, perturb):
844
852
  #zns,_ = self.encoder_zn(xs)
845
- #zns,_ = self._get_basal_embedding(xs)
846
- zns = self._soft_assignments(xs)
853
+ zns,_ = self._get_basal_embedding(xs)
847
854
  if perturb.ndim==2:
848
855
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
849
856
  else:
@@ -877,8 +884,7 @@ class PerturbFlow(nn.Module):
877
884
  return Z
878
885
 
879
886
  def get_metacell_response(self, factor_idx, perturb):
880
- #zs = self._get_codebook()
881
- zs = self._get_codebook_identity()
887
+ zs = self._get_codebook()
882
888
  ps = convert_to_tensor(perturb, device=self.get_device())
883
889
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
884
890
  return tensor_to_numpy(ms)
SURE/utils/custom_mlp.py CHANGED
@@ -239,7 +239,10 @@ class ZeroBiasMLP(nn.Module):
239
239
  def forward(self, x):
240
240
  y = self.mlp(x)
241
241
  mask = torch.zeros_like(y)
242
- mask[x[1][:,0]>0,:] = 1
242
+ if len(y.shape)==2:
243
+ mask[x[1][:,0]>0,:] = 1
244
+ elif len(y.shape)==3:
245
+ mask[:,x[1][:,0]>0,:] = 1
243
246
  return y*mask
244
247
 
245
248
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.83
3
+ Version: 2.1.85
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=5HzS8oB06iSR3JM5AalGfYi-quxbjkTZeTypjih-VBI,54759
1
+ SURE/PerturbFlow.py,sha256=hu4kD6PphVqAG8GlITk44trwVfE6jz72RhQlTYRbi-o,54991
2
2
  SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -14,12 +14,12 @@ SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
14
14
  SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
15
15
  SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
16
16
  SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
- SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
17
+ SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.83.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.83.dist-info/METADATA,sha256=H-q3GA7c-UxJp8C3OfR-f7YpSkhqaSQD3oZ_qcg9OJo,2678
22
- sure_tools-2.1.83.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.83.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.83.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.83.dist-info/RECORD,,
20
+ sure_tools-2.1.85.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.85.dist-info/METADATA,sha256=I2sMctindGA7kNf_TAdn64Oq6GL_QcGDuKLV6B9zQKY,2678
22
+ sure_tools-2.1.85.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.85.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.85.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.85.dist-info/RECORD,,