SURE-tools 2.1.33__py3-none-any.whl → 2.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +23 -12
- SURE/SURE.py +10 -22
- SURE/flow/flow_stats.py +5 -16
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/METADATA +1 -1
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/RECORD +9 -9
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.33.dist-info → sure_tools-2.1.35.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -423,12 +423,13 @@ class PerturbFlow(nn.Module):
|
|
|
423
423
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
424
424
|
|
|
425
425
|
if self.cell_factor_size>0:
|
|
426
|
-
zus = None
|
|
427
|
-
for i in np.arange(self.cell_factor_size):
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
426
|
+
#zus = None
|
|
427
|
+
#for i in np.arange(self.cell_factor_size):
|
|
428
|
+
# if i==0:
|
|
429
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
430
|
+
# else:
|
|
431
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
432
433
|
zs = zns+zus
|
|
433
434
|
else:
|
|
434
435
|
zs = zns
|
|
@@ -618,12 +619,13 @@ class PerturbFlow(nn.Module):
|
|
|
618
619
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
619
620
|
|
|
620
621
|
if self.cell_factor_size>0:
|
|
621
|
-
zus = None
|
|
622
|
-
for i in np.arange(self.cell_factor_size):
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
622
|
+
#zus = None
|
|
623
|
+
#for i in np.arange(self.cell_factor_size):
|
|
624
|
+
# if i==0:
|
|
625
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
626
|
+
# else:
|
|
627
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
628
|
+
zus = self._total_effects(zns, us)
|
|
627
629
|
zs = zns+zus
|
|
628
630
|
else:
|
|
629
631
|
zs = zns
|
|
@@ -660,6 +662,15 @@ class PerturbFlow(nn.Module):
|
|
|
660
662
|
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
661
663
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
662
664
|
|
|
665
|
+
def _total_effects(self, zns, us):
|
|
666
|
+
zus = None
|
|
667
|
+
for i in np.arange(self.cell_factor_size):
|
|
668
|
+
if i==0:
|
|
669
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
670
|
+
else:
|
|
671
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
672
|
+
return zus
|
|
673
|
+
|
|
663
674
|
def _get_codebook(self):
|
|
664
675
|
I = torch.eye(self.code_size, **self.options)
|
|
665
676
|
if self.latent_dist=='studentt':
|
SURE/SURE.py
CHANGED
|
@@ -111,7 +111,6 @@ class SURE(nn.Module):
|
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
112
|
use_cuda: bool = False,
|
|
113
113
|
seed: int = 42,
|
|
114
|
-
zero_bias: bool = True,
|
|
115
114
|
dtype = torch.float32, # type: ignore
|
|
116
115
|
):
|
|
117
116
|
super().__init__()
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.use_bias = not zero_bias
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -234,26 +232,16 @@ class SURE(nn.Module):
|
|
|
234
232
|
)
|
|
235
233
|
|
|
236
234
|
if self.cell_factor_size>0:
|
|
237
|
-
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
else:
|
|
248
|
-
self.cell_factor_effect = ZeroBiasMLP(
|
|
249
|
-
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
250
|
-
activation=activate_fct,
|
|
251
|
-
output_activation=None,
|
|
252
|
-
post_layer_fct=post_layer_fct,
|
|
253
|
-
post_act_fct=post_act_fct,
|
|
254
|
-
allow_broadcast=self.allow_broadcast,
|
|
255
|
-
use_cuda=self.use_cuda,
|
|
256
|
-
)
|
|
235
|
+
self.cell_factor_effect = MLP(
|
|
236
|
+
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
237
|
+
activation=activate_fct,
|
|
238
|
+
output_activation=None,
|
|
239
|
+
post_layer_fct=post_layer_fct,
|
|
240
|
+
post_act_fct=post_act_fct,
|
|
241
|
+
allow_broadcast=self.allow_broadcast,
|
|
242
|
+
use_cuda=self.use_cuda,
|
|
243
|
+
)
|
|
244
|
+
|
|
257
245
|
|
|
258
246
|
self.decoder_concentrate = MLP(
|
|
259
247
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
SURE/flow/flow_stats.py
CHANGED
|
@@ -8,26 +8,15 @@ class VectorFieldEval:
|
|
|
8
8
|
def __init__(self):
|
|
9
9
|
pass
|
|
10
10
|
|
|
11
|
-
def directional_alignment(self, vectors
|
|
12
|
-
|
|
13
|
-
return weighted_alignment(vectors=vectors)
|
|
14
|
-
else:
|
|
15
|
-
return directional_alignment(vectors=vectors)
|
|
11
|
+
def directional_alignment(self, vectors):
|
|
12
|
+
return weighted_alignment(vectors=vectors)
|
|
16
13
|
|
|
17
|
-
def flow_coherence(self, vectors
|
|
18
|
-
|
|
19
|
-
return pca_based_coherence(vectors=vectors)
|
|
20
|
-
else:
|
|
21
|
-
return flow_coherence_index(vectors=vectors)
|
|
14
|
+
def flow_coherence(self, vectors):
|
|
15
|
+
return pca_based_coherence(vectors=vectors)
|
|
22
16
|
|
|
23
17
|
def momentum_flow_metric(self, vectors, masses=None):
|
|
24
18
|
return momentum_flow_metric(vectors=vectors, masses=masses)
|
|
25
|
-
|
|
26
|
-
def multi_scale_coherence(self, vectors, positoins, scale_factors=[1.0, 0.5, 0.1]):
|
|
27
|
-
return multi_scale_coherence(vectors=vectors, positions=positoins, scale_factors=scale_factors)
|
|
28
|
-
|
|
29
|
-
def vector_field_coherence_score(self, vectors, positions=None, weights=None):
|
|
30
|
-
return vector_field_coherence_score(vectors=vectors, positions=positions, weights=weights)
|
|
19
|
+
|
|
31
20
|
|
|
32
21
|
def calculate_movement_stats(vectors):
|
|
33
22
|
"""
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
2
|
-
SURE/SURE.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=4oz02P9kTRNJ897ITR1oOxKYGVw2sgvFKazz2FC-6f0,52118
|
|
2
|
+
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
5
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
@@ -9,7 +9,7 @@ SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
|
|
|
9
9
|
SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
|
|
10
10
|
SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
|
|
11
11
|
SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
12
|
-
SURE/flow/flow_stats.py,sha256=
|
|
12
|
+
SURE/flow/flow_stats.py,sha256=cBBsPEDpWNMpbzlyQ3f0385RSrX6_5RCH2caOyi4ihM,9908
|
|
13
13
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
14
|
SURE/perturb/__init__.py,sha256=ouxShhbxZM4r5Gf7GmKiutrsmtyq7QL8rHjhgF0BU08,32
|
|
15
15
|
SURE/perturb/perturb.py,sha256=CqO3xPfNA3cG175tadDidKvGsTu_yKfJRRLn_93awKM,3303
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.35.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.35.dist-info/METADATA,sha256=YFpehlaqDEuAK-0FPQA8AivqA4PJM1kWcEppjSk-p50,2651
|
|
22
|
+
sure_tools-2.1.35.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.35.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.35.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.35.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|