SURE-tools 2.1.77__py3-none-any.whl → 2.1.79__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +7 -34
- SURE/SURE.py +17 -13
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/METADATA +1 -1
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/RECORD +8 -8
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.77.dist-info → sure_tools-2.1.79.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -104,7 +104,6 @@ class PerturbFlow(nn.Module):
|
|
|
104
104
|
#self.use_bias = not zero_bias
|
|
105
105
|
|
|
106
106
|
self.codebook_weights = None
|
|
107
|
-
self.total_count = None
|
|
108
107
|
|
|
109
108
|
set_random_seed(seed)
|
|
110
109
|
self.setup_networks()
|
|
@@ -318,7 +317,6 @@ class PerturbFlow(nn.Module):
|
|
|
318
317
|
if self.loss_func=='negbinomial':
|
|
319
318
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
320
319
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
321
|
-
self.total_count = total_count
|
|
322
320
|
|
|
323
321
|
if self.use_zeroinflate:
|
|
324
322
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -398,7 +396,6 @@ class PerturbFlow(nn.Module):
|
|
|
398
396
|
if self.loss_func=='negbinomial':
|
|
399
397
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
400
398
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
401
|
-
self.total_count = total_count
|
|
402
399
|
|
|
403
400
|
if self.use_zeroinflate:
|
|
404
401
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -483,7 +480,6 @@ class PerturbFlow(nn.Module):
|
|
|
483
480
|
if self.loss_func=='negbinomial':
|
|
484
481
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
485
482
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
486
|
-
self.total_count = total_count
|
|
487
483
|
|
|
488
484
|
if self.use_zeroinflate:
|
|
489
485
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -580,7 +576,6 @@ class PerturbFlow(nn.Module):
|
|
|
580
576
|
if self.loss_func=='negbinomial':
|
|
581
577
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
582
578
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
583
|
-
self.total_count = total_count
|
|
584
579
|
|
|
585
580
|
if self.use_zeroinflate:
|
|
586
581
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -881,25 +876,11 @@ class PerturbFlow(nn.Module):
|
|
|
881
876
|
if self.loss_func == 'bernoulli':
|
|
882
877
|
#counts = self.sigmoid(concentrate)
|
|
883
878
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
884
|
-
|
|
885
|
-
rate = concentrate.exp()
|
|
886
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
887
|
-
|
|
888
|
-
total_count = self.total_count
|
|
889
|
-
#total_count = pyro.param("inverse_dispersion")
|
|
890
|
-
#store = pyro.get_param_store()
|
|
891
|
-
#total_count = store['inverse_dispersion']
|
|
892
|
-
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
|
|
893
|
-
elif self.loss_func == 'poisson':
|
|
879
|
+
else:
|
|
894
880
|
rate = concentrate.exp()
|
|
895
881
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
896
882
|
counts = theta * library_size
|
|
897
883
|
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
898
|
-
elif self.loss_func == 'multinomial':
|
|
899
|
-
rate = concentrate.exp()
|
|
900
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
901
|
-
#counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
|
|
902
|
-
counts = theta * library_size
|
|
903
884
|
return counts
|
|
904
885
|
|
|
905
886
|
def _count_sample(self,concentrate):
|
|
@@ -911,22 +892,17 @@ class PerturbFlow(nn.Module):
|
|
|
911
892
|
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
912
893
|
return counts
|
|
913
894
|
|
|
914
|
-
def get_counts(self, zs, library_sizes
|
|
895
|
+
def get_counts(self, zs, library_sizes,
|
|
915
896
|
batch_size: int = 1024,
|
|
916
897
|
use_sampler: bool = False):
|
|
917
898
|
|
|
918
899
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
919
900
|
|
|
920
|
-
if
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
elif len(library_sizes.shape)==1:
|
|
926
|
-
library_sizes = library_sizes.view(-1,1)
|
|
927
|
-
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
928
|
-
else:
|
|
929
|
-
ls = zs
|
|
901
|
+
if type(library_sizes) == list:
|
|
902
|
+
library_sizes = np.array(library_sizes).view(-1,1)
|
|
903
|
+
elif len(library_sizes.shape)==1:
|
|
904
|
+
library_sizes = library_sizes.view(-1,1)
|
|
905
|
+
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
930
906
|
|
|
931
907
|
dataset = CustomDataset2(zs,ls)
|
|
932
908
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
@@ -1084,9 +1060,6 @@ class PerturbFlow(nn.Module):
|
|
|
1084
1060
|
pbar.set_postfix({'loss': str_loss})
|
|
1085
1061
|
pbar.update(1)
|
|
1086
1062
|
|
|
1087
|
-
if self.loss_func == 'negbinomial':
|
|
1088
|
-
self.total_count = pyro.param('inverse_dispersion')
|
|
1089
|
-
|
|
1090
1063
|
@classmethod
|
|
1091
1064
|
def save_model(cls, model, file_path, compression=False):
|
|
1092
1065
|
"""Save the model to the specified file path."""
|
SURE/SURE.py
CHANGED
|
@@ -369,12 +369,13 @@ class SURE(nn.Module):
|
|
|
369
369
|
|
|
370
370
|
zs = zns
|
|
371
371
|
concentrate = self.decoder_concentrate(zs)
|
|
372
|
-
if self.loss_func
|
|
372
|
+
if self.loss_func in ['bernoulli']:
|
|
373
373
|
log_theta = concentrate
|
|
374
374
|
else:
|
|
375
375
|
rate = concentrate.exp()
|
|
376
|
-
|
|
377
|
-
|
|
376
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
377
|
+
if self.loss_func == 'poisson':
|
|
378
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
378
379
|
|
|
379
380
|
if self.loss_func == 'negbinomial':
|
|
380
381
|
if self.use_zeroinflate:
|
|
@@ -451,12 +452,13 @@ class SURE(nn.Module):
|
|
|
451
452
|
zs = zns
|
|
452
453
|
|
|
453
454
|
concentrate = self.decoder_concentrate(zs)
|
|
454
|
-
if self.loss_func
|
|
455
|
+
if self.loss_func in ['bernoulli']:
|
|
455
456
|
log_theta = concentrate
|
|
456
457
|
else:
|
|
457
458
|
rate = concentrate.exp()
|
|
458
|
-
|
|
459
|
-
|
|
459
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
460
|
+
if self.loss_func == 'poisson':
|
|
461
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
460
462
|
|
|
461
463
|
if self.loss_func == 'negbinomial':
|
|
462
464
|
if self.use_zeroinflate:
|
|
@@ -545,12 +547,13 @@ class SURE(nn.Module):
|
|
|
545
547
|
zs = zns
|
|
546
548
|
|
|
547
549
|
concentrate = self.decoder_concentrate(zs)
|
|
548
|
-
if self.loss_func
|
|
550
|
+
if self.loss_func in ['bernoulli']:
|
|
549
551
|
log_theta = concentrate
|
|
550
552
|
else:
|
|
551
553
|
rate = concentrate.exp()
|
|
552
|
-
|
|
553
|
-
|
|
554
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
555
|
+
if self.loss_func == 'poisson':
|
|
556
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
554
557
|
|
|
555
558
|
if self.loss_func == 'negbinomial':
|
|
556
559
|
if self.use_zeroinflate:
|
|
@@ -641,13 +644,14 @@ class SURE(nn.Module):
|
|
|
641
644
|
zs = zns
|
|
642
645
|
|
|
643
646
|
concentrate = self.decoder_concentrate(zs)
|
|
644
|
-
if self.loss_func
|
|
647
|
+
if self.loss_func in ['bernoulli']:
|
|
645
648
|
log_theta = concentrate
|
|
646
649
|
else:
|
|
647
650
|
rate = concentrate.exp()
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
652
|
+
if self.loss_func == 'poisson':
|
|
653
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
654
|
+
|
|
651
655
|
if self.loss_func == 'negbinomial':
|
|
652
656
|
if self.use_zeroinflate:
|
|
653
657
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
2
|
-
SURE/SURE.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=jvN2CD2L7YX6EiqP0Zw--zwegwHDiGkQYL2v_So5YYA,53382
|
|
2
|
+
SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
5
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.79.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.79.dist-info/METADATA,sha256=B8i3DryW1NcFrOvzEROG5YDlZG14RHa_q2cVbNVUwmg,2678
|
|
22
|
+
sure_tools-2.1.79.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.79.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.79.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.79.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|