SURE-tools 2.1.77__py3-none-any.whl → 2.1.79__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -104,7 +104,6 @@ class PerturbFlow(nn.Module):
104
104
  #self.use_bias = not zero_bias
105
105
 
106
106
  self.codebook_weights = None
107
- self.total_count = None
108
107
 
109
108
  set_random_seed(seed)
110
109
  self.setup_networks()
@@ -318,7 +317,6 @@ class PerturbFlow(nn.Module):
318
317
  if self.loss_func=='negbinomial':
319
318
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
320
319
  xs.new_ones(self.input_size), constraint=constraints.positive)
321
- self.total_count = total_count
322
320
 
323
321
  if self.use_zeroinflate:
324
322
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -398,7 +396,6 @@ class PerturbFlow(nn.Module):
398
396
  if self.loss_func=='negbinomial':
399
397
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
400
398
  xs.new_ones(self.input_size), constraint=constraints.positive)
401
- self.total_count = total_count
402
399
 
403
400
  if self.use_zeroinflate:
404
401
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -483,7 +480,6 @@ class PerturbFlow(nn.Module):
483
480
  if self.loss_func=='negbinomial':
484
481
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
485
482
  xs.new_ones(self.input_size), constraint=constraints.positive)
486
- self.total_count = total_count
487
483
 
488
484
  if self.use_zeroinflate:
489
485
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -580,7 +576,6 @@ class PerturbFlow(nn.Module):
580
576
  if self.loss_func=='negbinomial':
581
577
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
582
578
  xs.new_ones(self.input_size), constraint=constraints.positive)
583
- self.total_count = total_count
584
579
 
585
580
  if self.use_zeroinflate:
586
581
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -881,25 +876,11 @@ class PerturbFlow(nn.Module):
881
876
  if self.loss_func == 'bernoulli':
882
877
  #counts = self.sigmoid(concentrate)
883
878
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
884
- elif self.loss_func == 'negbinomial':
885
- rate = concentrate.exp()
886
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
887
-
888
- total_count = self.total_count
889
- #total_count = pyro.param("inverse_dispersion")
890
- #store = pyro.get_param_store()
891
- #total_count = store['inverse_dispersion']
892
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
893
- elif self.loss_func == 'poisson':
879
+ else:
894
880
  rate = concentrate.exp()
895
881
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
896
882
  counts = theta * library_size
897
883
  #counts = dist.Poisson(rate=rate).to_event(1).mean
898
- elif self.loss_func == 'multinomial':
899
- rate = concentrate.exp()
900
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
901
- #counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
902
- counts = theta * library_size
903
884
  return counts
904
885
 
905
886
  def _count_sample(self,concentrate):
@@ -911,22 +892,17 @@ class PerturbFlow(nn.Module):
911
892
  counts = dist.Poisson(rate=counts).to_event(1).sample()
912
893
  return counts
913
894
 
914
- def get_counts(self, zs, library_sizes = None,
895
+ def get_counts(self, zs, library_sizes,
915
896
  batch_size: int = 1024,
916
897
  use_sampler: bool = False):
917
898
 
918
899
  zs = convert_to_tensor(zs, device=self.get_device())
919
900
 
920
- if self.loss_func in ['multinomial','poisson']:
921
- assert library_sizes is not None, 'Library sizes are required for multinomial!'
922
-
923
- if type(library_sizes) == list:
924
- library_sizes = np.array(library_sizes).view(-1,1)
925
- elif len(library_sizes.shape)==1:
926
- library_sizes = library_sizes.view(-1,1)
927
- ls = convert_to_tensor(library_sizes, device=self.get_device())
928
- else:
929
- ls = zs
901
+ if type(library_sizes) == list:
902
+ library_sizes = np.array(library_sizes).view(-1,1)
903
+ elif len(library_sizes.shape)==1:
904
+ library_sizes = library_sizes.view(-1,1)
905
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
930
906
 
931
907
  dataset = CustomDataset2(zs,ls)
932
908
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
@@ -1084,9 +1060,6 @@ class PerturbFlow(nn.Module):
1084
1060
  pbar.set_postfix({'loss': str_loss})
1085
1061
  pbar.update(1)
1086
1062
 
1087
- if self.loss_func == 'negbinomial':
1088
- self.total_count = pyro.param('inverse_dispersion')
1089
-
1090
1063
  @classmethod
1091
1064
  def save_model(cls, model, file_path, compression=False):
1092
1065
  """Save the model to the specified file path."""
SURE/SURE.py CHANGED
@@ -369,12 +369,13 @@ class SURE(nn.Module):
369
369
 
370
370
  zs = zns
371
371
  concentrate = self.decoder_concentrate(zs)
372
- if self.loss_func == 'bernoulli':
372
+ if self.loss_func in ['bernoulli']:
373
373
  log_theta = concentrate
374
374
  else:
375
375
  rate = concentrate.exp()
376
- if self.loss_func != 'poisson':
377
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
376
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
377
+ if self.loss_func == 'poisson':
378
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
378
379
 
379
380
  if self.loss_func == 'negbinomial':
380
381
  if self.use_zeroinflate:
@@ -451,12 +452,13 @@ class SURE(nn.Module):
451
452
  zs = zns
452
453
 
453
454
  concentrate = self.decoder_concentrate(zs)
454
- if self.loss_func == 'bernoulli':
455
+ if self.loss_func in ['bernoulli']:
455
456
  log_theta = concentrate
456
457
  else:
457
458
  rate = concentrate.exp()
458
- if self.loss_func != 'poisson':
459
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
459
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
460
+ if self.loss_func == 'poisson':
461
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
460
462
 
461
463
  if self.loss_func == 'negbinomial':
462
464
  if self.use_zeroinflate:
@@ -545,12 +547,13 @@ class SURE(nn.Module):
545
547
  zs = zns
546
548
 
547
549
  concentrate = self.decoder_concentrate(zs)
548
- if self.loss_func == 'bernoulli':
550
+ if self.loss_func in ['bernoulli']:
549
551
  log_theta = concentrate
550
552
  else:
551
553
  rate = concentrate.exp()
552
- if self.loss_func != 'poisson':
553
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
554
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
555
+ if self.loss_func == 'poisson':
556
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
554
557
 
555
558
  if self.loss_func == 'negbinomial':
556
559
  if self.use_zeroinflate:
@@ -641,13 +644,14 @@ class SURE(nn.Module):
641
644
  zs = zns
642
645
 
643
646
  concentrate = self.decoder_concentrate(zs)
644
- if self.loss_func == 'bernoulli':
647
+ if self.loss_func in ['bernoulli']:
645
648
  log_theta = concentrate
646
649
  else:
647
650
  rate = concentrate.exp()
648
- if self.loss_func != 'poisson':
649
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
650
-
651
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
652
+ if self.loss_func == 'poisson':
653
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
654
+
651
655
  if self.loss_func == 'negbinomial':
652
656
  if self.use_zeroinflate:
653
657
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.77
3
+ Version: 2.1.79
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,5 +1,5 @@
1
- SURE/PerturbFlow.py,sha256=lRBSHIvJiAeyOyRDKukQsHpGxMU9q_jNEYKYcyyQtIc,54717
2
- SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
1
+ SURE/PerturbFlow.py,sha256=jvN2CD2L7YX6EiqP0Zw--zwegwHDiGkQYL2v_So5YYA,53382
2
+ SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
3
3
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
5
5
  SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.77.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.77.dist-info/METADATA,sha256=R2y9rOLrZXegdYLXXdqUMiIc0wYi-w1cjFO198SnkZo,2678
22
- sure_tools-2.1.77.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.77.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.77.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.77.dist-info/RECORD,,
20
+ sure_tools-2.1.79.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.1.79.dist-info/METADATA,sha256=B8i3DryW1NcFrOvzEROG5YDlZG14RHa_q2cVbNVUwmg,2678
22
+ sure_tools-2.1.79.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.1.79.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.1.79.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.1.79.dist-info/RECORD,,