SURE-tools 2.2.15__tar.gz → 2.2.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.2.15 → sure_tools-2.2.20}/PKG-INFO +1 -1
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/DensityFlow.py +66 -27
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.15 → sure_tools-2.2.20}/setup.py +1 -1
- {sure_tools-2.2.15 → sure_tools-2.2.20}/LICENSE +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/README.md +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/SURE.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.15 → sure_tools-2.2.20}/setup.cfg +0 -0
|
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
|
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
68
69
|
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -676,9 +702,17 @@ class DensityFlow(nn.Module):
|
|
|
676
702
|
zus = None
|
|
677
703
|
for i in np.arange(self.cell_factor_size):
|
|
678
704
|
if i==0:
|
|
679
|
-
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
680
710
|
else:
|
|
681
|
-
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
682
716
|
return zus
|
|
683
717
|
|
|
684
718
|
def _get_codebook_identity(self):
|
|
@@ -820,12 +854,12 @@ class DensityFlow(nn.Module):
|
|
|
820
854
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
855
|
|
|
822
856
|
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(
|
|
857
|
+
dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
|
|
824
858
|
|
|
825
859
|
# perturbation effect
|
|
826
860
|
ps = np.ones_like(us_i)
|
|
827
861
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
828
|
-
dzs = self.get_cell_response(
|
|
862
|
+
dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
|
|
829
863
|
zs = zs + dzs0 + dzs
|
|
830
864
|
else:
|
|
831
865
|
zs = zs + dzs0
|
|
@@ -847,9 +881,15 @@ class DensityFlow(nn.Module):
|
|
|
847
881
|
#zns,_ = self._get_basal_embedding(xs)
|
|
848
882
|
zns = zs
|
|
849
883
|
if perturb.ndim==2:
|
|
850
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
851
888
|
else:
|
|
852
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
853
893
|
|
|
854
894
|
return ms
|
|
855
895
|
|
|
@@ -1338,5 +1378,4 @@ def main():
|
|
|
1338
1378
|
|
|
1339
1379
|
|
|
1340
1380
|
if __name__ == "__main__":
|
|
1341
|
-
|
|
1342
1381
|
main()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|