SURE-tools 2.1.76__py3-none-any.whl → 2.1.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -881,25 +881,11 @@ class PerturbFlow(nn.Module):
881
881
  if self.loss_func == 'bernoulli':
882
882
  #counts = self.sigmoid(concentrate)
883
883
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
884
- elif self.loss_func == 'negbinomial':
885
- rate = concentrate.exp()
886
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
887
-
888
- total_count = self.total_count
889
- #total_count = pyro.param("inverse_dispersion")
890
- #store = pyro.get_param_store()
891
- #total_count = store['inverse_dispersion']
892
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
893
- elif self.loss_func == 'poisson':
884
+ else:
894
885
  rate = concentrate.exp()
895
886
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
896
887
  counts = theta * library_size
897
888
  #counts = dist.Poisson(rate=rate).to_event(1).mean
898
- elif self.loss_func == 'multinomial':
899
- rate = concentrate.exp()
900
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
901
- #counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
902
- counts = theta * library_size
903
889
  return counts
904
890
 
905
891
  def _count_sample(self,concentrate):
@@ -911,22 +897,17 @@ class PerturbFlow(nn.Module):
911
897
  counts = dist.Poisson(rate=counts).to_event(1).sample()
912
898
  return counts
913
899
 
914
- def get_counts(self, zs, library_sizes = None,
900
+ def get_counts(self, zs, library_sizes,
915
901
  batch_size: int = 1024,
916
902
  use_sampler: bool = False):
917
903
 
918
904
  zs = convert_to_tensor(zs, device=self.get_device())
919
905
 
920
- if self.loss_func in ['multinomial','poisson']:
921
- assert library_sizes is not None, 'Library sizes are required for multinomial!'
922
-
923
- if type(library_sizes) == list:
924
- library_sizes = np.array(library_sizes).view(-1,1)
925
- elif len(library_sizes.shape)==1:
926
- library_sizes = library_sizes.view(-1,1)
927
- ls = convert_to_tensor(library_sizes, device=self.get_device())
928
- else:
929
- ls = zs
906
+ if type(library_sizes) == list:
907
+ library_sizes = np.array(library_sizes).view(-1,1)
908
+ elif len(library_sizes.shape)==1:
909
+ library_sizes = library_sizes.view(-1,1)
910
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
930
911
 
931
912
  dataset = CustomDataset2(zs,ls)
932
913
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
@@ -1084,7 +1065,8 @@ class PerturbFlow(nn.Module):
1084
1065
  pbar.set_postfix({'loss': str_loss})
1085
1066
  pbar.update(1)
1086
1067
 
1087
- self.total_count = pyro.param('inverse_dispersion')
1068
+ if self.loss_func == 'negbinomial':
1069
+ self.total_count = pyro.param('inverse_dispersion')
1088
1070
 
1089
1071
  @classmethod
1090
1072
  def save_model(cls, model, file_path, compression=False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.76
3
+ Version: 2.1.78
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=sltfC-Cb5bFI_EJTMLqjksZiSRCHQvP89Q1NFId2fBg,54669
1
+ SURE/PerturbFlow.py,sha256=BbpN3CSA-uSvIwVzHP1Vq7fuUxG3VC-x1KClCGa1euk,53695
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.76.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.76.dist-info/METADATA,sha256=h-xwe5IHtIZrIg8jlHXOxm2z9irVDcg1eFYnIoxr374,2678
22
- sure_tools-2.1.76.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.76.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.76.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.76.dist-info/RECORD,,
20
+ sure_tools-2.1.78.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.78.dist-info/METADATA,sha256=39CFgVJlh26Bd1i8qkQcK5jQOiDGDz_JzoNKKVS3S5Q,2678
22
+ sure_tools-2.1.78.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.78.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.78.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.78.dist-info/RECORD,,