SURE-tools 2.1.78__tar.gz → 2.1.80__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.78 → sure_tools-2.1.80}/PKG-INFO +1 -1
  2. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/PerturbFlow.py +30 -8
  3. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/SURE.py +17 -13
  4. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/PKG-INFO +1 -1
  5. {sure_tools-2.1.78 → sure_tools-2.1.80}/setup.py +1 -1
  6. {sure_tools-2.1.78 → sure_tools-2.1.80}/LICENSE +0 -0
  7. {sure_tools-2.1.78 → sure_tools-2.1.80}/README.md +0 -0
  8. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/__init__.py +0 -0
  9. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/assembly/__init__.py +0 -0
  10. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/assembly/assembly.py +0 -0
  11. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/assembly/atlas.py +0 -0
  12. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/atac/utils.py +0 -0
  14. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/codebook/__init__.py +0 -0
  15. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/codebook/codebook.py +0 -0
  16. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/flow/__init__.py +0 -0
  17. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/flow/flow_stats.py +0 -0
  18. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/flow/plot_quiver.py +0 -0
  19. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/perturb/__init__.py +0 -0
  20. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/perturb/perturb.py +0 -0
  21. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/utils/__init__.py +0 -0
  22. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/utils/custom_mlp.py +0 -0
  23. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.78 → sure_tools-2.1.80}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.78 → sure_tools-2.1.80}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.78
3
+ Version: 2.1.80
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -104,7 +104,6 @@ class PerturbFlow(nn.Module):
104
104
  #self.use_bias = not zero_bias
105
105
 
106
106
  self.codebook_weights = None
107
- self.total_count = None
108
107
 
109
108
  set_random_seed(seed)
110
109
  self.setup_networks()
@@ -318,7 +317,6 @@ class PerturbFlow(nn.Module):
318
317
  if self.loss_func=='negbinomial':
319
318
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
320
319
  xs.new_ones(self.input_size), constraint=constraints.positive)
321
- self.total_count = total_count
322
320
 
323
321
  if self.use_zeroinflate:
324
322
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -398,7 +396,6 @@ class PerturbFlow(nn.Module):
398
396
  if self.loss_func=='negbinomial':
399
397
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
400
398
  xs.new_ones(self.input_size), constraint=constraints.positive)
401
- self.total_count = total_count
402
399
 
403
400
  if self.use_zeroinflate:
404
401
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -483,7 +480,6 @@ class PerturbFlow(nn.Module):
483
480
  if self.loss_func=='negbinomial':
484
481
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
485
482
  xs.new_ones(self.input_size), constraint=constraints.positive)
486
- self.total_count = total_count
487
483
 
488
484
  if self.use_zeroinflate:
489
485
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -580,7 +576,6 @@ class PerturbFlow(nn.Module):
580
576
  if self.loss_func=='negbinomial':
581
577
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
582
578
  xs.new_ones(self.input_size), constraint=constraints.positive)
583
- self.total_count = total_count
584
579
 
585
580
  if self.use_zeroinflate:
586
581
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -812,6 +807,36 @@ class PerturbFlow(nn.Module):
812
807
  A = np.concatenate(A)
813
808
  return A
814
809
 
810
+ def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
811
+ perturbs_reference = np.array(perturbs_reference)
812
+
813
+ # basal embedding
814
+ zs = self.get_basal_embedding(xs)
815
+ for pert in perturbs_predict:
816
+ pert_idx = np.where(perturbs_reference==pert)[0]
817
+ us_i = us[:,pert_idx].reshape(-1,1)
818
+
819
+ # factor effect of xs
820
+ dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
821
+
822
+ # perturbation effect
823
+ ps = np.ones_like(us_i)
824
+ dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
825
+
826
+ zs = zs + dzs0 + dzs
827
+
828
+ if library_sizes is None:
829
+ library_sizes = np.sum(xs, axis=1, keepdims=True)
830
+ elif type(library_sizes) == list:
831
+ library_sizes = np.array(library_sizes)
832
+ library_sizes = library_sizes.reshape(-1,1)
833
+ elif len(library_sizes.shape)==1:
834
+ library_sizes = library_sizes.reshape(-1,1)
835
+
836
+ counts = self.get_counts(zs, library_sizes=library_sizes)
837
+
838
+ return counts, zs
839
+
815
840
  def _cell_response(self, xs, factor_idx, perturb):
816
841
  #zns,_ = self.encoder_zn(xs)
817
842
  zns,_ = self._get_basal_embedding(xs)
@@ -1065,9 +1090,6 @@ class PerturbFlow(nn.Module):
1065
1090
  pbar.set_postfix({'loss': str_loss})
1066
1091
  pbar.update(1)
1067
1092
 
1068
- if self.loss_func == 'negbinomial':
1069
- self.total_count = pyro.param('inverse_dispersion')
1070
-
1071
1093
  @classmethod
1072
1094
  def save_model(cls, model, file_path, compression=False):
1073
1095
  """Save the model to the specified file path."""
@@ -369,12 +369,13 @@ class SURE(nn.Module):
369
369
 
370
370
  zs = zns
371
371
  concentrate = self.decoder_concentrate(zs)
372
- if self.loss_func == 'bernoulli':
372
+ if self.loss_func in ['bernoulli']:
373
373
  log_theta = concentrate
374
374
  else:
375
375
  rate = concentrate.exp()
376
- if self.loss_func != 'poisson':
377
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
376
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
377
+ if self.loss_func == 'poisson':
378
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
378
379
 
379
380
  if self.loss_func == 'negbinomial':
380
381
  if self.use_zeroinflate:
@@ -451,12 +452,13 @@ class SURE(nn.Module):
451
452
  zs = zns
452
453
 
453
454
  concentrate = self.decoder_concentrate(zs)
454
- if self.loss_func == 'bernoulli':
455
+ if self.loss_func in ['bernoulli']:
455
456
  log_theta = concentrate
456
457
  else:
457
458
  rate = concentrate.exp()
458
- if self.loss_func != 'poisson':
459
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
459
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
460
+ if self.loss_func == 'poisson':
461
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
460
462
 
461
463
  if self.loss_func == 'negbinomial':
462
464
  if self.use_zeroinflate:
@@ -545,12 +547,13 @@ class SURE(nn.Module):
545
547
  zs = zns
546
548
 
547
549
  concentrate = self.decoder_concentrate(zs)
548
- if self.loss_func == 'bernoulli':
550
+ if self.loss_func in ['bernoulli']:
549
551
  log_theta = concentrate
550
552
  else:
551
553
  rate = concentrate.exp()
552
- if self.loss_func != 'poisson':
553
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
554
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
555
+ if self.loss_func == 'poisson':
556
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
554
557
 
555
558
  if self.loss_func == 'negbinomial':
556
559
  if self.use_zeroinflate:
@@ -641,13 +644,14 @@ class SURE(nn.Module):
641
644
  zs = zns
642
645
 
643
646
  concentrate = self.decoder_concentrate(zs)
644
- if self.loss_func == 'bernoulli':
647
+ if self.loss_func in ['bernoulli']:
645
648
  log_theta = concentrate
646
649
  else:
647
650
  rate = concentrate.exp()
648
- if self.loss_func != 'poisson':
649
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
650
-
651
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
652
+ if self.loss_func == 'poisson':
653
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
654
+
651
655
  if self.loss_func == 'negbinomial':
652
656
  if self.use_zeroinflate:
653
657
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.78
3
+ Version: 2.1.80
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.78',
8
+ version='2.1.80',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes