SURE-tools 2.2.2__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +103 -74
- SURE/{PerturbFlow.py → PerturbE.py} +51 -110
- SURE/TranscriptomeDecoder.py +527 -0
- SURE/__init__.py +5 -1
- SURE/perturb/perturb.py +27 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/METADATA +1 -1
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/RECORD +12 -11
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.2.dist-info → sure_tools-2.4.3.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
|
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
68
69
|
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -370,7 +396,8 @@ class DensityFlow(nn.Module):
|
|
|
370
396
|
else:
|
|
371
397
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
372
398
|
elif self.loss_func == 'multinomial':
|
|
373
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
399
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
400
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
374
401
|
elif self.loss_func == 'bernoulli':
|
|
375
402
|
if self.use_zeroinflate:
|
|
376
403
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -454,7 +481,8 @@ class DensityFlow(nn.Module):
|
|
|
454
481
|
else:
|
|
455
482
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
456
483
|
elif self.loss_func == 'multinomial':
|
|
457
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
484
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
485
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
458
486
|
elif self.loss_func == 'bernoulli':
|
|
459
487
|
if self.use_zeroinflate:
|
|
460
488
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -550,7 +578,8 @@ class DensityFlow(nn.Module):
|
|
|
550
578
|
else:
|
|
551
579
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
552
580
|
elif self.loss_func == 'multinomial':
|
|
553
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
581
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
582
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
554
583
|
elif self.loss_func == 'bernoulli':
|
|
555
584
|
if self.use_zeroinflate:
|
|
556
585
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -656,7 +685,8 @@ class DensityFlow(nn.Module):
|
|
|
656
685
|
else:
|
|
657
686
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
658
687
|
elif self.loss_func == 'multinomial':
|
|
659
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
688
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
689
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
660
690
|
elif self.loss_func == 'bernoulli':
|
|
661
691
|
if self.use_zeroinflate:
|
|
662
692
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -676,9 +706,17 @@ class DensityFlow(nn.Module):
|
|
|
676
706
|
zus = None
|
|
677
707
|
for i in np.arange(self.cell_factor_size):
|
|
678
708
|
if i==0:
|
|
679
|
-
|
|
709
|
+
#if self.turn_off_cell_specific:
|
|
710
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
711
|
+
#else:
|
|
712
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
713
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
680
714
|
else:
|
|
681
|
-
|
|
715
|
+
#if self.turn_off_cell_specific:
|
|
716
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
717
|
+
#else:
|
|
718
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
719
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
682
720
|
return zus
|
|
683
721
|
|
|
684
722
|
def _get_codebook_identity(self):
|
|
@@ -696,7 +734,7 @@ class DensityFlow(nn.Module):
|
|
|
696
734
|
"""
|
|
697
735
|
Return the mean part of metacell codebook
|
|
698
736
|
"""
|
|
699
|
-
cb = self.
|
|
737
|
+
cb = self._get_codebook()
|
|
700
738
|
cb = tensor_to_numpy(cb)
|
|
701
739
|
return cb
|
|
702
740
|
|
|
@@ -820,13 +858,15 @@ class DensityFlow(nn.Module):
|
|
|
820
858
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
859
|
|
|
822
860
|
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(
|
|
861
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
824
862
|
|
|
825
863
|
# perturbation effect
|
|
826
864
|
ps = np.ones_like(us_i)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
865
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
866
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
867
|
+
zs = zs + dzs0 + dzs
|
|
868
|
+
else:
|
|
869
|
+
zs = zs + dzs0
|
|
830
870
|
|
|
831
871
|
if library_sizes is None:
|
|
832
872
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -840,47 +880,48 @@ class DensityFlow(nn.Module):
|
|
|
840
880
|
|
|
841
881
|
return counts, zs
|
|
842
882
|
|
|
843
|
-
def _cell_response(self,
|
|
883
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
844
884
|
#zns,_ = self.encoder_zn(xs)
|
|
845
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
885
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
886
|
+
zns = zs
|
|
846
887
|
if perturb.ndim==2:
|
|
847
|
-
|
|
888
|
+
if self.turn_off_cell_specific:
|
|
889
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
890
|
+
else:
|
|
891
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
848
892
|
else:
|
|
849
|
-
|
|
893
|
+
if self.turn_off_cell_specific:
|
|
894
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
895
|
+
else:
|
|
896
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
850
897
|
|
|
851
898
|
return ms
|
|
852
899
|
|
|
853
900
|
def get_cell_response(self,
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
901
|
+
zs,
|
|
902
|
+
perturb_idx,
|
|
903
|
+
perturb_us,
|
|
857
904
|
batch_size: int = 1024):
|
|
858
905
|
"""
|
|
859
906
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
860
907
|
|
|
861
908
|
"""
|
|
862
|
-
xs = self.preprocess(xs)
|
|
863
|
-
|
|
864
|
-
ps = convert_to_tensor(
|
|
865
|
-
dataset = CustomDataset2(
|
|
909
|
+
#xs = self.preprocess(xs)
|
|
910
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
911
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
912
|
+
dataset = CustomDataset2(zs,ps)
|
|
866
913
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
867
914
|
|
|
868
915
|
Z = []
|
|
869
916
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
870
|
-
for
|
|
871
|
-
zns = self._cell_response(
|
|
917
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
918
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
872
919
|
Z.append(tensor_to_numpy(zns))
|
|
873
920
|
pbar.update(1)
|
|
874
921
|
|
|
875
922
|
Z = np.concatenate(Z)
|
|
876
923
|
return Z
|
|
877
924
|
|
|
878
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
879
|
-
zs = self._get_codebook()
|
|
880
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
881
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
882
|
-
return tensor_to_numpy(ms)
|
|
883
|
-
|
|
884
925
|
def _get_expression_response(self, delta_zs):
|
|
885
926
|
return self.decoder_concentrate(delta_zs)
|
|
886
927
|
|
|
@@ -905,36 +946,28 @@ class DensityFlow(nn.Module):
|
|
|
905
946
|
R = np.concatenate(R)
|
|
906
947
|
return R
|
|
907
948
|
|
|
908
|
-
def _count(self,concentrate, library_size=None):
|
|
949
|
+
def _count(self, concentrate, library_size=None):
|
|
909
950
|
if self.loss_func == 'bernoulli':
|
|
910
951
|
#counts = self.sigmoid(concentrate)
|
|
911
952
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
953
|
+
elif self.loss_func == 'multinomial':
|
|
954
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
955
|
+
counts = theta * library_size
|
|
912
956
|
else:
|
|
913
957
|
rate = concentrate.exp()
|
|
914
958
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
915
959
|
counts = theta * library_size
|
|
916
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
917
|
-
return counts
|
|
918
|
-
|
|
919
|
-
def _count_sample(self,concentrate):
|
|
920
|
-
if self.loss_func == 'bernoulli':
|
|
921
|
-
logits = concentrate
|
|
922
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
923
|
-
else:
|
|
924
|
-
counts = self._count(concentrate=concentrate)
|
|
925
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
926
960
|
return counts
|
|
927
961
|
|
|
928
962
|
def get_counts(self, zs, library_sizes,
|
|
929
|
-
batch_size: int = 1024
|
|
930
|
-
use_sampler: bool = False):
|
|
963
|
+
batch_size: int = 1024):
|
|
931
964
|
|
|
932
965
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
933
966
|
|
|
934
967
|
if type(library_sizes) == list:
|
|
935
|
-
library_sizes = np.array(library_sizes).
|
|
968
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
936
969
|
elif len(library_sizes.shape)==1:
|
|
937
|
-
library_sizes = library_sizes.
|
|
970
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
938
971
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
939
972
|
|
|
940
973
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -944,10 +977,7 @@ class DensityFlow(nn.Module):
|
|
|
944
977
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
945
978
|
for Z_batch, L_batch, _ in dataloader:
|
|
946
979
|
concentrate = self._get_expression_response(Z_batch)
|
|
947
|
-
|
|
948
|
-
counts = self._count_sample(concentrate)
|
|
949
|
-
else:
|
|
950
|
-
counts = self._count(concentrate, L_batch)
|
|
980
|
+
counts = self._count(concentrate, L_batch)
|
|
951
981
|
E.append(tensor_to_numpy(counts))
|
|
952
982
|
pbar.update(1)
|
|
953
983
|
|
|
@@ -1317,7 +1347,7 @@ def main():
|
|
|
1317
1347
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1318
1348
|
|
|
1319
1349
|
###########################################
|
|
1320
|
-
|
|
1350
|
+
df = DensityFlow(
|
|
1321
1351
|
input_size=input_size,
|
|
1322
1352
|
cell_factor_size=cell_factor_size,
|
|
1323
1353
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1336,7 +1366,7 @@ def main():
|
|
|
1336
1366
|
dtype=dtype,
|
|
1337
1367
|
)
|
|
1338
1368
|
|
|
1339
|
-
|
|
1369
|
+
df.fit(xs, us=us,
|
|
1340
1370
|
num_epochs=args.num_epochs,
|
|
1341
1371
|
learning_rate=args.learning_rate,
|
|
1342
1372
|
batch_size=args.batch_size,
|
|
@@ -1348,12 +1378,11 @@ def main():
|
|
|
1348
1378
|
|
|
1349
1379
|
if args.save_model is not None:
|
|
1350
1380
|
if args.save_model.endswith('gz'):
|
|
1351
|
-
DensityFlow.save_model(
|
|
1381
|
+
DensityFlow.save_model(df, args.save_model, compression=True)
|
|
1352
1382
|
else:
|
|
1353
|
-
DensityFlow.save_model(
|
|
1383
|
+
DensityFlow.save_model(df, args.save_model)
|
|
1354
1384
|
|
|
1355
1385
|
|
|
1356
1386
|
|
|
1357
1387
|
if __name__ == "__main__":
|
|
1358
|
-
|
|
1359
1388
|
main()
|