SURE-tools 2.1.59__py3-none-any.whl → 2.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -351,7 +351,7 @@ class PerturbFlow(nn.Module):
351
351
 
352
352
  zs = zns
353
353
  concentrate = self.decoder_concentrate(zs)
354
- if self.loss_func in ['bernoulli','negbinomial']:
354
+ if self.loss_func in ['bernoulli']:
355
355
  log_theta = concentrate
356
356
  else:
357
357
  rate = concentrate.exp()
@@ -361,9 +361,9 @@ class PerturbFlow(nn.Module):
361
361
 
362
362
  if self.loss_func == 'negbinomial':
363
363
  if self.use_zeroinflate:
364
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
364
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
365
365
  else:
366
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
366
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
367
367
  elif self.loss_func == 'poisson':
368
368
  if self.use_zeroinflate:
369
369
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -435,7 +435,7 @@ class PerturbFlow(nn.Module):
435
435
  zs = zns
436
436
 
437
437
  concentrate = self.decoder_concentrate(zs)
438
- if self.loss_func in ['bernoulli','negbinomial']:
438
+ if self.loss_func in ['bernoulli']:
439
439
  log_theta = concentrate
440
440
  else:
441
441
  rate = concentrate.exp()
@@ -445,9 +445,9 @@ class PerturbFlow(nn.Module):
445
445
 
446
446
  if self.loss_func == 'negbinomial':
447
447
  if self.use_zeroinflate:
448
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
448
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
449
449
  else:
450
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
450
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
451
451
  elif self.loss_func == 'poisson':
452
452
  if self.use_zeroinflate:
453
453
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -531,7 +531,7 @@ class PerturbFlow(nn.Module):
531
531
  zs = zns
532
532
 
533
533
  concentrate = self.decoder_concentrate(zs)
534
- if self.loss_func in ['bernoulli','negbinomial']:
534
+ if self.loss_func in ['bernoulli']:
535
535
  log_theta = concentrate
536
536
  else:
537
537
  rate = concentrate.exp()
@@ -541,9 +541,9 @@ class PerturbFlow(nn.Module):
541
541
 
542
542
  if self.loss_func == 'negbinomial':
543
543
  if self.use_zeroinflate:
544
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
544
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
545
545
  else:
546
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
546
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
547
547
  elif self.loss_func == 'poisson':
548
548
  if self.use_zeroinflate:
549
549
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -637,7 +637,7 @@ class PerturbFlow(nn.Module):
637
637
  zs = zns
638
638
 
639
639
  concentrate = self.decoder_concentrate(zs)
640
- if self.loss_func in ['bernoulli','negbinomial']:
640
+ if self.loss_func in ['bernoulli']:
641
641
  log_theta = concentrate
642
642
  else:
643
643
  rate = concentrate.exp()
@@ -647,9 +647,9 @@ class PerturbFlow(nn.Module):
647
647
 
648
648
  if self.loss_func == 'negbinomial':
649
649
  if self.use_zeroinflate:
650
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
650
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
651
651
  else:
652
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
652
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
653
653
  elif self.loss_func == 'poisson':
654
654
  if self.use_zeroinflate:
655
655
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.59
3
+ Version: 2.1.60
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=LoWyN1gcgCPtGsa69wyBCA2oMYcqmK2Eq3__Xy0u-aQ,54454
1
+ SURE/PerturbFlow.py,sha256=CtM6FG6s6nn0VAG0-Ssi6yntcNxoWQaolXBH-JX-zjM,54358
2
2
  SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.59.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.59.dist-info/METADATA,sha256=6U4dnYcyfpZX3GOpU2cBG4EctQSK03XHxrCqAdiqIUg,2678
22
- sure_tools-2.1.59.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.59.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.59.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.59.dist-info/RECORD,,
20
+ sure_tools-2.1.60.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.60.dist-info/METADATA,sha256=g48MZci8NumHGcOhtuANmkpEfB19Et_EkxfDtFgFtB0,2678
22
+ sure_tools-2.1.60.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.60.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.60.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.60.dist-info/RECORD,,