SURE-tools 2.1.56__py3-none-any.whl → 2.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/PerturbFlow.py CHANGED
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
225
225
  )
226
226
  )
227
227
 
228
- if self.loss_func == 'gamma-poisson':
229
- self.decoder_concentrate = MLP(
230
- [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
231
- activation=activate_fct,
232
- output_activation=[Exp,Exp],
233
- post_layer_fct=post_layer_fct,
234
- post_act_fct=post_act_fct,
235
- allow_broadcast=self.allow_broadcast,
236
- use_cuda=self.use_cuda,
237
- )
238
- #self.encoder_concentrate = MLP(
239
- # [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
240
- # activation=activate_fct,
241
- # output_activation=[Exp,Exp],
242
- # post_layer_fct=post_layer_fct,
243
- # post_act_fct=post_act_fct,
244
- # allow_broadcast=self.allow_broadcast,
245
- # use_cuda=self.use_cuda,
246
- # )
247
- #self.encoder_concentrate = self.decoder_concentrate
248
- else:
249
- self.decoder_concentrate = MLP(
228
+ self.decoder_concentrate = MLP(
250
229
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
251
230
  activation=activate_fct,
252
231
  output_activation=None,
@@ -373,33 +352,24 @@ class PerturbFlow(nn.Module):
373
352
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
374
353
 
375
354
  zs = zns
376
- if self.loss_func == 'gamma-poisson':
377
- con_alpha,con_beta = self.decoder_concentrate(zs)
378
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
355
+ concentrate = self.decoder_concentrate(zs)
356
+ if self.loss_func in ['bernoulli','negbinomial']:
357
+ log_theta = concentrate
379
358
  else:
380
- concentrate = self.decoder_concentrate(zs)
381
- if self.loss_func == 'bernoulli':
382
- log_theta = concentrate
383
- else:
384
- rate = concentrate.exp()
385
- if self.loss_func != 'poisson':
386
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
359
+ rate = concentrate.exp()
360
+ if self.loss_func != 'poisson':
361
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
387
362
 
388
363
  if self.loss_func == 'negbinomial':
389
364
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
365
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
391
366
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
367
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
393
368
  elif self.loss_func == 'poisson':
394
369
  if self.use_zeroinflate:
395
370
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
396
371
  else:
397
372
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
- elif self.loss_func == 'gamma-poisson':
399
- if self.use_zeroinflate:
400
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
401
- else:
402
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
403
373
  elif self.loss_func == 'multinomial':
404
374
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
405
375
  elif self.loss_func == 'bernoulli':
@@ -416,10 +386,6 @@ class PerturbFlow(nn.Module):
416
386
 
417
387
  alpha = self.encoder_n(zns)
418
388
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
419
-
420
- #if self.loss_func == 'gamma-poisson':
421
- # con_alpha,con_beta = self.encoder_concentrate(zns)
422
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
423
389
 
424
390
  def model2(self, xs, us=None):
425
391
  pyro.module('PerturbFlow', self)
@@ -477,33 +443,24 @@ class PerturbFlow(nn.Module):
477
443
  else:
478
444
  zs = zns
479
445
 
480
- if self.loss_func == 'gamma-poisson':
481
- con_alpha,con_beta = self.decoder_concentrate(zs)
482
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
446
+ concentrate = self.decoder_concentrate(zs)
447
+ if self.loss_func in ['bernoulli','negbinomial']:
448
+ log_theta = concentrate
483
449
  else:
484
- concentrate = self.decoder_concentrate(zs)
485
- if self.loss_func == 'bernoulli':
486
- log_theta = concentrate
487
- else:
488
- rate = concentrate.exp()
489
- if self.loss_func != 'poisson':
490
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
450
+ rate = concentrate.exp()
451
+ if self.loss_func != 'poisson':
452
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
491
453
 
492
454
  if self.loss_func == 'negbinomial':
493
455
  if self.use_zeroinflate:
494
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
456
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
495
457
  else:
496
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
458
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
497
459
  elif self.loss_func == 'poisson':
498
460
  if self.use_zeroinflate:
499
461
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
500
462
  else:
501
463
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
502
- elif self.loss_func == 'gamma-poisson':
503
- if self.use_zeroinflate:
504
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
505
- else:
506
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
507
464
  elif self.loss_func == 'multinomial':
508
465
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
509
466
  elif self.loss_func == 'bernoulli':
@@ -520,10 +477,6 @@ class PerturbFlow(nn.Module):
520
477
 
521
478
  alpha = self.encoder_n(zns)
522
479
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
523
-
524
- #if self.loss_func == 'gamma-poisson':
525
- # con_alpha,con_beta = self.encoder_concentrate(zns)
526
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
527
480
 
528
481
  def model3(self, xs, ys, embeds=None):
529
482
  pyro.module('PerturbFlow', self)
@@ -587,33 +540,24 @@ class PerturbFlow(nn.Module):
587
540
 
588
541
  zs = zns
589
542
 
590
- if self.loss_func == 'gamma-poisson':
591
- con_alpha,con_beta = self.decoder_concentrate(zs)
592
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
543
+ concentrate = self.decoder_concentrate(zs)
544
+ if self.loss_func in ['bernoulli','negbinomial']:
545
+ log_theta = concentrate
593
546
  else:
594
- concentrate = self.decoder_concentrate(zs)
595
- if self.loss_func == 'bernoulli':
596
- log_theta = concentrate
597
- else:
598
- rate = concentrate.exp()
599
- if self.loss_func != 'poisson':
600
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
547
+ rate = concentrate.exp()
548
+ if self.loss_func != 'poisson':
549
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
601
550
 
602
551
  if self.loss_func == 'negbinomial':
603
552
  if self.use_zeroinflate:
604
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
553
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
605
554
  else:
606
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
555
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
607
556
  elif self.loss_func == 'poisson':
608
557
  if self.use_zeroinflate:
609
558
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
610
559
  else:
611
560
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
612
- elif self.loss_func == 'gamma-poisson':
613
- if self.use_zeroinflate:
614
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
615
- else:
616
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
617
561
  elif self.loss_func == 'multinomial':
618
562
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
619
563
  elif self.loss_func == 'bernoulli':
@@ -630,10 +574,6 @@ class PerturbFlow(nn.Module):
630
574
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
631
575
  else:
632
576
  zns = embeds
633
-
634
- #if self.loss_func == 'gamma-poisson':
635
- # con_alpha,con_beta = self.encoder_concentrate(zns)
636
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
637
577
 
638
578
  def model4(self, xs, us, ys, embeds=None):
639
579
  pyro.module('PerturbFlow', self)
@@ -707,33 +647,24 @@ class PerturbFlow(nn.Module):
707
647
  else:
708
648
  zs = zns
709
649
 
710
- if self.loss_func == 'gamma-poisson':
711
- con_alpha,con_beta = self.decoder_concentrate(zs)
712
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
650
+ concentrate = self.decoder_concentrate(zs)
651
+ if self.loss_func in ['bernoulli','negbinomial']:
652
+ log_theta = concentrate
713
653
  else:
714
- concentrate = self.decoder_concentrate(zs)
715
- if self.loss_func == 'bernoulli':
716
- log_theta = concentrate
717
- else:
718
- rate = concentrate.exp()
719
- if self.loss_func != 'poisson':
720
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
654
+ rate = concentrate.exp()
655
+ if self.loss_func != 'poisson':
656
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
721
657
 
722
658
  if self.loss_func == 'negbinomial':
723
659
  if self.use_zeroinflate:
724
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
660
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
725
661
  else:
726
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
662
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
727
663
  elif self.loss_func == 'poisson':
728
664
  if self.use_zeroinflate:
729
665
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
730
666
  else:
731
667
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
732
- elif self.loss_func == 'gamma-poisson':
733
- if self.use_zeroinflate:
734
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
735
- else:
736
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
737
668
  elif self.loss_func == 'multinomial':
738
669
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
739
670
  elif self.loss_func == 'bernoulli':
@@ -750,10 +681,6 @@ class PerturbFlow(nn.Module):
750
681
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
751
682
  else:
752
683
  zns = embeds
753
-
754
- #if self.loss_func == 'gamma-poisson':
755
- # con_alpha,con_beta = self.encoder_concentrate(zns)
756
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
757
684
 
758
685
  def _total_effects(self, zns, us):
759
686
  zus = None
@@ -932,12 +859,7 @@ class PerturbFlow(nn.Module):
932
859
  return tensor_to_numpy(ms)
933
860
 
934
861
  def _get_expression_response(self, delta_zs):
935
- if self.loss_func == 'gamma-poisson':
936
- alpha,beta = self.decoder_concentrate(delta_zs)
937
- xs = dist.Gamma(alpha,beta).to_event(1).mean
938
- else:
939
- xs = self.decoder_concentrate(delta_zs)
940
- return xs
862
+ return self.decoder_concentrate(delta_zs)
941
863
 
942
864
  def get_expression_response(self,
943
865
  delta_zs,
@@ -966,16 +888,14 @@ class PerturbFlow(nn.Module):
966
888
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
967
889
  elif self.loss_func == 'negbinomial':
968
890
  #counts = concentrate.exp()
969
- rate = concentrate.exp()
970
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
891
+ #rate = concentrate.exp()
892
+ #theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
971
893
 
972
894
  total_count = pyro.param("inverse_dispersion")
973
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
895
+ counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
974
896
  elif self.loss_func == 'poisson':
975
897
  rate = concentrate.exp()
976
898
  counts = dist.Poisson(rate=rate).to_event(1).mean
977
- elif self.loss_func == 'gamma-poisson':
978
- counts = dist.Poisson(rate=concentrate).to_event(1).mean
979
899
  elif self.loss_func == 'multinomial':
980
900
  rate = concentrate.exp()
981
901
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.56
3
+ Version: 2.1.58
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=CvnmX1QVo4UK4rkmFQd0RR9YHrNjL1GCHM1aj-BHVqM,59536
1
+ SURE/PerturbFlow.py,sha256=7vflCQ8mtX0jzDe5lEIVxF4zwgWJIJ9aEZ6lD1duv0E,54985
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.56.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.56.dist-info/METADATA,sha256=kWBC-87jEjWE-JHxXGcrerFaxT9G5buo8zwZhkDxu9o,2678
22
- sure_tools-2.1.56.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.56.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.56.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.56.dist-info/RECORD,,
20
+ sure_tools-2.1.58.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.58.dist-info/METADATA,sha256=IKFJkaArfqXoAjczpEKYZworTX83okZYw7Kf8Bx430Y,2678
22
+ sure_tools-2.1.58.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.58.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.58.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.58.dist-info/RECORD,,