octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ __version__ = "1.4.0"
2
+
3
+ # Shared CLI context settings for all commands
4
+ cli_context = {
5
+ "show_default": True,
6
+ "help_option_names": ["-h", "--help"], # allow both -h and --help
7
+ }
File without changes
@@ -0,0 +1,83 @@
1
+ from monai.transforms import (
2
+ Compose,
3
+ RandFlipd,
4
+ Orientationd,
5
+ RandRotate90d,
6
+ NormalizeIntensityd,
7
+ EnsureChannelFirstd,
8
+ RandCropByLabelClassesd,
9
+ RandScaleIntensityd,
10
+ RandShiftIntensityd,
11
+ RandAdjustContrastd,
12
+ RandGaussianNoised,
13
+ ScaleIntensityRanged,
14
+ RandomOrder,
15
+ )
16
+
17
+ def get_transforms():
18
+ """
19
+ Returns non-random transforms.
20
+ """
21
+ return Compose([
22
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
23
+ NormalizeIntensityd(keys="image"),
24
+ Orientationd(keys=["image", "label"], axcodes="RAS")
25
+ ])
26
+
27
+ def get_random_transforms( input_dim, num_samples, Nclasses):
28
+ """
29
+ Input:
30
+ input_dim: tuple of (nx, ny, nz)
31
+ num_samples: int
32
+ Nclasses: int
33
+
34
+ Returns random transforms.
35
+
36
+ For data with a missing wedge along the first axis (causing smearing in that direction),
37
+ we avoid rotations that would move this artifact to other axes. We only rotate around
38
+ the first axis (spatial_axes=[1, 2]) and avoid flipping along the first axis.
39
+ """
40
+ return Compose([
41
+ RandCropByLabelClassesd(
42
+ keys=["image", "label"],
43
+ label_key="label",
44
+ spatial_size=[input_dim[0], input_dim[1], input_dim[2]],
45
+ num_classes=Nclasses,
46
+ num_samples=num_samples
47
+ ),
48
+ # Only rotate around the first axis (keeping the missing wedge orientation consistent)
49
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2], max_k=3),
50
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
51
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
52
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
53
+ RandomOrder([
54
+ # Intensity augmentations are still appropriate
55
+ RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)),
56
+ RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)),
57
+ RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)),
58
+ RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=0.5), # Reduced noise std
59
+ ]),
60
+ ])
61
+
62
+ # Augmentations to Explore in the Future:
63
+ # Intensity-based augmentations
64
+ # RandHistogramShiftd(keys="image", prob=0.5, num_control_points=(3, 5))
65
+ # RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
66
+
67
+ # Geometric Transforms
68
+ # RandAffined(
69
+ # keys=["image", "label"],
70
+ # rotate_range=(0.1, 0.1, 0.1), # Rotation angles (radians) for x, y, z axes
71
+ # scale_range=(0.1, 0.1, 0.1), # Scale range for isotropic/anisotropic scaling
72
+ # prob=0.5, # Probability of applying the transform
73
+ # padding_mode="border" # Handle out-of-bounds values
74
+ # )
75
+
76
+ def get_predict_transforms():
77
+ """
78
+ Returns predict transforms.
79
+ """
80
+ return Compose([
81
+ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
82
+ NormalizeIntensityd(keys="image")
83
+ ])
@@ -0,0 +1,113 @@
1
+ from typing import List, Tuple, Callable, Optional, Dict, Any
2
+ from monai.transforms import Compose
3
+ from monai.data import CacheDataset
4
+ from octopi.datasets import io
5
+ from tqdm import tqdm
6
+ import os, sys
7
+
8
+ class MultiConfigCacheDataset(CacheDataset):
9
+ """
10
+ A custom CacheDataset that loads data lazily from multiple sources
11
+ with consolidated loading and caching process.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ manager,
17
+ run_ids: List[Tuple[str, str]],
18
+ transform: Optional[Callable] = None,
19
+ cache_rate: float = 1.0,
20
+ num_workers: int = 0,
21
+ progress: bool = True,
22
+ copy_cache: bool = True,
23
+ cache_num: int = sys.maxsize
24
+ ):
25
+ # Save reference to manager and run_ids
26
+ self.manager = manager
27
+ self.run_ids = run_ids
28
+ self.progress = progress
29
+
30
+ # Prepare empty data list first - don't load immediately
31
+ self.data = []
32
+
33
+ # Initialize the parent CacheDataset with an empty list
34
+ # We'll override the _fill_cache method to handle loading and caching in one step
35
+ super().__init__(
36
+ data=[], # Empty list - we'll load data in _fill_cache
37
+ transform=transform,
38
+ cache_rate=cache_rate,
39
+ num_workers=num_workers,
40
+ progress=False, # We'll handle our own progress
41
+ copy_cache=copy_cache,
42
+ cache_num=cache_num
43
+ )
44
+
45
+ def _fill_cache(self):
46
+ """
47
+ Override the parent's _fill_cache method to combine loading and caching.
48
+ """
49
+ if self.progress:
50
+ print("Loading and caching dataset...")
51
+
52
+ # Load and process data in a single operation
53
+ self.data = []
54
+ iterator = tqdm(self.run_ids, desc="Loading dataset") if self.progress else self.run_ids
55
+
56
+ for session_name, run_name in iterator:
57
+ root = self.manager.roots[session_name]
58
+ batch_data = io.load_training_data(
59
+ root,
60
+ [run_name],
61
+ self.manager.voxel_size,
62
+ self.manager.tomo_algorithm,
63
+ self.manager.target_name,
64
+ self.manager.target_session_id,
65
+ self.manager.target_user_id,
66
+ progress_update=False
67
+ )
68
+
69
+ self.data.extend(batch_data)
70
+
71
+ # Process and cache this batch right away
72
+ for i, item in enumerate(batch_data):
73
+ if len(self._cache) < self.cache_num and self.cache_rate > 0.0:
74
+ if np.random.random() < self.cache_rate:
75
+ self._cache.append(self._transform(item))
76
+
77
+ # Check max label value if needed
78
+ if hasattr(self.manager, '_check_max_label_value'):
79
+ self.manager._check_max_label_value(self.data)
80
+
81
+ # Update the _data attribute to match the loaded data
82
+ self._data = self.data
83
+
84
+ def __len__(self):
85
+ """
86
+ Return the length of the dataset.
87
+ """
88
+ if not self.data:
89
+ self._fill_cache() # Load data if not loaded yet
90
+ return len(self.data)
91
+
92
+ def __getitem__(self, index):
93
+ """
94
+ Return the item at the given index.
95
+ """
96
+ if not self.data:
97
+ self._fill_cache() # Load data if not loaded yet
98
+
99
+ # Use parent's logic for cached items
100
+ if index < len(self._cache):
101
+ return self._cache[index]
102
+
103
+ # Otherwise transform on-the-fly
104
+ return self._transform(self.data[index])
105
+
106
+ # TODO: Implement Single Config Cache Dataset
107
+ # class SingleConfigCacheDataset(CacheDataset):
108
+ # def __init__(self,
109
+ # root: Any,
110
+ # run_ids: List[str],
111
+ # voxel_size: float,
112
+ # tomo_algorithm: str,
113
+ # target_name: str,
@@ -0,0 +1,19 @@
1
+ from torch.utils.data import Dataset
2
+
3
+ class DynamicDataset(Dataset):
4
+ def __init__(self, data, transform=None):
5
+ self.data = data
6
+ self.transform = transform
7
+
8
+ def __len__(self):
9
+ return len(self.data)
10
+
11
+ def __getitem__(self, idx):
12
+ sample = self.data[idx]
13
+ if self.transform:
14
+ sample = self.transform(sample)
15
+ return sample
16
+
17
+ def update_data(self, new_data):
18
+ """Update the internal dataset with new data."""
19
+ self.data = new_data
@@ -0,0 +1,458 @@
1
+ from octopi.datasets import dataset, augment, cached_datset
2
+ from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
3
+ from typing import List, Optional
4
+ from octopi.utils import io as io2
5
+ from octopi.datasets import io
6
+ import torch, os, random, gc
7
+ import multiprocess as mp
8
+
9
+ class TrainLoaderManager:
10
+
11
+ def __init__(self,
12
+ config: str,
13
+ target_name: str,
14
+ target_session_id: str = None,
15
+ target_user_id: str = None,
16
+ voxel_size: float = 10,
17
+ tomo_algorithm: List[str] = ['wbp'],
18
+ tomo_batch_size: int = 15
19
+ ):
20
+
21
+ # Read Copick Projectdd
22
+ self.config = config
23
+ self.root = io.load_copick_config(config)
24
+
25
+ # Copick Query for Target
26
+ self.target_name = target_name
27
+ self.target_session_id = target_session_id
28
+ self.target_user_id = target_user_id
29
+
30
+ # Copick Query For Input Tomogram
31
+ self.voxel_size = voxel_size
32
+ self.tomo_algorithm = tomo_algorithm
33
+
34
+ self.reload_training_dataset = True
35
+ self.reload_validation_dataset = True
36
+ self.val_loader = None
37
+ self.train_loader = None
38
+ self.tomo_batch_size = tomo_batch_size
39
+
40
+ # Initialize the input dimensions
41
+ self.nx = None
42
+ self.ny = None
43
+ self.nz = None
44
+
45
+ def get_available_runIDs(self):
46
+ """
47
+ Identify and return a list of run IDs that have segmentations available for the target.
48
+
49
+ - Iterates through all runs in the project to check for segmentations that match
50
+ the specified target name, session ID, and user ID.
51
+ - Only includes runs that have at least one matching segmentation.
52
+
53
+ Returns:
54
+ available_runIDs (list): List of run IDs with available segmentations.
55
+ """
56
+ available_runIDs = []
57
+ runIDs = [run.name for run in self.root.runs]
58
+ for run in runIDs:
59
+ run = self.root.get_run(run)
60
+ seg = run.get_segmentations(name=self.target_name,
61
+ session_id=self.target_session_id,
62
+ user_id=self.target_user_id,
63
+ voxel_size=float(self.voxel_size))
64
+ if len(seg) > 0:
65
+ available_runIDs.append(run.name)
66
+
67
+ # If No Segmentations are Found, Inform the User
68
+ if len(available_runIDs) == 0:
69
+ print(
70
+ f"[Error] No segmentations found for the target query:\n"
71
+ f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
72
+ f"SessionID: {self.target_session_id}\n"
73
+ f"Please check the target name, user ID, and session ID.\n"
74
+ )
75
+ exit()
76
+
77
+ return available_runIDs
78
+
79
+ def get_data_splits(self,
80
+ trainRunIDs: str = None,
81
+ validateRunIDs: str = None,
82
+ train_ratio: float = 0.8,
83
+ val_ratio: float = 0.2,
84
+ test_ratio: float = 0.0,
85
+ create_test_dataset: bool = False):
86
+ """
87
+ Split the available data into training, validation, and testing sets based on input parameters.
88
+
89
+ Args:
90
+ trainRunIDs (str): Predefined list of run IDs for training. If provided, it overrides splitting logic.
91
+ validateRunIDs (str): Predefined list of run IDs for validation. If provided with trainRunIDs, no splitting occurs.
92
+ train_ratio (float): Proportion of available data to allocate to the training set.
93
+ val_ratio (float): Proportion of available data to allocate to the validation set.
94
+ test_ratio (float): Proportion of available data to allocate to the test set.
95
+ create_test_dataset (bool): Whether to create a test dataset or leave it empty.
96
+
97
+ Returns:
98
+ myRunIDs (dict): Dictionary containing run IDs for training, validation, and testing.
99
+ """
100
+
101
+ # Option 1: Only TrainRunIDs are Provided, Split into Train, Validate and Test (Optional)
102
+ if trainRunIDs is not None and validateRunIDs is None:
103
+ trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
104
+ trainRunIDs, train_ratio, val_ratio, test_ratio,
105
+ return_test_dataset = create_test_dataset
106
+ )
107
+ # Option 2: TrainRunIDs and ValidateRunIDs are Provided, No Need to Split
108
+ elif trainRunIDs is not None and validateRunIDs is not None:
109
+ testRunIDs = None
110
+ # Option 3: Use the Entire Copick Project, Split into Train, Validate and Test
111
+ else:
112
+ runIDs = self.get_available_runIDs()
113
+ trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
114
+ runIDs, train_ratio, val_ratio, test_ratio,
115
+ return_test_dataset = create_test_dataset
116
+ )
117
+
118
+ # Get Class Info from the Training Dataset
119
+ self._get_class_info(trainRunIDs)
120
+
121
+ # Swap if Test Runs is Larger than Validation Runs
122
+ if create_test_dataset and len(testRunIDs) > len(validateRunIDs):
123
+ testRunIDs, validateRunIDs = validateRunIDs, testRunIDs
124
+
125
+ # Determine if datasets fit entirely in memory based on the batch size
126
+ # If the validation set is smaller than the batch size, avoid reloading
127
+ if len(validateRunIDs) < self.tomo_batch_size:
128
+ self.reload_validation_dataset = False
129
+
130
+ # If the training set is smaller than the batch size, avoid reloading
131
+ if len(trainRunIDs) < self.tomo_batch_size:
132
+ self.reload_training_dataset = False
133
+
134
+ # Store the split run IDs into a dictionary for easy access
135
+ self.myRunIDs = {
136
+ 'train': trainRunIDs,
137
+ 'validate': validateRunIDs,
138
+ 'test': testRunIDs
139
+ }
140
+
141
+ print(f"Number of training samples: {len(trainRunIDs)}")
142
+ print(f"Number of validation samples: {len(validateRunIDs)}")
143
+ if testRunIDs is not None:
144
+ print(f'Number of test samples: {len(testRunIDs)}')
145
+
146
+ # Define separate batch sizes
147
+ self.train_batch_size = min( len(self.myRunIDs['train']), self.tomo_batch_size)
148
+ self.val_batch_size = min( len(self.myRunIDs['validate']), self.tomo_batch_size)
149
+
150
+ # Initialize data iterators for training and validation
151
+ self._initialize_val_iterators()
152
+ self._initialize_train_iterators()
153
+
154
+ return self.myRunIDs
155
+
156
+ def _get_class_info(self, trainRunDs):
157
+
158
+ # Fetch a segmentation to determine class names and number of classes
159
+ for runID in trainRunDs:
160
+ run = self.root.get_run(runID)
161
+ seg = run.get_segmentations(name=self.target_name,
162
+ session_id=self.target_session_id,
163
+ user_id=self.target_user_id,
164
+ voxel_size=float(self.voxel_size))
165
+ if len(seg) == 0:
166
+ continue
167
+
168
+ # If Session ID or User ID are None, Set Them Based on the First Found Segmentation
169
+ if self.target_session_id is None:
170
+ self.target_session_id = seg[0].session_id
171
+ if self.target_user_id is None:
172
+ self.target_user_id = seg[0].user_id
173
+
174
+ # Read Yaml Config to Get Number of Classes and Class Names
175
+ target_config = io2.check_target_config_path(self)
176
+ class_names = target_config['input']['labels']
177
+ self.Nclasses = len(class_names) + 1
178
+ self.class_names = [name for name, idx in sorted(class_names.items(), key=lambda x: x[1])]
179
+
180
+ # We Only need to read One Segmentation to Get Class Info
181
+ break
182
+
183
+ def _get_padded_list(self, data_list, batch_size):
184
+ # Calculate padding needed to make `data_list` a multiple of `batch_size`
185
+ remainder = len(data_list) % batch_size
186
+ if remainder > 0:
187
+ # Number of additional items needed to make the length a multiple of batch size
188
+ padding_needed = batch_size - remainder
189
+ # Extend `data_list` with a random subset to achieve the padding
190
+ data_list = data_list + random.sample(data_list, padding_needed)
191
+ # Shuffle the full list
192
+ random.shuffle(data_list)
193
+ return data_list
194
+
195
+ def _initialize_train_iterators(self):
196
+ # Initialize padded train and validation data lists
197
+ self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
198
+
199
+ # Create iterators
200
+ self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
201
+
202
+ def _initialize_val_iterators(self):
203
+ # Initialize padded train and validation data lists
204
+ self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
205
+
206
+ # Create iterators
207
+ self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
208
+
209
+ def _get_data_batches(self, data_list, batch_size):
210
+ # Generator that yields batches of specified size
211
+ for i in range(0, len(data_list), batch_size):
212
+ yield data_list[i:i + batch_size]
213
+
214
+ def _extract_run_ids(self, data_iter_name, initialize_method):
215
+ # Access the instance's data iterator by name
216
+ data_iter = getattr(self, data_iter_name)
217
+ try:
218
+ # Attempt to get the next batch from the iterator
219
+ runIDs = next(data_iter)
220
+ except StopIteration:
221
+ # Reinitialize the iterator if exhausted
222
+ initialize_method()
223
+ # Update the iterator reference after reinitialization
224
+ data_iter = getattr(self, data_iter_name)
225
+ runIDs = next(data_iter)
226
+ # Update the instance attribute with the new iterator state
227
+ setattr(self, data_iter_name, data_iter)
228
+ return runIDs
229
+
230
+ def create_train_dataloaders(
231
+ self,
232
+ crop_size: int = 96,
233
+ num_samples: int = 64):
234
+
235
+ train_batch_size = 1
236
+ val_batch_size = 1
237
+
238
+ # If reloads are disabled and loaders already exist, reuse them
239
+ if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
240
+ return self.train_loader, self.val_loader
241
+
242
+ # We Only Need to Reload the Training Dataset if the Total Number of Runs is larger than
243
+ # the tomo batch size
244
+ if self.train_loader is None:
245
+
246
+ # Fetch the next batch of run IDs
247
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
248
+ train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
249
+ self.target_name, self.target_session_id, self.target_user_id,
250
+ progress_update=False)
251
+ self._check_max_label_value(train_files)
252
+
253
+ # Create the cached dataset with non-random transforms
254
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
255
+
256
+ # Delete the training files to free memory
257
+ train_files = None
258
+ gc.collect()
259
+
260
+ # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
261
+ if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
262
+ self.input_dim = io.get_input_dimensions(train_ds, crop_size)
263
+
264
+ # Wrap the cached dataset to apply random transforms during iteration
265
+ self.dynamic_train_dataset = dataset.DynamicDataset(
266
+ data=train_ds,
267
+ transform=augment.get_random_transforms(self.input_dim, num_samples, self.Nclasses)
268
+ )
269
+
270
+ # Define the number of processes for the DataLoader
271
+ n_procs = min(mp.cpu_count(), 4)
272
+
273
+ # DataLoader remains the same
274
+ self.train_loader = DataLoader(
275
+ self.dynamic_train_dataset,
276
+ batch_size=train_batch_size,
277
+ shuffle=False,
278
+ num_workers=n_procs,
279
+ pin_memory=torch.cuda.is_available(),
280
+ )
281
+
282
+ else:
283
+ # Fetch the next batch of run IDs
284
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
285
+ train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
286
+ self.target_name, self.target_session_id, self.target_user_id,
287
+ progress_update=False)
288
+ self._check_max_label_value(train_files)
289
+
290
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
291
+ self.dynamic_train_dataset.update_data(train_ds)
292
+
293
+ # We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
294
+ # the tomo batch size
295
+ if self.val_loader is None:
296
+
297
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
298
+ val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
299
+ self.target_name, self.target_session_id, self.target_user_id,
300
+ progress_update=False)
301
+ self._check_max_label_value(val_files)
302
+
303
+ # Create validation dataset
304
+ val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
305
+
306
+ # Delete the validation files to free memory
307
+ val_files = None
308
+ gc.collect()
309
+
310
+ # # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
311
+ # if self.nx is None:
312
+ # (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
313
+
314
+ # if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
315
+ # else: self.input_dim = (crop_size, crop_size, crop_size)
316
+
317
+ # Wrap the cached dataset to apply random transforms during iteration
318
+ self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
319
+
320
+ dataset_size = len(self.dynamic_validation_dataset)
321
+ n_procs = min(mp.cpu_count(), 8)
322
+
323
+ # Create validation DataLoader
324
+ self.val_loader = DataLoader(
325
+ self.dynamic_validation_dataset,
326
+ batch_size=val_batch_size,
327
+ num_workers=n_procs,
328
+ pin_memory=torch.cuda.is_available(),
329
+ shuffle=False, # Ensure the data order remains consistent,
330
+ )
331
+ else:
332
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
333
+ val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
334
+ self.target_name, self.target_session_id, self.target_user_id,
335
+ progress_update=False)
336
+ self._check_max_label_value(val_files)
337
+
338
+ return self.train_loader, self.val_loader
339
+
340
+ def get_reload_frequency(self, num_epochs: int):
341
+ """
342
+ Automatically calculate the reload frequency for the dataset during training.
343
+
344
+ Returns:
345
+ int: Reload frequency (number of epochs between dataset reloads).
346
+ """
347
+ if not self.reload_training_dataset:
348
+ # No need to reload if all tomograms fit in memory
349
+ print("All training samples fit in memory. No reloading required.")
350
+ self.reload_frequency = -1
351
+
352
+ else:
353
+ # Calculate the number of segments based on total training runs and batch size
354
+ num_segments = (len(self.myRunIDs['train']) + self.tomo_batch_size - 1) // self.tomo_batch_size
355
+
356
+ # Calculate reload frequency to distribute reloading evenly over epochs
357
+ self.reload_frequency = max(num_epochs // num_segments, 1)
358
+
359
+ print(f"\nReloading {self.tomo_batch_size} tomograms every {self.reload_frequency} epochs\n")
360
+
361
+ # Warn if the number of epochs is insufficient for full dataset coverage
362
+ if num_epochs < num_segments:
363
+ print(
364
+ f"Warning: Chosen number of epochs ({num_epochs}) may not be sufficient "
365
+ f"to train over all training samples. Consider increasing the number of epochs "
366
+ f"to at least {num_segments}\n."
367
+ )
368
+
369
+ def _check_max_label_value(self, train_files):
370
+ max_label_value = max(file['label'].max() for file in train_files)
371
+ if max_label_value > self.Nclasses:
372
+ print(f"Warning: Maximum class label value {max_label_value} exceeds the number of classes {self.Nclasses}.")
373
+ print("This may cause issues with the model's output layer.")
374
+ print("Consider adjusting the number of classes or the label values in your data.\n")
375
+
376
+ def get_dataloader_parameters(self):
377
+
378
+ parameters = {
379
+ 'config': self.config,
380
+ 'target_name': self.target_name,
381
+ 'target_session_id': self.target_session_id,
382
+ 'target_user_id': self.target_user_id,
383
+ 'voxel_size': self.voxel_size,
384
+ 'tomo_algorithm': self.tomo_algorithm,
385
+ 'tomo_batch_size': self.tomo_batch_size,
386
+ 'reload_frequency': self.reload_frequency,
387
+ 'testRunIDs': self.myRunIDs['test'],
388
+ 'valRunIDs': self.myRunIDs['validate'],
389
+ 'trainRunIDs': self.myRunIDs['train'],
390
+ }
391
+
392
+ return parameters
393
+
394
+ class PredictLoaderManager:
395
+
396
+ def __init__(self,
397
+ config: str,
398
+ voxel_size: float = 10,
399
+ tomo_algorithm: str = 'wbp',
400
+ tomo_batch_size: int = 15, # Number of Tomograms to Load Per Sub-Epoch
401
+ Nclasses: int = 3):
402
+
403
+ # Read Copick Project
404
+ self.copick_config = config
405
+ self.root = io.load_copick_config(config)
406
+
407
+ # Copick Query For Input Tomogram
408
+ self.voxel_size = voxel_size
409
+ self.tomo_algorithm = tomo_algorithm
410
+
411
+ self.Nclasses = Nclasses
412
+ self.tomo_batch_size = tomo_batch_size
413
+
414
+ # Initialize the input dimensions
415
+ self.nx = None
416
+ self.ny = None
417
+ self.nz = None
418
+
419
+
420
+ def create_predict_dataloader(
421
+ self,
422
+ voxel_spacing: float,
423
+ tomo_algorithm: str,
424
+ runIDs: str = None):
425
+
426
+ # Split trainRunIDs, validateRunIDs, testRunIDs
427
+ if runIDs is None:
428
+ runIDs = [run.name for run in self.root.runs]
429
+
430
+ # Load the test data
431
+ test_files = io.load_predict_data(self.root, runIDs, voxel_spacing, tomo_algorithm)
432
+
433
+ # Create the cached dataset with non-random transforms
434
+ test_ds = CacheDataset(data=test_files, transform=augment.get_predict_transforms())
435
+
436
+ # Read (nx,ny,nz) for input tomograms.
437
+ if self.nx is None:
438
+ (self.nx,self.ny,self.nz) = test_ds[0]['image'].shape[1:]
439
+
440
+ # Create the DataLoader
441
+ test_loader = DataLoader(test_ds,
442
+ batch_size=4,
443
+ shuffle=False,
444
+ num_workers=4,
445
+ pin_memory=torch.cuda.is_available())
446
+ return test_loader
447
+
448
+ def get_dataloader_parameters(self):
449
+
450
+ parameters = {
451
+ 'config': self.copick_config,
452
+ 'voxel_size': self.voxel_size,
453
+ 'tomo_algorithm': self.tomo_algorithm
454
+ }
455
+
456
+ return parameters
457
+
458
+