waveorder 2.2.1__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. waveorder/_version.py +16 -3
  2. waveorder/acq/__init__.py +0 -0
  3. waveorder/acq/acq_functions.py +166 -0
  4. waveorder/assets/HSV_legend.png +0 -0
  5. waveorder/assets/JCh_legend.png +0 -0
  6. waveorder/assets/waveorder_plugin_logo.png +0 -0
  7. waveorder/calib/Calibration.py +1512 -0
  8. waveorder/calib/Optimization.py +470 -0
  9. waveorder/calib/__init__.py +0 -0
  10. waveorder/calib/calibration_workers.py +464 -0
  11. waveorder/cli/apply_inverse_models.py +328 -0
  12. waveorder/cli/apply_inverse_transfer_function.py +379 -0
  13. waveorder/cli/compute_transfer_function.py +432 -0
  14. waveorder/cli/gui_widget.py +58 -0
  15. waveorder/cli/main.py +39 -0
  16. waveorder/cli/monitor.py +163 -0
  17. waveorder/cli/option_eat_all.py +47 -0
  18. waveorder/cli/parsing.py +122 -0
  19. waveorder/cli/printing.py +16 -0
  20. waveorder/cli/reconstruct.py +67 -0
  21. waveorder/cli/settings.py +187 -0
  22. waveorder/cli/utils.py +175 -0
  23. waveorder/filter.py +1 -2
  24. waveorder/focus.py +136 -25
  25. waveorder/io/__init__.py +0 -0
  26. waveorder/io/_reader.py +61 -0
  27. waveorder/io/core_functions.py +272 -0
  28. waveorder/io/metadata_reader.py +195 -0
  29. waveorder/io/utils.py +175 -0
  30. waveorder/io/visualization.py +160 -0
  31. waveorder/models/inplane_oriented_thick_pol3d_vector.py +3 -3
  32. waveorder/models/isotropic_fluorescent_thick_3d.py +92 -0
  33. waveorder/models/isotropic_fluorescent_thin_3d.py +331 -0
  34. waveorder/models/isotropic_thin_3d.py +73 -72
  35. waveorder/models/phase_thick_3d.py +103 -4
  36. waveorder/napari.yaml +36 -0
  37. waveorder/plugin/__init__.py +9 -0
  38. waveorder/plugin/gui.py +1094 -0
  39. waveorder/plugin/gui.ui +1440 -0
  40. waveorder/plugin/job_manager.py +42 -0
  41. waveorder/plugin/main_widget.py +1605 -0
  42. waveorder/plugin/tab_recon.py +3294 -0
  43. waveorder/scripts/__init__.py +0 -0
  44. waveorder/scripts/launch_napari.py +13 -0
  45. waveorder/scripts/repeat-cal-acq-rec.py +147 -0
  46. waveorder/scripts/repeat-calibration.py +31 -0
  47. waveorder/scripts/samples.py +85 -0
  48. waveorder/scripts/simulate_zarr_acq.py +204 -0
  49. waveorder/util.py +1 -1
  50. waveorder/visuals/napari_visuals.py +1 -1
  51. waveorder-3.0.0.dist-info/METADATA +350 -0
  52. waveorder-3.0.0.dist-info/RECORD +69 -0
  53. {waveorder-2.2.1.dist-info → waveorder-3.0.0.dist-info}/WHEEL +1 -1
  54. waveorder-3.0.0.dist-info/entry_points.txt +5 -0
  55. {waveorder-2.2.1.dist-info → waveorder-3.0.0.dist-info}/licenses/LICENSE +13 -1
  56. waveorder-2.2.1.dist-info/METADATA +0 -188
  57. waveorder-2.2.1.dist-info/RECORD +0 -27
  58. {waveorder-2.2.1.dist-info → waveorder-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ import click
2
+
3
+
4
+ # Copied directly from https://stackoverflow.com/a/48394004
5
+ # Enables `-i ./input.zarr/*/*/*`
6
+ class OptionEatAll(click.Option):
7
+ def __init__(self, *args, **kwargs):
8
+ self.save_other_options = kwargs.pop("save_other_options", True)
9
+ nargs = kwargs.pop("nargs", -1)
10
+ assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs)
11
+ super(OptionEatAll, self).__init__(*args, **kwargs)
12
+ self._previous_parser_process = None
13
+ self._eat_all_parser = None
14
+
15
+ def add_to_parser(self, parser, ctx):
16
+ def parser_process(value, state):
17
+ # method to hook to the parser.process
18
+ done = False
19
+ value = [value]
20
+ if self.save_other_options:
21
+ # grab everything up to the next option
22
+ while state.rargs and not done:
23
+ for prefix in self._eat_all_parser.prefixes:
24
+ if state.rargs[0].startswith(prefix):
25
+ done = True
26
+ if not done:
27
+ value.append(state.rargs.pop(0))
28
+ else:
29
+ # grab everything remaining
30
+ value += state.rargs
31
+ state.rargs[:] = []
32
+ value = tuple(value)
33
+
34
+ # call the actual process
35
+ self._previous_parser_process(value, state)
36
+
37
+ retval = super(OptionEatAll, self).add_to_parser(parser, ctx)
38
+ for name in self.opts:
39
+ our_parser = parser._long_opt.get(name) or parser._short_opt.get(
40
+ name
41
+ )
42
+ if our_parser:
43
+ self._eat_all_parser = our_parser
44
+ self._previous_parser_process = our_parser.process
45
+ our_parser.process = parser_process
46
+ break
47
+ return retval
@@ -0,0 +1,122 @@
1
+ from pathlib import Path
2
+ from typing import Callable
3
+
4
+ import click
5
+ import torch.multiprocessing as mp
6
+ from iohub.ngff import Plate, open_ome_zarr
7
+ from natsort import natsorted
8
+
9
+ from waveorder.cli.option_eat_all import OptionEatAll
10
+
11
+
12
+ def _validate_and_process_paths(
13
+ ctx: click.Context, opt: click.Option, value: str
14
+ ) -> list[Path]:
15
+ # Sort and validate the input paths, expanding plates into lists of positions
16
+ input_paths = [Path(path) for path in natsorted(value)]
17
+ for path in input_paths:
18
+ with open_ome_zarr(path, mode="r") as dataset:
19
+ if isinstance(dataset, Plate):
20
+ plate_path = input_paths.pop()
21
+ for position in dataset.positions():
22
+ input_paths.append(plate_path / position[0])
23
+
24
+ return input_paths
25
+
26
+
27
+ def _str_to_path(ctx: click.Context, opt: click.Option, value: str) -> Path:
28
+ return Path(value)
29
+
30
+
31
+ def input_position_dirpaths() -> Callable:
32
+ def decorator(f: Callable) -> Callable:
33
+ return click.option(
34
+ "--input-position-dirpaths",
35
+ "-i",
36
+ cls=OptionEatAll,
37
+ type=tuple,
38
+ required=True,
39
+ callback=_validate_and_process_paths,
40
+ help="List of paths to input positions, each with the same TCZYX shape. Supports wildcards e.g. 'input.zarr/*/*/*'.",
41
+ )(f)
42
+
43
+ return decorator
44
+
45
+
46
+ def config_filepath() -> Callable:
47
+ def decorator(f: Callable) -> Callable:
48
+ return click.option(
49
+ "--config-filepath",
50
+ "-c",
51
+ required=True,
52
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
53
+ callback=_str_to_path,
54
+ help="Path to YAML configuration file.",
55
+ )(f)
56
+
57
+ return decorator
58
+
59
+
60
+ def transfer_function_dirpath() -> Callable:
61
+ def decorator(f: Callable) -> Callable:
62
+ return click.option(
63
+ "--transfer-function-dirpath",
64
+ "-t",
65
+ required=True,
66
+ type=click.Path(exists=False),
67
+ callback=_str_to_path,
68
+ help="Path to transfer function .zarr.",
69
+ )(f)
70
+
71
+ return decorator
72
+
73
+
74
+ def output_dirpath() -> Callable:
75
+ def decorator(f: Callable) -> Callable:
76
+ return click.option(
77
+ "--output-dirpath",
78
+ "-o",
79
+ required=True,
80
+ type=click.Path(exists=False),
81
+ callback=_str_to_path,
82
+ help="Path to output directory.",
83
+ )(f)
84
+
85
+ return decorator
86
+
87
+
88
+ # TODO: this setting will have to be collected from SLURM?
89
+ def processes_option(default: int = None) -> Callable:
90
+ def check_processes_option(ctx, param, value):
91
+ max_processes = mp.cpu_count()
92
+ if value > max_processes:
93
+ raise click.BadParameter(
94
+ f"Maximum number of processes is {max_processes}"
95
+ )
96
+ return value
97
+
98
+ def decorator(f: Callable) -> Callable:
99
+ return click.option(
100
+ "--num_processes",
101
+ "-j",
102
+ default=default or mp.cpu_count(),
103
+ type=int,
104
+ help="Number of processes to run in parallel.",
105
+ callback=check_processes_option,
106
+ )(f)
107
+
108
+ return decorator
109
+
110
+
111
+ def unique_id() -> Callable:
112
+ def decorator(f: Callable) -> Callable:
113
+ return click.option(
114
+ "--unique-id",
115
+ "-uid",
116
+ default="",
117
+ required=False,
118
+ type=str,
119
+ help="Unique ID.",
120
+ )(f)
121
+
122
+ return decorator
@@ -0,0 +1,16 @@
1
+ import click
2
+ import yaml
3
+
4
+ from waveorder.cli.settings import MyBaseModel
5
+
6
+
7
+ def echo_settings(settings: MyBaseModel):
8
+ click.echo(
9
+ yaml.dump(
10
+ settings.model_dump(), default_flow_style=False, sort_keys=False
11
+ )
12
+ )
13
+
14
+
15
+ def echo_headline(headline):
16
+ click.echo(click.style(headline, fg="green"))
@@ -0,0 +1,67 @@
1
+ from pathlib import Path
2
+
3
+ import click
4
+
5
+ from waveorder.cli.apply_inverse_transfer_function import (
6
+ apply_inverse_transfer_function_cli,
7
+ )
8
+ from waveorder.cli.compute_transfer_function import (
9
+ compute_transfer_function_cli,
10
+ )
11
+ from waveorder.cli.parsing import (
12
+ config_filepath,
13
+ input_position_dirpaths,
14
+ output_dirpath,
15
+ processes_option,
16
+ unique_id,
17
+ )
18
+
19
+
20
+ @click.command("reconstruct")
21
+ @input_position_dirpaths()
22
+ @config_filepath()
23
+ @output_dirpath()
24
+ @processes_option(default=1)
25
+ @unique_id()
26
+ def _reconstruct_cli(
27
+ input_position_dirpaths,
28
+ config_filepath,
29
+ output_dirpath,
30
+ num_processes,
31
+ unique_id,
32
+ ):
33
+ """
34
+ Reconstruct a dataset using a configuration file. This is a
35
+ convenience function for a `compute-tf` call followed by a `apply-inv-tf`
36
+ call.
37
+
38
+ Calculates the transfer function based on the shape of the first position
39
+ in the list `input-position-dirpaths`, then applies that transfer function
40
+ to all positions in the list `input-position-dirpaths`, so all positions
41
+ must have the same TCZYX shape.
42
+
43
+ See /examples for example configuration files.
44
+
45
+ >> waveorder reconstruct -i ./input.zarr/*/*/* -c ./examples/birefringence.yml -o ./output.zarr
46
+ """
47
+
48
+ # Handle transfer function path
49
+ transfer_function_path = output_dirpath.parent / Path(
50
+ "transfer_function_" + config_filepath.stem + ".zarr"
51
+ )
52
+
53
+ # Compute transfer function
54
+ compute_transfer_function_cli(
55
+ input_position_dirpaths[0],
56
+ config_filepath,
57
+ transfer_function_path,
58
+ )
59
+
60
+ # Apply inverse transfer function
61
+ apply_inverse_transfer_function_cli(
62
+ input_position_dirpaths,
63
+ transfer_function_path,
64
+ config_filepath,
65
+ output_dirpath,
66
+ num_processes,
67
+ )
@@ -0,0 +1,187 @@
1
+ import os
2
+ import warnings
3
+ from pathlib import Path
4
+ from typing import List, Literal, Optional, Union
5
+
6
+ from pydantic import (
7
+ BaseModel,
8
+ ConfigDict,
9
+ Extra,
10
+ NonNegativeFloat,
11
+ NonNegativeInt,
12
+ PositiveFloat,
13
+ field_validator,
14
+ model_validator,
15
+ )
16
+
17
+ # This file defines the configuration settings for the CLI.
18
+
19
+ # Example settings files in `/docs/examples/settings/` are autmatically generated
20
+ # by the tests in `/tests/cli_tests/test_settings.py` - `test_generate_example_settings`.
21
+
22
+ # To keep the example settings up to date, run `pytest` locally when this file changes.
23
+
24
+
25
+ # All settings classes inherit from MyBaseModel, which forbids extra parameters to guard against typos
26
+ class MyBaseModel(BaseModel):
27
+ model_config = ConfigDict(extra="forbid")
28
+
29
+
30
+ # Bottom level settings
31
+ class WavelengthIllumination(MyBaseModel):
32
+ wavelength_illumination: PositiveFloat = 0.532
33
+
34
+
35
+ class BirefringenceTransferFunctionSettings(MyBaseModel):
36
+ swing: float = 0.1
37
+
38
+ @field_validator("swing")
39
+ @classmethod
40
+ def swing_range(cls, v):
41
+ if v <= 0 or v >= 1.0:
42
+ raise ValueError(f"swing = {v} should be between 0 and 1.")
43
+ return v
44
+
45
+
46
+ class BirefringenceApplyInverseSettings(WavelengthIllumination):
47
+ background_path: Union[str, Path] = ""
48
+
49
+ @field_validator("background_path")
50
+ @classmethod
51
+ def check_background_path(cls, v):
52
+ if v == "":
53
+ return v
54
+
55
+ raw_dir = r"{}".format(v)
56
+ if not os.path.isdir(raw_dir):
57
+ raise ValueError(f"{v} is not a existing directory")
58
+ return raw_dir
59
+
60
+ remove_estimated_background: bool = False
61
+ flip_orientation: bool = False
62
+ rotate_orientation: bool = False
63
+
64
+
65
+ class FourierTransferFunctionSettings(MyBaseModel):
66
+ yx_pixel_size: PositiveFloat = 6.5 / 20
67
+ z_pixel_size: PositiveFloat = 2.0
68
+ z_padding: NonNegativeInt = 0
69
+ z_focus_offset: Union[float, Literal["auto"]] = 0
70
+ index_of_refraction_media: PositiveFloat = 1.3
71
+ numerical_aperture_detection: PositiveFloat = 1.2
72
+
73
+ @model_validator(mode="after")
74
+ def validate_numerical_aperture_detection(self):
75
+ if self.numerical_aperture_detection > self.index_of_refraction_media:
76
+ raise ValueError(
77
+ f"numerical_aperture_detection = {self.numerical_aperture_detection} must be less than or equal to index_of_refraction_media = {self.index_of_refraction_media}"
78
+ )
79
+ return self
80
+
81
+ @model_validator(mode="after")
82
+ def warn_unit_consistency(self):
83
+ ratio = self.yx_pixel_size / self.z_pixel_size
84
+ if ratio < 1.0 / 20 or ratio > 20:
85
+ warnings.warn(
86
+ f"yx_pixel_size ({self.yx_pixel_size}) / z_pixel_size ({self.z_pixel_size}) = {ratio}. Did you use consistent units?",
87
+ UserWarning,
88
+ )
89
+ return self
90
+
91
+
92
+ class FourierApplyInverseSettings(MyBaseModel):
93
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov"
94
+ regularization_strength: NonNegativeFloat = 1e-3
95
+ TV_rho_strength: PositiveFloat = 1e-3
96
+ TV_iterations: NonNegativeInt = 1
97
+
98
+
99
+ class PhaseTransferFunctionSettings(
100
+ FourierTransferFunctionSettings,
101
+ WavelengthIllumination,
102
+ ):
103
+ numerical_aperture_illumination: NonNegativeFloat = 0.5
104
+ invert_phase_contrast: bool = False
105
+
106
+ @model_validator(mode="after")
107
+ def validate_numerical_aperture_illumination(self):
108
+ if (
109
+ self.numerical_aperture_illumination
110
+ > self.index_of_refraction_media
111
+ ):
112
+ raise ValueError(
113
+ f"numerical_aperture_illumination = {self.numerical_aperture_illumination} must be less than or equal to index_of_refraction_media = {self.index_of_refraction_media}"
114
+ )
115
+ return self
116
+
117
+
118
+ class FluorescenceTransferFunctionSettings(FourierTransferFunctionSettings):
119
+ wavelength_emission: PositiveFloat = 0.507
120
+ confocal_pinhole_diameter: Optional[PositiveFloat] = None
121
+
122
+ @model_validator(mode="after")
123
+ def warn_unit_consistency(self):
124
+ ratio = self.yx_pixel_size / self.wavelength_emission
125
+ if ratio < 1.0 / 20 or ratio > 20:
126
+ warnings.warn(
127
+ f"yx_pixel_size ({self.yx_pixel_size}) / wavelength_illumination ({self.wavelength_emission}) = {ratio}. Did you use consistent units?",
128
+ UserWarning,
129
+ )
130
+ return self
131
+
132
+
133
+ # Second level settings
134
+ class BirefringenceSettings(MyBaseModel):
135
+ transfer_function: BirefringenceTransferFunctionSettings = (
136
+ BirefringenceTransferFunctionSettings()
137
+ )
138
+ apply_inverse: BirefringenceApplyInverseSettings = (
139
+ BirefringenceApplyInverseSettings()
140
+ )
141
+
142
+
143
+ class PhaseSettings(MyBaseModel):
144
+ transfer_function: PhaseTransferFunctionSettings = (
145
+ PhaseTransferFunctionSettings()
146
+ )
147
+ apply_inverse: FourierApplyInverseSettings = FourierApplyInverseSettings()
148
+
149
+
150
+ class FluorescenceSettings(MyBaseModel):
151
+ transfer_function: FluorescenceTransferFunctionSettings = (
152
+ FluorescenceTransferFunctionSettings()
153
+ )
154
+ apply_inverse: FourierApplyInverseSettings = FourierApplyInverseSettings()
155
+
156
+
157
+ # Top level settings
158
+ class ReconstructionSettings(MyBaseModel):
159
+ input_channel_names: List[str] = [f"State{i}" for i in range(4)]
160
+ time_indices: Union[
161
+ NonNegativeInt, List[NonNegativeInt], Literal["all"]
162
+ ] = "all"
163
+ reconstruction_dimension: Literal[2, 3] = 3
164
+ birefringence: Optional[BirefringenceSettings] = None
165
+ phase: Optional[PhaseSettings] = None
166
+ fluorescence: Optional[FluorescenceSettings] = None
167
+
168
+ @model_validator(mode="after")
169
+ def validate_reconstruction_types(self):
170
+ if (
171
+ self.birefringence or self.phase
172
+ ) and self.fluorescence is not None:
173
+ raise ValueError(
174
+ '"fluorescence" cannot be present alongside "birefringence" or "phase". Please use one configuration file for a "fluorescence" reconstruction and another configuration file for a "birefringence" and/or "phase" reconstructions.'
175
+ )
176
+ num_channel_names = len(self.input_channel_names)
177
+ if self.birefringence is None:
178
+ if self.phase is None and self.fluorescence is None:
179
+ raise ValueError(
180
+ "Provide settings for either birefringence, phase, birefringence + phase, or fluorescence."
181
+ )
182
+ if num_channel_names != 1:
183
+ raise ValueError(
184
+ f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence/phase reconstructions."
185
+ )
186
+
187
+ return self
waveorder/cli/utils.py ADDED
@@ -0,0 +1,175 @@
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import click
5
+ import numpy as np
6
+ import torch
7
+ from iohub.ngff import Position, open_ome_zarr
8
+ from iohub.ngff.models import TransformationMeta
9
+ from numpy.typing import DTypeLike
10
+
11
+
12
+ def generate_valid_position_key(index: int) -> tuple[str, str, str]:
13
+ """Generate a valid HCS position key for single-position stores.
14
+
15
+ Args:
16
+ index: Position index (0-based)
17
+
18
+ Returns:
19
+ Tuple of (row, column, field) with alphanumeric characters only
20
+ """
21
+ row = chr(65 + (index // 10)) # A, B, C, etc.
22
+ column = str((index % 10) + 1) # 1, 2, 3, etc.
23
+ field = "0" # Always 0 for single positions
24
+ return (row, column, field)
25
+
26
+
27
+ def is_single_position_store(position_path: Path) -> bool:
28
+ """Check if a position path is from a single-position store (not HCS plate).
29
+
30
+ Args:
31
+ position_path: Path to the position directory
32
+
33
+ Returns:
34
+ True if it's a single-position store, False if it's part of an HCS plate
35
+ """
36
+ try:
37
+ # Try to open as HCS plate 3 levels up
38
+ open_ome_zarr(position_path.parent.parent.parent, mode="r")
39
+ return False # Successfully opened as plate
40
+ except (RuntimeError, FileNotFoundError):
41
+ return True # Not a plate structure
42
+
43
+
44
+ def create_empty_hcs_zarr(
45
+ store_path: Path,
46
+ position_keys: list[Tuple[str]],
47
+ shape: Tuple[int],
48
+ chunks: Tuple[int],
49
+ scale: Tuple[float],
50
+ channel_names: list[str],
51
+ dtype: DTypeLike,
52
+ plate_metadata: dict = {},
53
+ ) -> None:
54
+ """If the plate does not exist, create an empty zarr plate.
55
+
56
+ If the plate exists, append positions and channels if they are not
57
+ already in the plate.
58
+
59
+ Parameters
60
+ ----------
61
+ store_path : Path
62
+ hcs plate path
63
+ position_keys : list[Tuple[str]]
64
+ Position keys, will append if not present in the plate.
65
+ e.g. [("A", "1", "0"), ("A", "1", "1")]
66
+ shape : Tuple[int]
67
+ chunks : Tuple[int]
68
+ scale : Tuple[float]
69
+ channel_names : list[str]
70
+ Channel names, will append if not present in metadata.
71
+ dtype : DTypeLike
72
+ plate_metadata : dict
73
+ """
74
+
75
+ # Create plate
76
+ output_plate = open_ome_zarr(
77
+ str(store_path), layout="hcs", mode="a", channel_names=channel_names
78
+ )
79
+
80
+ # Pass metadata
81
+ output_plate.zattrs.update(plate_metadata)
82
+
83
+ # Create positions
84
+ for position_key in position_keys:
85
+ position_key_string = "/".join(position_key)
86
+ # Check if position is already in the store, if not create it
87
+ if position_key_string not in output_plate.zgroup:
88
+ position = output_plate.create_position(*position_key)
89
+
90
+ _ = position.create_zeros(
91
+ name="0",
92
+ shape=shape,
93
+ chunks=chunks,
94
+ dtype=dtype,
95
+ transform=[TransformationMeta(type="scale", scale=scale)],
96
+ )
97
+ else:
98
+ position = output_plate[position_key_string]
99
+
100
+ # Check if channel_names are already in the store, if not append them
101
+ for channel_name in channel_names:
102
+ # Read channel names directly from metadata to avoid race conditions
103
+ metadata_channel_names = [
104
+ channel.label for channel in position.metadata.omero.channels
105
+ ]
106
+ if channel_name not in metadata_channel_names:
107
+ position.append_channel(channel_name, resize_arrays=True)
108
+
109
+
110
+ def apply_inverse_to_zyx_and_save(
111
+ func,
112
+ position: Position,
113
+ output_path: Path,
114
+ input_channel_indices: list[int],
115
+ output_channel_indices: list[int],
116
+ t_idx: int = 0,
117
+ **kwargs,
118
+ ) -> None:
119
+ """Load a zyx array from a Position object, apply a transformation and save the result to file"""
120
+ click.echo(f"Reconstructing t={t_idx}")
121
+
122
+ # Load data
123
+ czyx_uint16_numpy = position.data.oindex[t_idx, input_channel_indices]
124
+
125
+ # Check if all values in czyx_uint16_numpy are not zeros or Nan
126
+ if _check_nan_n_zeros(czyx_uint16_numpy):
127
+ click.echo(
128
+ f"All values at t={t_idx} are zero or Nan, skipping reconstruction."
129
+ )
130
+ return
131
+
132
+ # convert to np.int32 (torch doesn't accept np.uint16), then convert to tensor float32
133
+ czyx_data = torch.tensor(np.int32(czyx_uint16_numpy), dtype=torch.float32)
134
+
135
+ # Apply transformation
136
+ reconstruction_czyx = func(czyx_data, **kwargs)
137
+
138
+ # Write to file
139
+ # for c, recon_zyx in enumerate(reconstruction_zyx):
140
+ with open_ome_zarr(output_path, mode="r+") as output_dataset:
141
+ output_dataset[0].oindex[
142
+ t_idx, output_channel_indices
143
+ ] = reconstruction_czyx
144
+ click.echo(f"Finished Writing.. t={t_idx}")
145
+
146
+
147
+ def estimate_resources(shape, settings, num_processes):
148
+ T, C, Z, Y, X = shape
149
+
150
+ gb_ram_per_cpu = 0
151
+ gb_per_element = 4 / 2**30 # bytes_per_float32 / bytes_per_gb
152
+ voxel_resource_multiplier = 4
153
+ fourier_resource_multiplier = 32
154
+ input_memory = Z * Y * X * gb_per_element
155
+
156
+ if settings.birefringence is not None:
157
+ gb_ram_per_cpu += input_memory * voxel_resource_multiplier
158
+ if settings.phase is not None:
159
+ gb_ram_per_cpu += input_memory * fourier_resource_multiplier
160
+ if settings.fluorescence is not None:
161
+ gb_ram_per_cpu += input_memory * fourier_resource_multiplier
162
+ ram_multiplier = 1
163
+ gb_ram_per_cpu = np.ceil(
164
+ np.max([1, ram_multiplier * gb_ram_per_cpu])
165
+ ).astype(int)
166
+ num_cpus = np.min([32, num_processes])
167
+
168
+ return num_cpus, gb_ram_per_cpu
169
+
170
+
171
+ def _check_nan_n_zeros(input_array):
172
+ """
173
+ Checks if data are all zeros or nan
174
+ """
175
+ return np.all(np.isnan(input_array)) or np.all(input_array == 0)
waveorder/filter.py CHANGED
@@ -74,7 +74,6 @@ def apply_filter_bank(
74
74
  )
75
75
 
76
76
  num_input_channels, num_output_channels = io_filter_bank.shape[:2]
77
- spatial_dims = io_filter_bank.shape[2:]
78
77
 
79
78
  # Pad input_array until each dimension is divisible by transfer_function
80
79
  pad_sizes = [
@@ -98,7 +97,7 @@ def apply_filter_bank(
98
97
  # Further optimization is likely with a combination of
99
98
  # torch.baddbmm, torch.pixel_shuffle, torch.pixel_unshuffle.
100
99
  padded_output_spectrum = torch.zeros(
101
- (num_output_channels,) + spatial_dims,
100
+ (num_output_channels,) + padded_input_spectrum.shape[1:],
102
101
  dtype=padded_input_spectrum.dtype,
103
102
  device=padded_input_spectrum.device,
104
103
  )