SURE-tools 2.1.52__py3-none-any.whl → 2.1.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -225,15 +225,36 @@ class PerturbFlow(nn.Module):
225
225
  )
226
226
  )
227
227
 
228
- self.decoder_concentrate = MLP(
229
- [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
230
- activation=activate_fct,
231
- output_activation=None,
232
- post_layer_fct=post_layer_fct,
233
- post_act_fct=post_act_fct,
234
- allow_broadcast=self.allow_broadcast,
235
- use_cuda=self.use_cuda,
236
- )
228
+ if self.loss_func == 'gamma-poisson':
229
+ self.decoder_concentrate = MLP(
230
+ [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
231
+ activation=activate_fct,
232
+ output_activation=[Exp,Exp],
233
+ post_layer_fct=post_layer_fct,
234
+ post_act_fct=post_act_fct,
235
+ allow_broadcast=self.allow_broadcast,
236
+ use_cuda=self.use_cuda,
237
+ )
238
+ #self.encoder_concentrate = MLP(
239
+ # [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
240
+ # activation=activate_fct,
241
+ # output_activation=[Exp,Exp],
242
+ # post_layer_fct=post_layer_fct,
243
+ # post_act_fct=post_act_fct,
244
+ # allow_broadcast=self.allow_broadcast,
245
+ # use_cuda=self.use_cuda,
246
+ # )
247
+ self.encoder_concentrate = self.decoder_concentrate
248
+ else:
249
+ self.decoder_concentrate = MLP(
250
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
251
+ activation=activate_fct,
252
+ output_activation=None,
253
+ post_layer_fct=post_layer_fct,
254
+ post_act_fct=post_act_fct,
255
+ allow_broadcast=self.allow_broadcast,
256
+ use_cuda=self.use_cuda,
257
+ )
237
258
 
238
259
  if self.latent_dist == 'studentt':
239
260
  self.codebook = MLP(
@@ -317,12 +338,14 @@ class PerturbFlow(nn.Module):
317
338
  if self.loss_func=='negbinomial':
318
339
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
319
340
  xs.new_ones(self.input_size), constraint=constraints.positive)
341
+ elif self.loss_func == 'multinomial':
342
+ total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
320
343
 
321
344
  if self.use_zeroinflate:
322
345
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
323
346
 
324
347
  acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
325
-
348
+
326
349
  I = torch.eye(self.code_size)
327
350
  if self.latent_dist=='studentt':
328
351
  acs_dof,acs_loc = self.codebook(I)
@@ -350,13 +373,17 @@ class PerturbFlow(nn.Module):
350
373
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
351
374
 
352
375
  zs = zns
353
- concentrate = self.decoder_concentrate(zs)
354
- if self.loss_func == 'bernoulli':
355
- log_theta = concentrate
376
+ if self.loss_func == 'gamma-poisson':
377
+ con_alpha,con_beta = self.decoder_concentrate(zs)
378
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
356
379
  else:
357
- rate = concentrate.exp()
358
- if self.loss_func != 'poisson':
359
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
380
+ concentrate = self.decoder_concentrate(zs)
381
+ if self.loss_func == 'bernoulli':
382
+ log_theta = concentrate
383
+ else:
384
+ rate = concentrate.exp()
385
+ if self.loss_func != 'poisson':
386
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
360
387
 
361
388
  if self.loss_func == 'negbinomial':
362
389
  if self.use_zeroinflate:
@@ -368,8 +395,13 @@ class PerturbFlow(nn.Module):
368
395
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
369
396
  else:
370
397
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
+ elif self.loss_func == 'gamma-poisson':
399
+ if self.use_zeroinflate:
400
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
401
+ else:
402
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
371
403
  elif self.loss_func == 'multinomial':
372
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
404
+ pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
373
405
  elif self.loss_func == 'bernoulli':
374
406
  if self.use_zeroinflate:
375
407
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -384,6 +416,10 @@ class PerturbFlow(nn.Module):
384
416
 
385
417
  alpha = self.encoder_n(zns)
386
418
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
419
+
420
+ if self.loss_func == 'gamma-poisson':
421
+ con_alpha,con_beta = self.encoder_concentrate(zns)
422
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
387
423
 
388
424
  def model2(self, xs, us=None):
389
425
  pyro.module('PerturbFlow', self)
@@ -395,6 +431,8 @@ class PerturbFlow(nn.Module):
395
431
  if self.loss_func=='negbinomial':
396
432
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
397
433
  xs.new_ones(self.input_size), constraint=constraints.positive)
434
+ elif self.loss_func == 'multinomial':
435
+ total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
398
436
 
399
437
  if self.use_zeroinflate:
400
438
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -439,13 +477,17 @@ class PerturbFlow(nn.Module):
439
477
  else:
440
478
  zs = zns
441
479
 
442
- concentrate = self.decoder_concentrate(zs)
443
- if self.loss_func == 'bernoulli':
444
- log_theta = concentrate
480
+ if self.loss_func == 'gamma-poisson':
481
+ con_alpha,con_beta = self.decoder_concentrate(zs)
482
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
445
483
  else:
446
- rate = concentrate.exp()
447
- if self.loss_func != 'poisson':
448
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
484
+ concentrate = self.decoder_concentrate(zs)
485
+ if self.loss_func == 'bernoulli':
486
+ log_theta = concentrate
487
+ else:
488
+ rate = concentrate.exp()
489
+ if self.loss_func != 'poisson':
490
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
449
491
 
450
492
  if self.loss_func == 'negbinomial':
451
493
  if self.use_zeroinflate:
@@ -457,8 +499,13 @@ class PerturbFlow(nn.Module):
457
499
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
458
500
  else:
459
501
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
502
+ elif self.loss_func == 'gamma-poisson':
503
+ if self.use_zeroinflate:
504
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
505
+ else:
506
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
460
507
  elif self.loss_func == 'multinomial':
461
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
508
+ pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
462
509
  elif self.loss_func == 'bernoulli':
463
510
  if self.use_zeroinflate:
464
511
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -473,6 +520,10 @@ class PerturbFlow(nn.Module):
473
520
 
474
521
  alpha = self.encoder_n(zns)
475
522
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
523
+
524
+ if self.loss_func == 'gamma-poisson':
525
+ con_alpha,con_beta = self.encoder_concentrate(zns)
526
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
476
527
 
477
528
  def model3(self, xs, ys, embeds=None):
478
529
  pyro.module('PerturbFlow', self)
@@ -484,6 +535,8 @@ class PerturbFlow(nn.Module):
484
535
  if self.loss_func=='negbinomial':
485
536
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
486
537
  xs.new_ones(self.input_size), constraint=constraints.positive)
538
+ elif self.loss_func == 'multinomial':
539
+ total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
487
540
 
488
541
  if self.use_zeroinflate:
489
542
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -534,13 +587,17 @@ class PerturbFlow(nn.Module):
534
587
 
535
588
  zs = zns
536
589
 
537
- concentrate = self.decoder_concentrate(zs)
538
- if self.loss_func == 'bernoulli':
539
- log_theta = concentrate
590
+ if self.loss_func == 'gamma-poisson':
591
+ con_alpha,con_beta = self.decoder_concentrate(zs)
592
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
540
593
  else:
541
- rate = concentrate.exp()
542
- if self.loss_func != 'poisson':
543
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
594
+ concentrate = self.decoder_concentrate(zs)
595
+ if self.loss_func == 'bernoulli':
596
+ log_theta = concentrate
597
+ else:
598
+ rate = concentrate.exp()
599
+ if self.loss_func != 'poisson':
600
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
544
601
 
545
602
  if self.loss_func == 'negbinomial':
546
603
  if self.use_zeroinflate:
@@ -552,8 +609,13 @@ class PerturbFlow(nn.Module):
552
609
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
553
610
  else:
554
611
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
612
+ elif self.loss_func == 'gamma-poisson':
613
+ if self.use_zeroinflate:
614
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
615
+ else:
616
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
555
617
  elif self.loss_func == 'multinomial':
556
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
618
+ pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
557
619
  elif self.loss_func == 'bernoulli':
558
620
  if self.use_zeroinflate:
559
621
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -566,6 +628,12 @@ class PerturbFlow(nn.Module):
566
628
  #zn_loc, zn_scale = self.encoder_zn(xs)
567
629
  zn_loc, zn_scale = self._get_basal_embedding(xs)
568
630
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
631
+ else:
632
+ zns = embeds
633
+
634
+ if self.loss_func == 'gamma-poisson':
635
+ con_alpha,con_beta = self.encoder_concentrate(zns)
636
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
569
637
 
570
638
  def model4(self, xs, us, ys, embeds=None):
571
639
  pyro.module('PerturbFlow', self)
@@ -577,6 +645,8 @@ class PerturbFlow(nn.Module):
577
645
  if self.loss_func=='negbinomial':
578
646
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
579
647
  xs.new_ones(self.input_size), constraint=constraints.positive)
648
+ elif self.loss_func == 'multinomial':
649
+ total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
580
650
 
581
651
  if self.use_zeroinflate:
582
652
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -637,13 +707,17 @@ class PerturbFlow(nn.Module):
637
707
  else:
638
708
  zs = zns
639
709
 
640
- concentrate = self.decoder_concentrate(zs)
641
- if self.loss_func == 'bernoulli':
642
- log_theta = concentrate
710
+ if self.loss_func == 'gamma-poisson':
711
+ con_alpha,con_beta = self.decoder_concentrate(zs)
712
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
643
713
  else:
644
- rate = concentrate.exp()
645
- if self.loss_func != 'poisson':
646
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
714
+ concentrate = self.decoder_concentrate(zs)
715
+ if self.loss_func == 'bernoulli':
716
+ log_theta = concentrate
717
+ else:
718
+ rate = concentrate.exp()
719
+ if self.loss_func != 'poisson':
720
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
647
721
 
648
722
  if self.loss_func == 'negbinomial':
649
723
  if self.use_zeroinflate:
@@ -655,8 +729,13 @@ class PerturbFlow(nn.Module):
655
729
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
656
730
  else:
657
731
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
732
+ elif self.loss_func == 'gamma-poisson':
733
+ if self.use_zeroinflate:
734
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
735
+ else:
736
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
658
737
  elif self.loss_func == 'multinomial':
659
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
738
+ pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
660
739
  elif self.loss_func == 'bernoulli':
661
740
  if self.use_zeroinflate:
662
741
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -669,6 +748,12 @@ class PerturbFlow(nn.Module):
669
748
  #zn_loc, zn_scale = self.encoder_zn(xs)
670
749
  zn_loc, zn_scale = self._get_basal_embedding(xs)
671
750
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
751
+ else:
752
+ zns = embeds
753
+
754
+ if self.loss_func == 'gamma-poisson':
755
+ con_alpha,con_beta = self.encoder_concentrate(zns)
756
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
672
757
 
673
758
  def _total_effects(self, zns, us):
674
759
  zus = None
@@ -847,7 +932,12 @@ class PerturbFlow(nn.Module):
847
932
  return tensor_to_numpy(ms)
848
933
 
849
934
  def _get_expression_response(self, delta_zs):
850
- return self.decoder_concentrate(delta_zs)
935
+ if self.loss_func == 'gamma-poisson':
936
+ alpha,beta = self.encoder_concentrate(delta_zs)
937
+ xs = dist.Gamma(alpha,beta).to_event(1).mean
938
+ else:
939
+ xs = self.decoder_concentrate(delta_zs)
940
+ return xs
851
941
 
852
942
  def get_expression_response(self,
853
943
  delta_zs,
@@ -880,10 +970,12 @@ class PerturbFlow(nn.Module):
880
970
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
881
971
 
882
972
  total_count = pyro.param("inverse_dispersion")
883
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1)
973
+ counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
884
974
  elif self.loss_func == 'poisson':
885
975
  rate = concentrate.exp()
886
- counts = dist.Poisson(rate=rate).to_event(1)
976
+ counts = dist.Poisson(rate=rate).to_event(1).mean
977
+ elif self.loss_func == 'gamma-poisson':
978
+ counts = dist.Poisson(rate=concentrate).to_event(1).mean
887
979
  return counts
888
980
 
889
981
  def _count_sample(self,concentrate):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.52
3
+ Version: 2.1.54
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=yt0kOW4buZKpJQ3Jn_8Zd2uEKUq29DZYkSzcgEP55EA,53211
1
+ SURE/PerturbFlow.py,sha256=hOVEsBrMAs7T5yi3LW7KV6hwPuwyjZtKG2wyMF6R08E,58614
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.52.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.52.dist-info/METADATA,sha256=ARO36IQ9aKV9Sp4F9AkLR6zzXSsPnMGguljP7XU95Mk,2678
22
- sure_tools-2.1.52.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.52.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.52.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.52.dist-info/RECORD,,
20
+ sure_tools-2.1.54.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.54.dist-info/METADATA,sha256=VDqYvGzqSz_HeBiPxGgwwl_i-uunVvd3t0MVIa4n6iI,2678
22
+ sure_tools-2.1.54.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.54.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.54.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.54.dist-info/RECORD,,