SURE-tools 4.0.0__py3-none-any.whl → 4.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/SURE_vae.py +3 -3
- {sure_tools-4.0.0.dist-info → sure_tools-4.0.1.dist-info}/METADATA +1 -1
- {sure_tools-4.0.0.dist-info → sure_tools-4.0.1.dist-info}/RECORD +6 -6
- {sure_tools-4.0.0.dist-info → sure_tools-4.0.1.dist-info}/WHEEL +0 -0
- {sure_tools-4.0.0.dist-info → sure_tools-4.0.1.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-4.0.0.dist-info → sure_tools-4.0.1.dist-info}/top_level.txt +0 -0
SURE/SURE_vae.py
CHANGED
|
@@ -294,7 +294,7 @@ class SUREVAE(nn.Module):
|
|
|
294
294
|
)
|
|
295
295
|
|
|
296
296
|
self.decoder_log_mu = MLP(
|
|
297
|
-
[self.latent_dim
|
|
297
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_dim],
|
|
298
298
|
activation=activate_fct,
|
|
299
299
|
output_activation=None,
|
|
300
300
|
post_layer_fct=post_layer_fct,
|
|
@@ -436,7 +436,7 @@ class SUREVAE(nn.Module):
|
|
|
436
436
|
else:
|
|
437
437
|
zfs = torch.zeros_like(zs)
|
|
438
438
|
|
|
439
|
-
log_mu = self.decoder_log_mu(
|
|
439
|
+
log_mu = self.decoder_log_mu(zs+zcs+zps+zfs)
|
|
440
440
|
if self.loss_func in ['bernoulli']:
|
|
441
441
|
log_theta = log_mu
|
|
442
442
|
elif self.loss_func in ['negbinomial']:
|
|
@@ -699,7 +699,7 @@ class SUREVAE(nn.Module):
|
|
|
699
699
|
|
|
700
700
|
zfs = torch.zeros_like(z_basal)
|
|
701
701
|
|
|
702
|
-
log_mu = self.decoder_log_mu(
|
|
702
|
+
log_mu = self.decoder_log_mu(z_basal+zcs+zps+zfs)
|
|
703
703
|
if self.loss_func == 'bernoulli':
|
|
704
704
|
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
705
705
|
else:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
SURE/SURE.py,sha256=8E39np6zhLbT1cp1xYOg5xLwzFHBupKIm1ydLxNKtqM,16654
|
|
2
2
|
SURE/SUREMO.py,sha256=hN0G0ZEBNQdmj0gGlBIy1wjKGKqAMUemDjQeICzvNUY,47644
|
|
3
3
|
SURE/SURE_nsf.py,sha256=VR6YgEiIfu7mRH0XLbovfjd2X3WFuom2J4_AgHrq2dM,49040
|
|
4
|
-
SURE/SURE_vae.py,sha256=
|
|
4
|
+
SURE/SURE_vae.py,sha256=b3xluu49kmuDlcpZnTAJc92SvQI8MPDN9ntBfnKTJOQ,45850
|
|
5
5
|
SURE/SURE_vanilla.py,sha256=I1RHHCpzk8ml1vMdH_gITOzAFXrHYUA7IAwbVSoxmBo,27327
|
|
6
6
|
SURE/__init__.py,sha256=eJN0vlGblWir1JHaoiQqbQHzS_C1PNYcA_ls7UviTqc,444
|
|
7
7
|
SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
|
|
@@ -15,8 +15,8 @@ SURE/utils/custom_mlp.py,sha256=XvviNUYeaZ5D1evqETyWLxgtL56mmFQR_pxsYpKw0yY,1170
|
|
|
15
15
|
SURE/utils/label.py,sha256=joKO1mSkjZXeLvSeC7GluQk4-_qgGgPqlwWixdcbKMQ,4648
|
|
16
16
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
17
17
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
18
|
-
sure_tools-4.0.
|
|
19
|
-
sure_tools-4.0.
|
|
20
|
-
sure_tools-4.0.
|
|
21
|
-
sure_tools-4.0.
|
|
22
|
-
sure_tools-4.0.
|
|
18
|
+
sure_tools-4.0.1.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
19
|
+
sure_tools-4.0.1.dist-info/METADATA,sha256=m4l4x8KnzdyfZquikfwrSqZXegaeRl6X9plr-fcuGrc,1661
|
|
20
|
+
sure_tools-4.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
21
|
+
sure_tools-4.0.1.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
22
|
+
sure_tools-4.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|