careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,58 +1,94 @@
1
1
  """
2
- Ladder VAE (LVAE) Model
2
+ Ladder VAE (LVAE) Model.
3
3
 
4
- The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
4
+ The current implementation is based on "Interpretable Unsupervised Diversity Denoising
5
+ and Artefact Removal, Prakash et al."
5
6
  """
6
7
 
7
8
  from collections.abc import Iterable
8
- from typing import Dict, List, Tuple
9
+ from typing import Union
9
10
 
10
11
  import numpy as np
11
12
  import torch
12
13
  import torch.nn as nn
13
14
 
14
- from careamics.config.architectures import register_model
15
-
16
15
  from ..activation import get_activation
17
16
  from .layers import (
18
17
  BottomUpDeterministicResBlock,
19
18
  BottomUpLayer,
19
+ GateLayer,
20
20
  TopDownDeterministicResBlock,
21
21
  TopDownLayer,
22
22
  )
23
- from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
23
+ from .utils import Interpolate, ModelType, crop_img_tensor
24
24
 
25
25
 
26
- @register_model("LVAE")
27
26
  class LadderVAE(nn.Module):
27
+ """
28
+ Constructor.
29
+
30
+ Parameters
31
+ ----------
32
+ input_shape : int
33
+ The size of the input image.
34
+ output_channels : int
35
+ The number of output channels.
36
+ multiscale_count : int
37
+ The number of scales for multiscale processing.
38
+ z_dims : list[int]
39
+ The dimensions of the latent space for each layer.
40
+ encoder_n_filters : int
41
+ The number of filters in the encoder.
42
+ decoder_n_filters : int
43
+ The number of filters in the decoder.
44
+ encoder_conv_strides : list[int]
45
+ The strides for the conv layers encoder.
46
+ decoder_conv_strides : list[int]
47
+ The strides for the conv layers decoder.
48
+ encoder_dropout : float
49
+ The dropout rate for the encoder.
50
+ decoder_dropout : float
51
+ The dropout rate for the decoder.
52
+ nonlinearity : str
53
+ The nonlinearity function to use.
54
+ predict_logvar : bool
55
+ Whether to predict the log variance.
56
+ analytical_kl : bool
57
+ Whether to use analytical KL divergence.
58
+
59
+ Raises
60
+ ------
61
+ NotImplementedError
62
+ If only 2D convolutions are supported.
63
+ """
28
64
 
29
65
  def __init__(
30
66
  self,
31
67
  input_shape: int,
32
68
  output_channels: int,
33
69
  multiscale_count: int,
34
- z_dims: List[int],
70
+ z_dims: list[int],
35
71
  encoder_n_filters: int,
36
72
  decoder_n_filters: int,
73
+ encoder_conv_strides: list[int],
74
+ decoder_conv_strides: list[int],
37
75
  encoder_dropout: float,
38
76
  decoder_dropout: float,
39
77
  nonlinearity: str,
40
78
  predict_logvar: bool,
41
79
  analytical_kl: bool,
42
80
  ):
43
- """
44
- Constructor.
45
-
46
- Parameters
47
- ----------
48
-
49
- """
50
81
  super().__init__()
51
82
 
52
83
  # -------------------------------------------------------
53
84
  # Customizable attributes
54
85
  self.image_size = input_shape
86
+ """Input image size. (Z, Y, X) or (Y, X) if the data is 2D."""
87
+ # TODO: we need to be careful with this since used to be an int.
88
+ # the tuple of shapes used to be `self.input_shape`.
55
89
  self.target_ch = output_channels
90
+ self.encoder_conv_strides = encoder_conv_strides
91
+ self.decoder_conv_strides = decoder_conv_strides
56
92
  self._multiscale_count = multiscale_count
57
93
  self.z_dims = z_dims
58
94
  self.encoder_n_filters = encoder_n_filters
@@ -80,7 +116,6 @@ class LadderVAE(nn.Module):
80
116
  self.merge_type = "residual"
81
117
  self.no_initial_downscaling = True
82
118
  self.skip_bottomk_buvalues = 0
83
- self.non_stochastic_version = False
84
119
  self.stochastic_skip = True
85
120
  self.learn_top_prior = True
86
121
  self.res_block_type = "bacdbacd" # TODO remove !
@@ -91,9 +126,7 @@ class LadderVAE(nn.Module):
91
126
  self._enable_topdown_normalize_factor = True
92
127
 
93
128
  # Attributes that handle LC -> Hardcoded
94
- self.enable_multiscale = (
95
- self._multiscale_count is not None and self._multiscale_count > 1
96
- )
129
+ self.enable_multiscale = self._multiscale_count > 1
97
130
  self.multiscale_retain_spatial_dims = True
98
131
  self.multiscale_lowres_separate_branch = False
99
132
  self.multiscale_decoder_retain_spatial_dims = (
@@ -102,14 +135,6 @@ class LadderVAE(nn.Module):
102
135
 
103
136
  # Derived attributes
104
137
  self.n_layers = len(self.z_dims)
105
- self.encoder_no_padding_mode = (
106
- self.encoder_res_block_skip_padding is True
107
- and self.encoder_res_block_kernel > 1
108
- )
109
- self.decoder_no_padding_mode = (
110
- self.decoder_res_block_skip_padding is True
111
- and self.decoder_res_block_kernel > 1
112
- )
113
138
 
114
139
  # Others...
115
140
  self._tethered_to_input = False
@@ -127,19 +152,41 @@ class LadderVAE(nn.Module):
127
152
 
128
153
  # -------------------------------------------------------
129
154
  # Data attributes
130
- self.color_ch = 1
131
- self.img_shape = (self.image_size, self.image_size)
155
+ self.color_ch = 1 # TODO for now we only support 1 channel
132
156
  self.normalized_input = True
133
157
  # -------------------------------------------------------
134
158
 
135
159
  # -------------------------------------------------------
136
160
  # Loss attributes
137
- self._restricted_kl = False # HC
138
161
  # enabling reconstruction loss on mixed input
139
162
  self.mixed_rec_w = 0
140
163
  self.nbr_consistency_w = 0
141
164
 
142
165
  # -------------------------------------------------------
166
+ # 3D related stuff
167
+ self._mode_3D = len(self.image_size) == 3 # TODO refac
168
+ self._model_3D_depth = self.image_size[0] if self._mode_3D else 1
169
+ self._decoder_mode_3D = len(self.decoder_conv_strides) == 3
170
+ if self._mode_3D and not self._decoder_mode_3D:
171
+ assert self._model_3D_depth % 2 == 1, "3D model depth should be odd"
172
+ assert (
173
+ self._mode_3D is True or self._decoder_mode_3D is False
174
+ ), "Decoder cannot be 3D when encoder is 2D"
175
+ self._squish3d = self._mode_3D and not self._decoder_mode_3D
176
+ self._3D_squisher = (
177
+ None
178
+ if not self._squish3d
179
+ else nn.ModuleList(
180
+ [
181
+ GateLayer(
182
+ channels=self.encoder_n_filters,
183
+ conv_strides=self.encoder_conv_strides,
184
+ )
185
+ for k in range(len(self.z_dims))
186
+ ]
187
+ )
188
+ )
189
+ # TODO: this bit is in the Ashesh's confusing-hacky style... Can we do better?
143
190
 
144
191
  # -------------------------------------------------------
145
192
  # # Training attributes
@@ -168,6 +215,11 @@ class LadderVAE(nn.Module):
168
215
  ### CREATE MODEL BLOCKS
169
216
  # First bottom-up layer: change num channels + downsample by factor 2
170
217
  # unless we want to prevent this
218
+ self.encoder_conv_op = getattr(nn, f"Conv{len(self.encoder_conv_strides)}d")
219
+ # TODO these should be defined for all layers here ?
220
+ self.decoder_conv_op = getattr(nn, f"Conv{len(self.decoder_conv_strides)}d")
221
+ # TODO: would be more readable to have a derived parameters to use like
222
+ # `conv_dims = len(self.encoder_conv_strides)` and then use `Conv{conv_dims}d`
171
223
  stride = 1 if self.no_initial_downscaling else 2
172
224
  self.first_bottom_up = self.create_first_bottom_up(stride)
173
225
 
@@ -191,7 +243,7 @@ class LadderVAE(nn.Module):
191
243
 
192
244
  # Output layer --> Project to target_ch many channels
193
245
  logvar_ch_needed = self.predict_logvar is not None
194
- self.output_layer = self.parameter_net = nn.Conv2d(
246
+ self.output_layer = self.parameter_net = self.decoder_conv_op(
195
247
  self.decoder_n_filters,
196
248
  self.target_ch * (1 + logvar_ch_needed),
197
249
  kernel_size=3,
@@ -205,6 +257,7 @@ class LadderVAE(nn.Module):
205
257
  # PSNR computation on validation.
206
258
  # self.label1_psnr = RunningPSNR()
207
259
  # self.label2_psnr = RunningPSNR()
260
+ # TODO: did you add this?
208
261
 
209
262
  # msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
210
263
  # msg += f' TargetCh: {self.target_ch}'
@@ -217,7 +270,8 @@ class LadderVAE(nn.Module):
217
270
  num_res_blocks: int = 1,
218
271
  ) -> nn.Sequential:
219
272
  """
220
- This method creates the first bottom-up block of the Encoder.
273
+ Method creates the first bottom-up block of the Encoder.
274
+
221
275
  Its role is to perform a first image compression step.
222
276
  It is composed by a sequence of nn.Conv2d + non-linearity +
223
277
  BottomUpDeterministicResBlock (1 or more, default is 1).
@@ -225,29 +279,30 @@ class LadderVAE(nn.Module):
225
279
  Parameters
226
280
  ----------
227
281
  init_stride: int
228
- The stride used by the initial Conv2d block.
282
+ The stride used by the intial Conv2d block.
229
283
  num_res_blocks: int, optional
230
- The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
284
+ The number of BottomUpDeterministicResBlocks, default is 1.
231
285
  """
286
+ # From what I got from Ashesh, Z should not be touched in any case.
232
287
  nonlin = get_activation(self.nonlin)
233
- modules = [
234
- nn.Conv2d(
235
- in_channels=self.color_ch,
236
- out_channels=self.encoder_n_filters,
237
- kernel_size=self.encoder_res_block_kernel,
238
- padding=(
239
- 0
240
- if self.encoder_res_block_skip_padding
241
- else self.encoder_res_block_kernel // 2
242
- ),
243
- stride=init_stride,
288
+ conv_block = self.encoder_conv_op(
289
+ in_channels=self.color_ch,
290
+ out_channels=self.encoder_n_filters,
291
+ kernel_size=self.encoder_res_block_kernel,
292
+ padding=(
293
+ 0
294
+ if self.encoder_res_block_skip_padding
295
+ else self.encoder_res_block_kernel // 2
244
296
  ),
245
- nonlin,
246
- ]
297
+ stride=init_stride,
298
+ )
299
+
300
+ modules = [conv_block, nonlin]
247
301
 
248
302
  for _ in range(num_res_blocks):
249
303
  modules.append(
250
304
  BottomUpDeterministicResBlock(
305
+ conv_strides=self.encoder_conv_strides,
251
306
  c_in=self.encoder_n_filters,
252
307
  c_out=self.encoder_n_filters,
253
308
  nonlin=nonlin,
@@ -255,7 +310,6 @@ class LadderVAE(nn.Module):
255
310
  batchnorm=self.bottomup_batchnorm,
256
311
  dropout=self.encoder_dropout,
257
312
  res_block_type=self.res_block_type,
258
- skip_padding=self.encoder_res_block_skip_padding,
259
313
  res_block_kernel=self.encoder_res_block_kernel,
260
314
  )
261
315
  )
@@ -264,7 +318,8 @@ class LadderVAE(nn.Module):
264
318
 
265
319
  def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
266
320
  """
267
- This method creates the stack of bottom-up layers of the Encoder
321
+ Method creates the stack of bottom-up layers of the Encoder.
322
+
268
323
  that are used to generate the so-called `bu_values`.
269
324
 
270
325
  NOTE:
@@ -274,8 +329,9 @@ class LadderVAE(nn.Module):
274
329
  Parameters
275
330
  ----------
276
331
  lowres_separate_branch: bool
277
- Whether the residual block(s) used for encoding the low-res input are shared (`False`) or
278
- not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
332
+ Whether the residual block(s) used for encoding the low-res input are shared
333
+ (`False`) or not (`True`) with the "same-size" residual block(s) in the
334
+ `BottomUpLayer`'s primary flow.
279
335
  """
280
336
  multiscale_lowres_size_factor = 1
281
337
  nonlin = get_activation(self.nonlin)
@@ -294,11 +350,11 @@ class LadderVAE(nn.Module):
294
350
  # N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
295
351
  multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
296
352
 
297
- output_expected_shape = (
298
- (self.img_shape[0] // 2 ** (i + 1), self.img_shape[1] // 2 ** (i + 1))
299
- if self._multiscale_count > 1
300
- else None
301
- )
353
+ # TODO: check correctness of this
354
+ if self._multiscale_count > 1:
355
+ output_expected_shape = (dim // 2 ** (i + 1) for dim in self.image_size)
356
+ else:
357
+ output_expected_shape = None
302
358
 
303
359
  # Add bottom-up deterministic layer at level i.
304
360
  # It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
@@ -308,14 +364,14 @@ class LadderVAE(nn.Module):
308
364
  n_filters=self.encoder_n_filters,
309
365
  downsampling_steps=self.downsample[i],
310
366
  nonlin=nonlin,
367
+ conv_strides=self.encoder_conv_strides,
311
368
  batchnorm=self.bottomup_batchnorm,
312
369
  dropout=self.encoder_dropout,
313
370
  res_block_type=self.res_block_type,
314
371
  res_block_kernel=self.encoder_res_block_kernel,
315
- res_block_skip_padding=self.encoder_res_block_skip_padding,
316
372
  gated=self.gated,
317
373
  lowres_separate_branch=lowres_separate_branch,
318
- enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
374
+ enable_multiscale=self.enable_multiscale, # TODO: shouldn't the arg be `layer_enable_multiscale` here?
319
375
  multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
320
376
  multiscale_lowres_size_factor=multiscale_lowres_size_factor,
321
377
  decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
@@ -327,7 +383,8 @@ class LadderVAE(nn.Module):
327
383
 
328
384
  def create_top_down_layers(self) -> nn.ModuleList:
329
385
  """
330
- This method creates the stack of top-down layers of the Decoder.
386
+ Method creates the stack of top-down layers of the Decoder.
387
+
331
388
  In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
332
389
  of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
333
390
  with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
@@ -346,8 +403,6 @@ class LadderVAE(nn.Module):
346
403
  When doing unconditional generation, bu_value is not available. Hence the
347
404
  merge layer is not used, and z is sampled directly from p_params.
348
405
 
349
- Parameters
350
- ----------
351
406
  """
352
407
  top_down_layers = nn.ModuleList([])
353
408
  nonlin = get_activation(self.nonlin)
@@ -356,7 +411,7 @@ class LadderVAE(nn.Module):
356
411
  # Check if this is the top layer
357
412
  is_top = i == self.n_layers - 1
358
413
 
359
- if self._enable_topdown_normalize_factor:
414
+ if self._enable_topdown_normalize_factor: # TODO: What is this?
360
415
  normalize_latent_factor = (
361
416
  1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
362
417
  )
@@ -369,7 +424,8 @@ class LadderVAE(nn.Module):
369
424
  n_res_blocks=self.decoder_blocks_per_layer,
370
425
  n_filters=self.decoder_n_filters,
371
426
  is_top_layer=is_top,
372
- downsampling_steps=self.downsample[i],
427
+ conv_strides=self.decoder_conv_strides,
428
+ upsampling_steps=self.downsample[i],
373
429
  nonlin=nonlin,
374
430
  merge_type=self.merge_type,
375
431
  batchnorm=self.topdown_batchnorm,
@@ -379,17 +435,11 @@ class LadderVAE(nn.Module):
379
435
  top_prior_param_shape=self.get_top_prior_param_shape(),
380
436
  res_block_type=self.res_block_type,
381
437
  res_block_kernel=self.decoder_res_block_kernel,
382
- res_block_skip_padding=self.decoder_res_block_skip_padding,
383
438
  gated=self.gated,
384
439
  analytical_kl=self.analytical_kl,
385
- restricted_kl=self._restricted_kl,
386
440
  vanilla_latent_hw=self.get_latent_spatial_size(i),
387
- # in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
388
- bottomup_no_padding_mode=self.encoder_no_padding_mode,
389
- topdown_no_padding_mode=self.decoder_no_padding_mode,
390
441
  retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
391
- non_stochastic_version=self.non_stochastic_version,
392
- input_image_shape=self.img_shape,
442
+ input_image_shape=self.image_size,
393
443
  normalize_latent_factor=normalize_latent_factor,
394
444
  conv2d_bias=self.topdown_conv2d_bias,
395
445
  stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
@@ -398,8 +448,10 @@ class LadderVAE(nn.Module):
398
448
  return top_down_layers
399
449
 
400
450
  def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
401
- """
402
- This method creates the final top-down layer of the Decoder.
451
+ """Create the final top-down layer of the Decoder.
452
+
453
+ NOTE: In this layer, (optional) upsampling is performed by bilinear interpolation
454
+ instead of transposed convolution (like in other TD layers).
403
455
 
404
456
  Parameters
405
457
  ----------
@@ -419,69 +471,76 @@ class LadderVAE(nn.Module):
419
471
  c_in=self.decoder_n_filters,
420
472
  c_out=self.decoder_n_filters,
421
473
  nonlin=get_activation(self.nonlin),
474
+ conv_strides=self.decoder_conv_strides,
422
475
  batchnorm=self.topdown_batchnorm,
423
476
  dropout=self.decoder_dropout,
424
477
  res_block_type=self.res_block_type,
425
478
  res_block_kernel=self.decoder_res_block_kernel,
426
- skip_padding=self.decoder_res_block_skip_padding,
427
479
  gated=self.gated,
428
480
  conv2d_bias=self.topdown_conv2d_bias,
429
481
  )
430
482
  )
431
483
  return nn.Sequential(*modules)
432
484
 
433
- def _init_multires(
434
- self, config=None
435
- ) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
485
+ def _init_multires(self, config=None) -> nn.ModuleList:
436
486
  """
437
- This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
438
- in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
439
- in the primary flow of the Encoder, namely to compress the lateral input image to a degree that is compatible with the
440
- one of the primary flow.
441
-
442
- NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity + BottomUpDeterministicResBlock.
443
- It is meaningful to observe that the `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
444
- in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.). Moreover, it does not perform downsampling.
445
-
446
- NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the bottom-up pass.
447
- In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
487
+ Method defines the input block/branch to encode/compress low-res lateral inputs.
488
+
489
+ at different hierarchical levels
490
+ in the multiresolution approach (LC). The role of the input branches is similar
491
+ to the one of the first bottom-up layer in the primary flow of the Encoder,
492
+ namely to compress the lateral input image to a degree that is compatible with
493
+ the one of the primary flow.
494
+
495
+ NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity
496
+ + BottomUpDeterministicResBlock. It is meaningful to observe that the
497
+ `BottomUpDeterministicResBlock` shares the same model attributes with the blocks
498
+ in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.).
499
+ Moreover, it does not perform downsampling.
500
+
501
+ NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the
502
+ bottom-up pass. In other terms if we have the input patch and n_LC additional
503
+ lateral inputs, we will have a total of (n_LC + 1) inputs.
448
504
  """
449
505
  stride = 1 if self.no_initial_downscaling else 2
450
506
  nonlin = get_activation(self.nonlin)
451
507
  if self._multiscale_count is None:
452
508
  self._multiscale_count = 1
453
509
 
454
- msg = "Multiscale count({}) should not exceed the number of bottom up layers ({}) by more than 1"
455
- msg = msg.format(self._multiscale_count, self.n_layers)
510
+ msg = (
511
+ f"Multiscale count ({self._multiscale_count}) should not exceed the number"
512
+ f"of bottom up layers ({self.n_layers}) by more than 1.\n"
513
+ )
456
514
  assert (
457
515
  self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
458
- ), msg
516
+ ), msg # TODO how ?
459
517
 
460
518
  msg = (
461
- "if multiscale is enabled, then we are just working with monocrome images."
519
+ "Multiscale approach only supports monocrome images. "
520
+ f"Found instead color_ch={self.color_ch}."
462
521
  )
463
- assert self._multiscale_count == 1 or self.color_ch == 1, msg
522
+ # assert self._multiscale_count == 1 or self.color_ch == 1, msg
464
523
 
465
524
  lowres_first_bottom_ups = []
466
525
  for _ in range(1, self._multiscale_count):
467
526
  first_bottom_up = nn.Sequential(
468
- nn.Conv2d(
527
+ self.encoder_conv_op(
469
528
  in_channels=self.color_ch,
470
529
  out_channels=self.encoder_n_filters,
471
530
  kernel_size=5,
472
- padding=2,
531
+ padding="same",
473
532
  stride=stride,
474
533
  ),
475
534
  nonlin,
476
535
  BottomUpDeterministicResBlock(
477
536
  c_in=self.encoder_n_filters,
478
537
  c_out=self.encoder_n_filters,
538
+ conv_strides=self.encoder_conv_strides,
479
539
  nonlin=nonlin,
480
540
  downsample=False,
481
541
  batchnorm=self.bottomup_batchnorm,
482
542
  dropout=self.encoder_dropout,
483
543
  res_block_type=self.res_block_type,
484
- skip_padding=self.encoder_res_block_skip_padding,
485
544
  ),
486
545
  )
487
546
  lowres_first_bottom_ups.append(first_bottom_up)
@@ -493,10 +552,9 @@ class LadderVAE(nn.Module):
493
552
  )
494
553
 
495
554
  ### SET OF FORWARD-LIKE METHODS
496
- def bottomup_pass(self, inp: torch.Tensor) -> List[torch.Tensor]:
497
- """
498
- Wrapper of _bottomup_pass().
499
- """
555
+ def bottomup_pass(self, inp: torch.Tensor) -> list[torch.Tensor]:
556
+ """Wrapper of _bottomup_pass()."""
557
+ # TODO Remove wrapper
500
558
  return self._bottomup_pass(
501
559
  inp,
502
560
  self.first_bottom_up,
@@ -510,9 +568,10 @@ class LadderVAE(nn.Module):
510
568
  first_bottom_up: nn.Sequential,
511
569
  lowres_first_bottom_ups: nn.ModuleList,
512
570
  bottom_up_layers: nn.ModuleList,
513
- ) -> List[torch.Tensor]:
571
+ ) -> list[torch.Tensor]:
514
572
  """
515
- This method defines the forward pass through the LVAE Encoder, the so-called
573
+ Method defines the forward pass through the LVAE Encoder, the so-called.
574
+
516
575
  Bottom-Up pass.
517
576
 
518
577
  Parameters
@@ -541,7 +600,6 @@ class LadderVAE(nn.Module):
541
600
  lowres_x = None
542
601
  if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
543
602
  lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
544
-
545
603
  x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
546
604
  bu_values.append(bu_value)
547
605
 
@@ -549,41 +607,40 @@ class LadderVAE(nn.Module):
549
607
 
550
608
  def topdown_pass(
551
609
  self,
552
- bu_values: torch.Tensor = None,
553
- n_img_prior: torch.Tensor = None,
554
- mode_layers: Iterable[int] = None,
555
- constant_layers: Iterable[int] = None,
556
- forced_latent: List[torch.Tensor] = None,
557
- top_down_layers: nn.ModuleList = None,
558
- final_top_down_layer: nn.Sequential = None,
559
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
610
+ bu_values: Union[torch.Tensor, None] = None,
611
+ n_img_prior: Union[torch.Tensor, None] = None,
612
+ constant_layers: Union[Iterable[int], None] = None,
613
+ forced_latent: Union[list[torch.Tensor], None] = None,
614
+ top_down_layers: Union[nn.ModuleList, None] = None,
615
+ final_top_down_layer: Union[nn.Sequential, None] = None,
616
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
560
617
  """
561
- This method defines the forward pass through the LVAE Decoder, the so-called
618
+ Method defines the forward pass through the LVAE Decoder, the so-called.
619
+
562
620
  Top-Down pass.
563
621
 
564
622
  Parameters
565
623
  ----------
566
624
  bu_values: torch.Tensor, optional
567
- Output of the bottom-up pass. It will have values from multiple layers of the ladder.
625
+ Output of the bottom-up pass. It will have values from multiple layers of
626
+ the ladder.
568
627
  n_img_prior: optional
569
- When `bu_values` is `None`, `n_img_prior` indicates the number of images to generate
628
+ When `bu_values` is `None`, `n_img_prior` indicates the number of images to
629
+ generate
570
630
  from the prior (so bottom-up pass is not used at all here).
571
- mode_layers: Iterable[int], optional
572
- A sequence of indexes associated to the layers in which sampling is disabled and
573
- the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
574
631
  constant_layers: Iterable[int], optional
575
- A sequence of indexes associated to the layers in which a single instance's z is
576
- copied over the entire batch (bottom-up path is not used, so only prior is used here).
577
- Set to `None` to avoid this behaviour.
578
- forced_latent: List[torch.Tensor], optional
579
- A list of tensors that are used as fixed latent variables (hence, sampling doesn't take
580
- place in this case).
632
+ A sequence of indexes associated to the layers in which a single instance's
633
+ z is copied over the entire batch (bottom-up path is not used, so only prior
634
+ is used here). Set to `None` to avoid this behaviour.
635
+ forced_latent: list[torch.Tensor], optional
636
+ A list of tensors that are used as fixed latent variables (hence, sampling
637
+ doesn't take place in this case).
581
638
  top_down_layers: nn.ModuleList, optional
582
- A list of top-down layers to use in the top-down pass. If `None`, the method uses the
583
- default layers defined in the constructor.
639
+ A list of top-down layers to use in the top-down pass. If `None`, the method
640
+ uses the default layers defined in the constructor.
584
641
  final_top_down_layer: nn.Sequential, optional
585
- The last top-down layer of the top-down pass. If `None`, the method uses the default
586
- layers defined in the constructor.
642
+ The last top-down layer of the top-down pass. If `None`, the method uses the
643
+ default layers defined in the constructor.
587
644
  """
588
645
  if top_down_layers is None:
589
646
  top_down_layers = self.top_down_layers
@@ -591,11 +648,9 @@ class LadderVAE(nn.Module):
591
648
  final_top_down_layer = self.final_top_down
592
649
 
593
650
  # Default: no layer is sampled from the distribution's mode
594
- if mode_layers is None:
595
- mode_layers = []
596
651
  if constant_layers is None:
597
652
  constant_layers = []
598
- prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0
653
+ prior_experiment = len(constant_layers) > 0
599
654
 
600
655
  # If the bottom-up inference values are not given, don't do
601
656
  # inference, sample from prior instead
@@ -608,11 +663,7 @@ class LadderVAE(nn.Module):
608
663
  "if and only if we're not doing inference"
609
664
  )
610
665
  raise RuntimeError(msg)
611
- if (
612
- inference_mode
613
- and prior_experiment
614
- and (self.non_stochastic_version is False)
615
- ):
666
+ if inference_mode and prior_experiment:
616
667
  msg = (
617
668
  "Prior experiments (e.g. sampling from mode) are not"
618
669
  " compatible with inference mode"
@@ -621,34 +672,24 @@ class LadderVAE(nn.Module):
621
672
 
622
673
  # Sampled latent variables at each layer
623
674
  z = [None] * self.n_layers
624
-
625
675
  # KL divergence of each layer
626
676
  kl = [None] * self.n_layers
627
677
  # Kl divergence restricted, only for the LC enabled setup denoiSplit.
628
678
  kl_restricted = [None] * self.n_layers
629
-
630
679
  # mean from which z is sampled.
631
680
  q_mu = [None] * self.n_layers
632
681
  # log(var) from which z is sampled.
633
682
  q_lv = [None] * self.n_layers
634
-
635
683
  # Spatial map of KL divergence for each layer
636
684
  kl_spatial = [None] * self.n_layers
637
-
638
685
  debug_qvar_max = [None] * self.n_layers
639
-
640
686
  kl_channelwise = [None] * self.n_layers
641
-
642
687
  if forced_latent is None:
643
688
  forced_latent = [None] * self.n_layers
644
689
 
645
- # log p(z) where z is the sample in the topdown pass
646
- # logprob_p = 0.
647
-
648
690
  # Top-down inference/generation loop
649
- out = out_pre_residual = None
691
+ out = None
650
692
  for i in reversed(range(self.n_layers)):
651
-
652
693
  # If available, get deterministic node from bottom-up inference
653
694
  try:
654
695
  bu_value = bu_values[i]
@@ -656,26 +697,23 @@ class LadderVAE(nn.Module):
656
697
  bu_value = None
657
698
 
658
699
  # Whether the current layer should be sampled from the mode
659
- use_mode = i in mode_layers
660
700
  constant_out = i in constant_layers
661
701
 
662
702
  # Input for skip connection
663
- skip_input = out # TODO or n? or both?
703
+ skip_input = out
664
704
 
665
705
  # Full top-down layer, including sampling and deterministic part
666
- out, out_pre_residual, aux = top_down_layers[i](
706
+ out, aux = top_down_layers[i](
667
707
  input_=out,
668
708
  skip_connection_input=skip_input,
669
709
  inference_mode=inference_mode,
670
710
  bu_value=bu_value,
671
711
  n_img_prior=n_img_prior,
672
- use_mode=use_mode,
673
712
  force_constant_output=constant_out,
674
713
  forced_latent=forced_latent[i],
675
714
  mode_pred=self.mode_pred,
676
715
  var_clip_max=self._var_clip_max,
677
716
  )
678
-
679
717
  # Save useful variables
680
718
  z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
681
719
  kl[i] = aux["kl_samplewise"] # (batch, )
@@ -708,8 +746,10 @@ class LadderVAE(nn.Module):
708
746
  }
709
747
  return out, data
710
748
 
711
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
749
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
712
750
  """
751
+ Forward pass through the LVAE model.
752
+
713
753
  Parameters
714
754
  ----------
715
755
  x: torch.Tensor
@@ -717,124 +757,40 @@ class LadderVAE(nn.Module):
717
757
  """
718
758
  img_size = x.size()[2:]
719
759
 
720
- # Pad input to size equal to the closest power of 2
721
- x_pad = self.pad_input(x)
722
-
723
760
  # Bottom-up inference: return list of length n_layers (bottom to top)
724
- bu_values = self.bottomup_pass(x_pad)
761
+ bu_values = self.bottomup_pass(x)
725
762
  for i in range(0, self.skip_bottomk_buvalues):
726
763
  bu_values[i] = None
727
764
 
728
- mode_layers = range(self.n_layers) if self.non_stochastic_version else None
765
+ if self._squish3d:
766
+ bu_values = [
767
+ torch.mean(self._3D_squisher[k](bu_value), dim=2)
768
+ for k, bu_value in enumerate(bu_values)
769
+ ]
729
770
 
730
771
  # Top-down inference/generation
731
- out, td_data = self.topdown_pass(bu_values, mode_layers=mode_layers)
772
+ out, td_data = self.topdown_pass(bu_values)
732
773
 
733
774
  if out.shape[-1] > img_size[-1]:
734
775
  # Restore original image size
735
776
  out = crop_img_tensor(out, img_size)
736
777
 
737
778
  out = self.output_layer(out)
738
- if self._tethered_to_input:
739
- assert out.shape[1] == 1
740
- ch2 = self.get_other_channel(out, x_pad)
741
- out = torch.cat([out, ch2], dim=1)
742
779
 
743
780
  return out, td_data
744
781
 
745
- ### SET OF UTILS METHODS
746
- # def sample_prior(
747
- # self,
748
- # n_imgs,
749
- # mode_layers=None,
750
- # constant_layers=None
751
- # ):
752
-
753
- # # Generate from prior
754
- # out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
755
- # out = crop_img_tensor(out, self.img_shape)
756
-
757
- # # Log likelihood and other info (per data point)
758
- # _, likelihood_data = self.likelihood(out, None)
759
-
760
- # return likelihood_data['sample']
761
-
762
- # ### ???
763
- # def sample_from_q(self, x, masks=None):
764
- # """
765
- # This method performs the bottomup_pass() and samples from the
766
- # obtained distribution.
767
- # """
768
- # img_size = x.size()[2:]
769
-
770
- # # Pad input to make everything easier with conv strides
771
- # x_pad = self.pad_input(x)
772
-
773
- # # Bottom-up inference: return list of length n_layers (bottom to top)
774
- # bu_values = self.bottomup_pass(x_pad)
775
- # return self._sample_from_q(bu_values, masks=masks)
776
- # ### ???
777
-
778
- # def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
779
- # if top_down_layers is None:
780
- # top_down_layers = self.top_down_layers
781
- # if final_top_down_layer is None:
782
- # final_top_down_layer = self.final_top_down
783
- # if masks is None:
784
- # masks = [None] * len(bu_values)
785
-
786
- # msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
787
- # assert self.n_layers == 1, msg
788
- # samples = []
789
- # for i in reversed(range(self.n_layers)):
790
- # bu_value = bu_values[i]
791
-
792
- # # Note that the first argument can be set to None since we are just dealing with one level
793
- # sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
794
- # samples.append(sample)
795
-
796
- # return samples
797
-
798
- def reset_for_different_output_size(self, output_size: int) -> None:
799
- """Reset shape of output and latent tensors for different output size.
800
-
801
- Used during evaluation to reset expected shapes of tensors when
802
- input/output shape changes.
803
- For instance, it is needed when the model was trained on, say, 64x64 sized
804
- patches, but prediction is done on 128x128 patches.
805
- """
806
- for i in range(self.n_layers):
807
- sz = output_size // 2 ** (1 + i)
808
- self.bottom_up_layers[i].output_expected_shape = (sz, sz)
809
- self.top_down_layers[i].latent_shape = (output_size, output_size)
810
-
811
- def pad_input(self, x):
812
- """
813
- Pads input x so that its sizes are powers of 2
814
- :param x:
815
- :return: Padded tensor
816
- """
817
- size = self.get_padded_size(x.size())
818
- x = pad_img_tensor(x, size)
819
- return x
820
-
821
782
  ### SET OF GETTERS
822
783
  def get_padded_size(self, size):
823
784
  """
824
785
  Returns the smallest size (H, W) of the image with actual size given
825
786
  as input, such that H and W are powers of 2.
826
- :param size: input size, tuple either (N, C, H, w) or (H, W)
787
+ :param size: input size, tuple either (N, C, H, W) or (H, W)
827
788
  :return: 2-tuple (H, W)
828
789
  """
829
790
  # Make size argument into (heigth, width)
830
- if len(size) == 4:
831
- size = size[2:]
832
- if len(size) != 2:
833
- msg = (
834
- "input size must be either (N, C, H, W) or (H, W), but it "
835
- f"has length {len(size)} (size={size})"
836
- )
837
- raise RuntimeError(msg)
791
+ # assert len(size) in [2, 4, 5] # TODO commented out cuz it's weird
792
+ # We're only interested in the Y,X dimensions
793
+ size = size[-2:]
838
794
 
839
795
  if self.multiscale_decoder_retain_spatial_dims is True:
840
796
  # In this case, we can go much more deeper and so this is not required
@@ -845,24 +801,21 @@ class LadderVAE(nn.Module):
845
801
  dwnsc = self.overall_downscale_factor
846
802
 
847
803
  # Output smallest powers of 2 that are larger than current sizes
848
- padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size)
849
-
804
+ padded_size = [((s - 1) // dwnsc + 1) * dwnsc for s in size]
805
+ # TODO Needed for pad/crop odd sizes. Move to dataset?
850
806
  return padded_size
851
807
 
852
808
  def get_latent_spatial_size(self, level_idx: int):
853
- """
854
- level_idx: 0 is the bottommost layer, the highest resolution one.
855
- """
809
+ """Level_idx: 0 is the bottommost layer, the highest resolution one."""
856
810
  actual_downsampling = level_idx + 1
857
811
  dwnsc = 2**actual_downsampling
858
- sz = self.get_padded_size(self.img_shape)
812
+ sz = self.get_padded_size(self.image_size)
859
813
  h = sz[0] // dwnsc
860
814
  w = sz[1] // dwnsc
861
815
  assert h == w
862
816
  return h
863
817
 
864
818
  def get_top_prior_param_shape(self, n_imgs: int = 1):
865
- # TODO num channels depends on random variable we're using
866
819
 
867
820
  # Compute the total downscaling performed in the Encoder
868
821
  if self.multiscale_decoder_retain_spatial_dims is False:
@@ -872,26 +825,12 @@ class LadderVAE(nn.Module):
872
825
  actual_downsampling = self.n_layers + 1 - self._multiscale_count
873
826
  dwnsc = 2**actual_downsampling
874
827
 
875
- sz = self.get_padded_size(self.img_shape)
876
- h = sz[0] // dwnsc
877
- w = sz[1] // dwnsc
878
- c = self.z_dims[-1] * 2 # mu and logvar
879
- top_layer_shape = (n_imgs, c, h, w)
828
+ h = self.image_size[-2] // dwnsc
829
+ w = self.image_size[-1] // dwnsc
830
+ mu_logvar = self.z_dims[-1] * 2 # mu and logvar
831
+ top_layer_shape = (n_imgs, mu_logvar, h, w)
832
+ # TODO refactor!
833
+ if self._model_3D_depth > 1 and self._decoder_mode_3D is True:
834
+ # TODO check if model_3D_depth is needed ?
835
+ top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
880
836
  return top_layer_shape
881
-
882
- def get_other_channel(self, ch1, input):
883
- assert self.data_std["target"].squeeze().shape == (2,)
884
- assert self.data_mean["target"].squeeze().shape == (2,)
885
- assert self.target_ch == 2
886
- ch1_un = (
887
- ch1[:, :1] * self.data_std["target"][:, :1]
888
- + self.data_mean["target"][:, :1]
889
- )
890
- input_un = input * self.data_std["input"] + self.data_mean["input"]
891
- ch2_un = self._tethered_ch2_scalar * (
892
- input_un - ch1_un * self._tethered_ch1_scalar
893
- )
894
- ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
895
- :, -1:
896
- ]
897
- return ch2