SURE-tools 2.1.58__tar.gz → 2.1.59__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.1.58 → sure_tools-2.1.59}/PKG-INFO +1 -1
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/PerturbFlow.py +20 -28
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.58 → sure_tools-2.1.59}/setup.py +1 -1
- {sure_tools-2.1.58 → sure_tools-2.1.59}/LICENSE +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/README.md +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/SURE.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.58 → sure_tools-2.1.59}/setup.cfg +0 -0
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -317,8 +317,6 @@ class PerturbFlow(nn.Module):
|
|
|
317
317
|
if self.loss_func=='negbinomial':
|
|
318
318
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
319
319
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
320
|
-
elif self.loss_func == 'multinomial':
|
|
321
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
322
320
|
|
|
323
321
|
if self.use_zeroinflate:
|
|
324
322
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -357,8 +355,9 @@ class PerturbFlow(nn.Module):
|
|
|
357
355
|
log_theta = concentrate
|
|
358
356
|
else:
|
|
359
357
|
rate = concentrate.exp()
|
|
360
|
-
|
|
361
|
-
|
|
358
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
359
|
+
if self.loss_func == 'poisson':
|
|
360
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
362
361
|
|
|
363
362
|
if self.loss_func == 'negbinomial':
|
|
364
363
|
if self.use_zeroinflate:
|
|
@@ -371,7 +370,7 @@ class PerturbFlow(nn.Module):
|
|
|
371
370
|
else:
|
|
372
371
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
373
372
|
elif self.loss_func == 'multinomial':
|
|
374
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
373
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
375
374
|
elif self.loss_func == 'bernoulli':
|
|
376
375
|
if self.use_zeroinflate:
|
|
377
376
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -397,8 +396,6 @@ class PerturbFlow(nn.Module):
|
|
|
397
396
|
if self.loss_func=='negbinomial':
|
|
398
397
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
399
398
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
400
|
-
elif self.loss_func == 'multinomial':
|
|
401
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
402
399
|
|
|
403
400
|
if self.use_zeroinflate:
|
|
404
401
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -432,12 +429,6 @@ class PerturbFlow(nn.Module):
|
|
|
432
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
433
430
|
|
|
434
431
|
if self.cell_factor_size>0:
|
|
435
|
-
#zus = None
|
|
436
|
-
#for i in np.arange(self.cell_factor_size):
|
|
437
|
-
# if i==0:
|
|
438
|
-
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
439
|
-
# else:
|
|
440
|
-
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
441
432
|
zus = self._total_effects(zns, us)
|
|
442
433
|
zs = zns+zus
|
|
443
434
|
else:
|
|
@@ -448,8 +439,9 @@ class PerturbFlow(nn.Module):
|
|
|
448
439
|
log_theta = concentrate
|
|
449
440
|
else:
|
|
450
441
|
rate = concentrate.exp()
|
|
451
|
-
|
|
452
|
-
|
|
442
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
443
|
+
if self.loss_func == 'poisson':
|
|
444
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
453
445
|
|
|
454
446
|
if self.loss_func == 'negbinomial':
|
|
455
447
|
if self.use_zeroinflate:
|
|
@@ -462,7 +454,7 @@ class PerturbFlow(nn.Module):
|
|
|
462
454
|
else:
|
|
463
455
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
464
456
|
elif self.loss_func == 'multinomial':
|
|
465
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
457
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
466
458
|
elif self.loss_func == 'bernoulli':
|
|
467
459
|
if self.use_zeroinflate:
|
|
468
460
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -488,8 +480,6 @@ class PerturbFlow(nn.Module):
|
|
|
488
480
|
if self.loss_func=='negbinomial':
|
|
489
481
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
490
482
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
491
|
-
elif self.loss_func == 'multinomial':
|
|
492
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
493
483
|
|
|
494
484
|
if self.use_zeroinflate:
|
|
495
485
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -545,8 +535,9 @@ class PerturbFlow(nn.Module):
|
|
|
545
535
|
log_theta = concentrate
|
|
546
536
|
else:
|
|
547
537
|
rate = concentrate.exp()
|
|
548
|
-
|
|
549
|
-
|
|
538
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
539
|
+
if self.loss_func == 'poisson':
|
|
540
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
550
541
|
|
|
551
542
|
if self.loss_func == 'negbinomial':
|
|
552
543
|
if self.use_zeroinflate:
|
|
@@ -559,7 +550,7 @@ class PerturbFlow(nn.Module):
|
|
|
559
550
|
else:
|
|
560
551
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
561
552
|
elif self.loss_func == 'multinomial':
|
|
562
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
553
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
563
554
|
elif self.loss_func == 'bernoulli':
|
|
564
555
|
if self.use_zeroinflate:
|
|
565
556
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -585,8 +576,6 @@ class PerturbFlow(nn.Module):
|
|
|
585
576
|
if self.loss_func=='negbinomial':
|
|
586
577
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
587
578
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
588
|
-
elif self.loss_func == 'multinomial':
|
|
589
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
590
579
|
|
|
591
580
|
if self.use_zeroinflate:
|
|
592
581
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -652,8 +641,9 @@ class PerturbFlow(nn.Module):
|
|
|
652
641
|
log_theta = concentrate
|
|
653
642
|
else:
|
|
654
643
|
rate = concentrate.exp()
|
|
655
|
-
|
|
656
|
-
|
|
644
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
645
|
+
if self.loss_func == 'poisson':
|
|
646
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
657
647
|
|
|
658
648
|
if self.loss_func == 'negbinomial':
|
|
659
649
|
if self.use_zeroinflate:
|
|
@@ -666,7 +656,7 @@ class PerturbFlow(nn.Module):
|
|
|
666
656
|
else:
|
|
667
657
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
668
658
|
elif self.loss_func == 'multinomial':
|
|
669
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
659
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
670
660
|
elif self.loss_func == 'bernoulli':
|
|
671
661
|
if self.use_zeroinflate:
|
|
672
662
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -895,6 +885,8 @@ class PerturbFlow(nn.Module):
|
|
|
895
885
|
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
896
886
|
elif self.loss_func == 'poisson':
|
|
897
887
|
rate = concentrate.exp()
|
|
888
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
889
|
+
rate = theta * library_size
|
|
898
890
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
899
891
|
elif self.loss_func == 'multinomial':
|
|
900
892
|
rate = concentrate.exp()
|
|
@@ -919,7 +911,7 @@ class PerturbFlow(nn.Module):
|
|
|
919
911
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
920
912
|
ls = zs
|
|
921
913
|
|
|
922
|
-
if self.loss_func
|
|
914
|
+
if self.loss_func in ['multinomial','poisson']:
|
|
923
915
|
assert library_sizes!=None, 'Library sizes are required for multinomial!'
|
|
924
916
|
|
|
925
917
|
if type(library_sizes) == list:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|