SURE-tools 2.1.53__py3-none-any.whl → 2.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -225,15 +225,36 @@ class PerturbFlow(nn.Module):
225
225
  )
226
226
  )
227
227
 
228
- self.decoder_concentrate = MLP(
229
- [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
230
- activation=activate_fct,
231
- output_activation=None,
232
- post_layer_fct=post_layer_fct,
233
- post_act_fct=post_act_fct,
234
- allow_broadcast=self.allow_broadcast,
235
- use_cuda=self.use_cuda,
236
- )
228
+ if self.loss_func == 'gamma-poisson':
229
+ self.decoder_concentrate = MLP(
230
+ [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
231
+ activation=activate_fct,
232
+ output_activation=[Exp,Exp],
233
+ post_layer_fct=post_layer_fct,
234
+ post_act_fct=post_act_fct,
235
+ allow_broadcast=self.allow_broadcast,
236
+ use_cuda=self.use_cuda,
237
+ )
238
+ #self.encoder_concentrate = MLP(
239
+ # [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
240
+ # activation=activate_fct,
241
+ # output_activation=[Exp,Exp],
242
+ # post_layer_fct=post_layer_fct,
243
+ # post_act_fct=post_act_fct,
244
+ # allow_broadcast=self.allow_broadcast,
245
+ # use_cuda=self.use_cuda,
246
+ # )
247
+ self.encoder_concentrate = self.decoder_concentrate
248
+ else:
249
+ self.decoder_concentrate = MLP(
250
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
251
+ activation=activate_fct,
252
+ output_activation=None,
253
+ post_layer_fct=post_layer_fct,
254
+ post_act_fct=post_act_fct,
255
+ allow_broadcast=self.allow_broadcast,
256
+ use_cuda=self.use_cuda,
257
+ )
237
258
 
238
259
  if self.latent_dist == 'studentt':
239
260
  self.codebook = MLP(
@@ -352,13 +373,17 @@ class PerturbFlow(nn.Module):
352
373
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
353
374
 
354
375
  zs = zns
355
- concentrate = self.decoder_concentrate(zs)
356
- if self.loss_func == 'bernoulli':
357
- log_theta = concentrate
376
+ if self.loss_func == 'gamma-poisson':
377
+ con_alpha,con_beta = self.decoder_concentrate(zs)
378
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
358
379
  else:
359
- rate = concentrate.exp()
360
- if self.loss_func != 'poisson':
361
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
380
+ concentrate = self.decoder_concentrate(zs)
381
+ if self.loss_func == 'bernoulli':
382
+ log_theta = concentrate
383
+ else:
384
+ rate = concentrate.exp()
385
+ if self.loss_func != 'poisson':
386
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
362
387
 
363
388
  if self.loss_func == 'negbinomial':
364
389
  if self.use_zeroinflate:
@@ -370,6 +395,11 @@ class PerturbFlow(nn.Module):
370
395
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
371
396
  else:
372
397
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
+ elif self.loss_func == 'gamma-poisson':
399
+ if self.use_zeroinflate:
400
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
401
+ else:
402
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
373
403
  elif self.loss_func == 'multinomial':
374
404
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
375
405
  elif self.loss_func == 'bernoulli':
@@ -386,6 +416,10 @@ class PerturbFlow(nn.Module):
386
416
 
387
417
  alpha = self.encoder_n(zns)
388
418
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
419
+
420
+ if self.loss_func == 'gamma-poisson':
421
+ con_alpha,con_beta = self.encoder_concentrate(zns)
422
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
389
423
 
390
424
  def model2(self, xs, us=None):
391
425
  pyro.module('PerturbFlow', self)
@@ -443,13 +477,17 @@ class PerturbFlow(nn.Module):
443
477
  else:
444
478
  zs = zns
445
479
 
446
- concentrate = self.decoder_concentrate(zs)
447
- if self.loss_func == 'bernoulli':
448
- log_theta = concentrate
480
+ if self.loss_func == 'gamma-poisson':
481
+ con_alpha,con_beta = self.decoder_concentrate(zs)
482
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
449
483
  else:
450
- rate = concentrate.exp()
451
- if self.loss_func != 'poisson':
452
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
484
+ concentrate = self.decoder_concentrate(zs)
485
+ if self.loss_func == 'bernoulli':
486
+ log_theta = concentrate
487
+ else:
488
+ rate = concentrate.exp()
489
+ if self.loss_func != 'poisson':
490
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
453
491
 
454
492
  if self.loss_func == 'negbinomial':
455
493
  if self.use_zeroinflate:
@@ -461,6 +499,11 @@ class PerturbFlow(nn.Module):
461
499
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
462
500
  else:
463
501
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
502
+ elif self.loss_func == 'gamma-poisson':
503
+ if self.use_zeroinflate:
504
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
505
+ else:
506
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
464
507
  elif self.loss_func == 'multinomial':
465
508
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
466
509
  elif self.loss_func == 'bernoulli':
@@ -477,6 +520,10 @@ class PerturbFlow(nn.Module):
477
520
 
478
521
  alpha = self.encoder_n(zns)
479
522
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
523
+
524
+ if self.loss_func == 'gamma-poisson':
525
+ con_alpha,con_beta = self.encoder_concentrate(zns)
526
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
480
527
 
481
528
  def model3(self, xs, ys, embeds=None):
482
529
  pyro.module('PerturbFlow', self)
@@ -540,13 +587,17 @@ class PerturbFlow(nn.Module):
540
587
 
541
588
  zs = zns
542
589
 
543
- concentrate = self.decoder_concentrate(zs)
544
- if self.loss_func == 'bernoulli':
545
- log_theta = concentrate
590
+ if self.loss_func == 'gamma-poisson':
591
+ con_alpha,con_beta = self.decoder_concentrate(zs)
592
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
546
593
  else:
547
- rate = concentrate.exp()
548
- if self.loss_func != 'poisson':
549
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
594
+ concentrate = self.decoder_concentrate(zs)
595
+ if self.loss_func == 'bernoulli':
596
+ log_theta = concentrate
597
+ else:
598
+ rate = concentrate.exp()
599
+ if self.loss_func != 'poisson':
600
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
550
601
 
551
602
  if self.loss_func == 'negbinomial':
552
603
  if self.use_zeroinflate:
@@ -558,6 +609,11 @@ class PerturbFlow(nn.Module):
558
609
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
559
610
  else:
560
611
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
612
+ elif self.loss_func == 'gamma-poisson':
613
+ if self.use_zeroinflate:
614
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
615
+ else:
616
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
561
617
  elif self.loss_func == 'multinomial':
562
618
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
563
619
  elif self.loss_func == 'bernoulli':
@@ -572,6 +628,12 @@ class PerturbFlow(nn.Module):
572
628
  #zn_loc, zn_scale = self.encoder_zn(xs)
573
629
  zn_loc, zn_scale = self._get_basal_embedding(xs)
574
630
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
631
+ else:
632
+ zns = embeds
633
+
634
+ if self.loss_func == 'gamma-poisson':
635
+ con_alpha,con_beta = self.encoder_concentrate(zns)
636
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
575
637
 
576
638
  def model4(self, xs, us, ys, embeds=None):
577
639
  pyro.module('PerturbFlow', self)
@@ -645,13 +707,17 @@ class PerturbFlow(nn.Module):
645
707
  else:
646
708
  zs = zns
647
709
 
648
- concentrate = self.decoder_concentrate(zs)
649
- if self.loss_func == 'bernoulli':
650
- log_theta = concentrate
710
+ if self.loss_func == 'gamma-poisson':
711
+ con_alpha,con_beta = self.decoder_concentrate(zs)
712
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
651
713
  else:
652
- rate = concentrate.exp()
653
- if self.loss_func != 'poisson':
654
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
714
+ concentrate = self.decoder_concentrate(zs)
715
+ if self.loss_func == 'bernoulli':
716
+ log_theta = concentrate
717
+ else:
718
+ rate = concentrate.exp()
719
+ if self.loss_func != 'poisson':
720
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
655
721
 
656
722
  if self.loss_func == 'negbinomial':
657
723
  if self.use_zeroinflate:
@@ -663,6 +729,11 @@ class PerturbFlow(nn.Module):
663
729
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
664
730
  else:
665
731
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
732
+ elif self.loss_func == 'gamma-poisson':
733
+ if self.use_zeroinflate:
734
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
735
+ else:
736
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
666
737
  elif self.loss_func == 'multinomial':
667
738
  pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
668
739
  elif self.loss_func == 'bernoulli':
@@ -677,6 +748,12 @@ class PerturbFlow(nn.Module):
677
748
  #zn_loc, zn_scale = self.encoder_zn(xs)
678
749
  zn_loc, zn_scale = self._get_basal_embedding(xs)
679
750
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
751
+ else:
752
+ zns = embeds
753
+
754
+ if self.loss_func == 'gamma-poisson':
755
+ con_alpha,con_beta = self.encoder_concentrate(zns)
756
+ rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
680
757
 
681
758
  def _total_effects(self, zns, us):
682
759
  zus = None
@@ -855,7 +932,12 @@ class PerturbFlow(nn.Module):
855
932
  return tensor_to_numpy(ms)
856
933
 
857
934
  def _get_expression_response(self, delta_zs):
858
- return self.decoder_concentrate(delta_zs)
935
+ if self.loss_func == 'gamma-poisson':
936
+ alpha,beta = self.encoder_concentrate(delta_zs)
937
+ xs = dist.Gamma(alpha,beta).to_event(1).mean
938
+ else:
939
+ xs = self.decoder_concentrate(delta_zs)
940
+ return xs
859
941
 
860
942
  def get_expression_response(self,
861
943
  delta_zs,
@@ -878,7 +960,7 @@ class PerturbFlow(nn.Module):
878
960
  R = np.concatenate(R)
879
961
  return R
880
962
 
881
- def _count(self,concentrate):
963
+ def _count(self,concentrate, library_size=None):
882
964
  if self.loss_func == 'bernoulli':
883
965
  #counts = self.sigmoid(concentrate)
884
966
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -888,10 +970,17 @@ class PerturbFlow(nn.Module):
888
970
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
889
971
 
890
972
  total_count = pyro.param("inverse_dispersion")
891
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1)
973
+ counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
892
974
  elif self.loss_func == 'poisson':
893
975
  rate = concentrate.exp()
894
- counts = dist.Poisson(rate=rate).to_event(1)
976
+ counts = dist.Poisson(rate=rate).to_event(1).mean
977
+ elif self.loss_func == 'gamma-poisson':
978
+ counts = dist.Poisson(rate=concentrate).to_event(1).mean
979
+ elif self.loss_func == 'multinomial':
980
+ rate = concentrate.exp()
981
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
982
+ counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
983
+ counts = counts * library_size
895
984
  return counts
896
985
 
897
986
  def _count_sample(self,concentrate):
@@ -903,22 +992,35 @@ class PerturbFlow(nn.Module):
903
992
  counts = dist.Poisson(rate=counts).to_event(1).sample()
904
993
  return counts
905
994
 
906
- def get_counts(self, zs,
995
+ def get_counts(self, zs, library_sizes = None,
907
996
  batch_size: int = 1024,
908
997
  use_sampler: bool = False):
909
998
 
910
999
  zs = convert_to_tensor(zs, device=self.get_device())
911
- dataset = CustomDataset(zs)
1000
+ ls = zs
1001
+
1002
+ if self.loss_func == 'multinomial':
1003
+ assert library_sizes!=None, 'Library sizes are required for multinomial!'
1004
+
1005
+ if type(library_sizes) == list:
1006
+ library_sizes = np.array(library_sizes).view(-1,1)
1007
+ elif len(library_sizes.shape)==1:
1008
+ library_sizes = library_sizes.view(-1,1)
1009
+ ls = convert_to_tensor(library_sizes, device=self.get_device)
1010
+
1011
+ dataset = CustomDataset2(zs,ls)
912
1012
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
913
1013
 
914
1014
  E = []
915
1015
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
916
- for Z_batch, _ in dataloader:
1016
+ for Z_batch, L_batch, _ in dataloader:
1017
+ if self.loss_func != 'multinomial':
1018
+ L_batch = None
917
1019
  concentrate = self._get_expression_response(Z_batch)
918
1020
  if use_sampler:
919
1021
  counts = self._count_sample(concentrate)
920
1022
  else:
921
- counts = self._count(concentrate)
1023
+ counts = self._count(concentrate, L_batch)
922
1024
  E.append(tensor_to_numpy(counts))
923
1025
  pbar.update(1)
924
1026
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.53
3
+ Version: 2.1.55
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=JS0TguFFewNU6lwFLI0rtJsPUkDcHWFpN2USuBB1dL8,53827
1
+ SURE/PerturbFlow.py,sha256=0-hD4NFKd0zvh_kBOCeh9irAjJ5TuyD7djKJKDCZv6I,59523
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.53.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.53.dist-info/METADATA,sha256=wNhmVGxxzIeL38Nb2VXIAJC6zX_jK3SgFqnqCd56ajA,2678
22
- sure_tools-2.1.53.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.53.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.53.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.53.dist-info/RECORD,,
20
+ sure_tools-2.1.55.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.55.dist-info/METADATA,sha256=GmbQukuqLtfvGrGd0VCuzY5S396a2I-M08_hYkB9vB8,2678
22
+ sure_tools-2.1.55.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.55.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.55.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.55.dist-info/RECORD,,