SURE-tools 3.7.0__tar.gz → 3.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-3.7.0 → sure_tools-3.7.1}/PKG-INFO +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/SURE_vae.py +33 -7
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.1}/setup.py +1 -1
- {sure_tools-3.7.0 → sure_tools-3.7.1}/LICENSE +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/README.md +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/SURE.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/SUREMO.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/SURE_nsf.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/atac/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/atac/utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/dist/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/graph/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/utils/__init__.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/utils/label.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/utils/queue.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE/utils/utils.py +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-3.7.0 → sure_tools-3.7.1}/setup.cfg +0 -0
|
@@ -256,18 +256,33 @@ class SUREVAE(nn.Module):
|
|
|
256
256
|
)
|
|
257
257
|
|
|
258
258
|
if np.sum(self.condition_sizes)>0:
|
|
259
|
-
self.
|
|
259
|
+
self.decoder_condition_effects = nn.ModuleList()
|
|
260
260
|
for condition_size in self.condition_sizes:
|
|
261
|
-
self.
|
|
262
|
-
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
261
|
+
self.decoder_condition_effects.append(ZeroBiasMLP3(
|
|
262
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [[self.latent_dim,self.latent_dim]],
|
|
263
263
|
activation=activate_fct,
|
|
264
|
-
output_activation=None,
|
|
264
|
+
output_activation=[None,Exp],
|
|
265
265
|
post_layer_fct=post_layer_fct,
|
|
266
266
|
post_act_fct=post_act_fct,
|
|
267
267
|
allow_broadcast=self.allow_broadcast,
|
|
268
268
|
use_cuda=self.use_cuda,
|
|
269
269
|
)
|
|
270
270
|
)
|
|
271
|
+
|
|
272
|
+
self.encoder_condition_effects = self.decoder_condition_effects
|
|
273
|
+
'''self.encoder_condition_effects = nn.ModuleList()
|
|
274
|
+
for condition_size in self.condition_sizes:
|
|
275
|
+
self.encoder_condition_effects.append(ZeroBiasMLP3(
|
|
276
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [[self.latent_dim,self.latent_dim]],
|
|
277
|
+
activation=activate_fct,
|
|
278
|
+
output_activation=[None,Exp],
|
|
279
|
+
post_layer_fct=post_layer_fct,
|
|
280
|
+
post_act_fct=post_act_fct,
|
|
281
|
+
allow_broadcast=self.allow_broadcast,
|
|
282
|
+
use_cuda=self.use_cuda,
|
|
283
|
+
)
|
|
284
|
+
)'''
|
|
285
|
+
|
|
271
286
|
if self.covariate_size>0:
|
|
272
287
|
self.covariate_effect = ZeroBiasMLP2(
|
|
273
288
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -409,7 +424,10 @@ class SUREVAE(nn.Module):
|
|
|
409
424
|
shift = 0
|
|
410
425
|
for i, condition_size in enumerate(self.condition_sizes):
|
|
411
426
|
cs_i = cs[:,shift:(shift+condition_size)]
|
|
412
|
-
|
|
427
|
+
l,s = self.decoder_condition_effects[i]([cs_i,zns])
|
|
428
|
+
zcs_i = pyro.sample(f'zcs_{i}', dist.Normal(l, s).to_event(1))
|
|
429
|
+
#zcs += self.condition_effects[i]([cs_i,zns])
|
|
430
|
+
zcs += zcs_i
|
|
413
431
|
shift += condition_size
|
|
414
432
|
else:
|
|
415
433
|
zcs = torch.zeros_like(zs)
|
|
@@ -458,6 +476,13 @@ class SUREVAE(nn.Module):
|
|
|
458
476
|
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
459
477
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
460
478
|
|
|
479
|
+
shift = 0
|
|
480
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
481
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
482
|
+
l,s = self.encoder_condition_effects[i]([cs_i,zns])
|
|
483
|
+
zcs_i = pyro.sample(f'zcs_{i}', dist.Normal(l, s).to_event(1))
|
|
484
|
+
shift += condition_size
|
|
485
|
+
|
|
461
486
|
#alpha = self.encoder_n(zns)
|
|
462
487
|
alpha = self.encoder_alpha(zns)
|
|
463
488
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
@@ -602,7 +627,7 @@ class SUREVAE(nn.Module):
|
|
|
602
627
|
C_batch = cs[idx].to(self.get_device())
|
|
603
628
|
|
|
604
629
|
z_basal = self._get_cell_embedding(X_batch)
|
|
605
|
-
dzs = self.
|
|
630
|
+
dzs,_ = self.encoder_condition_effects[i]([C_batch,z_basal])
|
|
606
631
|
|
|
607
632
|
A.append(tensor_to_numpy(dzs))
|
|
608
633
|
pbar.update(1)
|
|
@@ -646,7 +671,8 @@ class SUREVAE(nn.Module):
|
|
|
646
671
|
shift = 0
|
|
647
672
|
for i, condition_size in enumerate(self.condition_sizes):
|
|
648
673
|
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
649
|
-
|
|
674
|
+
zcs_i,_ = self.encoder_condition_effects[i]([C_batch_i,z_basal])
|
|
675
|
+
zcs += zcs_i
|
|
650
676
|
shift += condition_size
|
|
651
677
|
|
|
652
678
|
zfs = torch.zeros_like(z_basal)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|