SURE-tools 2.1.56__tar.gz → 2.1.58__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.56 → sure_tools-2.1.58}/PKG-INFO +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/PerturbFlow.py +38 -118
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.58}/setup.py +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.58}/LICENSE +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/README.md +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/SURE.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.58}/setup.cfg +0 -0
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
|
|
|
225
225
|
)
|
|
226
226
|
)
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
self.decoder_concentrate = MLP(
|
|
230
|
-
[self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
231
|
-
activation=activate_fct,
|
|
232
|
-
output_activation=[Exp,Exp],
|
|
233
|
-
post_layer_fct=post_layer_fct,
|
|
234
|
-
post_act_fct=post_act_fct,
|
|
235
|
-
allow_broadcast=self.allow_broadcast,
|
|
236
|
-
use_cuda=self.use_cuda,
|
|
237
|
-
)
|
|
238
|
-
#self.encoder_concentrate = MLP(
|
|
239
|
-
# [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
240
|
-
# activation=activate_fct,
|
|
241
|
-
# output_activation=[Exp,Exp],
|
|
242
|
-
# post_layer_fct=post_layer_fct,
|
|
243
|
-
# post_act_fct=post_act_fct,
|
|
244
|
-
# allow_broadcast=self.allow_broadcast,
|
|
245
|
-
# use_cuda=self.use_cuda,
|
|
246
|
-
# )
|
|
247
|
-
#self.encoder_concentrate = self.decoder_concentrate
|
|
248
|
-
else:
|
|
249
|
-
self.decoder_concentrate = MLP(
|
|
228
|
+
self.decoder_concentrate = MLP(
|
|
250
229
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
251
230
|
activation=activate_fct,
|
|
252
231
|
output_activation=None,
|
|
@@ -373,33 +352,24 @@ class PerturbFlow(nn.Module):
|
|
|
373
352
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
374
353
|
|
|
375
354
|
zs = zns
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
355
|
+
concentrate = self.decoder_concentrate(zs)
|
|
356
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
357
|
+
log_theta = concentrate
|
|
379
358
|
else:
|
|
380
|
-
|
|
381
|
-
if self.loss_func
|
|
382
|
-
|
|
383
|
-
else:
|
|
384
|
-
rate = concentrate.exp()
|
|
385
|
-
if self.loss_func != 'poisson':
|
|
386
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
359
|
+
rate = concentrate.exp()
|
|
360
|
+
if self.loss_func != 'poisson':
|
|
361
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
387
362
|
|
|
388
363
|
if self.loss_func == 'negbinomial':
|
|
389
364
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
365
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
366
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
367
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
393
368
|
elif self.loss_func == 'poisson':
|
|
394
369
|
if self.use_zeroinflate:
|
|
395
370
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
396
371
|
else:
|
|
397
372
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
|
-
elif self.loss_func == 'gamma-poisson':
|
|
399
|
-
if self.use_zeroinflate:
|
|
400
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
401
|
-
else:
|
|
402
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
403
373
|
elif self.loss_func == 'multinomial':
|
|
404
374
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
405
375
|
elif self.loss_func == 'bernoulli':
|
|
@@ -416,10 +386,6 @@ class PerturbFlow(nn.Module):
|
|
|
416
386
|
|
|
417
387
|
alpha = self.encoder_n(zns)
|
|
418
388
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
419
|
-
|
|
420
|
-
#if self.loss_func == 'gamma-poisson':
|
|
421
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
422
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
423
389
|
|
|
424
390
|
def model2(self, xs, us=None):
|
|
425
391
|
pyro.module('PerturbFlow', self)
|
|
@@ -477,33 +443,24 @@ class PerturbFlow(nn.Module):
|
|
|
477
443
|
else:
|
|
478
444
|
zs = zns
|
|
479
445
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
446
|
+
concentrate = self.decoder_concentrate(zs)
|
|
447
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
448
|
+
log_theta = concentrate
|
|
483
449
|
else:
|
|
484
|
-
|
|
485
|
-
if self.loss_func
|
|
486
|
-
|
|
487
|
-
else:
|
|
488
|
-
rate = concentrate.exp()
|
|
489
|
-
if self.loss_func != 'poisson':
|
|
490
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
450
|
+
rate = concentrate.exp()
|
|
451
|
+
if self.loss_func != 'poisson':
|
|
452
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
491
453
|
|
|
492
454
|
if self.loss_func == 'negbinomial':
|
|
493
455
|
if self.use_zeroinflate:
|
|
494
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
456
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
495
457
|
else:
|
|
496
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
458
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
497
459
|
elif self.loss_func == 'poisson':
|
|
498
460
|
if self.use_zeroinflate:
|
|
499
461
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
500
462
|
else:
|
|
501
463
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
502
|
-
elif self.loss_func == 'gamma-poisson':
|
|
503
|
-
if self.use_zeroinflate:
|
|
504
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
505
|
-
else:
|
|
506
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
507
464
|
elif self.loss_func == 'multinomial':
|
|
508
465
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
509
466
|
elif self.loss_func == 'bernoulli':
|
|
@@ -520,10 +477,6 @@ class PerturbFlow(nn.Module):
|
|
|
520
477
|
|
|
521
478
|
alpha = self.encoder_n(zns)
|
|
522
479
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
523
|
-
|
|
524
|
-
#if self.loss_func == 'gamma-poisson':
|
|
525
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
526
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
527
480
|
|
|
528
481
|
def model3(self, xs, ys, embeds=None):
|
|
529
482
|
pyro.module('PerturbFlow', self)
|
|
@@ -587,33 +540,24 @@ class PerturbFlow(nn.Module):
|
|
|
587
540
|
|
|
588
541
|
zs = zns
|
|
589
542
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
543
|
+
concentrate = self.decoder_concentrate(zs)
|
|
544
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
545
|
+
log_theta = concentrate
|
|
593
546
|
else:
|
|
594
|
-
|
|
595
|
-
if self.loss_func
|
|
596
|
-
|
|
597
|
-
else:
|
|
598
|
-
rate = concentrate.exp()
|
|
599
|
-
if self.loss_func != 'poisson':
|
|
600
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
547
|
+
rate = concentrate.exp()
|
|
548
|
+
if self.loss_func != 'poisson':
|
|
549
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
601
550
|
|
|
602
551
|
if self.loss_func == 'negbinomial':
|
|
603
552
|
if self.use_zeroinflate:
|
|
604
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
553
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
605
554
|
else:
|
|
606
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
555
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
607
556
|
elif self.loss_func == 'poisson':
|
|
608
557
|
if self.use_zeroinflate:
|
|
609
558
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
610
559
|
else:
|
|
611
560
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
612
|
-
elif self.loss_func == 'gamma-poisson':
|
|
613
|
-
if self.use_zeroinflate:
|
|
614
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
615
|
-
else:
|
|
616
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
617
561
|
elif self.loss_func == 'multinomial':
|
|
618
562
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
619
563
|
elif self.loss_func == 'bernoulli':
|
|
@@ -630,10 +574,6 @@ class PerturbFlow(nn.Module):
|
|
|
630
574
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
631
575
|
else:
|
|
632
576
|
zns = embeds
|
|
633
|
-
|
|
634
|
-
#if self.loss_func == 'gamma-poisson':
|
|
635
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
636
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
637
577
|
|
|
638
578
|
def model4(self, xs, us, ys, embeds=None):
|
|
639
579
|
pyro.module('PerturbFlow', self)
|
|
@@ -707,33 +647,24 @@ class PerturbFlow(nn.Module):
|
|
|
707
647
|
else:
|
|
708
648
|
zs = zns
|
|
709
649
|
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
650
|
+
concentrate = self.decoder_concentrate(zs)
|
|
651
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
652
|
+
log_theta = concentrate
|
|
713
653
|
else:
|
|
714
|
-
|
|
715
|
-
if self.loss_func
|
|
716
|
-
|
|
717
|
-
else:
|
|
718
|
-
rate = concentrate.exp()
|
|
719
|
-
if self.loss_func != 'poisson':
|
|
720
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
654
|
+
rate = concentrate.exp()
|
|
655
|
+
if self.loss_func != 'poisson':
|
|
656
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
721
657
|
|
|
722
658
|
if self.loss_func == 'negbinomial':
|
|
723
659
|
if self.use_zeroinflate:
|
|
724
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
660
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
725
661
|
else:
|
|
726
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
662
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
727
663
|
elif self.loss_func == 'poisson':
|
|
728
664
|
if self.use_zeroinflate:
|
|
729
665
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
730
666
|
else:
|
|
731
667
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
732
|
-
elif self.loss_func == 'gamma-poisson':
|
|
733
|
-
if self.use_zeroinflate:
|
|
734
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
735
|
-
else:
|
|
736
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
737
668
|
elif self.loss_func == 'multinomial':
|
|
738
669
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
739
670
|
elif self.loss_func == 'bernoulli':
|
|
@@ -750,10 +681,6 @@ class PerturbFlow(nn.Module):
|
|
|
750
681
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
751
682
|
else:
|
|
752
683
|
zns = embeds
|
|
753
|
-
|
|
754
|
-
#if self.loss_func == 'gamma-poisson':
|
|
755
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
756
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
757
684
|
|
|
758
685
|
def _total_effects(self, zns, us):
|
|
759
686
|
zus = None
|
|
@@ -932,12 +859,7 @@ class PerturbFlow(nn.Module):
|
|
|
932
859
|
return tensor_to_numpy(ms)
|
|
933
860
|
|
|
934
861
|
def _get_expression_response(self, delta_zs):
|
|
935
|
-
|
|
936
|
-
alpha,beta = self.decoder_concentrate(delta_zs)
|
|
937
|
-
xs = dist.Gamma(alpha,beta).to_event(1).mean
|
|
938
|
-
else:
|
|
939
|
-
xs = self.decoder_concentrate(delta_zs)
|
|
940
|
-
return xs
|
|
862
|
+
return self.decoder_concentrate(delta_zs)
|
|
941
863
|
|
|
942
864
|
def get_expression_response(self,
|
|
943
865
|
delta_zs,
|
|
@@ -966,16 +888,14 @@ class PerturbFlow(nn.Module):
|
|
|
966
888
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
967
889
|
elif self.loss_func == 'negbinomial':
|
|
968
890
|
#counts = concentrate.exp()
|
|
969
|
-
rate = concentrate.exp()
|
|
970
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
891
|
+
#rate = concentrate.exp()
|
|
892
|
+
#theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
971
893
|
|
|
972
894
|
total_count = pyro.param("inverse_dispersion")
|
|
973
|
-
counts = dist.NegativeBinomial(total_count=total_count,
|
|
895
|
+
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
974
896
|
elif self.loss_func == 'poisson':
|
|
975
897
|
rate = concentrate.exp()
|
|
976
898
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
977
|
-
elif self.loss_func == 'gamma-poisson':
|
|
978
|
-
counts = dist.Poisson(rate=concentrate).to_event(1).mean
|
|
979
899
|
elif self.loss_func == 'multinomial':
|
|
980
900
|
rate = concentrate.exp()
|
|
981
901
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|