SURE-tools 2.2.2__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +103 -74
- SURE/{PerturbFlow.py → PerturbE.py} +51 -110
- SURE/TranscriptomeDecoder.py +527 -0
- SURE/__init__.py +5 -1
- SURE/perturb/perturb.py +27 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/METADATA +1 -1
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/RECORD +12 -11
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_pr
|
|
|
10
10
|
from torch.distributions import constraints
|
|
11
11
|
from torch.distributions.transforms import SoftmaxTransform
|
|
12
12
|
|
|
13
|
-
from .utils.custom_mlp import MLP, Exp,
|
|
13
|
+
from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP2
|
|
14
14
|
from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
|
|
15
15
|
|
|
16
16
|
|
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class PerturbE(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = False,
|
|
68
|
-
hidden_layers: list = [
|
|
68
|
+
hidden_layers: list = [500],
|
|
69
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
70
|
nn_dropout: float = 0.1,
|
|
71
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -73,7 +73,6 @@ class PerturbFlow(nn.Module):
|
|
|
73
73
|
config_enum: str = 'parallel',
|
|
74
74
|
use_cuda: bool = True,
|
|
75
75
|
seed: int = 42,
|
|
76
|
-
zero_bias: bool|list = True,
|
|
77
76
|
dtype = torch.float32, # type: ignore
|
|
78
77
|
):
|
|
79
78
|
super().__init__()
|
|
@@ -97,11 +96,6 @@ class PerturbFlow(nn.Module):
|
|
|
97
96
|
self.post_layer_fct = post_layer_fct
|
|
98
97
|
self.post_act_fct = post_act_fct
|
|
99
98
|
self.hidden_layer_activation = hidden_layer_activation
|
|
100
|
-
if type(zero_bias) == list:
|
|
101
|
-
self.use_bias = [not x for x in zero_bias]
|
|
102
|
-
else:
|
|
103
|
-
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
|
-
#self.use_bias = not zero_bias
|
|
105
99
|
|
|
106
100
|
self.codebook_weights = None
|
|
107
101
|
|
|
@@ -200,29 +194,14 @@ class PerturbFlow(nn.Module):
|
|
|
200
194
|
)
|
|
201
195
|
|
|
202
196
|
if self.cell_factor_size>0:
|
|
203
|
-
self.cell_factor_effect =
|
|
204
|
-
|
|
205
|
-
if self.use_bias[i]:
|
|
206
|
-
self.cell_factor_effect.append(MLP(
|
|
207
|
-
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
208
|
-
activation=activate_fct,
|
|
209
|
-
output_activation=None,
|
|
210
|
-
post_layer_fct=post_layer_fct,
|
|
211
|
-
post_act_fct=post_act_fct,
|
|
212
|
-
allow_broadcast=self.allow_broadcast,
|
|
213
|
-
use_cuda=self.use_cuda,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
else:
|
|
217
|
-
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
218
|
-
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
197
|
+
self.cell_factor_effect = ZeroBiasMLP2(
|
|
198
|
+
[self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
219
199
|
activation=activate_fct,
|
|
220
200
|
output_activation=None,
|
|
221
201
|
post_layer_fct=post_layer_fct,
|
|
222
202
|
post_act_fct=post_act_fct,
|
|
223
203
|
allow_broadcast=self.allow_broadcast,
|
|
224
204
|
use_cuda=self.use_cuda,
|
|
225
|
-
)
|
|
226
205
|
)
|
|
227
206
|
|
|
228
207
|
self.decoder_concentrate = MLP(
|
|
@@ -308,7 +287,7 @@ class PerturbFlow(nn.Module):
|
|
|
308
287
|
return xs
|
|
309
288
|
|
|
310
289
|
def model1(self, xs):
|
|
311
|
-
pyro.module('
|
|
290
|
+
pyro.module('PerturbE', self)
|
|
312
291
|
|
|
313
292
|
eps = torch.finfo(xs.dtype).eps
|
|
314
293
|
batch_size = xs.size(0)
|
|
@@ -370,7 +349,8 @@ class PerturbFlow(nn.Module):
|
|
|
370
349
|
else:
|
|
371
350
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
372
351
|
elif self.loss_func == 'multinomial':
|
|
373
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
352
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
353
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
374
354
|
elif self.loss_func == 'bernoulli':
|
|
375
355
|
if self.use_zeroinflate:
|
|
376
356
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -387,7 +367,7 @@ class PerturbFlow(nn.Module):
|
|
|
387
367
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
388
368
|
|
|
389
369
|
def model2(self, xs, us=None):
|
|
390
|
-
pyro.module('
|
|
370
|
+
pyro.module('PerturbE', self)
|
|
391
371
|
|
|
392
372
|
eps = torch.finfo(xs.dtype).eps
|
|
393
373
|
batch_size = xs.size(0)
|
|
@@ -429,7 +409,7 @@ class PerturbFlow(nn.Module):
|
|
|
429
409
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
410
|
|
|
431
411
|
if self.cell_factor_size>0:
|
|
432
|
-
zus = self.
|
|
412
|
+
zus = self._perturb_effects(us)
|
|
433
413
|
zs = zns+zus
|
|
434
414
|
else:
|
|
435
415
|
zs = zns
|
|
@@ -454,7 +434,8 @@ class PerturbFlow(nn.Module):
|
|
|
454
434
|
else:
|
|
455
435
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
456
436
|
elif self.loss_func == 'multinomial':
|
|
457
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
437
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
438
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
458
439
|
elif self.loss_func == 'bernoulli':
|
|
459
440
|
if self.use_zeroinflate:
|
|
460
441
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -471,7 +452,7 @@ class PerturbFlow(nn.Module):
|
|
|
471
452
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
472
453
|
|
|
473
454
|
def model3(self, xs, ys, embeds=None):
|
|
474
|
-
pyro.module('
|
|
455
|
+
pyro.module('PerturbE', self)
|
|
475
456
|
|
|
476
457
|
eps = torch.finfo(xs.dtype).eps
|
|
477
458
|
batch_size = xs.size(0)
|
|
@@ -550,7 +531,8 @@ class PerturbFlow(nn.Module):
|
|
|
550
531
|
else:
|
|
551
532
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
552
533
|
elif self.loss_func == 'multinomial':
|
|
553
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
534
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
535
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
554
536
|
elif self.loss_func == 'bernoulli':
|
|
555
537
|
if self.use_zeroinflate:
|
|
556
538
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -567,7 +549,7 @@ class PerturbFlow(nn.Module):
|
|
|
567
549
|
zns = embeds
|
|
568
550
|
|
|
569
551
|
def model4(self, xs, us, ys, embeds=None):
|
|
570
|
-
pyro.module('
|
|
552
|
+
pyro.module('PerturbE', self)
|
|
571
553
|
|
|
572
554
|
eps = torch.finfo(xs.dtype).eps
|
|
573
555
|
batch_size = xs.size(0)
|
|
@@ -631,7 +613,7 @@ class PerturbFlow(nn.Module):
|
|
|
631
613
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
614
|
# else:
|
|
633
615
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
zus = self.
|
|
616
|
+
zus = self._perturb_effects(us)
|
|
635
617
|
zs = zns+zus
|
|
636
618
|
else:
|
|
637
619
|
zs = zns
|
|
@@ -656,7 +638,8 @@ class PerturbFlow(nn.Module):
|
|
|
656
638
|
else:
|
|
657
639
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
658
640
|
elif self.loss_func == 'multinomial':
|
|
659
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
641
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
642
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
660
643
|
elif self.loss_func == 'bernoulli':
|
|
661
644
|
if self.use_zeroinflate:
|
|
662
645
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -672,13 +655,8 @@ class PerturbFlow(nn.Module):
|
|
|
672
655
|
else:
|
|
673
656
|
zns = embeds
|
|
674
657
|
|
|
675
|
-
def
|
|
676
|
-
zus =
|
|
677
|
-
for i in np.arange(self.cell_factor_size):
|
|
678
|
-
if i==0:
|
|
679
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
680
|
-
else:
|
|
681
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
658
|
+
def _perturb_effects(self, us):
|
|
659
|
+
zus = self._cell_response(us)
|
|
682
660
|
return zus
|
|
683
661
|
|
|
684
662
|
def _get_codebook_identity(self):
|
|
@@ -696,7 +674,7 @@ class PerturbFlow(nn.Module):
|
|
|
696
674
|
"""
|
|
697
675
|
Return the mean part of metacell codebook
|
|
698
676
|
"""
|
|
699
|
-
cb = self.
|
|
677
|
+
cb = self._get_codebook()
|
|
700
678
|
cb = tensor_to_numpy(cb)
|
|
701
679
|
return cb
|
|
702
680
|
|
|
@@ -810,23 +788,13 @@ class PerturbFlow(nn.Module):
|
|
|
810
788
|
A = np.concatenate(A)
|
|
811
789
|
return A
|
|
812
790
|
|
|
813
|
-
def predict(self, xs,
|
|
791
|
+
def predict(self, xs, perturbs_us, library_sizes=None):
|
|
814
792
|
perturbs_reference = np.array(perturbs_reference)
|
|
815
793
|
|
|
816
794
|
# basal embedding
|
|
817
795
|
zs = self.get_basal_embedding(xs)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
|
-
|
|
822
|
-
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
|
|
824
|
-
|
|
825
|
-
# perturbation effect
|
|
826
|
-
ps = np.ones_like(us_i)
|
|
827
|
-
dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
|
|
828
|
-
|
|
829
|
-
zs = zs + dzs0 + dzs
|
|
796
|
+
dzs = self.get_cell_response(perturbs_us)
|
|
797
|
+
zs = zs + dzs
|
|
830
798
|
|
|
831
799
|
if library_sizes is None:
|
|
832
800
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -840,47 +808,32 @@ class PerturbFlow(nn.Module):
|
|
|
840
808
|
|
|
841
809
|
return counts, zs
|
|
842
810
|
|
|
843
|
-
def _cell_response(self,
|
|
844
|
-
|
|
845
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
846
|
-
if perturb.ndim==2:
|
|
847
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
848
|
-
else:
|
|
849
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
850
|
-
|
|
811
|
+
def _cell_response(self, perturb):
|
|
812
|
+
ms = self.cell_factor_effect(perturb)
|
|
851
813
|
return ms
|
|
852
814
|
|
|
853
815
|
def get_cell_response(self,
|
|
854
|
-
|
|
855
|
-
factor_idx,
|
|
856
|
-
perturb,
|
|
816
|
+
perturb_us,
|
|
857
817
|
batch_size: int = 1024):
|
|
858
818
|
"""
|
|
859
819
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
860
820
|
|
|
861
821
|
"""
|
|
862
|
-
xs = self.preprocess(xs)
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
dataset = CustomDataset2(xs,ps)
|
|
822
|
+
#xs = self.preprocess(xs)
|
|
823
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
824
|
+
dataset = CustomDataset(ps)
|
|
866
825
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
867
826
|
|
|
868
827
|
Z = []
|
|
869
828
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
870
|
-
for
|
|
871
|
-
zns = self._cell_response(
|
|
829
|
+
for P_batch, _ in dataloader:
|
|
830
|
+
zns = self._cell_response(P_batch)
|
|
872
831
|
Z.append(tensor_to_numpy(zns))
|
|
873
832
|
pbar.update(1)
|
|
874
833
|
|
|
875
834
|
Z = np.concatenate(Z)
|
|
876
835
|
return Z
|
|
877
836
|
|
|
878
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
879
|
-
zs = self._get_codebook()
|
|
880
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
881
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
882
|
-
return tensor_to_numpy(ms)
|
|
883
|
-
|
|
884
837
|
def _get_expression_response(self, delta_zs):
|
|
885
838
|
return self.decoder_concentrate(delta_zs)
|
|
886
839
|
|
|
@@ -905,36 +858,28 @@ class PerturbFlow(nn.Module):
|
|
|
905
858
|
R = np.concatenate(R)
|
|
906
859
|
return R
|
|
907
860
|
|
|
908
|
-
def _count(self,concentrate, library_size=None):
|
|
861
|
+
def _count(self, concentrate, library_size=None):
|
|
909
862
|
if self.loss_func == 'bernoulli':
|
|
910
863
|
#counts = self.sigmoid(concentrate)
|
|
911
864
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
865
|
+
elif self.loss_func == 'multinomial':
|
|
866
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
867
|
+
counts = theta * library_size
|
|
912
868
|
else:
|
|
913
869
|
rate = concentrate.exp()
|
|
914
870
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
915
871
|
counts = theta * library_size
|
|
916
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
917
|
-
return counts
|
|
918
|
-
|
|
919
|
-
def _count_sample(self,concentrate):
|
|
920
|
-
if self.loss_func == 'bernoulli':
|
|
921
|
-
logits = concentrate
|
|
922
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
923
|
-
else:
|
|
924
|
-
counts = self._count(concentrate=concentrate)
|
|
925
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
926
872
|
return counts
|
|
927
873
|
|
|
928
874
|
def get_counts(self, zs, library_sizes,
|
|
929
|
-
batch_size: int = 1024
|
|
930
|
-
use_sampler: bool = False):
|
|
875
|
+
batch_size: int = 1024):
|
|
931
876
|
|
|
932
877
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
933
878
|
|
|
934
879
|
if type(library_sizes) == list:
|
|
935
|
-
library_sizes = np.array(library_sizes).
|
|
880
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
936
881
|
elif len(library_sizes.shape)==1:
|
|
937
|
-
library_sizes = library_sizes.
|
|
882
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
938
883
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
939
884
|
|
|
940
885
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -944,10 +889,7 @@ class PerturbFlow(nn.Module):
|
|
|
944
889
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
945
890
|
for Z_batch, L_batch, _ in dataloader:
|
|
946
891
|
concentrate = self._get_expression_response(Z_batch)
|
|
947
|
-
|
|
948
|
-
counts = self._count_sample(concentrate)
|
|
949
|
-
else:
|
|
950
|
-
counts = self._count(concentrate, L_batch)
|
|
892
|
+
counts = self._count(concentrate, L_batch)
|
|
951
893
|
E.append(tensor_to_numpy(counts))
|
|
952
894
|
pbar.update(1)
|
|
953
895
|
|
|
@@ -970,7 +912,7 @@ class PerturbFlow(nn.Module):
|
|
|
970
912
|
us = None,
|
|
971
913
|
ys = None,
|
|
972
914
|
zs = None,
|
|
973
|
-
num_epochs: int =
|
|
915
|
+
num_epochs: int = 500,
|
|
974
916
|
learning_rate: float = 0.0001,
|
|
975
917
|
batch_size: int = 256,
|
|
976
918
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -981,7 +923,7 @@ class PerturbFlow(nn.Module):
|
|
|
981
923
|
threshold: int = 0,
|
|
982
924
|
use_jax: bool = True):
|
|
983
925
|
"""
|
|
984
|
-
Train the
|
|
926
|
+
Train the PerturbE model.
|
|
985
927
|
|
|
986
928
|
Parameters
|
|
987
929
|
----------
|
|
@@ -1007,7 +949,7 @@ class PerturbFlow(nn.Module):
|
|
|
1007
949
|
Parameter for optimization.
|
|
1008
950
|
use_jax
|
|
1009
951
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1010
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
952
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbE in the shell command.
|
|
1011
953
|
"""
|
|
1012
954
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1013
955
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1125,12 +1067,12 @@ class PerturbFlow(nn.Module):
|
|
|
1125
1067
|
|
|
1126
1068
|
|
|
1127
1069
|
EXAMPLE_RUN = (
|
|
1128
|
-
"example run:
|
|
1070
|
+
"example run: PerturbE --help"
|
|
1129
1071
|
)
|
|
1130
1072
|
|
|
1131
1073
|
def parse_args():
|
|
1132
1074
|
parser = argparse.ArgumentParser(
|
|
1133
|
-
description="
|
|
1075
|
+
description="PerturbE\n{}".format(EXAMPLE_RUN))
|
|
1134
1076
|
|
|
1135
1077
|
parser.add_argument(
|
|
1136
1078
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1317,7 +1259,7 @@ def main():
|
|
|
1317
1259
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1318
1260
|
|
|
1319
1261
|
###########################################
|
|
1320
|
-
|
|
1262
|
+
perturbe = PerturbE(
|
|
1321
1263
|
input_size=input_size,
|
|
1322
1264
|
cell_factor_size=cell_factor_size,
|
|
1323
1265
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1336,7 +1278,7 @@ def main():
|
|
|
1336
1278
|
dtype=dtype,
|
|
1337
1279
|
)
|
|
1338
1280
|
|
|
1339
|
-
|
|
1281
|
+
perturbe.fit(xs, us=us,
|
|
1340
1282
|
num_epochs=args.num_epochs,
|
|
1341
1283
|
learning_rate=args.learning_rate,
|
|
1342
1284
|
batch_size=args.batch_size,
|
|
@@ -1348,12 +1290,11 @@ def main():
|
|
|
1348
1290
|
|
|
1349
1291
|
if args.save_model is not None:
|
|
1350
1292
|
if args.save_model.endswith('gz'):
|
|
1351
|
-
|
|
1293
|
+
PerturbE.save_model(perturbe, args.save_model, compression=True)
|
|
1352
1294
|
else:
|
|
1353
|
-
|
|
1295
|
+
PerturbE.save_model(perturbe, args.save_model)
|
|
1354
1296
|
|
|
1355
1297
|
|
|
1356
1298
|
|
|
1357
1299
|
if __name__ == "__main__":
|
|
1358
|
-
|
|
1359
1300
|
main()
|