octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from octopi.datasets import generators, multi_config_generator
|
|
2
|
+
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
|
|
3
|
+
from octopi.models import common as builder
|
|
4
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
5
|
+
from octopi.entry_points import common
|
|
6
|
+
from octopi.pytorch import trainer
|
|
7
|
+
from octopi import io, utils
|
|
8
|
+
import torch, os, argparse
|
|
9
|
+
from typing import List, Optional, Tuple
|
|
10
|
+
import pprint
|
|
11
|
+
|
|
12
|
+
def train_model(
|
|
13
|
+
copick_config_path: str,
|
|
14
|
+
target_info: Tuple[str, str, str],
|
|
15
|
+
tomo_algorithm: str = 'wbp',
|
|
16
|
+
voxel_size: float = 10,
|
|
17
|
+
trainRunIDs: List[str] = None,
|
|
18
|
+
validateRunIDs: List[str] = None,
|
|
19
|
+
model_config: str = None,
|
|
20
|
+
model_weights: Optional[str] = None,
|
|
21
|
+
model_save_path: str = 'results',
|
|
22
|
+
num_tomo_crops: int = 16,
|
|
23
|
+
tomo_batch_size: int = 15,
|
|
24
|
+
lr: float = 1e-3,
|
|
25
|
+
tversky_alpha: float = 0.5,
|
|
26
|
+
num_epochs: int = 100,
|
|
27
|
+
val_interval: int = 5,
|
|
28
|
+
best_metric: str = 'avg_f1',
|
|
29
|
+
data_split: str = '0.8'
|
|
30
|
+
):
|
|
31
|
+
|
|
32
|
+
# Initialize the data generator to manage training and validation datasets
|
|
33
|
+
print(f'Training with {copick_config_path}\n')
|
|
34
|
+
if isinstance(copick_config_path, dict):
|
|
35
|
+
# Multi-config training
|
|
36
|
+
data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
|
|
37
|
+
copick_config_path,
|
|
38
|
+
target_info[0],
|
|
39
|
+
target_session_id = target_info[2],
|
|
40
|
+
target_user_id = target_info[1],
|
|
41
|
+
tomo_algorithm = tomo_algorithm,
|
|
42
|
+
voxel_size = voxel_size,
|
|
43
|
+
Nclasses = model_config['num_classes'],
|
|
44
|
+
tomo_batch_size = tomo_batch_size )
|
|
45
|
+
else:
|
|
46
|
+
# Single-config training
|
|
47
|
+
data_generator = generators.TrainLoaderManager(
|
|
48
|
+
copick_config_path,
|
|
49
|
+
target_info[0],
|
|
50
|
+
target_session_id = target_info[2],
|
|
51
|
+
target_user_id = target_info[1],
|
|
52
|
+
tomo_algorithm = tomo_algorithm,
|
|
53
|
+
voxel_size = voxel_size,
|
|
54
|
+
Nclasses = model_config['num_classes'],
|
|
55
|
+
tomo_batch_size = tomo_batch_size )
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Get the data splits
|
|
59
|
+
ratios = utils.parse_data_split(data_split)
|
|
60
|
+
data_generator.get_data_splits(
|
|
61
|
+
trainRunIDs = trainRunIDs,
|
|
62
|
+
validateRunIDs = validateRunIDs,
|
|
63
|
+
train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
|
|
64
|
+
create_test_dataset = False)
|
|
65
|
+
|
|
66
|
+
# Get the reload frequency
|
|
67
|
+
data_generator.get_reload_frequency(num_epochs)
|
|
68
|
+
|
|
69
|
+
# Monai Functions
|
|
70
|
+
alpha = tversky_alpha
|
|
71
|
+
beta = 1 - alpha
|
|
72
|
+
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
|
|
73
|
+
metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
|
|
74
|
+
|
|
75
|
+
# Build the Model
|
|
76
|
+
model_builder = builder.get_model(model_config['architecture'])
|
|
77
|
+
model = model_builder.build_model(model_config)
|
|
78
|
+
|
|
79
|
+
# Load the Model Weights if Provided
|
|
80
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
81
|
+
if model_weights:
|
|
82
|
+
state_dict = torch.load(model_weights, map_location=device, weights_only=True)
|
|
83
|
+
model.load_state_dict(state_dict)
|
|
84
|
+
model.to(device)
|
|
85
|
+
|
|
86
|
+
# Optimizer
|
|
87
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-4)
|
|
88
|
+
|
|
89
|
+
# Create UNet-Trainer
|
|
90
|
+
model_trainer = trainer.ModelTrainer(model, device, loss_function, metrics_function, optimizer)
|
|
91
|
+
|
|
92
|
+
results = model_trainer.train(
|
|
93
|
+
data_generator, model_save_path, max_epochs=num_epochs,
|
|
94
|
+
crop_size=model_config['dim_in'], my_num_samples=num_tomo_crops,
|
|
95
|
+
val_interval=val_interval, best_metric=best_metric, verbose=True
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Save parameters and results
|
|
99
|
+
parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
|
|
100
|
+
io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
|
|
101
|
+
|
|
102
|
+
# TODO: Write Results to Zarr or Another File Format?
|
|
103
|
+
results_save_name = os.path.join(model_save_path, "results.json")
|
|
104
|
+
io.save_results_to_json(results, results_save_name)
|
|
105
|
+
|
|
106
|
+
def train_model_parser(parser_description, add_slurm: bool = False):
|
|
107
|
+
"""
|
|
108
|
+
Parse the arguments for the training model
|
|
109
|
+
"""
|
|
110
|
+
parser = argparse.ArgumentParser(
|
|
111
|
+
description=parser_description,
|
|
112
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
113
|
+
)
|
|
114
|
+
# Input Arguments
|
|
115
|
+
input_group = parser.add_argument_group("Input Arguments")
|
|
116
|
+
common.add_config(input_group, single_config=False)
|
|
117
|
+
input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
|
|
118
|
+
help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
|
|
119
|
+
input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
|
|
120
|
+
input_group.add_argument("--trainRunIDs", type=utils.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
|
|
121
|
+
input_group.add_argument("--validateRunIDs", type=utils.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
|
|
122
|
+
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
123
|
+
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
124
|
+
|
|
125
|
+
fine_tune_group = parser.add_argument_group("Fine-Tuning Arguments")
|
|
126
|
+
fine_tune_group.add_argument('--model-config', type=str, help="Path to the model configuration file (typically used for fine-tuning)")
|
|
127
|
+
fine_tune_group.add_argument('--model-weights', type=str, help="Path to the model weights file (typically used for fine-tuning)")
|
|
128
|
+
|
|
129
|
+
# Model Arguments
|
|
130
|
+
model_group = parser.add_argument_group("UNet-Model Arguments")
|
|
131
|
+
common.add_model_parameters(model_group)
|
|
132
|
+
|
|
133
|
+
# Training Arguments
|
|
134
|
+
train_group = parser.add_argument_group("Training Arguments")
|
|
135
|
+
common.add_train_parameters(train_group)
|
|
136
|
+
|
|
137
|
+
# SLURM Arguments
|
|
138
|
+
if add_slurm:
|
|
139
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
140
|
+
common.add_slurm_parameters(slurm_group, 'train', gpus = 1)
|
|
141
|
+
|
|
142
|
+
args = parser.parse_args()
|
|
143
|
+
return args
|
|
144
|
+
|
|
145
|
+
# Entry point with argparse
|
|
146
|
+
def cli():
|
|
147
|
+
"""
|
|
148
|
+
CLI entry point for training models where results can either be saved to a local directory or a server with MLFlow.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
# Parse the arguments
|
|
152
|
+
parser_description = "Train 3D CNN U-Net models"
|
|
153
|
+
args = train_model_parser(parser_description)
|
|
154
|
+
|
|
155
|
+
# Parse the CoPick configuration paths
|
|
156
|
+
if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
|
|
157
|
+
else: copick_configs = args.config[0]
|
|
158
|
+
|
|
159
|
+
if args.model_config:
|
|
160
|
+
model_config = utils.load_yaml(args.model_config)
|
|
161
|
+
else:
|
|
162
|
+
model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
|
|
163
|
+
|
|
164
|
+
# Call the training function
|
|
165
|
+
train_model(
|
|
166
|
+
copick_config_path=copick_configs,
|
|
167
|
+
target_info=args.target_info,
|
|
168
|
+
tomo_algorithm=args.tomo_alg,
|
|
169
|
+
voxel_size=args.voxel_size,
|
|
170
|
+
model_config=model_config,
|
|
171
|
+
model_weights=args.model_weights,
|
|
172
|
+
model_save_path=args.model_save_path,
|
|
173
|
+
num_tomo_crops=args.num_tomo_crops,
|
|
174
|
+
tomo_batch_size=args.tomo_batch_size,
|
|
175
|
+
lr=args.lr,
|
|
176
|
+
tversky_alpha=args.tversky_alpha,
|
|
177
|
+
num_epochs=args.num_epochs,
|
|
178
|
+
val_interval=args.val_interval,
|
|
179
|
+
best_metric=args.best_metric,
|
|
180
|
+
trainRunIDs=args.trainRunIDs,
|
|
181
|
+
validateRunIDs=args.validateRunIDs,
|
|
182
|
+
data_split=args.data_split
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def get_model_config(channels, strides, res_units, Nclass, dim_in):
|
|
186
|
+
"""
|
|
187
|
+
Create a model configuration dictionary if no model configuration file is provided.
|
|
188
|
+
"""
|
|
189
|
+
model_config = {
|
|
190
|
+
'architecture': 'Unet',
|
|
191
|
+
'channels': channels,
|
|
192
|
+
'strides': strides,
|
|
193
|
+
'num_res_units': res_units,
|
|
194
|
+
'num_classes': Nclass,
|
|
195
|
+
'dropout': 0.1,
|
|
196
|
+
'dim_in': dim_in
|
|
197
|
+
}
|
|
198
|
+
return model_config
|
|
199
|
+
|
|
200
|
+
if __name__ == "__main__":
|
|
201
|
+
cli()
|
|
File without changes
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from skimage.morphology import binary_erosion, binary_dilation, ball
|
|
2
|
+
from scipy.cluster.hierarchy import fcluster, linkage
|
|
3
|
+
from skimage.segmentation import watershed
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
from skimage.measure import regionprops
|
|
6
|
+
from scipy.spatial import distance
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from octopi import io
|
|
9
|
+
import scipy.ndimage as ndi
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import numpy as np
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
def processs_localization(run,
|
|
15
|
+
objects,
|
|
16
|
+
seg_info: Tuple[str, str, str],
|
|
17
|
+
method: str = 'com',
|
|
18
|
+
voxel_size: float = 10,
|
|
19
|
+
filter_size: int = None,
|
|
20
|
+
radius_min_scale: float = 0.5,
|
|
21
|
+
radius_max_scale: float = 1.0,
|
|
22
|
+
pick_session_id: str = '1',
|
|
23
|
+
pick_user_id: str = 'monai'):
|
|
24
|
+
|
|
25
|
+
# Check if method is valid
|
|
26
|
+
if method not in ['watershed', 'com']:
|
|
27
|
+
raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
|
|
28
|
+
|
|
29
|
+
# Get Segmentation
|
|
30
|
+
seg = io.get_segmentation_array(run,
|
|
31
|
+
voxel_size,
|
|
32
|
+
seg_info[0],
|
|
33
|
+
user_id=seg_info[1],
|
|
34
|
+
session_id=seg_info[2],
|
|
35
|
+
raise_error=False)
|
|
36
|
+
|
|
37
|
+
# Preprocess Segmentation
|
|
38
|
+
# seg = preprocess_segmentation(seg, voxel_size, objects)
|
|
39
|
+
|
|
40
|
+
# If No Segmentation is Found, Return
|
|
41
|
+
if seg is None:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
# Iterate through all user pickable objects
|
|
45
|
+
for obj in objects:
|
|
46
|
+
|
|
47
|
+
# Extract Particle Radius from Root
|
|
48
|
+
min_radius = obj[2] * radius_min_scale / voxel_size
|
|
49
|
+
max_radius = obj[2] * radius_max_scale / voxel_size
|
|
50
|
+
|
|
51
|
+
if method == 'watershed':
|
|
52
|
+
points = extract_particle_centroids_via_watershed(seg, obj[1], filter_size, min_radius, max_radius)
|
|
53
|
+
elif method == 'com':
|
|
54
|
+
points = extract_particle_centroids_via_com(seg, obj[1], min_radius, max_radius)
|
|
55
|
+
points = np.array(points)
|
|
56
|
+
|
|
57
|
+
# Save Coordinates if any 3D points are provided
|
|
58
|
+
if points.size > 2:
|
|
59
|
+
|
|
60
|
+
# Remove Picks that are too close to each other
|
|
61
|
+
# points = remove_repeated_picks(points, min_radius, pixelSize = voxel_size)
|
|
62
|
+
|
|
63
|
+
# Swap the coordinates to match the expected format
|
|
64
|
+
points = points[:,[2,1,0]]
|
|
65
|
+
|
|
66
|
+
# Convert the Picks back to Angstrom
|
|
67
|
+
points *= voxel_size
|
|
68
|
+
|
|
69
|
+
# Save Picks
|
|
70
|
+
try:
|
|
71
|
+
picks = run.new_picks(object_name = obj[0], session_id = pick_session_id, user_id=pick_user_id)
|
|
72
|
+
except:
|
|
73
|
+
picks = run.get_picks(object_name = obj[0], session_id = pick_session_id, user_id=pick_user_id)[0]
|
|
74
|
+
|
|
75
|
+
# Assign Identity As Orientation
|
|
76
|
+
orientations = np.zeros([points.shape[0], 4, 4])
|
|
77
|
+
orientations[:,:3,:3] = np.identity(3)
|
|
78
|
+
orientations[:,3,3] = 1
|
|
79
|
+
|
|
80
|
+
picks.from_numpy( points, orientations )
|
|
81
|
+
else:
|
|
82
|
+
print(f"{run.name} didn't have any available picks for {obj[0]}!")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def extract_particle_centroids_via_watershed(
|
|
86
|
+
segmentation,
|
|
87
|
+
segmentation_idx,
|
|
88
|
+
maxima_filter_size,
|
|
89
|
+
min_particle_radius,
|
|
90
|
+
max_particle_radius):
|
|
91
|
+
"""
|
|
92
|
+
Process a specific label in the segmentation, extract centroids, and save them as picks.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
segmentation (np.ndarray): Multilabel segmentation array.
|
|
96
|
+
segmentation_idx (int): The specific label from the segmentation to process.
|
|
97
|
+
maxima_filter_size (int): Size of the maximum detection filter.
|
|
98
|
+
min_particle_size (int): Minimum size threshold for particles.
|
|
99
|
+
max_particle_size (int): Maximum size threshold for particles.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
if maxima_filter_size is None or maxima_filter_size < 0:
|
|
103
|
+
AssertionError('Enter a Non-Zero Filter Size!')
|
|
104
|
+
|
|
105
|
+
# Calculate minimum and maximum particle volumes based on the given radii
|
|
106
|
+
min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
|
|
107
|
+
max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
|
|
108
|
+
|
|
109
|
+
# Create a binary mask for the specific segmentation label
|
|
110
|
+
binary_mask = (segmentation == segmentation_idx).astype(int)
|
|
111
|
+
|
|
112
|
+
# Skip if the segmentation label is not present
|
|
113
|
+
if np.sum(binary_mask) == 0:
|
|
114
|
+
print(f"No segmentation with label {segmentation_idx} found.")
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Structuring element for erosion and dilation
|
|
118
|
+
struct_elem = ball(1)
|
|
119
|
+
eroded = binary_erosion(binary_mask, struct_elem)
|
|
120
|
+
dilated = binary_dilation(eroded, struct_elem)
|
|
121
|
+
|
|
122
|
+
# Distance transform and local maxima detection
|
|
123
|
+
distance = ndi.distance_transform_edt(dilated)
|
|
124
|
+
local_max = (distance == ndi.maximum_filter(distance, footprint=np.ones((maxima_filter_size, maxima_filter_size, maxima_filter_size))))
|
|
125
|
+
|
|
126
|
+
# Watershed segmentation
|
|
127
|
+
markers, _ = ndi.label(local_max)
|
|
128
|
+
watershed_labels = watershed(-distance, markers, mask=dilated)
|
|
129
|
+
|
|
130
|
+
# Extract region properties and filter based on particle size
|
|
131
|
+
all_centroids = []
|
|
132
|
+
for region in regionprops(watershed_labels):
|
|
133
|
+
if min_particle_size <= region.area <= max_particle_size:
|
|
134
|
+
|
|
135
|
+
# Option 1: Use all centroids
|
|
136
|
+
all_centroids.append(region.centroid)
|
|
137
|
+
|
|
138
|
+
return all_centroids
|
|
139
|
+
|
|
140
|
+
def extract_particle_centroids_via_com(
|
|
141
|
+
segmentation,
|
|
142
|
+
segmentation_idx,
|
|
143
|
+
min_particle_radius,
|
|
144
|
+
max_particle_radius
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Process a specific label in the segmentation, extract centroids, and save them as picks.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
segmentation (np.ndarray): Multilabel segmentation array.
|
|
151
|
+
segmentation_idx (int): The specific label from the segmentation to process.
|
|
152
|
+
min_particle_size (int): Minimum size threshold for particles.
|
|
153
|
+
max_particle_size (int): Maximum size threshold for particles.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Calculate minimum and maximum particle volumes based on the given radii
|
|
157
|
+
min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
|
|
158
|
+
max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
|
|
159
|
+
|
|
160
|
+
# Create a binary mask for the specific segmentation label
|
|
161
|
+
label_objs, _ = ndi.label(segmentation == segmentation_idx)
|
|
162
|
+
|
|
163
|
+
# Filter Candidates based on Object Size
|
|
164
|
+
# Get the sizes of all objects
|
|
165
|
+
object_sizes = np.bincount(label_objs.flat)
|
|
166
|
+
|
|
167
|
+
# Filter the objects based on size
|
|
168
|
+
valid_objects = np.where((object_sizes > min_particle_size) & (object_sizes < max_particle_size))[0]
|
|
169
|
+
|
|
170
|
+
# Estimate Coordiantes from CoM for LabelMaps
|
|
171
|
+
octopiCoords = []
|
|
172
|
+
for object_num in tqdm(valid_objects):
|
|
173
|
+
com = ndi.center_of_mass(label_objs == object_num)
|
|
174
|
+
swapped_com = (com[2], com[1], com[0])
|
|
175
|
+
octopiCoords.append(swapped_com)
|
|
176
|
+
|
|
177
|
+
return octopiCoords
|
|
178
|
+
|
|
179
|
+
def remove_repeated_picks(coordinates, distanceThreshold, pixelSize = 1):
|
|
180
|
+
|
|
181
|
+
# Calculate the distance matrix for the 3D coordinates
|
|
182
|
+
dist_matrix = distance.cdist(coordinates[:, :3]/pixelSize, coordinates[:, :3]/pixelSize)
|
|
183
|
+
|
|
184
|
+
# Create a linkage matrix using single linkage method
|
|
185
|
+
Z = linkage(dist_matrix, method='complete')
|
|
186
|
+
|
|
187
|
+
# Form flat clusters with a distance threshold to determine groups
|
|
188
|
+
clusters = fcluster(Z, t=distanceThreshold, criterion='distance')
|
|
189
|
+
|
|
190
|
+
# Initialize an array to store the average of each group
|
|
191
|
+
unique_coordinates = np.zeros((max(clusters), coordinates.shape[1]))
|
|
192
|
+
|
|
193
|
+
# Calculate the mean for each cluster
|
|
194
|
+
for i in range(1, max(clusters) + 1):
|
|
195
|
+
unique_coordinates[i-1] = np.mean(coordinates[clusters == i], axis=0)
|
|
196
|
+
|
|
197
|
+
return unique_coordinates
|
|
198
|
+
|
|
199
|
+
def preprocess_segmentation(segmentation, voxel_size, particle_info):
|
|
200
|
+
"""
|
|
201
|
+
Remove tiny fragments that aren't real particles
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
segmentation (np.ndarray): The multilabel segmentation array
|
|
205
|
+
particle_info (list): List of tuples containing (name, segment_id, radius)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
np.ndarray: Processed segmentation with small fragments removed
|
|
209
|
+
"""
|
|
210
|
+
import numpy as np
|
|
211
|
+
from skimage.morphology import remove_small_objects
|
|
212
|
+
|
|
213
|
+
processed_seg = segmentation.copy()
|
|
214
|
+
|
|
215
|
+
# Map segment IDs to particle types and their minimum sizes
|
|
216
|
+
segment_to_info = {}
|
|
217
|
+
for name, segment_id, radius in particle_info:
|
|
218
|
+
# # For small particles, use a larger minimum size
|
|
219
|
+
# if radius < 135:
|
|
220
|
+
# scale = 0.65
|
|
221
|
+
# # Normal threshold for other particles
|
|
222
|
+
# else:
|
|
223
|
+
# scale = 0.4
|
|
224
|
+
scale = 0.3
|
|
225
|
+
radius = radius / voxel_size
|
|
226
|
+
min_size = (4/3) * np.pi * ((radius * 0.5) ** 3)
|
|
227
|
+
|
|
228
|
+
segment_to_info[segment_id] = {
|
|
229
|
+
'name': name,
|
|
230
|
+
'min_size': min_size
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Get unique labels
|
|
234
|
+
unique_labels = np.unique(segmentation)
|
|
235
|
+
unique_labels = unique_labels[unique_labels > 0] # Skip background
|
|
236
|
+
|
|
237
|
+
# Process each label
|
|
238
|
+
for label in unique_labels:
|
|
239
|
+
if label not in segment_to_info:
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
# Create binary mask for this label
|
|
243
|
+
mask = segmentation == label
|
|
244
|
+
|
|
245
|
+
# Get minimum size for this particle type
|
|
246
|
+
min_size = segment_to_info[label]['min_size']
|
|
247
|
+
|
|
248
|
+
# Remove small objects
|
|
249
|
+
cleaned_mask = remove_small_objects(mask, min_size=min_size * scale)
|
|
250
|
+
|
|
251
|
+
# Update segmentation
|
|
252
|
+
processed_seg[mask & ~cleaned_mask] = 0
|
|
253
|
+
|
|
254
|
+
return processed_seg
|