careamics 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -1,4 +1,4 @@
1
- from typing import Union
1
+ from typing import Union, Optional
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -34,9 +34,6 @@ class Calibration:
34
34
  self._bins = num_bins
35
35
  self._bin_boundaries = None
36
36
 
37
- def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
38
- return np.exp(logvar / 2)
39
-
40
37
  def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
41
38
  """Compute the bin boundaries for `num_bins` bins and predicted std values."""
42
39
  min_std = np.min(predict_std)
@@ -104,65 +101,75 @@ class Calibration:
104
101
  )
105
102
  rmse_stderr = np.sqrt(stderr) if stderr is not None else None
106
103
 
107
- bin_var = np.mean((std_ch[bin_mask] ** 2))
104
+ bin_var = np.mean(std_ch[bin_mask] ** 2)
108
105
  stats_dict[ch_idx]["rmse"].append(bin_error)
109
106
  stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
110
107
  stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
111
108
  stats_dict[ch_idx]["bin_count"].append(bin_size)
109
+ self.stats_dict = stats_dict
112
110
  return stats_dict
113
111
 
112
+ def get_calibrated_factor_for_stdev(
113
+ self,
114
+ pred: Optional[np.ndarray] = None,
115
+ pred_std: Optional[np.ndarray] = None,
116
+ target: Optional[np.ndarray] = None,
117
+ q_s: float = 0.00001,
118
+ q_e: float = 0.99999,
119
+ ) -> dict[str, float]:
120
+ """Calibrate the uncertainty by multiplying the predicted std with a scalar.
114
121
 
115
- def get_calibrated_factor_for_stdev(
116
- pred: Union[np.ndarray, torch.Tensor],
117
- pred_std: Union[np.ndarray, torch.Tensor],
118
- target: Union[np.ndarray, torch.Tensor],
119
- q_s: float = 0.00001,
120
- q_e: float = 0.99999,
121
- num_bins: int = 30,
122
- ) -> dict[str, float]:
123
- """Calibrate the uncertainty by multiplying the predicted std with a scalar.
124
-
125
- Parameters
126
- ----------
127
- pred : Union[np.ndarray, torch.Tensor]
128
- Predicted image, shape (n, h, w, c).
129
- pred_std : Union[np.ndarray, torch.Tensor]
130
- Predicted std, shape (n, h, w, c).
131
- target : Union[np.ndarray, torch.Tensor]
132
- Target image, shape (n, h, w, c).
133
- q_s : float, optional
134
- Start quantile, by default 0.00001.
135
- q_e : float, optional
136
- End quantile, by default 0.99999.
137
- num_bins : int, optional
138
- Number of bins to use for calibration, by default 30.
139
-
140
- Returns
141
- -------
142
- dict[str, float]
143
- Calibrated factor for each channel (slope + intercept).
144
- """
145
- calib = Calibration(num_bins=num_bins)
146
- stats_dict = calib.compute_stats(pred, pred_std, target)
147
- outputs = {}
148
- for ch_idx in stats_dict.keys():
149
- y = stats_dict[ch_idx]["rmse"]
150
- x = stats_dict[ch_idx]["rmv"]
151
- count = stats_dict[ch_idx]["bin_count"]
152
-
153
- first_idx = get_first_index(count, q_s)
154
- last_idx = get_last_index(count, q_e)
155
- x = x[first_idx:-last_idx]
156
- y = y[first_idx:-last_idx]
157
- slope, intercept, *_ = stats.linregress(x, y)
158
- output = {"scalar": slope, "offset": intercept}
159
- outputs[ch_idx] = output
160
- return outputs
122
+ Parameters
123
+ ----------
124
+ stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
125
+ Dictionary containing the stats for each channel.
126
+ q_s : float, optional
127
+ Start quantile, by default 0.00001.
128
+ q_e : float, optional
129
+ End quantile, by default 0.99999.
130
+
131
+ Returns
132
+ -------
133
+ dict[str, float]
134
+ Calibrated factor for each channel (slope + intercept).
135
+ """
136
+ if not hasattr(self, "stats_dict"):
137
+ print("No stats found. Computing stats...")
138
+ if any(v is None for v in [pred, pred_std, target]):
139
+ raise ValueError("pred, pred_std, and target must be provided.")
140
+ self.stats_dict = self.compute_stats(
141
+ pred=pred, pred_std=pred_std, target=target
142
+ )
143
+ outputs = {}
144
+ for ch_idx in self.stats_dict.keys():
145
+ y = self.stats_dict[ch_idx]["rmse"]
146
+ x = self.stats_dict[ch_idx]["rmv"]
147
+ count = self.stats_dict[ch_idx]["bin_count"]
148
+
149
+ first_idx = get_first_index(count, q_s)
150
+ last_idx = get_last_index(count, q_e)
151
+ x = x[first_idx:-last_idx]
152
+ y = y[first_idx:-last_idx]
153
+ slope, intercept, *_ = stats.linregress(x, y)
154
+ output = {"scalar": slope, "offset": intercept}
155
+ outputs[ch_idx] = output
156
+ factors = self.get_factors_array(factors_dict=outputs)
157
+ return outputs, factors
158
+
159
+ def get_factors_array(self, factors_dict: list[dict]):
160
+ """Get the calibration factors as a numpy array."""
161
+ calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
162
+ calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
163
+ calib_offset = [
164
+ factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
165
+ ]
166
+ calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
167
+ return {"scalar": calib_scalar, "offset": calib_offset}
161
168
 
162
169
 
163
170
  def plot_calibration(ax, calibration_stats):
164
- first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
165
- last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
171
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
172
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
166
173
  ax.plot(
167
174
  calibration_stats[0]["rmv"][first_idx:-last_idx],
168
175
  calibration_stats[0]["rmse"][first_idx:-last_idx],
@@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats):
170
177
  label=r"$\hat{C}_0$: Ch1",
171
178
  )
172
179
 
173
- first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
174
- last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
180
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
181
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
175
182
  ax.plot(
176
183
  calibration_stats[1]["rmv"][first_idx:-last_idx],
177
184
  calibration_stats[1]["rmse"][first_idx:-last_idx],
178
185
  "o",
179
- label=r"$\hat{C}_1: : Ch2$",
186
+ label=r"$\hat{C}_1$: Ch2",
180
187
  )
181
-
188
+ # TODO add multichannel
182
189
  ax.set_xlabel("RMV")
183
190
  ax.set_ylabel("RMSE")
184
191
  ax.legend()
@@ -97,7 +97,8 @@ class LCMultiChDloader(MultiChDloader):
97
97
  ]
98
98
 
99
99
  self.N = len(t_list)
100
- self.set_img_sz(self._img_sz, self._grid_sz)
100
+ # TODO where tf is self._img_sz defined?
101
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
101
102
  print(
102
103
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
103
104
  )
@@ -359,8 +359,8 @@ class MultiChDloader:
359
359
  self._noise_data = self._noise_data[
360
360
  t_list, h_start:h_end, w_start:w_end, :
361
361
  ].copy()
362
-
363
- self.set_img_sz(self._img_sz, self._grid_sz)
362
+ # TODO where tf is self._img_sz defined?
363
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
364
364
  print(
365
365
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
366
366
  )
@@ -3,7 +3,7 @@ from enum import Enum
3
3
 
4
4
  class DataType(Enum):
5
5
  Elisa3DData = 0
6
- NicolaData = 1
6
+ HTLIF24Data = 1
7
7
  Pavia3SeqData = 2
8
8
  TavernaSox2GolgiV2 = 3
9
9
  Dao3ChannelWithInput = 4
@@ -7,23 +7,18 @@ It includes functions to:
7
7
  """
8
8
 
9
9
  import os
10
- from typing import List, Literal, Union
10
+ from typing import Optional
11
11
 
12
12
  import matplotlib
13
13
  import matplotlib.pyplot as plt
14
14
  import numpy as np
15
- from scipy import stats
16
15
  import torch
17
- from torch import nn
18
- from torch.utils.data import Dataset
19
16
  from matplotlib.gridspec import GridSpec
20
- from torch.utils.data import DataLoader
17
+ from torch.utils.data import DataLoader, Dataset, Subset
21
18
  from tqdm import tqdm
22
19
 
23
20
  from careamics.lightning import VAEModule
24
-
25
- from careamics.models.lvae.utils import ModelType
26
- from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
21
+ from careamics.utils.metrics import scale_invariant_psnr
27
22
 
28
23
 
29
24
  class TilingMode:
@@ -149,11 +144,10 @@ def plot_crops(
149
144
  tar,
150
145
  tar_hsnr,
151
146
  recon_img_list,
152
- calibration_stats,
147
+ calibration_stats=None,
153
148
  num_samples=2,
154
149
  baseline_preds=None,
155
150
  ):
156
- """ """
157
151
  if baseline_preds is None:
158
152
  baseline_preds = []
159
153
  if len(baseline_preds) > 0:
@@ -164,15 +158,13 @@ def plot_crops(
164
158
  )
165
159
  print("This happens when we want to predict the edges of the image.")
166
160
  return
161
+ color_ch_list = ["goldenrod", "cyan"]
162
+ color_pred = "red"
163
+ insetplot_xmax_value = 10000
164
+ insetplot_xmin_value = -1000
165
+ inset_min_labelsize = 10
166
+ inset_rect = [0.05, 0.05, 0.4, 0.2]
167
167
 
168
- # color_ch_list = ['goldenrod', 'cyan']
169
- # color_pred = 'red'
170
- # insetplot_xmax_value = 10000
171
- # insetplot_xmin_value = -1000
172
- # inset_min_labelsize = 10
173
- # inset_rect = [0.05, 0.05, 0.4, 0.2]
174
-
175
- # Set plot attributes
176
168
  img_sz = 3
177
169
  ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
178
170
  grid_factor = 5
@@ -191,7 +183,6 @@ def plot_crops(
191
183
  )
192
184
  params = {"mathtext.default": "regular"}
193
185
  plt.rcParams.update(params)
194
-
195
186
  # plot baselines
196
187
  for i in range(2, 2 + len(baseline_preds)):
197
188
  for col_idx in range(baseline_preds[0].shape[0]):
@@ -471,52 +462,17 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val
471
462
  plt.colorbar(img_err, ax=ax)
472
463
 
473
464
 
474
- # ------------------------------------------------------------------------------------------------
475
-
476
-
477
- def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
478
- """
479
- Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
480
- """
481
- print(f"Predicting for {idx}")
482
- val_dset.set_img_sz(patch_size, 64)
483
-
484
- with torch.no_grad():
485
- # val_dset.enable_noise()
486
- inp, tar = val_dset[idx]
487
- # val_dset.disable_noise()
488
-
489
- inp = torch.Tensor(inp[None])
490
- tar = torch.Tensor(tar[None])
491
- inp = inp.cuda()
492
- x_normalized = model.normalize_input(inp)
493
- tar = tar.cuda()
494
- tar_normalized = model.normalize_target(tar)
495
-
496
- recon_img_list = []
497
- for _ in range(mmse_count):
498
- recon_normalized, td_data = model(x_normalized)
499
- rec_loss, imgs = model.get_reconstruction_loss(
500
- recon_normalized,
501
- x_normalized,
502
- tar_normalized,
503
- return_predicted_img=True,
504
- )
505
- imgs = model.unnormalize_target(imgs)
506
- recon_img_list.append(imgs.cpu().numpy()[0])
507
-
508
- recon_img_list = np.array(recon_img_list)
509
- return inp, tar, recon_img_list
465
+ # -------------------------------------------------------------------------------------
510
466
 
511
467
 
512
- def get_dset_predictions(
468
+ def get_predictions(
513
469
  model: VAEModule,
514
470
  dset: Dataset,
515
471
  batch_size: int,
516
- loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
472
+ tile_size: Optional[tuple[int, int]] = None,
517
473
  mmse_count: int = 1,
518
474
  num_workers: int = 4,
519
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]:
475
+ ) -> tuple[dict, dict, dict]:
520
476
  """Get patch-wise predictions from a model for the entire dataset.
521
477
 
522
478
  Parameters
@@ -545,6 +501,55 @@ def get_dset_predictions(
545
501
  - losses: Reconstruction losses for the predictions.
546
502
  - psnr: PSNR values for the predictions.
547
503
  """
504
+ if hasattr(dset, "dsets"):
505
+ multifile_stitched_predictions = {}
506
+ multifile_stitched_stds = {}
507
+ for d in dset.dsets:
508
+ stitched_predictions, stitched_stds = get_single_file_mmse(
509
+ model=model,
510
+ dset=d,
511
+ batch_size=batch_size,
512
+ tile_size=tile_size,
513
+ mmse_count=mmse_count,
514
+ num_workers=num_workers,
515
+ )
516
+ # get filename without extension and path
517
+ filename = str(d._fpath).split("/")[-1].split(".")[0]
518
+ multifile_stitched_predictions[filename] = stitched_predictions
519
+ multifile_stitched_stds[filename] = stitched_stds
520
+ return (
521
+ multifile_stitched_predictions,
522
+ multifile_stitched_stds,
523
+ )
524
+ else:
525
+ stitched_predictions, stitched_stds = get_single_file_mmse(
526
+ model=model,
527
+ dset=dset,
528
+ batch_size=batch_size,
529
+ tile_size=tile_size,
530
+ mmse_count=mmse_count,
531
+ num_workers=num_workers,
532
+ )
533
+ # get filename without extension and path
534
+ filename = str(dset._fpath).split("/")[-1].split(".")[0]
535
+ return (
536
+ {filename: stitched_predictions},
537
+ {filename: stitched_stds},
538
+ )
539
+
540
+
541
+ def get_single_file_predictions(
542
+ model: VAEModule,
543
+ dset: Dataset,
544
+ batch_size: int,
545
+ tile_size: Optional[tuple[int, int]] = None,
546
+ grid_size: Optional[int] = None,
547
+ num_workers: int = 4,
548
+ ) -> tuple[np.ndarray, np.ndarray]:
549
+ """Get patch-wise predictions from a model for a single file dataset."""
550
+ if tile_size and grid_size:
551
+ dset.set_img_sz(tile_size, grid_size)
552
+
548
553
  dloader = DataLoader(
549
554
  dset,
550
555
  pin_memory=False,
@@ -552,43 +557,64 @@ def get_dset_predictions(
552
557
  shuffle=False,
553
558
  batch_size=batch_size,
554
559
  )
560
+ model.eval()
561
+ model.cuda()
562
+ tiles = []
563
+ logvar_arr = []
564
+ with torch.no_grad():
565
+ for batch in tqdm(dloader, desc="Predicting tiles"):
566
+ inp, tar = batch
567
+ inp = inp.cuda()
568
+ tar = tar.cuda()
569
+
570
+ # get model output
571
+ rec, _ = model(inp)
555
572
 
556
- gauss_likelihood = model.gaussian_likelihood
557
- nm_likelihood = model.noise_model_likelihood
573
+ # get reconstructed img
574
+ if model.model.predict_logvar is None:
575
+ rec_img = rec
576
+ logvar = torch.tensor([-1])
577
+ else:
578
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
579
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
580
+
581
+ tiles.append(rec_img.cpu().numpy())
582
+
583
+ tile_samples = np.concatenate(tiles, axis=0)
584
+ return stitch_predictions_new(tile_samples, dset)
558
585
 
559
- predictions = []
560
- predictions_std = []
561
- losses = []
586
+
587
+ def get_single_file_mmse(
588
+ model: VAEModule,
589
+ dset: Dataset,
590
+ batch_size: int,
591
+ tile_size: Optional[tuple[int, int]] = None,
592
+ mmse_count: int = 1,
593
+ num_workers: int = 4,
594
+ ) -> tuple[np.ndarray, np.ndarray]:
595
+ """Get patch-wise predictions from a model for a single file dataset."""
596
+ dloader = DataLoader(
597
+ dset,
598
+ pin_memory=False,
599
+ num_workers=num_workers,
600
+ shuffle=False,
601
+ batch_size=batch_size,
602
+ )
603
+ if tile_size:
604
+ dset.set_img_sz(tile_size, tile_size[-1] // 2)
605
+ model.eval()
606
+ model.cuda()
607
+ tile_mmse = []
608
+ tile_stds = []
562
609
  logvar_arr = []
563
- num_channels = dset[0][1].shape[0]
564
- patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
565
610
  with torch.no_grad():
566
- for batch in tqdm(dloader, desc="Predicting patches"):
611
+ for batch in tqdm(dloader, desc="Predicting tiles"):
567
612
  inp, tar = batch
568
613
  inp = inp.cuda()
569
614
  tar = tar.cuda()
570
615
 
571
616
  rec_img_list = []
572
- for mmse_idx in range(mmse_count):
573
-
574
- # TODO: case of HDN left for future refactoring
575
- # if model_type == ModelType.Denoiser:
576
- # assert model.denoise_channel in [
577
- # "Ch1",
578
- # "Ch2",
579
- # "input",
580
- # ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
581
-
582
- # x_normalized_new, tar_new = model.get_new_input_target(
583
- # (inp, tar, *batch[2:])
584
- # )
585
- # rec, _ = model(x_normalized_new)
586
- # rec_loss, imgs = model.get_reconstruction_loss(
587
- # rec,
588
- # tar,
589
- # x_normalized_new,
590
- # return_predicted_img=True,
591
- # )
617
+ for _ in range(mmse_count):
592
618
 
593
619
  # get model output
594
620
  rec, _ = model(inp)
@@ -600,52 +626,21 @@ def get_dset_predictions(
600
626
  else:
601
627
  rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
602
628
  rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
603
- logvar_arr.append(logvar.cpu().numpy())
604
-
605
- # compute reconstruction loss
606
- # if loss_type == "musplit":
607
- # rec_loss = get_reconstruction_loss(
608
- # reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
609
- # )
610
- # elif loss_type == "denoisplit":
611
- # rec_loss = get_reconstruction_loss(
612
- # reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
613
- # )
614
- # elif loss_type == "denoisplit_musplit":
615
- # rec_loss = reconstruction_loss_musplit_denoisplit(
616
- # predictions=rec,
617
- # targets=tar,
618
- # gaussian_likelihood=gauss_likelihood,
619
- # nm_likelihood=nm_likelihood,
620
- # nm_weight=model.loss_parameters.denoisplit_weight,
621
- # gaussian_weight=model.loss_parameters.musplit_weight,
622
- # )
623
- # rec_loss = {"loss": rec_loss} # hacky, but ok for now
624
-
625
- # # store rec loss values for first pred
626
- # if mmse_idx == 0:
627
- # try:
628
- # losses.append(rec_loss["loss"].cpu().numpy())
629
- # except:
630
- # losses.append(rec_loss["loss"])
631
-
632
- # update running PSNR
633
- # for i in range(num_channels):
634
- # patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
629
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
635
630
 
636
631
  # aggregate results
637
632
  samples = torch.cat(rec_img_list, dim=0)
638
633
  mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
639
- # mmse_std = torch.std(samples, dim=0)
640
- predictions.append(mmse_imgs.cpu().numpy())
641
- # predictions_std.append(mmse_std.cpu().numpy())
642
-
643
- # psnr = [x.get() for x in patch_psnr_channels]
644
- return np.concatenate(predictions, axis=0)
645
- # np.concatenate(predictions_std, axis=0),
646
- # np.concatenate(logvar_arr),
647
- # np.array(losses),
648
- # psnr, # TODO revisit !
634
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
635
+
636
+ tile_mmse.append(mmse_imgs.cpu().numpy())
637
+ tile_stds.append(std_imgs.cpu().numpy())
638
+
639
+ tiles_arr = np.concatenate(tile_mmse, axis=0)
640
+ tile_stds = np.concatenate(tile_stds, axis=0)
641
+ stitched_predictions = stitch_predictions_new(tiles_arr, dset)
642
+ stitched_stds = stitch_predictions_new(tile_stds, dset)
643
+ return stitched_predictions, stitched_stds
649
644
 
650
645
 
651
646
  # ------------------------------------------------------------------------------------------
@@ -324,6 +324,8 @@ class NoiseModelLikelihood(LikelihoodModule):
324
324
  if self.data_mean.device != correct_device_tensor.device:
325
325
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
326
326
  self.data_std = self.data_std.to(correct_device_tensor.device)
327
+ if correct_device_tensor.device != self.noiseModel.device:
328
+ self.noiseModel.to_device(correct_device_tensor.device)
327
329
 
328
330
  def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
329
331
  return x, None
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
6
6
  """
7
7
 
8
8
  from collections.abc import Iterable
9
- from typing import Union
9
+ from typing import Optional, Union
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -834,3 +834,15 @@ class LadderVAE(nn.Module):
834
834
  # TODO check if model_3D_depth is needed ?
835
835
  top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
836
836
  return top_layer_shape
837
+
838
+ def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
839
+ """Should be called if we want to predict for a different input/output size."""
840
+ self.mode_pred = True
841
+ if tile_size is None:
842
+ tile_size = self.image_size
843
+ self.image_size = tile_size
844
+ for i in range(self.n_layers):
845
+ self.bottom_up_layers[i].output_expected_shape = (
846
+ ts // 2 ** (i + 1) for ts in tile_size
847
+ )
848
+ self.top_down_layers[i].latent_shape = tile_size