SURE-tools 2.1.53__tar.gz → 2.1.54__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.53 → sure_tools-2.1.54}/PKG-INFO +1 -1
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/PerturbFlow.py +121 -37
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.53 → sure_tools-2.1.54}/setup.py +1 -1
- {sure_tools-2.1.53 → sure_tools-2.1.54}/LICENSE +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/README.md +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/SURE.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.53 → sure_tools-2.1.54}/setup.cfg +0 -0
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -225,15 +225,36 @@ class PerturbFlow(nn.Module):
|
|
|
225
225
|
)
|
|
226
226
|
)
|
|
227
227
|
|
|
228
|
-
self.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
228
|
+
if self.loss_func == 'gamma-poisson':
|
|
229
|
+
self.decoder_concentrate = MLP(
|
|
230
|
+
[self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
231
|
+
activation=activate_fct,
|
|
232
|
+
output_activation=[Exp,Exp],
|
|
233
|
+
post_layer_fct=post_layer_fct,
|
|
234
|
+
post_act_fct=post_act_fct,
|
|
235
|
+
allow_broadcast=self.allow_broadcast,
|
|
236
|
+
use_cuda=self.use_cuda,
|
|
237
|
+
)
|
|
238
|
+
#self.encoder_concentrate = MLP(
|
|
239
|
+
# [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
240
|
+
# activation=activate_fct,
|
|
241
|
+
# output_activation=[Exp,Exp],
|
|
242
|
+
# post_layer_fct=post_layer_fct,
|
|
243
|
+
# post_act_fct=post_act_fct,
|
|
244
|
+
# allow_broadcast=self.allow_broadcast,
|
|
245
|
+
# use_cuda=self.use_cuda,
|
|
246
|
+
# )
|
|
247
|
+
self.encoder_concentrate = self.decoder_concentrate
|
|
248
|
+
else:
|
|
249
|
+
self.decoder_concentrate = MLP(
|
|
250
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
251
|
+
activation=activate_fct,
|
|
252
|
+
output_activation=None,
|
|
253
|
+
post_layer_fct=post_layer_fct,
|
|
254
|
+
post_act_fct=post_act_fct,
|
|
255
|
+
allow_broadcast=self.allow_broadcast,
|
|
256
|
+
use_cuda=self.use_cuda,
|
|
257
|
+
)
|
|
237
258
|
|
|
238
259
|
if self.latent_dist == 'studentt':
|
|
239
260
|
self.codebook = MLP(
|
|
@@ -352,13 +373,17 @@ class PerturbFlow(nn.Module):
|
|
|
352
373
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
353
374
|
|
|
354
375
|
zs = zns
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
376
|
+
if self.loss_func == 'gamma-poisson':
|
|
377
|
+
con_alpha,con_beta = self.decoder_concentrate(zs)
|
|
378
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
358
379
|
else:
|
|
359
|
-
|
|
360
|
-
if self.loss_func
|
|
361
|
-
|
|
380
|
+
concentrate = self.decoder_concentrate(zs)
|
|
381
|
+
if self.loss_func == 'bernoulli':
|
|
382
|
+
log_theta = concentrate
|
|
383
|
+
else:
|
|
384
|
+
rate = concentrate.exp()
|
|
385
|
+
if self.loss_func != 'poisson':
|
|
386
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
362
387
|
|
|
363
388
|
if self.loss_func == 'negbinomial':
|
|
364
389
|
if self.use_zeroinflate:
|
|
@@ -370,6 +395,11 @@ class PerturbFlow(nn.Module):
|
|
|
370
395
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
371
396
|
else:
|
|
372
397
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
|
+
elif self.loss_func == 'gamma-poisson':
|
|
399
|
+
if self.use_zeroinflate:
|
|
400
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
401
|
+
else:
|
|
402
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
373
403
|
elif self.loss_func == 'multinomial':
|
|
374
404
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
375
405
|
elif self.loss_func == 'bernoulli':
|
|
@@ -386,6 +416,10 @@ class PerturbFlow(nn.Module):
|
|
|
386
416
|
|
|
387
417
|
alpha = self.encoder_n(zns)
|
|
388
418
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
419
|
+
|
|
420
|
+
if self.loss_func == 'gamma-poisson':
|
|
421
|
+
con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
422
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
389
423
|
|
|
390
424
|
def model2(self, xs, us=None):
|
|
391
425
|
pyro.module('PerturbFlow', self)
|
|
@@ -443,13 +477,17 @@ class PerturbFlow(nn.Module):
|
|
|
443
477
|
else:
|
|
444
478
|
zs = zns
|
|
445
479
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
480
|
+
if self.loss_func == 'gamma-poisson':
|
|
481
|
+
con_alpha,con_beta = self.decoder_concentrate(zs)
|
|
482
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
449
483
|
else:
|
|
450
|
-
|
|
451
|
-
if self.loss_func
|
|
452
|
-
|
|
484
|
+
concentrate = self.decoder_concentrate(zs)
|
|
485
|
+
if self.loss_func == 'bernoulli':
|
|
486
|
+
log_theta = concentrate
|
|
487
|
+
else:
|
|
488
|
+
rate = concentrate.exp()
|
|
489
|
+
if self.loss_func != 'poisson':
|
|
490
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
453
491
|
|
|
454
492
|
if self.loss_func == 'negbinomial':
|
|
455
493
|
if self.use_zeroinflate:
|
|
@@ -461,6 +499,11 @@ class PerturbFlow(nn.Module):
|
|
|
461
499
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
462
500
|
else:
|
|
463
501
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
502
|
+
elif self.loss_func == 'gamma-poisson':
|
|
503
|
+
if self.use_zeroinflate:
|
|
504
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
505
|
+
else:
|
|
506
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
464
507
|
elif self.loss_func == 'multinomial':
|
|
465
508
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
466
509
|
elif self.loss_func == 'bernoulli':
|
|
@@ -477,6 +520,10 @@ class PerturbFlow(nn.Module):
|
|
|
477
520
|
|
|
478
521
|
alpha = self.encoder_n(zns)
|
|
479
522
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
523
|
+
|
|
524
|
+
if self.loss_func == 'gamma-poisson':
|
|
525
|
+
con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
526
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
480
527
|
|
|
481
528
|
def model3(self, xs, ys, embeds=None):
|
|
482
529
|
pyro.module('PerturbFlow', self)
|
|
@@ -540,13 +587,17 @@ class PerturbFlow(nn.Module):
|
|
|
540
587
|
|
|
541
588
|
zs = zns
|
|
542
589
|
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
590
|
+
if self.loss_func == 'gamma-poisson':
|
|
591
|
+
con_alpha,con_beta = self.decoder_concentrate(zs)
|
|
592
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
546
593
|
else:
|
|
547
|
-
|
|
548
|
-
if self.loss_func
|
|
549
|
-
|
|
594
|
+
concentrate = self.decoder_concentrate(zs)
|
|
595
|
+
if self.loss_func == 'bernoulli':
|
|
596
|
+
log_theta = concentrate
|
|
597
|
+
else:
|
|
598
|
+
rate = concentrate.exp()
|
|
599
|
+
if self.loss_func != 'poisson':
|
|
600
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
550
601
|
|
|
551
602
|
if self.loss_func == 'negbinomial':
|
|
552
603
|
if self.use_zeroinflate:
|
|
@@ -558,6 +609,11 @@ class PerturbFlow(nn.Module):
|
|
|
558
609
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
559
610
|
else:
|
|
560
611
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
612
|
+
elif self.loss_func == 'gamma-poisson':
|
|
613
|
+
if self.use_zeroinflate:
|
|
614
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
615
|
+
else:
|
|
616
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
561
617
|
elif self.loss_func == 'multinomial':
|
|
562
618
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
563
619
|
elif self.loss_func == 'bernoulli':
|
|
@@ -572,6 +628,12 @@ class PerturbFlow(nn.Module):
|
|
|
572
628
|
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
573
629
|
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
574
630
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
631
|
+
else:
|
|
632
|
+
zns = embeds
|
|
633
|
+
|
|
634
|
+
if self.loss_func == 'gamma-poisson':
|
|
635
|
+
con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
636
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
575
637
|
|
|
576
638
|
def model4(self, xs, us, ys, embeds=None):
|
|
577
639
|
pyro.module('PerturbFlow', self)
|
|
@@ -645,13 +707,17 @@ class PerturbFlow(nn.Module):
|
|
|
645
707
|
else:
|
|
646
708
|
zs = zns
|
|
647
709
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
710
|
+
if self.loss_func == 'gamma-poisson':
|
|
711
|
+
con_alpha,con_beta = self.decoder_concentrate(zs)
|
|
712
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
651
713
|
else:
|
|
652
|
-
|
|
653
|
-
if self.loss_func
|
|
654
|
-
|
|
714
|
+
concentrate = self.decoder_concentrate(zs)
|
|
715
|
+
if self.loss_func == 'bernoulli':
|
|
716
|
+
log_theta = concentrate
|
|
717
|
+
else:
|
|
718
|
+
rate = concentrate.exp()
|
|
719
|
+
if self.loss_func != 'poisson':
|
|
720
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
655
721
|
|
|
656
722
|
if self.loss_func == 'negbinomial':
|
|
657
723
|
if self.use_zeroinflate:
|
|
@@ -663,6 +729,11 @@ class PerturbFlow(nn.Module):
|
|
|
663
729
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
664
730
|
else:
|
|
665
731
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
732
|
+
elif self.loss_func == 'gamma-poisson':
|
|
733
|
+
if self.use_zeroinflate:
|
|
734
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
735
|
+
else:
|
|
736
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
666
737
|
elif self.loss_func == 'multinomial':
|
|
667
738
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
668
739
|
elif self.loss_func == 'bernoulli':
|
|
@@ -677,6 +748,12 @@ class PerturbFlow(nn.Module):
|
|
|
677
748
|
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
678
749
|
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
679
750
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
751
|
+
else:
|
|
752
|
+
zns = embeds
|
|
753
|
+
|
|
754
|
+
if self.loss_func == 'gamma-poisson':
|
|
755
|
+
con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
756
|
+
rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
680
757
|
|
|
681
758
|
def _total_effects(self, zns, us):
|
|
682
759
|
zus = None
|
|
@@ -855,7 +932,12 @@ class PerturbFlow(nn.Module):
|
|
|
855
932
|
return tensor_to_numpy(ms)
|
|
856
933
|
|
|
857
934
|
def _get_expression_response(self, delta_zs):
|
|
858
|
-
|
|
935
|
+
if self.loss_func == 'gamma-poisson':
|
|
936
|
+
alpha,beta = self.encoder_concentrate(delta_zs)
|
|
937
|
+
xs = dist.Gamma(alpha,beta).to_event(1).mean
|
|
938
|
+
else:
|
|
939
|
+
xs = self.decoder_concentrate(delta_zs)
|
|
940
|
+
return xs
|
|
859
941
|
|
|
860
942
|
def get_expression_response(self,
|
|
861
943
|
delta_zs,
|
|
@@ -888,10 +970,12 @@ class PerturbFlow(nn.Module):
|
|
|
888
970
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
889
971
|
|
|
890
972
|
total_count = pyro.param("inverse_dispersion")
|
|
891
|
-
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1)
|
|
973
|
+
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
|
|
892
974
|
elif self.loss_func == 'poisson':
|
|
893
975
|
rate = concentrate.exp()
|
|
894
|
-
counts = dist.Poisson(rate=rate).to_event(1)
|
|
976
|
+
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
977
|
+
elif self.loss_func == 'gamma-poisson':
|
|
978
|
+
counts = dist.Poisson(rate=concentrate).to_event(1).mean
|
|
895
979
|
return counts
|
|
896
980
|
|
|
897
981
|
def _count_sample(self,concentrate):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|