SURE-tools 2.1.54__py3-none-any.whl → 2.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -244,7 +244,7 @@ class PerturbFlow(nn.Module):
244
244
  # allow_broadcast=self.allow_broadcast,
245
245
  # use_cuda=self.use_cuda,
246
246
  # )
247
- self.encoder_concentrate = self.decoder_concentrate
247
+ #self.encoder_concentrate = self.decoder_concentrate
248
248
  else:
249
249
  self.decoder_concentrate = MLP(
250
250
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -417,9 +417,9 @@ class PerturbFlow(nn.Module):
417
417
  alpha = self.encoder_n(zns)
418
418
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
419
419
 
420
- if self.loss_func == 'gamma-poisson':
421
- con_alpha,con_beta = self.encoder_concentrate(zns)
422
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
420
+ #if self.loss_func == 'gamma-poisson':
421
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
422
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
423
423
 
424
424
  def model2(self, xs, us=None):
425
425
  pyro.module('PerturbFlow', self)
@@ -521,9 +521,9 @@ class PerturbFlow(nn.Module):
521
521
  alpha = self.encoder_n(zns)
522
522
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
523
523
 
524
- if self.loss_func == 'gamma-poisson':
525
- con_alpha,con_beta = self.encoder_concentrate(zns)
526
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
524
+ #if self.loss_func == 'gamma-poisson':
525
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
526
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
527
527
 
528
528
  def model3(self, xs, ys, embeds=None):
529
529
  pyro.module('PerturbFlow', self)
@@ -631,9 +631,9 @@ class PerturbFlow(nn.Module):
631
631
  else:
632
632
  zns = embeds
633
633
 
634
- if self.loss_func == 'gamma-poisson':
635
- con_alpha,con_beta = self.encoder_concentrate(zns)
636
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
634
+ #if self.loss_func == 'gamma-poisson':
635
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
636
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
637
637
 
638
638
  def model4(self, xs, us, ys, embeds=None):
639
639
  pyro.module('PerturbFlow', self)
@@ -751,9 +751,9 @@ class PerturbFlow(nn.Module):
751
751
  else:
752
752
  zns = embeds
753
753
 
754
- if self.loss_func == 'gamma-poisson':
755
- con_alpha,con_beta = self.encoder_concentrate(zns)
756
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
754
+ #if self.loss_func == 'gamma-poisson':
755
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
756
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
757
757
 
758
758
  def _total_effects(self, zns, us):
759
759
  zus = None
@@ -933,7 +933,7 @@ class PerturbFlow(nn.Module):
933
933
 
934
934
  def _get_expression_response(self, delta_zs):
935
935
  if self.loss_func == 'gamma-poisson':
936
- alpha,beta = self.encoder_concentrate(delta_zs)
936
+ alpha,beta = self.decoder_concentrate(delta_zs)
937
937
  xs = dist.Gamma(alpha,beta).to_event(1).mean
938
938
  else:
939
939
  xs = self.decoder_concentrate(delta_zs)
@@ -960,7 +960,7 @@ class PerturbFlow(nn.Module):
960
960
  R = np.concatenate(R)
961
961
  return R
962
962
 
963
- def _count(self,concentrate):
963
+ def _count(self,concentrate, library_size=None):
964
964
  if self.loss_func == 'bernoulli':
965
965
  #counts = self.sigmoid(concentrate)
966
966
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -976,6 +976,11 @@ class PerturbFlow(nn.Module):
976
976
  counts = dist.Poisson(rate=rate).to_event(1).mean
977
977
  elif self.loss_func == 'gamma-poisson':
978
978
  counts = dist.Poisson(rate=concentrate).to_event(1).mean
979
+ elif self.loss_func == 'multinomial':
980
+ rate = concentrate.exp()
981
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
982
+ counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
983
+ counts = counts * library_size
979
984
  return counts
980
985
 
981
986
  def _count_sample(self,concentrate):
@@ -987,22 +992,35 @@ class PerturbFlow(nn.Module):
987
992
  counts = dist.Poisson(rate=counts).to_event(1).sample()
988
993
  return counts
989
994
 
990
- def get_counts(self, zs,
995
+ def get_counts(self, zs, library_sizes = None,
991
996
  batch_size: int = 1024,
992
997
  use_sampler: bool = False):
993
998
 
994
999
  zs = convert_to_tensor(zs, device=self.get_device())
995
- dataset = CustomDataset(zs)
1000
+ ls = zs
1001
+
1002
+ if self.loss_func == 'multinomial':
1003
+ assert library_sizes!=None, 'Library sizes are required for multinomial!'
1004
+
1005
+ if type(library_sizes) == list:
1006
+ library_sizes = np.array(library_sizes).view(-1,1)
1007
+ elif len(library_sizes.shape)==1:
1008
+ library_sizes = library_sizes.view(-1,1)
1009
+ ls = convert_to_tensor(library_sizes, device=self.get_device)
1010
+
1011
+ dataset = CustomDataset2(zs,ls)
996
1012
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
997
1013
 
998
1014
  E = []
999
1015
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
1000
- for Z_batch, _ in dataloader:
1016
+ for Z_batch, L_batch, _ in dataloader:
1017
+ if self.loss_func != 'multinomial':
1018
+ L_batch = None
1001
1019
  concentrate = self._get_expression_response(Z_batch)
1002
1020
  if use_sampler:
1003
1021
  counts = self._count_sample(concentrate)
1004
1022
  else:
1005
- counts = self._count(concentrate)
1023
+ counts = self._count(concentrate, L_batch)
1006
1024
  E.append(tensor_to_numpy(counts))
1007
1025
  pbar.update(1)
1008
1026
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.54
3
+ Version: 2.1.56
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=hOVEsBrMAs7T5yi3LW7KV6hwPuwyjZtKG2wyMF6R08E,58614
1
+ SURE/PerturbFlow.py,sha256=CvnmX1QVo4UK4rkmFQd0RR9YHrNjL1GCHM1aj-BHVqM,59536
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.54.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.54.dist-info/METADATA,sha256=VDqYvGzqSz_HeBiPxGgwwl_i-uunVvd3t0MVIa4n6iI,2678
22
- sure_tools-2.1.54.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.54.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.54.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.54.dist-info/RECORD,,
20
+ sure_tools-2.1.56.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.56.dist-info/METADATA,sha256=kWBC-87jEjWE-JHxXGcrerFaxT9G5buo8zwZhkDxu9o,2678
22
+ sure_tools-2.1.56.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.56.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.56.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.56.dist-info/RECORD,,