fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +224 -28
- fastMONAI/dataset_info.py +329 -47
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +394 -22
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +264 -28
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1259 -0
- fastMONAI/vision_plot.py +90 -1
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/METADATA +5 -5
- fastmonai-0.6.0.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/top_level.txt +0 -0
fastMONAI/utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_utils.ipynb.
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
|
-
__all__ = ['store_variables', 'load_variables', '
|
|
4
|
+
__all__ = ['store_variables', 'load_variables', 'store_patch_variables', 'load_patch_variables', 'print_colab_gpu_info',
|
|
5
|
+
'ModelTrackingCallback', 'create_mlflow_callback', 'MLflowUIManager']
|
|
5
6
|
|
|
6
7
|
# %% ../nbs/07_utils.ipynb 1
|
|
7
8
|
import pickle
|
|
@@ -12,15 +13,16 @@ import mlflow.pytorch
|
|
|
12
13
|
import os
|
|
13
14
|
import tempfile
|
|
14
15
|
import json
|
|
16
|
+
from datetime import datetime
|
|
15
17
|
from fastai.callback.core import Callback
|
|
16
18
|
from fastcore.foundation import L
|
|
17
19
|
from typing import Any
|
|
18
20
|
|
|
19
21
|
# %% ../nbs/07_utils.ipynb 3
|
|
20
|
-
def store_variables(pkl_fn: str | Path, size: list,
|
|
22
|
+
def store_variables(pkl_fn: str | Path, size: list, apply_reorder: bool, target_spacing: int | list):
|
|
21
23
|
"""Save variable values in a pickle file."""
|
|
22
24
|
|
|
23
|
-
var_vals = [size,
|
|
25
|
+
var_vals = [size, apply_reorder, target_spacing]
|
|
24
26
|
|
|
25
27
|
with open(pkl_fn, 'wb') as f:
|
|
26
28
|
pickle.dump(var_vals, f)
|
|
@@ -38,7 +40,89 @@ def load_variables(pkl_fn: (str, Path)):
|
|
|
38
40
|
with open(pkl_fn, 'rb') as f:
|
|
39
41
|
return pickle.load(f)
|
|
40
42
|
|
|
41
|
-
# %% ../nbs/07_utils.ipynb
|
|
43
|
+
# %% ../nbs/07_utils.ipynb 6
|
|
44
|
+
def store_patch_variables(
|
|
45
|
+
pkl_fn: str | Path,
|
|
46
|
+
patch_size: list,
|
|
47
|
+
patch_overlap: int | float | list,
|
|
48
|
+
aggregation_mode: str,
|
|
49
|
+
apply_reorder: bool = False,
|
|
50
|
+
target_spacing: list = None,
|
|
51
|
+
sampler_type: str = 'uniform',
|
|
52
|
+
label_probabilities: dict = None,
|
|
53
|
+
samples_per_volume: int = 8,
|
|
54
|
+
queue_length: int = 300,
|
|
55
|
+
queue_num_workers: int = 4,
|
|
56
|
+
keep_largest_component: bool = False
|
|
57
|
+
):
|
|
58
|
+
"""Save patch-based training and inference configuration to a pickle file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
pkl_fn: Path to save the pickle file.
|
|
62
|
+
patch_size: Size of patches [x, y, z].
|
|
63
|
+
patch_overlap: Overlap for inference (int, float 0-1, or list).
|
|
64
|
+
aggregation_mode: GridAggregator mode ('crop', 'average', 'hann').
|
|
65
|
+
apply_reorder: Whether to reorder to canonical (RAS+) orientation.
|
|
66
|
+
target_spacing: Target voxel spacing [x, y, z].
|
|
67
|
+
sampler_type: Type of sampler used during training.
|
|
68
|
+
label_probabilities: Label probabilities for LabelSampler.
|
|
69
|
+
samples_per_volume: Number of patches extracted per volume during training.
|
|
70
|
+
queue_length: Maximum number of patches in queue buffer.
|
|
71
|
+
queue_num_workers: Number of workers for parallel patch extraction.
|
|
72
|
+
keep_largest_component: If True, keep only the largest connected component
|
|
73
|
+
in binary segmentation predictions during inference.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> store_patch_variables(
|
|
77
|
+
... 'patch_settings.pkl',
|
|
78
|
+
... patch_size=[96, 96, 96],
|
|
79
|
+
... patch_overlap=0.5,
|
|
80
|
+
... aggregation_mode='hann',
|
|
81
|
+
... apply_reorder=True,
|
|
82
|
+
... target_spacing=[1.0, 1.0, 1.0],
|
|
83
|
+
... samples_per_volume=16,
|
|
84
|
+
... keep_largest_component=True
|
|
85
|
+
... )
|
|
86
|
+
"""
|
|
87
|
+
config = {
|
|
88
|
+
'patch_size': patch_size,
|
|
89
|
+
'patch_overlap': patch_overlap,
|
|
90
|
+
'aggregation_mode': aggregation_mode,
|
|
91
|
+
'apply_reorder': apply_reorder,
|
|
92
|
+
'target_spacing': target_spacing,
|
|
93
|
+
'sampler_type': sampler_type,
|
|
94
|
+
'label_probabilities': label_probabilities,
|
|
95
|
+
'samples_per_volume': samples_per_volume,
|
|
96
|
+
'queue_length': queue_length,
|
|
97
|
+
'queue_num_workers': queue_num_workers,
|
|
98
|
+
'keep_largest_component': keep_largest_component
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
with open(pkl_fn, 'wb') as f:
|
|
102
|
+
pickle.dump(config, f)
|
|
103
|
+
|
|
104
|
+
# %% ../nbs/07_utils.ipynb 7
|
|
105
|
+
def load_patch_variables(pkl_fn: str | Path) -> dict:
|
|
106
|
+
"""Load patch-based training and inference configuration from a pickle file.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
pkl_fn: Path to the pickle file.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Dictionary with patch configuration including:
|
|
113
|
+
- patch_size, patch_overlap, aggregation_mode
|
|
114
|
+
- apply_reorder, target_spacing, sampler_type, label_probabilities
|
|
115
|
+
- samples_per_volume, queue_length, queue_num_workers
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> config = load_patch_variables('patch_settings.pkl')
|
|
119
|
+
>>> from fastMONAI.vision_patch import PatchConfig
|
|
120
|
+
>>> patch_config = PatchConfig(**config)
|
|
121
|
+
"""
|
|
122
|
+
with open(pkl_fn, 'rb') as f:
|
|
123
|
+
return pickle.load(f)
|
|
124
|
+
|
|
125
|
+
# %% ../nbs/07_utils.ipynb 8
|
|
42
126
|
def print_colab_gpu_info():
|
|
43
127
|
"""Check if we have a GPU attached to the runtime."""
|
|
44
128
|
|
|
@@ -52,13 +136,148 @@ def print_colab_gpu_info():
|
|
|
52
136
|
if torch.cuda.is_available(): print('GPU attached.')
|
|
53
137
|
else: print(colab_gpu_msg)
|
|
54
138
|
|
|
55
|
-
# %% ../nbs/07_utils.ipynb
|
|
139
|
+
# %% ../nbs/07_utils.ipynb 9
|
|
140
|
+
def _detect_patch_workflow(dls) -> bool:
|
|
141
|
+
"""Detect if DataLoaders are patch-based (MedPatchDataLoaders).
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
dls: DataLoaders instance
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
True if dls is a MedPatchDataLoaders instance
|
|
148
|
+
"""
|
|
149
|
+
return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _extract_size_from_transforms(tfms) -> list | None:
|
|
153
|
+
"""Extract target size from PadOrCrop transform if present.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
tfms: List of transforms
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Target size as list, or None if not found
|
|
160
|
+
"""
|
|
161
|
+
if tfms is None:
|
|
162
|
+
return None
|
|
163
|
+
for tfm in tfms:
|
|
164
|
+
if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):
|
|
165
|
+
return list(tfm.pad_or_crop.target_shape)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _extract_standard_config(learn) -> dict:
|
|
170
|
+
"""Extract config from standard MedDataBlock workflow.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
learn: fastai Learner instance
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Dictionary with extracted configuration
|
|
177
|
+
"""
|
|
178
|
+
from fastMONAI.vision_core import MedBase
|
|
179
|
+
dls = learn.dls
|
|
180
|
+
|
|
181
|
+
# Get preprocessing from MedBase class attributes
|
|
182
|
+
apply_reorder = MedBase.apply_reorder
|
|
183
|
+
target_spacing = MedBase.target_spacing
|
|
184
|
+
|
|
185
|
+
# Extract item_tfms from DataLoaders pipeline
|
|
186
|
+
item_tfms = []
|
|
187
|
+
if hasattr(dls, 'after_item') and dls.after_item:
|
|
188
|
+
item_tfms = list(dls.after_item.fs)
|
|
189
|
+
|
|
190
|
+
# Extract size from PadOrCrop transform
|
|
191
|
+
size = _extract_size_from_transforms(item_tfms)
|
|
192
|
+
|
|
193
|
+
return {
|
|
194
|
+
'apply_reorder': apply_reorder,
|
|
195
|
+
'target_spacing': target_spacing,
|
|
196
|
+
'size': size,
|
|
197
|
+
'item_tfms': item_tfms,
|
|
198
|
+
'batch_size': dls.bs,
|
|
199
|
+
'patch_config': None,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _extract_patch_config(learn) -> dict:
|
|
204
|
+
"""Extract config from MedPatchDataLoaders workflow.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
learn: fastai Learner instance
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary with extracted configuration including patch-specific params
|
|
211
|
+
"""
|
|
212
|
+
dls = learn.dls
|
|
213
|
+
patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)
|
|
214
|
+
|
|
215
|
+
config = {
|
|
216
|
+
'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),
|
|
217
|
+
'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),
|
|
218
|
+
'size': patch_config.patch_size if patch_config else None,
|
|
219
|
+
'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],
|
|
220
|
+
'batch_size': dls.bs,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# Add patch-specific params for logging
|
|
224
|
+
if patch_config:
|
|
225
|
+
config['patch_config'] = {
|
|
226
|
+
'patch_size': patch_config.patch_size,
|
|
227
|
+
'patch_overlap': patch_config.patch_overlap,
|
|
228
|
+
'samples_per_volume': patch_config.samples_per_volume,
|
|
229
|
+
'sampler_type': patch_config.sampler_type,
|
|
230
|
+
'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,
|
|
231
|
+
'queue_length': patch_config.queue_length,
|
|
232
|
+
'aggregation_mode': patch_config.aggregation_mode,
|
|
233
|
+
'padding_mode': patch_config.padding_mode,
|
|
234
|
+
'keep_largest_component': patch_config.keep_largest_component,
|
|
235
|
+
}
|
|
236
|
+
else:
|
|
237
|
+
config['patch_config'] = None
|
|
238
|
+
|
|
239
|
+
return config
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _extract_loss_name(learn) -> str:
|
|
243
|
+
"""Extract loss function name from Learner.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
learn: fastai Learner instance
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Name of the loss function
|
|
250
|
+
"""
|
|
251
|
+
loss_func = learn.loss_func
|
|
252
|
+
# Handle CustomLoss wrapper
|
|
253
|
+
if hasattr(loss_func, 'loss_func'):
|
|
254
|
+
inner = loss_func.loss_func
|
|
255
|
+
return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__
|
|
256
|
+
return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _extract_model_name(learn) -> str:
|
|
260
|
+
"""Extract model architecture name from Learner.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
learn: fastai Learner instance
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Name of the model architecture
|
|
267
|
+
"""
|
|
268
|
+
model = learn.model
|
|
269
|
+
return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__
|
|
270
|
+
|
|
271
|
+
# %% ../nbs/07_utils.ipynb 10
|
|
56
272
|
class ModelTrackingCallback(Callback):
|
|
57
273
|
"""
|
|
58
274
|
A FastAI callback for comprehensive MLflow experiment tracking.
|
|
59
275
|
|
|
60
276
|
This callback automatically logs hyperparameters, metrics, model artifacts,
|
|
61
|
-
and configuration to MLflow during training.
|
|
277
|
+
and configuration to MLflow during training. If a SaveModelCallback is present,
|
|
278
|
+
the best model checkpoint will also be logged as an artifact.
|
|
279
|
+
|
|
280
|
+
Supports auto-managed runs when created via `create_mlflow_callback()`.
|
|
62
281
|
"""
|
|
63
282
|
|
|
64
283
|
def __init__(
|
|
@@ -67,8 +286,14 @@ class ModelTrackingCallback(Callback):
|
|
|
67
286
|
loss_function: str,
|
|
68
287
|
item_tfms: list[Any],
|
|
69
288
|
size: list[int],
|
|
70
|
-
|
|
71
|
-
|
|
289
|
+
target_spacing: list[float],
|
|
290
|
+
apply_reorder: bool,
|
|
291
|
+
experiment_name: str = None,
|
|
292
|
+
run_name: str = None,
|
|
293
|
+
auto_start: bool = False,
|
|
294
|
+
patch_config: dict = None,
|
|
295
|
+
extra_params: dict = None,
|
|
296
|
+
extra_tags: dict = None
|
|
72
297
|
):
|
|
73
298
|
"""
|
|
74
299
|
Initialize the MLflow tracking callback.
|
|
@@ -76,16 +301,32 @@ class ModelTrackingCallback(Callback):
|
|
|
76
301
|
Args:
|
|
77
302
|
model_name: Name of the model architecture for registration
|
|
78
303
|
loss_function: Name of the loss function being used
|
|
304
|
+
item_tfms: List of item transforms
|
|
79
305
|
size: Model input dimensions
|
|
80
|
-
|
|
81
|
-
|
|
306
|
+
target_spacing: Resampling dimensions
|
|
307
|
+
apply_reorder: Whether reordering augmentation is applied
|
|
308
|
+
experiment_name: MLflow experiment name (used with auto_start)
|
|
309
|
+
run_name: MLflow run name (auto-generated if None)
|
|
310
|
+
auto_start: If True, auto-starts/stops MLflow run
|
|
311
|
+
patch_config: Patch configuration dict for logging (from MedPatchDataLoaders)
|
|
312
|
+
extra_params: Additional parameters to log
|
|
313
|
+
extra_tags: MLflow tags to set on the run
|
|
82
314
|
"""
|
|
83
315
|
self.model_name = model_name
|
|
84
316
|
self.loss_function = loss_function
|
|
85
317
|
self.item_tfms = item_tfms
|
|
86
318
|
self.size = size
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
319
|
+
self.target_spacing = target_spacing
|
|
320
|
+
self.apply_reorder = apply_reorder
|
|
321
|
+
|
|
322
|
+
# New auto-management fields
|
|
323
|
+
self.experiment_name = experiment_name
|
|
324
|
+
self.run_name = run_name
|
|
325
|
+
self.auto_start = auto_start
|
|
326
|
+
self.patch_config = patch_config
|
|
327
|
+
self.extra_params = extra_params or {}
|
|
328
|
+
self.extra_tags = extra_tags or {}
|
|
329
|
+
self._auto_started = False
|
|
89
330
|
|
|
90
331
|
self.config = self._build_config()
|
|
91
332
|
|
|
@@ -128,8 +369,8 @@ class ModelTrackingCallback(Callback):
|
|
|
128
369
|
"loss_function": self.loss_function,
|
|
129
370
|
"transform_details": transform_details,
|
|
130
371
|
"size": self.size,
|
|
131
|
-
"
|
|
132
|
-
"
|
|
372
|
+
"target_spacing": self.target_spacing,
|
|
373
|
+
"apply_reorder": self.apply_reorder,
|
|
133
374
|
}
|
|
134
375
|
|
|
135
376
|
def _extract_training_params(self) -> dict[str, Any]:
|
|
@@ -143,8 +384,8 @@ class ModelTrackingCallback(Callback):
|
|
|
143
384
|
|
|
144
385
|
params["loss_function"] = self.config["loss_function"]
|
|
145
386
|
params["size"] = self.config["size"]
|
|
146
|
-
params["
|
|
147
|
-
params["
|
|
387
|
+
params["target_spacing"] = self.config["target_spacing"]
|
|
388
|
+
params["apply_reorder"] = self.config["apply_reorder"]
|
|
148
389
|
|
|
149
390
|
params["transformations"] = json.dumps(
|
|
150
391
|
self.config["transform_details"],
|
|
@@ -152,6 +393,12 @@ class ModelTrackingCallback(Callback):
|
|
|
152
393
|
separators=(',', ': ')
|
|
153
394
|
)
|
|
154
395
|
|
|
396
|
+
# Add patch-specific params if present
|
|
397
|
+
if self.patch_config:
|
|
398
|
+
for key, value in self.patch_config.items():
|
|
399
|
+
if value is not None:
|
|
400
|
+
params[f"patch_{key}"] = value if isinstance(value, (int, float, str, bool)) else str(value)
|
|
401
|
+
|
|
155
402
|
return params
|
|
156
403
|
|
|
157
404
|
def _extract_epoch_metrics(self) -> dict[str, float]:
|
|
@@ -201,6 +448,23 @@ class ModelTrackingCallback(Callback):
|
|
|
201
448
|
if os.path.exists(weights_file):
|
|
202
449
|
mlflow.log_artifact(weights_file, "model")
|
|
203
450
|
|
|
451
|
+
# Auto-detect SaveModelCallback and log best model
|
|
452
|
+
from fastai.callback.tracker import SaveModelCallback
|
|
453
|
+
best_model_cb = None
|
|
454
|
+
for cb in self.learn.cbs:
|
|
455
|
+
if isinstance(cb, SaveModelCallback):
|
|
456
|
+
best_model_path = self.learn.path/self.learn.model_dir/f'{cb.fname}.pth'
|
|
457
|
+
if best_model_path.exists():
|
|
458
|
+
mlflow.log_artifact(str(best_model_path), "model")
|
|
459
|
+
print(f"Logged best model artifact: {cb.fname}.pth")
|
|
460
|
+
best_model_cb = cb
|
|
461
|
+
break
|
|
462
|
+
|
|
463
|
+
# Load best model weights before export if SaveModelCallback was used
|
|
464
|
+
if best_model_cb is not None:
|
|
465
|
+
self.learn.load(best_model_cb.fname)
|
|
466
|
+
print(f"Loaded best model weights ({best_model_cb.fname}) for learner export")
|
|
467
|
+
|
|
204
468
|
# Remove MLflow callbacks before exporting learner for inference
|
|
205
469
|
# This prevents the callback from being triggered during inference
|
|
206
470
|
original_cbs = self.learn.cbs.copy() # Save original callbacks
|
|
@@ -218,7 +482,7 @@ class ModelTrackingCallback(Callback):
|
|
|
218
482
|
self.learn.cbs = original_cbs
|
|
219
483
|
|
|
220
484
|
config_path = temp_dir / "inference_settings.pkl"
|
|
221
|
-
store_variables(config_path, self.size, self.
|
|
485
|
+
store_variables(config_path, self.size, self.apply_reorder, self.target_spacing)
|
|
222
486
|
mlflow.log_artifact(str(config_path), "config")
|
|
223
487
|
|
|
224
488
|
def _register_pytorch_model(self) -> None:
|
|
@@ -229,8 +493,22 @@ class ModelTrackingCallback(Callback):
|
|
|
229
493
|
)
|
|
230
494
|
|
|
231
495
|
def before_fit(self) -> None:
|
|
232
|
-
"""Log hyperparameters before training starts."""
|
|
496
|
+
"""Log hyperparameters before training starts. Auto-start run if configured."""
|
|
497
|
+
# Auto-start run if requested
|
|
498
|
+
if self.auto_start:
|
|
499
|
+
if self.experiment_name:
|
|
500
|
+
mlflow.set_experiment(self.experiment_name)
|
|
501
|
+
self.run_name = self.run_name or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
502
|
+
mlflow.start_run(run_name=self.run_name)
|
|
503
|
+
self._auto_started = True
|
|
504
|
+
|
|
505
|
+
# Set tags
|
|
506
|
+
if self.extra_tags:
|
|
507
|
+
mlflow.set_tags(self.extra_tags)
|
|
508
|
+
|
|
509
|
+
# Log params
|
|
233
510
|
params = self._extract_training_params()
|
|
511
|
+
params.update(self.extra_params)
|
|
234
512
|
mlflow.log_params(params)
|
|
235
513
|
|
|
236
514
|
def after_epoch(self) -> None:
|
|
@@ -249,10 +527,104 @@ class ModelTrackingCallback(Callback):
|
|
|
249
527
|
self._save_model_artifacts(temp_path)
|
|
250
528
|
|
|
251
529
|
self._register_pytorch_model()
|
|
252
|
-
|
|
253
|
-
|
|
530
|
+
|
|
531
|
+
run_id = mlflow.active_run().info.run_id
|
|
532
|
+
print(f"MLflow run completed. Run ID: {run_id}")
|
|
533
|
+
|
|
534
|
+
# End run if auto-started
|
|
535
|
+
if self._auto_started:
|
|
536
|
+
mlflow.end_run()
|
|
537
|
+
self._auto_started = False
|
|
254
538
|
|
|
255
|
-
# %% ../nbs/07_utils.ipynb
|
|
539
|
+
# %% ../nbs/07_utils.ipynb 11
|
|
540
|
+
def create_mlflow_callback(
|
|
541
|
+
learn,
|
|
542
|
+
experiment_name: str = None,
|
|
543
|
+
run_name: str = None,
|
|
544
|
+
auto_start: bool = True,
|
|
545
|
+
model_name: str = None,
|
|
546
|
+
extra_params: dict = None,
|
|
547
|
+
extra_tags: dict = None,
|
|
548
|
+
) -> ModelTrackingCallback:
|
|
549
|
+
"""Create MLflow tracking callback with auto-extracted configuration.
|
|
550
|
+
|
|
551
|
+
This factory function automatically extracts configuration from the Learner,
|
|
552
|
+
eliminating the need to manually specify parameters like size, transforms,
|
|
553
|
+
loss function, etc.
|
|
554
|
+
|
|
555
|
+
Auto-extracts from Learner:
|
|
556
|
+
- Preprocessing: apply_reorder, target_spacing, size/patch_size
|
|
557
|
+
- Transforms: item_tfms or pre_patch_tfms
|
|
558
|
+
- Training: loss_func, model architecture
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
learn: fastai Learner instance
|
|
562
|
+
experiment_name: MLflow experiment name. If None, uses model name.
|
|
563
|
+
run_name: MLflow run name. If None, auto-generates with timestamp.
|
|
564
|
+
auto_start: If True, auto-starts/stops MLflow run in before_fit/after_fit.
|
|
565
|
+
model_name: Override auto-extracted model name for registration.
|
|
566
|
+
extra_params: Additional parameters to log (e.g., {'dropout': 0.5}).
|
|
567
|
+
extra_tags: MLflow tags to set on the run.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
ModelTrackingCallback ready to use with learn.fit()
|
|
571
|
+
|
|
572
|
+
Example:
|
|
573
|
+
>>> # Instead of this (6 manual params):
|
|
574
|
+
>>> # mlflow_callback = ModelTrackingCallback(
|
|
575
|
+
>>> # model_name=f"{task}_{model._get_name()}",
|
|
576
|
+
>>> # loss_function=loss_func.loss_func._get_name(),
|
|
577
|
+
>>> # item_tfms=item_tfms,
|
|
578
|
+
>>> # size=size,
|
|
579
|
+
>>> # target_spacing=target_spacing,
|
|
580
|
+
>>> # apply_reorder=True,
|
|
581
|
+
>>> # )
|
|
582
|
+
>>> # with mlflow.start_run(run_name="training"):
|
|
583
|
+
>>> # learn.fit_one_cycle(30, lr, cbs=[mlflow_callback])
|
|
584
|
+
>>>
|
|
585
|
+
>>> # Do this (zero manual params):
|
|
586
|
+
>>> callback = create_mlflow_callback(learn, experiment_name="Task02_Heart")
|
|
587
|
+
>>> learn.fit_one_cycle(30, lr, cbs=[callback, save_best])
|
|
588
|
+
"""
|
|
589
|
+
# Detect workflow and extract config
|
|
590
|
+
if _detect_patch_workflow(learn.dls):
|
|
591
|
+
config = _extract_patch_config(learn)
|
|
592
|
+
else:
|
|
593
|
+
config = _extract_standard_config(learn)
|
|
594
|
+
|
|
595
|
+
# Extract model/loss info
|
|
596
|
+
_model_name = model_name or _extract_model_name(learn)
|
|
597
|
+
_loss_name = _extract_loss_name(learn)
|
|
598
|
+
|
|
599
|
+
# Set experiment name
|
|
600
|
+
if experiment_name is None:
|
|
601
|
+
experiment_name = _model_name
|
|
602
|
+
|
|
603
|
+
# Validate size was extracted
|
|
604
|
+
if config['size'] is None:
|
|
605
|
+
raise ValueError(
|
|
606
|
+
"Could not auto-extract 'size'. Either:\n"
|
|
607
|
+
"1. Add PadOrCrop to item_tfms, or\n"
|
|
608
|
+
"2. Use MedPatchDataLoaders with PatchConfig, or\n"
|
|
609
|
+
"3. Use ModelTrackingCallback directly with manual params"
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
return ModelTrackingCallback(
|
|
613
|
+
model_name=_model_name,
|
|
614
|
+
loss_function=_loss_name,
|
|
615
|
+
item_tfms=config['item_tfms'],
|
|
616
|
+
size=config['size'],
|
|
617
|
+
target_spacing=config['target_spacing'],
|
|
618
|
+
apply_reorder=config['apply_reorder'],
|
|
619
|
+
experiment_name=experiment_name,
|
|
620
|
+
run_name=run_name,
|
|
621
|
+
auto_start=auto_start,
|
|
622
|
+
patch_config=config['patch_config'],
|
|
623
|
+
extra_params=extra_params,
|
|
624
|
+
extra_tags=extra_tags,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# %% ../nbs/07_utils.ipynb 13
|
|
256
628
|
import subprocess
|
|
257
629
|
import threading
|
|
258
630
|
import time
|
|
@@ -286,7 +658,7 @@ class MLflowUIManager:
|
|
|
286
658
|
try:
|
|
287
659
|
response = requests.get(f'http://localhost:{self.port}', timeout=2)
|
|
288
660
|
return response.status_code == 200
|
|
289
|
-
except:
|
|
661
|
+
except (requests.RequestException, ConnectionError, OSError):
|
|
290
662
|
return False
|
|
291
663
|
|
|
292
664
|
def find_available_port(self, start_port=5001):
|
fastMONAI/vision_all.py
CHANGED
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
from .vision_core import *
|
|
3
3
|
from .vision_data import *
|
|
4
4
|
from .vision_augmentation import *
|
|
5
|
-
from .vision_loss import *
|
|
5
|
+
from .vision_loss import *
|
|
6
6
|
from .vision_metrics import *
|
|
7
7
|
from .vision_inference import *
|
|
8
|
-
from .
|
|
8
|
+
from .vision_patch import *
|
|
9
|
+
from .utils import *
|
|
9
10
|
from .external_data import *
|
|
10
11
|
from .dataset_info import *
|
|
11
12
|
|