SURE-tools 2.1.44__tar.gz → 2.1.45__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.44 → sure_tools-2.1.45}/PKG-INFO +1 -1
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/PerturbFlow.py +8 -47
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.44 → sure_tools-2.1.45}/setup.py +1 -1
- {sure_tools-2.1.44 → sure_tools-2.1.45}/LICENSE +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/README.md +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/SURE.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.44 → sure_tools-2.1.45}/setup.cfg +0 -0
|
@@ -189,29 +189,10 @@ class PerturbFlow(nn.Module):
|
|
|
189
189
|
use_cuda=self.use_cuda,
|
|
190
190
|
)
|
|
191
191
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
# activation=activate_fct,
|
|
195
|
-
# output_activation=[None, Exp],
|
|
196
|
-
# post_layer_fct=post_layer_fct,
|
|
197
|
-
# post_act_fct=post_act_fct,
|
|
198
|
-
# allow_broadcast=self.allow_broadcast,
|
|
199
|
-
# use_cuda=self.use_cuda,
|
|
200
|
-
#)
|
|
201
|
-
|
|
202
|
-
self.encoder_zn_loc = MLP(
|
|
203
|
-
[self.input_size] + hidden_sizes + [self.input_size * latent_dim],
|
|
204
|
-
activation=activate_fct,
|
|
205
|
-
output_activation=None,
|
|
206
|
-
post_layer_fct=post_layer_fct,
|
|
207
|
-
post_act_fct=post_act_fct,
|
|
208
|
-
allow_broadcast=self.allow_broadcast,
|
|
209
|
-
use_cuda=self.use_cuda,
|
|
210
|
-
)
|
|
211
|
-
self.encoder_zn_scale = MLP(
|
|
212
|
-
[self.input_size] + hidden_sizes + [latent_dim],
|
|
192
|
+
self.encoder_zn = MLP(
|
|
193
|
+
[self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
|
|
213
194
|
activation=activate_fct,
|
|
214
|
-
output_activation=Exp,
|
|
195
|
+
output_activation=[None, Exp],
|
|
215
196
|
post_layer_fct=post_layer_fct,
|
|
216
197
|
post_act_fct=post_act_fct,
|
|
217
198
|
allow_broadcast=self.allow_broadcast,
|
|
@@ -397,11 +378,7 @@ class PerturbFlow(nn.Module):
|
|
|
397
378
|
|
|
398
379
|
def guide1(self, xs):
|
|
399
380
|
with pyro.plate('data'):
|
|
400
|
-
|
|
401
|
-
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
402
|
-
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
403
|
-
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
404
|
-
zn_loc = zn_loc_3d.sum(dim=1)
|
|
381
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
405
382
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
406
383
|
|
|
407
384
|
alpha = self.encoder_n(zns)
|
|
@@ -489,11 +466,7 @@ class PerturbFlow(nn.Module):
|
|
|
489
466
|
|
|
490
467
|
def guide2(self, xs, us=None):
|
|
491
468
|
with pyro.plate('data'):
|
|
492
|
-
|
|
493
|
-
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
494
|
-
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
495
|
-
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
496
|
-
zn_loc = zn_loc_3d.sum(dim=1)
|
|
469
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
497
470
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
498
471
|
|
|
499
472
|
alpha = self.encoder_n(zns)
|
|
@@ -588,11 +561,7 @@ class PerturbFlow(nn.Module):
|
|
|
588
561
|
def guide3(self, xs, ys, embeds=None):
|
|
589
562
|
with pyro.plate('data'):
|
|
590
563
|
if embeds is None:
|
|
591
|
-
|
|
592
|
-
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
593
|
-
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
594
|
-
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
595
|
-
zn_loc = zn_loc_3d.sum(dim=1)
|
|
564
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
596
565
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
597
566
|
|
|
598
567
|
def model4(self, xs, us, ys, embeds=None):
|
|
@@ -694,11 +663,7 @@ class PerturbFlow(nn.Module):
|
|
|
694
663
|
def guide4(self, xs, us, ys, embeds=None):
|
|
695
664
|
with pyro.plate('data'):
|
|
696
665
|
if embeds is None:
|
|
697
|
-
|
|
698
|
-
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
699
|
-
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
700
|
-
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
701
|
-
zn_loc = zn_loc_3d.sum(dim=1)
|
|
666
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
702
667
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
703
668
|
|
|
704
669
|
def _total_effects(self, zns, us):
|
|
@@ -727,11 +692,7 @@ class PerturbFlow(nn.Module):
|
|
|
727
692
|
return cb
|
|
728
693
|
|
|
729
694
|
def _get_basal_embedding(self, xs):
|
|
730
|
-
|
|
731
|
-
batch_size = xs.shape[0]
|
|
732
|
-
zns = self.encoder_zn_loc(xs)
|
|
733
|
-
zns_3d = zns.view(batch_size, self.input_size, self.latent_dim)
|
|
734
|
-
zns = zns_3d.sum(dim=1)
|
|
695
|
+
zns, _ = self.encoder_zn(xs)
|
|
735
696
|
return zns
|
|
736
697
|
|
|
737
698
|
def get_basal_embedding(self,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|