careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
This script is meant to load data,
|
|
2
|
+
This script is meant to load data, initialize the model, and provide the logic for training it.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import glob
|
|
@@ -20,8 +20,11 @@ from torch.utils.data import DataLoader
|
|
|
20
20
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
|
21
21
|
print(sys.path)
|
|
22
22
|
|
|
23
|
-
from careamics.lvae_training.data_modules import
|
|
24
|
-
|
|
23
|
+
from careamics.lvae_training.dataset.data_modules import (
|
|
24
|
+
LCMultiChDloader,
|
|
25
|
+
MultiChDloader,
|
|
26
|
+
)
|
|
27
|
+
from careamics.lvae_training.dataset.data_utils import DataSplitType
|
|
25
28
|
from careamics.lvae_training.lightning_module import LadderVAELight
|
|
26
29
|
from careamics.lvae_training.train_utils import *
|
|
27
30
|
|
|
@@ -25,7 +25,7 @@ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
|
25
25
|
def create_env_text(pytorch_version: str) -> str:
|
|
26
26
|
"""Create environment yaml content for the bioimage model.
|
|
27
27
|
|
|
28
|
-
This installs an
|
|
28
|
+
This installs an environment with the specified pytorch version and the latest
|
|
29
29
|
changes to careamics.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
@@ -204,7 +204,7 @@ def create_model_description(
|
|
|
204
204
|
config : Configuration
|
|
205
205
|
CAREamics configuration.
|
|
206
206
|
name : str
|
|
207
|
-
Name
|
|
207
|
+
Name of the model.
|
|
208
208
|
general_description : str
|
|
209
209
|
General description of the model.
|
|
210
210
|
authors : List[Author]
|
|
@@ -252,7 +252,7 @@ def create_model_description(
|
|
|
252
252
|
|
|
253
253
|
# weights description
|
|
254
254
|
architecture_descr = ArchitectureFromLibraryDescr(
|
|
255
|
-
import_from="careamics.models",
|
|
255
|
+
import_from="careamics.models.unet",
|
|
256
256
|
callable=f"{config.algorithm_config.model.architecture}",
|
|
257
257
|
kwargs=config.algorithm_config.model.model_dump(),
|
|
258
258
|
)
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -12,7 +12,7 @@ from torch import __version__, load, save
|
|
|
12
12
|
|
|
13
13
|
from careamics.config import Configuration, load_configuration, save_configuration
|
|
14
14
|
from careamics.config.support import SupportedArchitecture
|
|
15
|
-
from careamics.lightning.lightning_module import
|
|
15
|
+
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
16
16
|
|
|
17
17
|
from .bioimage import (
|
|
18
18
|
create_env_text,
|
|
@@ -22,7 +22,9 @@ from .bioimage import (
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def _export_state_dict(
|
|
25
|
+
def _export_state_dict(
|
|
26
|
+
model: Union[FCNModule, VAEModule], path: Union[Path, str]
|
|
27
|
+
) -> Path:
|
|
26
28
|
"""
|
|
27
29
|
Export the model state dictionary to a file.
|
|
28
30
|
|
|
@@ -52,7 +54,9 @@ def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
|
|
|
52
54
|
return path
|
|
53
55
|
|
|
54
56
|
|
|
55
|
-
def _load_state_dict(
|
|
57
|
+
def _load_state_dict(
|
|
58
|
+
model: Union[FCNModule, VAEModule], path: Union[Path, str]
|
|
59
|
+
) -> None:
|
|
56
60
|
"""
|
|
57
61
|
Load a model from a state dictionary.
|
|
58
62
|
|
|
@@ -74,7 +78,7 @@ def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
|
|
|
74
78
|
|
|
75
79
|
# TODO break down in subfunctions
|
|
76
80
|
def export_to_bmz(
|
|
77
|
-
model:
|
|
81
|
+
model: Union[FCNModule, VAEModule],
|
|
78
82
|
config: Configuration,
|
|
79
83
|
path_to_archive: Union[Path, str],
|
|
80
84
|
model_name: str,
|
|
@@ -187,7 +191,9 @@ def export_to_bmz(
|
|
|
187
191
|
save_bioimageio_package(model_description, output_path=path_to_archive)
|
|
188
192
|
|
|
189
193
|
|
|
190
|
-
def load_from_bmz(
|
|
194
|
+
def load_from_bmz(
|
|
195
|
+
path: Union[Path, str]
|
|
196
|
+
) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
191
197
|
"""Load a model from a BioImage Model Zoo archive.
|
|
192
198
|
|
|
193
199
|
Parameters
|
|
@@ -225,7 +231,14 @@ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuratio
|
|
|
225
231
|
config = load_configuration(config_path)
|
|
226
232
|
|
|
227
233
|
# create careamics lightning module
|
|
228
|
-
model
|
|
234
|
+
if config.algorithm_config.model.architecture == SupportedArchitecture.UNET:
|
|
235
|
+
model = FCNModule(algorithm_config=config.algorithm_config)
|
|
236
|
+
elif config.algorithm_config.model.architecture == SupportedArchitecture.LVAE:
|
|
237
|
+
model = VAEModule(algorithm_config=config.algorithm_config)
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Unsupported architecture {config.algorithm_config.model.architecture}"
|
|
241
|
+
) # TODO ugly ?
|
|
229
242
|
|
|
230
243
|
# load model state dictionary
|
|
231
244
|
_load_state_dict(model, weights_path)
|
|
@@ -6,12 +6,14 @@ from typing import Tuple, Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from careamics.config import Configuration
|
|
9
|
-
from careamics.lightning.lightning_module import
|
|
9
|
+
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def load_pretrained(
|
|
14
|
+
def load_pretrained(
|
|
15
|
+
path: Union[Path, str]
|
|
16
|
+
) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
15
17
|
"""
|
|
16
18
|
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
17
19
|
|
|
@@ -44,7 +46,9 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configurat
|
|
|
44
46
|
)
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
def _load_checkpoint(
|
|
49
|
+
def _load_checkpoint(
|
|
50
|
+
path: Union[Path, str]
|
|
51
|
+
) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
48
52
|
"""
|
|
49
53
|
Load a model from a checkpoint and return both model and configuration.
|
|
50
54
|
|
|
@@ -78,6 +82,14 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
|
|
|
78
82
|
f"checkpoint: {checkpoint.keys()}"
|
|
79
83
|
) from e
|
|
80
84
|
|
|
81
|
-
model
|
|
85
|
+
if cfg_dict["algorithm_config"]["model"]["architecture"] == "UNet":
|
|
86
|
+
model = FCNModule.load_from_checkpoint(path)
|
|
87
|
+
elif cfg_dict["algorithm_config"]["model"]["architecture"] == "LVAE":
|
|
88
|
+
model = VAEModule.load_from_checkpoint(path)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"Invalid model architecture: "
|
|
92
|
+
f"{cfg_dict['algorithm_config']['model']['architecture']}"
|
|
93
|
+
)
|
|
82
94
|
|
|
83
95
|
return model, Configuration(**cfg_dict)
|
careamics/models/__init__.py
CHANGED
careamics/models/activation.py
CHANGED
|
@@ -23,6 +23,8 @@ def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
|
|
|
23
23
|
"""
|
|
24
24
|
if activation == SupportedActivation.RELU:
|
|
25
25
|
return nn.ReLU()
|
|
26
|
+
elif activation == SupportedActivation.ELU:
|
|
27
|
+
return nn.ELU()
|
|
26
28
|
elif activation == SupportedActivation.LEAKYRELU:
|
|
27
29
|
return nn.LeakyReLU()
|
|
28
30
|
elif activation == SupportedActivation.TANH:
|
careamics/models/lvae/layers.py
CHANGED
|
@@ -117,7 +117,7 @@ class ResidualBlock(nn.Module):
|
|
|
117
117
|
bias=conv2d_bias,
|
|
118
118
|
)
|
|
119
119
|
modules.append(conv)
|
|
120
|
-
modules.append(nonlin
|
|
120
|
+
modules.append(nonlin)
|
|
121
121
|
if batchnorm:
|
|
122
122
|
modules.append(nn.BatchNorm2d(channels))
|
|
123
123
|
if dropout is not None:
|
|
@@ -126,7 +126,7 @@ class ResidualBlock(nn.Module):
|
|
|
126
126
|
for i in range(2):
|
|
127
127
|
if batchnorm:
|
|
128
128
|
modules.append(nn.BatchNorm2d(channels))
|
|
129
|
-
modules.append(nonlin
|
|
129
|
+
modules.append(nonlin)
|
|
130
130
|
conv = nn.Conv2d(
|
|
131
131
|
channels,
|
|
132
132
|
channels,
|
|
@@ -142,7 +142,7 @@ class ResidualBlock(nn.Module):
|
|
|
142
142
|
for i in range(2):
|
|
143
143
|
if batchnorm:
|
|
144
144
|
modules.append(nn.BatchNorm2d(channels))
|
|
145
|
-
modules.append(nonlin
|
|
145
|
+
modules.append(nonlin)
|
|
146
146
|
conv = nn.Conv2d(
|
|
147
147
|
channels,
|
|
148
148
|
channels,
|
|
@@ -189,7 +189,7 @@ class GateLayer2d(nn.Module):
|
|
|
189
189
|
assert kernel_size % 2 == 1
|
|
190
190
|
pad = kernel_size // 2
|
|
191
191
|
self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad)
|
|
192
|
-
self.nonlin = nonlin
|
|
192
|
+
self.nonlin = nonlin
|
|
193
193
|
|
|
194
194
|
def forward(self, x):
|
|
195
195
|
x = self.conv(x)
|
|
@@ -255,7 +255,7 @@ class ResBlockWithResampling(nn.Module):
|
|
|
255
255
|
resample: bool, optional
|
|
256
256
|
Whether to perform resampling in the first convolutional layer.
|
|
257
257
|
If `False`, the first convolutional layer just maps the input to a tensor with
|
|
258
|
-
`inner_channels` channels through 1x1 convolution.
|
|
258
|
+
`inner_channels` channels through 1x1 convolution. Default is `False`.
|
|
259
259
|
res_block_kernel: Union[int, Iterable[int]], optional
|
|
260
260
|
The kernel size used in the convolutions of the residual block.
|
|
261
261
|
It can be either a single integer or a pair of integers defining the squared kernel.
|
|
@@ -837,7 +837,7 @@ class TopDownLayer(nn.Module):
|
|
|
837
837
|
- In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path,
|
|
838
838
|
by merging outcomes of bottom-up and top-down passes. The exception is the top layer,
|
|
839
839
|
in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer.
|
|
840
|
-
- On the contrary in
|
|
840
|
+
- On the contrary in predicition/generative mode, parameters of q(z_i|z_i+1) can be obtained
|
|
841
841
|
once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is
|
|
842
842
|
possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION).
|
|
843
843
|
|
|
@@ -899,7 +899,7 @@ class TopDownLayer(nn.Module):
|
|
|
899
899
|
The number of downsampling steps that has to be done in this layer (typically 1).
|
|
900
900
|
Default is `False`.
|
|
901
901
|
nonlin: Callable, optional
|
|
902
|
-
The non-linearity function used in the block (e.g., `nn.ReLU`).
|
|
902
|
+
The non-linearity function used in the block (e.g., `nn.ReLU`). Default is `None`.
|
|
903
903
|
merge_type: Literal["linear", "residual", "residual_ungated"], optional
|
|
904
904
|
The type of merge done in the layer. It can be chosen between "linear", "residual",
|
|
905
905
|
and "residual_ungated". Check the `MergeLayer` class docstring for more information
|
|
@@ -931,7 +931,7 @@ class TopDownLayer(nn.Module):
|
|
|
931
931
|
Whether to set the top prior as learnable.
|
|
932
932
|
If this is set to `False`, in the top-most layer the prior will be N(0,1).
|
|
933
933
|
Otherwise, we will still have a normal distribution whose parameters will be learnt.
|
|
934
|
-
|
|
934
|
+
Default is `False`.
|
|
935
935
|
top_prior_param_shape: Iterable[int], optional
|
|
936
936
|
The size of the tensor which expresses the mean and the variance
|
|
937
937
|
of the prior for the top most layer. Default is `None`.
|
|
@@ -1102,7 +1102,7 @@ class TopDownLayer(nn.Module):
|
|
|
1102
1102
|
The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass
|
|
1103
1103
|
at the correspondent hierarchical layer.
|
|
1104
1104
|
var_clip_max: float, optional
|
|
1105
|
-
The maximum value reachable by the log-variance of the latent
|
|
1105
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1106
1106
|
Values exceeding this threshold are clipped. Default is `None`.
|
|
1107
1107
|
mask: Union[None, torch.Tensor], optional
|
|
1108
1108
|
A tensor that is used to mask the sampled latent tensor. Default is `None`.
|
|
@@ -1171,7 +1171,7 @@ class TopDownLayer(nn.Module):
|
|
|
1171
1171
|
at the correspondent hierarchical layer.
|
|
1172
1172
|
"""
|
|
1173
1173
|
if bu_value.shape[-2:] != p_params.shape[-2:]:
|
|
1174
|
-
assert self.bottomup_no_padding_mode is True
|
|
1174
|
+
assert self.bottomup_no_padding_mode is True # TODO WTF ?
|
|
1175
1175
|
if self.topdown_no_padding_mode is False:
|
|
1176
1176
|
assert bu_value.shape[-1] > p_params.shape[-1]
|
|
1177
1177
|
bu_value = F.center_crop(bu_value, p_params.shape[-2:])
|
|
@@ -1218,7 +1218,7 @@ class TopDownLayer(nn.Module):
|
|
|
1218
1218
|
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
1219
1219
|
hence, sampling does not happen. Default is `None`.
|
|
1220
1220
|
use_mode: bool, optional
|
|
1221
|
-
|
|
1221
|
+
Whether the latent tensor should be set as the latent distribution mode.
|
|
1222
1222
|
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1223
1223
|
Default is `False`.
|
|
1224
1224
|
force_constant_output: bool, optional
|
|
@@ -1230,7 +1230,7 @@ class TopDownLayer(nn.Module):
|
|
|
1230
1230
|
use_uncond_mode: bool, optional
|
|
1231
1231
|
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1232
1232
|
var_clip_max: float
|
|
1233
|
-
The maximum value reachable by the log-variance of the latent
|
|
1233
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1234
1234
|
Values exceeding this threshold are clipped.
|
|
1235
1235
|
"""
|
|
1236
1236
|
# Check consistency of arguments
|
|
@@ -1241,7 +1241,7 @@ class TopDownLayer(nn.Module):
|
|
|
1241
1241
|
p_params = self.get_p_params(input_, n_img_prior)
|
|
1242
1242
|
|
|
1243
1243
|
# Get the parameters for the latent distribution to sample from
|
|
1244
|
-
if inference_mode:
|
|
1244
|
+
if inference_mode: # TODO What's this ?
|
|
1245
1245
|
if self.is_top_layer:
|
|
1246
1246
|
q_params = bu_value
|
|
1247
1247
|
if mode_pred is False:
|
|
@@ -1466,7 +1466,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1466
1466
|
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
1467
1467
|
hence, sampling does not happen.
|
|
1468
1468
|
use_mode: bool
|
|
1469
|
-
|
|
1469
|
+
Whether the latent tensor should be set as the latent distribution mode.
|
|
1470
1470
|
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1471
1471
|
mode_pred: bool
|
|
1472
1472
|
Whether the model is prediction mode.
|
|
@@ -1501,7 +1501,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1501
1501
|
q_params: torch.Tensor
|
|
1502
1502
|
The input tensor to be processed.
|
|
1503
1503
|
var_clip_max: float
|
|
1504
|
-
The maximum value reachable by the log-variance of the latent
|
|
1504
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1505
1505
|
Values exceeding this threshold are clipped.
|
|
1506
1506
|
"""
|
|
1507
1507
|
_, _, q = self.process_q_params(q_params, var_clip_max)
|
|
@@ -1600,7 +1600,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1600
1600
|
p_params: torch.Tensor
|
|
1601
1601
|
The input tensor to be processed.
|
|
1602
1602
|
var_clip_max: float
|
|
1603
|
-
The maximum value reachable by the log-variance of the latent
|
|
1603
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1604
1604
|
Values exceeding this threshold are clipped.
|
|
1605
1605
|
"""
|
|
1606
1606
|
if self.transform_p_params:
|
|
@@ -1636,7 +1636,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1636
1636
|
p_params: torch.Tensor
|
|
1637
1637
|
The input tensor to be processed.
|
|
1638
1638
|
var_clip_max: float
|
|
1639
|
-
The maximum value reachable by the log-variance of the latent
|
|
1639
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1640
1640
|
Values exceeding this threshold are clipped.
|
|
1641
1641
|
"""
|
|
1642
1642
|
q_params = self.conv_in_q(q_params)
|
|
@@ -1681,7 +1681,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1681
1681
|
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1682
1682
|
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1683
1683
|
use_mode: bool, optional
|
|
1684
|
-
|
|
1684
|
+
Whether the latent tensor should be set as the latent distribution mode.
|
|
1685
1685
|
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1686
1686
|
Default is `False`.
|
|
1687
1687
|
force_constant_output: bool, optional
|
|
@@ -1697,7 +1697,7 @@ class NormalStochasticBlock2d(nn.Module):
|
|
|
1697
1697
|
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1698
1698
|
Default is `False`.
|
|
1699
1699
|
var_clip_max: float, optional
|
|
1700
|
-
The maximum value reachable by the log-variance of the latent
|
|
1700
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1701
1701
|
Values exceeding this threshold are clipped. Default is `None`.
|
|
1702
1702
|
"""
|
|
1703
1703
|
debug_qvar_max = 0
|
|
@@ -1928,7 +1928,7 @@ class NonStochasticBlock2d(nn.Module):
|
|
|
1928
1928
|
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
1929
1929
|
tensor and, hence, sampling does not happen. Default is `None`.
|
|
1930
1930
|
use_mode: bool, optional
|
|
1931
|
-
|
|
1931
|
+
Whether the latent tensor should be set as the latent distribution mode.
|
|
1932
1932
|
In the case of Gaussian, the mode coincides with the mean of the distribution.
|
|
1933
1933
|
Default is `False`.
|
|
1934
1934
|
force_constant_output: bool, optional
|
|
@@ -1944,7 +1944,7 @@ class NonStochasticBlock2d(nn.Module):
|
|
|
1944
1944
|
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
1945
1945
|
Default is `False`.
|
|
1946
1946
|
var_clip_max: float, optional
|
|
1947
|
-
The maximum value reachable by the log-variance of the latent
|
|
1947
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
1948
1948
|
Values exceeding this threshold are clipped. Default is `None`.
|
|
1949
1949
|
"""
|
|
1950
1950
|
debug_qvar_max = 0
|