SURE-tools 2.1.83__py3-none-any.whl → 2.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/{PerturbFlow.py → DensityFlow.py} +103 -83
- SURE/SURE.py +6 -6
- SURE/__init__.py +3 -3
- SURE/flow/flow_stats.py +12 -0
- SURE/perturb/perturb.py +27 -1
- SURE/utils/custom_mlp.py +4 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/METADATA +1 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/RECORD +12 -12
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.2.23.dist-info}/top_level.txt +0 -0
|
@@ -54,18 +54,19 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
68
|
use_zeroinflate: bool = False,
|
|
68
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
71
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -102,6 +103,7 @@ class PerturbFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class PerturbFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -308,7 +334,7 @@ class PerturbFlow(nn.Module):
|
|
|
308
334
|
return xs
|
|
309
335
|
|
|
310
336
|
def model1(self, xs):
|
|
311
|
-
pyro.module('
|
|
337
|
+
pyro.module('DensityFlow', self)
|
|
312
338
|
|
|
313
339
|
eps = torch.finfo(xs.dtype).eps
|
|
314
340
|
batch_size = xs.size(0)
|
|
@@ -387,7 +413,7 @@ class PerturbFlow(nn.Module):
|
|
|
387
413
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
388
414
|
|
|
389
415
|
def model2(self, xs, us=None):
|
|
390
|
-
pyro.module('
|
|
416
|
+
pyro.module('DensityFlow', self)
|
|
391
417
|
|
|
392
418
|
eps = torch.finfo(xs.dtype).eps
|
|
393
419
|
batch_size = xs.size(0)
|
|
@@ -429,7 +455,7 @@ class PerturbFlow(nn.Module):
|
|
|
429
455
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
456
|
|
|
431
457
|
if self.cell_factor_size>0:
|
|
432
|
-
zus = self._total_effects(
|
|
458
|
+
zus = self._total_effects(zns, us)
|
|
433
459
|
zs = zns+zus
|
|
434
460
|
else:
|
|
435
461
|
zs = zns
|
|
@@ -471,7 +497,7 @@ class PerturbFlow(nn.Module):
|
|
|
471
497
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
472
498
|
|
|
473
499
|
def model3(self, xs, ys, embeds=None):
|
|
474
|
-
pyro.module('
|
|
500
|
+
pyro.module('DensityFlow', self)
|
|
475
501
|
|
|
476
502
|
eps = torch.finfo(xs.dtype).eps
|
|
477
503
|
batch_size = xs.size(0)
|
|
@@ -567,7 +593,7 @@ class PerturbFlow(nn.Module):
|
|
|
567
593
|
zns = embeds
|
|
568
594
|
|
|
569
595
|
def model4(self, xs, us, ys, embeds=None):
|
|
570
|
-
pyro.module('
|
|
596
|
+
pyro.module('DensityFlow', self)
|
|
571
597
|
|
|
572
598
|
eps = torch.finfo(xs.dtype).eps
|
|
573
599
|
batch_size = xs.size(0)
|
|
@@ -631,7 +657,7 @@ class PerturbFlow(nn.Module):
|
|
|
631
657
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
658
|
# else:
|
|
633
659
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
zus = self._total_effects(
|
|
660
|
+
zus = self._total_effects(zns, us)
|
|
635
661
|
zs = zns+zus
|
|
636
662
|
else:
|
|
637
663
|
zs = zns
|
|
@@ -676,9 +702,17 @@ class PerturbFlow(nn.Module):
|
|
|
676
702
|
zus = None
|
|
677
703
|
for i in np.arange(self.cell_factor_size):
|
|
678
704
|
if i==0:
|
|
679
|
-
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
680
710
|
else:
|
|
681
|
-
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
682
716
|
return zus
|
|
683
717
|
|
|
684
718
|
def _get_codebook_identity(self):
|
|
@@ -696,7 +730,7 @@ class PerturbFlow(nn.Module):
|
|
|
696
730
|
"""
|
|
697
731
|
Return the mean part of metacell codebook
|
|
698
732
|
"""
|
|
699
|
-
cb = self.
|
|
733
|
+
cb = self._get_codebook()
|
|
700
734
|
cb = tensor_to_numpy(cb)
|
|
701
735
|
return cb
|
|
702
736
|
|
|
@@ -820,13 +854,15 @@ class PerturbFlow(nn.Module):
|
|
|
820
854
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
855
|
|
|
822
856
|
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(
|
|
857
|
+
dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
|
|
824
858
|
|
|
825
859
|
# perturbation effect
|
|
826
860
|
ps = np.ones_like(us_i)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
861
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
+
dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
|
|
863
|
+
zs = zs + dzs0 + dzs
|
|
864
|
+
else:
|
|
865
|
+
zs = zs + dzs0
|
|
830
866
|
|
|
831
867
|
if library_sizes is None:
|
|
832
868
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -840,49 +876,48 @@ class PerturbFlow(nn.Module):
|
|
|
840
876
|
|
|
841
877
|
return counts, zs
|
|
842
878
|
|
|
843
|
-
def _cell_response(self,
|
|
879
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
844
880
|
#zns,_ = self.encoder_zn(xs)
|
|
845
881
|
#zns,_ = self._get_basal_embedding(xs)
|
|
846
|
-
zns =
|
|
882
|
+
zns = zs
|
|
847
883
|
if perturb.ndim==2:
|
|
848
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
849
888
|
else:
|
|
850
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
851
893
|
|
|
852
894
|
return ms
|
|
853
895
|
|
|
854
896
|
def get_cell_response(self,
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
897
|
+
zs,
|
|
898
|
+
perturb_idx,
|
|
899
|
+
perturb_us,
|
|
858
900
|
batch_size: int = 1024):
|
|
859
901
|
"""
|
|
860
902
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
861
903
|
|
|
862
904
|
"""
|
|
863
|
-
xs = self.preprocess(xs)
|
|
864
|
-
|
|
865
|
-
ps = convert_to_tensor(
|
|
866
|
-
dataset = CustomDataset2(
|
|
905
|
+
#xs = self.preprocess(xs)
|
|
906
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
907
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
908
|
+
dataset = CustomDataset2(zs,ps)
|
|
867
909
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
868
910
|
|
|
869
911
|
Z = []
|
|
870
912
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
871
|
-
for
|
|
872
|
-
zns = self._cell_response(
|
|
913
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
873
915
|
Z.append(tensor_to_numpy(zns))
|
|
874
916
|
pbar.update(1)
|
|
875
917
|
|
|
876
918
|
Z = np.concatenate(Z)
|
|
877
919
|
return Z
|
|
878
920
|
|
|
879
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
880
|
-
#zs = self._get_codebook()
|
|
881
|
-
zs = self._get_codebook_identity()
|
|
882
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
883
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
884
|
-
return tensor_to_numpy(ms)
|
|
885
|
-
|
|
886
921
|
def _get_expression_response(self, delta_zs):
|
|
887
922
|
return self.decoder_concentrate(delta_zs)
|
|
888
923
|
|
|
@@ -907,7 +942,7 @@ class PerturbFlow(nn.Module):
|
|
|
907
942
|
R = np.concatenate(R)
|
|
908
943
|
return R
|
|
909
944
|
|
|
910
|
-
def _count(self,concentrate, library_size=None):
|
|
945
|
+
def _count(self, concentrate, library_size=None):
|
|
911
946
|
if self.loss_func == 'bernoulli':
|
|
912
947
|
#counts = self.sigmoid(concentrate)
|
|
913
948
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -915,28 +950,17 @@ class PerturbFlow(nn.Module):
|
|
|
915
950
|
rate = concentrate.exp()
|
|
916
951
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
917
952
|
counts = theta * library_size
|
|
918
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
919
|
-
return counts
|
|
920
|
-
|
|
921
|
-
def _count_sample(self,concentrate):
|
|
922
|
-
if self.loss_func == 'bernoulli':
|
|
923
|
-
logits = concentrate
|
|
924
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
925
|
-
else:
|
|
926
|
-
counts = self._count(concentrate=concentrate)
|
|
927
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
928
953
|
return counts
|
|
929
954
|
|
|
930
955
|
def get_counts(self, zs, library_sizes,
|
|
931
|
-
batch_size: int = 1024
|
|
932
|
-
use_sampler: bool = False):
|
|
956
|
+
batch_size: int = 1024):
|
|
933
957
|
|
|
934
958
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
935
959
|
|
|
936
960
|
if type(library_sizes) == list:
|
|
937
|
-
library_sizes = np.array(library_sizes).
|
|
961
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
938
962
|
elif len(library_sizes.shape)==1:
|
|
939
|
-
library_sizes = library_sizes.
|
|
963
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
940
964
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
941
965
|
|
|
942
966
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -946,10 +970,7 @@ class PerturbFlow(nn.Module):
|
|
|
946
970
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
947
971
|
for Z_batch, L_batch, _ in dataloader:
|
|
948
972
|
concentrate = self._get_expression_response(Z_batch)
|
|
949
|
-
|
|
950
|
-
counts = self._count_sample(concentrate)
|
|
951
|
-
else:
|
|
952
|
-
counts = self._count(concentrate, L_batch)
|
|
973
|
+
counts = self._count(concentrate, L_batch)
|
|
953
974
|
E.append(tensor_to_numpy(counts))
|
|
954
975
|
pbar.update(1)
|
|
955
976
|
|
|
@@ -972,7 +993,7 @@ class PerturbFlow(nn.Module):
|
|
|
972
993
|
us = None,
|
|
973
994
|
ys = None,
|
|
974
995
|
zs = None,
|
|
975
|
-
num_epochs: int =
|
|
996
|
+
num_epochs: int = 500,
|
|
976
997
|
learning_rate: float = 0.0001,
|
|
977
998
|
batch_size: int = 256,
|
|
978
999
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -983,7 +1004,7 @@ class PerturbFlow(nn.Module):
|
|
|
983
1004
|
threshold: int = 0,
|
|
984
1005
|
use_jax: bool = True):
|
|
985
1006
|
"""
|
|
986
|
-
Train the
|
|
1007
|
+
Train the DensityFlow model.
|
|
987
1008
|
|
|
988
1009
|
Parameters
|
|
989
1010
|
----------
|
|
@@ -1009,7 +1030,7 @@ class PerturbFlow(nn.Module):
|
|
|
1009
1030
|
Parameter for optimization.
|
|
1010
1031
|
use_jax
|
|
1011
1032
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1012
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1033
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1013
1034
|
"""
|
|
1014
1035
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1015
1036
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1127,12 +1148,12 @@ class PerturbFlow(nn.Module):
|
|
|
1127
1148
|
|
|
1128
1149
|
|
|
1129
1150
|
EXAMPLE_RUN = (
|
|
1130
|
-
"example run:
|
|
1151
|
+
"example run: DensityFlow --help"
|
|
1131
1152
|
)
|
|
1132
1153
|
|
|
1133
1154
|
def parse_args():
|
|
1134
1155
|
parser = argparse.ArgumentParser(
|
|
1135
|
-
description="
|
|
1156
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1136
1157
|
|
|
1137
1158
|
parser.add_argument(
|
|
1138
1159
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1319,7 +1340,7 @@ def main():
|
|
|
1319
1340
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1320
1341
|
|
|
1321
1342
|
###########################################
|
|
1322
|
-
|
|
1343
|
+
DensityFlow = DensityFlow(
|
|
1323
1344
|
input_size=input_size,
|
|
1324
1345
|
cell_factor_size=cell_factor_size,
|
|
1325
1346
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1338,7 +1359,7 @@ def main():
|
|
|
1338
1359
|
dtype=dtype,
|
|
1339
1360
|
)
|
|
1340
1361
|
|
|
1341
|
-
|
|
1362
|
+
DensityFlow.fit(xs, us=us,
|
|
1342
1363
|
num_epochs=args.num_epochs,
|
|
1343
1364
|
learning_rate=args.learning_rate,
|
|
1344
1365
|
batch_size=args.batch_size,
|
|
@@ -1350,12 +1371,11 @@ def main():
|
|
|
1350
1371
|
|
|
1351
1372
|
if args.save_model is not None:
|
|
1352
1373
|
if args.save_model.endswith('gz'):
|
|
1353
|
-
|
|
1374
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1354
1375
|
else:
|
|
1355
|
-
|
|
1376
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1356
1377
|
|
|
1357
1378
|
|
|
1358
1379
|
|
|
1359
1380
|
if __name__ == "__main__":
|
|
1360
|
-
|
|
1361
1381
|
main()
|
SURE/SURE.py
CHANGED
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|
SURE/__init__.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
SURE/flow/flow_stats.py
CHANGED
|
@@ -41,6 +41,18 @@ class VectorFieldEval:
|
|
|
41
41
|
divergence[np.isnan(divergence)] = 0
|
|
42
42
|
|
|
43
43
|
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def calculate_movement_stats(vectors):
|
SURE/perturb/perturb.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
SURE/utils/custom_mlp.py
CHANGED
|
@@ -239,7 +239,10 @@ class ZeroBiasMLP(nn.Module):
|
|
|
239
239
|
def forward(self, x):
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
|
-
|
|
242
|
+
if len(y.shape)==2:
|
|
243
|
+
mask[x[1][:,0]>0,:] = 1
|
|
244
|
+
elif len(y.shape)==3:
|
|
245
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
243
246
|
return y*mask
|
|
244
247
|
|
|
245
248
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
SURE/
|
|
2
|
-
SURE/SURE.py,sha256=
|
|
3
|
-
SURE/__init__.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=p5Pt3KrsdF_NTLFx0p1cUPuXkIac6wQED1LsLJRG7mI,56124
|
|
2
|
+
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
3
|
+
SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
5
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
6
6
|
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
@@ -9,17 +9,17 @@ SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
|
|
|
9
9
|
SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
|
|
10
10
|
SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
|
|
11
11
|
SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
12
|
-
SURE/flow/flow_stats.py,sha256=
|
|
12
|
+
SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
|
|
13
13
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
14
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
|
-
SURE/perturb/perturb.py,sha256=
|
|
15
|
+
SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
|
|
16
16
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
17
|
+
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.
|
|
21
|
-
sure_tools-2.
|
|
22
|
-
sure_tools-2.
|
|
23
|
-
sure_tools-2.
|
|
24
|
-
sure_tools-2.
|
|
25
|
-
sure_tools-2.
|
|
20
|
+
sure_tools-2.2.23.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.2.23.dist-info/METADATA,sha256=ckAOsGL19y8unUmL2zYK4yeTRGFyALbaN_3hM18u0tw,2678
|
|
22
|
+
sure_tools-2.2.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.2.23.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.2.23.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.2.23.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|