waveorder 2.2.1b0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. waveorder/_version.py +16 -3
  2. waveorder/acq/__init__.py +0 -0
  3. waveorder/acq/acq_functions.py +166 -0
  4. waveorder/assets/HSV_legend.png +0 -0
  5. waveorder/assets/JCh_legend.png +0 -0
  6. waveorder/assets/waveorder_plugin_logo.png +0 -0
  7. waveorder/calib/Calibration.py +1512 -0
  8. waveorder/calib/Optimization.py +470 -0
  9. waveorder/calib/__init__.py +0 -0
  10. waveorder/calib/calibration_workers.py +464 -0
  11. waveorder/cli/apply_inverse_models.py +328 -0
  12. waveorder/cli/apply_inverse_transfer_function.py +379 -0
  13. waveorder/cli/compute_transfer_function.py +432 -0
  14. waveorder/cli/gui_widget.py +58 -0
  15. waveorder/cli/main.py +39 -0
  16. waveorder/cli/monitor.py +163 -0
  17. waveorder/cli/option_eat_all.py +47 -0
  18. waveorder/cli/parsing.py +122 -0
  19. waveorder/cli/printing.py +16 -0
  20. waveorder/cli/reconstruct.py +67 -0
  21. waveorder/cli/settings.py +187 -0
  22. waveorder/cli/utils.py +175 -0
  23. waveorder/filter.py +1 -2
  24. waveorder/focus.py +136 -25
  25. waveorder/io/__init__.py +0 -0
  26. waveorder/io/_reader.py +61 -0
  27. waveorder/io/core_functions.py +272 -0
  28. waveorder/io/metadata_reader.py +195 -0
  29. waveorder/io/utils.py +175 -0
  30. waveorder/io/visualization.py +160 -0
  31. waveorder/models/inplane_oriented_thick_pol3d_vector.py +3 -3
  32. waveorder/models/isotropic_fluorescent_thick_3d.py +92 -0
  33. waveorder/models/isotropic_fluorescent_thin_3d.py +331 -0
  34. waveorder/models/isotropic_thin_3d.py +73 -72
  35. waveorder/models/phase_thick_3d.py +103 -4
  36. waveorder/napari.yaml +36 -0
  37. waveorder/plugin/__init__.py +9 -0
  38. waveorder/plugin/gui.py +1094 -0
  39. waveorder/plugin/gui.ui +1440 -0
  40. waveorder/plugin/job_manager.py +42 -0
  41. waveorder/plugin/main_widget.py +1605 -0
  42. waveorder/plugin/tab_recon.py +3294 -0
  43. waveorder/scripts/__init__.py +0 -0
  44. waveorder/scripts/launch_napari.py +13 -0
  45. waveorder/scripts/repeat-cal-acq-rec.py +147 -0
  46. waveorder/scripts/repeat-calibration.py +31 -0
  47. waveorder/scripts/samples.py +85 -0
  48. waveorder/scripts/simulate_zarr_acq.py +204 -0
  49. waveorder/util.py +1 -1
  50. waveorder/visuals/napari_visuals.py +1 -1
  51. waveorder-3.0.0.dist-info/METADATA +350 -0
  52. waveorder-3.0.0.dist-info/RECORD +69 -0
  53. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info}/WHEEL +1 -1
  54. waveorder-3.0.0.dist-info/entry_points.txt +5 -0
  55. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info/licenses}/LICENSE +13 -1
  56. waveorder-2.2.1b0.dist-info/METADATA +0 -187
  57. waveorder-2.2.1b0.dist-info/RECORD +0 -27
  58. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ import json
2
+ import os
3
+
4
+ from natsort import natsorted
5
+
6
+
7
+ def load_json(path):
8
+ with open(path, "r") as f:
9
+ data = json.load(f)
10
+
11
+ return data
12
+
13
+
14
+ def get_last_metadata_file(path):
15
+ last_metadata_file = natsorted(
16
+ [
17
+ file
18
+ for file in os.listdir(path)
19
+ if file.startswith("calibration_metadata")
20
+ ]
21
+ )[-1]
22
+ return os.path.join(path, last_metadata_file)
23
+
24
+
25
+ class MetadataReader:
26
+ """
27
+ Calibration metadata reader class. Helps load metadata from different metadata formats and naming conventions
28
+ """
29
+
30
+ def __init__(self, path: str):
31
+ """
32
+
33
+ Parameters
34
+ ----------
35
+ path: full path to calibration metadata
36
+ """
37
+ self.metadata_path = path
38
+ self.json_metadata = load_json(self.metadata_path)
39
+
40
+ self.Timestamp = self.get_summary_calibration_attr("Timestamp")
41
+ self.waveorder_version = self.get_summary_calibration_attr(
42
+ "waveorder version"
43
+ )
44
+ self.Calibration_scheme = self.get_calibration_scheme()
45
+ self.Swing = self.get_swing()
46
+ self.Wavelength = self.get_summary_calibration_attr("Wavelength (nm)")
47
+ self.Black_level = self.get_black_level()
48
+ self.Extinction_ratio = self.get_extinction_ratio()
49
+ self.Channel_names = self.get_channel_names()
50
+ self.LCA_retardance = self.get_lc_retardance("LCA")
51
+ self.LCB_retardance = self.get_lc_retardance("LCB")
52
+ self.LCA_voltage = self.get_lc_voltage("LCA")
53
+ self.LCB_voltage = self.get_lc_voltage("LCB")
54
+ self.Swing_measured = self.get_swing_measured()
55
+ self.Notes = self.get_notes()
56
+
57
+ def get_summary_calibration_attr(self, attr):
58
+ try:
59
+ val = self.json_metadata["Summary"][attr]
60
+ except KeyError:
61
+ try:
62
+ val = self.json_metadata["Calibration"][attr]
63
+ except KeyError:
64
+ val = None
65
+ return val
66
+
67
+ def get_cal_states(self):
68
+ if self.Calibration_scheme == "4-State":
69
+ states = ["ext", "0", "60", "120"]
70
+ elif self.Calibration_scheme == "5-State":
71
+ states = ["ext", "0", "45", "90", "135"]
72
+ return states
73
+
74
+ def get_lc_retardance(self, lc):
75
+ """
76
+
77
+ Parameters
78
+ ----------
79
+ lc: 'LCA' or 'LCB'
80
+
81
+ Returns
82
+ -------
83
+
84
+ """
85
+ states = self.get_cal_states()
86
+
87
+ val = None
88
+ try:
89
+ val = [
90
+ self.json_metadata["Calibration"]["LC retardance"][
91
+ f"{lc}_{state}"
92
+ ]
93
+ for state in states
94
+ ]
95
+ except KeyError:
96
+ states[0] = "Ext"
97
+ if lc == "LCA":
98
+ val = [
99
+ self.json_metadata["Summary"][
100
+ f"[LCA_{state}, LCB_{state}]"
101
+ ][0]
102
+ for state in states
103
+ ]
104
+ elif lc == "LCB":
105
+ val = [
106
+ self.json_metadata["Summary"][
107
+ f"[LCA_{state}, LCB_{state}]"
108
+ ][1]
109
+ for state in states
110
+ ]
111
+
112
+ return val
113
+
114
+ def get_lc_voltage(self, lc):
115
+ """
116
+
117
+ Parameters
118
+ ----------
119
+ lc: 'LCA' or 'LCB'
120
+
121
+ Returns
122
+ -------
123
+
124
+ """
125
+ states = self.get_cal_states()
126
+
127
+ val = None
128
+ if "Calibration" in self.json_metadata:
129
+ lc_voltage = self.json_metadata["Calibration"]["LC voltage"]
130
+ if lc_voltage:
131
+ val = [
132
+ self.json_metadata["Calibration"]["LC voltage"][
133
+ f"{lc}_{state}"
134
+ ]
135
+ for state in states
136
+ ]
137
+
138
+ return val
139
+
140
+ def get_swing(self):
141
+ try:
142
+ val = self.json_metadata["Calibration"]["Swing (waves)"]
143
+ except KeyError:
144
+ val = self.json_metadata["Summary"]["Swing (fraction)"]
145
+ return val
146
+
147
+ def get_swing_measured(self):
148
+ states = self.get_cal_states()
149
+ try:
150
+ val = [
151
+ self.json_metadata["Calibration"][f"Swing_{state}"]
152
+ for state in states[1:]
153
+ ]
154
+ except KeyError:
155
+ val = [
156
+ self.json_metadata["Summary"][f"Swing{state}"]
157
+ for state in states[1:]
158
+ ]
159
+
160
+ return val
161
+
162
+ def get_calibration_scheme(self):
163
+ try:
164
+ val = self.json_metadata["Calibration"]["Calibration scheme"]
165
+ except KeyError:
166
+ val = self.json_metadata["Summary"]["Acquired Using"]
167
+ return val
168
+
169
+ def get_black_level(self):
170
+ try:
171
+ val = self.json_metadata["Calibration"]["Black level"]
172
+ except KeyError:
173
+ val = self.json_metadata["Summary"]["BlackLevel"]
174
+ return val
175
+
176
+ def get_extinction_ratio(self):
177
+ try:
178
+ val = self.json_metadata["Calibration"]["Extinction ratio"]
179
+ except KeyError:
180
+ val = self.json_metadata["Summary"]["Extinction Ratio"]
181
+ return val
182
+
183
+ def get_channel_names(self):
184
+ try:
185
+ val = self.json_metadata["Calibration"]["Channel names"]
186
+ except KeyError:
187
+ val = self.json_metadata["Summary"]["ChNames"]
188
+ return val
189
+
190
+ def get_notes(self):
191
+ try:
192
+ val = self.json_metadata["Notes"]
193
+ except KeyError:
194
+ val = None
195
+ return val
waveorder/io/utils.py ADDED
@@ -0,0 +1,175 @@
1
+ import os
2
+ import textwrap
3
+ from pathlib import Path
4
+
5
+ import psutil
6
+ import torch
7
+ import yaml
8
+ from iohub import open_ome_zarr
9
+
10
+ from waveorder.cli.settings import MyBaseModel
11
+
12
+
13
+ def add_index_to_path(path: Path):
14
+ """Takes a path to a file or folder and appends the smallest index that does
15
+ not already exist in that folder.
16
+
17
+ For example:
18
+ './output.txt' -> './output_0.txt' if no other files named './output*.txt' exist.
19
+ './output.txt' -> './output_2.txt' if './output_0.txt' and './output_1.txt' already exist.
20
+
21
+ Parameters
22
+ ----------
23
+ path: Path
24
+ Base path to add index to
25
+
26
+ Returns
27
+ -------
28
+ Path
29
+ """
30
+ index = 0
31
+ new_stem = f"{path.stem}_{index}"
32
+
33
+ while (path.parent / (new_stem + path.suffix)).exists():
34
+ index += 1
35
+ new_stem = f"{path.stem}_{index}"
36
+
37
+ return path.parent / (new_stem + path.suffix)
38
+
39
+
40
+ def load_background(background_path):
41
+ with open_ome_zarr(
42
+ os.path.join(background_path, "background.zarr", "0", "0", "0")
43
+ ) as dataset:
44
+ cyx_data = dataset["0"][0, :, 0]
45
+ return torch.tensor(cyx_data, dtype=torch.float32)
46
+
47
+
48
+ class MockEmitter:
49
+ def emit(self, value):
50
+ pass
51
+
52
+
53
+ def ram_message():
54
+ """
55
+ Determine if the system's RAM capacity is sufficient for running reconstruction.
56
+ The message should be treated as a warning if the RAM detected is less than 32 GB.
57
+
58
+ Returns
59
+ -------
60
+ ram_report (is_warning, message)
61
+ """
62
+ BYTES_PER_GB = 2**30
63
+ gb_available = psutil.virtual_memory().total / BYTES_PER_GB
64
+ is_warning = gb_available < 32
65
+
66
+ if is_warning:
67
+ message = " \n".join(
68
+ textwrap.wrap(
69
+ f"waveorder reconstructions often require more than the {gb_available:.1f} "
70
+ f"GB of RAM that this computer is equipped with. We recommend starting with reconstructions of small "
71
+ f"volumes ~1000 x 1000 x 10 and working up to larger volumes while monitoring your RAM usage with "
72
+ f"Task Manager or htop.",
73
+ )
74
+ )
75
+ else:
76
+ message = f"{gb_available:.1f} GB of RAM is available."
77
+
78
+ return (is_warning, message)
79
+
80
+
81
+ def model_to_yaml(model: MyBaseModel, yaml_path: Path) -> None:
82
+ """
83
+ Save a model's dictionary representation to a YAML file.
84
+
85
+ Parameters
86
+ ----------
87
+ model : MyBaseModel
88
+ The model object to convert to YAML.
89
+ yaml_path : Path
90
+ The path to the output YAML file.
91
+
92
+ Raises
93
+ ------
94
+ TypeError
95
+ If the `model` object does not have a `dict()` method.
96
+
97
+ Notes
98
+ -----
99
+ This function converts a model object into a dictionary representation
100
+ using the `dict()` method. It removes any fields with None values before
101
+ writing the dictionary to a YAML file.
102
+
103
+ Examples
104
+ --------
105
+ >>> from my_model import MyModel
106
+ >>> model = MyModel()
107
+ >>> model_to_yaml(model, 'model.yaml')
108
+
109
+ """
110
+ yaml_path = Path(yaml_path)
111
+
112
+ if not hasattr(model, "dict"):
113
+ raise TypeError("The 'model' object does not have a 'dict()' method.")
114
+
115
+ model_dict = model.model_dump()
116
+
117
+ # Remove None-valued fields
118
+ clean_model_dict = {
119
+ key: value for key, value in model_dict.items() if value is not None
120
+ }
121
+
122
+ with open(yaml_path, "w+") as f:
123
+ yaml.dump(
124
+ clean_model_dict, f, default_flow_style=False, sort_keys=False
125
+ )
126
+
127
+
128
+ def yaml_to_model(yaml_path: Path, model):
129
+ """
130
+ Load model settings from a YAML file and create a model instance.
131
+
132
+ Parameters
133
+ ----------
134
+ yaml_path : Path
135
+ The path to the YAML file containing the model settings.
136
+ model : class
137
+ The model class used to create an instance with the loaded settings.
138
+
139
+ Returns
140
+ -------
141
+ object
142
+ An instance of the model class with the loaded settings.
143
+
144
+ Raises
145
+ ------
146
+ TypeError
147
+ If the provided model is not a class or does not have a callable constructor.
148
+ FileNotFoundError
149
+ If the YAML file specified by `yaml_path` does not exist.
150
+
151
+ Notes
152
+ -----
153
+ This function loads model settings from a YAML file using `yaml.safe_load()`.
154
+ It then creates an instance of the provided `model` class using the loaded settings.
155
+
156
+ Examples
157
+ --------
158
+ # >>> from my_model import MyModel
159
+ # >>> model = yaml_to_model('model.yaml', MyModel)
160
+
161
+ """
162
+ yaml_path = Path(yaml_path)
163
+
164
+ if not callable(getattr(model, "__init__", None)):
165
+ raise TypeError(
166
+ "The provided model must be a class with a callable constructor."
167
+ )
168
+
169
+ try:
170
+ with open(yaml_path, "r") as file:
171
+ raw_settings = yaml.safe_load(file)
172
+ except FileNotFoundError:
173
+ raise FileNotFoundError(f"The YAML file '{yaml_path}' does not exist.")
174
+
175
+ return model(**raw_settings)
@@ -0,0 +1,160 @@
1
+ from typing import Literal, Union
2
+
3
+ import numpy as np
4
+ from colorspacious import cspace_convert
5
+ from matplotlib.colors import hsv_to_rgb
6
+ from skimage.color import hsv2rgb
7
+ from skimage.exposure import rescale_intensity
8
+
9
+
10
+ def ret_ori_overlay(
11
+ czyx,
12
+ ret_min: float = 1,
13
+ ret_max: Union[float, Literal["auto"]] = 10,
14
+ cmap: Literal["JCh", "HSV"] = "HSV",
15
+ ):
16
+ """
17
+ Creates an overlay of retardance and orientation with two different colormap options.
18
+ "HSV" maps orientation to hue and retardance to value with maximum saturation.
19
+ "JCh" is a similar colormap but is perceptually uniform.
20
+
21
+ Parameters
22
+ ----------
23
+ czyx: (nd-array) czyx[0] is retardance in nanometers, czyx[1] is orientation in radians [0, pi],
24
+ czyx.shape = (2, ...)
25
+
26
+ ret_min: (float) minimum displayed retardance. Typically a noise floor.
27
+ ret_max: (float) maximum displayed retardance. Typically used to adjust contrast limits.
28
+
29
+ cmap: (str) 'JCh' or 'HSV'
30
+
31
+ Returns
32
+ -------
33
+ overlay (nd-array) RGB image with shape (3, ...)
34
+
35
+ """
36
+ if czyx.shape[0] != 2:
37
+ raise ValueError(
38
+ f"Input must have shape (2, ...) instead of ({czyx.shape[0]}, ...)"
39
+ )
40
+
41
+ retardance = czyx[0]
42
+ orientation = czyx[1]
43
+
44
+ if ret_max == "auto":
45
+ ret_max = np.percentile(np.ravel(retardance), 99.99)
46
+
47
+ # Prepare input and output arrays
48
+ ret_ = np.clip(retardance, 0, ret_max) # clip and copy
49
+ # Convert 180 degree range into 360 to match periodicity of hue.
50
+ ori_ = orientation * 360 / np.pi
51
+ overlay_final = np.zeros_like(retardance)
52
+
53
+ if cmap == "JCh":
54
+ J_MAX = 65
55
+ C_MAX = 60
56
+
57
+ J = (ret_ / ret_max) * J_MAX
58
+ C = np.ones_like(J) * C_MAX
59
+ C[ret_ < ret_min] = 0
60
+ h = ori_
61
+
62
+ JCh = np.stack((J, C, h), axis=-1)
63
+ JCh_rgb = cspace_convert(JCh, "JCh", "sRGB255")
64
+
65
+ JCh_rgb[JCh_rgb < 0] = 0
66
+ JCh_rgb[JCh_rgb > 255] = 255
67
+
68
+ overlay_final = JCh_rgb.astype(np.uint8)
69
+ elif cmap == "HSV":
70
+ I_hsv = np.moveaxis(
71
+ np.stack(
72
+ [
73
+ ori_ / 360,
74
+ np.ones_like(ori_),
75
+ ret_ / np.max(ret_),
76
+ ]
77
+ ),
78
+ source=0,
79
+ destination=-1,
80
+ )
81
+ overlay_final = hsv_to_rgb(I_hsv)
82
+ else:
83
+ raise ValueError(f"Colormap {cmap} not understood")
84
+
85
+ return np.moveaxis(
86
+ overlay_final, source=-1, destination=0
87
+ ) # .shape = (3, ...)
88
+
89
+
90
+ def ret_ori_phase_overlay(
91
+ czyx, max_val_V: float = 1.0, max_val_S: float = 1.0
92
+ ):
93
+ """
94
+ Creates an overlay of retardance, orientation, and phase.
95
+ Maps orientation to hue, retardance to saturation, and phase to value.
96
+
97
+ HSV encoding of retardance + orientation + phase image with hsv colormap
98
+ (orientation in h, retardance in s, phase in v)
99
+ Parameters
100
+ ----------
101
+ czyx : numpy.ndarray
102
+ czyx[0] corresponds to the retardance image
103
+ czyx[1]is the orientation image (range from 0 to pi)
104
+ czyx[2] is the the phase image
105
+
106
+ max_val_V : float
107
+ raise the brightness of the phase channel by 1/max_val_V
108
+
109
+ max_val_S : float
110
+ raise the brightness of the retardance channel by 1/max_val_S
111
+
112
+ Returns
113
+ -------
114
+ overlay (nd-array) RGB image with shape (3, ...)
115
+
116
+ Returns:
117
+ RGB with HSV
118
+ """
119
+
120
+ if czyx.shape[0] != 3:
121
+ raise ValueError(
122
+ f"Input must have shape (3, ...) instead of ({czyx.shape[0]}, ...)"
123
+ )
124
+
125
+ czyx_out = np.zeros_like(czyx, dtype=np.float32)
126
+
127
+ retardance = czyx[0]
128
+ orientation = czyx[1]
129
+ phase = czyx[2]
130
+
131
+ # Normalize the stack
132
+ ordered_stack = np.stack(
133
+ (
134
+ # Normalize the first channel by dividing by pi
135
+ orientation / np.pi,
136
+ # Normalize the second channel and rescale intensity
137
+ rescale_intensity(
138
+ retardance,
139
+ in_range=(
140
+ np.min(retardance),
141
+ np.max(retardance),
142
+ ),
143
+ out_range=(0, 1),
144
+ )
145
+ / max_val_S,
146
+ # Normalize the third channel and rescale intensity
147
+ rescale_intensity(
148
+ phase,
149
+ in_range=(
150
+ np.min(phase),
151
+ np.max(phase),
152
+ ),
153
+ out_range=(0, 1),
154
+ )
155
+ / max_val_V,
156
+ ),
157
+ axis=0,
158
+ )
159
+ czyx_out = hsv2rgb(ordered_stack, channel_axis=0)
160
+ return czyx_out
@@ -65,9 +65,9 @@ def calculate_transfer_function(
65
65
  print("Z factor:", z_factor)
66
66
 
67
67
  tf_calculation_shape = (
68
- zyx_shape[0] * z_factor * fourier_oversample_factor,
69
- int(np.ceil(zyx_shape[1] * yx_factor * fourier_oversample_factor)),
70
- int(np.ceil(zyx_shape[2] * yx_factor * fourier_oversample_factor)),
68
+ int(zyx_shape[0] * z_factor / fourier_oversample_factor),
69
+ int(np.ceil(zyx_shape[1] * yx_factor / fourier_oversample_factor)),
70
+ int(np.ceil(zyx_shape[2] * yx_factor / fourier_oversample_factor)),
71
71
  )
72
72
 
73
73
  (
@@ -31,7 +31,40 @@ def calculate_transfer_function(
31
31
  z_padding: int,
32
32
  index_of_refraction_media: float,
33
33
  numerical_aperture_detection: float,
34
+ confocal_pinhole_diameter: float | None = None,
34
35
  ) -> Tensor:
36
+ """Calculate the optical transfer function for fluorescence imaging.
37
+
38
+ Supports both widefield and confocal microscopy modes. When
39
+ confocal_pinhole_diameter is None, computes widefield OTF. When specified,
40
+ computes confocal OTF by multiplying excitation and detection PSFs, where
41
+ the detection PSF is downweighted by the pinhole aperture function.
42
+
43
+ Parameters
44
+ ----------
45
+ zyx_shape : tuple[int, int, int]
46
+ Shape of the 3D volume
47
+ yx_pixel_size : float
48
+ Pixel size in YX plane
49
+ z_pixel_size : float
50
+ Pixel size in Z dimension
51
+ wavelength_emission : float
52
+ Emission wavelength
53
+ z_padding : int
54
+ Padding for axial dimension
55
+ index_of_refraction_media : float
56
+ Refractive index of imaging medium
57
+ numerical_aperture_detection : float
58
+ Numerical aperture of detection objective
59
+ confocal_pinhole_diameter : float | None, optional
60
+ Diameter of confocal pinhole in image space (demagnified). If None,
61
+ computes widefield OTF. If specified, computes confocal OTF.
62
+
63
+ Returns
64
+ -------
65
+ Tensor
66
+ 3D optical transfer function
67
+ """
35
68
  transverse_nyquist = sampling.transverse_nyquist(
36
69
  wavelength_emission,
37
70
  numerical_aperture_detection, # ill = det for fluorescence
@@ -43,6 +76,11 @@ def calculate_transfer_function(
43
76
  index_of_refraction_media,
44
77
  )
45
78
 
79
+ # For confocal, double the Nyquist range (half the sampling requirement)
80
+ if confocal_pinhole_diameter is not None:
81
+ transverse_nyquist = transverse_nyquist / 2
82
+ axial_nyquist = axial_nyquist / 2
83
+
46
84
  yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
47
85
  z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
48
86
 
@@ -58,6 +96,7 @@ def calculate_transfer_function(
58
96
  z_padding,
59
97
  index_of_refraction_media,
60
98
  numerical_aperture_detection,
99
+ confocal_pinhole_diameter,
61
100
  )
62
101
  zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
63
102
  return sampling.nd_fourier_central_cuboid(
@@ -65,6 +104,35 @@ def calculate_transfer_function(
65
104
  )
66
105
 
67
106
 
107
+ def _calculate_pinhole_aperture_otf(
108
+ radial_frequencies: Tensor,
109
+ pinhole_diameter: float,
110
+ ) -> Tensor:
111
+ """Calculate the pinhole aperture OTF for confocal microscopy.
112
+
113
+ The pinhole acts as a spatial filter in the image plane. A smaller pinhole
114
+ (approaching a point) gives a broader OTF (approaching flat/ones).
115
+ A larger pinhole gives a narrower OTF (approaching a delta function).
116
+
117
+ Parameters
118
+ ----------
119
+ radial_frequencies : Tensor
120
+ Radial spatial frequencies (units of 1/length)
121
+ pinhole_diameter : float
122
+ Diameter (not radius) of the confocal pinhole (units of length, matching
123
+ radial_frequencies)
124
+
125
+ Returns
126
+ -------
127
+ Tensor
128
+ Pinhole aperture OTF (jinc^2 function)
129
+ """
130
+ argument = pinhole_diameter * radial_frequencies
131
+ j1_values = torch.special.bessel_j1(np.pi * argument)
132
+ jinc = torch.where(argument > 1e-10, j1_values / (2 * argument), 0.5)
133
+ return jinc**2
134
+
135
+
68
136
  def _calculate_wrap_unsafe_transfer_function(
69
137
  zyx_shape: tuple[int, int, int],
70
138
  yx_pixel_size: float,
@@ -73,6 +141,7 @@ def _calculate_wrap_unsafe_transfer_function(
73
141
  z_padding: int,
74
142
  index_of_refraction_media: float,
75
143
  numerical_aperture_detection: float,
144
+ confocal_pinhole_diameter: float | None = None,
76
145
  ) -> Tensor:
77
146
  radial_frequencies = util.generate_radial_frequencies(
78
147
  zyx_shape[1:], yx_pixel_size
@@ -102,6 +171,29 @@ def _calculate_wrap_unsafe_transfer_function(
102
171
  optical_transfer_function = torch.fft.fftn(
103
172
  point_spread_function, dim=(0, 1, 2)
104
173
  )
174
+
175
+ # Confocal: multiply excitation PSF with detection PSF (downweighted by pinhole)
176
+ if confocal_pinhole_diameter is not None:
177
+ pinhole_otf_2d = _calculate_pinhole_aperture_otf(
178
+ radial_frequencies, confocal_pinhole_diameter
179
+ )
180
+ # Detection OTF is downweighted by pinhole
181
+ otf_detection = optical_transfer_function * pinhole_otf_2d[None, :, :]
182
+
183
+ # Convert to PSFs
184
+ psf_excitation = torch.abs(
185
+ torch.fft.ifftn(optical_transfer_function, dim=(0, 1, 2))
186
+ )
187
+ psf_detection = torch.abs(
188
+ torch.fft.ifftn(otf_detection, dim=(0, 1, 2))
189
+ )
190
+
191
+ # Confocal PSF = excitation PSF * detection PSF (in real space)
192
+ psf_confocal = psf_excitation * psf_detection
193
+
194
+ # Convert back to OTF
195
+ optical_transfer_function = torch.fft.fftn(psf_confocal, dim=(0, 1, 2))
196
+
105
197
  optical_transfer_function /= torch.max(
106
198
  torch.abs(optical_transfer_function)
107
199
  ) # normalize