SURE-tools 2.1.58__py3-none-any.whl → 2.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -317,8 +317,6 @@ class PerturbFlow(nn.Module):
317
317
  if self.loss_func=='negbinomial':
318
318
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
319
319
  xs.new_ones(self.input_size), constraint=constraints.positive)
320
- elif self.loss_func == 'multinomial':
321
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
322
320
 
323
321
  if self.use_zeroinflate:
324
322
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -353,25 +351,26 @@ class PerturbFlow(nn.Module):
353
351
 
354
352
  zs = zns
355
353
  concentrate = self.decoder_concentrate(zs)
356
- if self.loss_func in ['bernoulli','negbinomial']:
354
+ if self.loss_func in ['bernoulli']:
357
355
  log_theta = concentrate
358
356
  else:
359
357
  rate = concentrate.exp()
360
- if self.loss_func != 'poisson':
361
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
358
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
359
+ if self.loss_func == 'poisson':
360
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
362
361
 
363
362
  if self.loss_func == 'negbinomial':
364
363
  if self.use_zeroinflate:
365
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
364
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
366
365
  else:
367
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
366
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
368
367
  elif self.loss_func == 'poisson':
369
368
  if self.use_zeroinflate:
370
369
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
371
370
  else:
372
371
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
373
372
  elif self.loss_func == 'multinomial':
374
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
373
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
375
374
  elif self.loss_func == 'bernoulli':
376
375
  if self.use_zeroinflate:
377
376
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -397,8 +396,6 @@ class PerturbFlow(nn.Module):
397
396
  if self.loss_func=='negbinomial':
398
397
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
399
398
  xs.new_ones(self.input_size), constraint=constraints.positive)
400
- elif self.loss_func == 'multinomial':
401
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
402
399
 
403
400
  if self.use_zeroinflate:
404
401
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -432,37 +429,32 @@ class PerturbFlow(nn.Module):
432
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
433
430
 
434
431
  if self.cell_factor_size>0:
435
- #zus = None
436
- #for i in np.arange(self.cell_factor_size):
437
- # if i==0:
438
- # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
439
- # else:
440
- # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
441
432
  zus = self._total_effects(zns, us)
442
433
  zs = zns+zus
443
434
  else:
444
435
  zs = zns
445
436
 
446
437
  concentrate = self.decoder_concentrate(zs)
447
- if self.loss_func in ['bernoulli','negbinomial']:
438
+ if self.loss_func in ['bernoulli']:
448
439
  log_theta = concentrate
449
440
  else:
450
441
  rate = concentrate.exp()
451
- if self.loss_func != 'poisson':
452
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
442
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
443
+ if self.loss_func == 'poisson':
444
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
453
445
 
454
446
  if self.loss_func == 'negbinomial':
455
447
  if self.use_zeroinflate:
456
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
448
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
457
449
  else:
458
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
450
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
459
451
  elif self.loss_func == 'poisson':
460
452
  if self.use_zeroinflate:
461
453
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
462
454
  else:
463
455
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
464
456
  elif self.loss_func == 'multinomial':
465
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
457
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
466
458
  elif self.loss_func == 'bernoulli':
467
459
  if self.use_zeroinflate:
468
460
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -488,8 +480,6 @@ class PerturbFlow(nn.Module):
488
480
  if self.loss_func=='negbinomial':
489
481
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
490
482
  xs.new_ones(self.input_size), constraint=constraints.positive)
491
- elif self.loss_func == 'multinomial':
492
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
493
483
 
494
484
  if self.use_zeroinflate:
495
485
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -541,25 +531,26 @@ class PerturbFlow(nn.Module):
541
531
  zs = zns
542
532
 
543
533
  concentrate = self.decoder_concentrate(zs)
544
- if self.loss_func in ['bernoulli','negbinomial']:
534
+ if self.loss_func in ['bernoulli']:
545
535
  log_theta = concentrate
546
536
  else:
547
537
  rate = concentrate.exp()
548
- if self.loss_func != 'poisson':
549
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
538
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
539
+ if self.loss_func == 'poisson':
540
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
550
541
 
551
542
  if self.loss_func == 'negbinomial':
552
543
  if self.use_zeroinflate:
553
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
544
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
554
545
  else:
555
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
546
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
556
547
  elif self.loss_func == 'poisson':
557
548
  if self.use_zeroinflate:
558
549
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
559
550
  else:
560
551
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
561
552
  elif self.loss_func == 'multinomial':
562
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
553
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
563
554
  elif self.loss_func == 'bernoulli':
564
555
  if self.use_zeroinflate:
565
556
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -585,8 +576,6 @@ class PerturbFlow(nn.Module):
585
576
  if self.loss_func=='negbinomial':
586
577
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
587
578
  xs.new_ones(self.input_size), constraint=constraints.positive)
588
- elif self.loss_func == 'multinomial':
589
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
590
579
 
591
580
  if self.use_zeroinflate:
592
581
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -648,25 +637,26 @@ class PerturbFlow(nn.Module):
648
637
  zs = zns
649
638
 
650
639
  concentrate = self.decoder_concentrate(zs)
651
- if self.loss_func in ['bernoulli','negbinomial']:
640
+ if self.loss_func in ['bernoulli']:
652
641
  log_theta = concentrate
653
642
  else:
654
643
  rate = concentrate.exp()
655
- if self.loss_func != 'poisson':
656
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
644
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
645
+ if self.loss_func == 'poisson':
646
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
657
647
 
658
648
  if self.loss_func == 'negbinomial':
659
649
  if self.use_zeroinflate:
660
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
650
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
661
651
  else:
662
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
652
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
663
653
  elif self.loss_func == 'poisson':
664
654
  if self.use_zeroinflate:
665
655
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
666
656
  else:
667
657
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
668
658
  elif self.loss_func == 'multinomial':
669
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
659
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
670
660
  elif self.loss_func == 'bernoulli':
671
661
  if self.use_zeroinflate:
672
662
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -895,6 +885,8 @@ class PerturbFlow(nn.Module):
895
885
  counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
896
886
  elif self.loss_func == 'poisson':
897
887
  rate = concentrate.exp()
888
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
889
+ rate = theta * library_size
898
890
  counts = dist.Poisson(rate=rate).to_event(1).mean
899
891
  elif self.loss_func == 'multinomial':
900
892
  rate = concentrate.exp()
@@ -919,7 +911,7 @@ class PerturbFlow(nn.Module):
919
911
  zs = convert_to_tensor(zs, device=self.get_device())
920
912
  ls = zs
921
913
 
922
- if self.loss_func == 'multinomial':
914
+ if self.loss_func in ['multinomial','poisson']:
923
915
  assert library_sizes!=None, 'Library sizes are required for multinomial!'
924
916
 
925
917
  if type(library_sizes) == list:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.58
3
+ Version: 2.1.60
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=7vflCQ8mtX0jzDe5lEIVxF4zwgWJIJ9aEZ6lD1duv0E,54985
1
+ SURE/PerturbFlow.py,sha256=CtM6FG6s6nn0VAG0-Ssi6yntcNxoWQaolXBH-JX-zjM,54358
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.58.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.58.dist-info/METADATA,sha256=IKFJkaArfqXoAjczpEKYZworTX83okZYw7Kf8Bx430Y,2678
22
- sure_tools-2.1.58.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.58.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.58.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.58.dist-info/RECORD,,
20
+ sure_tools-2.1.60.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.60.dist-info/METADATA,sha256=g48MZci8NumHGcOhtuANmkpEfB19Et_EkxfDtFgFtB0,2678
22
+ sure_tools-2.1.60.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.60.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.60.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.60.dist-info/RECORD,,