fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastMONAI/utils.py CHANGED
@@ -1,7 +1,8 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_utils.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['store_variables', 'load_variables', 'print_colab_gpu_info', 'ModelTrackingCallback', 'MLflowUIManager']
4
+ __all__ = ['store_variables', 'load_variables', 'store_patch_variables', 'load_patch_variables', 'print_colab_gpu_info',
5
+ 'ModelTrackingCallback', 'create_mlflow_callback', 'MLflowUIManager']
5
6
 
6
7
  # %% ../nbs/07_utils.ipynb 1
7
8
  import pickle
@@ -12,15 +13,16 @@ import mlflow.pytorch
12
13
  import os
13
14
  import tempfile
14
15
  import json
16
+ from datetime import datetime
15
17
  from fastai.callback.core import Callback
16
18
  from fastcore.foundation import L
17
19
  from typing import Any
18
20
 
19
21
  # %% ../nbs/07_utils.ipynb 3
20
- def store_variables(pkl_fn: str | Path, size: list, reorder: bool, resample: int | list):
22
+ def store_variables(pkl_fn: str | Path, size: list, apply_reorder: bool, target_spacing: int | list):
21
23
  """Save variable values in a pickle file."""
22
24
 
23
- var_vals = [size, reorder, resample]
25
+ var_vals = [size, apply_reorder, target_spacing]
24
26
 
25
27
  with open(pkl_fn, 'wb') as f:
26
28
  pickle.dump(var_vals, f)
@@ -38,7 +40,89 @@ def load_variables(pkl_fn: (str, Path)):
38
40
  with open(pkl_fn, 'rb') as f:
39
41
  return pickle.load(f)
40
42
 
41
- # %% ../nbs/07_utils.ipynb 5
43
+ # %% ../nbs/07_utils.ipynb 6
44
+ def store_patch_variables(
45
+ pkl_fn: str | Path,
46
+ patch_size: list,
47
+ patch_overlap: int | float | list,
48
+ aggregation_mode: str,
49
+ apply_reorder: bool = False,
50
+ target_spacing: list = None,
51
+ sampler_type: str = 'uniform',
52
+ label_probabilities: dict = None,
53
+ samples_per_volume: int = 8,
54
+ queue_length: int = 300,
55
+ queue_num_workers: int = 4,
56
+ keep_largest_component: bool = False
57
+ ):
58
+ """Save patch-based training and inference configuration to a pickle file.
59
+
60
+ Args:
61
+ pkl_fn: Path to save the pickle file.
62
+ patch_size: Size of patches [x, y, z].
63
+ patch_overlap: Overlap for inference (int, float 0-1, or list).
64
+ aggregation_mode: GridAggregator mode ('crop', 'average', 'hann').
65
+ apply_reorder: Whether to reorder to canonical (RAS+) orientation.
66
+ target_spacing: Target voxel spacing [x, y, z].
67
+ sampler_type: Type of sampler used during training.
68
+ label_probabilities: Label probabilities for LabelSampler.
69
+ samples_per_volume: Number of patches extracted per volume during training.
70
+ queue_length: Maximum number of patches in queue buffer.
71
+ queue_num_workers: Number of workers for parallel patch extraction.
72
+ keep_largest_component: If True, keep only the largest connected component
73
+ in binary segmentation predictions during inference.
74
+
75
+ Example:
76
+ >>> store_patch_variables(
77
+ ... 'patch_settings.pkl',
78
+ ... patch_size=[96, 96, 96],
79
+ ... patch_overlap=0.5,
80
+ ... aggregation_mode='hann',
81
+ ... apply_reorder=True,
82
+ ... target_spacing=[1.0, 1.0, 1.0],
83
+ ... samples_per_volume=16,
84
+ ... keep_largest_component=True
85
+ ... )
86
+ """
87
+ config = {
88
+ 'patch_size': patch_size,
89
+ 'patch_overlap': patch_overlap,
90
+ 'aggregation_mode': aggregation_mode,
91
+ 'apply_reorder': apply_reorder,
92
+ 'target_spacing': target_spacing,
93
+ 'sampler_type': sampler_type,
94
+ 'label_probabilities': label_probabilities,
95
+ 'samples_per_volume': samples_per_volume,
96
+ 'queue_length': queue_length,
97
+ 'queue_num_workers': queue_num_workers,
98
+ 'keep_largest_component': keep_largest_component
99
+ }
100
+
101
+ with open(pkl_fn, 'wb') as f:
102
+ pickle.dump(config, f)
103
+
104
+ # %% ../nbs/07_utils.ipynb 7
105
+ def load_patch_variables(pkl_fn: str | Path) -> dict:
106
+ """Load patch-based training and inference configuration from a pickle file.
107
+
108
+ Args:
109
+ pkl_fn: Path to the pickle file.
110
+
111
+ Returns:
112
+ Dictionary with patch configuration including:
113
+ - patch_size, patch_overlap, aggregation_mode
114
+ - apply_reorder, target_spacing, sampler_type, label_probabilities
115
+ - samples_per_volume, queue_length, queue_num_workers
116
+
117
+ Example:
118
+ >>> config = load_patch_variables('patch_settings.pkl')
119
+ >>> from fastMONAI.vision_patch import PatchConfig
120
+ >>> patch_config = PatchConfig(**config)
121
+ """
122
+ with open(pkl_fn, 'rb') as f:
123
+ return pickle.load(f)
124
+
125
+ # %% ../nbs/07_utils.ipynb 8
42
126
  def print_colab_gpu_info():
43
127
  """Check if we have a GPU attached to the runtime."""
44
128
 
@@ -52,13 +136,148 @@ def print_colab_gpu_info():
52
136
  if torch.cuda.is_available(): print('GPU attached.')
53
137
  else: print(colab_gpu_msg)
54
138
 
55
- # %% ../nbs/07_utils.ipynb 6
139
+ # %% ../nbs/07_utils.ipynb 9
140
+ def _detect_patch_workflow(dls) -> bool:
141
+ """Detect if DataLoaders are patch-based (MedPatchDataLoaders).
142
+
143
+ Args:
144
+ dls: DataLoaders instance
145
+
146
+ Returns:
147
+ True if dls is a MedPatchDataLoaders instance
148
+ """
149
+ return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')
150
+
151
+
152
+ def _extract_size_from_transforms(tfms) -> list | None:
153
+ """Extract target size from PadOrCrop transform if present.
154
+
155
+ Args:
156
+ tfms: List of transforms
157
+
158
+ Returns:
159
+ Target size as list, or None if not found
160
+ """
161
+ if tfms is None:
162
+ return None
163
+ for tfm in tfms:
164
+ if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):
165
+ return list(tfm.pad_or_crop.target_shape)
166
+ return None
167
+
168
+
169
+ def _extract_standard_config(learn) -> dict:
170
+ """Extract config from standard MedDataBlock workflow.
171
+
172
+ Args:
173
+ learn: fastai Learner instance
174
+
175
+ Returns:
176
+ Dictionary with extracted configuration
177
+ """
178
+ from fastMONAI.vision_core import MedBase
179
+ dls = learn.dls
180
+
181
+ # Get preprocessing from MedBase class attributes
182
+ apply_reorder = MedBase.apply_reorder
183
+ target_spacing = MedBase.target_spacing
184
+
185
+ # Extract item_tfms from DataLoaders pipeline
186
+ item_tfms = []
187
+ if hasattr(dls, 'after_item') and dls.after_item:
188
+ item_tfms = list(dls.after_item.fs)
189
+
190
+ # Extract size from PadOrCrop transform
191
+ size = _extract_size_from_transforms(item_tfms)
192
+
193
+ return {
194
+ 'apply_reorder': apply_reorder,
195
+ 'target_spacing': target_spacing,
196
+ 'size': size,
197
+ 'item_tfms': item_tfms,
198
+ 'batch_size': dls.bs,
199
+ 'patch_config': None,
200
+ }
201
+
202
+
203
+ def _extract_patch_config(learn) -> dict:
204
+ """Extract config from MedPatchDataLoaders workflow.
205
+
206
+ Args:
207
+ learn: fastai Learner instance
208
+
209
+ Returns:
210
+ Dictionary with extracted configuration including patch-specific params
211
+ """
212
+ dls = learn.dls
213
+ patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)
214
+
215
+ config = {
216
+ 'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),
217
+ 'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),
218
+ 'size': patch_config.patch_size if patch_config else None,
219
+ 'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],
220
+ 'batch_size': dls.bs,
221
+ }
222
+
223
+ # Add patch-specific params for logging
224
+ if patch_config:
225
+ config['patch_config'] = {
226
+ 'patch_size': patch_config.patch_size,
227
+ 'patch_overlap': patch_config.patch_overlap,
228
+ 'samples_per_volume': patch_config.samples_per_volume,
229
+ 'sampler_type': patch_config.sampler_type,
230
+ 'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,
231
+ 'queue_length': patch_config.queue_length,
232
+ 'aggregation_mode': patch_config.aggregation_mode,
233
+ 'padding_mode': patch_config.padding_mode,
234
+ 'keep_largest_component': patch_config.keep_largest_component,
235
+ }
236
+ else:
237
+ config['patch_config'] = None
238
+
239
+ return config
240
+
241
+
242
+ def _extract_loss_name(learn) -> str:
243
+ """Extract loss function name from Learner.
244
+
245
+ Args:
246
+ learn: fastai Learner instance
247
+
248
+ Returns:
249
+ Name of the loss function
250
+ """
251
+ loss_func = learn.loss_func
252
+ # Handle CustomLoss wrapper
253
+ if hasattr(loss_func, 'loss_func'):
254
+ inner = loss_func.loss_func
255
+ return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__
256
+ return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__
257
+
258
+
259
+ def _extract_model_name(learn) -> str:
260
+ """Extract model architecture name from Learner.
261
+
262
+ Args:
263
+ learn: fastai Learner instance
264
+
265
+ Returns:
266
+ Name of the model architecture
267
+ """
268
+ model = learn.model
269
+ return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__
270
+
271
+ # %% ../nbs/07_utils.ipynb 10
56
272
  class ModelTrackingCallback(Callback):
57
273
  """
58
274
  A FastAI callback for comprehensive MLflow experiment tracking.
59
275
 
60
276
  This callback automatically logs hyperparameters, metrics, model artifacts,
61
- and configuration to MLflow during training.
277
+ and configuration to MLflow during training. If a SaveModelCallback is present,
278
+ the best model checkpoint will also be logged as an artifact.
279
+
280
+ Supports auto-managed runs when created via `create_mlflow_callback()`.
62
281
  """
63
282
 
64
283
  def __init__(
@@ -67,8 +286,14 @@ class ModelTrackingCallback(Callback):
67
286
  loss_function: str,
68
287
  item_tfms: list[Any],
69
288
  size: list[int],
70
- resample: list[float],
71
- reorder: bool
289
+ target_spacing: list[float],
290
+ apply_reorder: bool,
291
+ experiment_name: str = None,
292
+ run_name: str = None,
293
+ auto_start: bool = False,
294
+ patch_config: dict = None,
295
+ extra_params: dict = None,
296
+ extra_tags: dict = None
72
297
  ):
73
298
  """
74
299
  Initialize the MLflow tracking callback.
@@ -76,16 +301,32 @@ class ModelTrackingCallback(Callback):
76
301
  Args:
77
302
  model_name: Name of the model architecture for registration
78
303
  loss_function: Name of the loss function being used
304
+ item_tfms: List of item transforms
79
305
  size: Model input dimensions
80
- resample: Resampling dimensions
81
- reorder: Whether reordering augmentation is applied
306
+ target_spacing: Resampling dimensions
307
+ apply_reorder: Whether reordering augmentation is applied
308
+ experiment_name: MLflow experiment name (used with auto_start)
309
+ run_name: MLflow run name (auto-generated if None)
310
+ auto_start: If True, auto-starts/stops MLflow run
311
+ patch_config: Patch configuration dict for logging (from MedPatchDataLoaders)
312
+ extra_params: Additional parameters to log
313
+ extra_tags: MLflow tags to set on the run
82
314
  """
83
315
  self.model_name = model_name
84
316
  self.loss_function = loss_function
85
317
  self.item_tfms = item_tfms
86
318
  self.size = size
87
- self.resample = resample
88
- self.reorder = reorder
319
+ self.target_spacing = target_spacing
320
+ self.apply_reorder = apply_reorder
321
+
322
+ # New auto-management fields
323
+ self.experiment_name = experiment_name
324
+ self.run_name = run_name
325
+ self.auto_start = auto_start
326
+ self.patch_config = patch_config
327
+ self.extra_params = extra_params or {}
328
+ self.extra_tags = extra_tags or {}
329
+ self._auto_started = False
89
330
 
90
331
  self.config = self._build_config()
91
332
 
@@ -128,8 +369,8 @@ class ModelTrackingCallback(Callback):
128
369
  "loss_function": self.loss_function,
129
370
  "transform_details": transform_details,
130
371
  "size": self.size,
131
- "resample": self.resample,
132
- "reorder": self.reorder,
372
+ "target_spacing": self.target_spacing,
373
+ "apply_reorder": self.apply_reorder,
133
374
  }
134
375
 
135
376
  def _extract_training_params(self) -> dict[str, Any]:
@@ -143,8 +384,8 @@ class ModelTrackingCallback(Callback):
143
384
 
144
385
  params["loss_function"] = self.config["loss_function"]
145
386
  params["size"] = self.config["size"]
146
- params["resample"] = self.config["resample"]
147
- params["reorder"] = self.config["reorder"]
387
+ params["target_spacing"] = self.config["target_spacing"]
388
+ params["apply_reorder"] = self.config["apply_reorder"]
148
389
 
149
390
  params["transformations"] = json.dumps(
150
391
  self.config["transform_details"],
@@ -152,6 +393,12 @@ class ModelTrackingCallback(Callback):
152
393
  separators=(',', ': ')
153
394
  )
154
395
 
396
+ # Add patch-specific params if present
397
+ if self.patch_config:
398
+ for key, value in self.patch_config.items():
399
+ if value is not None:
400
+ params[f"patch_{key}"] = value if isinstance(value, (int, float, str, bool)) else str(value)
401
+
155
402
  return params
156
403
 
157
404
  def _extract_epoch_metrics(self) -> dict[str, float]:
@@ -201,6 +448,23 @@ class ModelTrackingCallback(Callback):
201
448
  if os.path.exists(weights_file):
202
449
  mlflow.log_artifact(weights_file, "model")
203
450
 
451
+ # Auto-detect SaveModelCallback and log best model
452
+ from fastai.callback.tracker import SaveModelCallback
453
+ best_model_cb = None
454
+ for cb in self.learn.cbs:
455
+ if isinstance(cb, SaveModelCallback):
456
+ best_model_path = self.learn.path/self.learn.model_dir/f'{cb.fname}.pth'
457
+ if best_model_path.exists():
458
+ mlflow.log_artifact(str(best_model_path), "model")
459
+ print(f"Logged best model artifact: {cb.fname}.pth")
460
+ best_model_cb = cb
461
+ break
462
+
463
+ # Load best model weights before export if SaveModelCallback was used
464
+ if best_model_cb is not None:
465
+ self.learn.load(best_model_cb.fname)
466
+ print(f"Loaded best model weights ({best_model_cb.fname}) for learner export")
467
+
204
468
  # Remove MLflow callbacks before exporting learner for inference
205
469
  # This prevents the callback from being triggered during inference
206
470
  original_cbs = self.learn.cbs.copy() # Save original callbacks
@@ -218,7 +482,7 @@ class ModelTrackingCallback(Callback):
218
482
  self.learn.cbs = original_cbs
219
483
 
220
484
  config_path = temp_dir / "inference_settings.pkl"
221
- store_variables(config_path, self.size, self.reorder, self.resample)
485
+ store_variables(config_path, self.size, self.apply_reorder, self.target_spacing)
222
486
  mlflow.log_artifact(str(config_path), "config")
223
487
 
224
488
  def _register_pytorch_model(self) -> None:
@@ -229,8 +493,22 @@ class ModelTrackingCallback(Callback):
229
493
  )
230
494
 
231
495
  def before_fit(self) -> None:
232
- """Log hyperparameters before training starts."""
496
+ """Log hyperparameters before training starts. Auto-start run if configured."""
497
+ # Auto-start run if requested
498
+ if self.auto_start:
499
+ if self.experiment_name:
500
+ mlflow.set_experiment(self.experiment_name)
501
+ self.run_name = self.run_name or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
502
+ mlflow.start_run(run_name=self.run_name)
503
+ self._auto_started = True
504
+
505
+ # Set tags
506
+ if self.extra_tags:
507
+ mlflow.set_tags(self.extra_tags)
508
+
509
+ # Log params
233
510
  params = self._extract_training_params()
511
+ params.update(self.extra_params)
234
512
  mlflow.log_params(params)
235
513
 
236
514
  def after_epoch(self) -> None:
@@ -249,10 +527,104 @@ class ModelTrackingCallback(Callback):
249
527
  self._save_model_artifacts(temp_path)
250
528
 
251
529
  self._register_pytorch_model()
252
-
253
- print(f"MLflow run completed. Run ID: {mlflow.active_run().info.run_id}")
530
+
531
+ run_id = mlflow.active_run().info.run_id
532
+ print(f"MLflow run completed. Run ID: {run_id}")
533
+
534
+ # End run if auto-started
535
+ if self._auto_started:
536
+ mlflow.end_run()
537
+ self._auto_started = False
254
538
 
255
- # %% ../nbs/07_utils.ipynb 7
539
+ # %% ../nbs/07_utils.ipynb 11
540
+ def create_mlflow_callback(
541
+ learn,
542
+ experiment_name: str = None,
543
+ run_name: str = None,
544
+ auto_start: bool = True,
545
+ model_name: str = None,
546
+ extra_params: dict = None,
547
+ extra_tags: dict = None,
548
+ ) -> ModelTrackingCallback:
549
+ """Create MLflow tracking callback with auto-extracted configuration.
550
+
551
+ This factory function automatically extracts configuration from the Learner,
552
+ eliminating the need to manually specify parameters like size, transforms,
553
+ loss function, etc.
554
+
555
+ Auto-extracts from Learner:
556
+ - Preprocessing: apply_reorder, target_spacing, size/patch_size
557
+ - Transforms: item_tfms or pre_patch_tfms
558
+ - Training: loss_func, model architecture
559
+
560
+ Args:
561
+ learn: fastai Learner instance
562
+ experiment_name: MLflow experiment name. If None, uses model name.
563
+ run_name: MLflow run name. If None, auto-generates with timestamp.
564
+ auto_start: If True, auto-starts/stops MLflow run in before_fit/after_fit.
565
+ model_name: Override auto-extracted model name for registration.
566
+ extra_params: Additional parameters to log (e.g., {'dropout': 0.5}).
567
+ extra_tags: MLflow tags to set on the run.
568
+
569
+ Returns:
570
+ ModelTrackingCallback ready to use with learn.fit()
571
+
572
+ Example:
573
+ >>> # Instead of this (6 manual params):
574
+ >>> # mlflow_callback = ModelTrackingCallback(
575
+ >>> # model_name=f"{task}_{model._get_name()}",
576
+ >>> # loss_function=loss_func.loss_func._get_name(),
577
+ >>> # item_tfms=item_tfms,
578
+ >>> # size=size,
579
+ >>> # target_spacing=target_spacing,
580
+ >>> # apply_reorder=True,
581
+ >>> # )
582
+ >>> # with mlflow.start_run(run_name="training"):
583
+ >>> # learn.fit_one_cycle(30, lr, cbs=[mlflow_callback])
584
+ >>>
585
+ >>> # Do this (zero manual params):
586
+ >>> callback = create_mlflow_callback(learn, experiment_name="Task02_Heart")
587
+ >>> learn.fit_one_cycle(30, lr, cbs=[callback, save_best])
588
+ """
589
+ # Detect workflow and extract config
590
+ if _detect_patch_workflow(learn.dls):
591
+ config = _extract_patch_config(learn)
592
+ else:
593
+ config = _extract_standard_config(learn)
594
+
595
+ # Extract model/loss info
596
+ _model_name = model_name or _extract_model_name(learn)
597
+ _loss_name = _extract_loss_name(learn)
598
+
599
+ # Set experiment name
600
+ if experiment_name is None:
601
+ experiment_name = _model_name
602
+
603
+ # Validate size was extracted
604
+ if config['size'] is None:
605
+ raise ValueError(
606
+ "Could not auto-extract 'size'. Either:\n"
607
+ "1. Add PadOrCrop to item_tfms, or\n"
608
+ "2. Use MedPatchDataLoaders with PatchConfig, or\n"
609
+ "3. Use ModelTrackingCallback directly with manual params"
610
+ )
611
+
612
+ return ModelTrackingCallback(
613
+ model_name=_model_name,
614
+ loss_function=_loss_name,
615
+ item_tfms=config['item_tfms'],
616
+ size=config['size'],
617
+ target_spacing=config['target_spacing'],
618
+ apply_reorder=config['apply_reorder'],
619
+ experiment_name=experiment_name,
620
+ run_name=run_name,
621
+ auto_start=auto_start,
622
+ patch_config=config['patch_config'],
623
+ extra_params=extra_params,
624
+ extra_tags=extra_tags,
625
+ )
626
+
627
+ # %% ../nbs/07_utils.ipynb 13
256
628
  import subprocess
257
629
  import threading
258
630
  import time
@@ -286,7 +658,7 @@ class MLflowUIManager:
286
658
  try:
287
659
  response = requests.get(f'http://localhost:{self.port}', timeout=2)
288
660
  return response.status_code == 200
289
- except:
661
+ except (requests.RequestException, ConnectionError, OSError):
290
662
  return False
291
663
 
292
664
  def find_available_port(self, start_port=5001):
fastMONAI/vision_all.py CHANGED
@@ -2,10 +2,11 @@
2
2
  from .vision_core import *
3
3
  from .vision_data import *
4
4
  from .vision_augmentation import *
5
- from .vision_loss import *
5
+ from .vision_loss import *
6
6
  from .vision_metrics import *
7
7
  from .vision_inference import *
8
- from .utils import *
8
+ from .vision_patch import *
9
+ from .utils import *
9
10
  from .external_data import *
10
11
  from .dataset_info import *
11
12