SURE-tools 2.1.56__tar.gz → 2.1.64__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.56 → sure_tools-2.1.64}/PKG-INFO +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/PerturbFlow.py +44 -131
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.64}/setup.py +1 -1
- {sure_tools-2.1.56 → sure_tools-2.1.64}/LICENSE +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/README.md +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/SURE.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.56 → sure_tools-2.1.64}/setup.cfg +0 -0
|
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
|
|
|
225
225
|
)
|
|
226
226
|
)
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
self.decoder_concentrate = MLP(
|
|
230
|
-
[self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
231
|
-
activation=activate_fct,
|
|
232
|
-
output_activation=[Exp,Exp],
|
|
233
|
-
post_layer_fct=post_layer_fct,
|
|
234
|
-
post_act_fct=post_act_fct,
|
|
235
|
-
allow_broadcast=self.allow_broadcast,
|
|
236
|
-
use_cuda=self.use_cuda,
|
|
237
|
-
)
|
|
238
|
-
#self.encoder_concentrate = MLP(
|
|
239
|
-
# [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
|
|
240
|
-
# activation=activate_fct,
|
|
241
|
-
# output_activation=[Exp,Exp],
|
|
242
|
-
# post_layer_fct=post_layer_fct,
|
|
243
|
-
# post_act_fct=post_act_fct,
|
|
244
|
-
# allow_broadcast=self.allow_broadcast,
|
|
245
|
-
# use_cuda=self.use_cuda,
|
|
246
|
-
# )
|
|
247
|
-
#self.encoder_concentrate = self.decoder_concentrate
|
|
248
|
-
else:
|
|
249
|
-
self.decoder_concentrate = MLP(
|
|
228
|
+
self.decoder_concentrate = MLP(
|
|
250
229
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
251
230
|
activation=activate_fct,
|
|
252
231
|
output_activation=None,
|
|
@@ -338,8 +317,6 @@ class PerturbFlow(nn.Module):
|
|
|
338
317
|
if self.loss_func=='negbinomial':
|
|
339
318
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
340
319
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
341
|
-
elif self.loss_func == 'multinomial':
|
|
342
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
343
320
|
|
|
344
321
|
if self.use_zeroinflate:
|
|
345
322
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -373,17 +350,14 @@ class PerturbFlow(nn.Module):
|
|
|
373
350
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
374
351
|
|
|
375
352
|
zs = zns
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
353
|
+
concentrate = self.decoder_concentrate(zs)
|
|
354
|
+
if self.loss_func in ['bernoulli']:
|
|
355
|
+
log_theta = concentrate
|
|
379
356
|
else:
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
rate = concentrate.exp()
|
|
385
|
-
if self.loss_func != 'poisson':
|
|
386
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
357
|
+
rate = concentrate.exp()
|
|
358
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
359
|
+
if self.loss_func == 'poisson':
|
|
360
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
387
361
|
|
|
388
362
|
if self.loss_func == 'negbinomial':
|
|
389
363
|
if self.use_zeroinflate:
|
|
@@ -395,13 +369,8 @@ class PerturbFlow(nn.Module):
|
|
|
395
369
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
396
370
|
else:
|
|
397
371
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
|
-
elif self.loss_func == 'gamma-poisson':
|
|
399
|
-
if self.use_zeroinflate:
|
|
400
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
401
|
-
else:
|
|
402
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
403
372
|
elif self.loss_func == 'multinomial':
|
|
404
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
373
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
405
374
|
elif self.loss_func == 'bernoulli':
|
|
406
375
|
if self.use_zeroinflate:
|
|
407
376
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -416,10 +385,6 @@ class PerturbFlow(nn.Module):
|
|
|
416
385
|
|
|
417
386
|
alpha = self.encoder_n(zns)
|
|
418
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
419
|
-
|
|
420
|
-
#if self.loss_func == 'gamma-poisson':
|
|
421
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
422
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
423
388
|
|
|
424
389
|
def model2(self, xs, us=None):
|
|
425
390
|
pyro.module('PerturbFlow', self)
|
|
@@ -431,8 +396,6 @@ class PerturbFlow(nn.Module):
|
|
|
431
396
|
if self.loss_func=='negbinomial':
|
|
432
397
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
433
398
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
434
|
-
elif self.loss_func == 'multinomial':
|
|
435
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
436
399
|
|
|
437
400
|
if self.use_zeroinflate:
|
|
438
401
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -466,28 +429,19 @@ class PerturbFlow(nn.Module):
|
|
|
466
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
467
430
|
|
|
468
431
|
if self.cell_factor_size>0:
|
|
469
|
-
#zus = None
|
|
470
|
-
#for i in np.arange(self.cell_factor_size):
|
|
471
|
-
# if i==0:
|
|
472
|
-
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
473
|
-
# else:
|
|
474
|
-
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
475
432
|
zus = self._total_effects(zns, us)
|
|
476
433
|
zs = zns+zus
|
|
477
434
|
else:
|
|
478
435
|
zs = zns
|
|
479
436
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
437
|
+
concentrate = self.decoder_concentrate(zs)
|
|
438
|
+
if self.loss_func in ['bernoulli']:
|
|
439
|
+
log_theta = concentrate
|
|
483
440
|
else:
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
rate = concentrate.exp()
|
|
489
|
-
if self.loss_func != 'poisson':
|
|
490
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
441
|
+
rate = concentrate.exp()
|
|
442
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
443
|
+
if self.loss_func == 'poisson':
|
|
444
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
491
445
|
|
|
492
446
|
if self.loss_func == 'negbinomial':
|
|
493
447
|
if self.use_zeroinflate:
|
|
@@ -499,13 +453,8 @@ class PerturbFlow(nn.Module):
|
|
|
499
453
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
500
454
|
else:
|
|
501
455
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
502
|
-
elif self.loss_func == 'gamma-poisson':
|
|
503
|
-
if self.use_zeroinflate:
|
|
504
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
505
|
-
else:
|
|
506
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
507
456
|
elif self.loss_func == 'multinomial':
|
|
508
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
457
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
509
458
|
elif self.loss_func == 'bernoulli':
|
|
510
459
|
if self.use_zeroinflate:
|
|
511
460
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -520,10 +469,6 @@ class PerturbFlow(nn.Module):
|
|
|
520
469
|
|
|
521
470
|
alpha = self.encoder_n(zns)
|
|
522
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
523
|
-
|
|
524
|
-
#if self.loss_func == 'gamma-poisson':
|
|
525
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
526
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
527
472
|
|
|
528
473
|
def model3(self, xs, ys, embeds=None):
|
|
529
474
|
pyro.module('PerturbFlow', self)
|
|
@@ -535,8 +480,6 @@ class PerturbFlow(nn.Module):
|
|
|
535
480
|
if self.loss_func=='negbinomial':
|
|
536
481
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
537
482
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
538
|
-
elif self.loss_func == 'multinomial':
|
|
539
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
540
483
|
|
|
541
484
|
if self.use_zeroinflate:
|
|
542
485
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -587,17 +530,14 @@ class PerturbFlow(nn.Module):
|
|
|
587
530
|
|
|
588
531
|
zs = zns
|
|
589
532
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
533
|
+
concentrate = self.decoder_concentrate(zs)
|
|
534
|
+
if self.loss_func in ['bernoulli']:
|
|
535
|
+
log_theta = concentrate
|
|
593
536
|
else:
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
rate = concentrate.exp()
|
|
599
|
-
if self.loss_func != 'poisson':
|
|
600
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
537
|
+
rate = concentrate.exp()
|
|
538
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
539
|
+
if self.loss_func == 'poisson':
|
|
540
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
601
541
|
|
|
602
542
|
if self.loss_func == 'negbinomial':
|
|
603
543
|
if self.use_zeroinflate:
|
|
@@ -609,13 +549,8 @@ class PerturbFlow(nn.Module):
|
|
|
609
549
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
610
550
|
else:
|
|
611
551
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
612
|
-
elif self.loss_func == 'gamma-poisson':
|
|
613
|
-
if self.use_zeroinflate:
|
|
614
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
615
|
-
else:
|
|
616
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
617
552
|
elif self.loss_func == 'multinomial':
|
|
618
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
553
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
619
554
|
elif self.loss_func == 'bernoulli':
|
|
620
555
|
if self.use_zeroinflate:
|
|
621
556
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -630,10 +565,6 @@ class PerturbFlow(nn.Module):
|
|
|
630
565
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
631
566
|
else:
|
|
632
567
|
zns = embeds
|
|
633
|
-
|
|
634
|
-
#if self.loss_func == 'gamma-poisson':
|
|
635
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
636
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
637
568
|
|
|
638
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
639
570
|
pyro.module('PerturbFlow', self)
|
|
@@ -645,8 +576,6 @@ class PerturbFlow(nn.Module):
|
|
|
645
576
|
if self.loss_func=='negbinomial':
|
|
646
577
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
647
578
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
648
|
-
elif self.loss_func == 'multinomial':
|
|
649
|
-
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
650
579
|
|
|
651
580
|
if self.use_zeroinflate:
|
|
652
581
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -707,17 +636,14 @@ class PerturbFlow(nn.Module):
|
|
|
707
636
|
else:
|
|
708
637
|
zs = zns
|
|
709
638
|
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
639
|
+
concentrate = self.decoder_concentrate(zs)
|
|
640
|
+
if self.loss_func in ['bernoulli']:
|
|
641
|
+
log_theta = concentrate
|
|
713
642
|
else:
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
rate = concentrate.exp()
|
|
719
|
-
if self.loss_func != 'poisson':
|
|
720
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
643
|
+
rate = concentrate.exp()
|
|
644
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
645
|
+
if self.loss_func == 'poisson':
|
|
646
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
721
647
|
|
|
722
648
|
if self.loss_func == 'negbinomial':
|
|
723
649
|
if self.use_zeroinflate:
|
|
@@ -729,13 +655,8 @@ class PerturbFlow(nn.Module):
|
|
|
729
655
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
730
656
|
else:
|
|
731
657
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
732
|
-
elif self.loss_func == 'gamma-poisson':
|
|
733
|
-
if self.use_zeroinflate:
|
|
734
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
735
|
-
else:
|
|
736
|
-
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
737
658
|
elif self.loss_func == 'multinomial':
|
|
738
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
659
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
739
660
|
elif self.loss_func == 'bernoulli':
|
|
740
661
|
if self.use_zeroinflate:
|
|
741
662
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -750,10 +671,6 @@ class PerturbFlow(nn.Module):
|
|
|
750
671
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
751
672
|
else:
|
|
752
673
|
zns = embeds
|
|
753
|
-
|
|
754
|
-
#if self.loss_func == 'gamma-poisson':
|
|
755
|
-
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
756
|
-
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
757
674
|
|
|
758
675
|
def _total_effects(self, zns, us):
|
|
759
676
|
zus = None
|
|
@@ -932,12 +849,7 @@ class PerturbFlow(nn.Module):
|
|
|
932
849
|
return tensor_to_numpy(ms)
|
|
933
850
|
|
|
934
851
|
def _get_expression_response(self, delta_zs):
|
|
935
|
-
|
|
936
|
-
alpha,beta = self.decoder_concentrate(delta_zs)
|
|
937
|
-
xs = dist.Gamma(alpha,beta).to_event(1).mean
|
|
938
|
-
else:
|
|
939
|
-
xs = self.decoder_concentrate(delta_zs)
|
|
940
|
-
return xs
|
|
852
|
+
return self.decoder_concentrate(delta_zs)
|
|
941
853
|
|
|
942
854
|
def get_expression_response(self,
|
|
943
855
|
delta_zs,
|
|
@@ -966,16 +878,16 @@ class PerturbFlow(nn.Module):
|
|
|
966
878
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
967
879
|
elif self.loss_func == 'negbinomial':
|
|
968
880
|
#counts = concentrate.exp()
|
|
969
|
-
rate = concentrate.exp()
|
|
970
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
881
|
+
#rate = concentrate.exp()
|
|
882
|
+
#theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
971
883
|
|
|
972
884
|
total_count = pyro.param("inverse_dispersion")
|
|
973
|
-
counts = dist.NegativeBinomial(total_count=total_count,
|
|
885
|
+
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
974
886
|
elif self.loss_func == 'poisson':
|
|
975
887
|
rate = concentrate.exp()
|
|
888
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
889
|
+
rate = theta * library_size
|
|
976
890
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
977
|
-
elif self.loss_func == 'gamma-poisson':
|
|
978
|
-
counts = dist.Poisson(rate=concentrate).to_event(1).mean
|
|
979
891
|
elif self.loss_func == 'multinomial':
|
|
980
892
|
rate = concentrate.exp()
|
|
981
893
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
@@ -997,16 +909,17 @@ class PerturbFlow(nn.Module):
|
|
|
997
909
|
use_sampler: bool = False):
|
|
998
910
|
|
|
999
911
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
1000
|
-
ls = zs
|
|
1001
912
|
|
|
1002
|
-
if self.loss_func
|
|
1003
|
-
assert library_sizes
|
|
913
|
+
if self.loss_func in ['multinomial','poisson']:
|
|
914
|
+
assert library_sizes is not None, 'Library sizes are required for multinomial!'
|
|
1004
915
|
|
|
1005
916
|
if type(library_sizes) == list:
|
|
1006
917
|
library_sizes = np.array(library_sizes).view(-1,1)
|
|
1007
918
|
elif len(library_sizes.shape)==1:
|
|
1008
919
|
library_sizes = library_sizes.view(-1,1)
|
|
1009
920
|
ls = convert_to_tensor(library_sizes, device=self.get_device)
|
|
921
|
+
else:
|
|
922
|
+
ls = zs
|
|
1010
923
|
|
|
1011
924
|
dataset = CustomDataset2(zs,ls)
|
|
1012
925
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|