SURE-tools 2.2.28__tar.gz → 2.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (31) hide show
  1. {sure_tools-2.2.28 → sure_tools-2.3.2}/PKG-INFO +1 -1
  2. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/DensityFlow.py +15 -8
  3. sure_tools-2.3.2/SURE/PerturbE.py +1293 -0
  4. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/__init__.py +2 -1
  5. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/utils/custom_mlp.py +31 -0
  6. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/PKG-INFO +1 -1
  7. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/SOURCES.txt +1 -0
  8. {sure_tools-2.2.28 → sure_tools-2.3.2}/setup.py +1 -1
  9. {sure_tools-2.2.28 → sure_tools-2.3.2}/LICENSE +0 -0
  10. {sure_tools-2.2.28 → sure_tools-2.3.2}/README.md +0 -0
  11. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/SURE.py +0 -0
  12. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/flow/flow_stats.py +0 -0
  21. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/perturb/__init__.py +0 -0
  23. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/perturb/perturb.py +0 -0
  24. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/utils/__init__.py +0 -0
  25. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/utils/queue.py +0 -0
  26. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE/utils/utils.py +0 -0
  27. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/dependency_links.txt +0 -0
  28. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/entry_points.txt +0 -0
  29. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/requires.txt +0 -0
  30. {sure_tools-2.2.28 → sure_tools-2.3.2}/SURE_tools.egg-info/top_level.txt +0 -0
  31. {sure_tools-2.2.28 → sure_tools-2.3.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.28
3
+ Version: 2.3.2
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -396,7 +396,8 @@ class DensityFlow(nn.Module):
396
396
  else:
397
397
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
398
  elif self.loss_func == 'multinomial':
399
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
399
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
400
401
  elif self.loss_func == 'bernoulli':
401
402
  if self.use_zeroinflate:
402
403
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -480,7 +481,8 @@ class DensityFlow(nn.Module):
480
481
  else:
481
482
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
482
483
  elif self.loss_func == 'multinomial':
483
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
484
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
484
486
  elif self.loss_func == 'bernoulli':
485
487
  if self.use_zeroinflate:
486
488
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -576,7 +578,8 @@ class DensityFlow(nn.Module):
576
578
  else:
577
579
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
578
580
  elif self.loss_func == 'multinomial':
579
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
581
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
580
583
  elif self.loss_func == 'bernoulli':
581
584
  if self.use_zeroinflate:
582
585
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -682,7 +685,8 @@ class DensityFlow(nn.Module):
682
685
  else:
683
686
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
684
687
  elif self.loss_func == 'multinomial':
685
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
688
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
686
690
  elif self.loss_func == 'bernoulli':
687
691
  if self.use_zeroinflate:
688
692
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -946,6 +950,9 @@ class DensityFlow(nn.Module):
946
950
  if self.loss_func == 'bernoulli':
947
951
  #counts = self.sigmoid(concentrate)
948
952
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
953
+ elif self.loss_func == 'multinomial':
954
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
955
+ counts = theta * library_size
949
956
  else:
950
957
  rate = concentrate.exp()
951
958
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -1340,7 +1347,7 @@ def main():
1340
1347
  cell_factor_size = 0 if us is None else us.shape[1]
1341
1348
 
1342
1349
  ###########################################
1343
- DensityFlow = DensityFlow(
1350
+ df = DensityFlow(
1344
1351
  input_size=input_size,
1345
1352
  cell_factor_size=cell_factor_size,
1346
1353
  inverse_dispersion=args.inverse_dispersion,
@@ -1359,7 +1366,7 @@ def main():
1359
1366
  dtype=dtype,
1360
1367
  )
1361
1368
 
1362
- DensityFlow.fit(xs, us=us,
1369
+ df.fit(xs, us=us,
1363
1370
  num_epochs=args.num_epochs,
1364
1371
  learning_rate=args.learning_rate,
1365
1372
  batch_size=args.batch_size,
@@ -1371,9 +1378,9 @@ def main():
1371
1378
 
1372
1379
  if args.save_model is not None:
1373
1380
  if args.save_model.endswith('gz'):
1374
- DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1381
+ DensityFlow.save_model(df, args.save_model, compression=True)
1375
1382
  else:
1376
- DensityFlow.save_model(DensityFlow, args.save_model)
1383
+ DensityFlow.save_model(df, args.save_model)
1377
1384
 
1378
1385
 
1379
1386