SURE-tools 2.1.57__py3-none-any.whl → 2.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/PerturbFlow.py +27 -115
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/METADATA +1 -1
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/RECORD +7 -7
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.57.dist-info → sure_tools-2.1.58.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
|
|
|
225
225
|
)
|
|
226
226
|
)
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
self.decoder_concentrate = MLP(
|
|
230
|
-
[self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
231
|
-
activation=activate_fct,
|
|
232
|
-
output_activation=[Exp,Exp],
|
|
233
|
-
post_layer_fct=post_layer_fct,
|
|
234
|
-
post_act_fct=post_act_fct,
|
|
235
|
-
allow_broadcast=self.allow_broadcast,
|
|
236
|
-
use_cuda=self.use_cuda,
|
|
237
|
-
)
|
|
238
|
-
#self.encoder_concentrate = MLP(
|
|
239
|
-
# [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
240
|
-
# activation=activate_fct,
|
|
241
|
-
# output_activation=[Exp,Exp],
|
|
242
|
-
# post_layer_fct=post_layer_fct,
|
|
243
|
-
# post_act_fct=post_act_fct,
|
|
244
|
-
# allow_broadcast=self.allow_broadcast,
|
|
245
|
-
# use_cuda=self.use_cuda,
|
|
246
|
-
# )
|
|
247
|
-
#self.encoder_concentrate = self.decoder_concentrate
|
|
248
|
-
else:
|
|
249
|
-
self.decoder_concentrate = MLP(
|
|
228
|
+
self.decoder_concentrate = MLP(
|
|
250
229
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
251
230
|
activation=activate_fct,
|
|
252
231
|
output_activation=None,
|
|
@@ -373,19 +352,13 @@ class PerturbFlow(nn.Module):
|
|
|
373
352
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
374
353
|
|
|
375
354
|
zs = zns
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
355
|
+
concentrate = self.decoder_concentrate(zs)
|
|
356
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
357
|
+
log_theta = concentrate
|
|
379
358
|
else:
|
|
380
|
-
|
|
381
|
-
if self.loss_func
|
|
382
|
-
|
|
383
|
-
elif self.loss_func == 'negbinomial':
|
|
384
|
-
log_theta = concentrate
|
|
385
|
-
else:
|
|
386
|
-
rate = concentrate.exp()
|
|
387
|
-
if self.loss_func != 'poisson':
|
|
388
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
359
|
+
rate = concentrate.exp()
|
|
360
|
+
if self.loss_func != 'poisson':
|
|
361
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
389
362
|
|
|
390
363
|
if self.loss_func == 'negbinomial':
|
|
391
364
|
if self.use_zeroinflate:
|
|
@@ -397,11 +370,6 @@ class PerturbFlow(nn.Module):
|
|
|
397
370
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
398
371
|
else:
|
|
399
372
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
400
|
-
elif self.loss_func == 'gamma-poisson':
|
|
401
|
-
if self.use_zeroinflate:
|
|
402
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
|
-
else:
|
|
404
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
373
|
elif self.loss_func == 'multinomial':
|
|
406
374
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
407
375
|
elif self.loss_func == 'bernoulli':
|
|
@@ -418,10 +386,6 @@ class PerturbFlow(nn.Module):
|
|
|
418
386
|
|
|
419
387
|
alpha = self.encoder_n(zns)
|
|
420
388
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
421
|
-
|
|
422
|
-
#if self.loss_func == 'gamma-poisson':
|
|
423
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
424
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
425
389
|
|
|
426
390
|
def model2(self, xs, us=None):
|
|
427
391
|
pyro.module('PerturbFlow', self)
|
|
@@ -479,19 +443,13 @@ class PerturbFlow(nn.Module):
|
|
|
479
443
|
else:
|
|
480
444
|
zs = zns
|
|
481
445
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
446
|
+
concentrate = self.decoder_concentrate(zs)
|
|
447
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
448
|
+
log_theta = concentrate
|
|
485
449
|
else:
|
|
486
|
-
|
|
487
|
-
if self.loss_func
|
|
488
|
-
|
|
489
|
-
elif self.loss_func == 'negbinomial':
|
|
490
|
-
log_theta = concentrate
|
|
491
|
-
else:
|
|
492
|
-
rate = concentrate.exp()
|
|
493
|
-
if self.loss_func != 'poisson':
|
|
494
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
450
|
+
rate = concentrate.exp()
|
|
451
|
+
if self.loss_func != 'poisson':
|
|
452
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
495
453
|
|
|
496
454
|
if self.loss_func == 'negbinomial':
|
|
497
455
|
if self.use_zeroinflate:
|
|
@@ -503,11 +461,6 @@ class PerturbFlow(nn.Module):
|
|
|
503
461
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
504
462
|
else:
|
|
505
463
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
506
|
-
elif self.loss_func == 'gamma-poisson':
|
|
507
|
-
if self.use_zeroinflate:
|
|
508
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
509
|
-
else:
|
|
510
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
511
464
|
elif self.loss_func == 'multinomial':
|
|
512
465
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
513
466
|
elif self.loss_func == 'bernoulli':
|
|
@@ -524,10 +477,6 @@ class PerturbFlow(nn.Module):
|
|
|
524
477
|
|
|
525
478
|
alpha = self.encoder_n(zns)
|
|
526
479
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
527
|
-
|
|
528
|
-
#if self.loss_func == 'gamma-poisson':
|
|
529
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
530
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
531
480
|
|
|
532
481
|
def model3(self, xs, ys, embeds=None):
|
|
533
482
|
pyro.module('PerturbFlow', self)
|
|
@@ -591,19 +540,13 @@ class PerturbFlow(nn.Module):
|
|
|
591
540
|
|
|
592
541
|
zs = zns
|
|
593
542
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
543
|
+
concentrate = self.decoder_concentrate(zs)
|
|
544
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
545
|
+
log_theta = concentrate
|
|
597
546
|
else:
|
|
598
|
-
|
|
599
|
-
if self.loss_func
|
|
600
|
-
|
|
601
|
-
elif self.loss_func == 'negbinomial':
|
|
602
|
-
log_theta = concentrate
|
|
603
|
-
else:
|
|
604
|
-
rate = concentrate.exp()
|
|
605
|
-
if self.loss_func != 'poisson':
|
|
606
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
547
|
+
rate = concentrate.exp()
|
|
548
|
+
if self.loss_func != 'poisson':
|
|
549
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
607
550
|
|
|
608
551
|
if self.loss_func == 'negbinomial':
|
|
609
552
|
if self.use_zeroinflate:
|
|
@@ -615,11 +558,6 @@ class PerturbFlow(nn.Module):
|
|
|
615
558
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
616
559
|
else:
|
|
617
560
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
618
|
-
elif self.loss_func == 'gamma-poisson':
|
|
619
|
-
if self.use_zeroinflate:
|
|
620
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
621
|
-
else:
|
|
622
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
623
561
|
elif self.loss_func == 'multinomial':
|
|
624
562
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
625
563
|
elif self.loss_func == 'bernoulli':
|
|
@@ -636,10 +574,6 @@ class PerturbFlow(nn.Module):
|
|
|
636
574
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
637
575
|
else:
|
|
638
576
|
zns = embeds
|
|
639
|
-
|
|
640
|
-
#if self.loss_func == 'gamma-poisson':
|
|
641
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
642
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
643
577
|
|
|
644
578
|
def model4(self, xs, us, ys, embeds=None):
|
|
645
579
|
pyro.module('PerturbFlow', self)
|
|
@@ -713,19 +647,13 @@ class PerturbFlow(nn.Module):
|
|
|
713
647
|
else:
|
|
714
648
|
zs = zns
|
|
715
649
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
650
|
+
concentrate = self.decoder_concentrate(zs)
|
|
651
|
+
if self.loss_func in ['bernoulli','negbinomial']:
|
|
652
|
+
log_theta = concentrate
|
|
719
653
|
else:
|
|
720
|
-
|
|
721
|
-
if self.loss_func
|
|
722
|
-
|
|
723
|
-
elif self.loss_func == 'negbinomial':
|
|
724
|
-
log_theta = concentrate
|
|
725
|
-
else:
|
|
726
|
-
rate = concentrate.exp()
|
|
727
|
-
if self.loss_func != 'poisson':
|
|
728
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
654
|
+
rate = concentrate.exp()
|
|
655
|
+
if self.loss_func != 'poisson':
|
|
656
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
729
657
|
|
|
730
658
|
if self.loss_func == 'negbinomial':
|
|
731
659
|
if self.use_zeroinflate:
|
|
@@ -737,11 +665,6 @@ class PerturbFlow(nn.Module):
|
|
|
737
665
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
738
666
|
else:
|
|
739
667
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
740
|
-
elif self.loss_func == 'gamma-poisson':
|
|
741
|
-
if self.use_zeroinflate:
|
|
742
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
743
|
-
else:
|
|
744
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
745
668
|
elif self.loss_func == 'multinomial':
|
|
746
669
|
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
747
670
|
elif self.loss_func == 'bernoulli':
|
|
@@ -758,10 +681,6 @@ class PerturbFlow(nn.Module):
|
|
|
758
681
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
759
682
|
else:
|
|
760
683
|
zns = embeds
|
|
761
|
-
|
|
762
|
-
#if self.loss_func == 'gamma-poisson':
|
|
763
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
764
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
765
684
|
|
|
766
685
|
def _total_effects(self, zns, us):
|
|
767
686
|
zus = None
|
|
@@ -940,12 +859,7 @@ class PerturbFlow(nn.Module):
|
|
|
940
859
|
return tensor_to_numpy(ms)
|
|
941
860
|
|
|
942
861
|
def _get_expression_response(self, delta_zs):
|
|
943
|
-
|
|
944
|
-
alpha,beta = self.decoder_concentrate(delta_zs)
|
|
945
|
-
xs = dist.Gamma(alpha,beta).to_event(1).mean
|
|
946
|
-
else:
|
|
947
|
-
xs = self.decoder_concentrate(delta_zs)
|
|
948
|
-
return xs
|
|
862
|
+
return self.decoder_concentrate(delta_zs)
|
|
949
863
|
|
|
950
864
|
def get_expression_response(self,
|
|
951
865
|
delta_zs,
|
|
@@ -982,8 +896,6 @@ class PerturbFlow(nn.Module):
|
|
|
982
896
|
elif self.loss_func == 'poisson':
|
|
983
897
|
rate = concentrate.exp()
|
|
984
898
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
985
|
-
elif self.loss_func == 'gamma-poisson':
|
|
986
|
-
counts = dist.Poisson(rate=concentrate).to_event(1).mean
|
|
987
899
|
elif self.loss_func == 'multinomial':
|
|
988
900
|
rate = concentrate.exp()
|
|
989
901
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=7vflCQ8mtX0jzDe5lEIVxF4zwgWJIJ9aEZ6lD1duv0E,54985
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.58.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.58.dist-info/METADATA,sha256=IKFJkaArfqXoAjczpEKYZworTX83okZYw7Kf8Bx430Y,2678
|
|
22
|
+
sure_tools-2.1.58.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.58.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.58.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.58.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|