SURE-tools 2.2.4__tar.gz → 2.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.2.4 → sure_tools-2.2.10}/PKG-INFO +1 -1
  2. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/DensityFlow.py +36 -41
  3. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/PKG-INFO +1 -1
  4. {sure_tools-2.2.4 → sure_tools-2.2.10}/setup.py +1 -1
  5. {sure_tools-2.2.4 → sure_tools-2.2.10}/LICENSE +0 -0
  6. {sure_tools-2.2.4 → sure_tools-2.2.10}/README.md +0 -0
  7. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/SURE.py +0 -0
  8. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/__init__.py +0 -0
  9. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/__init__.py +0 -0
  10. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/assembly.py +0 -0
  11. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/atlas.py +0 -0
  12. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/atac/utils.py +0 -0
  14. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/codebook/__init__.py +0 -0
  15. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/codebook/codebook.py +0 -0
  16. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/__init__.py +0 -0
  17. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/flow_stats.py +0 -0
  18. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/plot_quiver.py +0 -0
  19. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/perturb/__init__.py +0 -0
  20. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/perturb/perturb.py +0 -0
  21. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/__init__.py +0 -0
  22. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/custom_mlp.py +0 -0
  23. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.2.4 → sure_tools-2.2.10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.4
3
+ Version: 2.2.10
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -62,7 +62,7 @@ class DensityFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = True,
68
68
  hidden_layers: list = [500],
@@ -234,6 +234,16 @@ class DensityFlow(nn.Module):
234
234
  allow_broadcast=self.allow_broadcast,
235
235
  use_cuda=self.use_cuda,
236
236
  )
237
+ if self.loss_func == 'negbinomial':
238
+ self.decoder_total_count = MLP(
239
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
240
+ activation=activate_fct,
241
+ output_activation=Exp,
242
+ post_layer_fct=post_layer_fct,
243
+ post_act_fct=post_act_fct,
244
+ allow_broadcast=self.allow_broadcast,
245
+ use_cuda=self.use_cuda,
246
+ )
237
247
 
238
248
  if self.latent_dist == 'studentt':
239
249
  self.codebook = MLP(
@@ -314,9 +324,9 @@ class DensityFlow(nn.Module):
314
324
  batch_size = xs.size(0)
315
325
  self.options = dict(dtype=xs.dtype, device=xs.device)
316
326
 
317
- if self.loss_func=='negbinomial':
318
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
319
- xs.new_ones(self.input_size), constraint=constraints.positive)
327
+ #if self.loss_func=='negbinomial':
328
+ # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
329
+ # xs.new_ones(self.input_size), constraint=constraints.positive)
320
330
 
321
331
  if self.use_zeroinflate:
322
332
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -360,6 +370,7 @@ class DensityFlow(nn.Module):
360
370
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
361
371
 
362
372
  if self.loss_func == 'negbinomial':
373
+ total_count = self.decoder_total_count(zs)
363
374
  if self.use_zeroinflate:
364
375
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
365
376
  else:
@@ -393,9 +404,9 @@ class DensityFlow(nn.Module):
393
404
  batch_size = xs.size(0)
394
405
  self.options = dict(dtype=xs.dtype, device=xs.device)
395
406
 
396
- if self.loss_func=='negbinomial':
397
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
398
- xs.new_ones(self.input_size), constraint=constraints.positive)
407
+ #if self.loss_func=='negbinomial':
408
+ # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
409
+ # xs.new_ones(self.input_size), constraint=constraints.positive)
399
410
 
400
411
  if self.use_zeroinflate:
401
412
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -444,6 +455,7 @@ class DensityFlow(nn.Module):
444
455
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
445
456
 
446
457
  if self.loss_func == 'negbinomial':
458
+ total_count = self.decoder_total_count(zs)
447
459
  if self.use_zeroinflate:
448
460
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
449
461
  else:
@@ -477,9 +489,9 @@ class DensityFlow(nn.Module):
477
489
  batch_size = xs.size(0)
478
490
  self.options = dict(dtype=xs.dtype, device=xs.device)
479
491
 
480
- if self.loss_func=='negbinomial':
481
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
482
- xs.new_ones(self.input_size), constraint=constraints.positive)
492
+ #if self.loss_func=='negbinomial':
493
+ # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
494
+ # xs.new_ones(self.input_size), constraint=constraints.positive)
483
495
 
484
496
  if self.use_zeroinflate:
485
497
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -540,6 +552,7 @@ class DensityFlow(nn.Module):
540
552
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
541
553
 
542
554
  if self.loss_func == 'negbinomial':
555
+ total_count = self.decoder_total_count(zs)
543
556
  if self.use_zeroinflate:
544
557
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
545
558
  else:
@@ -573,9 +586,9 @@ class DensityFlow(nn.Module):
573
586
  batch_size = xs.size(0)
574
587
  self.options = dict(dtype=xs.dtype, device=xs.device)
575
588
 
576
- if self.loss_func=='negbinomial':
577
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
578
- xs.new_ones(self.input_size), constraint=constraints.positive)
589
+ #if self.loss_func=='negbinomial':
590
+ # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
591
+ # xs.new_ones(self.input_size), constraint=constraints.positive)
579
592
 
580
593
  if self.use_zeroinflate:
581
594
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -646,6 +659,7 @@ class DensityFlow(nn.Module):
646
659
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
647
660
 
648
661
  if self.loss_func == 'negbinomial':
662
+ total_count = self.decoder_total_count(zs)
649
663
  if self.use_zeroinflate:
650
664
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
651
665
  else:
@@ -824,9 +838,11 @@ class DensityFlow(nn.Module):
824
838
 
825
839
  # perturbation effect
826
840
  ps = np.ones_like(us_i)
827
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
-
829
- zs = zs + dzs0 + dzs
841
+ if np.sum(np.abs(ps-us_i))>=1:
842
+ dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
843
+ zs = zs + dzs0 + dzs
844
+ else:
845
+ zs = zs + dzs0
830
846
 
831
847
  if library_sizes is None:
832
848
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -905,33 +921,18 @@ class DensityFlow(nn.Module):
905
921
  R = np.concatenate(R)
906
922
  return R
907
923
 
908
- def _count(self,concentrate, library_size=None):
924
+ def _count(self, concentrate, library_size=None):
909
925
  if self.loss_func == 'bernoulli':
910
926
  #counts = self.sigmoid(concentrate)
911
927
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
912
928
  else:
913
929
  rate = concentrate.exp()
914
930
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
915
- if self.loss_func in ['poisson','multinomial']:
916
- rate = theta * library_size
917
- counts = dist.Poisson(rate=rate).to_event(1).mean
918
- elif self.loss_func == 'negbinomial':
919
- total_count = self.inverse_dispersion
920
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
921
- return counts
922
-
923
- def _count_sample(self,concentrate):
924
- if self.loss_func == 'bernoulli':
925
- logits = concentrate
926
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
927
- else:
928
- counts = self._count(concentrate=concentrate)
929
- counts = dist.Poisson(rate=counts).to_event(1).sample()
931
+ counts = theta * library_size
930
932
  return counts
931
933
 
932
934
  def get_counts(self, zs, library_sizes,
933
- batch_size: int = 1024,
934
- use_sampler: bool = False):
935
+ batch_size: int = 1024):
935
936
 
936
937
  zs = convert_to_tensor(zs, device=self.get_device())
937
938
 
@@ -948,10 +949,7 @@ class DensityFlow(nn.Module):
948
949
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
949
950
  for Z_batch, L_batch, _ in dataloader:
950
951
  concentrate = self._get_expression_response(Z_batch)
951
- if use_sampler:
952
- counts = self._count_sample(concentrate)
953
- else:
954
- counts = self._count(concentrate, L_batch)
952
+ counts = self._count(concentrate, L_batch)
955
953
  E.append(tensor_to_numpy(counts))
956
954
  pbar.update(1)
957
955
 
@@ -1096,9 +1094,6 @@ class DensityFlow(nn.Module):
1096
1094
  # Update progress bar
1097
1095
  pbar.set_postfix({'loss': str_loss})
1098
1096
  pbar.update(1)
1099
-
1100
- if self.loss_func == 'negbinomial':
1101
- self.inverse_dispersion = pyro.param("inverse_dispersion")
1102
1097
 
1103
1098
  @classmethod
1104
1099
  def save_model(cls, model, file_path, compression=False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.4
3
+ Version: 2.2.10
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.2.4',
8
+ version='2.2.10',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes
File without changes