SURE-tools 2.1.54__py3-none-any.whl → 2.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -960,7 +960,7 @@ class PerturbFlow(nn.Module):
960
960
  R = np.concatenate(R)
961
961
  return R
962
962
 
963
- def _count(self,concentrate):
963
+ def _count(self,concentrate, library_size=None):
964
964
  if self.loss_func == 'bernoulli':
965
965
  #counts = self.sigmoid(concentrate)
966
966
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -976,6 +976,11 @@ class PerturbFlow(nn.Module):
976
976
  counts = dist.Poisson(rate=rate).to_event(1).mean
977
977
  elif self.loss_func == 'gamma-poisson':
978
978
  counts = dist.Poisson(rate=concentrate).to_event(1).mean
979
+ elif self.loss_func == 'multinomial':
980
+ rate = concentrate.exp()
981
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
982
+ counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
983
+ counts = counts * library_size
979
984
  return counts
980
985
 
981
986
  def _count_sample(self,concentrate):
@@ -987,22 +992,35 @@ class PerturbFlow(nn.Module):
987
992
  counts = dist.Poisson(rate=counts).to_event(1).sample()
988
993
  return counts
989
994
 
990
- def get_counts(self, zs,
995
+ def get_counts(self, zs, library_sizes = None,
991
996
  batch_size: int = 1024,
992
997
  use_sampler: bool = False):
993
998
 
994
999
  zs = convert_to_tensor(zs, device=self.get_device())
995
- dataset = CustomDataset(zs)
1000
+ ls = zs
1001
+
1002
+ if self.loss_func == 'multinomial':
1003
+ assert library_sizes!=None, 'Library sizes are required for multinomial!'
1004
+
1005
+ if type(library_sizes) == list:
1006
+ library_sizes = np.array(library_sizes).view(-1,1)
1007
+ elif len(library_sizes.shape)==1:
1008
+ library_sizes = library_sizes.view(-1,1)
1009
+ ls = convert_to_tensor(library_sizes, device=self.get_device)
1010
+
1011
+ dataset = CustomDataset2(zs,ls)
996
1012
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
997
1013
 
998
1014
  E = []
999
1015
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
1000
- for Z_batch, _ in dataloader:
1016
+ for Z_batch, L_batch, _ in dataloader:
1017
+ if self.loss_func != 'multinomial':
1018
+ L_batch = None
1001
1019
  concentrate = self._get_expression_response(Z_batch)
1002
1020
  if use_sampler:
1003
1021
  counts = self._count_sample(concentrate)
1004
1022
  else:
1005
- counts = self._count(concentrate)
1023
+ counts = self._count(concentrate, L_batch)
1006
1024
  E.append(tensor_to_numpy(counts))
1007
1025
  pbar.update(1)
1008
1026
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.54
3
+ Version: 2.1.55
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=hOVEsBrMAs7T5yi3LW7KV6hwPuwyjZtKG2wyMF6R08E,58614
1
+ SURE/PerturbFlow.py,sha256=0-hD4NFKd0zvh_kBOCeh9irAjJ5TuyD7djKJKDCZv6I,59523
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.54.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.54.dist-info/METADATA,sha256=VDqYvGzqSz_HeBiPxGgwwl_i-uunVvd3t0MVIa4n6iI,2678
22
- sure_tools-2.1.54.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.54.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.54.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.54.dist-info/RECORD,,
20
+ sure_tools-2.1.55.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.55.dist-info/METADATA,sha256=GmbQukuqLtfvGrGd0VCuzY5S396a2I-M08_hYkB9vB8,2678
22
+ sure_tools-2.1.55.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.55.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.55.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.55.dist-info/RECORD,,