octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,84 @@
1
+ from monai.transforms import (
2
+ Compose,
3
+ RandFlipd,
4
+ Orientationd,
5
+ RandRotate90d,
6
+ NormalizeIntensityd,
7
+ EnsureChannelFirstd,
8
+ RandCropByLabelClassesd,
9
+ RandScaleIntensityd,
10
+ RandShiftIntensityd,
11
+ RandAdjustContrastd,
12
+ RandGaussianNoised,
13
+ ScaleIntensityRanged,
14
+ RandomOrder,
15
+ )
16
+
17
+ def get_transforms():
18
+ """
19
+ Returns non-random transforms.
20
+ """
21
+ return Compose([
22
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
23
+ NormalizeIntensityd(keys="image"),
24
+ Orientationd(keys=["image", "label"], axcodes="RAS")
25
+ ])
26
+
27
+ def get_random_transforms( input_dim, num_samples, Nclasses):
28
+ """
29
+ Input:
30
+ input_dim: tuple of (nx, ny, nz)
31
+ num_samples: int
32
+ Nclasses: int
33
+
34
+ Returns random transforms.
35
+
36
+ For data with a missing wedge along the first axis (causing smearing in that direction),
37
+ we avoid rotations that would move this artifact to other axes. We only rotate around
38
+ the first axis (spatial_axes=[1, 2]) and avoid flipping along the first axis.
39
+ """
40
+ return Compose([
41
+ RandCropByLabelClassesd(
42
+ keys=["image", "label"],
43
+ label_key="label",
44
+ spatial_size=[input_dim[0], input_dim[1], input_dim[2]],
45
+ num_classes=Nclasses,
46
+ num_samples=num_samples
47
+ ),
48
+ # Only rotate around the first axis (keeping the missing wedge orientation consistent)
49
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2], max_k=3),
50
+ # Avoid flipping along the first axis (where the missing wedge is)
51
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), # Removed
52
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
53
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
54
+ RandomOrder([
55
+ # Intensity augmentations are still appropriate
56
+ RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)),
57
+ RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)),
58
+ RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)),
59
+ RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=0.5), # Reduced noise std
60
+ ]),
61
+ ])
62
+
63
+ # Augmentations to Explore in the Future:
64
+ # Intensity-based augmentations
65
+ # RandHistogramShiftd(keys="image", prob=0.5, num_control_points=(3, 5))
66
+ # RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
67
+
68
+ # Geometric Transforms
69
+ # RandAffined(
70
+ # keys=["image", "label"],
71
+ # rotate_range=(0.1, 0.1, 0.1), # Rotation angles (radians) for x, y, z axes
72
+ # scale_range=(0.1, 0.1, 0.1), # Scale range for isotropic/anisotropic scaling
73
+ # prob=0.5, # Probability of applying the transform
74
+ # padding_mode="border" # Handle out-of-bounds values
75
+ # )
76
+
77
+ def get_predict_transforms():
78
+ """
79
+ Returns predict transforms.
80
+ """
81
+ return Compose([
82
+ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
83
+ NormalizeIntensityd(keys="image")
84
+ ])
@@ -0,0 +1,113 @@
1
+ from typing import List, Tuple, Callable, Optional, Dict, Any
2
+ from monai.transforms import Compose
3
+ from monai.data import CacheDataset
4
+ from octopi import io
5
+ from tqdm import tqdm
6
+ import os, sys
7
+
8
+ class MultiConfigCacheDataset(CacheDataset):
9
+ """
10
+ A custom CacheDataset that loads data lazily from multiple sources
11
+ with consolidated loading and caching process.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ manager,
17
+ run_ids: List[Tuple[str, str]],
18
+ transform: Optional[Callable] = None,
19
+ cache_rate: float = 1.0,
20
+ num_workers: int = 0,
21
+ progress: bool = True,
22
+ copy_cache: bool = True,
23
+ cache_num: int = sys.maxsize
24
+ ):
25
+ # Save reference to manager and run_ids
26
+ self.manager = manager
27
+ self.run_ids = run_ids
28
+ self.progress = progress
29
+
30
+ # Prepare empty data list first - don't load immediately
31
+ self.data = []
32
+
33
+ # Initialize the parent CacheDataset with an empty list
34
+ # We'll override the _fill_cache method to handle loading and caching in one step
35
+ super().__init__(
36
+ data=[], # Empty list - we'll load data in _fill_cache
37
+ transform=transform,
38
+ cache_rate=cache_rate,
39
+ num_workers=num_workers,
40
+ progress=False, # We'll handle our own progress
41
+ copy_cache=copy_cache,
42
+ cache_num=cache_num
43
+ )
44
+
45
+ def _fill_cache(self):
46
+ """
47
+ Override the parent's _fill_cache method to combine loading and caching.
48
+ """
49
+ if self.progress:
50
+ print("Loading and caching dataset...")
51
+
52
+ # Load and process data in a single operation
53
+ self.data = []
54
+ iterator = tqdm(self.run_ids, desc="Loading dataset") if self.progress else self.run_ids
55
+
56
+ for session_name, run_name in iterator:
57
+ root = self.manager.roots[session_name]
58
+ batch_data = io.load_training_data(
59
+ root,
60
+ [run_name],
61
+ self.manager.voxel_size,
62
+ self.manager.tomo_algorithm,
63
+ self.manager.target_name,
64
+ self.manager.target_session_id,
65
+ self.manager.target_user_id,
66
+ progress_update=False
67
+ )
68
+
69
+ self.data.extend(batch_data)
70
+
71
+ # Process and cache this batch right away
72
+ for i, item in enumerate(batch_data):
73
+ if len(self._cache) < self.cache_num and self.cache_rate > 0.0:
74
+ if np.random.random() < self.cache_rate:
75
+ self._cache.append(self._transform(item))
76
+
77
+ # Check max label value if needed
78
+ if hasattr(self.manager, '_check_max_label_value'):
79
+ self.manager._check_max_label_value(self.data)
80
+
81
+ # Update the _data attribute to match the loaded data
82
+ self._data = self.data
83
+
84
+ def __len__(self):
85
+ """
86
+ Return the length of the dataset.
87
+ """
88
+ if not self.data:
89
+ self._fill_cache() # Load data if not loaded yet
90
+ return len(self.data)
91
+
92
+ def __getitem__(self, index):
93
+ """
94
+ Return the item at the given index.
95
+ """
96
+ if not self.data:
97
+ self._fill_cache() # Load data if not loaded yet
98
+
99
+ # Use parent's logic for cached items
100
+ if index < len(self._cache):
101
+ return self._cache[index]
102
+
103
+ # Otherwise transform on-the-fly
104
+ return self._transform(self.data[index])
105
+
106
+ # TODO: Implement Single Config Cache Dataset
107
+ # class SingleConfigCacheDataset(CacheDataset):
108
+ # def __init__(self,
109
+ # root: Any,
110
+ # run_ids: List[str],
111
+ # voxel_size: float,
112
+ # tomo_algorithm: str,
113
+ # target_name: str,
@@ -0,0 +1,19 @@
1
+ from torch.utils.data import Dataset
2
+
3
+ class DynamicDataset(Dataset):
4
+ def __init__(self, data, transform=None):
5
+ self.data = data
6
+ self.transform = transform
7
+
8
+ def __len__(self):
9
+ return len(self.data)
10
+
11
+ def __getitem__(self, idx):
12
+ sample = self.data[idx]
13
+ if self.transform:
14
+ sample = self.transform(sample)
15
+ return sample
16
+
17
+ def update_data(self, new_data):
18
+ """Update the internal dataset with new data."""
19
+ self.data = new_data
@@ -0,0 +1,429 @@
1
+ from octopi.datasets import dataset, augment, cached_datset
2
+ from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
3
+ from typing import List, Optional
4
+ from octopi import io
5
+ import torch, os, random, gc
6
+ import multiprocess as mp
7
+
8
+ class TrainLoaderManager:
9
+
10
+ def __init__(self,
11
+ config: str,
12
+ target_name: str,
13
+ target_session_id: str = None,
14
+ target_user_id: str = None,
15
+ voxel_size: float = 10,
16
+ tomo_algorithm: List[str] = ['wbp'],
17
+ tomo_batch_size: int = 15, # Number of Tomograms to Load Per Sub-Epoch
18
+ Nclasses: int = 3): # Number of Objects + Background
19
+
20
+ # Read Copick Projectdd
21
+ self.config = config
22
+ self.root = io.load_copick_config(config)
23
+
24
+ # Copick Query for Target
25
+ self.target_name = target_name
26
+ self.target_session_id = target_session_id
27
+ self.target_user_id = target_user_id
28
+
29
+ # Copick Query For Input Tomogram
30
+ self.voxel_size = voxel_size
31
+ self.tomo_algorithm = tomo_algorithm
32
+
33
+ self.Nclasses = Nclasses
34
+ self.tomo_batch_size = tomo_batch_size
35
+
36
+ self.reload_training_dataset = True
37
+ self.reload_validation_dataset = True
38
+ self.val_loader = None
39
+ self.train_loader = None
40
+
41
+ # Initialize the input dimensions
42
+ self.nx = None
43
+ self.ny = None
44
+ self.nz = None
45
+
46
+ def get_available_runIDs(self):
47
+ """
48
+ Identify and return a list of run IDs that have segmentations available for the target.
49
+
50
+ - Iterates through all runs in the project to check for segmentations that match
51
+ the specified target name, session ID, and user ID.
52
+ - Only includes runs that have at least one matching segmentation.
53
+
54
+ Returns:
55
+ available_runIDs (list): List of run IDs with available segmentations.
56
+ """
57
+ available_runIDs = []
58
+ runIDs = [run.name for run in self.root.runs]
59
+ for run in runIDs:
60
+ run = self.root.get_run(run)
61
+ seg = run.get_segmentations(name=self.target_name,
62
+ session_id=self.target_session_id,
63
+ user_id=self.target_user_id,
64
+ voxel_size=float(self.voxel_size))
65
+ if len(seg) > 0:
66
+ available_runIDs.append(run.name)
67
+
68
+ # If No Segmentations are Found, Inform the User
69
+ if len(available_runIDs) == 0:
70
+ print(
71
+ f"[Error] No segmentations found for the target query:\n"
72
+ f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
73
+ f"SessionID: {self.target_session_id}\n"
74
+ f"Please check the target name, user ID, and session ID.\n"
75
+ )
76
+ exit()
77
+
78
+ return available_runIDs
79
+
80
+ def get_data_splits(self,
81
+ trainRunIDs: str = None,
82
+ validateRunIDs: str = None,
83
+ train_ratio: float = 0.8,
84
+ val_ratio: float = 0.1,
85
+ test_ratio: float = 0.1,
86
+ create_test_dataset: bool = True):
87
+ """
88
+ Split the available data into training, validation, and testing sets based on input parameters.
89
+
90
+ Args:
91
+ trainRunIDs (str): Predefined list of run IDs for training. If provided, it overrides splitting logic.
92
+ validateRunIDs (str): Predefined list of run IDs for validation. If provided with trainRunIDs, no splitting occurs.
93
+ train_ratio (float): Proportion of available data to allocate to the training set.
94
+ val_ratio (float): Proportion of available data to allocate to the validation set.
95
+ test_ratio (float): Proportion of available data to allocate to the test set.
96
+ create_test_dataset (bool): Whether to create a test dataset or leave it empty.
97
+
98
+ Returns:
99
+ myRunIDs (dict): Dictionary containing run IDs for training, validation, and testing.
100
+ """
101
+
102
+ # Option 1: Only TrainRunIDs are Provided, Split into Train, Validate and Test (Optional)
103
+ if trainRunIDs is not None and validateRunIDs is None:
104
+ trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
105
+ trainRunIDs, train_ratio, val_ratio, test_ratio,
106
+ return_test_dataset = create_test_dataset
107
+ )
108
+ # Option 2: TrainRunIDs and ValidateRunIDs are Provided, No Need to Split
109
+ elif trainRunIDs is not None and validateRunIDs is not None:
110
+ testRunIDs = None
111
+ # Option 3: Use the Entire Copick Project, Split into Train, Validate and Test
112
+ else:
113
+ runIDs = self.get_available_runIDs()
114
+ trainRunIDs, validateRunIDs, testRunIDs = io.split_multiclass_dataset(
115
+ runIDs, train_ratio, val_ratio, test_ratio,
116
+ return_test_dataset = create_test_dataset
117
+ )
118
+
119
+ # Swap if Test Runs is Larger than Validation Runs
120
+ if create_test_dataset and len(testRunIDs) > len(validateRunIDs):
121
+ testRunIDs, validateRunIDs = validateRunIDs, testRunIDs
122
+
123
+ # Determine if datasets fit entirely in memory based on the batch size
124
+ # If the validation set is smaller than the batch size, avoid reloading
125
+ if len(validateRunIDs) < self.tomo_batch_size:
126
+ self.reload_validation_dataset = False
127
+
128
+ # If the training set is smaller than the batch size, avoid reloading
129
+ if len(trainRunIDs) < self.tomo_batch_size:
130
+ self.reload_training_dataset = False
131
+
132
+ # Store the split run IDs into a dictionary for easy access
133
+ self.myRunIDs = {
134
+ 'train': trainRunIDs,
135
+ 'validate': validateRunIDs,
136
+ 'test': testRunIDs
137
+ }
138
+
139
+ print(f"Number of training samples: {len(trainRunIDs)}")
140
+ print(f"Number of validation samples: {len(validateRunIDs)}")
141
+ if testRunIDs is not None:
142
+ print(f'Number of test samples: {len(testRunIDs)}')
143
+
144
+ # Define separate batch sizes
145
+ self.train_batch_size = min( len(self.myRunIDs['train']), self.tomo_batch_size)
146
+ self.val_batch_size = min( len(self.myRunIDs['validate']), self.tomo_batch_size)
147
+
148
+ # Initialize data iterators for training and validation
149
+ self._initialize_val_iterators()
150
+ self._initialize_train_iterators()
151
+
152
+ return self.myRunIDs
153
+
154
+ def _get_padded_list(self, data_list, batch_size):
155
+ # Calculate padding needed to make `data_list` a multiple of `batch_size`
156
+ remainder = len(data_list) % batch_size
157
+ if remainder > 0:
158
+ # Number of additional items needed to make the length a multiple of batch size
159
+ padding_needed = batch_size - remainder
160
+ # Extend `data_list` with a random subset to achieve the padding
161
+ data_list = data_list + random.sample(data_list, padding_needed)
162
+ # Shuffle the full list
163
+ random.shuffle(data_list)
164
+ return data_list
165
+
166
+ def _initialize_train_iterators(self):
167
+ # Initialize padded train and validation data lists
168
+ self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
169
+
170
+ # Create iterators
171
+ self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
172
+
173
+ def _initialize_val_iterators(self):
174
+ # Initialize padded train and validation data lists
175
+ self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
176
+
177
+ # Create iterators
178
+ self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
179
+
180
+ def _get_data_batches(self, data_list, batch_size):
181
+ # Generator that yields batches of specified size
182
+ for i in range(0, len(data_list), batch_size):
183
+ yield data_list[i:i + batch_size]
184
+
185
+ def _extract_run_ids(self, data_iter_name, initialize_method):
186
+ # Access the instance's data iterator by name
187
+ data_iter = getattr(self, data_iter_name)
188
+ try:
189
+ # Attempt to get the next batch from the iterator
190
+ runIDs = next(data_iter)
191
+ except StopIteration:
192
+ # Reinitialize the iterator if exhausted
193
+ initialize_method()
194
+ # Update the iterator reference after reinitialization
195
+ data_iter = getattr(self, data_iter_name)
196
+ runIDs = next(data_iter)
197
+ # Update the instance attribute with the new iterator state
198
+ setattr(self, data_iter_name, data_iter)
199
+ return runIDs
200
+
201
+ def create_train_dataloaders(
202
+ self,
203
+ crop_size: int = 96,
204
+ num_samples: int = 64):
205
+
206
+ train_batch_size = 1
207
+ val_batch_size = 1
208
+
209
+ # If reloads are disabled and loaders already exist, reuse them
210
+ if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
211
+ return self.train_loader, self.val_loader
212
+
213
+ # We Only Need to Reload the Training Dataset if the Total Number of Runs is larger than
214
+ # the tomo batch size
215
+ if self.train_loader is None:
216
+
217
+ # Fetch the next batch of run IDs
218
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
219
+ train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
220
+ self.target_name, self.target_session_id, self.target_user_id,
221
+ progress_update=False)
222
+ self._check_max_label_value(train_files)
223
+
224
+ # Create the cached dataset with non-random transforms
225
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
226
+
227
+ # Delete the training files to free memory
228
+ train_files = None
229
+ gc.collect()
230
+
231
+ # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
232
+ if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
233
+ self.input_dim = io.get_input_dimensions(train_ds, crop_size)
234
+
235
+ # Wrap the cached dataset to apply random transforms during iteration
236
+ self.dynamic_train_dataset = dataset.DynamicDataset(
237
+ data=train_ds,
238
+ transform=augment.get_random_transforms(self.input_dim, num_samples, self.Nclasses)
239
+ )
240
+
241
+ # Define the number of processes for the DataLoader
242
+ n_procs = min(mp.cpu_count(), 4)
243
+
244
+ # DataLoader remains the same
245
+ self.train_loader = DataLoader(
246
+ self.dynamic_train_dataset,
247
+ batch_size=train_batch_size,
248
+ shuffle=False,
249
+ num_workers=n_procs,
250
+ pin_memory=torch.cuda.is_available(),
251
+ )
252
+
253
+ else:
254
+ # Fetch the next batch of run IDs
255
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
256
+ train_files = io.load_training_data(self.root, trainRunIDs, self.voxel_size, self.tomo_algorithm,
257
+ self.target_name, self.target_session_id, self.target_user_id,
258
+ progress_update=False)
259
+ self._check_max_label_value(train_files)
260
+
261
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
262
+ self.dynamic_train_dataset.update_data(train_ds)
263
+
264
+ # We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
265
+ # the tomo batch size
266
+ if self.val_loader is None:
267
+
268
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
269
+ val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
270
+ self.target_name, self.target_session_id, self.target_user_id,
271
+ progress_update=False)
272
+ self._check_max_label_value(val_files)
273
+
274
+ # Create validation dataset
275
+ val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
276
+
277
+ # Delete the validation files to free memory
278
+ val_files = None
279
+ gc.collect()
280
+
281
+ # # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
282
+ # if self.nx is None:
283
+ # (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
284
+
285
+ # if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
286
+ # else: self.input_dim = (crop_size, crop_size, crop_size)
287
+
288
+ # Wrap the cached dataset to apply random transforms during iteration
289
+ self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
290
+
291
+ dataset_size = len(self.dynamic_validation_dataset)
292
+ n_procs = min(mp.cpu_count(), 8)
293
+
294
+ # Create validation DataLoader
295
+ self.val_loader = DataLoader(
296
+ self.dynamic_validation_dataset,
297
+ batch_size=val_batch_size,
298
+ num_workers=n_procs,
299
+ pin_memory=torch.cuda.is_available(),
300
+ shuffle=False, # Ensure the data order remains consistent,
301
+ )
302
+ else:
303
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
304
+ val_files = io.load_training_data(self.root, validateRunIDs, self.voxel_size, self.tomo_algorithm,
305
+ self.target_name, self.target_session_id, self.target_user_id,
306
+ progress_update=False)
307
+ self._check_max_label_value(val_files)
308
+
309
+ return self.train_loader, self.val_loader
310
+
311
+ def get_reload_frequency(self, num_epochs: int):
312
+ """
313
+ Automatically calculate the reload frequency for the dataset during training.
314
+
315
+ Returns:
316
+ int: Reload frequency (number of epochs between dataset reloads).
317
+ """
318
+ if not self.reload_training_dataset:
319
+ # No need to reload if all tomograms fit in memory
320
+ print("All training samples fit in memory. No reloading required.")
321
+ self.reload_frequency = -1
322
+
323
+ else:
324
+ # Calculate the number of segments based on total training runs and batch size
325
+ num_segments = (len(self.myRunIDs['train']) + self.tomo_batch_size - 1) // self.tomo_batch_size
326
+
327
+ # Calculate reload frequency to distribute reloading evenly over epochs
328
+ self.reload_frequency = max(num_epochs // num_segments, 1)
329
+
330
+ print(f"\nReloading {self.tomo_batch_size} tomograms every {self.reload_frequency} epochs\n")
331
+
332
+ # Warn if the number of epochs is insufficient for full dataset coverage
333
+ if num_epochs < num_segments:
334
+ print(
335
+ f"Warning: Chosen number of epochs ({num_epochs}) may not be sufficient "
336
+ f"to train over all training samples. Consider increasing the number of epochs "
337
+ f"to at least {num_segments}\n."
338
+ )
339
+
340
+ def _check_max_label_value(self, train_files):
341
+ max_label_value = max(file['label'].max() for file in train_files)
342
+ if max_label_value > self.Nclasses:
343
+ print(f"Warning: Maximum class label value {max_label_value} exceeds the number of classes {self.Nclasses}.")
344
+ print("This may cause issues with the model's output layer.")
345
+ print("Consider adjusting the number of classes or the label values in your data.\n")
346
+
347
+ def get_dataloader_parameters(self):
348
+
349
+ parameters = {
350
+ 'config': self.config,
351
+ 'target_name': self.target_name,
352
+ 'target_session_id': self.target_session_id,
353
+ 'target_user_id': self.target_user_id,
354
+ 'voxel_size': self.voxel_size,
355
+ 'tomo_algorithm': self.tomo_algorithm,
356
+ 'tomo_batch_size': self.tomo_batch_size,
357
+ 'reload_frequency': self.reload_frequency,
358
+ 'testRunIDs': self.myRunIDs['test'],
359
+ 'valRunIDs': self.myRunIDs['validate'],
360
+ 'trainRunIDs': self.myRunIDs['train'],
361
+ }
362
+
363
+ return parameters
364
+
365
+ class PredictLoaderManager:
366
+
367
+ def __init__(self,
368
+ config: str,
369
+ voxel_size: float = 10,
370
+ tomo_algorithm: str = 'wbp',
371
+ tomo_batch_size: int = 15, # Number of Tomograms to Load Per Sub-Epoch
372
+ Nclasses: int = 3):
373
+
374
+ # Read Copick Project
375
+ self.copick_config = config
376
+ self.root = io.load_copick_config(config)
377
+
378
+ # Copick Query For Input Tomogram
379
+ self.voxel_size = voxel_size
380
+ self.tomo_algorithm = tomo_algorithm
381
+
382
+ self.Nclasses = Nclasses
383
+ self.tomo_batch_size = tomo_batch_size
384
+
385
+ # Initialize the input dimensions
386
+ self.nx = None
387
+ self.ny = None
388
+ self.nz = None
389
+
390
+
391
+ def create_predict_dataloader(
392
+ self,
393
+ voxel_spacing: float,
394
+ tomo_algorithm: str,
395
+ runIDs: str = None):
396
+
397
+ # Split trainRunIDs, validateRunIDs, testRunIDs
398
+ if runIDs is None:
399
+ runIDs = [run.name for run in self.root.runs]
400
+
401
+ # Load the test data
402
+ test_files = io.load_predict_data(self.root, runIDs, voxel_spacing, tomo_algorithm)
403
+
404
+ # Create the cached dataset with non-random transforms
405
+ test_ds = CacheDataset(data=test_files, transform=augment.get_predict_transforms())
406
+
407
+ # Read (nx,ny,nz) for input tomograms.
408
+ if self.nx is None:
409
+ (self.nx,self.ny,self.nz) = test_ds[0]['image'].shape[1:]
410
+
411
+ # Create the DataLoader
412
+ test_loader = DataLoader(test_ds,
413
+ batch_size=4,
414
+ shuffle=False,
415
+ num_workers=4,
416
+ pin_memory=torch.cuda.is_available())
417
+ return test_loader
418
+
419
+ def get_dataloader_parameters(self):
420
+
421
+ parameters = {
422
+ 'config': self.copick_config,
423
+ 'voxel_size': self.voxel_size,
424
+ 'tomo_algorithm': self.tomo_algorithm
425
+ }
426
+
427
+ return parameters
428
+
429
+