SURE-tools 2.1.34__tar.gz → 2.1.36__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.34 → sure_tools-2.1.36}/PKG-INFO +1 -1
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/PerturbFlow.py +23 -12
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/SURE.py +10 -22
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/perturb/perturb.py +5 -1
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.34 → sure_tools-2.1.36}/setup.py +1 -1
- {sure_tools-2.1.34 → sure_tools-2.1.36}/LICENSE +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/README.md +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.34 → sure_tools-2.1.36}/setup.cfg +0 -0
|
@@ -423,12 +423,13 @@ class PerturbFlow(nn.Module):
|
|
|
423
423
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
424
424
|
|
|
425
425
|
if self.cell_factor_size>0:
|
|
426
|
-
zus = None
|
|
427
|
-
for i in np.arange(self.cell_factor_size):
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
426
|
+
#zus = None
|
|
427
|
+
#for i in np.arange(self.cell_factor_size):
|
|
428
|
+
# if i==0:
|
|
429
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
430
|
+
# else:
|
|
431
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
432
433
|
zs = zns+zus
|
|
433
434
|
else:
|
|
434
435
|
zs = zns
|
|
@@ -618,12 +619,13 @@ class PerturbFlow(nn.Module):
|
|
|
618
619
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
619
620
|
|
|
620
621
|
if self.cell_factor_size>0:
|
|
621
|
-
zus = None
|
|
622
|
-
for i in np.arange(self.cell_factor_size):
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
622
|
+
#zus = None
|
|
623
|
+
#for i in np.arange(self.cell_factor_size):
|
|
624
|
+
# if i==0:
|
|
625
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
626
|
+
# else:
|
|
627
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
628
|
+
zus = self._total_effects(zns, us)
|
|
627
629
|
zs = zns+zus
|
|
628
630
|
else:
|
|
629
631
|
zs = zns
|
|
@@ -660,6 +662,15 @@ class PerturbFlow(nn.Module):
|
|
|
660
662
|
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
661
663
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
662
664
|
|
|
665
|
+
def _total_effects(self, zns, us):
|
|
666
|
+
zus = None
|
|
667
|
+
for i in np.arange(self.cell_factor_size):
|
|
668
|
+
if i==0:
|
|
669
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
670
|
+
else:
|
|
671
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
672
|
+
return zus
|
|
673
|
+
|
|
663
674
|
def _get_codebook(self):
|
|
664
675
|
I = torch.eye(self.code_size, **self.options)
|
|
665
676
|
if self.latent_dist=='studentt':
|
|
@@ -111,7 +111,6 @@ class SURE(nn.Module):
|
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
112
|
use_cuda: bool = False,
|
|
113
113
|
seed: int = 42,
|
|
114
|
-
zero_bias: bool = True,
|
|
115
114
|
dtype = torch.float32, # type: ignore
|
|
116
115
|
):
|
|
117
116
|
super().__init__()
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.use_bias = not zero_bias
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -234,26 +232,16 @@ class SURE(nn.Module):
|
|
|
234
232
|
)
|
|
235
233
|
|
|
236
234
|
if self.cell_factor_size>0:
|
|
237
|
-
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
else:
|
|
248
|
-
self.cell_factor_effect = ZeroBiasMLP(
|
|
249
|
-
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
250
|
-
activation=activate_fct,
|
|
251
|
-
output_activation=None,
|
|
252
|
-
post_layer_fct=post_layer_fct,
|
|
253
|
-
post_act_fct=post_act_fct,
|
|
254
|
-
allow_broadcast=self.allow_broadcast,
|
|
255
|
-
use_cuda=self.use_cuda,
|
|
256
|
-
)
|
|
235
|
+
self.cell_factor_effect = MLP(
|
|
236
|
+
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
237
|
+
activation=activate_fct,
|
|
238
|
+
output_activation=None,
|
|
239
|
+
post_layer_fct=post_layer_fct,
|
|
240
|
+
post_act_fct=post_act_fct,
|
|
241
|
+
allow_broadcast=self.allow_broadcast,
|
|
242
|
+
use_cuda=self.use_cuda,
|
|
243
|
+
)
|
|
244
|
+
|
|
257
245
|
|
|
258
246
|
self.decoder_concentrate = MLP(
|
|
259
247
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -8,7 +8,7 @@ class LabelMatrix:
|
|
|
8
8
|
def __init__(self):
|
|
9
9
|
self.labels_ = None
|
|
10
10
|
|
|
11
|
-
def fit_transform(self, labels, sep_pattern=r'[;_\-\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
11
|
+
def fit_transform(self, labels, control_label, sep_pattern=r'[;_\-\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
12
12
|
if speedup=='none':
|
|
13
13
|
mat, self.labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
14
14
|
elif speedup=='vectorize':
|
|
@@ -17,6 +17,10 @@ class LabelMatrix:
|
|
|
17
17
|
mat, self.labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
18
18
|
|
|
19
19
|
self.labels_ = np.array(self.labels_)
|
|
20
|
+
|
|
21
|
+
idx = np.where(self.labels_==control_label)[0]
|
|
22
|
+
mat = np.delete(mat, idx, axis=1)
|
|
23
|
+
self.labels_ = np.delete(self.labels_, idx)
|
|
20
24
|
return mat
|
|
21
25
|
|
|
22
26
|
def inverse_transform(self, matrix):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|