fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastMONAI/dataset_info.py CHANGED
@@ -1,10 +1,11 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/08_dataset_info.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['MedDataset', 'get_class_weights']
4
+ __all__ = ['MedDataset', 'suggest_patch_size', 'get_class_weights']
5
5
 
6
6
  # %% ../nbs/08_dataset_info.ipynb 2
7
7
  from .vision_core import *
8
+ from .vision_plot import find_max_slice
8
9
 
9
10
  from sklearn.utils.class_weight import compute_class_weight
10
11
  from concurrent.futures import ThreadPoolExecutor
@@ -12,29 +13,37 @@ import pandas as pd
12
13
  import numpy as np
13
14
  import torch
14
15
  import glob
16
+ import matplotlib.pyplot as plt
17
+
18
+ # %% ../nbs/08_dataset_info.ipynb 3
19
+ import warnings
15
20
 
16
- # %% ../nbs/08_dataset_info.ipynb 4
17
21
  class MedDataset:
18
22
  """A class to extract and present information about the dataset."""
19
23
 
20
- def __init__(self, path=None, postfix: str = '', img_list: list = None,
21
- reorder: bool = False, dtype: (MedImage, MedMask) = MedImage,
22
- max_workers: int = 1):
24
+ def __init__(self, dataframe=None, image_col:str=None, mask_col:str="mask_path",
25
+ path=None, img_list=None, postfix:str='', apply_reorder:bool=True,
26
+ dtype:(MedImage, MedMask)=MedImage, max_workers:int=1):
23
27
  """Constructs MedDataset object.
24
28
 
25
29
  Args:
26
- path (str, optional): Path to the image folder.
27
- postfix (str, optional): Specify the file type if there are different files in the folder.
28
- img_list (List[str], optional): Alternatively, pass in a list with image paths.
29
- reorder (bool, optional): Whether to reorder the data to be closest to canonical (RAS+) orientation.
30
- dtype (Union[MedImage, MedMask], optional): Load data as datatype. Default is MedImage.
31
- max_workers (int, optional): The number of worker threads. Default is 1.
30
+ dataframe: DataFrame containing image paths.
31
+ image_col: Column name for image paths (used for visualization).
32
+ mask_col: Column name for mask/label paths when using dataframe mode.
33
+ path: Directory path containing images.
34
+ img_list: List of image file paths to analyze.
35
+ postfix: File postfix filter when using path mode.
36
+ apply_reorder: Whether to reorder images to RAS+ orientation.
37
+ dtype: MedImage for images or MedMask for segmentation masks.
38
+ max_workers: Number of parallel workers for processing.
32
39
  """
33
-
40
+ self.input_df = dataframe
41
+ self.image_col = image_col
42
+ self.mask_col = mask_col
34
43
  self.path = path
35
- self.postfix = postfix
36
44
  self.img_list = img_list
37
- self.reorder = reorder
45
+ self.postfix = postfix
46
+ self.apply_reorder = apply_reorder
38
47
  self.dtype = dtype
39
48
  self.max_workers = max_workers
40
49
  self.df = self._create_data_frame()
@@ -42,70 +51,343 @@ class MedDataset:
42
51
  def _create_data_frame(self):
43
52
  """Private method that returns a dataframe with information about the dataset."""
44
53
 
45
- if self.path:
46
- self.img_list = glob.glob(f'{self.path}/*{self.postfix}*')
47
- if not self.img_list: print('Could not find images. Check the image path')
54
+ # Handle img_list (simple list of paths)
55
+ if self.img_list is not None:
56
+ file_list = self.img_list
57
+
58
+ # Handle path-based initialization
59
+ elif self.path:
60
+ file_list = glob.glob(f'{self.path}/*{self.postfix}*')
61
+ if not file_list:
62
+ print('Could not find images. Check the image path')
63
+ return pd.DataFrame()
48
64
 
65
+ # Handle dataframe-based initialization
66
+ elif self.input_df is not None and self.mask_col in self.input_df.columns:
67
+ file_list = self.input_df[self.mask_col].tolist()
68
+
69
+ else:
70
+ print('Error: Must provide path, img_list, or dataframe with mask_col')
71
+ return pd.DataFrame()
72
+
73
+ # Process images to extract metadata
49
74
  with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
50
- data_info_dict = list(executor.map(self._get_data_info, self.img_list))
75
+ data_info_dict = list(executor.map(self._get_data_info, file_list))
51
76
 
52
77
  df = pd.DataFrame(data_info_dict)
53
-
54
- if df.orientation.nunique() > 1:
55
- print('The volumes in this dataset have different orientations. '
56
- 'Recommended to pass in the argument reorder=True when creating a MedDataset object for this dataset')
78
+
79
+ if len(df) > 0 and df.orientation.nunique() > 1 and not self.apply_reorder:
80
+ raise ValueError(
81
+ 'Mixed orientations detected in dataset. '
82
+ 'Please recreate MedDataset with apply_reorder=True to get correct resample values: '
83
+ 'MedDataset(..., apply_reorder=True)'
84
+ )
57
85
 
58
86
  return df
59
87
 
60
88
  def summary(self):
61
89
  """Summary DataFrame of the dataset with example path for similar data."""
62
-
90
+
63
91
  columns = ['dim_0', 'dim_1', 'dim_2', 'voxel_0', 'voxel_1', 'voxel_2', 'orientation']
64
-
92
+
65
93
  return self.df.groupby(columns, as_index=False).agg(
66
94
  example_path=('path', 'min'), total=('path', 'size')
67
95
  ).sort_values('total', ascending=False)
68
96
 
69
- def suggestion(self):
70
- """Voxel value that appears most often in dim_0, dim_1 and dim_2, and whether the data should be reoriented."""
71
-
72
- resample = [float(self.df.voxel_0.mode()[0]), float(self.df.voxel_1.mode()[0]), float(self.df.voxel_2.mode()[0])]
73
- return resample, self.reorder
97
+ def get_suggestion(self, include_patch_size: bool = False):
98
+ """Returns suggested preprocessing parameters as a dictionary.
99
+
100
+ The returned target_spacing is derived from the mode (most common value)
101
+ of voxel spacings in the dataset.
102
+
103
+ Note:
104
+ apply_reorder is NOT included in the return value because it is not
105
+ data-derived. Access dataset.apply_reorder directly if needed.
106
+
107
+ Args:
108
+ include_patch_size: If True, includes suggested patch_size for
109
+ patch-based training. Requires vision_patch module.
110
+
111
+ Returns:
112
+ dict: {'target_spacing': [voxel_0, voxel_1, voxel_2]}
113
+ If include_patch_size=True, also includes 'patch_size': [dim_0, dim_1, dim_2]
114
+ """
115
+ target_spacing = [float(self.df.voxel_0.mode()[0]), float(self.df.voxel_1.mode()[0]), float(self.df.voxel_2.mode()[0])]
116
+ result = {'target_spacing': target_spacing}
117
+
118
+ if include_patch_size:
119
+ result['patch_size'] = suggest_patch_size(self)
120
+
121
+ return result
74
122
 
75
123
  def _get_data_info(self, fn: str):
76
124
  """Private method to collect information about an image file."""
77
- _, o, _ = med_img_reader(fn, reorder=self.reorder, only_tensor=False, dtype=self.dtype)
125
+ try:
126
+ _, o, _ = med_img_reader(fn, apply_reorder=self.apply_reorder, only_tensor=False, dtype=self.dtype)
78
127
 
79
- info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],
80
- 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),
81
- 'orientation': f'{"".join(o.orientation)}+'}
128
+ info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],
129
+ 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),
130
+ 'orientation': f'{"".join(o.orientation)}+'}
82
131
 
83
- if self.dtype is MedMask:
84
- mask_labels_dict = o.count_labels()
85
- mask_labels_dict = {f'voxel_count_{int(key)}': val for key, val in mask_labels_dict.items()}
86
- info_dict.update(mask_labels_dict)
132
+ if self.dtype is MedMask:
133
+ # Calculate voxel volume in mm³
134
+ voxel_volume = o.spacing[0] * o.spacing[1] * o.spacing[2]
87
135
 
88
- return info_dict
136
+ # Get voxel counts for each label
137
+ mask_labels_dict = o.count_labels()
89
138
 
90
- def get_largest_img_size(self, resample: list = None) -> list:
91
- """Get the largest image size in the dataset."""
92
-
93
- dims = None
139
+ # Calculate volumes for each label > 0 (skip background)
140
+ for key, voxel_count in mask_labels_dict.items():
141
+ label_int = int(key)
142
+ if label_int > 0 and voxel_count > 0: # Skip background (label 0)
143
+ volume_mm3 = voxel_count * voxel_volume
144
+ info_dict[f'label_{label_int}_volume_mm3'] = round(volume_mm3, 4)
94
145
 
95
- if resample is not None:
146
+ return info_dict
147
+
148
+ except Exception as e:
149
+ print(f"Warning: Failed to process {fn}: {e}")
150
+ return {'path': fn, 'error': str(e)}
151
+
152
+ def calculate_target_size(self, target_spacing: list = None) -> list:
153
+ """Calculate the target image size for the dataset.
154
+
155
+ .. deprecated::
156
+ Use `get_size_statistics(target_spacing)['max']` instead for consistency
157
+ with other size statistics methods.
158
+
159
+ Args:
160
+ target_spacing: If provided, calculates size after resampling to this spacing.
161
+ If None, returns original dimensions.
162
+
163
+ Returns:
164
+ list: [dim_0, dim_1, dim_2] largest dimensions in dataset.
165
+ """
166
+ warnings.warn(
167
+ "calculate_target_size() is deprecated. "
168
+ "Use get_size_statistics(target_spacing)['max'] instead.",
169
+ DeprecationWarning,
170
+ stacklevel=2
171
+ )
172
+ if target_spacing is not None:
96
173
  org_voxels = self.df[["voxel_0", "voxel_1", 'voxel_2']].values
97
174
  org_dims = self.df[["dim_0", "dim_1", 'dim_2']].values
98
175
 
99
- ratio = org_voxels/resample
176
+ ratio = org_voxels/target_spacing
100
177
  new_dims = (org_dims * ratio).T
101
- dims = [float(new_dims[0].max().round()), float(new_dims[1].max().round()), float(new_dims[2].max().round())]
102
-
178
+ # Use floor() to match TorchIO's Resample dimension calculation
179
+ dims = [float(np.floor(new_dims[0].max())), float(np.floor(new_dims[1].max())), float(np.floor(new_dims[2].max()))]
103
180
  else:
104
181
  dims = [float(self.df.dim_0.max()), float(self.df.dim_1.max()), float(self.df.dim_2.max())]
105
182
 
106
183
  return dims
107
184
 
108
- # %% ../nbs/08_dataset_info.ipynb 5
185
+ def get_size_statistics(self, target_spacing: list = None) -> dict:
186
+ """Calculate comprehensive size statistics for the dataset.
187
+
188
+ Args:
189
+ target_spacing: If provided, calculates statistics after
190
+ simulating resampling to this spacing.
191
+
192
+ Returns:
193
+ dict with keys: 'median', 'min', 'max', 'std', 'percentile_10', 'percentile_90'
194
+ Each value is a list [dim_0, dim_1, dim_2].
195
+ """
196
+ if len(self.df) == 0:
197
+ raise ValueError("Dataset is empty - cannot calculate statistics")
198
+
199
+ if target_spacing is not None:
200
+ # Simulate resampled dimensions
201
+ org_voxels = self.df[["voxel_0", "voxel_1", "voxel_2"]].values
202
+ org_dims = self.df[["dim_0", "dim_1", "dim_2"]].values
203
+ ratio = org_voxels / np.array(target_spacing)
204
+ dims = np.floor(org_dims * ratio)
205
+ else:
206
+ dims = self.df[["dim_0", "dim_1", "dim_2"]].values
207
+
208
+ return {
209
+ 'median': [float(np.median(dims[:, i])) for i in range(3)],
210
+ 'min': [float(np.min(dims[:, i])) for i in range(3)],
211
+ 'max': [float(np.max(dims[:, i])) for i in range(3)],
212
+ 'std': [float(np.std(dims[:, i])) for i in range(3)],
213
+ 'percentile_10': [float(np.percentile(dims[:, i], 10)) for i in range(3)],
214
+ 'percentile_90': [float(np.percentile(dims[:, i], 90)) for i in range(3)],
215
+ }
216
+
217
+ def get_volume_summary(self):
218
+ """Returns DataFrame with volume statistics for each label.
219
+
220
+ Returns:
221
+ DataFrame with columns: label, count, mean_mm3, median_mm3, min_mm3, max_mm3
222
+ Returns None if no volume columns found (dtype was not MedMask).
223
+ """
224
+ volume_cols = [col for col in self.df.columns if col.endswith('_volume_mm3')]
225
+
226
+ if not volume_cols:
227
+ return None
228
+
229
+ summary_data = []
230
+ for col in volume_cols:
231
+ non_zero = self.df[self.df[col] > 0][col]
232
+ if len(non_zero) > 0:
233
+ summary_data.append({
234
+ 'label': col.replace('_volume_mm3', ''),
235
+ 'count': len(non_zero),
236
+ 'mean_mm3': non_zero.mean(),
237
+ 'median_mm3': non_zero.median(),
238
+ 'min_mm3': non_zero.min(),
239
+ 'max_mm3': non_zero.max()
240
+ })
241
+
242
+ return pd.DataFrame(summary_data) if summary_data else None
243
+
244
+ def _visualize_single_case(self, img_path, mask_path, case_id, anatomical_plane=2, cmap='hot', figsize=(12, 5)):
245
+ """Helper method to visualize a single case."""
246
+ try:
247
+ # Create MedImage and MedMask with current preprocessing settings
248
+ suggestion = self.get_suggestion()
249
+ MedBase.item_preprocessing(target_spacing=suggestion['target_spacing'], apply_reorder=self.apply_reorder)
250
+
251
+ img = MedImage.create(img_path)
252
+ mask = MedMask.create(mask_path)
253
+
254
+ # Find optimal slice using explicit function
255
+ mask_data = mask.numpy()[0] # Remove channel dimension
256
+ optimal_slice = find_max_slice(mask_data, anatomical_plane)
257
+
258
+ # Create subplot
259
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
260
+
261
+ # Show image
262
+ img.show(ctx=axes[0], anatomical_plane=anatomical_plane, slice_index=optimal_slice)
263
+ axes[0].set_title(f"{case_id} - Image (slice {optimal_slice})")
264
+
265
+ # Show overlay
266
+ img.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice)
267
+ mask.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice,
268
+ alpha=0.3, cmap=cmap)
269
+ axes[1].set_title(f"{case_id} - Overlay (slice {optimal_slice})")
270
+
271
+ # Adjust spacing to bring plots closer
272
+ plt.subplots_adjust(wspace=0.1)
273
+ plt.tight_layout()
274
+ plt.show()
275
+
276
+ except Exception as e:
277
+ print(f"Failed to visualize case {case_id}: {e}")
278
+
279
+ def visualize_cases(self, n_cases=4, anatomical_plane=2, cmap='hot', figsize=(12, 5)):
280
+ """Visualize cases from the dataset.
281
+
282
+ Args:
283
+ n_cases: Number of cases to show.
284
+ anatomical_plane: 0=sagittal, 1=coronal, 2=axial
285
+ cmap: Colormap for mask overlay
286
+ figsize: Figure size for each case
287
+ """
288
+ if self.input_df is None:
289
+ print("Error: No dataframe provided. Cannot visualize cases.")
290
+ return
291
+
292
+ if self.image_col is None:
293
+ print("Error: No image_col specified. Cannot visualize cases.")
294
+ return
295
+
296
+ # Check if required columns exist
297
+ if self.image_col not in self.input_df.columns:
298
+ print(f"Error: Column '{self.image_col}' not found in dataframe.")
299
+ return
300
+
301
+ if self.mask_col not in self.input_df.columns:
302
+ print(f"Error: Column '{self.mask_col}' not found in dataframe.")
303
+ return
304
+
305
+ for idx in range(min(n_cases, len(self.input_df))):
306
+ row = self.input_df.iloc[idx]
307
+ case_id = row.get('case_id', f'Case_{idx}') # Fallback if no case_id
308
+ img_path = row[self.image_col]
309
+ mask_path = row[self.mask_col]
310
+
311
+ self._visualize_single_case(img_path, mask_path, case_id, anatomical_plane, cmap, figsize)
312
+ print("-" * 60)
313
+
314
+ # %% ../nbs/08_dataset_info.ipynb 4
315
+ def suggest_patch_size(
316
+ dataset: MedDataset,
317
+ target_spacing: list = None,
318
+ min_patch_size: list = None,
319
+ max_patch_size: list = None,
320
+ divisor: int = 16
321
+ ) -> list:
322
+ """Suggest optimal patch size based on median image dimensions.
323
+
324
+ Algorithm:
325
+ 1. Use median shape for robustness to outliers
326
+ 2. Round down to nearest multiple of divisor (16 for 4+ UNet pooling layers)
327
+ 3. Clamp to [min_patch_size, max_patch_size]
328
+
329
+ Args:
330
+ dataset: MedDataset instance with analyzed images.
331
+ target_spacing: Target voxel spacing [x, y, z]. If None, uses
332
+ dataset.get_suggestion()['target_spacing'].
333
+ min_patch_size: Minimum per dimension. Default [32, 32, 32].
334
+ max_patch_size: Maximum per dimension. Default [256, 256, 256].
335
+ divisor: Ensure divisibility (default 16 for UNet compatibility).
336
+
337
+ Returns:
338
+ list: [patch_dim_0, patch_dim_1, patch_dim_2]
339
+
340
+ Example:
341
+ >>> from fastMONAI.dataset_info import MedDataset
342
+ >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)
343
+ >>>
344
+ >>> # Use recommended spacing
345
+ >>> patch_size = suggest_patch_size(dataset)
346
+ >>>
347
+ >>> # Use custom spacing
348
+ >>> patch_size = suggest_patch_size(dataset, target_spacing=[1.0, 1.0, 2.0])
349
+ """
350
+ # Defaults
351
+ min_patch_size = min_patch_size or [32, 32, 32]
352
+ max_patch_size = max_patch_size or [256, 256, 256]
353
+
354
+ # Use explicit spacing or get from dataset suggestion
355
+ if target_spacing is None:
356
+ suggestion = dataset.get_suggestion()
357
+ target_spacing = suggestion['target_spacing']
358
+
359
+ # Get size statistics (resampled to target_spacing)
360
+ stats = dataset.get_size_statistics(target_spacing)
361
+ median_shape = stats['median']
362
+
363
+ # Handle single-image edge case
364
+ if len(dataset.df) == 1:
365
+ warnings.warn("Single image dataset - using image dimensions directly")
366
+
367
+ # Step 1: Round down to nearest divisor
368
+ def round_to_divisor(val, div):
369
+ """Round down to nearest multiple of divisor."""
370
+ return max(div, int(val // div) * div)
371
+
372
+ patch_size = [round_to_divisor(dim, divisor) for dim in median_shape]
373
+
374
+ # Step 2: Clamp to bounds
375
+ patch_size = [
376
+ max(min_p, min(max_p, p))
377
+ for p, min_p, max_p in zip(patch_size, min_patch_size, max_patch_size)
378
+ ]
379
+
380
+ # Edge case: image smaller than suggested patch
381
+ for i, (p, median_dim) in enumerate(zip(patch_size, median_shape)):
382
+ if median_dim < p:
383
+ warnings.warn(
384
+ f"Median dimension {i} ({median_dim:.0f}) smaller than suggested "
385
+ f"patch_size ({p}). Images will require padding."
386
+ )
387
+
388
+ return patch_size
389
+
390
+ # %% ../nbs/08_dataset_info.ipynb 6
109
391
  def get_class_weights(labels: (np.array, list), class_weight: str = 'balanced') -> torch.Tensor:
110
392
  """Calculates and returns the class weights.
111
393
 
@@ -94,7 +94,7 @@ def download_ixi_data(path: (str, Path) = '../data') -> Path:
94
94
  if len(list(img_path.iterdir())) >= 581: # 581 imgs in the IXI dataset
95
95
  is_extracted = True
96
96
  print(f"Images already downloaded and extracted to {img_path}")
97
- except:
97
+ except (FileNotFoundError, StopIteration, OSError):
98
98
  is_extracted = False
99
99
 
100
100
  if not is_extracted: