fastMONAI 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastMONAI/utils.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  # %% auto 0
4
4
  __all__ = ['store_variables', 'load_variables', 'store_patch_variables', 'load_patch_variables', 'print_colab_gpu_info',
5
- 'ModelTrackingCallback', 'MLflowUIManager']
5
+ 'ModelTrackingCallback', 'create_mlflow_callback', 'MLflowUIManager']
6
6
 
7
7
  # %% ../nbs/07_utils.ipynb 1
8
8
  import pickle
@@ -13,6 +13,7 @@ import mlflow.pytorch
13
13
  import os
14
14
  import tempfile
15
15
  import json
16
+ from datetime import datetime
16
17
  from fastai.callback.core import Callback
17
18
  from fastcore.foundation import L
18
19
  from typing import Any
@@ -136,12 +137,147 @@ def print_colab_gpu_info():
136
137
  else: print(colab_gpu_msg)
137
138
 
138
139
  # %% ../nbs/07_utils.ipynb 9
140
+ def _detect_patch_workflow(dls) -> bool:
141
+ """Detect if DataLoaders are patch-based (MedPatchDataLoaders).
142
+
143
+ Args:
144
+ dls: DataLoaders instance
145
+
146
+ Returns:
147
+ True if dls is a MedPatchDataLoaders instance
148
+ """
149
+ return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')
150
+
151
+
152
+ def _extract_size_from_transforms(tfms) -> list | None:
153
+ """Extract target size from PadOrCrop transform if present.
154
+
155
+ Args:
156
+ tfms: List of transforms
157
+
158
+ Returns:
159
+ Target size as list, or None if not found
160
+ """
161
+ if tfms is None:
162
+ return None
163
+ for tfm in tfms:
164
+ if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):
165
+ return list(tfm.pad_or_crop.target_shape)
166
+ return None
167
+
168
+
169
+ def _extract_standard_config(learn) -> dict:
170
+ """Extract config from standard MedDataBlock workflow.
171
+
172
+ Args:
173
+ learn: fastai Learner instance
174
+
175
+ Returns:
176
+ Dictionary with extracted configuration
177
+ """
178
+ from fastMONAI.vision_core import MedBase
179
+ dls = learn.dls
180
+
181
+ # Get preprocessing from MedBase class attributes
182
+ apply_reorder = MedBase.apply_reorder
183
+ target_spacing = MedBase.target_spacing
184
+
185
+ # Extract item_tfms from DataLoaders pipeline
186
+ item_tfms = []
187
+ if hasattr(dls, 'after_item') and dls.after_item:
188
+ item_tfms = list(dls.after_item.fs)
189
+
190
+ # Extract size from PadOrCrop transform
191
+ size = _extract_size_from_transforms(item_tfms)
192
+
193
+ return {
194
+ 'apply_reorder': apply_reorder,
195
+ 'target_spacing': target_spacing,
196
+ 'size': size,
197
+ 'item_tfms': item_tfms,
198
+ 'batch_size': dls.bs,
199
+ 'patch_config': None,
200
+ }
201
+
202
+
203
+ def _extract_patch_config(learn) -> dict:
204
+ """Extract config from MedPatchDataLoaders workflow.
205
+
206
+ Args:
207
+ learn: fastai Learner instance
208
+
209
+ Returns:
210
+ Dictionary with extracted configuration including patch-specific params
211
+ """
212
+ dls = learn.dls
213
+ patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)
214
+
215
+ config = {
216
+ 'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),
217
+ 'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),
218
+ 'size': patch_config.patch_size if patch_config else None,
219
+ 'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],
220
+ 'batch_size': dls.bs,
221
+ }
222
+
223
+ # Add patch-specific params for logging
224
+ if patch_config:
225
+ config['patch_config'] = {
226
+ 'patch_size': patch_config.patch_size,
227
+ 'patch_overlap': patch_config.patch_overlap,
228
+ 'samples_per_volume': patch_config.samples_per_volume,
229
+ 'sampler_type': patch_config.sampler_type,
230
+ 'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,
231
+ 'queue_length': patch_config.queue_length,
232
+ 'aggregation_mode': patch_config.aggregation_mode,
233
+ 'padding_mode': patch_config.padding_mode,
234
+ 'keep_largest_component': patch_config.keep_largest_component,
235
+ }
236
+ else:
237
+ config['patch_config'] = None
238
+
239
+ return config
240
+
241
+
242
+ def _extract_loss_name(learn) -> str:
243
+ """Extract loss function name from Learner.
244
+
245
+ Args:
246
+ learn: fastai Learner instance
247
+
248
+ Returns:
249
+ Name of the loss function
250
+ """
251
+ loss_func = learn.loss_func
252
+ # Handle CustomLoss wrapper
253
+ if hasattr(loss_func, 'loss_func'):
254
+ inner = loss_func.loss_func
255
+ return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__
256
+ return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__
257
+
258
+
259
+ def _extract_model_name(learn) -> str:
260
+ """Extract model architecture name from Learner.
261
+
262
+ Args:
263
+ learn: fastai Learner instance
264
+
265
+ Returns:
266
+ Name of the model architecture
267
+ """
268
+ model = learn.model
269
+ return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__
270
+
271
+ # %% ../nbs/07_utils.ipynb 10
139
272
  class ModelTrackingCallback(Callback):
140
273
  """
141
274
  A FastAI callback for comprehensive MLflow experiment tracking.
142
275
 
143
276
  This callback automatically logs hyperparameters, metrics, model artifacts,
144
- and configuration to MLflow during training.
277
+ and configuration to MLflow during training. If a SaveModelCallback is present,
278
+ the best model checkpoint will also be logged as an artifact.
279
+
280
+ Supports auto-managed runs when created via `create_mlflow_callback()`.
145
281
  """
146
282
 
147
283
  def __init__(
@@ -151,7 +287,13 @@ class ModelTrackingCallback(Callback):
151
287
  item_tfms: list[Any],
152
288
  size: list[int],
153
289
  target_spacing: list[float],
154
- apply_reorder: bool
290
+ apply_reorder: bool,
291
+ experiment_name: str = None,
292
+ run_name: str = None,
293
+ auto_start: bool = False,
294
+ patch_config: dict = None,
295
+ extra_params: dict = None,
296
+ extra_tags: dict = None
155
297
  ):
156
298
  """
157
299
  Initialize the MLflow tracking callback.
@@ -159,9 +301,16 @@ class ModelTrackingCallback(Callback):
159
301
  Args:
160
302
  model_name: Name of the model architecture for registration
161
303
  loss_function: Name of the loss function being used
304
+ item_tfms: List of item transforms
162
305
  size: Model input dimensions
163
306
  target_spacing: Resampling dimensions
164
307
  apply_reorder: Whether reordering augmentation is applied
308
+ experiment_name: MLflow experiment name (used with auto_start)
309
+ run_name: MLflow run name (auto-generated if None)
310
+ auto_start: If True, auto-starts/stops MLflow run
311
+ patch_config: Patch configuration dict for logging (from MedPatchDataLoaders)
312
+ extra_params: Additional parameters to log
313
+ extra_tags: MLflow tags to set on the run
165
314
  """
166
315
  self.model_name = model_name
167
316
  self.loss_function = loss_function
@@ -170,6 +319,15 @@ class ModelTrackingCallback(Callback):
170
319
  self.target_spacing = target_spacing
171
320
  self.apply_reorder = apply_reorder
172
321
 
322
+ # New auto-management fields
323
+ self.experiment_name = experiment_name
324
+ self.run_name = run_name
325
+ self.auto_start = auto_start
326
+ self.patch_config = patch_config
327
+ self.extra_params = extra_params or {}
328
+ self.extra_tags = extra_tags or {}
329
+ self._auto_started = False
330
+
173
331
  self.config = self._build_config()
174
332
 
175
333
  def extract_all_params(self, tfm):
@@ -235,6 +393,12 @@ class ModelTrackingCallback(Callback):
235
393
  separators=(',', ': ')
236
394
  )
237
395
 
396
+ # Add patch-specific params if present
397
+ if self.patch_config:
398
+ for key, value in self.patch_config.items():
399
+ if value is not None:
400
+ params[f"patch_{key}"] = value if isinstance(value, (int, float, str, bool)) else str(value)
401
+
238
402
  return params
239
403
 
240
404
  def _extract_epoch_metrics(self) -> dict[str, float]:
@@ -284,6 +448,23 @@ class ModelTrackingCallback(Callback):
284
448
  if os.path.exists(weights_file):
285
449
  mlflow.log_artifact(weights_file, "model")
286
450
 
451
+ # Auto-detect SaveModelCallback and log best model
452
+ from fastai.callback.tracker import SaveModelCallback
453
+ best_model_cb = None
454
+ for cb in self.learn.cbs:
455
+ if isinstance(cb, SaveModelCallback):
456
+ best_model_path = self.learn.path/self.learn.model_dir/f'{cb.fname}.pth'
457
+ if best_model_path.exists():
458
+ mlflow.log_artifact(str(best_model_path), "model")
459
+ print(f"Logged best model artifact: {cb.fname}.pth")
460
+ best_model_cb = cb
461
+ break
462
+
463
+ # Load best model weights before export if SaveModelCallback was used
464
+ if best_model_cb is not None:
465
+ self.learn.load(best_model_cb.fname)
466
+ print(f"Loaded best model weights ({best_model_cb.fname}) for learner export")
467
+
287
468
  # Remove MLflow callbacks before exporting learner for inference
288
469
  # This prevents the callback from being triggered during inference
289
470
  original_cbs = self.learn.cbs.copy() # Save original callbacks
@@ -312,8 +493,22 @@ class ModelTrackingCallback(Callback):
312
493
  )
313
494
 
314
495
  def before_fit(self) -> None:
315
- """Log hyperparameters before training starts."""
496
+ """Log hyperparameters before training starts. Auto-start run if configured."""
497
+ # Auto-start run if requested
498
+ if self.auto_start:
499
+ if self.experiment_name:
500
+ mlflow.set_experiment(self.experiment_name)
501
+ self.run_name = self.run_name or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
502
+ mlflow.start_run(run_name=self.run_name)
503
+ self._auto_started = True
504
+
505
+ # Set tags
506
+ if self.extra_tags:
507
+ mlflow.set_tags(self.extra_tags)
508
+
509
+ # Log params
316
510
  params = self._extract_training_params()
511
+ params.update(self.extra_params)
317
512
  mlflow.log_params(params)
318
513
 
319
514
  def after_epoch(self) -> None:
@@ -332,10 +527,104 @@ class ModelTrackingCallback(Callback):
332
527
  self._save_model_artifacts(temp_path)
333
528
 
334
529
  self._register_pytorch_model()
335
-
336
- print(f"MLflow run completed. Run ID: {mlflow.active_run().info.run_id}")
530
+
531
+ run_id = mlflow.active_run().info.run_id
532
+ print(f"MLflow run completed. Run ID: {run_id}")
533
+
534
+ # End run if auto-started
535
+ if self._auto_started:
536
+ mlflow.end_run()
537
+ self._auto_started = False
337
538
 
338
- # %% ../nbs/07_utils.ipynb 10
539
+ # %% ../nbs/07_utils.ipynb 11
540
+ def create_mlflow_callback(
541
+ learn,
542
+ experiment_name: str = None,
543
+ run_name: str = None,
544
+ auto_start: bool = True,
545
+ model_name: str = None,
546
+ extra_params: dict = None,
547
+ extra_tags: dict = None,
548
+ ) -> ModelTrackingCallback:
549
+ """Create MLflow tracking callback with auto-extracted configuration.
550
+
551
+ This factory function automatically extracts configuration from the Learner,
552
+ eliminating the need to manually specify parameters like size, transforms,
553
+ loss function, etc.
554
+
555
+ Auto-extracts from Learner:
556
+ - Preprocessing: apply_reorder, target_spacing, size/patch_size
557
+ - Transforms: item_tfms or pre_patch_tfms
558
+ - Training: loss_func, model architecture
559
+
560
+ Args:
561
+ learn: fastai Learner instance
562
+ experiment_name: MLflow experiment name. If None, uses model name.
563
+ run_name: MLflow run name. If None, auto-generates with timestamp.
564
+ auto_start: If True, auto-starts/stops MLflow run in before_fit/after_fit.
565
+ model_name: Override auto-extracted model name for registration.
566
+ extra_params: Additional parameters to log (e.g., {'dropout': 0.5}).
567
+ extra_tags: MLflow tags to set on the run.
568
+
569
+ Returns:
570
+ ModelTrackingCallback ready to use with learn.fit()
571
+
572
+ Example:
573
+ >>> # Instead of this (6 manual params):
574
+ >>> # mlflow_callback = ModelTrackingCallback(
575
+ >>> # model_name=f"{task}_{model._get_name()}",
576
+ >>> # loss_function=loss_func.loss_func._get_name(),
577
+ >>> # item_tfms=item_tfms,
578
+ >>> # size=size,
579
+ >>> # target_spacing=target_spacing,
580
+ >>> # apply_reorder=True,
581
+ >>> # )
582
+ >>> # with mlflow.start_run(run_name="training"):
583
+ >>> # learn.fit_one_cycle(30, lr, cbs=[mlflow_callback])
584
+ >>>
585
+ >>> # Do this (zero manual params):
586
+ >>> callback = create_mlflow_callback(learn, experiment_name="Task02_Heart")
587
+ >>> learn.fit_one_cycle(30, lr, cbs=[callback, save_best])
588
+ """
589
+ # Detect workflow and extract config
590
+ if _detect_patch_workflow(learn.dls):
591
+ config = _extract_patch_config(learn)
592
+ else:
593
+ config = _extract_standard_config(learn)
594
+
595
+ # Extract model/loss info
596
+ _model_name = model_name or _extract_model_name(learn)
597
+ _loss_name = _extract_loss_name(learn)
598
+
599
+ # Set experiment name
600
+ if experiment_name is None:
601
+ experiment_name = _model_name
602
+
603
+ # Validate size was extracted
604
+ if config['size'] is None:
605
+ raise ValueError(
606
+ "Could not auto-extract 'size'. Either:\n"
607
+ "1. Add PadOrCrop to item_tfms, or\n"
608
+ "2. Use MedPatchDataLoaders with PatchConfig, or\n"
609
+ "3. Use ModelTrackingCallback directly with manual params"
610
+ )
611
+
612
+ return ModelTrackingCallback(
613
+ model_name=_model_name,
614
+ loss_function=_loss_name,
615
+ item_tfms=config['item_tfms'],
616
+ size=config['size'],
617
+ target_spacing=config['target_spacing'],
618
+ apply_reorder=config['apply_reorder'],
619
+ experiment_name=experiment_name,
620
+ run_name=run_name,
621
+ auto_start=auto_start,
622
+ patch_config=config['patch_config'],
623
+ extra_params=extra_params,
624
+ extra_tags=extra_tags,
625
+ )
626
+
627
+ # %% ../nbs/07_utils.ipynb 13
339
628
  import subprocess
340
629
  import threading
341
630
  import time