octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from monai.transforms import (
|
|
2
|
+
Compose,
|
|
3
|
+
RandFlipd,
|
|
4
|
+
Orientationd,
|
|
5
|
+
RandRotate90d,
|
|
6
|
+
NormalizeIntensityd,
|
|
7
|
+
EnsureChannelFirstd,
|
|
8
|
+
RandCropByLabelClassesd,
|
|
9
|
+
RandScaleIntensityd,
|
|
10
|
+
RandShiftIntensityd,
|
|
11
|
+
RandAdjustContrastd,
|
|
12
|
+
RandGaussianNoised,
|
|
13
|
+
ScaleIntensityRanged,
|
|
14
|
+
RandomOrder,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def get_transforms():
|
|
18
|
+
"""
|
|
19
|
+
Returns non-random transforms.
|
|
20
|
+
"""
|
|
21
|
+
return Compose([
|
|
22
|
+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
|
|
23
|
+
NormalizeIntensityd(keys="image"),
|
|
24
|
+
Orientationd(keys=["image", "label"], axcodes="RAS")
|
|
25
|
+
])
|
|
26
|
+
|
|
27
|
+
def get_random_transforms( input_dim, num_samples, Nclasses):
|
|
28
|
+
"""
|
|
29
|
+
Input:
|
|
30
|
+
input_dim: tuple of (nx, ny, nz)
|
|
31
|
+
num_samples: int
|
|
32
|
+
Nclasses: int
|
|
33
|
+
|
|
34
|
+
Returns random transforms.
|
|
35
|
+
|
|
36
|
+
For data with a missing wedge along the first axis (causing smearing in that direction),
|
|
37
|
+
we avoid rotations that would move this artifact to other axes. We only rotate around
|
|
38
|
+
the first axis (spatial_axes=[1, 2]) and avoid flipping along the first axis.
|
|
39
|
+
"""
|
|
40
|
+
return Compose([
|
|
41
|
+
RandCropByLabelClassesd(
|
|
42
|
+
keys=["image", "label"],
|
|
43
|
+
label_key="label",
|
|
44
|
+
spatial_size=[input_dim[0], input_dim[1], input_dim[2]],
|
|
45
|
+
num_classes=Nclasses,
|
|
46
|
+
num_samples=num_samples
|
|
47
|
+
),
|
|
48
|
+
# Only rotate around the first axis (keeping the missing wedge orientation consistent)
|
|
49
|
+
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2], max_k=3),
|
|
50
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
|
51
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
|
|
52
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
|
53
|
+
RandomOrder([
|
|
54
|
+
# Intensity augmentations are still appropriate
|
|
55
|
+
RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)),
|
|
56
|
+
RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)),
|
|
57
|
+
RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)),
|
|
58
|
+
RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=0.5), # Reduced noise std
|
|
59
|
+
]),
|
|
60
|
+
])
|
|
61
|
+
|
|
62
|
+
# Augmentations to Explore in the Future:
|
|
63
|
+
# Intensity-based augmentations
|
|
64
|
+
# RandHistogramShiftd(keys="image", prob=0.5, num_control_points=(3, 5))
|
|
65
|
+
# RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
|
|
66
|
+
|
|
67
|
+
# Geometric Transforms
|
|
68
|
+
# RandAffined(
|
|
69
|
+
# keys=["image", "label"],
|
|
70
|
+
# rotate_range=(0.1, 0.1, 0.1), # Rotation angles (radians) for x, y, z axes
|
|
71
|
+
# scale_range=(0.1, 0.1, 0.1), # Scale range for isotropic/anisotropic scaling
|
|
72
|
+
# prob=0.5, # Probability of applying the transform
|
|
73
|
+
# padding_mode="border" # Handle out-of-bounds values
|
|
74
|
+
# )
|
|
75
|
+
|
|
76
|
+
def get_predict_transforms():
|
|
77
|
+
"""
|
|
78
|
+
Returns predict transforms.
|
|
79
|
+
"""
|
|
80
|
+
return Compose([
|
|
81
|
+
EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
82
|
+
NormalizeIntensityd(keys="image")
|
|
83
|
+
])
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import List, Tuple, Callable, Optional, Dict, Any
|
|
2
|
+
from monai.transforms import Compose
|
|
3
|
+
from monai.data import CacheDataset
|
|
4
|
+
from octopi.datasets import io
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
import os, sys
|
|
7
|
+
|
|
8
|
+
class MultiConfigCacheDataset(CacheDataset):
|
|
9
|
+
"""
|
|
10
|
+
A custom CacheDataset that loads data lazily from multiple sources
|
|
11
|
+
with consolidated loading and caching process.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
manager,
|
|
17
|
+
run_ids: List[Tuple[str, str]],
|
|
18
|
+
transform: Optional[Callable] = None,
|
|
19
|
+
cache_rate: float = 1.0,
|
|
20
|
+
num_workers: int = 0,
|
|
21
|
+
progress: bool = True,
|
|
22
|
+
copy_cache: bool = True,
|
|
23
|
+
cache_num: int = sys.maxsize
|
|
24
|
+
):
|
|
25
|
+
# Save reference to manager and run_ids
|
|
26
|
+
self.manager = manager
|
|
27
|
+
self.run_ids = run_ids
|
|
28
|
+
self.progress = progress
|
|
29
|
+
|
|
30
|
+
# Prepare empty data list first - don't load immediately
|
|
31
|
+
self.data = []
|
|
32
|
+
|
|
33
|
+
# Initialize the parent CacheDataset with an empty list
|
|
34
|
+
# We'll override the _fill_cache method to handle loading and caching in one step
|
|
35
|
+
super().__init__(
|
|
36
|
+
data=[], # Empty list - we'll load data in _fill_cache
|
|
37
|
+
transform=transform,
|
|
38
|
+
cache_rate=cache_rate,
|
|
39
|
+
num_workers=num_workers,
|
|
40
|
+
progress=False, # We'll handle our own progress
|
|
41
|
+
copy_cache=copy_cache,
|
|
42
|
+
cache_num=cache_num
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def _fill_cache(self):
|
|
46
|
+
"""
|
|
47
|
+
Override the parent's _fill_cache method to combine loading and caching.
|
|
48
|
+
"""
|
|
49
|
+
if self.progress:
|
|
50
|
+
print("Loading and caching dataset...")
|
|
51
|
+
|
|
52
|
+
# Load and process data in a single operation
|
|
53
|
+
self.data = []
|
|
54
|
+
iterator = tqdm(self.run_ids, desc="Loading dataset") if self.progress else self.run_ids
|
|
55
|
+
|
|
56
|
+
for session_name, run_name in iterator:
|
|
57
|
+
root = self.manager.roots[session_name]
|
|
58
|
+
batch_data = io.load_training_data(
|
|
59
|
+
root,
|
|
60
|
+
[run_name],
|
|
61
|
+
self.manager.voxel_size,
|
|
62
|
+
self.manager.tomo_algorithm,
|
|
63
|
+
self.manager.target_name,
|
|
64
|
+
self.manager.target_session_id,
|
|
65
|
+
self.manager.target_user_id,
|
|
66
|
+
progress_update=False
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.data.extend(batch_data)
|
|
70
|
+
|
|
71
|
+
# Process and cache this batch right away
|
|
72
|
+
for i, item in enumerate(batch_data):
|
|
73
|
+
if len(self._cache) < self.cache_num and self.cache_rate > 0.0:
|
|
74
|
+
if np.random.random() < self.cache_rate:
|
|
75
|
+
self._cache.append(self._transform(item))
|
|
76
|
+
|
|
77
|
+
# Check max label value if needed
|
|
78
|
+
if hasattr(self.manager, '_check_max_label_value'):
|
|
79
|
+
self.manager._check_max_label_value(self.data)
|
|
80
|
+
|
|
81
|
+
# Update the _data attribute to match the loaded data
|
|
82
|
+
self._data = self.data
|
|
83
|
+
|
|
84
|
+
def __len__(self):
|
|
85
|
+
"""
|
|
86
|
+
Return the length of the dataset.
|
|
87
|
+
"""
|
|
88
|
+
if not self.data:
|
|
89
|
+
self._fill_cache() # Load data if not loaded yet
|
|
90
|
+
return len(self.data)
|
|
91
|
+
|
|
92
|
+
def __getitem__(self, index):
|
|
93
|
+
"""
|
|
94
|
+
Return the item at the given index.
|
|
95
|
+
"""
|
|
96
|
+
if not self.data:
|
|
97
|
+
self._fill_cache() # Load data if not loaded yet
|
|
98
|
+
|
|
99
|
+
# Use parent's logic for cached items
|
|
100
|
+
if index < len(self._cache):
|
|
101
|
+
return self._cache[index]
|
|
102
|
+
|
|
103
|
+
# Otherwise transform on-the-fly
|
|
104
|
+
return self._transform(self.data[index])
|
|
105
|
+
|
|
106
|
+
# TODO: Implement Single Config Cache Dataset
|
|
107
|
+
# class SingleConfigCacheDataset(CacheDataset):
|
|
108
|
+
# def __init__(self,
|
|
109
|
+
# root: Any,
|
|
110
|
+
# run_ids: List[str],
|
|
111
|
+
# voxel_size: float,
|
|
112
|
+
# tomo_algorithm: str,
|
|
113
|
+
# target_name: str,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from torch.utils.data import Dataset
|
|
2
|
+
|
|
3
|
+
class DynamicDataset(Dataset):
|
|
4
|
+
def __init__(self, data, transform=None):
|
|
5
|
+
self.data = data
|
|
6
|
+
self.transform = transform
|
|
7
|
+
|
|
8
|
+
def __len__(self):
|
|
9
|
+
return len(self.data)
|
|
10
|
+
|
|
11
|
+
def __getitem__(self, idx):
|
|
12
|
+
sample = self.data[idx]
|
|
13
|
+
if self.transform:
|
|
14
|
+
sample = self.transform(sample)
|
|
15
|
+
return sample
|
|
16
|
+
|
|
17
|
+
def update_data(self, new_data):
|
|
18
|
+
"""Update the internal dataset with new data."""
|
|
19
|
+
self.data = new_data
|
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
from octopi.datasets import dataset, augment, cached_datset
|
|
2
|
+
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from octopi.utils import io as io2
|
|
5
|
+
from octopi.datasets import io
|
|
6
|
+
import torch, os, random, gc
|
|
7
|
+
import multiprocess as mp
|
|
8
|
+
|
|
9
|
+
class TrainLoaderManager:
|
|
10
|
+
|
|
11
|
+
def __init__(self,
|
|
12
|
+
config: str,
|
|
13
|
+
target_name: str,
|
|
14
|
+
target_session_id: str = None,
|
|
15
|
+
target_user_id: str = None,
|
|
16
|
+
voxel_size: float = 10,
|
|
17
|
+
tomo_algorithm: List[str] = ['wbp'],
|
|
18
|
+
tomo_batch_size: int = 15
|
|
19
|
+
):
|
|
20
|
+
|
|
21
|
+
# Read Copick Projectdd
|
|
22
|
+
self.config = config
|
|
23
|
+
self.root = io.load_copick_config(config)
|
|
24
|
+
|
|
25
|
+
# Copick Query for Target
|
|
26
|
+
self.target_name = target_name
|
|
27
|
+
self.target_session_id = target_session_id
|
|
28
|
+
self.target_user_id = target_user_id
|
|
29
|
+
|
|
30
|
+
# Copick Query For Input Tomogram
|
|
31
|
+
self.voxel_size = voxel_size
|
|
32
|
+
self.tomo_algorithm = tomo_algorithm
|
|
33
|
+
|
|
34
|
+
self.reload_training_dataset = True
|
|
35
|
+
self.reload_validation_dataset = True
|
|
36
|
+
self.val_loader = None
|
|
37
|
+
self.train_loader = None
|
|
38
|
+
self.tomo_batch_size = tomo_batch_size
|
|
39
|
+
|
|
40
|
+
# Initialize the input dimensions
|
|
41
|
+
self.nx = None
|
|
42
|
+
self.ny = None
|
|
43
|
+
self.nz = None
|
|
44
|
+
|
|
45
|
+
def get_available_runIDs(self):
|
|
46
|
+
"""
|
|
47
|
+
Identify and return a list of run IDs that have segmentations available for the target.
|
|
48
|
+
|
|
49
|
+
- Iterates through all runs in the project to check for segmentations that match
|
|
50
|
+
the specified target name, session ID, and user ID.
|
|
51
|
+
- Only includes runs that have at least one matching segmentation.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
available_runIDs (list): List of run IDs with available segmentations.
|
|
55
|
+
"""
|
|
56
|
+
available_runIDs = []
|
|
57
|
+
runIDs = [run.name for run in self.root.runs]
|
|
58
|
+
for run in runIDs:
|
|
59
|
+
run = self.root.get_run(run)
|
|
60
|
+
seg = run.get_segmentations(name=self.target_name,
|
|
61
|
+
session_id=self.target_session_id,
|
|
62
|
+
user_id=self.target_user_id,
|
|
63
|
+
voxel_size=float(self.voxel_size))
|
|
64
|
+
if len(seg) > 0:
|
|
65
|
+
available_runIDs.append(run.name)
|
|
66
|
+
|
|
67
|
+
# If No Segmentations are Found, Inform the User
|
|
68
|
+
if len(available_runIDs) == 0:
|
|
69
|
+
print(
|
|
70
|
+
f"[Error] No segmentations found for the target query:\n"
|
|
71
|
+
f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
|
|
72
|
+
f"SessionID: {self.target_session_id}\n"
|
|
73
|
+
f"Please check the target name, user ID, and session ID.\n"
|
|
74
|
+
)
|
|
75
|
+
exit()
|
|
76
|
+
|
|
77
|
+
return available_runIDs
|
|
78
|
+
|
|
79
|
+
def get_data_splits(self,
|
|
80
|
+
trainRunIDs: str = None,
|
|
81
|
+
validateRunIDs: str = None,
|
|
82
|
+
train_ratio: float = 0.8,
|
|
83
|
+
val_ratio: float = 0.2,
|
|
84
|
+
test_ratio: float = 0.0,
|
|
85
|
+
create_test_dataset: bool = False):
|
|
86
|
+
"""
|
|
87
|
+
Split the available data into training, validation, and testing sets based on input parameters.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
trainRunIDs (str): Predefined list of run IDs for training. If provided, it overrides splitting logic.
|
|
91
|
+
validateRunIDs (str): Predefined list of run IDs for validation. If provided with trainRunIDs, no splitting occurs.
|
|
92
|
+
train_ratio (float): Proportion of available data to allocate to the training set.
|
|
93
|
+
val_ratio (float): Proportion of available data to allocate to the validation set.
|
|
94
|
+
test_ratio (float): Proportion of available data to allocate to the test set.
|
|
95
|
+
create_test_dataset (bool): Whether to create a test dataset or leave it empty.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
myRunIDs (dict): Dictionary containing run IDs for training, validation, and testing.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# Option 1: Only TrainRunIDs are Provided, Split into Train, Validate and Test (Optional)
|
|
102
|
+
if trainRunIDs is not None and validateRunIDs is None:
|
|
103
|
+
trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
|
|
104
|
+
trainRunIDs, train_ratio, val_ratio, test_ratio,
|
|
105
|
+
return_test_dataset = create_test_dataset
|
|
106
|
+
)
|
|
107
|
+
# Option 2: TrainRunIDs and ValidateRunIDs are Provided, No Need to Split
|
|
108
|
+
elif trainRunIDs is not None and validateRunIDs is not None:
|
|
109
|
+
testRunIDs = None
|
|
110
|
+
# Option 3: Use the Entire Copick Project, Split into Train, Validate and Test
|
|
111
|
+
else:
|
|
112
|
+
runIDs = self.get_available_runIDs()
|
|
113
|
+
trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
|
|
114
|
+
runIDs, train_ratio, val_ratio, test_ratio,
|
|
115
|
+
return_test_dataset = create_test_dataset
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Get Class Info from the Training Dataset
|
|
119
|
+
self._get_class_info(trainRunIDs)
|
|
120
|
+
|
|
121
|
+
# Swap if Test Runs is Larger than Validation Runs
|
|
122
|
+
if create_test_dataset and len(testRunIDs) > len(validateRunIDs):
|
|
123
|
+
testRunIDs, validateRunIDs = validateRunIDs, testRunIDs
|
|
124
|
+
|
|
125
|
+
# Determine if datasets fit entirely in memory based on the batch size
|
|
126
|
+
# If the validation set is smaller than the batch size, avoid reloading
|
|
127
|
+
if len(validateRunIDs) < self.tomo_batch_size:
|
|
128
|
+
self.reload_validation_dataset = False
|
|
129
|
+
|
|
130
|
+
# If the training set is smaller than the batch size, avoid reloading
|
|
131
|
+
if len(trainRunIDs) < self.tomo_batch_size:
|
|
132
|
+
self.reload_training_dataset = False
|
|
133
|
+
|
|
134
|
+
# Store the split run IDs into a dictionary for easy access
|
|
135
|
+
self.myRunIDs = {
|
|
136
|
+
'train': trainRunIDs,
|
|
137
|
+
'validate': validateRunIDs,
|
|
138
|
+
'test': testRunIDs
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
print(f"Number of training samples: {len(trainRunIDs)}")
|
|
142
|
+
print(f"Number of validation samples: {len(validateRunIDs)}")
|
|
143
|
+
if testRunIDs is not None:
|
|
144
|
+
print(f'Number of test samples: {len(testRunIDs)}')
|
|
145
|
+
|
|
146
|
+
# Define separate batch sizes
|
|
147
|
+
self.train_batch_size = min( len(self.myRunIDs['train']), self.tomo_batch_size)
|
|
148
|
+
self.val_batch_size = min( len(self.myRunIDs['validate']), self.tomo_batch_size)
|
|
149
|
+
|
|
150
|
+
# Initialize data iterators for training and validation
|
|
151
|
+
self._initialize_val_iterators()
|
|
152
|
+
self._initialize_train_iterators()
|
|
153
|
+
|
|
154
|
+
return self.myRunIDs
|
|
155
|
+
|
|
156
|
+
def _get_class_info(self, trainRunDs):
|
|
157
|
+
|
|
158
|
+
# Fetch a segmentation to determine class names and number of classes
|
|
159
|
+
for runID in trainRunDs:
|
|
160
|
+
run = self.root.get_run(runID)
|
|
161
|
+
seg = run.get_segmentations(name=self.target_name,
|
|
162
|
+
session_id=self.target_session_id,
|
|
163
|
+
user_id=self.target_user_id,
|
|
164
|
+
voxel_size=float(self.voxel_size))
|
|
165
|
+
if len(seg) == 0:
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# If Session ID or User ID are None, Set Them Based on the First Found Segmentation
|
|
169
|
+
if self.target_session_id is None:
|
|
170
|
+
self.target_session_id = seg[0].session_id
|
|
171
|
+
if self.target_user_id is None:
|
|
172
|
+
self.target_user_id = seg[0].user_id
|
|
173
|
+
|
|
174
|
+
# Read Yaml Config to Get Number of Classes and Class Names
|
|
175
|
+
target_config = io2.check_target_config_path(self)
|
|
176
|
+
class_names = target_config['input']['labels']
|
|
177
|
+
self.Nclasses = len(class_names) + 1
|
|
178
|
+
self.class_names = [name for name, idx in sorted(class_names.items(), key=lambda x: x[1])]
|
|
179
|
+
|
|
180
|
+
# We Only need to read One Segmentation to Get Class Info
|
|
181
|
+
break
|
|
182
|
+
|
|
183
|
+
def _get_padded_list(self, data_list, batch_size):
|
|
184
|
+
# Calculate padding needed to make `data_list` a multiple of `batch_size`
|
|
185
|
+
remainder = len(data_list) % batch_size
|
|
186
|
+
if remainder > 0:
|
|
187
|
+
# Number of additional items needed to make the length a multiple of batch size
|
|
188
|
+
padding_needed = batch_size - remainder
|
|
189
|
+
# Extend `data_list` with a random subset to achieve the padding
|
|
190
|
+
data_list = data_list + random.sample(data_list, padding_needed)
|
|
191
|
+
# Shuffle the full list
|
|
192
|
+
random.shuffle(data_list)
|
|
193
|
+
return data_list
|
|
194
|
+
|
|
195
|
+
def _initialize_train_iterators(self):
|
|
196
|
+
# Initialize padded train and validation data lists
|
|
197
|
+
self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
|
|
198
|
+
|
|
199
|
+
# Create iterators
|
|
200
|
+
self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
|
|
201
|
+
|
|
202
|
+
def _initialize_val_iterators(self):
|
|
203
|
+
# Initialize padded train and validation data lists
|
|
204
|
+
self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
|
|
205
|
+
|
|
206
|
+
# Create iterators
|
|
207
|
+
self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
|
|
208
|
+
|
|
209
|
+
def _get_data_batches(self, data_list, batch_size):
|
|
210
|
+
# Generator that yields batches of specified size
|
|
211
|
+
for i in range(0, len(data_list), batch_size):
|
|
212
|
+
yield data_list[i:i + batch_size]
|
|
213
|
+
|
|
214
|
+
def _extract_run_ids(self, data_iter_name, initialize_method):
|
|
215
|
+
# Access the instance's data iterator by name
|
|
216
|
+
data_iter = getattr(self, data_iter_name)
|
|
217
|
+
try:
|
|
218
|
+
# Attempt to get the next batch from the iterator
|
|
219
|
+
runIDs = next(data_iter)
|
|
220
|
+
except StopIteration:
|
|
221
|
+
# Reinitialize the iterator if exhausted
|
|
222
|
+
initialize_method()
|
|
223
|
+
# Update the iterator reference after reinitialization
|
|
224
|
+
data_iter = getattr(self, data_iter_name)
|
|
225
|
+
runIDs = next(data_iter)
|
|
226
|
+
# Update the instance attribute with the new iterator state
|
|
227
|
+
setattr(self, data_iter_name, data_iter)
|
|
228
|
+
return runIDs
|
|
229
|
+
|
|
230
|
+
def create_train_dataloaders(
|
|
231
|
+
self,
|
|
232
|
+
crop_size: int = 96,
|
|
233
|
+
num_samples: int = 64):
|
|
234
|
+
|
|
235
|
+
train_batch_size = 1
|
|
236
|
+
val_batch_size = 1
|
|
237
|
+
|
|
238
|
+
# If reloads are disabled and loaders already exist, reuse them
|
|
239
|
+
if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
|
|
240
|
+
return self.train_loader, self.val_loader
|
|
241
|
+
|
|
242
|
+
# We Only Need to Reload the Training Dataset if the Total Number of Runs is larger than
|
|
243
|
+
# the tomo batch size
|
|
244
|
+
if self.train_loader is None:
|
|
245
|
+
|
|
246
|
+
# Fetch the next batch of run IDs
|
|
247
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
248
|
+
train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
|
|
249
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
250
|
+
progress_update=False)
|
|
251
|
+
self._check_max_label_value(train_files)
|
|
252
|
+
|
|
253
|
+
# Create the cached dataset with non-random transforms
|
|
254
|
+
train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
255
|
+
|
|
256
|
+
# Delete the training files to free memory
|
|
257
|
+
train_files = None
|
|
258
|
+
gc.collect()
|
|
259
|
+
|
|
260
|
+
# I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
261
|
+
if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
|
|
262
|
+
self.input_dim = io.get_input_dimensions(train_ds, crop_size)
|
|
263
|
+
|
|
264
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
265
|
+
self.dynamic_train_dataset = dataset.DynamicDataset(
|
|
266
|
+
data=train_ds,
|
|
267
|
+
transform=augment.get_random_transforms(self.input_dim, num_samples, self.Nclasses)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Define the number of processes for the DataLoader
|
|
271
|
+
n_procs = min(mp.cpu_count(), 4)
|
|
272
|
+
|
|
273
|
+
# DataLoader remains the same
|
|
274
|
+
self.train_loader = DataLoader(
|
|
275
|
+
self.dynamic_train_dataset,
|
|
276
|
+
batch_size=train_batch_size,
|
|
277
|
+
shuffle=False,
|
|
278
|
+
num_workers=n_procs,
|
|
279
|
+
pin_memory=torch.cuda.is_available(),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
else:
|
|
283
|
+
# Fetch the next batch of run IDs
|
|
284
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
285
|
+
train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
|
|
286
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
287
|
+
progress_update=False)
|
|
288
|
+
self._check_max_label_value(train_files)
|
|
289
|
+
|
|
290
|
+
train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
291
|
+
self.dynamic_train_dataset.update_data(train_ds)
|
|
292
|
+
|
|
293
|
+
# We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
|
|
294
|
+
# the tomo batch size
|
|
295
|
+
if self.val_loader is None:
|
|
296
|
+
|
|
297
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
298
|
+
val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
|
|
299
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
300
|
+
progress_update=False)
|
|
301
|
+
self._check_max_label_value(val_files)
|
|
302
|
+
|
|
303
|
+
# Create validation dataset
|
|
304
|
+
val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
305
|
+
|
|
306
|
+
# Delete the validation files to free memory
|
|
307
|
+
val_files = None
|
|
308
|
+
gc.collect()
|
|
309
|
+
|
|
310
|
+
# # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
311
|
+
# if self.nx is None:
|
|
312
|
+
# (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
|
|
313
|
+
|
|
314
|
+
# if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
|
|
315
|
+
# else: self.input_dim = (crop_size, crop_size, crop_size)
|
|
316
|
+
|
|
317
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
318
|
+
self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
|
|
319
|
+
|
|
320
|
+
dataset_size = len(self.dynamic_validation_dataset)
|
|
321
|
+
n_procs = min(mp.cpu_count(), 8)
|
|
322
|
+
|
|
323
|
+
# Create validation DataLoader
|
|
324
|
+
self.val_loader = DataLoader(
|
|
325
|
+
self.dynamic_validation_dataset,
|
|
326
|
+
batch_size=val_batch_size,
|
|
327
|
+
num_workers=n_procs,
|
|
328
|
+
pin_memory=torch.cuda.is_available(),
|
|
329
|
+
shuffle=False, # Ensure the data order remains consistent,
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
333
|
+
val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
|
|
334
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
335
|
+
progress_update=False)
|
|
336
|
+
self._check_max_label_value(val_files)
|
|
337
|
+
|
|
338
|
+
return self.train_loader, self.val_loader
|
|
339
|
+
|
|
340
|
+
def get_reload_frequency(self, num_epochs: int):
|
|
341
|
+
"""
|
|
342
|
+
Automatically calculate the reload frequency for the dataset during training.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
int: Reload frequency (number of epochs between dataset reloads).
|
|
346
|
+
"""
|
|
347
|
+
if not self.reload_training_dataset:
|
|
348
|
+
# No need to reload if all tomograms fit in memory
|
|
349
|
+
print("All training samples fit in memory. No reloading required.")
|
|
350
|
+
self.reload_frequency = -1
|
|
351
|
+
|
|
352
|
+
else:
|
|
353
|
+
# Calculate the number of segments based on total training runs and batch size
|
|
354
|
+
num_segments = (len(self.myRunIDs['train']) + self.tomo_batch_size - 1) // self.tomo_batch_size
|
|
355
|
+
|
|
356
|
+
# Calculate reload frequency to distribute reloading evenly over epochs
|
|
357
|
+
self.reload_frequency = max(num_epochs // num_segments, 1)
|
|
358
|
+
|
|
359
|
+
print(f"\nReloading {self.tomo_batch_size} tomograms every {self.reload_frequency} epochs\n")
|
|
360
|
+
|
|
361
|
+
# Warn if the number of epochs is insufficient for full dataset coverage
|
|
362
|
+
if num_epochs < num_segments:
|
|
363
|
+
print(
|
|
364
|
+
f"Warning: Chosen number of epochs ({num_epochs}) may not be sufficient "
|
|
365
|
+
f"to train over all training samples. Consider increasing the number of epochs "
|
|
366
|
+
f"to at least {num_segments}\n."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def _check_max_label_value(self, train_files):
|
|
370
|
+
max_label_value = max(file['label'].max() for file in train_files)
|
|
371
|
+
if max_label_value > self.Nclasses:
|
|
372
|
+
print(f"Warning: Maximum class label value {max_label_value} exceeds the number of classes {self.Nclasses}.")
|
|
373
|
+
print("This may cause issues with the model's output layer.")
|
|
374
|
+
print("Consider adjusting the number of classes or the label values in your data.\n")
|
|
375
|
+
|
|
376
|
+
def get_dataloader_parameters(self):
|
|
377
|
+
|
|
378
|
+
parameters = {
|
|
379
|
+
'config': self.config,
|
|
380
|
+
'target_name': self.target_name,
|
|
381
|
+
'target_session_id': self.target_session_id,
|
|
382
|
+
'target_user_id': self.target_user_id,
|
|
383
|
+
'voxel_size': self.voxel_size,
|
|
384
|
+
'tomo_algorithm': self.tomo_algorithm,
|
|
385
|
+
'tomo_batch_size': self.tomo_batch_size,
|
|
386
|
+
'reload_frequency': self.reload_frequency,
|
|
387
|
+
'testRunIDs': self.myRunIDs['test'],
|
|
388
|
+
'valRunIDs': self.myRunIDs['validate'],
|
|
389
|
+
'trainRunIDs': self.myRunIDs['train'],
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
return parameters
|
|
393
|
+
|
|
394
|
+
class PredictLoaderManager:
|
|
395
|
+
|
|
396
|
+
def __init__(self,
|
|
397
|
+
config: str,
|
|
398
|
+
voxel_size: float = 10,
|
|
399
|
+
tomo_algorithm: str = 'wbp',
|
|
400
|
+
tomo_batch_size: int = 15, # Number of Tomograms to Load Per Sub-Epoch
|
|
401
|
+
Nclasses: int = 3):
|
|
402
|
+
|
|
403
|
+
# Read Copick Project
|
|
404
|
+
self.copick_config = config
|
|
405
|
+
self.root = io.load_copick_config(config)
|
|
406
|
+
|
|
407
|
+
# Copick Query For Input Tomogram
|
|
408
|
+
self.voxel_size = voxel_size
|
|
409
|
+
self.tomo_algorithm = tomo_algorithm
|
|
410
|
+
|
|
411
|
+
self.Nclasses = Nclasses
|
|
412
|
+
self.tomo_batch_size = tomo_batch_size
|
|
413
|
+
|
|
414
|
+
# Initialize the input dimensions
|
|
415
|
+
self.nx = None
|
|
416
|
+
self.ny = None
|
|
417
|
+
self.nz = None
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def create_predict_dataloader(
|
|
421
|
+
self,
|
|
422
|
+
voxel_spacing: float,
|
|
423
|
+
tomo_algorithm: str,
|
|
424
|
+
runIDs: str = None):
|
|
425
|
+
|
|
426
|
+
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
427
|
+
if runIDs is None:
|
|
428
|
+
runIDs = [run.name for run in self.root.runs]
|
|
429
|
+
|
|
430
|
+
# Load the test data
|
|
431
|
+
test_files = io.load_predict_data(self.root, runIDs, voxel_spacing, tomo_algorithm)
|
|
432
|
+
|
|
433
|
+
# Create the cached dataset with non-random transforms
|
|
434
|
+
test_ds = CacheDataset(data=test_files, transform=augment.get_predict_transforms())
|
|
435
|
+
|
|
436
|
+
# Read (nx,ny,nz) for input tomograms.
|
|
437
|
+
if self.nx is None:
|
|
438
|
+
(self.nx,self.ny,self.nz) = test_ds[0]['image'].shape[1:]
|
|
439
|
+
|
|
440
|
+
# Create the DataLoader
|
|
441
|
+
test_loader = DataLoader(test_ds,
|
|
442
|
+
batch_size=4,
|
|
443
|
+
shuffle=False,
|
|
444
|
+
num_workers=4,
|
|
445
|
+
pin_memory=torch.cuda.is_available())
|
|
446
|
+
return test_loader
|
|
447
|
+
|
|
448
|
+
def get_dataloader_parameters(self):
|
|
449
|
+
|
|
450
|
+
parameters = {
|
|
451
|
+
'config': self.copick_config,
|
|
452
|
+
'voxel_size': self.voxel_size,
|
|
453
|
+
'tomo_algorithm': self.tomo_algorithm
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
return parameters
|
|
457
|
+
|
|
458
|
+
|