SURE-tools 2.1.57__py3-none-any.whl → 2.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/PerturbFlow.py CHANGED
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
225
225
  )
226
226
  )
227
227
 
228
- if self.loss_func == 'gamma-poisson':
229
- self.decoder_concentrate = MLP(
230
- [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
231
- activation=activate_fct,
232
- output_activation=[Exp,Exp],
233
- post_layer_fct=post_layer_fct,
234
- post_act_fct=post_act_fct,
235
- allow_broadcast=self.allow_broadcast,
236
- use_cuda=self.use_cuda,
237
- )
238
- #self.encoder_concentrate = MLP(
239
- # [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
240
- # activation=activate_fct,
241
- # output_activation=[Exp,Exp],
242
- # post_layer_fct=post_layer_fct,
243
- # post_act_fct=post_act_fct,
244
- # allow_broadcast=self.allow_broadcast,
245
- # use_cuda=self.use_cuda,
246
- # )
247
- #self.encoder_concentrate = self.decoder_concentrate
248
- else:
249
- self.decoder_concentrate = MLP(
228
+ self.decoder_concentrate = MLP(
250
229
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
251
230
  activation=activate_fct,
252
231
  output_activation=None,
@@ -373,19 +352,13 @@ class PerturbFlow(nn.Module):
373
352
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
374
353
 
375
354
  zs = zns
376
- if self.loss_func == 'gamma-poisson':
377
- con_alpha,con_beta = self.decoder_concentrate(zs)
378
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
355
+ concentrate = self.decoder_concentrate(zs)
356
+ if self.loss_func in ['bernoulli','negbinomial']:
357
+ log_theta = concentrate
379
358
  else:
380
- concentrate = self.decoder_concentrate(zs)
381
- if self.loss_func == 'bernoulli':
382
- log_theta = concentrate
383
- elif self.loss_func == 'negbinomial':
384
- log_theta = concentrate
385
- else:
386
- rate = concentrate.exp()
387
- if self.loss_func != 'poisson':
388
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
359
+ rate = concentrate.exp()
360
+ if self.loss_func != 'poisson':
361
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
389
362
 
390
363
  if self.loss_func == 'negbinomial':
391
364
  if self.use_zeroinflate:
@@ -397,11 +370,6 @@ class PerturbFlow(nn.Module):
397
370
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
398
371
  else:
399
372
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
400
- elif self.loss_func == 'gamma-poisson':
401
- if self.use_zeroinflate:
402
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
- else:
404
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
373
  elif self.loss_func == 'multinomial':
406
374
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
407
375
  elif self.loss_func == 'bernoulli':
@@ -418,10 +386,6 @@ class PerturbFlow(nn.Module):
418
386
 
419
387
  alpha = self.encoder_n(zns)
420
388
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
421
-
422
- #if self.loss_func == 'gamma-poisson':
423
- # con_alpha,con_beta = self.encoder_concentrate(zns)
424
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
425
389
 
426
390
  def model2(self, xs, us=None):
427
391
  pyro.module('PerturbFlow', self)
@@ -479,19 +443,13 @@ class PerturbFlow(nn.Module):
479
443
  else:
480
444
  zs = zns
481
445
 
482
- if self.loss_func == 'gamma-poisson':
483
- con_alpha,con_beta = self.decoder_concentrate(zs)
484
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
446
+ concentrate = self.decoder_concentrate(zs)
447
+ if self.loss_func in ['bernoulli','negbinomial']:
448
+ log_theta = concentrate
485
449
  else:
486
- concentrate = self.decoder_concentrate(zs)
487
- if self.loss_func == 'bernoulli':
488
- log_theta = concentrate
489
- elif self.loss_func == 'negbinomial':
490
- log_theta = concentrate
491
- else:
492
- rate = concentrate.exp()
493
- if self.loss_func != 'poisson':
494
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
450
+ rate = concentrate.exp()
451
+ if self.loss_func != 'poisson':
452
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
495
453
 
496
454
  if self.loss_func == 'negbinomial':
497
455
  if self.use_zeroinflate:
@@ -503,11 +461,6 @@ class PerturbFlow(nn.Module):
503
461
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
504
462
  else:
505
463
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
506
- elif self.loss_func == 'gamma-poisson':
507
- if self.use_zeroinflate:
508
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
509
- else:
510
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
511
464
  elif self.loss_func == 'multinomial':
512
465
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
513
466
  elif self.loss_func == 'bernoulli':
@@ -524,10 +477,6 @@ class PerturbFlow(nn.Module):
524
477
 
525
478
  alpha = self.encoder_n(zns)
526
479
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
527
-
528
- #if self.loss_func == 'gamma-poisson':
529
- # con_alpha,con_beta = self.encoder_concentrate(zns)
530
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
531
480
 
532
481
  def model3(self, xs, ys, embeds=None):
533
482
  pyro.module('PerturbFlow', self)
@@ -591,19 +540,13 @@ class PerturbFlow(nn.Module):
591
540
 
592
541
  zs = zns
593
542
 
594
- if self.loss_func == 'gamma-poisson':
595
- con_alpha,con_beta = self.decoder_concentrate(zs)
596
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
543
+ concentrate = self.decoder_concentrate(zs)
544
+ if self.loss_func in ['bernoulli','negbinomial']:
545
+ log_theta = concentrate
597
546
  else:
598
- concentrate = self.decoder_concentrate(zs)
599
- if self.loss_func == 'bernoulli':
600
- log_theta = concentrate
601
- elif self.loss_func == 'negbinomial':
602
- log_theta = concentrate
603
- else:
604
- rate = concentrate.exp()
605
- if self.loss_func != 'poisson':
606
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
547
+ rate = concentrate.exp()
548
+ if self.loss_func != 'poisson':
549
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
607
550
 
608
551
  if self.loss_func == 'negbinomial':
609
552
  if self.use_zeroinflate:
@@ -615,11 +558,6 @@ class PerturbFlow(nn.Module):
615
558
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
616
559
  else:
617
560
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
618
- elif self.loss_func == 'gamma-poisson':
619
- if self.use_zeroinflate:
620
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
621
- else:
622
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
623
561
  elif self.loss_func == 'multinomial':
624
562
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
625
563
  elif self.loss_func == 'bernoulli':
@@ -636,10 +574,6 @@ class PerturbFlow(nn.Module):
636
574
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
637
575
  else:
638
576
  zns = embeds
639
-
640
- #if self.loss_func == 'gamma-poisson':
641
- # con_alpha,con_beta = self.encoder_concentrate(zns)
642
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
643
577
 
644
578
  def model4(self, xs, us, ys, embeds=None):
645
579
  pyro.module('PerturbFlow', self)
@@ -713,19 +647,13 @@ class PerturbFlow(nn.Module):
713
647
  else:
714
648
  zs = zns
715
649
 
716
- if self.loss_func == 'gamma-poisson':
717
- con_alpha,con_beta = self.decoder_concentrate(zs)
718
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
650
+ concentrate = self.decoder_concentrate(zs)
651
+ if self.loss_func in ['bernoulli','negbinomial']:
652
+ log_theta = concentrate
719
653
  else:
720
- concentrate = self.decoder_concentrate(zs)
721
- if self.loss_func == 'bernoulli':
722
- log_theta = concentrate
723
- elif self.loss_func == 'negbinomial':
724
- log_theta = concentrate
725
- else:
726
- rate = concentrate.exp()
727
- if self.loss_func != 'poisson':
728
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
654
+ rate = concentrate.exp()
655
+ if self.loss_func != 'poisson':
656
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
729
657
 
730
658
  if self.loss_func == 'negbinomial':
731
659
  if self.use_zeroinflate:
@@ -737,11 +665,6 @@ class PerturbFlow(nn.Module):
737
665
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
738
666
  else:
739
667
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
740
- elif self.loss_func == 'gamma-poisson':
741
- if self.use_zeroinflate:
742
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
743
- else:
744
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
745
668
  elif self.loss_func == 'multinomial':
746
669
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
747
670
  elif self.loss_func == 'bernoulli':
@@ -758,10 +681,6 @@ class PerturbFlow(nn.Module):
758
681
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
759
682
  else:
760
683
  zns = embeds
761
-
762
- #if self.loss_func == 'gamma-poisson':
763
- # con_alpha,con_beta = self.encoder_concentrate(zns)
764
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
765
684
 
766
685
  def _total_effects(self, zns, us):
767
686
  zus = None
@@ -940,12 +859,7 @@ class PerturbFlow(nn.Module):
940
859
  return tensor_to_numpy(ms)
941
860
 
942
861
  def _get_expression_response(self, delta_zs):
943
- if self.loss_func == 'gamma-poisson':
944
- alpha,beta = self.decoder_concentrate(delta_zs)
945
- xs = dist.Gamma(alpha,beta).to_event(1).mean
946
- else:
947
- xs = self.decoder_concentrate(delta_zs)
948
- return xs
862
+ return self.decoder_concentrate(delta_zs)
949
863
 
950
864
  def get_expression_response(self,
951
865
  delta_zs,
@@ -982,8 +896,6 @@ class PerturbFlow(nn.Module):
982
896
  elif self.loss_func == 'poisson':
983
897
  rate = concentrate.exp()
984
898
  counts = dist.Poisson(rate=rate).to_event(1).mean
985
- elif self.loss_func == 'gamma-poisson':
986
- counts = dist.Poisson(rate=concentrate).to_event(1).mean
987
899
  elif self.loss_func == 'multinomial':
988
900
  rate = concentrate.exp()
989
901
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.57
3
+ Version: 2.1.58
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=emkyhDc99eTJQNkMdsHCp6VPg6468CRkc8lRHyA4P4o,59977
1
+ SURE/PerturbFlow.py,sha256=7vflCQ8mtX0jzDe5lEIVxF4zwgWJIJ9aEZ6lD1duv0E,54985
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.57.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.57.dist-info/METADATA,sha256=Y1npoz3fb9597vOLCGFK4__9N85QgzBnX_zUra5E1Fg,2678
22
- sure_tools-2.1.57.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.57.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.57.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.57.dist-info/RECORD,,
20
+ sure_tools-2.1.58.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.58.dist-info/METADATA,sha256=IKFJkaArfqXoAjczpEKYZworTX83okZYw7Kf8Bx430Y,2678
22
+ sure_tools-2.1.58.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.58.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.58.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.58.dist-info/RECORD,,