fastMONAI 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +171 -27
- fastMONAI/dataset_info.py +190 -45
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +101 -18
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +133 -29
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1125 -0
- fastMONAI/vision_plot.py +1 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/METADATA +5 -5
- fastmonai-0.5.4.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/top_level.txt +0 -0
fastMONAI/dataset_info.py
CHANGED
|
@@ -5,6 +5,7 @@ __all__ = ['MedDataset', 'get_class_weights']
|
|
|
5
5
|
|
|
6
6
|
# %% ../nbs/08_dataset_info.ipynb 2
|
|
7
7
|
from .vision_core import *
|
|
8
|
+
from .vision_plot import find_max_slice
|
|
8
9
|
|
|
9
10
|
from sklearn.utils.class_weight import compute_class_weight
|
|
10
11
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -12,29 +13,35 @@ import pandas as pd
|
|
|
12
13
|
import numpy as np
|
|
13
14
|
import torch
|
|
14
15
|
import glob
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
15
17
|
|
|
16
|
-
# %% ../nbs/08_dataset_info.ipynb
|
|
18
|
+
# %% ../nbs/08_dataset_info.ipynb 3
|
|
17
19
|
class MedDataset:
|
|
18
20
|
"""A class to extract and present information about the dataset."""
|
|
19
21
|
|
|
20
|
-
def __init__(self,
|
|
21
|
-
|
|
22
|
-
max_workers:
|
|
22
|
+
def __init__(self, dataframe=None, image_col:str=None, mask_col:str="mask_path",
|
|
23
|
+
path=None, img_list=None, postfix:str='', apply_reorder:bool=False,
|
|
24
|
+
dtype:(MedImage, MedMask)=MedImage, max_workers:int=1):
|
|
23
25
|
"""Constructs MedDataset object.
|
|
24
26
|
|
|
25
27
|
Args:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
dataframe: DataFrame containing image paths.
|
|
29
|
+
image_col: Column name for image paths (used for visualization).
|
|
30
|
+
mask_col: Column name for mask/label paths when using dataframe mode.
|
|
31
|
+
path: Directory path containing images.
|
|
32
|
+
img_list: List of image file paths to analyze.
|
|
33
|
+
postfix: File postfix filter when using path mode.
|
|
34
|
+
apply_reorder: Whether to reorder images to RAS+ orientation.
|
|
35
|
+
dtype: MedImage for images or MedMask for segmentation masks.
|
|
36
|
+
max_workers: Number of parallel workers for processing.
|
|
32
37
|
"""
|
|
33
|
-
|
|
38
|
+
self.input_df = dataframe
|
|
39
|
+
self.image_col = image_col
|
|
40
|
+
self.mask_col = mask_col
|
|
34
41
|
self.path = path
|
|
35
|
-
self.postfix = postfix
|
|
36
42
|
self.img_list = img_list
|
|
37
|
-
self.
|
|
43
|
+
self.postfix = postfix
|
|
44
|
+
self.apply_reorder = apply_reorder
|
|
38
45
|
self.dtype = dtype
|
|
39
46
|
self.max_workers = max_workers
|
|
40
47
|
self.df = self._create_data_frame()
|
|
@@ -42,69 +49,207 @@ class MedDataset:
|
|
|
42
49
|
def _create_data_frame(self):
|
|
43
50
|
"""Private method that returns a dataframe with information about the dataset."""
|
|
44
51
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
52
|
+
# Handle img_list (simple list of paths)
|
|
53
|
+
if self.img_list is not None:
|
|
54
|
+
file_list = self.img_list
|
|
55
|
+
|
|
56
|
+
# Handle path-based initialization
|
|
57
|
+
elif self.path:
|
|
58
|
+
file_list = glob.glob(f'{self.path}/*{self.postfix}*')
|
|
59
|
+
if not file_list:
|
|
60
|
+
print('Could not find images. Check the image path')
|
|
61
|
+
return pd.DataFrame()
|
|
62
|
+
|
|
63
|
+
# Handle dataframe-based initialization
|
|
64
|
+
elif self.input_df is not None and self.mask_col in self.input_df.columns:
|
|
65
|
+
file_list = self.input_df[self.mask_col].tolist()
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
print('Error: Must provide path, img_list, or dataframe with mask_col')
|
|
69
|
+
return pd.DataFrame()
|
|
48
70
|
|
|
71
|
+
# Process images to extract metadata
|
|
49
72
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
50
|
-
data_info_dict = list(executor.map(self._get_data_info,
|
|
73
|
+
data_info_dict = list(executor.map(self._get_data_info, file_list))
|
|
51
74
|
|
|
52
75
|
df = pd.DataFrame(data_info_dict)
|
|
53
|
-
|
|
54
|
-
if df.orientation.nunique() > 1:
|
|
55
|
-
|
|
56
|
-
|
|
76
|
+
|
|
77
|
+
if len(df) > 0 and df.orientation.nunique() > 1 and not self.apply_reorder:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
'Mixed orientations detected in dataset. '
|
|
80
|
+
'Please recreate MedDataset with apply_reorder=True to get correct resample values: '
|
|
81
|
+
'MedDataset(..., apply_reorder=True)'
|
|
82
|
+
)
|
|
57
83
|
|
|
58
84
|
return df
|
|
59
85
|
|
|
60
86
|
def summary(self):
|
|
61
87
|
"""Summary DataFrame of the dataset with example path for similar data."""
|
|
62
|
-
|
|
88
|
+
|
|
63
89
|
columns = ['dim_0', 'dim_1', 'dim_2', 'voxel_0', 'voxel_1', 'voxel_2', 'orientation']
|
|
64
|
-
|
|
90
|
+
|
|
65
91
|
return self.df.groupby(columns, as_index=False).agg(
|
|
66
92
|
example_path=('path', 'min'), total=('path', 'size')
|
|
67
93
|
).sort_values('total', ascending=False)
|
|
68
94
|
|
|
69
|
-
def
|
|
70
|
-
"""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
95
|
+
def get_suggestion(self):
|
|
96
|
+
"""Returns suggested preprocessing parameters as a dictionary.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
dict: {'target_spacing': [voxel_0, voxel_1, voxel_2], 'apply_reorder': bool}
|
|
100
|
+
"""
|
|
101
|
+
target_spacing = [float(self.df.voxel_0.mode()[0]), float(self.df.voxel_1.mode()[0]), float(self.df.voxel_2.mode()[0])]
|
|
102
|
+
return {'target_spacing': target_spacing, 'apply_reorder': self.apply_reorder}
|
|
74
103
|
|
|
75
104
|
def _get_data_info(self, fn: str):
|
|
76
105
|
"""Private method to collect information about an image file."""
|
|
77
|
-
|
|
106
|
+
try:
|
|
107
|
+
_, o, _ = med_img_reader(fn, apply_reorder=self.apply_reorder, only_tensor=False, dtype=self.dtype)
|
|
78
108
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
109
|
+
info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],
|
|
110
|
+
'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),
|
|
111
|
+
'orientation': f'{"".join(o.orientation)}+'}
|
|
82
112
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
info_dict.update(mask_labels_dict)
|
|
113
|
+
if self.dtype is MedMask:
|
|
114
|
+
# Calculate voxel volume in mm³
|
|
115
|
+
voxel_volume = o.spacing[0] * o.spacing[1] * o.spacing[2]
|
|
87
116
|
|
|
88
|
-
|
|
117
|
+
# Get voxel counts for each label
|
|
118
|
+
mask_labels_dict = o.count_labels()
|
|
89
119
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
120
|
+
# Calculate volumes for each label > 0 (skip background)
|
|
121
|
+
for key, voxel_count in mask_labels_dict.items():
|
|
122
|
+
label_int = int(key)
|
|
123
|
+
if label_int > 0 and voxel_count > 0: # Skip background (label 0)
|
|
124
|
+
volume_mm3 = voxel_count * voxel_volume
|
|
125
|
+
info_dict[f'label_{label_int}_volume_mm3'] = round(volume_mm3, 4)
|
|
94
126
|
|
|
95
|
-
|
|
127
|
+
return info_dict
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f"Warning: Failed to process {fn}: {e}")
|
|
131
|
+
return {'path': fn, 'error': str(e)}
|
|
132
|
+
|
|
133
|
+
def calculate_target_size(self, target_spacing: list = None) -> list:
|
|
134
|
+
"""Calculate the target image size for the dataset.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
target_spacing: If provided, calculates size after resampling to this spacing.
|
|
138
|
+
If None, returns original dimensions.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
list: [dim_0, dim_1, dim_2] largest dimensions in dataset.
|
|
142
|
+
"""
|
|
143
|
+
if target_spacing is not None:
|
|
96
144
|
org_voxels = self.df[["voxel_0", "voxel_1", 'voxel_2']].values
|
|
97
145
|
org_dims = self.df[["dim_0", "dim_1", 'dim_2']].values
|
|
98
146
|
|
|
99
|
-
ratio = org_voxels/
|
|
147
|
+
ratio = org_voxels/target_spacing
|
|
100
148
|
new_dims = (org_dims * ratio).T
|
|
101
|
-
|
|
102
|
-
|
|
149
|
+
# Use floor() to match TorchIO's Resample dimension calculation
|
|
150
|
+
dims = [float(np.floor(new_dims[0].max())), float(np.floor(new_dims[1].max())), float(np.floor(new_dims[2].max()))]
|
|
103
151
|
else:
|
|
104
152
|
dims = [float(self.df.dim_0.max()), float(self.df.dim_1.max()), float(self.df.dim_2.max())]
|
|
105
153
|
|
|
106
154
|
return dims
|
|
107
155
|
|
|
156
|
+
def get_volume_summary(self):
|
|
157
|
+
"""Returns DataFrame with volume statistics for each label.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
DataFrame with columns: label, count, mean_mm3, median_mm3, min_mm3, max_mm3
|
|
161
|
+
Returns None if no volume columns found (dtype was not MedMask).
|
|
162
|
+
"""
|
|
163
|
+
volume_cols = [col for col in self.df.columns if col.endswith('_volume_mm3')]
|
|
164
|
+
|
|
165
|
+
if not volume_cols:
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
summary_data = []
|
|
169
|
+
for col in volume_cols:
|
|
170
|
+
non_zero = self.df[self.df[col] > 0][col]
|
|
171
|
+
if len(non_zero) > 0:
|
|
172
|
+
summary_data.append({
|
|
173
|
+
'label': col.replace('_volume_mm3', ''),
|
|
174
|
+
'count': len(non_zero),
|
|
175
|
+
'mean_mm3': non_zero.mean(),
|
|
176
|
+
'median_mm3': non_zero.median(),
|
|
177
|
+
'min_mm3': non_zero.min(),
|
|
178
|
+
'max_mm3': non_zero.max()
|
|
179
|
+
})
|
|
180
|
+
|
|
181
|
+
return pd.DataFrame(summary_data) if summary_data else None
|
|
182
|
+
|
|
183
|
+
def _visualize_single_case(self, img_path, mask_path, case_id, anatomical_plane=2, cmap='hot', figsize=(12, 5)):
|
|
184
|
+
"""Helper method to visualize a single case."""
|
|
185
|
+
try:
|
|
186
|
+
# Create MedImage and MedMask with current preprocessing settings
|
|
187
|
+
suggestion = self.get_suggestion()
|
|
188
|
+
MedBase.item_preprocessing(target_spacing=suggestion['target_spacing'], apply_reorder=suggestion['apply_reorder'])
|
|
189
|
+
|
|
190
|
+
img = MedImage.create(img_path)
|
|
191
|
+
mask = MedMask.create(mask_path)
|
|
192
|
+
|
|
193
|
+
# Find optimal slice using explicit function
|
|
194
|
+
mask_data = mask.numpy()[0] # Remove channel dimension
|
|
195
|
+
optimal_slice = find_max_slice(mask_data, anatomical_plane)
|
|
196
|
+
|
|
197
|
+
# Create subplot
|
|
198
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
199
|
+
|
|
200
|
+
# Show image
|
|
201
|
+
img.show(ctx=axes[0], anatomical_plane=anatomical_plane, slice_index=optimal_slice)
|
|
202
|
+
axes[0].set_title(f"{case_id} - Image (slice {optimal_slice})")
|
|
203
|
+
|
|
204
|
+
# Show overlay
|
|
205
|
+
img.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice)
|
|
206
|
+
mask.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice,
|
|
207
|
+
alpha=0.3, cmap=cmap)
|
|
208
|
+
axes[1].set_title(f"{case_id} - Overlay (slice {optimal_slice})")
|
|
209
|
+
|
|
210
|
+
# Adjust spacing to bring plots closer
|
|
211
|
+
plt.subplots_adjust(wspace=0.1)
|
|
212
|
+
plt.tight_layout()
|
|
213
|
+
plt.show()
|
|
214
|
+
|
|
215
|
+
except Exception as e:
|
|
216
|
+
print(f"Failed to visualize case {case_id}: {e}")
|
|
217
|
+
|
|
218
|
+
def visualize_cases(self, n_cases=4, anatomical_plane=2, cmap='hot', figsize=(12, 5)):
|
|
219
|
+
"""Visualize cases from the dataset.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
n_cases: Number of cases to show.
|
|
223
|
+
anatomical_plane: 0=sagittal, 1=coronal, 2=axial
|
|
224
|
+
cmap: Colormap for mask overlay
|
|
225
|
+
figsize: Figure size for each case
|
|
226
|
+
"""
|
|
227
|
+
if self.input_df is None:
|
|
228
|
+
print("Error: No dataframe provided. Cannot visualize cases.")
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
if self.image_col is None:
|
|
232
|
+
print("Error: No image_col specified. Cannot visualize cases.")
|
|
233
|
+
return
|
|
234
|
+
|
|
235
|
+
# Check if required columns exist
|
|
236
|
+
if self.image_col not in self.input_df.columns:
|
|
237
|
+
print(f"Error: Column '{self.image_col}' not found in dataframe.")
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
if self.mask_col not in self.input_df.columns:
|
|
241
|
+
print(f"Error: Column '{self.mask_col}' not found in dataframe.")
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
for idx in range(min(n_cases, len(self.input_df))):
|
|
245
|
+
row = self.input_df.iloc[idx]
|
|
246
|
+
case_id = row.get('case_id', f'Case_{idx}') # Fallback if no case_id
|
|
247
|
+
img_path = row[self.image_col]
|
|
248
|
+
mask_path = row[self.mask_col]
|
|
249
|
+
|
|
250
|
+
self._visualize_single_case(img_path, mask_path, case_id, anatomical_plane, cmap, figsize)
|
|
251
|
+
print("-" * 60)
|
|
252
|
+
|
|
108
253
|
# %% ../nbs/08_dataset_info.ipynb 5
|
|
109
254
|
def get_class_weights(labels: (np.array, list), class_weight: str = 'balanced') -> torch.Tensor:
|
|
110
255
|
"""Calculates and returns the class weights.
|
fastMONAI/external_data.py
CHANGED
|
@@ -94,7 +94,7 @@ def download_ixi_data(path: (str, Path) = '../data') -> Path:
|
|
|
94
94
|
if len(list(img_path.iterdir())) >= 581: # 581 imgs in the IXI dataset
|
|
95
95
|
is_extracted = True
|
|
96
96
|
print(f"Images already downloaded and extracted to {img_path}")
|
|
97
|
-
except:
|
|
97
|
+
except (FileNotFoundError, StopIteration, OSError):
|
|
98
98
|
is_extracted = False
|
|
99
99
|
|
|
100
100
|
if not is_extracted:
|
fastMONAI/utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_utils.ipynb.
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
|
-
__all__ = ['store_variables', 'load_variables', '
|
|
4
|
+
__all__ = ['store_variables', 'load_variables', 'store_patch_variables', 'load_patch_variables', 'print_colab_gpu_info',
|
|
5
|
+
'ModelTrackingCallback', 'MLflowUIManager']
|
|
5
6
|
|
|
6
7
|
# %% ../nbs/07_utils.ipynb 1
|
|
7
8
|
import pickle
|
|
@@ -17,10 +18,10 @@ from fastcore.foundation import L
|
|
|
17
18
|
from typing import Any
|
|
18
19
|
|
|
19
20
|
# %% ../nbs/07_utils.ipynb 3
|
|
20
|
-
def store_variables(pkl_fn: str | Path, size: list,
|
|
21
|
+
def store_variables(pkl_fn: str | Path, size: list, apply_reorder: bool, target_spacing: int | list):
|
|
21
22
|
"""Save variable values in a pickle file."""
|
|
22
23
|
|
|
23
|
-
var_vals = [size,
|
|
24
|
+
var_vals = [size, apply_reorder, target_spacing]
|
|
24
25
|
|
|
25
26
|
with open(pkl_fn, 'wb') as f:
|
|
26
27
|
pickle.dump(var_vals, f)
|
|
@@ -38,7 +39,89 @@ def load_variables(pkl_fn: (str, Path)):
|
|
|
38
39
|
with open(pkl_fn, 'rb') as f:
|
|
39
40
|
return pickle.load(f)
|
|
40
41
|
|
|
41
|
-
# %% ../nbs/07_utils.ipynb
|
|
42
|
+
# %% ../nbs/07_utils.ipynb 6
|
|
43
|
+
def store_patch_variables(
|
|
44
|
+
pkl_fn: str | Path,
|
|
45
|
+
patch_size: list,
|
|
46
|
+
patch_overlap: int | float | list,
|
|
47
|
+
aggregation_mode: str,
|
|
48
|
+
apply_reorder: bool = False,
|
|
49
|
+
target_spacing: list = None,
|
|
50
|
+
sampler_type: str = 'uniform',
|
|
51
|
+
label_probabilities: dict = None,
|
|
52
|
+
samples_per_volume: int = 8,
|
|
53
|
+
queue_length: int = 300,
|
|
54
|
+
queue_num_workers: int = 4,
|
|
55
|
+
keep_largest_component: bool = False
|
|
56
|
+
):
|
|
57
|
+
"""Save patch-based training and inference configuration to a pickle file.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
pkl_fn: Path to save the pickle file.
|
|
61
|
+
patch_size: Size of patches [x, y, z].
|
|
62
|
+
patch_overlap: Overlap for inference (int, float 0-1, or list).
|
|
63
|
+
aggregation_mode: GridAggregator mode ('crop', 'average', 'hann').
|
|
64
|
+
apply_reorder: Whether to reorder to canonical (RAS+) orientation.
|
|
65
|
+
target_spacing: Target voxel spacing [x, y, z].
|
|
66
|
+
sampler_type: Type of sampler used during training.
|
|
67
|
+
label_probabilities: Label probabilities for LabelSampler.
|
|
68
|
+
samples_per_volume: Number of patches extracted per volume during training.
|
|
69
|
+
queue_length: Maximum number of patches in queue buffer.
|
|
70
|
+
queue_num_workers: Number of workers for parallel patch extraction.
|
|
71
|
+
keep_largest_component: If True, keep only the largest connected component
|
|
72
|
+
in binary segmentation predictions during inference.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> store_patch_variables(
|
|
76
|
+
... 'patch_settings.pkl',
|
|
77
|
+
... patch_size=[96, 96, 96],
|
|
78
|
+
... patch_overlap=0.5,
|
|
79
|
+
... aggregation_mode='hann',
|
|
80
|
+
... apply_reorder=True,
|
|
81
|
+
... target_spacing=[1.0, 1.0, 1.0],
|
|
82
|
+
... samples_per_volume=16,
|
|
83
|
+
... keep_largest_component=True
|
|
84
|
+
... )
|
|
85
|
+
"""
|
|
86
|
+
config = {
|
|
87
|
+
'patch_size': patch_size,
|
|
88
|
+
'patch_overlap': patch_overlap,
|
|
89
|
+
'aggregation_mode': aggregation_mode,
|
|
90
|
+
'apply_reorder': apply_reorder,
|
|
91
|
+
'target_spacing': target_spacing,
|
|
92
|
+
'sampler_type': sampler_type,
|
|
93
|
+
'label_probabilities': label_probabilities,
|
|
94
|
+
'samples_per_volume': samples_per_volume,
|
|
95
|
+
'queue_length': queue_length,
|
|
96
|
+
'queue_num_workers': queue_num_workers,
|
|
97
|
+
'keep_largest_component': keep_largest_component
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
with open(pkl_fn, 'wb') as f:
|
|
101
|
+
pickle.dump(config, f)
|
|
102
|
+
|
|
103
|
+
# %% ../nbs/07_utils.ipynb 7
|
|
104
|
+
def load_patch_variables(pkl_fn: str | Path) -> dict:
|
|
105
|
+
"""Load patch-based training and inference configuration from a pickle file.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
pkl_fn: Path to the pickle file.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Dictionary with patch configuration including:
|
|
112
|
+
- patch_size, patch_overlap, aggregation_mode
|
|
113
|
+
- apply_reorder, target_spacing, sampler_type, label_probabilities
|
|
114
|
+
- samples_per_volume, queue_length, queue_num_workers
|
|
115
|
+
|
|
116
|
+
Example:
|
|
117
|
+
>>> config = load_patch_variables('patch_settings.pkl')
|
|
118
|
+
>>> from fastMONAI.vision_patch import PatchConfig
|
|
119
|
+
>>> patch_config = PatchConfig(**config)
|
|
120
|
+
"""
|
|
121
|
+
with open(pkl_fn, 'rb') as f:
|
|
122
|
+
return pickle.load(f)
|
|
123
|
+
|
|
124
|
+
# %% ../nbs/07_utils.ipynb 8
|
|
42
125
|
def print_colab_gpu_info():
|
|
43
126
|
"""Check if we have a GPU attached to the runtime."""
|
|
44
127
|
|
|
@@ -52,7 +135,7 @@ def print_colab_gpu_info():
|
|
|
52
135
|
if torch.cuda.is_available(): print('GPU attached.')
|
|
53
136
|
else: print(colab_gpu_msg)
|
|
54
137
|
|
|
55
|
-
# %% ../nbs/07_utils.ipynb
|
|
138
|
+
# %% ../nbs/07_utils.ipynb 9
|
|
56
139
|
class ModelTrackingCallback(Callback):
|
|
57
140
|
"""
|
|
58
141
|
A FastAI callback for comprehensive MLflow experiment tracking.
|
|
@@ -67,8 +150,8 @@ class ModelTrackingCallback(Callback):
|
|
|
67
150
|
loss_function: str,
|
|
68
151
|
item_tfms: list[Any],
|
|
69
152
|
size: list[int],
|
|
70
|
-
|
|
71
|
-
|
|
153
|
+
target_spacing: list[float],
|
|
154
|
+
apply_reorder: bool
|
|
72
155
|
):
|
|
73
156
|
"""
|
|
74
157
|
Initialize the MLflow tracking callback.
|
|
@@ -77,15 +160,15 @@ class ModelTrackingCallback(Callback):
|
|
|
77
160
|
model_name: Name of the model architecture for registration
|
|
78
161
|
loss_function: Name of the loss function being used
|
|
79
162
|
size: Model input dimensions
|
|
80
|
-
|
|
81
|
-
|
|
163
|
+
target_spacing: Resampling dimensions
|
|
164
|
+
apply_reorder: Whether reordering augmentation is applied
|
|
82
165
|
"""
|
|
83
166
|
self.model_name = model_name
|
|
84
167
|
self.loss_function = loss_function
|
|
85
168
|
self.item_tfms = item_tfms
|
|
86
169
|
self.size = size
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
170
|
+
self.target_spacing = target_spacing
|
|
171
|
+
self.apply_reorder = apply_reorder
|
|
89
172
|
|
|
90
173
|
self.config = self._build_config()
|
|
91
174
|
|
|
@@ -128,8 +211,8 @@ class ModelTrackingCallback(Callback):
|
|
|
128
211
|
"loss_function": self.loss_function,
|
|
129
212
|
"transform_details": transform_details,
|
|
130
213
|
"size": self.size,
|
|
131
|
-
"
|
|
132
|
-
"
|
|
214
|
+
"target_spacing": self.target_spacing,
|
|
215
|
+
"apply_reorder": self.apply_reorder,
|
|
133
216
|
}
|
|
134
217
|
|
|
135
218
|
def _extract_training_params(self) -> dict[str, Any]:
|
|
@@ -143,8 +226,8 @@ class ModelTrackingCallback(Callback):
|
|
|
143
226
|
|
|
144
227
|
params["loss_function"] = self.config["loss_function"]
|
|
145
228
|
params["size"] = self.config["size"]
|
|
146
|
-
params["
|
|
147
|
-
params["
|
|
229
|
+
params["target_spacing"] = self.config["target_spacing"]
|
|
230
|
+
params["apply_reorder"] = self.config["apply_reorder"]
|
|
148
231
|
|
|
149
232
|
params["transformations"] = json.dumps(
|
|
150
233
|
self.config["transform_details"],
|
|
@@ -218,7 +301,7 @@ class ModelTrackingCallback(Callback):
|
|
|
218
301
|
self.learn.cbs = original_cbs
|
|
219
302
|
|
|
220
303
|
config_path = temp_dir / "inference_settings.pkl"
|
|
221
|
-
store_variables(config_path, self.size, self.
|
|
304
|
+
store_variables(config_path, self.size, self.apply_reorder, self.target_spacing)
|
|
222
305
|
mlflow.log_artifact(str(config_path), "config")
|
|
223
306
|
|
|
224
307
|
def _register_pytorch_model(self) -> None:
|
|
@@ -252,7 +335,7 @@ class ModelTrackingCallback(Callback):
|
|
|
252
335
|
|
|
253
336
|
print(f"MLflow run completed. Run ID: {mlflow.active_run().info.run_id}")
|
|
254
337
|
|
|
255
|
-
# %% ../nbs/07_utils.ipynb
|
|
338
|
+
# %% ../nbs/07_utils.ipynb 10
|
|
256
339
|
import subprocess
|
|
257
340
|
import threading
|
|
258
341
|
import time
|
|
@@ -286,7 +369,7 @@ class MLflowUIManager:
|
|
|
286
369
|
try:
|
|
287
370
|
response = requests.get(f'http://localhost:{self.port}', timeout=2)
|
|
288
371
|
return response.status_code == 200
|
|
289
|
-
except:
|
|
372
|
+
except (requests.RequestException, ConnectionError, OSError):
|
|
290
373
|
return False
|
|
291
374
|
|
|
292
375
|
def find_available_port(self, start_port=5001):
|
fastMONAI/vision_all.py
CHANGED
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
from .vision_core import *
|
|
3
3
|
from .vision_data import *
|
|
4
4
|
from .vision_augmentation import *
|
|
5
|
-
from .vision_loss import *
|
|
5
|
+
from .vision_loss import *
|
|
6
6
|
from .vision_metrics import *
|
|
7
7
|
from .vision_inference import *
|
|
8
|
-
from .
|
|
8
|
+
from .vision_patch import *
|
|
9
|
+
from .utils import *
|
|
9
10
|
from .external_data import *
|
|
10
11
|
from .dataset_info import *
|
|
11
12
|
|