SURE-tools 2.1.58__py3-none-any.whl → 2.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +32 -40
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/METADATA +1 -1
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/RECORD +7 -7
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.58.dist-info → sure_tools-2.1.60.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -317,8 +317,6 @@ class PerturbFlow(nn.Module):
|
|
|
317
317
|
if self.loss_func=='negbinomial':
|
|
318
318
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
319
319
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
320
|
-
elif self.loss_func == 'multinomial':
|
|
321
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
322
320
|
|
|
323
321
|
if self.use_zeroinflate:
|
|
324
322
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -353,25 +351,26 @@ class PerturbFlow(nn.Module):
|
|
|
353
351
|
|
|
354
352
|
zs = zns
|
|
355
353
|
concentrate = self.decoder_concentrate(zs)
|
|
356
|
-
if self.loss_func in ['bernoulli'
|
|
354
|
+
if self.loss_func in ['bernoulli']:
|
|
357
355
|
log_theta = concentrate
|
|
358
356
|
else:
|
|
359
357
|
rate = concentrate.exp()
|
|
360
|
-
|
|
361
|
-
|
|
358
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
359
|
+
if self.loss_func == 'poisson':
|
|
360
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
362
361
|
|
|
363
362
|
if self.loss_func == 'negbinomial':
|
|
364
363
|
if self.use_zeroinflate:
|
|
365
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
364
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
366
365
|
else:
|
|
367
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
366
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
368
367
|
elif self.loss_func == 'poisson':
|
|
369
368
|
if self.use_zeroinflate:
|
|
370
369
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
371
370
|
else:
|
|
372
371
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
373
372
|
elif self.loss_func == 'multinomial':
|
|
374
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
373
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
375
374
|
elif self.loss_func == 'bernoulli':
|
|
376
375
|
if self.use_zeroinflate:
|
|
377
376
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -397,8 +396,6 @@ class PerturbFlow(nn.Module):
|
|
|
397
396
|
if self.loss_func=='negbinomial':
|
|
398
397
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
399
398
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
400
|
-
elif self.loss_func == 'multinomial':
|
|
401
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
402
399
|
|
|
403
400
|
if self.use_zeroinflate:
|
|
404
401
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -432,37 +429,32 @@ class PerturbFlow(nn.Module):
|
|
|
432
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
433
430
|
|
|
434
431
|
if self.cell_factor_size>0:
|
|
435
|
-
#zus = None
|
|
436
|
-
#for i in np.arange(self.cell_factor_size):
|
|
437
|
-
# if i==0:
|
|
438
|
-
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
439
|
-
# else:
|
|
440
|
-
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
441
432
|
zus = self._total_effects(zns, us)
|
|
442
433
|
zs = zns+zus
|
|
443
434
|
else:
|
|
444
435
|
zs = zns
|
|
445
436
|
|
|
446
437
|
concentrate = self.decoder_concentrate(zs)
|
|
447
|
-
if self.loss_func in ['bernoulli'
|
|
438
|
+
if self.loss_func in ['bernoulli']:
|
|
448
439
|
log_theta = concentrate
|
|
449
440
|
else:
|
|
450
441
|
rate = concentrate.exp()
|
|
451
|
-
|
|
452
|
-
|
|
442
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
443
|
+
if self.loss_func == 'poisson':
|
|
444
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
453
445
|
|
|
454
446
|
if self.loss_func == 'negbinomial':
|
|
455
447
|
if self.use_zeroinflate:
|
|
456
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
448
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
457
449
|
else:
|
|
458
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
450
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
459
451
|
elif self.loss_func == 'poisson':
|
|
460
452
|
if self.use_zeroinflate:
|
|
461
453
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
462
454
|
else:
|
|
463
455
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
464
456
|
elif self.loss_func == 'multinomial':
|
|
465
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
457
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
466
458
|
elif self.loss_func == 'bernoulli':
|
|
467
459
|
if self.use_zeroinflate:
|
|
468
460
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -488,8 +480,6 @@ class PerturbFlow(nn.Module):
|
|
|
488
480
|
if self.loss_func=='negbinomial':
|
|
489
481
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
490
482
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
491
|
-
elif self.loss_func == 'multinomial':
|
|
492
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
493
483
|
|
|
494
484
|
if self.use_zeroinflate:
|
|
495
485
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -541,25 +531,26 @@ class PerturbFlow(nn.Module):
|
|
|
541
531
|
zs = zns
|
|
542
532
|
|
|
543
533
|
concentrate = self.decoder_concentrate(zs)
|
|
544
|
-
if self.loss_func in ['bernoulli'
|
|
534
|
+
if self.loss_func in ['bernoulli']:
|
|
545
535
|
log_theta = concentrate
|
|
546
536
|
else:
|
|
547
537
|
rate = concentrate.exp()
|
|
548
|
-
|
|
549
|
-
|
|
538
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
539
|
+
if self.loss_func == 'poisson':
|
|
540
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
550
541
|
|
|
551
542
|
if self.loss_func == 'negbinomial':
|
|
552
543
|
if self.use_zeroinflate:
|
|
553
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
544
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
554
545
|
else:
|
|
555
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
546
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
556
547
|
elif self.loss_func == 'poisson':
|
|
557
548
|
if self.use_zeroinflate:
|
|
558
549
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
559
550
|
else:
|
|
560
551
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
561
552
|
elif self.loss_func == 'multinomial':
|
|
562
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
553
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
563
554
|
elif self.loss_func == 'bernoulli':
|
|
564
555
|
if self.use_zeroinflate:
|
|
565
556
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -585,8 +576,6 @@ class PerturbFlow(nn.Module):
|
|
|
585
576
|
if self.loss_func=='negbinomial':
|
|
586
577
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
587
578
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
588
|
-
elif self.loss_func == 'multinomial':
|
|
589
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
590
579
|
|
|
591
580
|
if self.use_zeroinflate:
|
|
592
581
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -648,25 +637,26 @@ class PerturbFlow(nn.Module):
|
|
|
648
637
|
zs = zns
|
|
649
638
|
|
|
650
639
|
concentrate = self.decoder_concentrate(zs)
|
|
651
|
-
if self.loss_func in ['bernoulli'
|
|
640
|
+
if self.loss_func in ['bernoulli']:
|
|
652
641
|
log_theta = concentrate
|
|
653
642
|
else:
|
|
654
643
|
rate = concentrate.exp()
|
|
655
|
-
|
|
656
|
-
|
|
644
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
645
|
+
if self.loss_func == 'poisson':
|
|
646
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
657
647
|
|
|
658
648
|
if self.loss_func == 'negbinomial':
|
|
659
649
|
if self.use_zeroinflate:
|
|
660
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
650
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
661
651
|
else:
|
|
662
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
652
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
663
653
|
elif self.loss_func == 'poisson':
|
|
664
654
|
if self.use_zeroinflate:
|
|
665
655
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
666
656
|
else:
|
|
667
657
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
668
658
|
elif self.loss_func == 'multinomial':
|
|
669
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
659
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
670
660
|
elif self.loss_func == 'bernoulli':
|
|
671
661
|
if self.use_zeroinflate:
|
|
672
662
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -895,6 +885,8 @@ class PerturbFlow(nn.Module):
|
|
|
895
885
|
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
896
886
|
elif self.loss_func == 'poisson':
|
|
897
887
|
rate = concentrate.exp()
|
|
888
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
889
|
+
rate = theta * library_size
|
|
898
890
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
899
891
|
elif self.loss_func == 'multinomial':
|
|
900
892
|
rate = concentrate.exp()
|
|
@@ -919,7 +911,7 @@ class PerturbFlow(nn.Module):
|
|
|
919
911
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
920
912
|
ls = zs
|
|
921
913
|
|
|
922
|
-
if self.loss_func
|
|
914
|
+
if self.loss_func in ['multinomial','poisson']:
|
|
923
915
|
assert library_sizes!=None, 'Library sizes are required for multinomial!'
|
|
924
916
|
|
|
925
917
|
if type(library_sizes) == list:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=CtM6FG6s6nn0VAG0-Ssi6yntcNxoWQaolXBH-JX-zjM,54358
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.60.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.60.dist-info/METADATA,sha256=g48MZci8NumHGcOhtuANmkpEfB19Et_EkxfDtFgFtB0,2678
|
|
22
|
+
sure_tools-2.1.60.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.60.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.60.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.60.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|