SURE-tools 3.7.0__tar.gz → 3.7.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-3.7.0 → sure_tools-3.7.2}/PKG-INFO +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/SURE_vae.py +32 -7
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.2}/setup.py +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.2}/LICENSE +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/README.md +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/SURE.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/SUREMO.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/SURE_nsf.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/atac/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/atac/utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/dist/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/graph/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/utils/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/utils/label.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/utils/queue.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE/utils/utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.2}/setup.cfg +0 -0
|
@@ -256,18 +256,32 @@ class SUREVAE(nn.Module):
|
|
|
256
256
|
)
|
|
257
257
|
|
|
258
258
|
if np.sum(self.condition_sizes)>0:
|
|
259
|
-
self.
|
|
259
|
+
self.decoder_condition_effects = nn.ModuleList()
|
|
260
260
|
for condition_size in self.condition_sizes:
|
|
261
|
-
self.
|
|
262
|
-
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
261
|
+
self.decoder_condition_effects.append(ZeroBiasMLP3(
|
|
262
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [[self.latent_dim,self.latent_dim]],
|
|
263
263
|
activation=activate_fct,
|
|
264
|
-
output_activation=None,
|
|
264
|
+
output_activation=[None,Exp],
|
|
265
265
|
post_layer_fct=post_layer_fct,
|
|
266
266
|
post_act_fct=post_act_fct,
|
|
267
267
|
allow_broadcast=self.allow_broadcast,
|
|
268
268
|
use_cuda=self.use_cuda,
|
|
269
269
|
)
|
|
270
270
|
)
|
|
271
|
+
|
|
272
|
+
'''self.encoder_condition_effects = nn.ModuleList()
|
|
273
|
+
for condition_size in self.condition_sizes:
|
|
274
|
+
self.encoder_condition_effects.append(ZeroBiasMLP3(
|
|
275
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [[self.latent_dim,self.latent_dim]],
|
|
276
|
+
activation=activate_fct,
|
|
277
|
+
output_activation=[None,Exp],
|
|
278
|
+
post_layer_fct=post_layer_fct,
|
|
279
|
+
post_act_fct=post_act_fct,
|
|
280
|
+
allow_broadcast=self.allow_broadcast,
|
|
281
|
+
use_cuda=self.use_cuda,
|
|
282
|
+
)
|
|
283
|
+
)'''
|
|
284
|
+
|
|
271
285
|
if self.covariate_size>0:
|
|
272
286
|
self.covariate_effect = ZeroBiasMLP2(
|
|
273
287
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -409,7 +423,10 @@ class SUREVAE(nn.Module):
|
|
|
409
423
|
shift = 0
|
|
410
424
|
for i, condition_size in enumerate(self.condition_sizes):
|
|
411
425
|
cs_i = cs[:,shift:(shift+condition_size)]
|
|
412
|
-
|
|
426
|
+
l,s = self.decoder_condition_effects[i]([cs_i,zns])
|
|
427
|
+
zcs_i = pyro.sample(f'zcs_{i}', dist.Normal(l, s).to_event(1))
|
|
428
|
+
#zcs += self.condition_effects[i]([cs_i,zns])
|
|
429
|
+
zcs += zcs_i
|
|
413
430
|
shift += condition_size
|
|
414
431
|
else:
|
|
415
432
|
zcs = torch.zeros_like(zs)
|
|
@@ -458,6 +475,13 @@ class SUREVAE(nn.Module):
|
|
|
458
475
|
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
459
476
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
460
477
|
|
|
478
|
+
shift = 0
|
|
479
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
480
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
481
|
+
l,s = self.decoder_condition_effects[i]([cs_i,zns])
|
|
482
|
+
zcs_i = pyro.sample(f'zcs_{i}', dist.Normal(l, s).to_event(1))
|
|
483
|
+
shift += condition_size
|
|
484
|
+
|
|
461
485
|
#alpha = self.encoder_n(zns)
|
|
462
486
|
alpha = self.encoder_alpha(zns)
|
|
463
487
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
@@ -602,7 +626,7 @@ class SUREVAE(nn.Module):
|
|
|
602
626
|
C_batch = cs[idx].to(self.get_device())
|
|
603
627
|
|
|
604
628
|
z_basal = self._get_cell_embedding(X_batch)
|
|
605
|
-
dzs = self.
|
|
629
|
+
dzs,_ = self.decoder_condition_effects[i]([C_batch,z_basal])
|
|
606
630
|
|
|
607
631
|
A.append(tensor_to_numpy(dzs))
|
|
608
632
|
pbar.update(1)
|
|
@@ -646,7 +670,8 @@ class SUREVAE(nn.Module):
|
|
|
646
670
|
shift = 0
|
|
647
671
|
for i, condition_size in enumerate(self.condition_sizes):
|
|
648
672
|
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
649
|
-
|
|
673
|
+
zcs_i,_ = self.decoder_condition_effects[i]([C_batch_i,z_basal])
|
|
674
|
+
zcs += zcs_i
|
|
650
675
|
shift += condition_size
|
|
651
676
|
|
|
652
677
|
zfs = torch.zeros_like(z_basal)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|