SURE-tools 2.1.82__tar.gz → 2.1.83__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.82 → sure_tools-2.1.83}/PKG-INFO +1 -1
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/PerturbFlow.py +11 -6
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.82 → sure_tools-2.1.83}/setup.py +1 -1
- {sure_tools-2.1.82 → sure_tools-2.1.83}/LICENSE +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/README.md +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/SURE.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.82 → sure_tools-2.1.83}/setup.cfg +0 -0
|
@@ -204,7 +204,7 @@ class PerturbFlow(nn.Module):
|
|
|
204
204
|
for i in np.arange(self.cell_factor_size):
|
|
205
205
|
if self.use_bias[i]:
|
|
206
206
|
self.cell_factor_effect.append(MLP(
|
|
207
|
-
[self.
|
|
207
|
+
[self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
208
208
|
activation=activate_fct,
|
|
209
209
|
output_activation=None,
|
|
210
210
|
post_layer_fct=post_layer_fct,
|
|
@@ -215,7 +215,7 @@ class PerturbFlow(nn.Module):
|
|
|
215
215
|
)
|
|
216
216
|
else:
|
|
217
217
|
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
218
|
-
[self.
|
|
218
|
+
[self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
219
219
|
activation=activate_fct,
|
|
220
220
|
output_activation=None,
|
|
221
221
|
post_layer_fct=post_layer_fct,
|
|
@@ -429,7 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
429
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
430
|
|
|
431
431
|
if self.cell_factor_size>0:
|
|
432
|
-
zus = self._total_effects(
|
|
432
|
+
zus = self._total_effects(ns, us)
|
|
433
433
|
zs = zns+zus
|
|
434
434
|
else:
|
|
435
435
|
zs = zns
|
|
@@ -631,7 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
631
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
632
|
# else:
|
|
633
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
zus = self._total_effects(
|
|
634
|
+
zus = self._total_effects(ns, us)
|
|
635
635
|
zs = zns+zus
|
|
636
636
|
else:
|
|
637
637
|
zs = zns
|
|
@@ -681,6 +681,9 @@ class PerturbFlow(nn.Module):
|
|
|
681
681
|
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
682
682
|
return zus
|
|
683
683
|
|
|
684
|
+
def _get_codebook_identity(self):
|
|
685
|
+
return torch.eye(self.code_size, **self.options)
|
|
686
|
+
|
|
684
687
|
def _get_codebook(self):
|
|
685
688
|
I = torch.eye(self.code_size, **self.options)
|
|
686
689
|
if self.latent_dist=='studentt':
|
|
@@ -839,7 +842,8 @@ class PerturbFlow(nn.Module):
|
|
|
839
842
|
|
|
840
843
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
841
844
|
#zns,_ = self.encoder_zn(xs)
|
|
842
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
845
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
846
|
+
zns = self._soft_assignments(xs)
|
|
843
847
|
if perturb.ndim==2:
|
|
844
848
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
845
849
|
else:
|
|
@@ -873,7 +877,8 @@ class PerturbFlow(nn.Module):
|
|
|
873
877
|
return Z
|
|
874
878
|
|
|
875
879
|
def get_metacell_response(self, factor_idx, perturb):
|
|
876
|
-
zs = self._get_codebook()
|
|
880
|
+
#zs = self._get_codebook()
|
|
881
|
+
zs = self._get_codebook_identity()
|
|
877
882
|
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
878
883
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
879
884
|
return tensor_to_numpy(ms)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|