SURE-tools 2.1.81__py3-none-any.whl → 2.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/PerturbFlow.py CHANGED
@@ -204,7 +204,7 @@ class PerturbFlow(nn.Module):
204
204
  for i in np.arange(self.cell_factor_size):
205
205
  if self.use_bias[i]:
206
206
  self.cell_factor_effect.append(MLP(
207
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
207
+ [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
208
208
  activation=activate_fct,
209
209
  output_activation=None,
210
210
  post_layer_fct=post_layer_fct,
@@ -215,7 +215,7 @@ class PerturbFlow(nn.Module):
215
215
  )
216
216
  else:
217
217
  self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
218
+ [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
219
219
  activation=activate_fct,
220
220
  output_activation=None,
221
221
  post_layer_fct=post_layer_fct,
@@ -429,7 +429,7 @@ class PerturbFlow(nn.Module):
429
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
430
 
431
431
  if self.cell_factor_size>0:
432
- zus = self._total_effects(zns, us)
432
+ zus = self._total_effects(ns, us)
433
433
  zs = zns+zus
434
434
  else:
435
435
  zs = zns
@@ -631,7 +631,7 @@ class PerturbFlow(nn.Module):
631
631
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
632
  # else:
633
633
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
- zus = self._total_effects(zns, us)
634
+ zus = self._total_effects(ns, us)
635
635
  zs = zns+zus
636
636
  else:
637
637
  zs = zns
@@ -681,6 +681,9 @@ class PerturbFlow(nn.Module):
681
681
  zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
682
682
  return zus
683
683
 
684
+ def _get_codebook_identity(self):
685
+ return torch.eye(self.code_size, **self.options)
686
+
684
687
  def _get_codebook(self):
685
688
  I = torch.eye(self.code_size, **self.options)
686
689
  if self.latent_dist=='studentt':
@@ -839,7 +842,8 @@ class PerturbFlow(nn.Module):
839
842
 
840
843
  def _cell_response(self, xs, factor_idx, perturb):
841
844
  #zns,_ = self.encoder_zn(xs)
842
- zns,_ = self._get_basal_embedding(xs)
845
+ #zns,_ = self._get_basal_embedding(xs)
846
+ zns = self._soft_assignments(xs)
843
847
  if perturb.ndim==2:
844
848
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
845
849
  else:
@@ -873,7 +877,8 @@ class PerturbFlow(nn.Module):
873
877
  return Z
874
878
 
875
879
  def get_metacell_response(self, factor_idx, perturb):
876
- zs = self._get_codebook()
880
+ #zs = self._get_codebook()
881
+ zs = self._get_codebook_identity()
877
882
  ps = convert_to_tensor(perturb, device=self.get_device())
878
883
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
879
884
  return tensor_to_numpy(ms)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.81
3
+ Version: 2.1.83
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=Hi7oF9S_eImo2Q-Oxu0mW461DHOG0O05TIE_pUW3uJo,54577
1
+ SURE/PerturbFlow.py,sha256=5HzS8oB06iSR3JM5AalGfYi-quxbjkTZeTypjih-VBI,54759
2
2
  SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.81.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.81.dist-info/METADATA,sha256=uEiJyslzIPEZ_4HfoTR9JLtBuIfmAPH-HIu2Ar2dKg0,2678
22
- sure_tools-2.1.81.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.81.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.81.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.81.dist-info/RECORD,,
20
+ sure_tools-2.1.83.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.83.dist-info/METADATA,sha256=H-q3GA7c-UxJp8C3OfR-f7YpSkhqaSQD3oZ_qcg9OJo,2678
22
+ sure_tools-2.1.83.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.83.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.83.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.83.dist-info/RECORD,,