octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from octopi.processing.segmentation_from_picks import from_picks
|
|
2
|
+
from copick_utils.io import readers, writers
|
|
3
|
+
import zarr, os, yaml, copick
|
|
4
|
+
from octopi.utils import io
|
|
5
|
+
from typing import List
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
def print_target_summary(train_targets: dict, target_segmentation_name: str, maxval: int):
|
|
10
|
+
"""
|
|
11
|
+
Print a summary of the target volume structure.
|
|
12
|
+
"""
|
|
13
|
+
print("\n" + "="*60)
|
|
14
|
+
print("TARGET VOLUME SUMMARY")
|
|
15
|
+
print("="*60)
|
|
16
|
+
print(f"Segmentation name: {target_segmentation_name}")
|
|
17
|
+
print(f"Total classes: {len(train_targets) + 1} (including background)")
|
|
18
|
+
print("\nLabel Index → Object Name (Type):")
|
|
19
|
+
print(f" {0:3d} → background")
|
|
20
|
+
|
|
21
|
+
# Sort by label for display
|
|
22
|
+
sorted_targets = sorted(train_targets.items(), key=lambda x: x[1]['label'])
|
|
23
|
+
for name, info in sorted_targets:
|
|
24
|
+
obj_type = "particle" if info['is_particle_target'] else "segmentation"
|
|
25
|
+
radius_info = f", radius={info['radius']:.1f}Å" if info['radius'] else ""
|
|
26
|
+
print(f" {info['label']:3d} → {name} ({obj_type}{radius_info})")
|
|
27
|
+
|
|
28
|
+
print("="*60)
|
|
29
|
+
print(f"💡 Use --num-classes {maxval + 1} when training with this target")
|
|
30
|
+
print("="*60 + "\n")
|
|
31
|
+
|
|
32
|
+
def generate_targets(
|
|
33
|
+
config,
|
|
34
|
+
train_targets: dict,
|
|
35
|
+
voxel_size: float = 10,
|
|
36
|
+
tomo_algorithm: str = 'wbp',
|
|
37
|
+
radius_scale: float = 0.8,
|
|
38
|
+
target_segmentation_name: str = 'targets',
|
|
39
|
+
target_user_name: str = 'octopi',
|
|
40
|
+
target_session_id: str = '1',
|
|
41
|
+
run_ids: List[str] = None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Generate segmentation targets from picks in CoPick configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
copick_config_path (str): Path to CoPick configuration file.
|
|
48
|
+
picks_user_id (str): User ID associated with picks.
|
|
49
|
+
picks_session_id (str): Session ID associated with picks.
|
|
50
|
+
target_segmentation_name (str): Name for the target segmentation.
|
|
51
|
+
target_user_name (str): User name associated with target segmentation.
|
|
52
|
+
target_session_id (str): Session ID for the target segmentation.
|
|
53
|
+
voxel_size (float): Voxel size for tomogram reconstruction.
|
|
54
|
+
tomo_algorithm (str): Tomogram reconstruction algorithm.
|
|
55
|
+
radius_scale (float): Scale factor for target object radius.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Default session ID to 1 if not provided
|
|
59
|
+
root = copick.from_file(config)
|
|
60
|
+
if target_session_id is None:
|
|
61
|
+
target_session_id = '1'
|
|
62
|
+
|
|
63
|
+
# Print target summary
|
|
64
|
+
print('🔄 Creating Targets for the following objects:', ', '.join(train_targets.keys()))
|
|
65
|
+
|
|
66
|
+
# Get Target Names
|
|
67
|
+
target_names = list(train_targets.keys())
|
|
68
|
+
|
|
69
|
+
# If runIDs are not provided, load all runs
|
|
70
|
+
if run_ids is None:
|
|
71
|
+
run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is not None]
|
|
72
|
+
skipped_run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is None]
|
|
73
|
+
|
|
74
|
+
if skipped_run_ids:
|
|
75
|
+
print(f"⚠️ Warning: skipping runs with no voxel spacing {voxel_size}: {skipped_run_ids}")
|
|
76
|
+
|
|
77
|
+
# Iterate Over All Runs
|
|
78
|
+
maxval = -1
|
|
79
|
+
for runID in tqdm(run_ids):
|
|
80
|
+
|
|
81
|
+
# Get Run
|
|
82
|
+
numPicks = 0
|
|
83
|
+
run = root.get_run(runID)
|
|
84
|
+
|
|
85
|
+
# Get Target Shape
|
|
86
|
+
vs = run.get_voxel_spacing(voxel_size)
|
|
87
|
+
if vs is None:
|
|
88
|
+
print(f"⚠️ Warning: skipping run {runID} with no voxel spacing {voxel_size}")
|
|
89
|
+
continue
|
|
90
|
+
tomo = vs.get_tomogram(tomo_algorithm)
|
|
91
|
+
if tomo is None:
|
|
92
|
+
print(f"⚠️ Warning: skipping run {runID} with no tomogram {tomo_algorithm}")
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
# Initialize Target Volume
|
|
96
|
+
loc = tomo.zarr()
|
|
97
|
+
shape = zarr.open(loc)['0'].shape
|
|
98
|
+
target = np.zeros(shape, dtype=np.uint8)
|
|
99
|
+
|
|
100
|
+
# Generate Targets
|
|
101
|
+
# Applicable segmentations
|
|
102
|
+
query_seg = []
|
|
103
|
+
for target_name in target_names:
|
|
104
|
+
if not train_targets[target_name]["is_particle_target"]:
|
|
105
|
+
query_seg += run.get_segmentations(
|
|
106
|
+
name=target_name,
|
|
107
|
+
user_id=train_targets[target_name]["user_id"],
|
|
108
|
+
session_id=train_targets[target_name]["session_id"],
|
|
109
|
+
voxel_size=voxel_size
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Add Segmentations to Target
|
|
113
|
+
for seg in query_seg:
|
|
114
|
+
classLabel = train_targets[seg.name]['label']
|
|
115
|
+
segvol = seg.numpy()
|
|
116
|
+
# Set all non-zero values to the class label
|
|
117
|
+
segvol[segvol > 0] = classLabel
|
|
118
|
+
target = np.maximum(target, segvol)
|
|
119
|
+
|
|
120
|
+
# Applicable picks
|
|
121
|
+
query = []
|
|
122
|
+
for target_name in target_names:
|
|
123
|
+
if train_targets[target_name]["is_particle_target"]:
|
|
124
|
+
query += run.get_picks(
|
|
125
|
+
object_name=target_name,
|
|
126
|
+
user_id=train_targets[target_name]["user_id"],
|
|
127
|
+
session_id=train_targets[target_name]["session_id"],
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Filter out empty picks
|
|
131
|
+
query = [pick for pick in query if pick.points is not None]
|
|
132
|
+
|
|
133
|
+
# Add Picks to Target
|
|
134
|
+
for pick in query:
|
|
135
|
+
numPicks += len(pick.points)
|
|
136
|
+
target = from_picks(pick,
|
|
137
|
+
target,
|
|
138
|
+
train_targets[pick.pickable_object_name]['radius'] * radius_scale,
|
|
139
|
+
train_targets[pick.pickable_object_name]['label'],
|
|
140
|
+
voxel_size
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Write Segmentation for non-empty targets
|
|
144
|
+
if target.max() > 0:
|
|
145
|
+
tqdm.write(f'📝 Annotating {numPicks} picks in {runID}...')
|
|
146
|
+
writers.segmentation(run, target, target_user_name,
|
|
147
|
+
name = target_segmentation_name, session_id= target_session_id,
|
|
148
|
+
voxel_size = voxel_size)
|
|
149
|
+
if target.max() > maxval:
|
|
150
|
+
maxval = target.max()
|
|
151
|
+
|
|
152
|
+
print('✅ Creation of targets complete!')
|
|
153
|
+
|
|
154
|
+
# Save Parameters
|
|
155
|
+
overlay_root = io.remove_prefix(root.config.overlay_root)
|
|
156
|
+
basepath = os.path.join(overlay_root, 'logs')
|
|
157
|
+
os.makedirs(basepath, exist_ok=True)
|
|
158
|
+
labels = {name: info['label'] for name, info in train_targets.items()}
|
|
159
|
+
args = {
|
|
160
|
+
"config": config,
|
|
161
|
+
"train_targets": train_targets,
|
|
162
|
+
"radius_scale": radius_scale,
|
|
163
|
+
"tomo_algorithm": tomo_algorithm,
|
|
164
|
+
"target_name": target_segmentation_name,
|
|
165
|
+
"target_user_name": target_user_name,
|
|
166
|
+
"target_session_id": target_session_id,
|
|
167
|
+
"voxel_size": voxel_size,
|
|
168
|
+
"labels": labels,
|
|
169
|
+
}
|
|
170
|
+
target_query = f'{target_user_name}_{target_session_id}_{target_segmentation_name}'
|
|
171
|
+
print(f'💾 Saving parameters to {basepath}/targets-{target_query}.yaml')
|
|
172
|
+
save_parameters(args, basepath, target_query)
|
|
173
|
+
|
|
174
|
+
# Print Target Summary
|
|
175
|
+
print_target_summary(train_targets, target_segmentation_name, maxval)
|
|
176
|
+
|
|
177
|
+
def save_parameters(args, basepath: str, target_query: str):
|
|
178
|
+
"""
|
|
179
|
+
Save parameters to a YAML file with subgroups for input, output, and parameters.
|
|
180
|
+
Append to the file if it already exists.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
args: Parsed arguments from argparse.
|
|
184
|
+
basepath: Path to save the YAML file.
|
|
185
|
+
target_query: Query string for target identification.
|
|
186
|
+
"""
|
|
187
|
+
# Prepare input group
|
|
188
|
+
keys = ['user_id', 'session_id']
|
|
189
|
+
input_group = {
|
|
190
|
+
"config": args['config'],
|
|
191
|
+
"labels": {name: info['label'] for name, info in args['train_targets'].items()}, # <-- Added comma here
|
|
192
|
+
"targets": {name: {k: info[k] for k in keys} for name, info in args['train_targets'].items()}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# Organize parameters into subgroups
|
|
196
|
+
new_entry = {
|
|
197
|
+
"input": input_group,
|
|
198
|
+
"parameters": {
|
|
199
|
+
"radius_scale": args["radius_scale"],
|
|
200
|
+
"tomogram_algorithm": args["tomo_algorithm"],
|
|
201
|
+
"voxel_size": args["voxel_size"],
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# Check if the YAML file already exists
|
|
206
|
+
output_path = os.path.join(
|
|
207
|
+
basepath,
|
|
208
|
+
f'targets-{args["target_user_name"]}_{args["target_session_id"]}_{args["target_name"]}.yaml')
|
|
209
|
+
if os.path.exists(output_path):
|
|
210
|
+
# Load the existing content
|
|
211
|
+
with open(output_path, 'r') as f:
|
|
212
|
+
try:
|
|
213
|
+
existing_data = yaml.safe_load(f)
|
|
214
|
+
if existing_data is None:
|
|
215
|
+
existing_data = {} # Ensure it's a dictionary
|
|
216
|
+
elif not isinstance(existing_data, dict):
|
|
217
|
+
raise ValueError("Existing YAML data is not a dictionary. Cannot update.")
|
|
218
|
+
except yaml.YAMLError:
|
|
219
|
+
existing_data = {} # Treat as empty if the file is malformed
|
|
220
|
+
else:
|
|
221
|
+
existing_data = {} # Initialize as empty dictionary if the file does not exist
|
|
222
|
+
|
|
223
|
+
# Save back to the YAML file
|
|
224
|
+
io.save_parameters_yaml(new_entry, output_path)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import rich_click as click
|
|
2
|
+
|
|
3
|
+
def from_dataportal(
|
|
4
|
+
config,
|
|
5
|
+
datasetID,
|
|
6
|
+
overlay_path,
|
|
7
|
+
source_type,
|
|
8
|
+
target_type,
|
|
9
|
+
input_voxel_size = 10,
|
|
10
|
+
output_voxel_size = None):
|
|
11
|
+
"""
|
|
12
|
+
Download and process tomograms from the CZI Dataportal.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
config (str): Path to the copick configuration file
|
|
16
|
+
datasetID (int): ID of the dataset to download
|
|
17
|
+
overlay_path (str): Path to the overlay file
|
|
18
|
+
source_type (str): Name of the tomogram type in the dataportal
|
|
19
|
+
target_type (str): Name to use for the tomogram locally
|
|
20
|
+
input_voxel_size (float): Original voxel size of the tomograms
|
|
21
|
+
output_voxel_size (float, optional): Desired voxel size for downsampling
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from octopi.processing.downsample import FourierRescale
|
|
25
|
+
from octopi.utils.progress import _progress, print_summary
|
|
26
|
+
from copick_utils.io import writers
|
|
27
|
+
import copick
|
|
28
|
+
|
|
29
|
+
# Either load an existing configuration file or create one from datasetID and overlay_path
|
|
30
|
+
if config is not None:
|
|
31
|
+
root = copick.from_file(config)
|
|
32
|
+
elif datasetID is not None and overlay_path is not None:
|
|
33
|
+
root = copick.from_czcdp_datasets(
|
|
34
|
+
[datasetID], overlay_root=overlay_path,
|
|
35
|
+
output_path='config.json', overlay_fs_args={'auto_mkdir': True}
|
|
36
|
+
)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError('Either config or datasetID and overlay_path must be provided')
|
|
39
|
+
|
|
40
|
+
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
41
|
+
if output_voxel_size is not None and output_voxel_size > input_voxel_size:
|
|
42
|
+
rescale = FourierRescale(input_voxel_size, output_voxel_size)
|
|
43
|
+
else:
|
|
44
|
+
output_voxel_size = None
|
|
45
|
+
|
|
46
|
+
# Print Parameter Summary
|
|
47
|
+
print_summary(
|
|
48
|
+
"Download Tomograms",
|
|
49
|
+
datasetID=datasetID, overlay_path=overlay_path,
|
|
50
|
+
config=config, source_type=source_type, target_type=target_type,
|
|
51
|
+
input_voxel_size=input_voxel_size, output_voxel_size=output_voxel_size,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Main Loop
|
|
55
|
+
for run in _progress(root.runs):
|
|
56
|
+
|
|
57
|
+
# Check if voxel spacing is available
|
|
58
|
+
vs = run.get_voxel_spacing(input_voxel_size)
|
|
59
|
+
|
|
60
|
+
if vs is None:
|
|
61
|
+
print(f'No Voxel-Spacing Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}')
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
# Check if base reconstruction method is available
|
|
65
|
+
avail_tomos = vs.get_tomograms(source_type)
|
|
66
|
+
if avail_tomos is None:
|
|
67
|
+
print(f'No Tomograms Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}, Tomo-Type: {source_type}')
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
# Download the tomogram
|
|
71
|
+
if len(avail_tomos) > 0:
|
|
72
|
+
vol = avail_tomos[0].numpy()
|
|
73
|
+
|
|
74
|
+
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
75
|
+
if output_voxel_size is None:
|
|
76
|
+
writers.tomogram(run, vol, input_voxel_size, target_type)
|
|
77
|
+
else:
|
|
78
|
+
vol = rescale.run(vol)
|
|
79
|
+
writers.tomogram(run, vol, output_voxel_size, target_type)
|
|
80
|
+
|
|
81
|
+
print(f'✅ Download Complete!\nDownloaded {len(root.runs)} runs')
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@click.command('download')
|
|
85
|
+
# Voxel Settings
|
|
86
|
+
@click.option('-ovs', '--output-voxel-size', type=float, default=None,
|
|
87
|
+
help="Desired output voxel size for downsampling (optional)")
|
|
88
|
+
@click.option('-ivs', '--input-voxel-size', type=float, default=10,
|
|
89
|
+
help="Original voxel size of the tomograms")
|
|
90
|
+
# Tomogram Settings
|
|
91
|
+
@click.option('-t', '--target-type', type=str, default='denoised',
|
|
92
|
+
help="Local tomogram type name to save in your Copick project.")
|
|
93
|
+
@click.option('-s', '--source-type', type=str, default='wbp-denoised-ctfdeconv',
|
|
94
|
+
help="Name of the tomogram type as labeled on the CryoET Data Portal")
|
|
95
|
+
# Input Arguments
|
|
96
|
+
@click.option('-o', '--overlay', type=click.Path(), default=None,
|
|
97
|
+
help="Path to the overlay directory (required with datasetID)")
|
|
98
|
+
@click.option('-ds', '--datasetID', type=int, default=None,
|
|
99
|
+
help="Dataset ID from CZI Dataportal (alternative to config)")
|
|
100
|
+
@click.option('-c', '--config', type=click.Path(exists=True), default=None,
|
|
101
|
+
help="Path to the copick configuration file (alternative to datasetID)")
|
|
102
|
+
def cli(config, datasetid, overlay, source_type, target_type,
|
|
103
|
+
input_voxel_size, output_voxel_size):
|
|
104
|
+
"""
|
|
105
|
+
Download and (optionally) downsample tomograms from the CryoET-DataPortal.
|
|
106
|
+
|
|
107
|
+
This command fetches reconstructed tomograms from publicly available datasets and saves them
|
|
108
|
+
to your local copick project. Downsampling is performed via Fourier cropping to preserve
|
|
109
|
+
high-frequency information while reducing file size.
|
|
110
|
+
|
|
111
|
+
Two modes of operation:
|
|
112
|
+
|
|
113
|
+
\b
|
|
114
|
+
1. Download tomograms and downsample to a new voxel size:
|
|
115
|
+
octopi download -c config.json --input-voxel-size 10 --output-voxel-size 20
|
|
116
|
+
|
|
117
|
+
\b
|
|
118
|
+
2. Create new project from the CryoET Data Portal:
|
|
119
|
+
octopi download --datasetID 10301 --overlay-path ./my_project --input-voxel-size 10
|
|
120
|
+
|
|
121
|
+
The downloaded tomograms will be stored in your copick project structure with the specified
|
|
122
|
+
voxel spacing and tomogram type.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
from_dataportal(
|
|
126
|
+
config=config,
|
|
127
|
+
datasetID=datasetid,
|
|
128
|
+
overlay_path=overlay,
|
|
129
|
+
source_type=source_type,
|
|
130
|
+
target_type=target_type,
|
|
131
|
+
input_voxel_size=input_voxel_size,
|
|
132
|
+
output_voxel_size=output_voxel_size
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
cli()
|
|
138
|
+
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class FourierRescale:
|
|
5
|
+
def __init__(self, input_voxel_size, output_voxel_size):
|
|
6
|
+
"""
|
|
7
|
+
Initialize the FourierRescale operation with voxel sizes.
|
|
8
|
+
|
|
9
|
+
Parameters:
|
|
10
|
+
input_voxel_size (int or tuple): Physical spacing of the input voxels (d, h, w)
|
|
11
|
+
or a single int (which will be applied to all dimensions).
|
|
12
|
+
output_voxel_size (int or tuple): Desired physical spacing of the output voxels (d, h, w)
|
|
13
|
+
or a single int (which will be applied to all dimensions).
|
|
14
|
+
Must be greater than or equal to input_voxel_size.
|
|
15
|
+
"""
|
|
16
|
+
# Convert to tuples if single int is provided.
|
|
17
|
+
if isinstance(input_voxel_size, int) or isinstance(input_voxel_size, float):
|
|
18
|
+
input_voxel_size = (input_voxel_size, input_voxel_size, input_voxel_size)
|
|
19
|
+
if isinstance(output_voxel_size, int) or isinstance(output_voxel_size, float):
|
|
20
|
+
output_voxel_size = (output_voxel_size, output_voxel_size, output_voxel_size)
|
|
21
|
+
|
|
22
|
+
self.input_voxel_size = input_voxel_size
|
|
23
|
+
self.output_voxel_size = output_voxel_size
|
|
24
|
+
|
|
25
|
+
# Check: output voxel size must be greater than or equal to input voxel size (element-wise).
|
|
26
|
+
if any(out_vs < in_vs for in_vs, out_vs in zip(input_voxel_size, output_voxel_size)):
|
|
27
|
+
raise ValueError("Output voxel size must be greater than or equal to the input voxel size.")
|
|
28
|
+
|
|
29
|
+
# Determine device: use GPU if available, otherwise CPU.
|
|
30
|
+
if torch.cuda.is_available():
|
|
31
|
+
self.device = torch.device('cuda')
|
|
32
|
+
else:
|
|
33
|
+
self.device = torch.device('cpu')
|
|
34
|
+
|
|
35
|
+
def run(self, volume):
|
|
36
|
+
"""
|
|
37
|
+
Rescale a 3D volume (or a batch of volumes on GPU) using Fourier cropping.
|
|
38
|
+
"""
|
|
39
|
+
# Initialize return_numpy flag
|
|
40
|
+
return_numpy = False
|
|
41
|
+
|
|
42
|
+
# If a numpy array is passed, convert it to a PyTorch tensor.
|
|
43
|
+
if isinstance(volume, np.ndarray):
|
|
44
|
+
return_numpy = True
|
|
45
|
+
volume = torch.from_numpy(volume)
|
|
46
|
+
|
|
47
|
+
# If running on CPU, ensure only a single volume is provided.
|
|
48
|
+
if self.device.type == 'cpu' and volume.dim() == 4:
|
|
49
|
+
raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.")
|
|
50
|
+
|
|
51
|
+
if volume.dim() == 4:
|
|
52
|
+
output = self.batched_rescale(volume)
|
|
53
|
+
else:
|
|
54
|
+
output = self.single_rescale(volume)
|
|
55
|
+
|
|
56
|
+
# Return to CPU if Compute is on GPU
|
|
57
|
+
if self.device == torch.device('cuda'):
|
|
58
|
+
output = output.cpu()
|
|
59
|
+
torch.cuda.empty_cache()
|
|
60
|
+
|
|
61
|
+
# Either return a numpy array or a torch tensor
|
|
62
|
+
if return_numpy:
|
|
63
|
+
return output.numpy()
|
|
64
|
+
else:
|
|
65
|
+
return output
|
|
66
|
+
|
|
67
|
+
def batched_rescale(self, volume: torch.Tensor):
|
|
68
|
+
"""
|
|
69
|
+
Process a (batched) volume: move to device, perform FFT, crop in Fourier space,
|
|
70
|
+
and compute the inverse FFT.
|
|
71
|
+
"""
|
|
72
|
+
volume = volume.to(self.device)
|
|
73
|
+
is_batched = (volume.dim() == 4)
|
|
74
|
+
if not is_batched:
|
|
75
|
+
volume = volume.unsqueeze(0)
|
|
76
|
+
|
|
77
|
+
fft_volume = torch.fft.fftn(volume, dim=(-3, -2, -1), norm='ortho')
|
|
78
|
+
fft_volume = torch.fft.fftshift(fft_volume, dim=(-3, -2, -1))
|
|
79
|
+
|
|
80
|
+
start_d, start_h, start_w, new_depth, new_height, new_width = self.calculate_cropping(volume)
|
|
81
|
+
|
|
82
|
+
fft_cropped = fft_volume[...,
|
|
83
|
+
start_d:start_d + new_depth,
|
|
84
|
+
start_h:start_h + new_height,
|
|
85
|
+
start_w:start_w + new_width]
|
|
86
|
+
|
|
87
|
+
fft_cropped = torch.fft.ifftshift(fft_cropped, dim=(-3, -2, -1))
|
|
88
|
+
out_volume = torch.fft.ifftn(fft_cropped, dim=(-3, -2, -1), norm='ortho')
|
|
89
|
+
out_volume = out_volume.real
|
|
90
|
+
|
|
91
|
+
if not is_batched:
|
|
92
|
+
out_volume = out_volume.squeeze(0)
|
|
93
|
+
|
|
94
|
+
return out_volume
|
|
95
|
+
|
|
96
|
+
def single_rescale(self, volume: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
return self.batched_rescale(volume)
|
|
98
|
+
|
|
99
|
+
def calculate_cropping(self, volume: torch.Tensor):
|
|
100
|
+
"""
|
|
101
|
+
Calculate cropping indices and new dimensions based on the voxel sizes.
|
|
102
|
+
"""
|
|
103
|
+
in_depth, in_height, in_width = volume.shape[-3:]
|
|
104
|
+
|
|
105
|
+
# Calculate new dimensions
|
|
106
|
+
extent_depth = in_depth * self.input_voxel_size[0]
|
|
107
|
+
extent_height = in_height * self.input_voxel_size[1]
|
|
108
|
+
extent_width = in_width * self.input_voxel_size[2]
|
|
109
|
+
|
|
110
|
+
new_depth = int(round(extent_depth / self.output_voxel_size[0]))
|
|
111
|
+
new_height = int(round(extent_height / self.output_voxel_size[1]))
|
|
112
|
+
new_width = int(round(extent_width / self.output_voxel_size[2]))
|
|
113
|
+
|
|
114
|
+
# Ensure new dimensions are even
|
|
115
|
+
new_depth = new_depth - (new_depth % 2)
|
|
116
|
+
new_height = new_height - (new_height % 2)
|
|
117
|
+
new_width = new_width - (new_width % 2)
|
|
118
|
+
|
|
119
|
+
# Calculate starting points - properly centered around DC component
|
|
120
|
+
# No odd/even correction needed - just center the crop
|
|
121
|
+
start_d = (in_depth - new_depth) // 2
|
|
122
|
+
start_h = (in_height - new_height) // 2
|
|
123
|
+
start_w = (in_width - new_width) // 2
|
|
124
|
+
|
|
125
|
+
return start_d, start_h, start_w, new_depth, new_height, new_width
|