SURE-tools 2.1.87__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +1388 -0
- SURE/{PerturbFlow.py → PerturbE.py} +51 -122
- SURE/SURE.py +6 -6
- SURE/TranscriptomeDecoder.py +527 -0
- SURE/__init__.py +7 -3
- SURE/flow/flow_stats.py +12 -0
- SURE/perturb/perturb.py +27 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.1.87.dist-info → sure_tools-2.4.3.dist-info}/METADATA +1 -1
- sure_tools-2.4.3.dist-info/RECORD +27 -0
- sure_tools-2.1.87.dist-info/RECORD +0 -25
- {sure_tools-2.1.87.dist-info → sure_tools-2.4.3.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.87.dist-info → sure_tools-2.4.3.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.87.dist-info → sure_tools-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.87.dist-info → sure_tools-2.4.3.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_pr
|
|
|
10
10
|
from torch.distributions import constraints
|
|
11
11
|
from torch.distributions.transforms import SoftmaxTransform
|
|
12
12
|
|
|
13
|
-
from .utils.custom_mlp import MLP, Exp,
|
|
13
|
+
from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP2
|
|
14
14
|
from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
|
|
15
15
|
|
|
16
16
|
|
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class PerturbE(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
|
-
hidden_layers: list = [
|
|
68
|
+
hidden_layers: list = [500],
|
|
69
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
70
|
nn_dropout: float = 0.1,
|
|
71
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -73,8 +73,6 @@ class PerturbFlow(nn.Module):
|
|
|
73
73
|
config_enum: str = 'parallel',
|
|
74
74
|
use_cuda: bool = True,
|
|
75
75
|
seed: int = 42,
|
|
76
|
-
zero_bias: bool|list = True,
|
|
77
|
-
enumrate: bool = False,
|
|
78
76
|
dtype = torch.float32, # type: ignore
|
|
79
77
|
):
|
|
80
78
|
super().__init__()
|
|
@@ -98,12 +96,6 @@ class PerturbFlow(nn.Module):
|
|
|
98
96
|
self.post_layer_fct = post_layer_fct
|
|
99
97
|
self.post_act_fct = post_act_fct
|
|
100
98
|
self.hidden_layer_activation = hidden_layer_activation
|
|
101
|
-
if type(zero_bias) == list:
|
|
102
|
-
self.use_bias = [not x for x in zero_bias]
|
|
103
|
-
else:
|
|
104
|
-
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
|
-
#self.use_bias = not zero_bias
|
|
106
|
-
self.enumrate = enumrate
|
|
107
99
|
|
|
108
100
|
self.codebook_weights = None
|
|
109
101
|
|
|
@@ -202,29 +194,14 @@ class PerturbFlow(nn.Module):
|
|
|
202
194
|
)
|
|
203
195
|
|
|
204
196
|
if self.cell_factor_size>0:
|
|
205
|
-
self.cell_factor_effect =
|
|
206
|
-
|
|
207
|
-
if self.use_bias[i]:
|
|
208
|
-
self.cell_factor_effect.append(MLP(
|
|
209
|
-
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
197
|
+
self.cell_factor_effect = ZeroBiasMLP2(
|
|
198
|
+
[self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
210
199
|
activation=activate_fct,
|
|
211
200
|
output_activation=None,
|
|
212
201
|
post_layer_fct=post_layer_fct,
|
|
213
202
|
post_act_fct=post_act_fct,
|
|
214
203
|
allow_broadcast=self.allow_broadcast,
|
|
215
204
|
use_cuda=self.use_cuda,
|
|
216
|
-
)
|
|
217
|
-
)
|
|
218
|
-
else:
|
|
219
|
-
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
220
|
-
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
221
|
-
activation=activate_fct,
|
|
222
|
-
output_activation=None,
|
|
223
|
-
post_layer_fct=post_layer_fct,
|
|
224
|
-
post_act_fct=post_act_fct,
|
|
225
|
-
allow_broadcast=self.allow_broadcast,
|
|
226
|
-
use_cuda=self.use_cuda,
|
|
227
|
-
)
|
|
228
205
|
)
|
|
229
206
|
|
|
230
207
|
self.decoder_concentrate = MLP(
|
|
@@ -310,7 +287,7 @@ class PerturbFlow(nn.Module):
|
|
|
310
287
|
return xs
|
|
311
288
|
|
|
312
289
|
def model1(self, xs):
|
|
313
|
-
pyro.module('
|
|
290
|
+
pyro.module('PerturbE', self)
|
|
314
291
|
|
|
315
292
|
eps = torch.finfo(xs.dtype).eps
|
|
316
293
|
batch_size = xs.size(0)
|
|
@@ -372,7 +349,8 @@ class PerturbFlow(nn.Module):
|
|
|
372
349
|
else:
|
|
373
350
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
374
351
|
elif self.loss_func == 'multinomial':
|
|
375
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
352
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
353
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
376
354
|
elif self.loss_func == 'bernoulli':
|
|
377
355
|
if self.use_zeroinflate:
|
|
378
356
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -389,7 +367,7 @@ class PerturbFlow(nn.Module):
|
|
|
389
367
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
390
368
|
|
|
391
369
|
def model2(self, xs, us=None):
|
|
392
|
-
pyro.module('
|
|
370
|
+
pyro.module('PerturbE', self)
|
|
393
371
|
|
|
394
372
|
eps = torch.finfo(xs.dtype).eps
|
|
395
373
|
batch_size = xs.size(0)
|
|
@@ -431,12 +409,7 @@ class PerturbFlow(nn.Module):
|
|
|
431
409
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
432
410
|
|
|
433
411
|
if self.cell_factor_size>0:
|
|
434
|
-
|
|
435
|
-
idx = torch.argmax(ns, dim=1)
|
|
436
|
-
zn_loc = acs_loc[idx]
|
|
437
|
-
zus = self._total_effects(zn_loc, us)
|
|
438
|
-
else:
|
|
439
|
-
zus = self._total_effects(zns, us)
|
|
412
|
+
zus = self._perturb_effects(us)
|
|
440
413
|
zs = zns+zus
|
|
441
414
|
else:
|
|
442
415
|
zs = zns
|
|
@@ -461,7 +434,8 @@ class PerturbFlow(nn.Module):
|
|
|
461
434
|
else:
|
|
462
435
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
463
436
|
elif self.loss_func == 'multinomial':
|
|
464
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
437
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
438
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
465
439
|
elif self.loss_func == 'bernoulli':
|
|
466
440
|
if self.use_zeroinflate:
|
|
467
441
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -478,7 +452,7 @@ class PerturbFlow(nn.Module):
|
|
|
478
452
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
479
453
|
|
|
480
454
|
def model3(self, xs, ys, embeds=None):
|
|
481
|
-
pyro.module('
|
|
455
|
+
pyro.module('PerturbE', self)
|
|
482
456
|
|
|
483
457
|
eps = torch.finfo(xs.dtype).eps
|
|
484
458
|
batch_size = xs.size(0)
|
|
@@ -557,7 +531,8 @@ class PerturbFlow(nn.Module):
|
|
|
557
531
|
else:
|
|
558
532
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
559
533
|
elif self.loss_func == 'multinomial':
|
|
560
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
534
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
535
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
561
536
|
elif self.loss_func == 'bernoulli':
|
|
562
537
|
if self.use_zeroinflate:
|
|
563
538
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -574,7 +549,7 @@ class PerturbFlow(nn.Module):
|
|
|
574
549
|
zns = embeds
|
|
575
550
|
|
|
576
551
|
def model4(self, xs, us, ys, embeds=None):
|
|
577
|
-
pyro.module('
|
|
552
|
+
pyro.module('PerturbE', self)
|
|
578
553
|
|
|
579
554
|
eps = torch.finfo(xs.dtype).eps
|
|
580
555
|
batch_size = xs.size(0)
|
|
@@ -638,12 +613,7 @@ class PerturbFlow(nn.Module):
|
|
|
638
613
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
639
614
|
# else:
|
|
640
615
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
641
|
-
|
|
642
|
-
idx = torch.argmax(ns, dim=1)
|
|
643
|
-
zn_loc = acs_loc[idx]
|
|
644
|
-
zus = self._total_effects(zn_loc, us)
|
|
645
|
-
else:
|
|
646
|
-
zus = self._total_effects(zns, us)
|
|
616
|
+
zus = self._perturb_effects(us)
|
|
647
617
|
zs = zns+zus
|
|
648
618
|
else:
|
|
649
619
|
zs = zns
|
|
@@ -668,7 +638,8 @@ class PerturbFlow(nn.Module):
|
|
|
668
638
|
else:
|
|
669
639
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
670
640
|
elif self.loss_func == 'multinomial':
|
|
671
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
641
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
642
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
672
643
|
elif self.loss_func == 'bernoulli':
|
|
673
644
|
if self.use_zeroinflate:
|
|
674
645
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -684,13 +655,8 @@ class PerturbFlow(nn.Module):
|
|
|
684
655
|
else:
|
|
685
656
|
zns = embeds
|
|
686
657
|
|
|
687
|
-
def
|
|
688
|
-
zus =
|
|
689
|
-
for i in np.arange(self.cell_factor_size):
|
|
690
|
-
if i==0:
|
|
691
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
692
|
-
else:
|
|
693
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
658
|
+
def _perturb_effects(self, us):
|
|
659
|
+
zus = self._cell_response(us)
|
|
694
660
|
return zus
|
|
695
661
|
|
|
696
662
|
def _get_codebook_identity(self):
|
|
@@ -708,7 +674,7 @@ class PerturbFlow(nn.Module):
|
|
|
708
674
|
"""
|
|
709
675
|
Return the mean part of metacell codebook
|
|
710
676
|
"""
|
|
711
|
-
cb = self.
|
|
677
|
+
cb = self._get_codebook()
|
|
712
678
|
cb = tensor_to_numpy(cb)
|
|
713
679
|
return cb
|
|
714
680
|
|
|
@@ -822,23 +788,13 @@ class PerturbFlow(nn.Module):
|
|
|
822
788
|
A = np.concatenate(A)
|
|
823
789
|
return A
|
|
824
790
|
|
|
825
|
-
def predict(self, xs,
|
|
791
|
+
def predict(self, xs, perturbs_us, library_sizes=None):
|
|
826
792
|
perturbs_reference = np.array(perturbs_reference)
|
|
827
793
|
|
|
828
794
|
# basal embedding
|
|
829
795
|
zs = self.get_basal_embedding(xs)
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
us_i = us[:,pert_idx].reshape(-1,1)
|
|
833
|
-
|
|
834
|
-
# factor effect of xs
|
|
835
|
-
dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
|
|
836
|
-
|
|
837
|
-
# perturbation effect
|
|
838
|
-
ps = np.ones_like(us_i)
|
|
839
|
-
dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
|
|
840
|
-
|
|
841
|
-
zs = zs + dzs0 + dzs
|
|
796
|
+
dzs = self.get_cell_response(perturbs_us)
|
|
797
|
+
zs = zs + dzs
|
|
842
798
|
|
|
843
799
|
if library_sizes is None:
|
|
844
800
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -852,47 +808,32 @@ class PerturbFlow(nn.Module):
|
|
|
852
808
|
|
|
853
809
|
return counts, zs
|
|
854
810
|
|
|
855
|
-
def _cell_response(self,
|
|
856
|
-
|
|
857
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
858
|
-
if perturb.ndim==2:
|
|
859
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
860
|
-
else:
|
|
861
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
862
|
-
|
|
811
|
+
def _cell_response(self, perturb):
|
|
812
|
+
ms = self.cell_factor_effect(perturb)
|
|
863
813
|
return ms
|
|
864
814
|
|
|
865
815
|
def get_cell_response(self,
|
|
866
|
-
|
|
867
|
-
factor_idx,
|
|
868
|
-
perturb,
|
|
816
|
+
perturb_us,
|
|
869
817
|
batch_size: int = 1024):
|
|
870
818
|
"""
|
|
871
819
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
872
820
|
|
|
873
821
|
"""
|
|
874
|
-
xs = self.preprocess(xs)
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
dataset = CustomDataset2(xs,ps)
|
|
822
|
+
#xs = self.preprocess(xs)
|
|
823
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
824
|
+
dataset = CustomDataset(ps)
|
|
878
825
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
879
826
|
|
|
880
827
|
Z = []
|
|
881
828
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
882
|
-
for
|
|
883
|
-
zns = self._cell_response(
|
|
829
|
+
for P_batch, _ in dataloader:
|
|
830
|
+
zns = self._cell_response(P_batch)
|
|
884
831
|
Z.append(tensor_to_numpy(zns))
|
|
885
832
|
pbar.update(1)
|
|
886
833
|
|
|
887
834
|
Z = np.concatenate(Z)
|
|
888
835
|
return Z
|
|
889
836
|
|
|
890
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
891
|
-
zs = self._get_codebook()
|
|
892
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
893
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
894
|
-
return tensor_to_numpy(ms)
|
|
895
|
-
|
|
896
837
|
def _get_expression_response(self, delta_zs):
|
|
897
838
|
return self.decoder_concentrate(delta_zs)
|
|
898
839
|
|
|
@@ -917,36 +858,28 @@ class PerturbFlow(nn.Module):
|
|
|
917
858
|
R = np.concatenate(R)
|
|
918
859
|
return R
|
|
919
860
|
|
|
920
|
-
def _count(self,concentrate, library_size=None):
|
|
861
|
+
def _count(self, concentrate, library_size=None):
|
|
921
862
|
if self.loss_func == 'bernoulli':
|
|
922
863
|
#counts = self.sigmoid(concentrate)
|
|
923
864
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
865
|
+
elif self.loss_func == 'multinomial':
|
|
866
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
867
|
+
counts = theta * library_size
|
|
924
868
|
else:
|
|
925
869
|
rate = concentrate.exp()
|
|
926
870
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
927
871
|
counts = theta * library_size
|
|
928
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
929
|
-
return counts
|
|
930
|
-
|
|
931
|
-
def _count_sample(self,concentrate):
|
|
932
|
-
if self.loss_func == 'bernoulli':
|
|
933
|
-
logits = concentrate
|
|
934
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
935
|
-
else:
|
|
936
|
-
counts = self._count(concentrate=concentrate)
|
|
937
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
938
872
|
return counts
|
|
939
873
|
|
|
940
874
|
def get_counts(self, zs, library_sizes,
|
|
941
|
-
batch_size: int = 1024
|
|
942
|
-
use_sampler: bool = False):
|
|
875
|
+
batch_size: int = 1024):
|
|
943
876
|
|
|
944
877
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
945
878
|
|
|
946
879
|
if type(library_sizes) == list:
|
|
947
|
-
library_sizes = np.array(library_sizes).
|
|
880
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
948
881
|
elif len(library_sizes.shape)==1:
|
|
949
|
-
library_sizes = library_sizes.
|
|
882
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
950
883
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
951
884
|
|
|
952
885
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -956,10 +889,7 @@ class PerturbFlow(nn.Module):
|
|
|
956
889
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
957
890
|
for Z_batch, L_batch, _ in dataloader:
|
|
958
891
|
concentrate = self._get_expression_response(Z_batch)
|
|
959
|
-
|
|
960
|
-
counts = self._count_sample(concentrate)
|
|
961
|
-
else:
|
|
962
|
-
counts = self._count(concentrate, L_batch)
|
|
892
|
+
counts = self._count(concentrate, L_batch)
|
|
963
893
|
E.append(tensor_to_numpy(counts))
|
|
964
894
|
pbar.update(1)
|
|
965
895
|
|
|
@@ -982,7 +912,7 @@ class PerturbFlow(nn.Module):
|
|
|
982
912
|
us = None,
|
|
983
913
|
ys = None,
|
|
984
914
|
zs = None,
|
|
985
|
-
num_epochs: int =
|
|
915
|
+
num_epochs: int = 500,
|
|
986
916
|
learning_rate: float = 0.0001,
|
|
987
917
|
batch_size: int = 256,
|
|
988
918
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -993,7 +923,7 @@ class PerturbFlow(nn.Module):
|
|
|
993
923
|
threshold: int = 0,
|
|
994
924
|
use_jax: bool = True):
|
|
995
925
|
"""
|
|
996
|
-
Train the
|
|
926
|
+
Train the PerturbE model.
|
|
997
927
|
|
|
998
928
|
Parameters
|
|
999
929
|
----------
|
|
@@ -1019,7 +949,7 @@ class PerturbFlow(nn.Module):
|
|
|
1019
949
|
Parameter for optimization.
|
|
1020
950
|
use_jax
|
|
1021
951
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1022
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
952
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbE in the shell command.
|
|
1023
953
|
"""
|
|
1024
954
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1025
955
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1137,12 +1067,12 @@ class PerturbFlow(nn.Module):
|
|
|
1137
1067
|
|
|
1138
1068
|
|
|
1139
1069
|
EXAMPLE_RUN = (
|
|
1140
|
-
"example run:
|
|
1070
|
+
"example run: PerturbE --help"
|
|
1141
1071
|
)
|
|
1142
1072
|
|
|
1143
1073
|
def parse_args():
|
|
1144
1074
|
parser = argparse.ArgumentParser(
|
|
1145
|
-
description="
|
|
1075
|
+
description="PerturbE\n{}".format(EXAMPLE_RUN))
|
|
1146
1076
|
|
|
1147
1077
|
parser.add_argument(
|
|
1148
1078
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1329,7 +1259,7 @@ def main():
|
|
|
1329
1259
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1330
1260
|
|
|
1331
1261
|
###########################################
|
|
1332
|
-
|
|
1262
|
+
perturbe = PerturbE(
|
|
1333
1263
|
input_size=input_size,
|
|
1334
1264
|
cell_factor_size=cell_factor_size,
|
|
1335
1265
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1348,7 +1278,7 @@ def main():
|
|
|
1348
1278
|
dtype=dtype,
|
|
1349
1279
|
)
|
|
1350
1280
|
|
|
1351
|
-
|
|
1281
|
+
perturbe.fit(xs, us=us,
|
|
1352
1282
|
num_epochs=args.num_epochs,
|
|
1353
1283
|
learning_rate=args.learning_rate,
|
|
1354
1284
|
batch_size=args.batch_size,
|
|
@@ -1360,12 +1290,11 @@ def main():
|
|
|
1360
1290
|
|
|
1361
1291
|
if args.save_model is not None:
|
|
1362
1292
|
if args.save_model.endswith('gz'):
|
|
1363
|
-
|
|
1293
|
+
PerturbE.save_model(perturbe, args.save_model, compression=True)
|
|
1364
1294
|
else:
|
|
1365
|
-
|
|
1295
|
+
PerturbE.save_model(perturbe, args.save_model)
|
|
1366
1296
|
|
|
1367
1297
|
|
|
1368
1298
|
|
|
1369
1299
|
if __name__ == "__main__":
|
|
1370
|
-
|
|
1371
1300
|
main()
|
SURE/SURE.py
CHANGED
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|